两个主要的迁移学习场景:
- Finetuning the convnet: 我们使用预训练网络初始化网络,而不是随机初始化,就像在imagenet 1000数据集上训练的网络一样。其余训练看起来像往常一样。
- ConvNet as fixed feature extractor: 在这里,我们将冻结除最终完全连接层之外的所有网络的权重。最后一个全连接层被替换为具有随机权重的新层,并且仅训练该层。
参考链接:
1 https://pytorch.apachecn.org/docs/1.0/transfer_learning_tutorial.html
2 https://pytorch.apachecn.org/docs/1.0/finetuning_torchvision_models_tutorial.html
导包
1 | from __future__ import print_function, division |
加载数据
我们将使用 torchvision 和 torch.utils.data 包来加载数据。
我们今天要解决的问题是训练一个模型来对 蚂蚁
和 蜜蜂
进行分类。我们有大约120个训练图像,每个图像用于 蚂蚁
和 蜜蜂
。每个类有75个验证图像。通常,如果从头开始训练,这是一个非常小的数据集。由于我们正在使用迁移学习,我们应该能够合理地推广。
该数据集是 imagenet 的一个非常小的子集。
注意
从 此处 下载数据并将其解压缩到当前目录。
1 | # Data augmentation and normalization for training |
{'train': 244, 'val': 153}
['ants', 'bees']
可视化图像
1 | def imshow(inp, title=None): |
训练模型的通用函数
函数有以下功能:
- 训练模型
- 调整学习率
- 保存最佳的学习模型
函数中, scheduler
参数是 torch.optim.lr_scheduler
中的 LR scheduler 对象
1 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25): |
可视化模型函数
用于显示少量图像预测的通用功能
1 | def visualize_model(model, num_images=6): |
微调卷积网络
加载预训练模型并重置最终的全连接层
1 | model_ft = models.resnet18(pretrained=True) |
训练和评估
1 | model_ft, val_acc_history = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler) |
Epoch 1/25
----------
train Loss: 0.6349 Acc: 0.7049
val Loss: 0.5844 Acc: 0.7451
...
Epoch 25/25
----------
train Loss: 0.3083 Acc: 0.8484
val Loss: 0.2406 Acc: 0.9085
Training complete in 29m 38s
Best val Acc: 0.9346
1 | visualize_model(model_ft) |
ConvNet 作为固定特征提取器
在这里,我们需要冻结除最后一层之外的所有网络。我们需要设置 requires_grad == False
冻结参数,以便在 backward()
中不计算梯度。
1 | model_conv = models.resnet18(pretrained=True) |
训练和评估
在CPU上,与前一个场景相比,这将花费大约一半的时间。这是预期的,因为不需要为大多数网络计算梯度。但是,前向传递需要计算梯度。
1 | model_conv, val_acc_hietory = train_model(model_conv, criterion, optimizer_conv, |
Epoch 1/10
----------
train Loss: 0.6322 Acc: 0.6148
val Loss: 0.2261 Acc: 0.9346
...
Epoch 10/10
----------
train Loss: 0.4004 Acc: 0.8197
val Loss: 0.1821 Acc: 0.9477
Training complete in 6m 5s
Best val Acc: 0.9477
1 | hist = [h.cpu().numpy() for h in val_acc_hietory] |
1 | visualize_model(model_conv) |