CNN卷积神经网络是人工智能的开端,CNN卷积神经网络让计算机能够认识图片,文字,甚至音频与视频。CNN卷积神经网络的基础知识,可以参考:CNN卷积神经网络

LetNet体系结构是卷积神经网络的“第一个图像分类器”。最初设计用于对手写数字进行分类,上期文章我们分享了如何使用keras来进行手写数字的神经网络搭建:Keras人工智能神经网络 Classifier 分类 神经网络搭建

我们也可以轻松地将其扩展到其他类型的图像上,本期使用小雪人的照片,来让神经网络识别雪人

雪人的图片大家可以到网络上自行下载,当然也可以使用爬虫技术来下载

搭建keras神经网络识别图片

from keras.models import Sequentialfrom keras.layers.convolutional import Conv2Dfrom keras.layers.convolutional import MaxPooling2Dfrom keras.layers.core import Activationfrom keras.layers.core import Flattenfrom keras.layers.core import Densefrom keras import backend as K

首先导入需要的模块,建立一个神经网络以便后期使用,在一个单独的文件中,命名此神经网络类(lenet.py)

class LeNet:@staticmethoddef build(width, height, depth, classes):# 使用Sequential()初始化modelmodel = Sequential()inputShape = (height, width, depth) #tensorflow默认设置    #宽度 :输入图像的宽度    #高度 :输入图像的高度    #深度 :输入图像中的频道数(1个 对于灰度单通道图像, 3 标准RGB图像)    # 若是其他的(Theano),则使用((depth, height, width)if K.image_data_format() == "channels_first":inputShape = (depth, height, width)#建立卷积神经网络  =>然后是 RELU => 然后是max pooling(跟前期分享的tensorflow教程类似)model.add(Conv2D(20, (5, 5), padding="same",input_shape=inputShape))model.add(Activation("relu"))model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))# 建立卷积神经网络  =>然后是 RELU => 然后是max pooling(第二层)model.add(Conv2D(50, (5, 5), padding="same"))model.add(Activation("relu"))model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))# 增加全连接层model.add(Flatten())model.add(Dense(500))model.add(Activation("relu"))# softmax classifier 来进行神经网络的分类model.add(Dense(classes))model.add(Activation("softmax"))# return the modelreturn model

训练keras神经网络

以上建立了keras 的神经网络模型,我们就使用预先下载好的图片来训练神经模型

建立一个train.py文件,插入如下代码,来训练神经网络模型(图片数据里面分成如下2类)

  1. snowman #我们训练的图片
  2. notsnowman 增加非雪人图片的训练
import matplotlibmatplotlib.use("Agg")from keras.preprocessing.image import ImageDataGeneratorfrom keras.optimizers import Adamfrom sklearn.model_selection import train_test_splitfrom keras.preprocessing.image import img_to_arrayfrom keras.utils import to_categoricalfrom lenet import LeNetfrom imutils import pathsimport matplotlib.pyplot as pltimport numpy as npimport randomimport cv2import os

初始化参数

EPOCHS = 25 #学习的步数INIT_LR = 1e-3# 学习效率BS = 32# 每步学习的个数data = []# 存放图片数据labels = []# 存放图片标签imagePaths = sorted(list(paths.list_images("dataset")))# 遍历所有的图片random.seed(42)random.shuffle(imagePaths) # 打乱图片顺序

初始化参数完成后,需要把所有的图片加载,进行图片数据的整理

for imagePath in imagePaths:    # 加载图片    image = cv2.imread(imagePath)    image = cv2.resize(image, (28, 28)) # resize 到28*28 LeNet所需的空间尺寸    image = img_to_array(image) # 图片转换成array    data.append(image) # 保存图片数据    label = imagePath.split(os.path.sep)[-2] #获取图片标签    label = 1 if label == "snowman" else 0    labels.append(label) # 获取图片标签

预先处理图片

# 把图片数据变成【0.1】data = np.array(data, dtype="float") / 255.0labels = np.array(labels)# 设置测试数据与训练数据#使用75%的数据将数据划分为训练和测试#用于训练的数据,其余25%用于测试(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)# 标签转换成向量trainY = to_categorical(trainY, num_classes=2)testY = to_categorical(testY, num_classes=2)# 创建一个图像生成器对象,该对象在图像数据集上执行随机旋转,平移,翻转,修剪和剪切。#这使我们可以使用较小的数据集,但仍然可以获得较高的结果aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,                         height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,                         horizontal_flip=True, fill_mode="nearest")

建立神经网络,进行神经网络训练

#建立modelmodel = LeNet.build(width=28, height=28, depth=3, classes=2)opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])# 训练神经网络H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),                        validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,                        epochs=EPOCHS, verbose=1)

神经网络训练完成后,对神经网络训练的结果进行保存,以便后期使用预训练模型进行图片识别

保存模型,显示训练结果

