在第3篇文章中,我们构建并训练了第一个神经网络,接下来可以处理一些更复杂的样本了。

最顶尖的深度学习模型通常都复杂到让人难以置信。其中可能包含数百层,就算用不了数周,往往也要数天时间来使用海量数据进行训练。这类模型的构建和优化需要大量经验。

好在这些模型的使用还是很简单的,通常只需要编写几行代码。本文将使用一个名为Inception v3的预训练模型进行图片分类。

Inception v3

诞生于2015年12月的Inception v3是GoogleNet模型(曾赢得2014年度ImageNet挑战赛)的改进版。本文不准备深入介绍该模型的研究论文,不过打算强调一下论文的结论:相比当时最棒的模型,Inception v3的准确度高出了15%–25%,同时计算的经济性方面低六倍,并且至少将参数的数量减少了五倍(例如使用该模型对内存的要求更低)。

简直就是神器!那么我们该如何使用?

MXNet model zoo

Model zoo提供了一系列可直接使用的预训练模型,并且通常还会提供模型定义模型参数(例如神经元权重),(也许还会提供)使用说明。

首先来下载定义和参数(你也许需要更改文件名)。第一个文件可以直接打开:其中包含了每一层的定义。第二个文件是一个二进制文件,请不要打开 ;)

$ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-symbol.json
$ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-0126.params
$ mv Inception-BN-0126.params Inception-BN-0000.params

该模型已通过ImageNet数据集进行了训练,因此我们还需要下载对应的图片分类清单(共有1000个分类)。

$ wget http://data.dmlc.ml/models/imagenet/synset.txt
$ wc -l synset.txt1000 synset.txt
$ head -5 synset.txt
n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
n01491361 tiger shark, Galeocerdo cuvieri
n01494475 hammerhead, hammerhead shark

搞定,开始实战。

加载模型

我们需要:

  • 加载处于保存状态的模型:MXNet将其称之为检查点(Checkpoint)。随后即可得到输入的Symbol和模型参数。

    import mxnet as mxsym, arg_params, aux_params = mx.model.load_checkpoint('Inception-BN', 0)
  • 新建一个Module并为其指派输入Symbol。我们还可以使用一个Context参数决定要在哪里运行该模型:默认值为cpu(0),但也可改为gpu(0)以便通过GPU运行。
    mod = mx.mod.Module(symbol=sym)
  • 将输入Symbol绑定至输入数据。将其称之为“数据”是因为在网络的输入层中就使用了这样的名称(可以从JSON文件的前几行代码中看到)。

  • 将“数据”的形态(Shape)定义为1x3x224x224。别慌 ;),“224x224”是图片的分辨率,模型就是这样训练出来的。“3”是通道数量:红绿蓝(严格按照这样的顺序),“1”是批大小:我们将一次预测一张图片。

    mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
  • 设置模型参数。
    mod.set_params(arg_params, aux_params)

这样就可以了。只需要四行代码!随后可以放入一些数据看看会发生什么。嗯……先别急。

准备数据

数据准备:从七十年代以来,这一直是个痛苦的过程……从关系型数据库到机器学习,再到深度学习,这方面没有任何改进。虽然乏味但很必要。开始吧。

还记得吗,这个模型需要通过四维NDArray来保存一张224x224分辨率图片的红、绿、蓝通道数据。我们将使用流行的OpenCV库从输入图片中构建这样的NDArray。如果还没安装OpenCV,考虑到本例的要求,直接运行pip install opencv-python就够了 :)。

随后的步骤如下:

  • 读取图片:将返回一个Numpy数组,其形态为(图片高度, 图片宽度, 3),按顺序代表BGR(蓝、绿、红)三个通道。

    img = cv2.imread(filename)
  • 将图片转换为RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  • 将图片调整大小224x224
    img = cv2.resize(img, (224, 224,))
  • 重塑数组的形态,从(图片高度, 图片宽度, 3)重塑为(3, 图片高度, 图片宽度)。
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
  • 添加一个第四维度并构建NDArray
    img = img[np.newaxis, :]
    array = mx.nd.array(img)
    >>> print array.shape
    (1L, 3L, 224L, 224L)

