1. 新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器
  2. 新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine-tuning
  3. 新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的化那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据
  4. 新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错
  5. 如果数据集大小不同的话,可以在最后的fc层之前添加卷积或者pool层,使得最后的输出与fc层一致,但这样会导致准确度大幅下降,所以不建议这样做
  6. 对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于(一般可设置小于10倍)初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。
%matplotlib inline
import torch,os,torchvision
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
torch.__version__

这里面我们使用官方训练好的resnet50来参加kaggle上面的 dog breed 狗的种类识别来做一个简单微调实例。

首先我们需要下载官方的数据解压,只要保持数据的目录结构即可,这里指定一下目录的位置,并且看下内容

DATA_ROOT = '/home/huangshaobo/dataset/dog'
all_labels_df = pd.read_csv(os.path.join(DATA_ROOT,'labels.csv'))
all_labels_df.head()


获取狗的分类,根据分类进行编号。这里定义了两个字典,分别以名字和id作为对应,方便后面处理:

breeds = all_labels_df.breed.unique()
breed2idx = dict((breed,idx) for idx,breed in enumerate(breeds))
idx2breed = dict((idx,breed) for idx,breed in enumerate(breeds))
len(breeds)


添加到列表中:

all_labels_df['label_idx'] = [breed2idx[b] for b in all_labels_df.breed]
all_labels_df.head()


由于我们的数据集不是官方指定的格式,我们自己定义一个数据集:

class DogDataset(Dataset):def __init__(self, labels_df, img_path, transform=None):self.labels_df = labels_dfself.img_path = img_pathself.transform = transformdef __len__(self):return self.labels_df.shape[0]def __getitem__(self, idx):image_name = os.path.join(self.img_path, self.labels_df.id[idx]) + '.jpg'img = Image.open(image_name)label = self.labels_df.label_idx[idx]if self.transform:img = self.transform(img)return img, label
# 定义一些超参数:
IMG_SIZE = 224 # resnet50的输入是224的所以需要将图片统一大小
BATCH_SIZE= 256 #这个批次大小需要占用4.6-5g的显存,如果不够的化可以改下批次,如果内存超过10G可以改为512
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]
CUDA=torch.cuda.is_available()
DEVICE = torch.device("cuda" if CUDA else "cpu")# 定义训练和验证数据的图片变换规则:
train_transforms = transforms.Compose([transforms.Resize(IMG_SIZE),transforms.RandomResizedCrop(IMG_SIZE),transforms.RandomHorizontalFlip(),transforms.RandomRotation(30),transforms.ToTensor(),transforms.Normalize(IMG_MEAN, IMG_STD)
])val_transforms = transforms.Compose([transforms.Resize(IMG_SIZE),transforms.CenterCrop(IMG_SIZE),transforms.ToTensor(),transforms.Normalize(IMG_MEAN, IMG_STD)
])

我们这里只分割10%的数据作为训练时的验证数据:

dataset_names = ['train', 'valid']
stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=0)
train_split_idx, val_split_idx = next(iter(stratified_split.split(all_labels_df.id, all_labels_df.breed)))
train_df = all_labels_df.iloc[train_split_idx].reset_index()
val_df = all_labels_df.iloc[val_split_idx].reset_index()
print(len(train_df))
print(len(val_df))


使用官方的dataloader载入数据:

image_transforms = {'train':train_transforms, 'valid':val_transforms}train_dataset = DogDataset(train_df, os.path.join(DATA_ROOT,'train'), transform=image_transforms['train'])
val_dataset = DogDataset(val_df, os.path.join(DATA_ROOT,'train'), transform=image_transforms['valid'])
image_dataset = {'train':train_dataset, 'valid':val_dataset}image_dataloader = {x:DataLoader(image_dataset[x],batch_size=BATCH_SIZE,shuffle=True,num_workers=0) for x in dataset_names}
dataset_sizes = {x:len(image_dataset[x]) for x in dataset_names}
# 开始配置网络,由于ImageNet是识别1000个物体,我们的狗的分类一共只有120,
# 所以需要对模型的最后一层全连接层进行微调,将输出从1000改为120:model_ft = models.resnet50(pretrained=True) # 这里自动下载官方的预训练模型,并且
# 将所有的参数层进行冻结
for param in model_ft.parameters():param.requires_grad = False
# 这里打印下全连接层的信息
print(model_ft.fc)
num_fc_ftr = model_ft.fc.in_features #获取到fc层的输入
model_ft.fc = nn.Linear(num_fc_ftr, len(breeds)) # 定义一个新的FC层
model_ft=model_ft.to(DEVICE)# 放到设备中
print(model_ft) # 最后再打印一下新的模型
# 设置训练参数:criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam([{'params':model_ft.fc.parameters()}
], lr=0.001)#指定 新加的fc层的学习率

