使用Tensorflow和VGG16预训模型进行预测

from:https://zhuanlan.zhihu.com/p/28997549

fast.ai的入门教程中使用了kaggle: dogs vs cats作为例子来让大家入门Computer Vision。不过并未应用到最近很火的Tensorflow。Keras虽然可以调用Tensorflow作为backend,不过既然可以少走一层直接走Tensorflow,那秉着学习的想法,就直接用Tensorflow来一下把。

听说工程上普遍的做法并不是从头开始训练模型,而是直接用已经训练完的模型稍加改动(这个过程叫finetune)来达到目的。那么这里就需要用Tensorflow还原出VGG16的模型。这里借鉴了frossard的python代码和他转化的权重。架构具体如下:(cs231n有更详细的介绍)

INPUT: [224x224x3]        memory:  224*224*3=150K   weights: 0
CONV3-64: [224x224x64]  memory:  224*224*64=3.2M   weights: (3*3*3)*64 = 1,728
CONV3-64: [224x224x64]  memory:  224*224*64=3.2M   weights: (3*3*64)*64 = 36,864
POOL2: [112x112x64]  memory:  112*112*64=800K   weights: 0
CONV3-128: [112x112x128]  memory:  112*112*128=1.6M   weights: (3*3*64)*128 = 73,728
CONV3-128: [112x112x128]  memory:  112*112*128=1.6M   weights: (3*3*128)*128 = 147,456
POOL2: [56x56x128]  memory:  56*56*128=400K   weights: 0
CONV3-256: [56x56x256]  memory:  56*56*256=800K   weights: (3*3*128)*256 = 294,912
CONV3-256: [56x56x256]  memory:  56*56*256=800K   weights: (3*3*256)*256 = 589,824
CONV3-256: [56x56x256]  memory:  56*56*256=800K   weights: (3*3*256)*256 = 589,824
POOL2: [28x28x256]  memory:  28*28*256=200K   weights: 0
CONV3-512: [28x28x512]  memory:  28*28*512=400K   weights: (3*3*256)*512 = 1,179,648
CONV3-512: [28x28x512]  memory:  28*28*512=400K   weights: (3*3*512)*512 = 2,359,296
CONV3-512: [28x28x512]  memory:  28*28*512=400K   weights: (3*3*512)*512 = 2,359,296
POOL2: [14x14x512]  memory:  14*14*512=100K   weights: 0
CONV3-512: [14x14x512]  memory:  14*14*512=100K   weights: (3*3*512)*512 = 2,359,296
CONV3-512: [14x14x512]  memory:  14*14*512=100K   weights: (3*3*512)*512 = 2,359,296
CONV3-512: [14x14x512]  memory:  14*14*512=100K   weights: (3*3*512)*512 = 2,359,296
POOL2: [7x7x512]  memory:  7*7*512=25K  weights: 0
FC: [1x1x4096]  memory:  4096  weights: 7*7*512*4096 = 102,760,448
FC: [1x1x4096]  memory:  4096  weights: 4096*4096 = 16,777,216
FC: [1x1x1000]  memory:  1000 weights: 4096*1000 = 4,096,000

具体实现移步VGG16。这里要注意的一点就是最后的输出是不需要经过Relu的。

预测猫和狗不能照搬这个架构,因为VGG16是用来预测ImageNet上1000个不同种类的。用来预测猫和狗两种类别,需要在这个架构的基础上再加一层FC把1000转化成2个。(也可以把最后一层替换掉,直接输出成2个)。我还在VGG16之后多加了一层BN,原来VGG16的时候并不存在BN。我也并没有在每个CONV后面加,因为不想算...

FC的输出在训练的时候使用Cross Entropy损失函数,预测的时候使用Softmax。这样就可以识别出给定图片是猫还是狗了。具体代码移步cats_model.py

我们来看一下效果如何。完整的:Jupyter Notebook

未经过Finetune直接运行VGG16改模型(加上了最后一层FC)的结果(预测非常不准,因为最后一层的权重都是随机的)。这么做的目的是看一下模型是否能运行,顺便看看能蒙对几个。

经过一次迭代,准确率就达到95%了(重复过几次,这次并不是最高的)。

再看一下同样的图片预测结果,似乎准确了很多。

Final Thoughts

图像识别非常有趣,是一个非常有挑战的领域。

转载于:https://www.cnblogs.com/bonelee/p/9017114.html

