模型微调(finetune)
----接上次的鸟的图像分类,其acc为84%。
这次依然使用此数据集,并用resenet网络进行finetune,然后进行鸟的图像分类。
1、什么是finetune?
利用已训练好的模型进行重构(自己的理解)。 对给定的预训练模型(用数据训练好的模型)进行微调,直接利用预训练模型进行微调可以节省许多的时间,能在比较小的epoch下就达到比较好的效果。通常进行微调,1、自己构建模型效果差,所以采用一些常用的模型,别人用数据训好的。2、数据量不够大,所以采用微调。以下是模型微调的例子:
2、数据为鸟类的数据集,其一共有4个类别,如下所示:
1、数据的类别
2、图像数据
数据的前期处理和划分,划分可以用random.shuffule直接进行打乱,然后划分。
2、数据导入
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as nprandom_seed = 1
torch.manual_seed(random_seed)transform = transforms.Compose([# transforms.RandomRotation(1),transforms.Resize(224), #transforms.CenterCrop(224), #transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]), # 标准化至[-1, 1],规定均值和标准差])transform1 = transforms.Compose([# transforms.RandomRotation(1),# transforms.Resize(224), ## transforms.CenterCrop(224), # 从图片中间切出224*224的图片transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]), # 标准化至[-1, 1],规定均值和标准差# transforms.Normalize(mean=[.5], std=[.5]) # 通道数为1#input[channel] = (input[channel] - mean[channel]) / std[channel]
])class DogLoader(torch.utils.data.Dataset):def __init__(self,root,train=True, img_transform = None, target_transform=None,transform = None): # 负责将输入的数据做成list形式# 这样不会调用父类中的init方法,这个是重载了子类的init,并且这里的属性都属于全局的属性self.root = rootself.transform = img_transformself.target_transform = target_transformself.train = trainself.transforms = transformself.data = []self.label = []if self.train:with open('trainImg.txt') as fr:fr = fr.readlines()for imgPath in fr:img = imgPath.split('\t')[0]label = imgPath.split('\t')[1]self.data.append(root+img)self.label.append(int(label.strip()))else:with open('testImg.txt') as fr1:fr1 = fr1.readlines()for imgPath in fr1:img = imgPath.split('\t')[0]label = imgPath.split('\t')[1]self.data.append(root+img)self.label.append(int(label.strip()))def __getitem__(self, index): # 对数据进行编码,然后转换成我们想要的格式img,label = self.data[index],self.label[index]img = Image.open(img).convert('RGB') # 将图片转为RGB图像,为了有一些图像不是RGB的# img = Image.open(img_path).convert('L') # 将RGB的三通道转为一通道的数if self.transforms:img=self.transforms(img) # 传入transforms是PIL数据array = np.asarray(img)img = torch.from_numpy(array)return img, labeldef __len__(self):return len(self.data)train_data = DogLoader('data/',train=True,transform=transform)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)test_data = DogLoader('data/',train=False,transform=transform1) # 测试不进行数据增强,训练继续宁数据增
test_loader = DataLoader(test_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)# for i, trainData in enumerate(train_loader): # 将一个类的对象可以像list那样调用
# print("第 {} 个Batch \n{}".format(i, trainData))
#
# for i, testData in enumerate(test_loader): # 将一个类的对象可以像list那样调用
# print("第 {} 个Batch \n{}".format(i, testData))
__init__() :类的属性形式,这里主要获取图像的地址和图像的标签。
__getitem__(): 当使用这个函数时,它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。此处使用__getitem__()函数主要是通过后面训练时for循环得到图像的数据和标签。
2、预训练模型
# -*- coding: utf-8 -*-
from torchvision import models
from torch import nn
from global_config import *def fine_tune_resnet18(): # 这里表示为model_ft = models.resnet18(pretrained=True)'''这里写为True,会自动下载模型的参数,并加载到模型中。当然也可以手动下载模型的参数,然后将模型的参数加载到模型中'''# 把前面的特征进行了拼接print('num_features', model_ft)num_features = model_ft.fc.in_features# fine tune we change original fc layer into classes num of our ownmodel_ft.fc = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_vgg16():model_ft = models.vgg16(pretrained=True)print('fine_tune_vgg16() = ',model_ft)num_features = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_resnet18_():"""'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',"""model_ft = models.resnet18(pretrained=False)model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth')) # 加载已经下载好的模型参数print('model_ft: ',model_ft)print('resnet18-5c106cde.pth',model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth')))num_features = model_ft.fc.in_features# fine tune we change original fc layer into classes num of our ownmodel_ft.fc = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_resnet50():# 实际任务中这个挺重要的resNet50 = models.resnet50(pretrained=True) # 调用的预训练网络ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2) # 自己定义的网络# 读取参数pretrained_dict = resNet50.state_dict() # 读取预训练网络模型的参数model_dict = ResNet50.state_dict() # 读自定义模型的参数# 将pretained_dict里不属于model_dict的键剔除掉pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 剔除一些不同的网络模型参数# 更新现有的model_dictmodel_dict.update(pretrained_dict)# 加载真正需要的state_dictResNet50.load_state_dict(model_dict)
预训练模型的调用,models.resnet18(),表示的是调用resnet18模型,当models.resnet18(pretrained=True)的时候,则表示直接下载了模型的参数。当models.resnet18(pretrained=False)的时候,可以手动下载好模型torch.load('resnet18-5c106cde.pth')。由于不同的数据可能的类别不同所以通常对最后的一层更改。
4、实验结果:
经过2个epoch就可以达到99的准确率。
模型微调(finetune)相关推荐
- pytorch模型微调(Finetune)
Transfer Learning & Model Finetune 模型微调 **Transfer Learning:**机器学习分支,研究源域(source domain)的知识如何应用到 ...
- PyTorch框架学习二十——模型微调(Finetune)
PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...
- Pytorch之模型微调(Finetune)——用Resnet18进行蚂蚁蜜蜂二分类为例
Pytorch之模型微调(Finetune)--手写数字集为例 文章目录 Pytorch之模型微调(Finetune)--手写数字集为例 前言 一.Transfer Learning and Mode ...
- 模型微调迁移学习Finetune方法大全
迁移学习广泛地应用于NLP.CV等各种领域,通过在源域数据上学习知识,再迁移到下游其他目标任务上,提升目标任务上的效果.其中,Pretrain-Finetune(预训练+精调)模式是最为常见的一种迁移 ...
- 大模型微调技术(Adapter-Tuning、Prefix-Tuning、Prompt-Tuning(P-Tuning)、P-Tuning v2、LoRA)
2022年11月30日,ChatGPT发布至今,国内外不断涌现出了不少大模型,呈现"百模大战"的景象,比如ChatGLM-6B.LLAMA.Alpaca等模型及在此模型基础上进一步 ...
- 【pytorch笔记】(五)自定义损失函数、学习率衰减、模型微调
本文目录: 1. 自定义损失函数 2. 动态调整学习率 3. 模型微调-torchvision 3.1 使用已有模型 3.2 训练特定层 1. 自定义损失函数 虽然pytorch提供了许多常用的损失函 ...
- BERT微调finetune笔记
参考: 什么是BERT? - 知乎 (zhihu.com) 词向量之BERT - 知乎 (zhihu.com) BERT 详解 - 知乎 (zhihu.com) 详解Transformer (Atte ...
- 最新ChatGPT GPT-4 NLU应用之实体分类识别与模型微调(附ipynb与python源码及视频)——开源DataWhale发布入门ChatGPT技术新手从0到1必备使用指南手册(六)
目录 前言 最新ChatGPT GPT-4 自然语言理解NLU实战之实体分类识别与模型微调 主题分类 精准分类解决手段 模型微调步骤 核心代码 其它NLU应用及实战 相关文献 参考资料 其它资料下载 ...
- 在 Amazon SageMaker 上玩转 Stable Diffusion: 基于 Dreambooth 的模型微调
本文将以 Stable Diffusion Quick Kit 为例,详细讲解如何利用 Dreambooth 对 Stable Diffusion 模型进行微调,包括基础的 Stable Diffus ...
最新文章
- call、apply、bind
- GNU ARM汇编--(二)汇编编译链接与运行
- mysql没法安装_mysql没法使用、没法启动服务的解决方法
- HDU 1054 Strategic Game 最小点覆盖
- 区块链基础知识系列 第三课 区块链中的默克尔树
- 详细分析JVM内存模型
- spring ref historydesign philosophy
- SQLServer DBA 三十问(加强版)
- SimpleITK使用深度学习识别肺癌CT DICOM数据集
- stucts2 页面上的值如何与Action的属性值对应
- Spring Web MVC 的工作流程
- varchar(10)与nvarchar(10)有什么区别
- Power BI DAX 之日期函数
- LCD1602芯片的使用——简单易懂
- 小学初中数据常用定理公式总结-------复习一下
- 基于微信小程序+JavaWeb+SSM开发的图书借阅小程序
- Parcel打包React
- 全面了解风控策略体系
- SVN怎么去掉版本控制,去除调svn绿色图标显示
- 迭代开发中的微服务拆分