本文介绍如何使用keras作图片分类(2分类与多分类,其实就一个参数的区别。。。呵呵)

先来看看解决的问题:从一堆图片中分出是不是书本,也就是最终给图片标签上:“书本“、“非书本”,简单吧。

先来看看网络模型,用到了卷积和全连接层,最后套上SOFTMAX算出各自概率,输出ONE-HOT码,主要部件就是这些,下面的nb_classes就是用来控制分类数的,本文是2分类:

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD  def Net_model(nb_classes, lr=0.001,decay=1e-6,momentum=0.9):  model = Sequential()  model.add(Convolution2D(filters=10, kernel_size=(5,5),padding='valid',  input_shape=(200, 200, 3)))  model.add(Activation('tanh'))  model.add(MaxPooling2D(pool_size=(2, 2)))  model.add(Convolution2D(filters=20, kernel_size=(10,10)))model.add(Activation('tanh'))  model.add(MaxPooling2D(pool_size=(2, 2)))  model.add(Dropout(0.25))  model.add(Flatten())  model.add(Dense(1000))model.add(Activation('tanh'))  model.add(Dropout(0.5))  model.add(Dense(nb_classes))  model.add(Activation('softmax'))  sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  model.compile(loss='categorical_crossentropy', optimizer=sgd)  return model

上面的input_shape=(200, 200, 3)代表图片像素大小为宽高为200,200,并且包含RGB 3通道的图片,不是灰度图片(只要1个通道)

也就是说送入此网络的图片宽高必须200*200*3;如果不是这个shape就需要resize到这个shape

下面来看看训练程序,首先肯定是要收集些照片,书本、非书本的照片,我是分别放在了0文件夹和1文件夹下了,再带个验证用途的文件夹validate:

  

训练程序涉及到几个地方:照片文件的读取、模型加载训练与保存、可视化训练过程中的损失函数value

照片文件的读取

import cv2
import os
import numpy as np
import kerasdef loadImages():imageList=[]labelList=[]rootdir="d:\\books\\0"list =os.listdir(rootdir)for item in list:path=os.path.join(rootdir,item)if(os.path.isfile(path)):f=cv2.imread(path)f=cv2.resize(f, (200, 200))#resize到网络input的shapeimageList.append(f)labelList.append(0)#类别0rootdir="d:\\books\\1"list =os.listdir(rootdir)for item in list:path=os.path.join(rootdir,item)if(os.path.isfile(path)):f=cv2.imread(path)f=cv2.resize(f, (200, 200))#resize到网络input的shapeimageList.append(f)labelList.append(1)#类别1return np.asarray(imageList), keras.utils.to_categorical(labelList, 2)

关于(200,200)这个shape怎么得来的,只是几月前开始玩opencv时随便写了个数值,后来想利用那些图片,就适应到这个shape了

keras.utils.to_categorical函数类似numpy.onehot、tf.one_hot这些,只是one hot的keras封装

模型加载训练与保存

nb_classes = 2
nb_epoch = 30
nb_step = 6
batch_size = 3x,y=loadImages()from keras.preprocessing.image import ImageDataGenerator
dataGenerator=ImageDataGenerator()
dataGenerator.fit(x)
data_generator=dataGenerator.flow(x, y, batch_size, True)#generator函数,用来生成批处理数据(从loadImages中)model=NetModule.Net_model(nb_classes=nb_classes, lr=0.0001) #加载网络模型history=model.fit_generator(data_generator, epochs=nb_epoch, steps_per_epoch=nb_step, shuffle=True)#训练网络,并且返回每次epoch的损失valuemodel.save_weights('D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5')#保存权重
print("DONE, model saved in path-->D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5")

ImageDataGenerator构造函数有很多参数,主要用来提升数据质量,比如要不要标准化数字

lr=0.001这个参数要看经验,大了会导致不收敛,训练的时候经常由于这个参数的问题导致重复训练,这在没有GPU的情况下很是痛苦。。痛苦。。。痛苦。。。

model.save_weights是保存权重,但是不保存网络模型 ,对应的是model.load_weights方法

model.save是保存网络+权重,只是。。。。此例中用save_weights保存的h5文件是125M,但用save方法保存后,h5文件就增大为280M了。。。

上面2个save方法都能finetune,只是灵活度不一样。

可视化训练过程中的损失函数value

import matplotlib.pyplot as pltplt.plot(history.history['loss'])
plt.show()

  

貌似没啥好补充的。。。

AND。。。。看看预测部分吧,这部分加载图片、加载模型,似乎都和训练部分雷同:

def loadImages():imageList=[]rootdir="d:\\books\\validate"list =os.listdir(rootdir)for item in list:path=os.path.join(rootdir,item)if(os.path.isfile(path)):f=cv2.imread(path)f=cv2.resize(f, (200, 200))imageList.append(f)return np.asarray(imageList)x=loadImages()x=np.asarray(x)model=NetModule.Net_model(nb_classes=2, lr=0.0001)
model.load_weights('D:\\Documents\\Visual Studio 2017\\Projects\\ConsoleApp9\\PythonApplication1\\书本识别\\trained_model_weights.h5')print(model.predict(x))
print(model.predict_classes(x))
y=convert2label(model.predict_classes(x))
print(y)

predict的返回其实是softmax层返回的概率数值,是<=1的float

predict_classes返回的是经过one-hot处理后的数值,此时只有0、1两种数值(最大的value会被返回称为1,其他都为0)  

convert2label:

def convert2label(vector):string_array=[]for v in vector:if v==1:string_array.append('BOOK')else:string_array.append('NOT BOOK')return string_array

