本算法是基于tensorflow,使用python语言进行的一种图像分类算法,参考于谷歌的mnist手写识别,包括以下几个模块:图像读取,图像处理,图像增强。卷积神经网络部分包括:卷积层1,汇合层1(部分文献也有叫池化层的),卷积层2,汇合层2,全连接层1,全连接层2,共6层神经网络。损失函数采用交叉熵,优化则采用adam优化法,由于数据集大小较小,只有200张图片,故没有采用MBGD梯度下降算法,直接采用BGD梯度下降算法。

首先是图像读取模块,由于本方案是应用于识别渣土车顶棚是否遮盖好的算法,所以没有网上现成的数据库,目前只有从网上收集图片,并转换成数据。

"""
定义一个遍历文件夹下所有图片,并转化为矩阵,压缩为特定大小,并传入一个总
的矩阵中去的函数
"""
def creat_x_database(rootdir,resize_row,resize_col):#列出文件夹下所有的,目录和文件list = os.listdir(rootdir)#创建一个随机矩阵,作为多个图片转换为矩阵后传入其中database=np.arange(len(list)*resize_row*resize_col*3).reshape(len(list),resize_row,resize_col,3)for i in range(0,len(list)):path = os.path.join(rootdir,list[i])    #把目录和文件名合成一个路径if os.path.isfile(path):                ##判断路径是否为文件image_raw_data = tf.gfile.FastGFile(path,'rb').read()#读取图片with tf.Session() as sess:img_data = tf.image.decode_jpeg(image_raw_data)#图片解码#压缩图片矩阵为指定大小resized=tf.image.resize_images(img_data,[resize_row,resize_col],method=0)database[i]=resized.eval()return database                          

以上是创建数据集,还有标签集,网上部分创建标签集的方法是直接读取已经设定好标签的图片名字创建,比较适合分类数目较多的且在一个文件夹下的图片,这里笔者因为图片数量不多加之图片全是网上搜索,故采用以下方法

def creat_y_database(length,classfication_value,one_hot_value):#创建一个适当大小的矩阵来接收array=np.arange(length*classfication_value).reshape(length,classfication_value)for i in range(0,length):array[i]=one_hot_value #这里采用one hot值来区别合格与不合格return array

下面是卷积神经网络的搭建,这里参考的mnist手写输入识别的分类的神经网络

'''
初步打造一个卷积神经网,使其能够对输入图片进行二分类
'''
#计算准确率def compute_accuracy(v_xs, v_ys):global predictiony_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})return result
#定义各参数变量并初始化
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):# stride [1, x_movement, y_movement, 1]# Must have strides[0] = strides[3] = 1return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):# stride [1, x_movement, y_movement, 1]return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')#创建训练集
fail_x_data=creat_x_database('E:/data/muck truck pic/failed',128,128)
true_x_data=creat_x_database('E:/data/muck truck pic/qualified',128,128)
x_data=np.vstack((fail_x_data,true_x_data))  #两个矩阵在列上进行合并
#创建标签集
fail_y_data=creat_y_database(fail_x_data.shape[0],2,[0,1])
true_y_data=creat_y_database(true_x_data.shape[0],2,[1,0])
y_data=np.vstack((fail_y_data,true_y_data))
#划分训练集和测试集
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data,test_size=0.1,random_state=0)
#train_test_split函数用于将矩阵随机划分为训练子集和测试子集,并返回划分好的训练集测试集样本和训练集测试集标签。xs = tf.placeholder(tf.float32, [None, 128,128,3])/255 #归一化
ys = tf.placeholder(tf.float32, [None, 2])
keep_prob = tf.placeholder(tf.float32)
#x_image = tf.reshape(xs, [-1, 50, 50, 3])W_conv1 = weight_variable([5,5, 3,32]) # patch 5x5, in size 3, out size 32
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(xs, W_conv1) + b_conv1) # output size 128x128x32
h_pool1 = max_pool_2x2(h_conv1)                          # output size 64x64x32W_conv2 = weight_variable([5,5, 32, 64]) # patch 5x5, in size 32, out size 64
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) # output size 64x64x64
h_pool2 = max_pool_2x2(h_conv2) #32x32x64W_fc1 = weight_variable([32*32*64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 32*32*64])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)W_fc2 = weight_variable([1024, 2])
b_fc2 = bias_variable([2])
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction+ 1e-10), reduction_indices=[1]))
#由于 prediction 可能为 0, 导致 log 出错,最后结果会出现 NA
#train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)for i in range(1000):
#    batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={xs: x_train, ys: y_train, keep_prob: 0.5})if i % 50 == 0:print(compute_accuracy(x_test,y_test))print(sess.run(cross_entropy, feed_dict={xs: x_train, ys: y_train, keep_prob: 1}))

这里有一个坑点要说明一下,在修改神经网络的过程中,发现精确度一直不变,这说明模型不收敛,一层一层的查找问题才发现是prediction的值可能为0,导致交叉熵出现-Naf这种情况。解决办法如上所示,在prediction的后面加一个极小数,防止log出现负无穷的情况。同时参考其他文章,在倒数第二层的激活函数上由ReLU改为sigmoid,原因是ReLU输出可能相差很大(比如0和几十),这时再经过softmax就会出现一个节点为1其它全0的情况。softmax的cost function里包含一项log(y),如果y正好是0就没法算了。

