Github

源代码与数据文件均在github上,对识别感兴趣的小伙伴点个star啦,共同学习共同进步,谢谢!!!
只需要把文件完整下载,改变文件的目录,然后把数据改为你想要识别的物品,修改全连接层(最后一层)的输出即可完成识别。
Github地址

数据

数据是由老师手底下几个学生帮忙拍摄获取,大约两千张数据,数据参差不齐,勉强用来训练耍耍,毕竟大数据咱这破电脑也带不动的哇!
数据分类放在drink_data下的7个文件夹中,文件夹为其分类。drink_data下有一些自己手工拍摄用于验证的照片(事实证明过拟合严重。So 需要大量的数据去解决这个问题。这里挖个坑,先通过改变一些RGB的数字,然后加上裁剪+反转/镜像之类的方法加大数据集,最后了解生成对抗网络去(GANs加大数据集!!

训练

ml的日常包与一些基本参数。

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time
import datetime#数据集地址
path='C:/Users/Administrator/Desktop/ML/bottle/drink_data/'
#图片集合保存地址
#模型保存地址
#C:\Users\Administrator\Desktop\ML\bottle\model
#model_path='C:/Users/Administrator/Desktop/ML/bottle/model/model.ckpt'#将所有的图片resize成100*100
w=100
h=100
c=3

读取图片,并把图片尺寸修改然后乱序分割数据集

#读取图片
def read_img(path):cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]imgs=[]labels=[]for idx,folder in enumerate(cate):for im in glob.glob(folder+'/*.jpg'):#输出读取了的模型的图片#print('reading the images:%s'%(im))img=io.imread(im)img=transform.resize(img,(w,h))imgs.append(img)labels.append(idx)return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

图片初始设置为128*128*3
搭建网络:这里卷积与池化的’SAME’/’ VALID’,一个是当卷积核无法找到足够大小的尺寸进行卷积时补0,一个是尺寸不够的时候直接将剩余多余的部分扔掉。
第一次卷积厚度从3->32,W与H不变。
第二次卷积厚度从32->64,W与H不变。
第三次卷积厚度从64->128,W与H不变。
第四次卷积厚度从128->128,W与H不变。
其中每次卷积后面跟一次池化,池化大小为1*2*2*1。
那么图片在最后得到的结果就是8*8*3
然后其中有3次全连接层降参数,从8*8*3->1024->512->7(这里我识别的种数为7

#-----------------构建网络----------------------
#占位符,因为这里喂的数据组数不确定,
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')def inference(input_tensor, train, regularizer):with tf.variable_scope('layer1-conv1'):conv1_weights = tf.get_variable("weight",[5,5,3,32],initializer=tf.truncated_normal_initializer(stddev=0.1))conv1_biases = tf.get_variable("bias", [32], initializer=tf.constant_initializer(0.0))conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))with tf.name_scope("layer2-pool1"): #池化到一半pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="VALID")with tf.variable_scope("layer3-conv2"):conv2_weights = tf.get_variable("weight",[5,5,32,64],initializer=tf.truncated_normal_initializer(stddev=0.1))conv2_biases = tf.get_variable("bias", [64], initializer=tf.constant_initializer(0.0))conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))with tf.name_scope("layer4-pool2"):#多出来的直接不pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')with tf.variable_scope("layer5-conv3"):conv3_weights = tf.get_variable("weight",[3,3,64,128],initializer=tf.truncated_normal_initializer(stddev=0.1))conv3_biases = tf.get_variable("bias", [128], initializer=tf.constant_initializer(0.0))conv3 = tf.nn.conv2d(pool2, conv3_weights, strides=[1, 1, 1, 1], padding='SAME')relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_biases))with tf.name_scope("layer6-pool3"):pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')with tf.variable_scope("layer7-conv4"):conv4_weights = tf.get_variable("weight",[3,3,128,128],initializer=tf.truncated_normal_initializer(stddev=0.1))conv4_biases = tf.get_variable("bias", [128], initializer=tf.constant_initializer(0.0))conv4 = tf.nn.conv2d(pool3, conv4_weights, strides=[1, 1, 1, 1], padding='SAME')relu4 = tf.nn.relu(tf.nn.bias_add(conv4, conv4_biases))with tf.name_scope("layer8-pool4"):pool4 = tf.nn.max_pool(relu4, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')nodes = 6*6*128reshaped = tf.reshape(pool4,[-1,nodes])with tf.variable_scope('layer9-fc1'):fc1_weights = tf.get_variable("weight", [nodes, 1024],initializer=tf.truncated_normal_initializer(stddev=0.1))if regularizer != None: tf.add_to_collection('losses', regularizer(fc1_weights))fc1_biases = tf.get_variable("bias", [1024], initializer=tf.constant_initializer(0.1))fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)if train: fc1 = tf.nn.dropout(fc1, 0.5)with tf.variable_scope('layer10-fc2'):fc2_weights = tf.get_variable("weight", [1024, 512],initializer=tf.truncated_normal_initializer(stddev=0.1))if regularizer != None: tf.add_to_collection('losses', regularizer(fc2_weights))fc2_biases = tf.get_variable("bias", [512], initializer=tf.constant_initializer(0.1))fc2 = tf.nn.relu(tf.matmul(fc1, fc2_weights) + fc2_biases)if train: fc2 = tf.nn.dropout(fc2, 0.5)with tf.variable_scope('layer11-fc3'):fc3_weights = tf.get_variable("weight", [512, 7],initializer=tf.truncated_normal_initializer(stddev=0.1))if regularizer != None: tf.add_to_collection('losses', regularizer(fc3_weights))fc3_biases = tf.get_variable("bias", [7], initializer=tf.constant_initializer(0.1))logit = tf.matmul(fc2, fc3_weights) + fc3_biasesreturn logit
#激活函数relu改成tanh,可能对你的网络是个优化方法
#---------------------------网络结束---------------------------

