1. Training a Custom Classifier based on a Quantized Feature Extractor量化特征提取器

在本节中,您将使用**“冻结”量化特征提取器**,并在其顶部训练自定义分类器头。 与浮点模型不同,您不需要为量化模型设置require_grad = False,因为它没有可训练的参数

加载预训练的模型:在本练习中,您将使用ResNet-18。

import torchvision.models.quantization as models# You will need the number of filters in the `fc` for future use.
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_fe = models.resnet18(pretrained=True, progress=True, quantize=True)
num_ftrs = model_fe.fc.in_features

此时,您需要修改预训练模型(pretrained model)。 该模型在开始和结束时都有量化/去量化块(quantize/dequantize blocks)。 但是,由于只使用feature 提取器,因此去量化层(dequantizatioin layer)必须在线性层(linear layer (the head))之前移动。 最简单的方法是将模型包装在nn.Sequential模块中。

第一步是在ResNet模型中隔离(isolate )特征提取器。 尽管在本示例中,我们使用all layers (除了except fc)作为feature 提取器,但实际上,您可以根据需要选择任意数量的零件。 如果您也想替换一些卷积层,这将很有用。

When separating the feature extractor from the rest of a quantized model, you have to manually place the quantizer/dequantized in the beginning and the end of the parts you want to keep quantized.

将特征提取器与量化模型的其余部分分开时,您必须手动将,量化器/去量化器(quantizer/dequantized)放置在要保持量化的部分(the parts you want to keep quantized)的开头和结尾。

The function below creates a model with a custom head.

from torch import nndef create_combined_model(model_fe):# Step 1. Isolate the feature extractor.model_fe_features = nn.Sequential(model_fe.quant,  # Quantize the inputmodel_fe.conv1,model_fe.bn1,model_fe.relu,model_fe.maxpool,model_fe.layer1,model_fe.layer2,model_fe.layer3,model_fe.layer4,model_fe.avgpool,model_fe.dequant,  # Dequantize the output)# Step 2. Create a new "head"new_head = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(num_ftrs, 2),)# Step 3. Combine, and don't forget the quant stubs.new_model = nn.Sequential(model_fe_features,nn.Flatten(1),new_head,)return new_model

当前,量化模型只能在CPU上运行。 但是,可以将模型的未量化部分发送到GPU。