迁移学习——使用Tensorflow和VGG16预训模型进行预测相关推荐

  1. 迁移学习(Transfer learning)、重用预训练图层、预训练模型库

    迁移学习(Transfer learning).重用预训练图层.预训练模型库 目录 迁移学习(Transfer learning).重用预训练图层.预训练模型库 迁移学

  2. 平潭迁移库是什么意思_迁移学习》第四章总结---基于模型的迁移学习

    基于模型的迁移学习可以简单理解为就是基于模型参数的迁移学习,如何使我们构建的模型可以学习到域之间的通用知识. 1. 基于共享模型成分的迁移学习 在模型中添加先验知识. 1.1 利用高斯过程的迁移学习 ...

  3. 使用迁移学习和TensorFlow.js在浏览器中进行AI情感检测

    目录 KNN分类器 迁移学习 我们的技术栈 配置 使用KNN分类器 将代码放在一起 测试结果 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们已经看到了加载预训练模型有多么容易.在本文中, ...

  4. 使用迁移学习和 TensorFlow 进行食品分类

    摘要 在今天的报告中,我们将分析食品以预测它们是否可以食用.我们应用最先进的 迁移学习方法和 Tensorflow 框架来构建用于食品分类的机器学习模型. 介绍 图像分类是机器预测图片属于哪个类别的工 ...

  5. 【迁移学习】深度域自适应网络DANN模型

    DANN Domain-Adversarial Training of Neural Networks in Tensorflow 域适配:目标域与源域的数据分布不同但任务相同下的迁移学习. 模型建立 ...

  6. pytorch迁移学习后使用微调策略再次提高模型训练结果

    1.使迁移的模型解冻 for param in model.parameters():param.requires_grad=True 2.此时学习速率设置再小些 optimizer=torch.op ...

  7. PaddlePaddle迁移学习做图像分类,数十种高精度模型任意切换

    向AI转型的程序员都关注了这个号

  8. slim php dd model,第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)...

    在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...

  9. tensorflow实现宝可梦数据集迁移学习

    目录 一.迁移学习简介 二.构建预训练模型 1.调用内置模型 2.修改模型 3.构建模型 三.导入数据和预处理 1.设置batch size 2.读取训练数据 3.读取验证数据 4.读取测试数据 5. ...

最新文章

  1. windows 2003 下oracle从10.2.0.1升级到10.2.0.4
  2. android9.0不能用4g定位,Android 9.0新特性:让用户认为4G信号更强
  3. 敏捷项目向组合级看齐
  4. 【setup.py编译出错】——提示无法查找到powershell.exe
  5. 《科学:无尽的前沿》分享会在京举办,助力中国企业打造“科研的应许之地”
  6. 4乘4方格走的路线_苏州周边4个冷门自驾游路线景点推荐
  7. 终结者:具体解释Nginx(一)
  8. Web Application Security 网络应用程序安全 - (二)2010年网络安全威胁排行榜TOP 10...
  9. Spring boot 配置array,list,map
  10. Day04:循环结构(while、do-while、for)
  11. java 调用 easypr_Java程序执行Linux命令调用EasyPR程序识别车牌号
  12. Xv6 traps and system calls
  13. 玩机技巧|去除Windows桌面快捷方式图标左下角上的小箭头
  14. SAP采购中若干价格表的梳理
  15. 对路径“C:\inetpub\wwwroot\Test\Temper\”的访问被拒绝 【已解决】
  16. ai动漫生成软件哪个好?这篇文章告诉你
  17. 平台会员卡券源码文档
  18. WPS:将彻底关闭广告
  19. 游戏服务器多少钱一个月 游戏服务器配置怎么选择
  20. java 虚拟机内存类_java 虚拟机类加载 及内存结构

热门文章

  1. openresty require报错
  2. 大数据SQL日常学习——NVL函数
  3. ubuntu下打开matlab_ubuntu终端命令启动matlab方法
  4. python坐标轴刻度为经纬度_python各类经纬度转换
  5. python中json模块_Python使用内置json模块解析json格式数据的方法
  6. 镜像和linux关系,Docker中容器和镜像的关系【通俗易懂】
  7. 页面jlabel背景色设置_(六)使用elementUI搭建管理员页面布局
  8. mysql注入反弹_Discuz!x xss反弹后台无防御sql注入getshell(附带exploit)
  9. 口碑好的mysql数据监控平台_构建狂拽炫酷屌的 MySQL 监控平台
  10. Android性能优化常见问题,附架构师必备技术详解