晕了?一起用个例子看看吧。输入下列这张图片:

输入448x336的图片(来源:metaltraveller.com)

处理完毕后,该图会被缩小尺寸并拆分为RGB通道,存储在array[0]中(生成下文图片的代码可参阅这里)。

array[0][0]:224x224,红色通道

array0:224x224,绿色通道

array0:224x224,蓝色通道

如果批大小大于1,那么可以通过array1指定第二张图片,使用array2指定第三张图片,以此类推。

无论这个过程是乏味还是有趣,接下来我们开始预测吧!

开始预测

你可能还记得第3篇文章中提到,Module对象必须以为单位向模型提供数据:最常见的做法是使用数据迭代器(因此我们使用了NDArrayIter对象)。

在这里我们想要预测一张图片,因此尽管可以使用数据迭代器,不过也没啥必要。但我们可以创建一个名为Batch的具名元组(Named tuple),它可以充当假的迭代器,在引用数据属性时返回输入的NDArray。

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

随后即可将这个“Batch”传递给模型开始预测。

mod.forward(Batch([array]))

这个模型会输出一个包含1000个可能性的NDArray,每个可能性对应一个分类。由于批大小等于1,因此只需要一行代码。

prob = mod.get_outputs()[0].asnumpy()
>>> prob.shape
(1, 1000)

使用squeeze()将其转换为数组,随后使用argsort()创建第二个数组,其中保存了这些可能性按照降序排列的指数

prob = np.squeeze(prob)
>>> prob.shape
(1000,)
>> prob
[  4.14978594e-08   1.31608676e-05   2.51907986e-05   2.24045834e-052.30327873e-06   3.40798979e-05   7.41563645e-06   3.04062659e-08 etc.
sortedprob = np.argsort(prob)[::-1]
>> sortedprob.shape
(1000,)

根据模型的计算,这张图片最可能的分类是#546,可能性为58%

>> sortedprob
[546 819 862 818 542 402 650 420 983 632 733 644 513 875 776 917 795
etc.
>> prob[546]
0.58039135

这个分类叫什么名字呢?我们可以使用synset.txt文件构建分类清单,并找出546号的名称。

synsetfile = open('synset.txt', 'r')
categorylist = []
for line in synsetfile:categorylist.append(line.rstrip())
>>> categorylist[546]
'n03272010 electric guitar'

可能性第二大的分类是什么?

>>> prob[819]
0.27168664
>>> categorylist[819]
'n04296562 stage

挺棒的,你说呢?

就是这样,我们已经了解了如何使用预训练的顶尖模型进行图片分类。而这一切只需要4行代码……除此之外只要准备好数据就够了。

完整代码如下,请自行尝试并继续保持关注 ??

代码已发布至GitHub:mxnet_example2.py

后续内容:

  • 第5篇:进一步了解预训练模型(VGG16和ResNet-152)
  • 第6篇:通过树莓派进行实时物体检测(并让它讲话!)

mxnet入门--第4篇相关推荐

  1. Mxnet入门--第1篇

    MXNet教程 这一系列文章将概括介绍深度学习库MXNet,将介绍该库的主要功能及其Python API(可能会成为该库的首选API).随后还将提供一些有关MXNet的在线教程和笔记,希望能帮助大家更 ...

  2. linux usb3.0改2.0,TX1入门教程硬件篇-切换USB2.0与USB3.0

    TX1入门教程硬件篇-切换USB2.0与USB3.0 说明: 介绍如何切换TX1USB口的为2.0或3.0版本 步骤: 编辑extlinux.conf文件,修改usb_port_owner_info= ...

  3. 微信公众号开发入门教程第一篇

    微信公众号开发入门教程第一篇 关键字:微信公众平台开发 作者:方倍工作室 在这篇微信公众平台开发教程中,我们假定你已经有了PHP语言程序.MySQL数据库.计算机网络通讯.及HTTP/XML/CSS/ ...

  4. React入门看这篇就够了

    2019独角兽企业重金招聘Python工程师标准>>> 摘要: 很多值得了解的细节. 原文:React入门看这篇就够了 作者:Random Fundebug经授权转载,版权归原作者所 ...

  5. 一看就明白的爬虫入门讲解-基础理论篇(下篇)

    文/孔淼 上篇我分享了爬虫入门中的"我们的目的是什么"."内容从何而来"."了解网络请求"这三部分的内容,这一篇我继续分享以下内容: 1) 一些常见的限制方式 2) 尝试解决问题的思路 3) 效率问题 ...

  6. 一看就明白的爬虫入门讲解-基础理论篇(上篇)

    作者:孔淼 关于爬虫内容的分享,我会分成两篇,六个部分来分享,分别是: 1)  我们的目的是什么 2)  内容从何而来 3)  了解网络请求 4)  一些常见的限制方式 5)  尝试解决问题的思路 6 ...

  7. 怎么安装python_零基础入门必看篇:浅析python,PyCharm,Anaconda三者之间关系

    今天为大家带来的内容是:零基础入门必看篇:浅析python ,PyCharm,Anaconda三者之间关系 众所周知,Python是一种跨平台的计算机程序设计语言,简单来说,python就是类似于C, ...

  8. Java快速入门-01-基础篇

    Java快速入门-01-基础篇 如果基础不好或者想学的很细,请参看:菜鸟教程-JAVA 本笔记适合快速学习,文章后面也会包含一些常见面试问题,记住快捷键操作,一些内容我就不转载了,直接附上链接,嘻嘻 ...

  9. .NET Core实战项目之CMS 第五章 入门篇-Dapper的快速入门看这篇就够了

    写在前面 上篇文章我们讲了如在在实际项目开发中使用Git来进行代码的版本控制,当然介绍的都是比较常用的功能.今天我再带着大家一起熟悉下一个ORM框架Dapper,实例代码的演示编写完成后我会通过Git ...

  10. Spring Cloud 入门 之 Config 篇(六)

    一.前言 随着业务的扩展,为了方便开发和维护项目,我们通常会将大项目拆分成多个小项目做成微服务,每个微服务都会有各自配置文件,管理和修改文件起来也会变得繁琐.而且,当我们需要修改正在运行的项目的配置时 ...