model.save("lenet.model") # 保存模型# 显示结果plt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on snowman/Notsnowman")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig("plot1.JPG")

从训练结果可以看出,loss越来越小,精度越来越高,表明我们的神经网络模型是完全ok。

若想得到更好的训练数据,当然是使用大量的数据进行训练

以上便是我们训练的神经网络模型,下期我们使用预训练模型,对图片进行识别

plt保存图片_人工智能Keras CNN卷积神经网络的图片识别模型训练相关推荐

  1. cnn图像二分类 python_人工智能Keras图像分类器(CNN卷积神经网络的图片识别篇)...

    上期文章我们分享了人工智能Keras图像分类器(CNN卷积神经网络的图片识别的训练模型),本期我们使用预训练模型对图片进行识别:Keras CNN卷积神经网络模型训练 导入第三方库 from kera ...

  2. CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用

    CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用 目录 基于tensorflow框架采用CNN(改进 ...

  3. cnn 预测过程代码_代码实践 | CNN卷积神经网络之文本分类

    学习目录阿力阿哩哩:深度学习 | 学习目录​zhuanlan.zhihu.com 前面我们介绍了:阿力阿哩哩:一文掌握CNN卷积神经网络​zhuanlan.zhihu.com阿力阿哩哩:代码实践|全连 ...

  4. cnn神经网络可以用于数据拟合吗_使用Keras搭建卷积神经网络进行手写识别的入门(包含代码解读)...

    本文是发在Medium上的一篇博客:<Handwritten Equation Solver using Convolutional Neural Network>.本文是原文的翻译.这篇 ...

  5. tensorflow机器学习之利用CNN卷积神经网络进行面部表情识别的实例代码

    本例通过 TensorFlow 构造卷积神经网络,做表情识别的测试. 输入数据可以从http://download.csdn.net/user/shinian1987上下载FER-2013 这个数据库 ...

  6. 基于CNN卷积神经网络的人脸识别

    一.利用卷积神经网络进行人脸检测,称作CFF(卷积人脸搜索) 卷积神经网络人脸识别的大致流程: 1)对本地人脸进行特征提取 2)打开摄像头(opencv) 3)从cap获取信息 4)找人脸 5)对人脸 ...

  7. Tensorflow【实战Google深度学习框架】用卷积神经网络打造图片识别应用

    文章目录 1 Tensorflow model 2 卷积神经网络的基础单元 2.1 卷积 2.2 激活函数 2.3 池化 2.4 批归一化 2.5 Dropout 3 主流的25个深度学习模型 4 训 ...

  8. python机器学习库keras——CNN卷积神经网络识别手写体

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 keras使用CNN识别手写体 其 ...

  9. python机器学习库keras——CNN卷积神经网络人脸识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 github地址:https://github.com/626626cdllp/kears/tree/master/Face_Recognit ...

最新文章

  1. 让思维导图改变我们的工作和生活吧
  2. Global Average Pooling对全连接层的可替代性分析
  3. 使用Flex4容器若干技巧
  4. JPA中实现单向一对多的关联关系
  5. 打开计算机后 无法最小化,最小化窗口后无法在任务栏中显示的三种解决方法...
  6. 【思科】GNS3模拟静态NAT/动态NAT
  7. 一次MySQL线上慢查询分析及索引使用
  8. pdf在线翻译_如何将英文的PDF文档翻译成中文简体?
  9. Tough Days
  10. 计算机单词 硬件类、软件类、网络类、其他
  11. 通过串口波特率计算数据传输速率(每秒字节数)
  12. 怎样安装2003服务器系统安装,Windows 2003系统详细安装教程图解
  13. vue3.0项目打包后,由于vender.js 文件过大引起的首页加载时间缓慢的解决方式
  14. python怎么重复画圆_重画圆Python
  15. 不撞南墙不回头——深度优先搜索
  16. 影视网站导航PHP源码
  17. petya病毒分析_首先是WannaCry,现在是Petya –防范大规模勒索软件攻击
  18. 数值计算 - Richardson外推法求一阶导数(C++实现)
  19. ❤️UI自动化轻松解决微信手工群发消息的烦恼❤️
  20. UnityShader学习笔记:Caustic水纹焦散与鱼群制作水族馆

热门文章

  1. Adobe Premiere Pro CC 2015.0 已停止工作【解决方案】
  2. js解决iframe跨域问题
  3. 我应该在CSS中使用px或rem值单位吗?
  4. 如何在ImageView中缩放图像以保持纵横比
  5. 为什么Java的+ =,-=,* =,/ =复合赋值运算符不需要强制转换?
  6. 如何正确强制执行Git推送?
  7. Python是否具有字符串“包含”子字符串方法?
  8. “ INSERT IGNORE”与“ INSERT…ON DUPLICATE KEY UPDATE”
  9. “最少惊讶”和可变默认参数
  10. 安装node和pm2