非完全原创,但遗失了原文链接,看到可更改

迁移学习

一、迁移学习方法介绍

1. 微调网络的方法

微调网络的方法实现迁移学习,更改最后一层全连接,并且微调训练网络

2. 将模型看成特征提取器

模型看成特征提取器,如果一个模型的预训练模型非常的好,那完全就把前面的层看成特征提取器冻结所有层并且更改最后一层只训练最后一层,这样我们只训练了最后一层,训练会非常的快速

两者区别:1. 前者是会在更改网络之后微调训练全部的网络参数

  1. 后者是将前面的层冻结,只训练最后一层,训练速度会很快

二、迁移基本步骤

  1. 数据的准备

  2. 选择数据增广的方式 : 扩展数据集

  3. 选择合适的模型

  4. 更换最后一层全连接

  5. 冻结层,开始训练

  6. 选择预测结果最好的模型保存

是先冻结部分层,此时已经将这些层的权重传入了,然后再修改网络的结构,最后是训练修改了的结构的权重

三、简单迁移学习案例

3.1 数据准备

# 解压数据到指定文件
def unzip(filename, dst_dir):z = zipfile.ZipFile(filename)z.extractall(dst_dir)
unzip('./data/hymenoptera_data.zip', './data/')
# 实现自己的Dataset方法,主要实现两个方法__len__和__getitem__
class MyDataset(Dataset):def __init__(self, dirname, transform=None):super(MyDataset, self).__init__()self.classes = os.listdir(dirname)self.images = []self.transform = transformfor i, classes in enumerate(self.classes):classes_path = os.path.join(dirname, classes)for image_name in os.listdir(classes_path):self.images.append((os.path.join(classes_path, image_name), i))def __len__(self):return len(self.images)def __getitem__(self, idx):image_name, classes = self.images[idx]image = Image.open(image_name)if self.transform:image = self.transform(image)return image, classesdef get_claesses(self):return self.classes
# 分布实现训练和预测的transform
train_transform = transforms.Compose([transforms.Grayscale(3),transforms.RandomResizedCrop(224), #随机裁剪一个area然后再resizetransforms.RandomHorizontalFlip(), #随机水平翻转transforms.Resize(size=(256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([transforms.Grayscale(3),transforms.Resize(size=(256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 分别实现loader
train_dataset = MyDataset('./data/hymenoptera_data/train/', train_transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)
val_dataset = MyDataset('./data/hymenoptera_data/val/', val_transform)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=32)

3.2 选择预训练的模型

这里我们选择了resnet18在ImageNet 1000类上进行了预训练的

model = models.resnet18(pretrained=True) # 使用预训练

使用model.buffers查看网络基本结构

然后修改网络结构

from torchvision.models.resnet import *
def get_net():model = resnet18(pretrained=True)model.avgpool = nn.AdaptiveAvgPool2d((1, 1))model.fc = nn.Sequential(nn.BatchNorm1d(512*1),nn.Linear(512*1, 你的分类类别数),)return model

代码简单解读一下:

首先,通过torchvision导入相关的函数

通过resnet18( )实例化一个模型,并使用imagenet预训练权重

将平均池化修改为自适应全局平均池化,避免输入特征尺寸不匹配

修改全连接层,主要是修改分类类别数,并加入BN1d

这样子,不仅可以根据自己的需求改造网络,还能最大限度的使用现成的预训练权重。需要注意的是,这里的nn.BatchNorm1d(512*1)是很必要的,初学者可以尝试删除这个部件感受一下区别。

总结两点:

1. 由于pytorch框架是先创建网络中会使用到的模块,然后一个个拼接(在forward方法中),init方法是我们定义的一个个模块(包括Conv2、pool池化等),所以当我们创建一个对象时,还能够通过修改对象里面的属性来修改网络的结构,然后在实际训练的时候调用forward()方法,此时用的已经是修改过的网络了。

2. 相对于tensorflow,tf是将网络结构直接定义死的,不容易在现成的网络结构中进行修改。

替换最后一层网络结构

only_train_fc = True
if only_train_fc:for param in model.parameters():param.requires_grad_(False)
fc_in_features = model.fc.in_features
model.fc = torch.nn.Linear(fc_in_features, 2, bias=True)

注释:only_train_fc如果我们设置为True那么就只训练最后的fc层 现在观察一下可导的参数有那些(在只训练最后一层的情况下)

for i in model.parameters():if i.requires_grad:print(i)

3.3 训练主体实现

epochs = 50
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(lr=0.01, params=model.parameters())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
model.to(device)
opt_step = torch.optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.1)
max_acc = 0
epoch_acc = []
epoch_loss = []
for epoch in range(epochs):for type_id, loader in enumerate([train_loader, val_loader]):mean_loss = []mean_acc = []for images, labels in loader:if type_id == 0:# opt_step.step()model.train()else:model.eval()images = images.to(device)labels = labels.to(device).long()opt.zero_grad()with torch.set_grad_enabled(type_id==0):outputs = model(images)_, pre_labels = torch.max(outputs, 1)loss = loss_fn(outputs, labels)if type_id == 0:loss.backward()opt.step()acc = torch.sum(pre_labels==labels) / torch.tensor(labels.shape[0], dtype=torch.float32)        mean_loss.append(loss.cpu().detach().numpy())mean_acc.append(acc.cpu().detach().numpy())if type_id == 1:epoch_acc.append(np.mean(mean_acc))epoch_loss.append(np.mean(mean_loss))if max_acc < np.mean(mean_acc):max_acc = np.mean(mean_acc)print(type_id, np.mean(mean_loss),np.mean(mean_acc))
print(max_acc)