这个函数是用来把0、1转换成文本的,小插曲:

本来这里是中文的“书本”、“非书本”,后来和女儿一起调试时发现都显示成了问号,应该是中文字符问题,就改成了英文显示,和女儿一起写代码是种乐趣啊!

本来只是显示文本,感觉太无聊了,因此加上了opencv显示图片+分类文本的代码段:

for i in range(len(x)):cv2.putText(x[i], y[i], (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, 2)cv2.imshow('image'+str(i), x[i])cv2.waitKey(-1)

  

OK, 2018年继续学习,继续科学信仰。

转载于:https://www.cnblogs.com/aarond/p/CNN.html

用keras作CNN卷积网络书本分类(书本、非书本)相关推荐

  1. 基于Keras搭建CNN、TextCNN文本分类模型

    基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...

  2. 基于TensorFlow的CNN卷积网络模型花卉分类GUI版(2)

    一.项目描述 10类花的图片1100张,按{牡丹,月季,百合,菊花,荷花,紫荆花,梅花,-}标注,其中1000张作为训练样本,100张作为测试样本,设计一个CNN卷积神经网络花卉分类器进行花卉的分类, ...

  3. cnn卷积网络解决尺寸大小不确定的图片分类问题--银行凭证印章检查

    简 介 CNN卷积神经网络进行图像识别分类的技术很简单,一般神经网络初学者都做过mnist分类实验,简单的分类任务准确率跑到99%已经不成问题,这里不做过多的介绍.但是想用卷积神经网络解决实际问题,往 ...

  4. 【keras】一维卷积神经网络多分类

    刚刚接触到深度学习,前2个月的时间里,我用一维的卷积神经网络实现了对于一维数据集的分类和回归.由于在做这次课题之前,我对深度学习基本上没有过接触,所以期间走了很多弯路. 在刚刚收到题目的要求时,我选择 ...

  5. 【记录】本科毕设:基于树莓派的智能小车设计(使用Tensorflow + Keras 搭建CNN卷积神经网络 使用端到端的学习方法训练CNN)

    0 申明 这是本人2020年的本科毕业设计,内容多为毕设论文和答辩内容中挑选.最初的灵感来自于早前看过的一些项目(抱歉时间久远,只记录了这一个,见下),才让我萌生了做个机电(小车动力与驱动)和控制(树 ...

  6. CNN卷积神经网络(数字分类)

    CNN卷积神经网络--手写数字识别 import torch import torch.nn as nn from torch.autograd import Variable import torc ...

  7. 自然语言处理--keras实现一维卷积网络对IMDB 电影评论数据集构建情感分类器

    为什么在 NLP 分类任务中选择 CNN 呢? 1.CNN神经网络可以像处理图像一样处理文本并"理解"它们 2.主要好处是高效率 3.在许多方面,由于池化层和卷积核大小所造成的限制 ...

  8. m基于CNN卷积网络和GEI步态能量图的步态识别算法MATLAB仿真,测试样本采用现实拍摄的场景进行测试,带GUI界面

    目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 目前关于步态识别算法研究主要有两种:基于模型的方法和非基于模型的方法.基于模型的步态识别方法优点在于 ...

  9. 基于TensorFlow的CNN卷积网络模型花卉分类(1)

    一.项目描述 使用TensorFlow进行卷积神经网络实现花卉分类的项目,加载十种花分类,建立模型后进行预测分类图片 环境:win10 +TensorFlow gpu 1.12.0+pycharm 训 ...

最新文章

  1. 商汤提基于贪心超网络的One-Shot NAS,达到最新SOTA | CVPR 2020
  2. Android log 引发的血案
  3. 【PC工具】Windows10开始菜单增强工具Stardock Start10
  4. Python爬虫的框架有哪些?推荐这五个!
  5. WebBrowser控件打开https站点
  6. java高校教师工作量管理系统_基于ssh/bs/java/asp.net/php/web/安卓的高校教师工作量管理系统...
  7. python找出一个数的所有因子_python – 找到最大素因子的正确算法
  8. CERC2017 Gambling Guide,最短路变形,期望dp
  9. 对于mysql的用户权限管理
  10. (转载)Hadoop -- Map-Reduce入门
  11. Implement Trie (Prefix Tree)
  12. 百度地图行政区域划分镂空
  13. 纯css实现二级下拉菜单
  14. 羊皮卷之七:我要笑遍世界
  15. java使用Aspose.pdf实现pdf转图片
  16. Java教程——软件开发基础
  17. AI 智能头像生成神器|PhotoShot
  18. 计算机毕业设计Java房产中介管理系统(源码+系统+mysql数据库+lW文档)
  19. 五款高效率黑科技神器工具,炸裂好用,省时间
  20. 【10-11】PR调色+多机位剪辑

热门文章

  1. Java 利用InetAddress类确定特殊Ip地址
  2. 【机器视觉】机器视觉博客汇总
  3. 【Qt】数据库SQL接口层
  4. 虚拟机3种网络模式(桥接、nat、Host-only)
  5. android 约束布局的坑,android - 使用android约束布局2.0.0 Flow将项目放置一行 - 堆栈内存溢出...
  6. python就业前景如何_2020年Python就业前景如何?就业岗位多不多?薪资高不高?...
  7. qq飞车登陆服务器无响应,qq飞车手游进不去怎么回事 为什么进不去游戏
  8. 计算机一级查找同类型文件,如何快捷找出电脑内的重复文件
  9. 计算机中cmos设置程序,电脑主板上有CMOS设置是什么意思
  10. Docker Centos 7.X部署Tomcat 并且修改Server.xml配置文件方案 并设置时区 只要十一步