目录

1  pytorch的版本:

2  数据下载地址:

3  原始版本代码下载:

4  直接上代码:


1  pytorch的版本:

2  数据下载地址:

<https://download.pytorch.org/tutorial/hymenoptera_data.zip>

3  原始版本代码下载:

https://pytorch.org/tutorials/_downloads/transfer_learning_tutorial.py

4  直接上代码:

# -*- coding: utf-8 -*-
# @File    : test4.py
# @Blog    : https://blog.csdn.net/caomin1haofrom __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")plt.ion()   # interactive mode######################################################################
# 1.定义模型,  2.加载部分预训练数据,  3.冻结部分层
######################################
#1.定义模型
model_conv = models.resnet18()
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)'''
#打印模型的结构
print('###打印模型model_conv的结构####')
print(model_conv)
print('\n')print('###打印模型model_conv加载参数前的初始值####')
print(list(model_conv.parameters()))
print('\n')
'''#############################################
#2.加载部分预训练数据
pretrained_dict = torch.load('./08 transfer_learning/resnet18-5c106cde.pth')
'''
for k,v in pretrained_dict.items():print(k)
'''
#删除预训练模型跟当前模型层名称相同,层结构却不同的元素;这里有两个'fc.weight'、'fc.bias'
pretrained_dict.pop('fc.weight')
pretrained_dict.pop('fc.bias')#自己的模型参数变量
model_dict = model_conv.state_dict()
#去除一些不需要的参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#参数更新
model_dict.update(pretrained_dict)# 加载我们真正需要的state_dict
model_conv.load_state_dict(model_dict)'''
print('###打印模型model_conv加载参数后的参数值####')
print(list(model_conv.parameters()))
print('\n')
'''
#############################################
#3.冻结部分层
#将满足条件的参数的 requires_grad 属性设置为False
for name, value in model_conv.named_parameters():if (name != 'fc.weight') and (name != 'fc.bias'):value.requires_grad = False
'''
#打印各层的requires_grad属性
print('###打印模型model_conv参数的requires_grad属性####')
for name, param in model_conv.named_parameters():print(name,param.requires_grad)
'''# filter 函数将模型中属性 requires_grad = True 的参数选出来
params_conv = filter(lambda p: p.requires_grad, model_conv.parameters())
model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opoosed to before.
optimizer_conv = optim.SGD(params_conv, lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)######################################################################
# Training the model
#编写一个通用函数来训练模型。
# 下面将说明: * 调整学习速率 * 保存最好的模型
#下面的参数scheduler是一个来自 torch.optim.lr_scheduler 的学习速率调整类的对象(LR scheduler object)。def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0#  迭代数据.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 后向+仅在训练阶段进行优化if phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 深度复制moif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 加载最佳模型权重model.load_state_dict(best_model_wts)return model######################################################################
# 可视化部分训练图像,以便了解数据扩充。def imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated######################################################################
# Visualizing the model predictions
# 一个通用的展示少量预测图片的函数def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)######################################################################
#训练集数据扩充和归一化
#在验证集上仅需要归一化
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = './08 transfer_learning/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")if __name__ == '__main__':# Train and evaluate 2# 训练模型 在CPU上,与前一个场景相比,这将花费大约一半的时间,因为不需要为大多数网络计算梯度。但需要计算转发。model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=11)visualize_model(model_conv)plt.ioff()plt.show()

部分运行结果:

Pytorch 加载部分预训练模型并冻结某些层相关推荐

  1. python怎么使用预训练的模型_Tensorflow加载Vgg预训练模型操作

    很多深度神经网络模型需要加载预训练过的Vgg参数,比如说:风格迁移.目标检测.图像标注等计算机视觉中常见的任务.那么到底如何加载Vgg模型呢?Vgg文件的参数到底有何意义呢?加载后的模型该如何使用呢? ...

  2. python glove训练模型_gensim加载Glove预训练模型

    前言 之前一直用word2vec,今天在用gensim加载glove时发现gensim只提供了word2vec的接口,如果我们想用gensim加载Glove词向量怎么办呢? word2vec和Glov ...

  3. keras冻结_Keras 实现加载预训练模型并冻结网络的层

    在解决一个任务时,我会选择加载预训练模型并逐步fine-tune.比如,分类任务中,优异的深度学习网络有很多. ResNet, VGG, Xception等等... 并且这些模型参数已经在imagen ...

  4. pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

    问题 最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下.     KeyError: 'layer1.0.bn1.num_batches_tracked' 其实 ...

  5. pytorch加载之前训练模型中的部分参数以及冻结部分参数(实测,自己实际项目代码中的)

    我的需求是,由于我在不停的尝试各种模型,导致模型木块一直会变.如果每次重复重新开始训练要花费大把时间. 我之前运行的模型 ResNet ->                            ...

  6. Pytorch加载torchvision从本地下载好的预训练模型的简单解决方案

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.喜 ...

  7. 用pytorch加载训练模型

    用pytorch加载.pth格式的训练模型 在pytorch/vision/models网页上有很多现成的经典网络模型可以调用,其中包括alexnet.vgg.googlenet.resnet.inc ...

  8. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  9. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

最新文章

  1. php trace 图形,php 方便水印和缩略图的图形类
  2. 天呐!java生成DAT文件并写入数据
  3. java jframe 运行_java – 使用JProgressBar运行JFrame
  4. c语言cnn实现ocr字符,端到端的OCR:基于CNN的实现
  5. 【Boost】boost库中thread多线程详解1——thread入门与简介
  6. 反转!物联网火爆,开发者却很难入门?
  7. vueCli3中使用代理,点击页面的刷新按钮时报错
  8. .NPT 扩展名格式文件类型及打开方式分析:首次渗入 XR 内容领域
  9. express基本原理
  10. 微服务体系三维可缩放模型
  11. 解决办法:无法安装 /lib/x86_64-linux-gnu/libpng12.so.0 的新版本
  12. 利用python调用谷歌翻译API
  13. 数字电路:数据选择器与译码器
  14. 计算机验证菜单命令的各种特性,2017年CAD工程师认证单选题「附答案」
  15. InputStream的available()方法(读文件)
  16. python画兔子代码_Python基础练习实例11(兔子问题)
  17. 学习笔记之 初试Linux遇到的问题
  18. go文件服务器加密,gosignal: 使用 Golang 实现的端对端加密聊天软件 Signal 服务端...
  19. tcp 粘包 丢包 解决方案
  20. 计算机无线网卡,电脑如何无线上网 电脑无线网卡买什么好

热门文章

  1. C vector详解
  2. wordpress致命错误怎么解决_pppoe错误是什么意思 pppoe错误怎么解决
  3. listview控件在php的使用方法,Android_Android编程之控件ListView使用方法,本文实例讲述了Android编程之控 - phpStudy...
  4. 协议转换器是怎么分类的?主要有哪些类别?
  5. 工业以太网交换机的冗余功能及发展历程介绍
  6. 【渝粤题库】陕西师范大学164117 企业组网技术 作业 (高起专)
  7. 433M数传电台窄带无线通讯技术手册
  8. Filtration, σ-algebras
  9. 从拉格朗日乘数法到KKT条件
  10. 机器学习中的算法-支持向量机(SVM)基础