用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000
nb_validation_samples = 10000
epochs = 50
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':input_shape = (3, img_width, img_height)
else:input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义
比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:
x = Dense(10,activation=’softmax’)(x)
此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)model.compile(loss='binary_crossentropy',optimizer='rmsprop',metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集
注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical’,如果是二分类问题该值可设为‘binary’,另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical',    #多分类问题设为'categorical'classes=['0','1','2','3','4','5','6','7','8','9']  #十种数字图片所在文件夹的名字)validation_generator = test_datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical')

训练和保存模型及权值

model.fit_generator(train_generator,samples_per_epoch=nb_train_samples,nb_epoch=epochs,validation_data=validation_generator,nb_val_samples=nb_validation_samples)model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束


图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。
此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:img_addr=input('Please input your image address:')if img_addr=="exit":breakelse:img = load_img(img_addr, False, target_size=(28, 28))x = img_to_array(img) / 255.0x = np.expand_dims(x, axis=0)result = model.predict(x)ind=np.argmax(result,1)print('this is a ', classes[ind])

keras 多分类一些函数参数设置相关推荐

  1. oracle数据库初始化参数分类,oracle初始化参数设置

    oracle初始化参数设置 时间:2007-11-09  来源:不详  作者:迈克DB ALTER DATABASE DATAFILE'd:ORANTDATABASEUSER1ORCL.ORA' RE ...

  2. keras中的fit函数参数_keras的fit_generator与callback函数

    fit_generator函数 fit_generator函数 callback类 每一个epoch结束(on_epoch_end)时,都要调用callback函数,callback函数(类)都要集成 ...

  3. ecall函数调用系统函数参数设置教程

    系统调用的ecall指令会使用a0和a7寄存器,其中a7寄存器保存的是系统调用号,a0寄存器保存的是系统调用参数,返回值会保存在a0寄存器中.为了能让系统调用指令能被集成进当前的流水线,ecall指令 ...

  4. c语言 为参数设置默认值,js函数参数设置默认值

    前端学HTTP之网关.隧道和中继 前面的话 Web是一种强大的内容发布工具.人们已经从只在网上发送静态的在线文档,发展到共享更复杂的资源,比如数据库内容或动态生成的HTML页面.Web浏览器为用户提供 ...

  5. 改mysql修改界定符_dbvisualizer参数设置

    6.13 可否完全禁用数据编辑? 可以. 方法: 在文本编辑器里打开文件 DBVIS-HOME/resources/dbvis-custom.prefs. 找出 dbvis.disabledataed ...

  6. js 的函数参数的默认值问题

    js函数参数设置默认值 php有个很方便的用法是在定义函数时可以直接给参数设默认值,如: function simue ($a=1,$b=2){   return $a+$b; } echo simu ...

  7. 前端提高篇(十一)JS进阶8函数参数及arguments

    形参与实参 基础点可以看这篇文章 获取形参个数:函数名.length function add(a,b,c,d,e){console.log('形参个数:' + add.length);} 运行效果: ...

  8. Caffe学习(四)数据层及参数设置

    caffe的各种数据层在caffe.proto文件中有定义.通过对定义的caffe.proto文件进行编译,产生支持各种层操作的c++代码.后面将会详细解读caffe.proto文件(在caffe里就 ...

  9. WordPress分类列表函数:wp_list_categories用法及参数详解举例

    http://www.511yj.com/wordpress-wp-categories.html 注意: 1. wp_list_categories() 和 list_cats() 以及 wp_li ...

  10. vlc 详细使用方法:libvlc_media_add_option 函数中的参数设置

    [转载自]tinyle的专栏 [原文链接地址]http://blog.csdn.net/myaccella/article/details/7027962 [手记] 下面列出的参数可以在命令行中执行, ...

最新文章

  1. Java内部类及其实例化
  2. 在Salesforce中调用外部系统所提供的的Web Service
  3. 站内搜索引擎初探:haystack全文检索,whoosh搜索引擎,jieba中文分词
  4. Android 系统(254)---Android libphonenumber Demo 手机号码归属地
  5. 说说WeakReference弱引用
  6. 一题多解(八)—— 矩阵上三角(下三角)的访问
  7. linux操作系统死机处理办法
  8. 安卓版的水经注地图_水经注万能地图下载器
  9. 电脑桌面数字时钟c语言,DesktopDigitalClock(桌面数字时钟)
  10. ATAT-mcsqs- 运行后出现报错:段错误(吐核/core dumped)
  11. java pgp 加密_加密软件PGP的使用
  12. mysql 获取两个月前的日期
  13. DELPHI盒子上的RAD studio 2010安装过程及体验(超多图)
  14. 读周爱民《javascript语言精髓与编程实践》有感
  15. 数据挖掘实战:员工离职预测(训练赛)
  16. 服务器内存占用太高如何解决及知识点介绍
  17. python:蒙特卡罗方法计算圆周率
  18. 问卷链接怎么做二维码?如何使用二维码做问卷调查?
  19. 逍遥模拟器配合fiddler爬取抖音视频!开源免费!
  20. [Unity官方文档翻译]Downloading and Installing Unity下载和安装unity教程

热门文章

  1. Wine 开发版 4.6 发布,Windows 应用的兼容层
  2. java B2B2C 源码 多级分销Springcloud多租户电子商城系统-SpringCloud配置中心内容加密...
  3. 从零开始webpack搭建项目
  4. oracle VM manager 3.1 试验备忘录
  5. 【小贴士】工作中的”闭包“与事件委托的”阻止冒泡“
  6. [洛谷1681]最大正方形II
  7. Nancy之Cache的简单使用
  8. 目前M院M师的教学乱象
  9. 老去的80后忆当年-致80后的朋友们
  10. #paragma详解