文章目录

  • 1.前言
  • 2.什么是预训练和微调
  • 3.预训练和微调的作用
  • 4.在一个新任务上微调一个预训练的模型代码实现

1.前言

预训练(pre-training/trained)和微调(fine tuning)这两个词经常在论文中见到,今天主要按以下两点来说明。

什么是预训练和微调?
它俩有什么作用?

2.什么是预训练和微调

你需要搭建一个网络模型来完成一个特定的图像分类的任务。首先,你需要随机初始化参数,然后开始训练网络,不断调整直到网络的损失越来越小。在训练的过程中,一开始初始化的参数会不断变化。当你觉得结果很满意的时候,你就可以将训练模型的参数保存下来,以便训练好的模型可以在下次执行类似任务时获得较好的结果。这个过程就是 pre-training。
之后,你又接收到一个类似的图像分类的任务。这时候,你可以直接使用之前保存下来的模型的参数来作为这一任务的初始化参数,然后在训练的过程中,依据结果不断进行一些修改。这时候,你使用的就是一个 pre-trained 模型,而过程就是 fine tuning。
所以,预训练 就是指预先训练的一个模型或者指预先训练模型的过程;微调 就是指将预训练过的模型作用于自己的数据集,并使参数适应自己数据集的过程。

3.预训练和微调的作用

在 CNN 领域中,实际上,很少人自己从头训练一个 CNN 网络。主要原因是自己很小的概率会拥有足够大的数据集,基本是几百或者几千张,不像 ImageNet 有 120 万张图片这样的规模。拥有的数据集不够大,而又想使用很好的模型的话,很容易会造成过拟合。

所以,一般的操作都是在一个大型的数据集上(ImageNet)训练一个模型,然后使用该模型作为类似任务的初始化或者特征提取器。比如 VGG,Inception 等模型都提供了自己的训练参数,以便人们可以拿来微调。这样既节省了时间和计算资源,又能很快的达到较好的效果。

4.在一个新任务上微调一个预训练的模型代码实现

# -*- coding: utf-8 -*-
""" Finetuning Example. Using weights from model trained in
convnet_cifar10.py to retrain network for a new task (your own dataset).
All weights are restored except last layer (softmax) that will be retrained
to match the new task (finetuning).
"""from __future__ import division, print_function, absolute_importimport tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression# Data loading
# Note: You input here any dataset you would like to finetune
X, Y = your_dataset()
num_classes = 10# Redefinition of convnet_cifar10 network
network = input_data(shape=[None, 32, 32, 3])
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = dropout(network, 0.75)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = dropout(network, 0.5)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
# Finetuning Softmax layer (Setting restore=False to not restore its weights)
softmax = fully_connected(network, num_classes, activation='softmax', restore=False)
regression = regression(softmax, optimizer='adam',loss='categorical_crossentropy',learning_rate=0.001)model = tflearn.DNN(regression, checkpoint_path='model_finetuning',max_checkpoints=3, tensorboard_verbose=0)
# Load pre-existing model, restoring all weights, except softmax layer ones
model.load('cifar10_cnn')# Start finetuning
model.fit(X, Y, n_epoch=10, validation_set=0.1, shuffle=True,show_metric=True, batch_size=64, snapshot_step=200,snapshot_epoch=False, run_id='model_finetuning')model.save('model_finetuning')