import torch.optim as optim
new_model = create_combined_model(model_fe)
new_model = new_model.to('cpu')criterion = nn.CrossEntropyLoss()# Note that we are only training the head.
optimizer_ft = optim.SGD(new_model.parameters(), lr=0.01, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Train and evaluate

This step takes around 15-25 min on CPU. Because the quantized model can only run on the CPU, you cannot run the training on GPU.

new_model = train_model(new_model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25, device='cpu')visualize_model(new_model)
plt.tight_layout()

2. Finetuning the Quantizable Model

我们将微调用于迁移学习的特征提取器,并对特征提取器进行量化。

请注意,在第1节和第2节中,特征提取器都是量化的。
不同之处在于,在第1节中,我们使用了预训练的量化模型。
在这一部分中,我们将在对感兴趣的数据集进行微调之后创建一个量化的特征提取器

因此这是一种在具有量化优势的同时,通过转移学习,获得更好的准确性的方法。

  • get better accuracy
  • transfer learning
  • having the benefits of quantization.

请注意,在我们的特定示例中,训练集非常小(120张图像),因此微调整个模型的好处并不明显。 但是,此处显示的过程将提高使用较大数据集进行传递学习的准确性。

预训练特征提取器必须是可量化的。 为确保其可量化,请执行以下步骤:

  • 使用torch.quantization.fuse_modules融合(Conv,BN,ReLU),(Conv,BN)和(Conv,ReLU)。
  • 将特征提取器与自定义头部连接。 这需要对特征提取器的输出进行反量化。
  • 在特征提取器中的适当位置插入伪量化模块,以模拟训练期间的量化。
# notice  quantize=False
model = models.resnet18(pretrained=True, progress=True, quantize=False)
num_ftrs = model.fc.in_features# Step 1
model.train()
model.fuse_model() # Step 2
model_ft = create_combined_model(model)
model_ft[0].qconfig = torch.quantization.default_qat_qconfig # Use default QAT configuration
# Step 3
model_ft = torch.quantization.prepare_qat(model_ft, inplace=True)

对于步骤(1),我们使用来自torchvision / models / quantization的模型,这些模型具有成员方法fuse_model。 此功能将所有conv,bn和relu模块融合在一起。 对于自定义模型,这需要使用模块列表调用torch.quantization.fuse_modules API进行手动融合。

步骤(2)由上一部分中使用的create_combined_model函数执行。

通过使用torch.quantization.prepare_qat(插入假量化模块)来实现步骤(3)。

在步骤(4)中,您可以开始“微调”模型,然后将其转换为完全量化的版本(步骤5)。

要将微调模型转换为量化模型,可以调用torch.quantization.convert函数(在本例中,仅对特征提取器进行量化)。

Finetuning the model
在当前教程中,整个模型都经过了微调。 通常,这将导致更高的精度。 但是,由于此处使用的培训集很小,最终导致我们过度适应了培训集。

# Step 4. Fine tune the model
for param in model_ft.parameters():param.requires_grad = Truemodel_ft.to(device)  # We can fine-tune on GPU if availablecriterion = nn.CrossEntropyLoss()# Note that we are training everything, so the learning rate is lower
# Notice the smaller learning rate
optimizer_ft = optim.SGD(model_ft.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.1)# Decay LR by a factor of 0.3 every several epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.3)model_ft_tuned = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25, device=device)
# Step 5. Convert to quantized modelfrom torch.quantization import convert
model_ft_tuned.cpu()model_quantized_and_trained = convert(model_ft_tuned, inplace=False)
isualize_model(model_quantized_and_trained)plt.ioff()
plt.tight_layout()
plt.show()

09月28日 pytorch与resnet(五) 转移学习相关推荐

  1. 面试经历---阿里游戏(2020年09月28日晚上7点视频面试)

    9月28日晚上进行了一次视频面试,阿里广州游戏部门,下面说下这次面试的情况 1.自我介绍 介绍了做过的项目,面试官就围绕做过的项目进行深挖. 2.redis的集群方式 如果节点挂掉怎么办? 单个节点的 ...

  2. 2005年09月28日  日本,东京  晴朗

    又一次故地重游,来这里出差,前段日子公司限制上网,所以一直都没有机会来这里看望大家. 咳,在日本觉得没有国内那么随便,也没有朝鲜冷面可以吃. 这次来日本,也赶上了国内吃月饼的节日,而我去吃不到,公司发 ...

  3. 扒一扒HTTPS网站的内幕[2015年09月29日]

    扒一扒HTTPS网站的内幕 野狗 2015年09月28日发布 作者:王继波  野狗科技运维总监,曾在360.TP-Link从事网络运维相关工作,在网站性能优化.网络协议研究上经验丰富. 野狗官博:ht ...

  4. 第五人格6月28日服务器维护到极点,第五人格6月28日更新了什么内容_深渊的呼唤规则玩法介绍...

    今天小编为大家带来的是<第五人格>6月28日更新内容,感兴趣的小伙伴赶紧一起来看看吧,祝各位游戏愉快. 活动 暑期盛典活动[深渊的呼唤]来袭,以战队为单位开启报名,报名之后战队所有成员在新 ...

  5. 解密谷歌机器学习工程最佳实践——机器学习43条军规 翻译 2017年09月19日 10:54:58 98310 本文是对Rules of Machine Learning: Best Practice

    解密谷歌机器学习工程最佳实践--机器学习43条军规 翻译 2017年09月19日 10:54:58 983 1 0 本文是对Rules of Machine Learning: Best Practi ...

  6. 个人空间岁末大回报活动12月28日获奖名单

    个人空间岁末大回报: 动手就有C币拿!活动已于15日启动,非常感谢各位网友的大力支持和积极参与,个人空间的所有工作人员在这祝大家好运,希望你们每天都能拿到C币存入社区银行! 欢迎各位获奖者去自己的银行 ...

  7. 8月2日Pytorch笔记——梯度、全连接层、GPU加速、Visdom

    文章目录 前言 一.常见函数的梯度 二.激活函数及其梯度 1.Sigmoid 2.Tanh 3.ReLU 三.Loss 函数及其梯度 1.Mean Squared Error(MSE) 2.Softm ...

  8. 10月28日人工智能讲师叶梓为各工科院校老师进行了为期三天的人工智能培训

    10月28日人工智能讲师叶梓为各工科院校老师进行了为期三天的人工智能培训,培训过程中人工智能讲师叶梓与各高校老师就人工智能前沿热点进行热烈的讨论. 根据人力资源和社会保障部办公厅<关于印发专业技 ...

  9. PANDAS 数据合并与重塑(concat篇) 原创 2016年09月13日 19:26:30 47784 pandas作者Wes McKinney 在【PYTHON FOR DATA ANALYS

    PANDAS 数据合并与重塑(concat篇) 原创 2016年09月13日 19:26:30 标签: 47784 编辑 删除 pandas作者Wes McKinney 在[PYTHON FOR DA ...

  10. 分享Silverlight/WPF/Windows Phone/HTML5一周学习导读(11月28日-12月4日)

    分享Silverlight/WPF/Windows Phone/HTML5一周学习导读(11月28日-12月4日) 本周Silverlight学习资源更新 Silverlight HttpUtil 封 ...

最新文章

  1. python pymysql实例_python笔记-mysql命令使用示例(使用pymysql执行)
  2. 团队项目第一阶段冲刺站立会议04
  3. Yii的beforeAction
  4. Android 程序打包及签名
  5. 已知线性表最多可能有20个元素,存储每个元素需要8字节,存储每个指针需要4字节。当元素个数为( )时使用单链表比使用数组存储此线性表更加节约空间。
  6. k8s灰度更新_通过rancher部署k8s过程实战分享
  7. Linux下git的使用——将已有项目放到github上
  8. #if, #ifdef, #ifndef, #else, #elif, #endif的用法
  9. Spring-Data-JPA--增删改查2——自定义接口查询
  10. SSIS(2012版本)连接MongoDB,使用SSIS2012导入MongoDB
  11. Hadoop报错 Failed to locate the winutils binary in the hadoop
  12. kotlin埋点_GitHub - shajinyang/ilvdo-event-track: 埋点框架
  13. 每当Xcode升级之后,都会导致原有的Xcode插件不能使用,解决办法
  14. python 实现的huffman 编码压缩,解码解压缩
  15. java脚本语言 dim_写给新手windows脚本的入门
  16. ollvm源码分析之指令替换(1)
  17. 一天搞懂深度学习—学习笔记3(RNN)
  18. TFTLCD显示实验_STM32F1开发指南_第十八章
  19. Google Chrome 扩展程序
  20. 砸盘、销号、解散社群,Merlin Lab“跑路三连”暴露了DeFi哪些问题?

热门文章

  1. Python实现WGS 84坐标与web墨卡托投影坐标的转换
  2. java钝化_session的活化与钝化 (转)
  3. JavaEE核心API--Servlet
  4. Android自定义View【实战教程】4⃣️----BitmapShader详解及圆形、圆角、多边形实现
  5. PHP tcp短链接,示例:建立TCP链接
  6. mysql 用户列表数据结构_MySQL数据结构-行结构
  7. iphone mac地址是否随机_iPad 的 Mac 地址是否会随机更换,如何关闭呢
  8. springboot mybatis如何打印出查询语句_Java 面试,如何坐等 offer?
  9. cryptapi双向认证_2019 08 28 netty案例,netty4.1中级拓展篇十三《Netty基于SSL实现信息传输过程中双向加密验证》...
  10. PingInfoView,中文,以及ping包+描述的使用。