最新文章

  1. 奇数页分节符什么意思_删除分节符问题
  2. 《Linux防火墙(第4版)》——1.3 传输层机制
  3. 分页的limit_分页场景(limit,offset)为什么会慢
  4. 数据结构实验之排序七:选课名单(卡内存的一道题。。坑)
  5. vue中v-show指令的使用之Vue知识点归纳(五)
  6. 博为峰Java技术题 ——JavaSE Java 方法Ⅰ
  7. 二进制漏洞利用原理--栈溢出
  8. 代码随笔——点阵汉字在LCD上的显示
  9. LINUX下载编译YASM
  10. 菜鸟学习日志3.界面控件的设置
  11. 解决Not all parameters were used in the SQL statement问题
  12. word文档图片画红线_word文档怎么画线条
  13. 【弄nèng - Elasticsearch】运维篇 —— ES分片unassigned解决方案(ALLOCATION_FAILED,REPLICA_ADDED等
  14. 看《墨攻》理解IoC
  15. 独家对话AAAI、ACM、ACL三会会士Raymond J. Mooney | 香侬专栏
  16. ZYNQ - 嵌入式Linux开发 -10- ZYNQ启动流程分析
  17. 指针java_Java中的指针
  18. RDKit|分子修改与编辑
  19. HTTP服务器的文件缓存
  20. iOS Keychain和keychain share

热门文章

  1. python用蓝牙发文件_用pybluez进行python蓝牙发现
  2. win10计算机本地无法连接,Win10没有本地连接怎么办?
  3. presto读取oracle,Presto数据接入方式
  4. java经纬度转地址_经纬度转地址示例代码
  5. datetime只要年月python_Python 的日期和时间处理
  6. c语言形式参数若为b 4,4月全国计算机等级二级C笔试考试题目
  7. Linux平台代码覆盖率测试-.gcda/.gcno文件及其格式分析
  8. 全局变量 局部变量 静态变量
  9. 在u-boot中自定义的命令
  10. docker 系列 - 基础镜像环境和Docker常用命令整理