Applications模块解析(一)
文章目录
- 说明
- 官方模型
- 使用与下载
- 存储文件与位置
- 预测
- 完整代码
- 构建新网络
- 特征提取
- 提取任意中间层特征
- 微调神经网络
- 自动输入张量
- 其他推荐
前言
阅读该文档需要二十分钟,完成后你将会学会使用applications模块的核心功能,并能够使用该模块中与训练模型进行预测,或在对该模块提供的神经网络进行微调,提取任意中间层特征。以下为类库版本:
keras 2.3.1
keras-applications 1.0.8
keras-base 2.3.1
keras-preprocessing 1.1.0
tensorflow 2.2.0
tensorflow-base 2.2.0
tensorflow-estimator 2.2.0
tensorflow-gpu 2.2.0
说明
Application是keras中的一个特殊模块,其为我们提供了已经构建好的多种经典神经网络以及在特定数据集上训练好的模型。同时借助该模块,我们也可以抽取其中的神经网络结构,直接用于或者调整后用于训练自己的模型。以下是keras.applications
模块中的初始化文件(__init__.py
)中的代码。
from .vgg16 import VGG16
from .vgg19 import VGG19
from .resnet50 import ResNet50
from .inception_v3 import InceptionV3
from .inception_resnet_v2 import InceptionResNetV2
from .xception import Xception
from .mobilenet import MobileNet
from .mobilenet_v2 import MobileNetV2
from .densenet import DenseNet121, DenseNet169, DenseNet201
from .nasnet import NASNetMobile, NASNetLarge
from .resnet import ResNet101, ResNet152
from .resnet_v2 import ResNet50V2, ResNet101V2, ResNet152V2
这些代码直接告诉了我们当前版本可以使用的神经网络。
官方模型
使用与下载
上面的代码告诉了我们Applications
模块中神经网络的引用格式。引用后的函数则可以通过参数weights
获取官方模型。
from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet')
使用以上代码,keras将会检测模型是否存在,不存在则会前往github下载模型,以上代码所需模型的下载地址为:https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5
存储文件与位置
模型下载后的默认存储位置是当前账户下的隐藏文件夹,路径为~/.keras/models
,绝对路径(以本人为例)/home/fonttian/.keras/models
。我们也可以直接下载模型然后存在该位置。不过除了模型之外,还会有其他文件存在,比如imagenet_class_index.json
。字如其名,该文件是所使用的数据集的class_index.json
,该文件中一共有一千类,与函数ResNet50
中的参数classes=1000
也是对应的。
预测
首先展示代码,这里使用的是官方例子。
img_path = '../Images/elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)preds = model.predict(x)
# 将结果解码为元组列表 (class, description, probability)
# (一个列表代表批次中的一个样本)
print('Predicted:', decode_predictions(preds, top=3)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
上面的代码中可以分为两个模块来看,首先是加载数据并处理,这里加载数据的函数image.load_img
是官方提供的函数,size
参数的作用是对图片进行尺寸的调整。加载后的图片继续使用numpy
进行调整。而第二个模块则是预测模块,使用model.predict
进行预测,预测后则使用官方自带的解码函数decode_predictions
将其转化为元组列表(class, description, probability)
进行输出,同时借助参数top
选择要输出的项多少个。
完整代码
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as npmodel = ResNet50(weights='imagenet')img_path = '../Images/elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)preds = model.predict(x)
# 将结果解码为元组列表 (class, description, probability)
# (一个列表代表批次中的一个样本)
print('Predicted:', decode_predictions(preds, top=3)[0])
# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
从完整代码我们可以很清楚的看出,Applications模块本质返回的是一个model对象,所以具体怎么操作这个model其实应该是很灵活的。下边就是一些更加灵活使用该model对象的内容。同时,很显然这些代码也可以移植到我们自己创建的model对象上。
构建新网络
特征提取
Keras自带的神经网络都是已经构建的经典神经网络,比如VGG16,这是一个经典的分类神经网络。而特征提取需要的则是没有分类器的神经网络,这点我们可以通过参数include_top
来控制。keras
中对该参数的解释为:
include_top: whether to include the ? fully-connectedlayers at the top of the network.
由于不同的神经网络使用的分类器不同,有时候会有个数字表示全连接神经网络的层数,所以我在这用一个问号代替。
官方给的例子有多个,首先VGG16 提取特征的例子:
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as npmodel = VGG16(weights='imagenet', include_top=False)img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)features = model.predict(x)
此时出书的例子就变为了特征,本质是矩阵。而非之前的类别。
提取任意中间层特征
除此之外,本质上来说所有applications中的神经网络返回的都是定义好的keras中的神经网络。因此我们也可以按照一般神经网络那样获取任意的中间层特征。具体方法如下,这里使用的是官方给的VGG19提取任意中间层特征的例子:
from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as npbase_model = VGG19(weights='imagenet')
model = Model(inputs=base_model.input, outputs=base_model.get_layer('block4_pool').output)img_path = '../Images/elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)block4_pool_features = model.predict(x)
print(block4_pool_features)
这个很容易看懂,就是借助了Model参数对原来的完整的model进行了处理,关键参数是outputs
,
这里真真的个问题在于如何获取核心参数,也就是函数get_layer
的参数name
(另外get_layer
还有一个参数为index
)这里确实没有很好地方法,而直接最有效的方法肯定还是直接查看keras的源码,该部分代码位于keras_applications
中,如果在Pycharm中可以直接ctrl
点击进去查看,或者前往github下载源码,这里是keras-applications的github地址。
微调神经网络
通过刚刚的代码我们已经非常接近applications
模块的本质了,我们可以在其上做更多更加纯粹的事情,比如只使用该模块提供的神经网络的一部分,而其他重要部分则由我们自行构建。具体方法则刚刚所做一致。首先借用include_top
参数去除分类器,然后借助Model
函数获取我们想要的输出,比如上一部分的提取任意中间层的特征,其实说到底就是把截取并运行了一个正常的神经网络一部分。之后再进行微调,获取我们需要的模型。
下面是keras的例子-微调 InceptionV3 ,我们以此为讲解:
(1)构建神经网络
# 构建不带分类器的预训练模型
base_model = InceptionV3(weights='imagenet', include_top=False)# 添加全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)# 添加一个全连接层
x = Dense(1024, activation='relu')(x)# 添加一个分类器,假设我们有200个类
predictions = Dense(200, activation='softmax')(x)# 构建我们需要训练的完整模型
model = Model(inputs=base_model.input, outputs=predictions)
这一步前半部分与之前的一致,都是先获取不带分类的神经网络。之后我们则自己来添加一个分类器,此处是先添加全剧平均池化层,然后添加一个两层神经网络作为分类器(200类)。之后再使用Model构建完整的模型。
(2)锁层,训练新分类器
这里我们使用了预训练的模型,之后我们则需要使用锁层。锁层的方法是设置.trainable=False
,该属性是keras中Network
的固有属性,其继承于keras.engine.base_layer
中的layer
。这是我们非常常用的keras中的类。我们通过该方法可以将v3的所有卷积层都锁定,然后训练未锁定的层。
# 首先,我们只训练顶部的几层(随机初始化的层)
# 锁住所有 InceptionV3 的卷积层
for layer in base_model.layers:layer.trainable = False# 编译模型(一定要在锁层以后操作)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')# 在新的数据集上训练几代
model.fit_generator(...)# 现在顶层应该训练好了,让我们开始微调 Inception V3 的卷积层。
(3)训练顶层
通过刚刚的操作我们就训练好了我们的顶层(top layers),现在我们则可以逐渐放开,并训练后边的神经元了。具体代码如下:
# 现在顶层应该训练好了,让我们开始微调 Inception V3 的卷积层。
# 我们会锁住底下的几层,然后训练其余的顶层。# 让我们看看每一层的名字和层号,看看我们应该锁多少层呢:
for i, layer in enumerate(base_model.layers):print(i, layer.name)# 我们选择训练最上面的两个 Inception block
# 也就是说锁住前面249层,然后放开之后的层。
for layer in model.layers[:249]:layer.trainable = False
for layer in model.layers[249:]:layer.trainable = True# 我们需要重新编译模型,才能使上面的修改生效
# 让我们设置一个很低的学习率,使用 SGD 来微调
from keras.optimizers import SGD
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy')# 我们继续训练模型,这次我们训练最后两个 Inception block
# 和两个全连接层
model.fit_generator(...)
(4)全部代码
from keras.applications.inception_v3 import InceptionV3
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K# 构建不带分类器的预训练模型
base_model = InceptionV3(weights='imagenet', include_top=False)# 添加全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)# 添加一个全连接层
x = Dense(1024, activation='relu')(x)# 添加一个分类器,假设我们有200个类
predictions = Dense(200, activation='softmax')(x)# 构建我们需要训练的完整模型
model = Model(inputs=base_model.input, outputs=predictions)# 首先,我们只训练顶部的几层(随机初始化的层)
# 锁住所有 InceptionV3 的卷积层
for layer in base_model.layers:layer.trainable = False# 编译模型(一定要在锁层以后操作)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')# 在新的数据集上训练几代
model.fit_generator(...)# 现在顶层应该训练好了,让我们开始微调 Inception V3 的卷积层。
# 我们会锁住底下的几层,然后训练其余的顶层。# 让我们看看每一层的名字和层号,看看我们应该锁多少层呢:
for i, layer in enumerate(base_model.layers):print(i, layer.name)# 我们选择训练最上面的两个 Inception block
# 也就是说锁住前面249层,然后放开之后的层。
for layer in model.layers[:249]:layer.trainable = False
for layer in model.layers[249:]:layer.trainable = True# 我们需要重新编译模型,才能使上面的修改生效
# 让我们设置一个很低的学习率,使用 SGD 来微调
from keras.optimizers import SGD
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy')# 我们继续训练模型,这次我们训练最后两个 Inception block
# 和两个全连接层
model.fit_generator(...)
自动输入张量
from keras.applications.inception_v3 import InceptionV3
from keras.layers import Input# 这也可能是不同的 Keras 模型或层的输出
input_tensor = Input(shape=(224, 224, 3)) # 假定 K.image_data_format() == 'channels_last'model = InceptionV3(input_tensor=input_tensor, weights='imagenet', include_top=True)
其他推荐
- Python ML环境搭建与学习资料推荐
- PyTorch官方中文文档CSDN引流
- Hyperopt官方中文文档导读
- pyhanlp用户指引CSDN博客专栏(官方推荐)
Applications模块解析(一)相关推荐
- cuDNN 功能模块解析
cuDNN 功能模块解析 Abstract 本cuDNN 8.0.4开发人员指南概述了cuDNN功能,如可自定义的数据布局.支持灵活的dimension ordering,striding,4D张量的 ...
- 实现一个webpack模块解析器
最近在学习 webpack源码,由于源码比较复杂,就先梳理了一下整体流程,就参考官网的例子,手写一个最基本的 webpack 模块解析器. 代码很少,github地址:手写webpack模块解析器 整 ...
- python中json模块_Python使用内置json模块解析json格式数据的方法
本文实例讲述了Python使用内置json模块解析json格式数据的方法.分享给大家供大家参考,具体如下: Python中解析json字符串非常简单,直接用内置的json模块就可以,不需要安装额外的模 ...
- TypeScript 素描 - 模块解析、声明合并
模块解析 模块解析有两种方式 相对方式 也就是以/或 ./或-/开头的,比如import jq from "/jq" 非相对方式 比如 import model from ...
- Spring的核心模块解析
转载自 Spring的核心模块解析 Spring框架是一个轻量级的集成式开发框架,可以和任何一种框架集成在一起使用,可以说是一个大的全家桶.Spring从1.x发展到现在的5.x可以说是越来越强大,下 ...
- php验证密码后跳转_php-laravel框架用户验证(Auth)模块解析(四)忘记密码
一.忘记密码模块路由 二.控制器解析 跟注册.登录的控制器一样,大部分的逻辑使用trait引入. ForgotPasswordController:负责忘记密码页面,以及邮件发送 四.扩展开发:自定义 ...
- Nginx 静态文件服务器搭建及autoindex模块解析
ngx_http_autoindex_module ngx_http_autoindex_module模块处理以斜杠字符('/')结尾的请求,并生成目录列表. 当ngx_http_index_modu ...
- 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)
追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...
- 解析html xml最好的模块,解析--import--htmllib--xml
--import ConfigParser 模块------解析配置文件---------------------------------------------------------------- ...
最新文章
- python输出去空格_python输出怎么取消空格
- 高斯混合模型Gaussian Mixture Model (GMM)——通过增加 Model 的个数,我们可以任意地逼近任何连续的概率密分布...
- linux(CentOS)下安装mongodb
- Pycharm 解决pip遇到的错误:module 'pip' has no attribute 'main'
- 微服务实践沙龙-上海站
- 直播PK短视频?直播+短视频才是王道
- java基础—线程间的通讯 生产者与消费者
- kafka消息处理失败后如何处理_面试题:Kafka 会不会丢消息?怎么处理的?
- SharePoint 删除废弃站点步骤
- 标准模板库(STL)之 map 列传 (三)
- verilog之按键消抖的理解
- android js 子线程,Android学习笔记:Android中的线程:MainThread 和 WorkerThread
- 关于KX混响插件:REVERB R详解
- 四参数拟合曲线_Origin进行体外释药规律的拟合
- 必看 logit回归分析步骤汇总
- Golang的反射机制(The Laws of Reflection)
- android双系统切换软件,可一键切换安卓/Win!双系统设备涌现CES
- python做饼图出现重影_解决echarts中饼图标签重叠的问题
- 【面试总结】JNI层MediaScanner的分析,挥泪整理面经
- 【FlashDB】第二步 FlashDB 移植 STM32L475 使用QSPI驱动外部 flash W25Q64之 SFUD 移植
热门文章
- Hopfield神经网络和TSP问题
- 《LoadRunner性能测试巧匠训练营》——3.3 场景监控实战
- Java面试题阶段汇总
- 我会手动创建线程,为什么让我使用线程池?
- Nomad技术手册:整体架构(Architecture)
- 数据库的事务隔离技术 之 MVCC
- 公共基础知识计算机,公共基础知识计算机基础知识试题
- 转盘抽奖php,使用PHP实现转盘抽奖算法案例解析
- android 补间动画重复次数,9.1.5 setRepeatCount方法:设置重复次数
- mysql互为主从利弊_MySQL互为主从复制常见问题