损失函数,验证函数,训练等函数的源码:

regularizer = tf.contrib.layers.l2_regularizer(0.001)
logits = inference(x,False,regularizer)#(小处理)将logits乘以1赋值给logits_eval,定义name,方便在后续调用模型时通过tensor名字调用输出tensor
b = tf.constant(value=1,dtype=tf.float32)
logits_eval = tf.multiply(logits,b,name='logits_eval')loss=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_)
#定义损失函数train_op=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) #这里调整每次的步长,步长太小,验证集达不到效果,很快就过拟合
#每次步长1e-4
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
#进行验证
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):assert len(inputs) == len(targets)if shuffle:indices = np.arange(len(inputs))np.random.shuffle(indices)for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):if shuffle:excerpt = indices[start_idx:start_idx + batch_size]else:excerpt = slice(start_idx, start_idx + batch_size)yield inputs[excerpt], targets[excerpt]

训练网络,把for下面的代码注释掉,为使用之前保存过的模型再次训练。

#训练和测试数据,可将Times设置更大一些Times =20
batch_size=64
saver=tf.train.Saver()
sess=tf.Session()
sess.run(tf.global_variables_initializer())#下面两句话为调用已有的模型,去再次训练
model_path = "model/model.ckpt"
#saver.restore(sess, model_path)
start_time = datetime.datetime.now()
print('开始训练时间为:  ' ,start_time.strftime('%Y-%m-%d %H:%M:%S '))for epoch in range(Times):#加下面2行是调用之前的模型#model_path = "model/model.ckpt"#saver.restore(sess, model_path)#trainingtrain_loss, train_acc, n_batch = 0, 0, 0#这里是把所有的数据都训练一遍,然后输出一次结果for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})train_loss += err; train_acc += ac; n_batch += 1current_time = datetime.datetime.now()print('现在时间是:',current_time.strftime('%Y-%m-%d %H:%M:%S '),'第',epoch,'次的结果为 : '  )print("   train loss: %f" % (np.sum(train_loss)/ n_batch))print("   train acc: %f" % (np.sum(train_acc)/ n_batch))#saver = tf.train.Saver()#validationval_loss, val_acc, n_batch = 0, 0, 0for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})val_loss += err; val_acc += ac; n_batch += 1print("   validation loss: %f" % (np.sum(val_loss)/ n_batch))print("   validation acc: %f" % (np.sum(val_acc)/ n_batch))saver.save(sess, model_path)print('保存模型')
sess.close()

优化方向

场景:基于无人售货机的一个物品识别基础版本

