keras 多分类一些函数参数设置
用Lenet-5 识别Mnist数据集为例子:
采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。
首先import所需要的模块
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K
定义图像数据信息及训练参数
img_width, img_height = 28, 28
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000
nb_validation_samples = 10000
epochs = 50
batch_size = 32
判断使用的后台
if K.image_dim_ordering() == 'th':input_shape = (3, img_width, img_height)
else:input_shape = (img_width, img_height, 3)
网络模型定义
主要注意最后的输出层定义
比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:
x = Dense(10,activation=’softmax’)(x)
此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid
img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)model.compile(loss='binary_crossentropy',optimizer='rmsprop',metrics=['accuracy'])
model.summary()
利用ImageDataGenerator传入图像数据集
注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical’,如果是二分类问题该值可设为‘binary’,另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。
train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical', #多分类问题设为'categorical'classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字)validation_generator = test_datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')
训练和保存模型及权值
model.fit_generator(train_generator,samples_per_epoch=nb_train_samples,nb_epoch=epochs,validation_data=validation_generator,nb_val_samples=nb_validation_samples)model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')
至此训练结束
图片预测
注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。
此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:img_addr=input('Please input your image address:')if img_addr=="exit":breakelse:img = load_img(img_addr, False, target_size=(28, 28))x = img_to_array(img) / 255.0x = np.expand_dims(x, axis=0)result = model.predict(x)ind=np.argmax(result,1)print('this is a ', classes[ind])
keras 多分类一些函数参数设置相关推荐
- oracle数据库初始化参数分类,oracle初始化参数设置
oracle初始化参数设置 时间:2007-11-09 来源:不详 作者:迈克DB ALTER DATABASE DATAFILE'd:ORANTDATABASEUSER1ORCL.ORA' RE ...
- keras中的fit函数参数_keras的fit_generator与callback函数
fit_generator函数 fit_generator函数 callback类 每一个epoch结束(on_epoch_end)时,都要调用callback函数,callback函数(类)都要集成 ...
- ecall函数调用系统函数参数设置教程
系统调用的ecall指令会使用a0和a7寄存器,其中a7寄存器保存的是系统调用号,a0寄存器保存的是系统调用参数,返回值会保存在a0寄存器中.为了能让系统调用指令能被集成进当前的流水线,ecall指令 ...
- c语言 为参数设置默认值,js函数参数设置默认值
前端学HTTP之网关.隧道和中继 前面的话 Web是一种强大的内容发布工具.人们已经从只在网上发送静态的在线文档,发展到共享更复杂的资源,比如数据库内容或动态生成的HTML页面.Web浏览器为用户提供 ...
- 改mysql修改界定符_dbvisualizer参数设置
6.13 可否完全禁用数据编辑? 可以. 方法: 在文本编辑器里打开文件 DBVIS-HOME/resources/dbvis-custom.prefs. 找出 dbvis.disabledataed ...
- js 的函数参数的默认值问题
js函数参数设置默认值 php有个很方便的用法是在定义函数时可以直接给参数设默认值,如: function simue ($a=1,$b=2){ return $a+$b; } echo simu ...
- 前端提高篇(十一)JS进阶8函数参数及arguments
形参与实参 基础点可以看这篇文章 获取形参个数:函数名.length function add(a,b,c,d,e){console.log('形参个数:' + add.length);} 运行效果: ...
- Caffe学习(四)数据层及参数设置
caffe的各种数据层在caffe.proto文件中有定义.通过对定义的caffe.proto文件进行编译,产生支持各种层操作的c++代码.后面将会详细解读caffe.proto文件(在caffe里就 ...
- WordPress分类列表函数:wp_list_categories用法及参数详解举例
http://www.511yj.com/wordpress-wp-categories.html 注意: 1. wp_list_categories() 和 list_cats() 以及 wp_li ...
- vlc 详细使用方法:libvlc_media_add_option 函数中的参数设置
[转载自]tinyle的专栏 [原文链接地址]http://blog.csdn.net/myaccella/article/details/7027962 [手记] 下面列出的参数可以在命令行中执行, ...
最新文章
- Java内部类及其实例化
- 在Salesforce中调用外部系统所提供的的Web Service
- 站内搜索引擎初探:haystack全文检索,whoosh搜索引擎,jieba中文分词
- Android 系统(254)---Android libphonenumber Demo 手机号码归属地
- 说说WeakReference弱引用
- 一题多解(八)—— 矩阵上三角(下三角)的访问
- linux操作系统死机处理办法
- 安卓版的水经注地图_水经注万能地图下载器
- 电脑桌面数字时钟c语言,DesktopDigitalClock(桌面数字时钟)
- ATAT-mcsqs- 运行后出现报错:段错误(吐核/core dumped)
- java pgp 加密_加密软件PGP的使用
- mysql 获取两个月前的日期
- DELPHI盒子上的RAD studio 2010安装过程及体验(超多图)
- 读周爱民《javascript语言精髓与编程实践》有感
- 数据挖掘实战:员工离职预测(训练赛)
- 服务器内存占用太高如何解决及知识点介绍
- python:蒙特卡罗方法计算圆周率
- 问卷链接怎么做二维码?如何使用二维码做问卷调查?
- 逍遥模拟器配合fiddler爬取抖音视频!开源免费!
- [Unity官方文档翻译]Downloading and Installing Unity下载和安装unity教程
热门文章
- Wine 开发版 4.6 发布,Windows 应用的兼容层
- java B2B2C 源码 多级分销Springcloud多租户电子商城系统-SpringCloud配置中心内容加密...
- 从零开始webpack搭建项目
- oracle VM manager 3.1 试验备忘录
- 【小贴士】工作中的”闭包“与事件委托的”阻止冒泡“
- [洛谷1681]最大正方形II
- Nancy之Cache的简单使用
- 目前M院M师的教学乱象
- 老去的80后忆当年-致80后的朋友们
- #paragma详解