定义训练函数:

def train(model,device, train_loader, epoch):model.train()for batch_idx, data in enumerate(train_loader):x,y= datax=x.to(device)y=y.to(device)optimizer.zero_grad()y_hat= model(x)loss = criterion(y_hat, y)loss.backward()optimizer.step()print ('Train Epoch: {}\t Loss: {:.6f}'.format(epoch,loss.item()))

定义测试函数:

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for i,data in enumerate(test_loader):          x,y= datax=x.to(device)y=y.to(device)optimizer.zero_grad()y_hat = model(x)test_loss += criterion(y_hat, y).item() # sum up batch losspred = y_hat.max(1, keepdim=True)[1] # get the index of the max log-probabilitycorrect += pred.eq(y.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(val_dataset),100. * correct / len(val_dataset)))

训练9次,看看效果:

for epoch in range(1, 10):%time train(model=model_ft,device=DEVICE, train_loader=image_dataloader["train"],epoch=epoch)test(model=model_ft, device=DEVICE, test_loader=image_dataloader["valid"])


我们看到只训练了9次就达到了80%的准确率,效果还是可以的。

但是每次训练都需要将一张图片在全部网络中进行计算,而且计算的结果每次都是一样的,这样浪费了很多计算的资源。 下面我们就将这些不进行反向传播或者说不更新网络权重参数层的计算结果保存下来, 这样我们以后使用的时候就可以直接将这些结果输入到FC层或者以这些结果构建新的网络层, 省去了计算的时间,并且这样如果只训练全连接层,CPU就可以完成了。

PyTorch论坛中说到可以使用自己手动实现模型中的forward参数,这样看起来是很简便的,但是这样处理起来很麻烦,不建议这样使用。

这里我们就要采用PyTorch比较高级的API,hook来处理了,我们要先定义一个hook函数

in_list= [] # 这里存放所有的输出
def hook(module, input, output):#input是一个tuple代表顺序代表每一个输入项,我们这里只有一项,所以直接获取#需要全部的参数信息可以使用这个打印#for val in input:#    print("input val:",val)for i in range(input[0].size(0)):in_list.append(input[0][i].cpu().numpy())
# 在相应的层注册hook函数,保证函数能够正常工作,我们这里直接hook 全连接层前面的pool层,获取pool层的输入数据,这样会获得更多的特征:model_ft.avgpool.register_forward_hook(hook)
# 开始获取输出,这里我们因为不需要反向传播,所以直接可以使用no_grad嵌套:%%time
with torch.no_grad():for batch_idx, data in enumerate(image_dataloader["train"]):x,y= datax=x.to(DEVICE)y=y.to(DEVICE)y_hat = model_ft(x)

