----接上次的鸟的图像分类,其acc为84%。

这次依然使用此数据集,并用resenet网络进行finetune,然后进行鸟的图像分类。

1、什么是finetune?

利用已训练好的模型进行重构(自己的理解)。 对给定的预训练模型(用数据训练好的模型)进行微调,直接利用预训练模型进行微调可以节省许多的时间,能在比较小的epoch下就达到比较好的效果。通常进行微调,1、自己构建模型效果差,所以采用一些常用的模型,别人用数据训好的。2、数据量不够大,所以采用微调。以下是模型微调的例子:

2、数据为鸟类的数据集,其一共有4个类别,如下所示:

 1、数据的类别

2、图像数据

数据的前期处理和划分,划分可以用random.shuffule直接进行打乱,然后划分。

2、数据导入

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as nprandom_seed = 1
torch.manual_seed(random_seed)transform = transforms.Compose([# transforms.RandomRotation(1),transforms.Resize(224), #transforms.CenterCrop(224), #transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]),      # 标准化至[-1, 1],规定均值和标准差])transform1 = transforms.Compose([# transforms.RandomRotation(1),# transforms.Resize(224), ## transforms.CenterCrop(224), # 从图片中间切出224*224的图片transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]), # 标准化至[-1, 1],规定均值和标准差# transforms.Normalize(mean=[.5], std=[.5])   # 通道数为1#input[channel] = (input[channel] - mean[channel]) / std[channel]
])class DogLoader(torch.utils.data.Dataset):def __init__(self,root,train=True, img_transform = None, target_transform=None,transform = None):  # 负责将输入的数据做成list形式# 这样不会调用父类中的init方法,这个是重载了子类的init,并且这里的属性都属于全局的属性self.root = rootself.transform = img_transformself.target_transform = target_transformself.train = trainself.transforms = transformself.data = []self.label = []if self.train:with open('trainImg.txt') as fr:fr = fr.readlines()for imgPath in fr:img = imgPath.split('\t')[0]label = imgPath.split('\t')[1]self.data.append(root+img)self.label.append(int(label.strip()))else:with open('testImg.txt') as fr1:fr1 = fr1.readlines()for imgPath in fr1:img = imgPath.split('\t')[0]label = imgPath.split('\t')[1]self.data.append(root+img)self.label.append(int(label.strip()))def __getitem__(self, index):   # 对数据进行编码,然后转换成我们想要的格式img,label = self.data[index],self.label[index]img = Image.open(img).convert('RGB')   # 将图片转为RGB图像,为了有一些图像不是RGB的# img = Image.open(img_path).convert('L')    # 将RGB的三通道转为一通道的数if self.transforms:img=self.transforms(img)     # 传入transforms是PIL数据array = np.asarray(img)img = torch.from_numpy(array)return img, labeldef __len__(self):return len(self.data)train_data = DogLoader('data/',train=True,transform=transform)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)test_data = DogLoader('data/',train=False,transform=transform1)   # 测试不进行数据增强,训练继续宁数据增
test_loader = DataLoader(test_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)# for i, trainData in enumerate(train_loader):  # 将一个类的对象可以像list那样调用
#     print("第 {} 个Batch \n{}".format(i, trainData))
#
# for i, testData in enumerate(test_loader):  # 将一个类的对象可以像list那样调用
#     print("第 {} 个Batch \n{}".format(i, testData))

__init__() :类的属性形式,这里主要获取图像的地址和图像的标签。

__getitem__():  当使用这个函数时,它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。此处使用__getitem__()函数主要是通过后面训练时for循环得到图像的数据和标签。

2、预训练模型

# -*- coding: utf-8 -*-
from torchvision import models
from torch import nn
from global_config import *def fine_tune_resnet18(): # 这里表示为model_ft = models.resnet18(pretrained=True)'''这里写为True,会自动下载模型的参数,并加载到模型中。当然也可以手动下载模型的参数,然后将模型的参数加载到模型中'''# 把前面的特征进行了拼接print('num_features', model_ft)num_features = model_ft.fc.in_features# fine tune we change original fc layer into classes num of our ownmodel_ft.fc = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_vgg16():model_ft = models.vgg16(pretrained=True)print('fine_tune_vgg16() = ',model_ft)num_features = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_resnet18_():"""'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',"""model_ft = models.resnet18(pretrained=False)model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth'))   # 加载已经下载好的模型参数print('model_ft: ',model_ft)print('resnet18-5c106cde.pth',model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth')))num_features = model_ft.fc.in_features# fine tune we change original fc layer into classes num of our ownmodel_ft.fc = nn.Linear(num_features, 4)if USE_GPU:model_ft = model_ft.cuda()return model_ftdef fine_tune_resnet50():# 实际任务中这个挺重要的resNet50 = models.resnet50(pretrained=True) # 调用的预训练网络ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)   # 自己定义的网络# 读取参数pretrained_dict = resNet50.state_dict()   # 读取预训练网络模型的参数model_dict = ResNet50.state_dict()   # 读自定义模型的参数# 将pretained_dict里不属于model_dict的键剔除掉pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}   # 剔除一些不同的网络模型参数# 更新现有的model_dictmodel_dict.update(pretrained_dict)# 加载真正需要的state_dictResNet50.load_state_dict(model_dict)

