#实践中,受限于数据集规模的约束,我们很少从头开始端到端的训练一个神经网络。通常情况下,
# 我们会选择在ImageNet数据集上预训练好的网络模型上进行适当的修改,使其适用于目标数据集。#首先,修改网络模型的最后一个全连接层,使其适应于目标数据集,
# 使用预训练的网络权重来初始化网络模型的权重,用自己的图像数据来微调训练网络。微调网络主要有以下两种做法:#1.只训练最后一个全连接层,冻结除最后一个全连接层外的所有层的权重。
#2.所有网络层都参与训练,不过最后一个全连接层在训练时使用更大的学习率,通常最后一个全连接层的学习率是前面层学习率的10倍。#下面基于迁移学习实现一个ResNet18来对蜜蜂和蚂蚁分类,点击这里下载数据集。蚂蚁和蜜蜂大约均有120幅训练图像。每个类别有75幅验证图像。from __future__ import print_function, divisionimport torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import time
import os
import copy# 是否使用gpu运算
use_gpu = torch.cuda.is_available()
# 数据预处理,Pytorch提供了一个数据预处理的操作对象。定义如下:
data_transforms = {'train': transforms.Compose([# 随机在图像上裁剪出224*224大小的图像transforms.RandomResizedCrop(224),# 将图像随机翻转transforms.RandomHorizontalFlip(),# 将图像数据,转换为网络训练所需的tensor向量transforms.ToTensor(),# 图像归一化处理# 个人理解,前面是3个通道的均值,后面是3个通道的方差transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 读取数据
# 这种数据读取方法,需要有train和val两个文件夹,
# 每个文件夹下一类图像存在一个文件夹下
#在对分类的数据进行处理的时候,可以使用Pytorch提供的ImageFolder类来实现数据预处理。
#首先需要定义数据集的根目录:
data_dir = '../data/hymenoptera_data'
#然后,对于train和val这两个分别使用ImageFolder处理.这时,ImageFolder已经完成了照片数据的分类,并将这些图片的分类信息放倒了image_datasets变量中,
#可以看到,ImageFolder类已经将ants,bees做好了分类,并赋值为0和1。并且,训练数据以及测试数据被很好的分开。
#data_transforms对象在ImageFolder进行数据处理的时候作为参数传入,可以将上面数据处理的代码改为如下形式:
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}#有了ImageFolder获取到的image_datasets,这里只是找到了数据的路径以及相对应的类别,
# Pytorch还提供了DataLoader类,用于在训练时,实时获取数据对应的训练数据。代码如下:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
#DataLoader的第一个参数为上面获取到的image_datasets,第二个参数为batch_size,
#表示的是批训练时每批样本的数量。参数shuffle表示的是是否打乱数据的顺序,True表示打乱。参数num_workers表示参与计算的CPU核心数。# 读取数据集大小 train:244,val:153
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# 数据类别 ['ants','bees']
class_names = image_datasets['train'].classes# 训练与验证网络(所有层都参加训练)
def train_model(model, criterion, optimizer, scheduler, num_epochs):since = time.time() #返回的是毫秒# 保存网络训练最好的权重best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每训练一个epoch,测试一下网络模型的准确率for phase in ['train', 'val']: #phase=='train'if phase == 'train':# 学习率更新方式scheduler.step()#  调用模型训练model.train(True)else:# 调用模型测试model.train(False)running_loss = 0.0running_corrects = 0# 依次获取所有图像,参与模型训练或测试for data in dataloaders[phase]:# 获取输入inputs, labels = data# 判断是否使用gpuif use_gpu:inputs = inputs.cuda()labels = labels.cuda()# 梯度清零optimizer.zero_grad()# 网络前向运行outputs = model(inputs)_, preds = torch.max(outputs.data, 1) #获取最大值索引# 计算Loss值,交叉熵损失函数,其内部会自动加上Sofrmax层loss = criterion(outputs, labels)# 反传梯度,更新权重if phase == 'train':# 反传梯度loss.backward()# 更新权重optimizer.step()# 计算一个epoch的loss值和准确率,inputs.size(0)=4,running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)# 计算Loss和准确率的均值epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = float(running_corrects) / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 保存测试阶段,准确率最高的模型if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 网络导入最好的网络权重model.load_state_dict(best_model_wts)return model# 微调网络
if __name__ ==  '__main__':# 导入Pytorch中自带的resnet18网络模型model_ft = models.resnet18(pretrained=True)# 将网络模型的各层的梯度更新置为Falsefor param in model_ft.parameters():param.requires_grad = False# 修改网络模型的最后一个全连接层# 获取最后一个全连接层的输入通道数num_ftrs = model_ft.fc.in_features# 修改最后一个全连接层的的输出数为2model_ft.fc = nn.Linear(num_ftrs, 2)# 是否使用gpuif use_gpu:model_ft = model_ft.cuda()# 定义网络模型的损失函数criterion = nn.CrossEntropyLoss()# 只训练最后一个层# 采用随机梯度下降的方式,来优化网络模型optimizer_ft = torch.optim.SGD(model_ft.fc.parameters(), lr=0.001, momentum=0.9)# 定义学习率的更新方式,每5个epoch修改一次学习率exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)# 训练网络模型model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=10)# 存储网络模型的权重torch.save(model_ft.state_dict(),"model_only_fc.pkl")

迁移学习(基于ResNet18的蜜蜂和蚂蚁分类)相关推荐

  1. Tensorflow 2.1 迁移学习 基于VGG

    1. 什么是迁移学习 迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中.迁移学习是通过从已学习的相关任务 ...

  2. Deeplearning4j 实战 (10):迁移学习--ImageNet比赛预训练网络VGG16分类花卉图片

    Eclipse Deeplearning4j GitChat课程:https://gitbook.cn/gitchat/column/5bfb6741ae0e5f436e35cd9f Eclipse ...

  3. 迁移学习resnet_ResNet-V1-50卷积神经网络迁移学习进行不同品种的花的分类识别

    运行环境 python3.6.3.tensorflow1.10.0 Intel@AIDevCloud:Intel Xeon Gold 6128 processors集群 数据和模型来源 思路 数据集分 ...

  4. 猿创征文丨深度学习基于双向LSTM模型完成文本分类任务

    大家好,我是猿童学,本期猿创征文的第三期,也是最后一期,给大家带来神经网络中的循环神经网络案例,基于双向LSTM模型完成文本分类任务,数据集来自kaggle,对电影评论进行文本分类. 电影评论可以蕴含 ...

  5. 深度学习基于双向 LSTM 模型完成文本分类任务

    大家好,本期给大家带来神经网络中的循环神经网络案例,基于双向LSTM模型完成文本分类任务,数据集来自kaggle,对电影评论进行文本分类. 电影评论可以蕴含丰富的情感:比如喜欢.讨厌.等等.情感分析( ...

  6. 机器学习工程师 — Udacity 基于CNN和迁移学习创建狗品种分类器

    卷积神经网络(Convolutional Neural Network, CNN) 项目:实现一个狗品种识别算法App 推荐你阅读以下材料来加深对 CNN和Transfer Learning的理解: ...

  7. 基于特征的对抗迁移学习论文_学界 | 综述论文:四大类深度迁移学习

    选自arXiv 作者:Chuanqi Tan.Fuchun Sun.Tao Kong. Wenchang Zhang.Chao Yang.Chunfang Liu 机器之心编译 参与:乾树.刘晓坤 本 ...

  8. Python 迁移学习实用指南:6~11

    原文:Hands-On Transfer Learning with Python 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MT ...

  9. 整理学习之深度迁移学习

    迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的.世间万 ...

最新文章

  1. 提交PR后修改内容并合并commit
  2. Asp.net中防止用户多次登录的方法
  3. 单片机流星灯_51单片机拖尾灯实现
  4. 一个简单的Java计时器项目,附源码
  5. (238)数字IC工程师核心技能树(一)
  6. IOS微信API异常:unrecognized selector sent to instance 0x17005c9b0‘
  7. pythonarp攻击_python通过scapy模块进行arp断网攻击
  8. 从零开始搭二维激光SLAM --- 写作计划
  9. php合并两个有序链表,合并两个排序的链表
  10. 【转】临界区、互斥对象
  11. Software Engineering at Google翻译-III-9-Code Review(代码审查)
  12. React15中的栈调和diff算法
  13. PCB焊接——原理篇
  14. 软工实践第二次作业之个人项目
  15. 赞美CSDN 我去年买了个登山包超耐磨。
  16. c语言(练习for循环、字母大写转小写)
  17. VirtualBOX 虚拟机安装 OS X 10.9 Mavericks 及 Xcode 5,本人X220亲测
  18. 服务器搭建微信会员卡系统,智络会员管理系统如何与微信对接
  19. c++操作Office之ppt
  20. OpenAI肩负使命,宣布AI新计划

热门文章

  1. Kylin源码解析——从CubingJob的构建过程看Kylin的工作原理
  2. 哪个城市美女最多?OPPO R11开启“谁是拍照King·仲夏之梦”活动
  3. Exp9 Web安全基础实践 20164323段钊阳
  4. 在线数据迁移,数字化时代的必修课——京东云数据迁移实践
  5. Jetson Nano 关闭开启图形界面减少内存占用
  6. hdwiki下model目录功能
  7. 战神笔记本如何打开/关闭关机状态下USB供电
  8. webRtc-streamer简单使用-备份
  9. 数据库原理(三):Sql Server操作语句
  10. 1740 蜂巢迷宫(模拟,暴力,剪枝)