四、总结

使用了预训练模型,发现大概10个epoch就可以很快的得到较好的结果了,即使在使用cpu情况下训练,这也是迁移学习为什么这么受欢迎的原因之一了,如果读者有兴趣可以自己试一试在不冻结层的情况下,使用方法一能否得到更好的结果

迁移学习一、基本使用相关推荐

  1. PyTorch迁移学习

    PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...

  2. VGG16迁移学习实现

    VGG16迁移学习实现 本文讨论迁移学习,它是一个非常强大的深度学习技术,在不同领域有很多应用.动机很简单,可以打个比方来解释.假设想学习一种新的语言,比如西班牙语,那么从已经掌握的另一种语言(比如英 ...

  3. 优达学城《DeepLearning》2-2:迁移学习

    目录 加载和预处理数据 转换数据 数据加载器和数据可视化 定义模型 最终分类器层 指定损失函数和优化器 训练 测试 可视化样本测试结果 大多数时候,你不会想自己训练一个完整的卷积网络.像ImageNe ...

  4. 【图像分类案例】(2) DenseNet 天气图片四分类(权重迁移学习),附Tensorflow完整代码

    各位同学好,今天和大家分享一下使用 Tensorflow 构建 DenseNet 卷积神经网络模型,并使用预训练模型的权重,完成对四种天气图片的分类. 完整代码在我的 Gitee 中,有需要的自取: ...

  5. 【神经网络】(7) 迁移学习(CNN-MobileNetV2),案例:乳腺癌二分类

    各位同学好,今天和大家分享一下Tensorflow2.0中如何使用迁移学习的方法构造神经网络.需要数据集的在评论区留个言. 1. 迁移学习 官方文档:Module: tf.keras.applicat ...

  6. 一、迁移学习与fine-tuning有什么区别?

    一.迁移学习 举个例子,假设boss让你做一下目标检测,这个数据集是关于光纤箱喷码字符检测的.问题是,数据集很少(只有1000张数据),里面有多干扰的信息,你发现从零训练开始训练yolo的效果很差,很 ...

  7. pytorch与keras_Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者

    pytorch与keras by Patryk Miziuła 通过PatrykMiziuła Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者 (Keras vs PyTorch ...

  8. 读懂深度迁移学习,看这文就够了 | 赠书

    百度前首席科学家.斯坦福大学副教授吴恩达(Andrew Ng)曾经说过:迁移学习将是继监督学习之后的下一个促使机器学习成功商业化的驱动力. 本文选自<深度学习500问:AI工程师面试宝典> ...

  9. 杂谈 | 当前知识蒸馏与迁移学习有哪些可用的开源工具?

    所有参与投票的 CSDN 用户都参加抽奖活动 群内公布奖项,还有更多福利赠送 作者&编辑 | 言有三 来源 | 有三AI(ID:yanyousan_ai) [导读]知识蒸馏与迁移学习不仅仅属于 ...

  10. 迁移学习前沿研究亟需新鲜血液,深度学习理论不能掉链子

    作者 | Frederico Guth,Teófilo Emidio de Campos 编译 | 夕颜 出品 | AI科技大本营(ID:rgznai100) [导读]人类可以从很少的样本中学习,显示 ...

最新文章

  1. pyqt setStyleSheet用法
  2. linux图机界面机制
  3. 利用RTL2832u电视棒芯片追踪民航飞机轨迹
  4. [html] 网页上的验证码是为了解决什么问题?说说你了解的验证码种类有哪些
  5. Windows和linux提权方法,Windows与Linux本地用户提权体验(一)
  6. jdbctemplate无where条件查询_多表查询
  7. opencv_判断两张图片是否相同
  8. mysql数据库引擎问题
  9. 中国医学史(第三章 中医药理论体系的初步形成)
  10. Java锁原理与应用
  11. excel中相对引用、绝对引用、混合引用
  12. 论文解读 | CenterNet:Keypoint Triplets for Object Detection
  13. 用Python批量下载视频
  14. 非法经营?USDT涉刑分析
  15. ubuntu top命令详解
  16. 1004: 惠民工程
  17. 关于正则表达式里含有空格的问题
  18. ide编辑器 android,从 IDE 到终端 + 文本编辑器
  19. python 建筑结构设计_新手进入建筑设计院做结构设计,主要看哪些书籍?
  20. 看视频学编程的一点小建议

热门文章

  1. Win10添加简体中文美式键盘的方法
  2. 百度搜索引擎高级指令使用及实例
  3. 计算机主机光驱弹不出来怎么办,台式机光驱弹不出来怎么办
  4. AI绘画火爆,到现在还只是冰山一角?AIGC掀起当代新艺术浪潮
  5. Java 获取 n个 工作日【前】或【后】的日期
  6. 软件定义广域网和即将到来的网络洪流
  7. 个人 OKR目标如何实际伤害你的团队
  8. java 微博客户端_[置顶] java新浪微博客户端
  9. 计算机四级英语翻译,全国英语四级考试翻译特训题
  10. python coding style_python coding style guide 的高速落地实践