Tensorflow【实战Google深度学习框架】预训练与微调含代码(看不懂你来打我)相关推荐

  1. 06.图像识别与卷积神经网络------《Tensorflow实战Google深度学习框架》笔记

    一.图像识别问题简介及经典数据集 图像识别问题希望借助计算机程序来处理.分析和理解图片中的内容,使得计算机可以从图片中自动识别各种不同模式的目标和对象.图像识别问题作为人工智能的一个重要领域,在最近几 ...

  2. (转)Tensorflow 实战Google深度学习框架 读书笔记

    本文大致脉络: 读书笔记的自我说明 对读书笔记的摘要 具体章节的摘要: 第一章 深度学习简介 第二章 TensorFlow环境搭建 第三章 TensorFlow入门 第四章 深层神经网络 第五章 MN ...

  3. 免费教材丨第55期:Python机器学习实践指南、Tensorflow 实战Google深度学习框架

    小编说  时间过的好快啊,小伙伴们是不是都快进入寒假啦?但是学习可不要落下哦!  本期教材  本期为大家发放的教材为:<Python机器学习实践指南>.<Tensorflow 实战G ...

  4. 《Tensorflow 实战google深度学习框架》第二版源代码

    <<Tensorflow 实战google深度学习框架–第二版>> 完整资料github地址: https://github.com/caicloud/tensorflow-t ...

  5. 学习《TensorFlow实战Google深度学习框架 (第2版) 》中文PDF和代码

    TensorFlow是谷歌2015年开源的主流深度学习框架,目前已得到广泛应用.<TensorFlow:实战Google深度学习框架(第2版)>为TensorFlow入门参考书,帮助快速. ...

  6. 说说TensorFlow实战Google深度学习框架

    说说TensorFlow实战Google深度学习框架 事情是这样的,博主买了这本书,但是碍于想在电脑上边看边码,想找找PDF版本,然后各种百度,Google,百度网盘,最后找到的都是很多200M的,百 ...

  7. TensorFlow实战Google深度学习框架

    TensorFlow是谷歌2015年开源的主流深度学习框架.科技届的聚光灯已经从"互联网+"转到了"AI+": 掌握深度学习需要较强的理论功底,用好Tensor ...

  8. TensorFlow实战Google深度学习框架5-7章学习笔记

    目录 第5章 MNIST数字识别问题 第6章 图像识别与卷积神经网络 第7章 图像数据处理 第5章 MNIST数字识别问题 MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会 ...

  9. tensorflow实战google深度学习框架在线阅读

    https://max.book118.com/html/2019/0317/7112141026002014.shtm

  10. Tensorflow 实战 Google 深度学习框架(第2版)---- 10.2.2节 P274 代码

    #-*-coding:utf-8-*- import keras from tflearn.layers.core import fully_connected from keras.datasets ...

最新文章

  1. ibmm,让思维导图回归本质
  2. docker,mysql,wordpress搭建个人博客
  3. php锁定文本框内容的方法
  4. html click事件 参数,vue 实现click同时传入事件对象和自定义参数
  5. 为什么超长列表数据的翻页技术实现复杂(二)
  6. cout不明确什么意思_不计免赔险是什么意思?弄不清楚要吃大亏
  7. RAID5中的“左、右循环”与“同步、异步”(2)
  8. A - 1 CodeForces - 500A
  9. pycharm调试GreenOdoo
  10. Baidu All Reduce
  11. icem划分网格步骤_ICEM CFD教程-icem网格划分教程
  12. CSDN数据泄密凸显数据安全黑洞 飞客提示注意数据库保护
  13. MikroTik路由器配置
  14. typecho图标_handsome+Typecho美化过程【持续更新】包括踩坑解决办法
  15. 解决制作FAT32格式的重装U盘中文件过大问题
  16. 8脚 tja1050t_TJA1050TD-T_PDF技术资料下载_货期信息(1/10)_NXP - 万联芯城
  17. 无限制神器aria2懒人包及Aria2配置/Web管理面板教程
  18. python循环结构教学设计_Python程序设计 循环结构说课稿
  19. php sku 代码编写,SKU代码生成规则
  20. 在JBuilder中生成EXE、可执行jar、带shell窗口的EXE

热门文章

  1. numpy 筛选面积最大
  2. pytorch 批量筛选
  3. C++ dll 类型与 C#类型对应关系
  4. cd 在windows下 无法切换盘符目录
  5. 双控专业就业机器人_工业机器人专业好就业吗?有哪些机器人技术岗位?
  6. gitlab合并分支后需要提交吗_阿里前端,如何基于 GitLab 进行「自动化」构建及发布...
  7. 某都计算机考研计算机组成原理,东北大学2000年考研真题-计算机组成原理
  8. python写简单购物车_python简单的购物车程序(含代码)
  9. VMWare 虚拟机启动报“内部错误”的解决办法
  10. Linux之grep命令