最终训练结果如图所示:

训练结果并不好,最高有75%,可能原因有超参数的调整没到位,图像样本太少,训练层数太少,训练次数太少等。后期需要改进这个模型。

参考文献:【1】https://blog.csdn.net/huangbo10/article/details/24941079

【2】https://github.com/MorvanZhou/tutorials/blob/master/tensorflowTUT/tf18_CNN3/full_code.py

【3】http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html

基于cnn的图像二分类算法(一)相关推荐

  1. 基于SVM的图像二分类算法

    本实验是用的python代码实现图像的二分类问题,我是在eclipse中搭建python环境. 一.数据集处理 我采用的是甜椒叶数据集(我忘了下载地址在哪里,我只用了一部分数据集),其中健康叶片有37 ...

  2. python中文文本分析_基于CNN的中文文本分类算法(可应用于垃圾邮件过滤、情感分析等场景)...

    基于cnn的中文文本分类算法 简介 参考IMPLEMENTING A CNN FOR TEXT CLASSIFICATION IN TENSORFLOW实现的一个简单的卷积神经网络,用于中文文本分类任 ...

  3. 基于CNN的图像缺陷分类

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自|新机器视觉 来源:博客园  原文地址:https://w ...

  4. 基于CNN的中文文本分类算法(可应用于垃圾文本过滤、情感分析等场景)

    向AI转型的程序员都关注了这个号

  5. 基于CNN的海面舰船图像二分类

    基于CNN的海面舰船图像二分类 1. 模型依赖的环境和硬件配置 Python3.8 Pillow==8.2.0 torch-1.5.1(cuda9.2) torchfile==0.1.0 torchv ...

  6. 基于deap脑电数据集的脑电情绪识别二分类算法(附代码)

    想尝试一下脑电情绪识别的各个二分类算法. 代码主要分为三部分:快速傅里叶变换处理(fft).数据预处理.以及各个模型处理. 采用的模型包括:决策树.SVM.KNN三个模型(模型采用的比较简单,可以直接 ...

  7. cnn生成图像显著图_基于CNN与图像前背景分离的显著目标检测

    基于 CNN 与图像前背景分离的显著目标检测 东野长磊 ; 万文鑫 [期刊名称] <软件导刊> [年 ( 卷 ), 期] 2020(019)001 [ 摘 要 ] 为 了 解 决 计 算 ...

  8. matlab人工选择阈值进行分割,基于MATLAB的图像阈值分割算法的研究

    [摘要]:图像分割是一种重要的数字图像处理技术.本文首先介绍了图像分割技术,其次总结了目前图像分割技术中所用到的阈值.边缘检测.区域提取等方法以及分水岭算法.针对各种阈值分割算法,本文在最后做了详细的 ...

  9. 支持向量机的近邻理解:图像二分类为例(1)

    前言: 机器学习在是否保留原始样本的层面划分为两类:参数学习和非参数学习.参数学习使用相对固定框架,把样本分布通过训练的方式回归到一个使用参数描述的数学模型里面,最终使用的是归纳方法:非参数模型保留了 ...

最新文章

  1. oracle获取 表名,Oracle获取当前数据库的所有表名字段名和注释
  2. (二)git常用基本概念
  3. string contains不区分大小写_String基础复习
  4. 如何在数字化转型战略中真正获得价值?浅谈数字化转型的四个层级
  5. 微信小程序 环形进度条_微信小程序:实时圆形进度条实现
  6. 全网首发:把一个bit数组矩阵旋转90度
  7. 多个微服务的接口依赖如何测试_一文看懂微服务
  8. 物理增强的深度学习模型改善卫星图像对热带气旋强度和大小估计(翻译)
  9. 微软ime日文输入法在假名输入模式下怎么快速输入英文
  10. Linux 删除多余内核
  11. Java 统计连续签到天数
  12. 江苏电信用户将体验iPhone6s的极速4G+网络
  13. 笔记本电脑外接显示器无信号 其实是主板静电积压 完全可以不拆机放电
  14. 计算机开关电源基本原理,开关电源基本原理与设计介绍——第一讲
  15. 互联网因特网计算机网络的区别,因特网和互联网的区别?
  16. VS C++ error LNK2005 1169报错
  17. Failed to start [powershell.exe](idea 启动前端页面,open in terminal时报错)
  18. 【转】:金龙鱼等品牌花生油全线涨价每瓶最高涨12.8元
  19. (三)如何实现多节目轮播。——安卓智能广告机
  20. Golang 023. 《孙子算经》之鸡兔同笼

热门文章

  1. python3中正确代码报红显示Indent expected
  2. 微信小程序如何设计实现
  3. 入门C语言第二话:函数(上)之锻体篇,带你玩转函数(内有汉诺塔,青蛙跳台阶等经典问题,建议收藏和分享)
  4. 线上展厅打造视觉亮点
  5. python怎样使用各个日期赤纬_天文数据处理笔记之python(3)
  6. maven oracle 10.2.0.4.0,马文介绍说ojdbc:ojdbc14-10.2.0.4.0.jar,Maven,引入,ojdbcojdbc14102040jar...
  7. SAP世界生存指南(2017版)
  8. SQL语句在dos操作MySQL数据库
  9. matlab中(),[],与{}的使用区别
  10. 记 计算机 科学学院 教师,永做学生的操作系统——记计算机科学技术学院、软件学院教师金虎...