基于无人售货机的一个对货品的图像识别 用了11层网络
源码包括了数据集,数据处理,训练过程,验证过程,并有部分手工照片来验证结果(并且可以通过结果看到神经网络过拟合的严重程度
感兴趣的小伙伴可以直接git clone 下载下来直接跑源码即可正常出结果(需要改变部分源码中的目录
源码train中的session中有两句话是可以通过删掉注释,让代码跑上一次存下的模型,而不需要每次都从头开始跑
需要应用到自己的东西上,仅仅需要改变目录与自己图片的数据,然后改变全连接层的分类结果即可。

优化方向:
1、过拟合考虑加改变层数与厚度,加大数据集
2、激活函数可以relu改成tanh试试
3、步伐的改变
4、单通道图片可以把c=3改为c=1
5、迁移学习获取模型再用自己的数据进行训练
6、考虑一些新的网络,google最近出了不少好东西哇

如果报错:
1、loss的结果为nan(说明你最后一个全连接层的分类结果没设置正确
2、验证结果出现错误,找不到正确的文件(把我源码中的model/model.ckpt.meta修改为model/model.ckpt即可

一个11层的CNN(基于无人售货机的货物识别相关推荐

  1. 基于JAVA无人售货机管理系统计算机毕业设计源码+数据库+lw文档+系统+部署

    基于JAVA无人售货机管理系统计算机毕业设计源码+数据库+lw文档+系统+部署 基于JAVA无人售货机管理系统计算机毕业设计源码+数据库+lw文档+系统+部署 本源码技术栈: 项目架构:B/S架构 开 ...

  2. 几乎等于一个小超市的新型无人售货机

    随着科技的发展,自助购物,移动支付,都是大家现在非常喜欢做的事情-无人零售时代已经来临,它是人们展望未来社会所衍生的产物,自发明至今已经得到全面发展.而随着科技的进步,无人售货机的售卖也更加丰富多彩. ...

  3. 基于SSM的智能无人售货机系统的设计与实现

    开发工具(eclipse/idea/vscode等): 数据库(sqlite/mysql/sqlserver等): 功能模块(请用文字描述,至少200字): 基于SSM的智能无人售货机系统的设计与实现 ...

  4. java毕业设计无人售货机管理系统源码+lw文档+mybatis+系统+mysql数据库+调试

    java毕业设计无人售货机管理系统源码+lw文档+mybatis+系统+mysql数据库+调试 java毕业设计无人售货机管理系统源码+lw文档+mybatis+系统+mysql数据库+调试 本源码技 ...

  5. 计算机毕业设计Java无人售货机管理系统(源码+系统+mysql数据库+Lw文档)

    计算机毕业设计Java无人售货机管理系统(源码+系统+mysql数据库+Lw文档) 计算机毕业设计Java无人售货机管理系统(源码+系统+mysql数据库+Lw文档) 本源码技术栈: 项目架构:B/S ...

  6. java计算机毕业设计无人售货机管理系统源码+lw文档+系统+数据库

    java计算机毕业设计无人售货机管理系统源码+lw文档+系统+数据库 java计算机毕业设计无人售货机管理系统源码+lw文档+系统+数据库 本源码技术栈: 项目架构:B/S架构 开发语言:Java语言 ...

  7. java/php/net/python无人售货机管理系统设计

    本系统带文档lw万字以上+答辩PPT+查重 如果这个题目不合适,可以去我上传的资源里面找题目,找不到的话,评论留下题目,或者站内私信我, 有时间看到机会给您发 1.关于无人售货机管理系统的基本要求 ( ...

  8. 无人便利店和无人售货机的区别在哪

    在很多业界观点看来,无人便利店和无人售货机的主要特质就在于"无人",也就是通过无人零售的方式从而降低人力成本,同时仍能够实现24小时开店的需求.下面37号仓小编给大家讲解下无人便利 ...

  9. 37号仓:Cbox单门无人售货机

    Cbox单门无人售货机是37号仓设计出了新一代更智能的基于视觉识别的无人售货设备,用户通过扫码开门-随意选取货品-关门实时结算,实现了高效.便捷的购物体验. 每一台Cbox都搭载了一台高效率的识别大脑 ...

最新文章

  1. Microsoft Visual C++ 14.0 is required
  2. 超有用的,从此vi变得友好了
  3. PAT1010 一元多项式求导 (25 分)
  4. TCP/IP协议各层首部汇总
  5. StringEscapeUtils类的使用
  6. 你是菜鸡是有原因的 谈谈提问的艺术
  7. 前端系统化学习【JS篇】:(二)Javascript、变量和值的简述
  8. Repeater控件里面取不到CheckBox的值
  9. 远程服务器时Ubuntu报错:qt.qpa.xcb: could not connect to display
  10. --legacy-peer-deps 作用
  11. KeilC51使用详解 (三)
  12. 如何查看一个期刊是sci几区以及影响因子 入藏号 ISSN等信息
  13. Mac录屏方法:无需下载软件
  14. 【渝粤题库】陕西师范大学292251 公司金融学Ⅰ 作业(高起专)
  15. QQ代替;teamviewer检测为商业用途 5分钟后关闭解决方法
  16. Android 仿微信通讯录 导航分组列表-上】使用ItemDecoration为RecyclerView打造带悬停头部的分组列表
  17. 淘宝新店一个流量没有如何是好
  18. 安科瑞变电所运维云平台AcrelCloud-1000实时监测
  19. prefetch()
  20. 设计模式--静态工厂、简单工厂方法案例分析

热门文章

  1. 计算机未连接到网络但是可以上网,win10系统能上网但图标显示未连接的解决办法...
  2. GPU计算能力和性能指标
  3. csharp:百度翻译
  4. arduino tft 方向_Arduino库教程-TFT Library
  5. python语言能够跨平台使用吗_中国大学MOOC: Python语言能够跨平台使用。
  6. 保存登录信息的Cookie加密技术
  7. toad for oracle 11 手册,toad for oracle 11
  8. 一周心结总,拨开云雾见青天
  9. 海南省大数据管理局项目建设处刘雄:区块链技术在海南政务服务领域的典型应用
  10. [AFCTF2018]一道有趣的题目