pytorch 模型微调相关推荐

  1. pytorch模型微调(Finetune)

    Transfer Learning & Model Finetune 模型微调 **Transfer Learning:**机器学习分支,研究源域(source domain)的知识如何应用到 ...

  2. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  3. PyTorch 1.0 中文官方教程:Torchvision 模型微调

    译者:ZHHAYO 作者: Nathan Inkawhich 在本教程中,我们将深入探讨如何微调和特征提取torchvision 模型,所有这些模型都已经预先在1000类的magenet数据集上训练完 ...

  4. 【pytorch笔记】(五)自定义损失函数、学习率衰减、模型微调

    本文目录: 1. 自定义损失函数 2. 动态调整学习率 3. 模型微调-torchvision 3.1 使用已有模型 3.2 训练特定层 1. 自定义损失函数 虽然pytorch提供了许多常用的损失函 ...

  5. Hugging Face实战(NLP实战/Transformer实战/预训练模型/分词器/模型微调/模型自动选择/PyTorch版本/代码逐行解析)下篇之模型训练

    模型训练的流程代码是不是特别特别多啊?有的童鞋看过Bert那个源码写的特别特别详细,参数贼多,运行一个模型百八十个参数的. Transformer对NLP的理解是一个大道至简的感觉,Hugging F ...

  6. Pytorch之模型微调(Finetune)——用Resnet18进行蚂蚁蜜蜂二分类为例

    Pytorch之模型微调(Finetune)--手写数字集为例 文章目录 Pytorch之模型微调(Finetune)--手写数字集为例 前言 一.Transfer Learning and Mode ...

  7. TensorFlow与PyTorch模型部署性能比较

    TensorFlow与PyTorch模型部署性能比较 前言 2022了,选 PyTorch 还是 TensorFlow?之前有一种说法:TensorFlow 适合业界,PyTorch 适合学界.这种说 ...

  8. 在 Amazon SageMaker 上玩转 Stable Diffusion: 基于 Dreambooth 的模型微调

    本文将以 Stable Diffusion Quick Kit 为例,详细讲解如何利用 Dreambooth 对 Stable Diffusion 模型进行微调,包括基础的 Stable Diffus ...

  9. 自然语言处理NLP星空智能对话机器人系列:深入理解Transformer自然语言处理 基于BERT模型微调实现句子分类

    自然语言处理NLP星空智能对话机器人系列:深入理解Transformer自然语言处理 基于BERT模型微调实现句子分类 目录 基于BERT模型微调实现句子分类案例实战 Installing the H ...

  10. 模型微调(finetune)

    ----接上次的鸟的图像分类,其acc为84%. 这次依然使用此数据集,并用resenet网络进行finetune,然后进行鸟的图像分类. 1.什么是finetune? 利用已训练好的模型进行重构(自 ...

最新文章

  1. [翻译] AKKA笔记- ACTORSYSTEM (配置CONFIGURATION 与调度SCHEDULING) - 4(二)
  2. 用实例证明dll中new的内存不能在exe中释放
  3. 如何看待雅虎套现760亿美元从阿里巴巴退出?
  4. mysql工厂模式_设计模式-三种工厂模式实例
  5. oracle内置函数 wmsys.wm_concat使用
  6. 如何设置mysql的运行目录_如何修改mysql数据库文件的路径 | 学步园
  7. 北大计算机最好的班叫什么,中国大学计算机最好的班,再次迎来“图灵奖”导师,赶超“姚班”...
  8. 对话CDN巨头Akamai:携手金山云,意欲何为?
  9. 如何自动生成SpringBoot项目代码
  10. Unity3d 好友管理系统
  11. PC蛋蛋 按键精灵手机助手 安装+拉代码(视频)
  12. 这一篇彻底说清楚了!乐高,编程,机器人到底要不要学?
  13. 清明节出行客流 人山人海
  14. Linux下C语言开发
  15. 智能家居 “孤岛”:群雄并起 标准混战
  16. 【钉钉-场景化能力包】制造业考勤数据多维分析
  17. 面试官:知道你的接口QPS是多少么?
  18. 职业生涯规划之自我探索论文
  19. 【python 程序题:火车票购买程序】
  20. 免费天气预报查询 API、历史天气查询 API 接口使用示例【源码可用】

热门文章

  1. crx插件转换火狐插件_我的Firefox插件
  2. Nginx中传输带宽限制
  3. 如何快速成为CSDN的博客专家,以及「博客专家」申请及审核执行标准
  4. linux是基于什么的开源操作系统,什么是开源操作系统
  5. php写phalapi,PhalApi框架
  6. Web前端鼠标变小手CSS和JS(Vue)两种实现
  7. XP的定时关机命令?
  8. 经济学论文素材之美国浮动汇率制度
  9. 入侵检测系统的原理与应用
  10. 会让你变得与众不同的22个技巧