预训练模型的调用,models.resnet18(),表示的是调用resnet18模型,当models.resnet18(pretrained=True)的时候,则表示直接下载了模型的参数。当models.resnet18(pretrained=False)的时候,可以手动下载好模型torch.load('resnet18-5c106cde.pth')。由于不同的数据可能的类别不同所以通常对最后的一层更改。

4、实验结果:

经过2个epoch就可以达到99的准确率。

模型微调(finetune)相关推荐

  1. pytorch模型微调(Finetune)

    Transfer Learning & Model Finetune 模型微调 **Transfer Learning:**机器学习分支,研究源域(source domain)的知识如何应用到 ...

  2. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  3. Pytorch之模型微调(Finetune)——用Resnet18进行蚂蚁蜜蜂二分类为例

    Pytorch之模型微调(Finetune)--手写数字集为例 文章目录 Pytorch之模型微调(Finetune)--手写数字集为例 前言 一.Transfer Learning and Mode ...

  4. 模型微调迁移学习Finetune方法大全

    迁移学习广泛地应用于NLP.CV等各种领域,通过在源域数据上学习知识,再迁移到下游其他目标任务上,提升目标任务上的效果.其中,Pretrain-Finetune(预训练+精调)模式是最为常见的一种迁移 ...

  5. 大模型微调技术(Adapter-Tuning、Prefix-Tuning、Prompt-Tuning(P-Tuning)、P-Tuning v2、LoRA)

    2022年11月30日,ChatGPT发布至今,国内外不断涌现出了不少大模型,呈现"百模大战"的景象,比如ChatGLM-6B.LLAMA.Alpaca等模型及在此模型基础上进一步 ...

  6. 【pytorch笔记】(五)自定义损失函数、学习率衰减、模型微调

    本文目录: 1. 自定义损失函数 2. 动态调整学习率 3. 模型微调-torchvision 3.1 使用已有模型 3.2 训练特定层 1. 自定义损失函数 虽然pytorch提供了许多常用的损失函 ...

  7. BERT微调finetune笔记

    参考: 什么是BERT? - 知乎 (zhihu.com) 词向量之BERT - 知乎 (zhihu.com) BERT 详解 - 知乎 (zhihu.com) 详解Transformer (Atte ...

  8. 最新ChatGPT GPT-4 NLU应用之实体分类识别与模型微调(附ipynb与python源码及视频)——开源DataWhale发布入门ChatGPT技术新手从0到1必备使用指南手册(六)

    目录 前言 最新ChatGPT GPT-4 自然语言理解NLU实战之实体分类识别与模型微调 主题分类 精准分类解决手段 模型微调步骤 核心代码 其它NLU应用及实战 相关文献 参考资料 其它资料下载 ...

  9. 在 Amazon SageMaker 上玩转 Stable Diffusion: 基于 Dreambooth 的模型微调

    本文将以 Stable Diffusion Quick Kit 为例,详细讲解如何利用 Dreambooth 对 Stable Diffusion 模型进行微调,包括基础的 Stable Diffus ...

最新文章

  1. call、apply、bind
  2. GNU ARM汇编--(二)汇编编译链接与运行
  3. mysql没法安装_mysql没法使用、没法启动服务的解决方法
  4. HDU 1054 Strategic Game 最小点覆盖
  5. 区块链基础知识系列 第三课 区块链中的默克尔树
  6. 详细分析JVM内存模型
  7. spring ref historydesign philosophy
  8. SQLServer DBA 三十问(加强版)
  9. SimpleITK使用深度学习识别肺癌CT DICOM数据集
  10. stucts2 页面上的值如何与Action的属性值对应
  11. Spring Web MVC 的工作流程
  12. varchar(10)与nvarchar(10)有什么区别
  13. Power BI DAX 之日期函数
  14. LCD1602芯片的使用——简单易懂
  15. 小学初中数据常用定理公式总结-------复习一下
  16. 基于微信小程序+JavaWeb+SSM开发的图书借阅小程序
  17. Parcel打包React
  18. 全面了解风控策略体系
  19. SVN怎么去掉版本控制,去除调svn绿色图标显示
  20. 迭代开发中的微服务拆分

热门文章

  1. 浅谈我对python学习的心得
  2. 信息论 | 计算离散信源的信息量和熵的MATLAB实现(函数封装调用)
  3. 老旧笔记本安装(升级)黑群晖7.1
  4. stricmp linux 头文件,Windows下程序向Linux下移植细节
  5. 股市风云:价值成长投资 稳健赢利之道
  6. 您为什么要加入CSDN个人空间
  7. 农村将迎来重大爆发!传统农业链条正在重塑,关键一步已经迈出
  8. 锚定物决定成败?四国央行数字货币对比
  9. BroadCastReceiver 简介
  10. 水晶报表(Crystal Report)- 水晶报表常见问题总结