去年年底学习了深度学习的相关知识,但是寒假回来之后忘得也差不多了。。。为了巩固下所学知识,近期利用卷积神经网络做了一个小实例。卷积神经网络是一种多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。卷积网络通过一系列方法,成功将数据量庞大的图像识别问题不断降维,最终使其能够被训练。为了测试卷积神经网络的性能,特地选择了猴子和狒狒这两种长得差不多的动物图片进行训练。

【step1:数据准备】

首先为了方便之后制作TFRecord数据集和训练,先在一个文件下建立两个文件夹"train" 和 "test":

然后分别进入这两个文件夹中,创建存放两种动物图片的文件夹"monkey"和"baboon":

现在文件夹中还是空的,没有现成的数据集,于是我通过百度爬虫获取了大约2000张猴子和狒狒的图片,爬虫脚本如下:

# -*- coding: utf-8 -*-
"""根据搜索词下载百度图片"""
import re
import sys
import urllib
import requestskeyWord = '狒狒'
savePath = 'E:/python/2019_02_24/train/baboon/'def get_onepage_urls(onepageurl):"""获取单个翻页的所有图片的urls+当前翻页的下一翻页的url"""if not onepageurl:print('已到最后一页, 结束')return [], ''try:html = requests.get(onepageurl)html.encoding = 'utf-8'html = html.textexcept Exception as e:print(e)pic_urls = []fanye_url = ''return pic_urls, fanye_urlpic_urls = re.findall('"objURL":"(.*?)",', html, re.S)fanye_urls = re.findall(re.compile(r'<a href="(.*)" class="n">下一页</a>'), html, flags=0)fanye_url = 'http://image.baidu.com' + fanye_urls[0] if fanye_urls else ''return pic_urls, fanye_urldef down_pic(pic_urls):"""给出图片链接列表, 下载所有图片"""for i, pic_url in enumerate(pic_urls):try:pic = requests.get(pic_url, timeout=15)string = savePath + str(i + 1) + '.jpg'with open(string, 'wb') as f:f.write(pic.content)print('成功下载第%s张图片: %s' % (str(i + 1), str(pic_url)))except Exception as e:print('下载第%s张图片时失败: %s' % (str(i + 1), str(pic_url)))print(e)continueif __name__ == '__main__':keyword = keyWord  # 关键词, 改为你想输入的词即可, 相当于在百度图片里搜索一样url_init_first = 'http://image.baidu.com/search/flip?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1497491098685_R&pv=&ic=0&nc=1&z=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&ie=utf-8&ctd=1497491098685%5E00_1519X735&word='url_init = url_init_first + urllib.parse.quote(keyword, safe='/')all_pic_urls = []onepage_urls, fanye_url = get_onepage_urls(url_init)all_pic_urls.extend(onepage_urls)fanye_count = 0  # 累计翻页数while True:onepage_urls, fanye_url = get_onepage_urls(fanye_url)fanye_count += 1# print('第页' % str(fanye_count))if fanye_url == '' and onepage_urls == []:breakall_pic_urls.extend(onepage_urls)down_pic(list(set(all_pic_urls)))

其中keyWord就是你想要搜索的关键词,savePath就是存放路径,运行脚本后,程序会自动从百度上自动下载该类图片到你设置的路径下。下载完毕后,将某些无法显示,或者无关的图片删除。最后为了方便,直接剪切小部分图片放入对应类别的test文件夹中。

【step2:制作TFRecord数据】

制作TFRecord数据的脚本网上有很多,大致都是差不多的,我用的脚本如下:

import os
import tensorflow as tf
from PIL import Image  #注意Image,后面会用到
import matplotlib.pyplot as plt
import numpy as npcwd='E:/python/2019_02_24/train/'  #图片存放路径
classes={'monkey','baboon'} #人为 设定 2 类
writer= tf.python_io.TFRecordWriter("20190224_train.tfrecords") #要生成的文件路径for index, name in enumerate(classes):class_path = cwd + name +'/'for img_name in os.listdir(class_path): img_path = class_path + img_name #每一个图片的地址img = Image.open(img_path)img = img.convert("RGB")  # 将图片转成3通道的RGB图片img = img.resize((24, 24))img_raw = img.tobytes() #将图片转化为二进制格式example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))})) #example对象对label和image数据进行封装writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

其中classes中的类别名称一定要和文件夹的名称一样,我这里把每张图片统一缩小为24*24*3的大小。运行之后就可以得到训练集和测试集的TFRecord数据:

【step3 训练】

我的网络结构为:输入-->卷积核为1,步长为1的卷积层-->窗口大小为2,步长为2的最大池化层-->卷积核为1,步长为1的卷积层-->窗口大小为2,步长为2的最大池化层-->卷积核为1,步长为1的卷积层-->窗口大小为6,步长为6的均值池化层-->全连接层-->输出

为了提高模型精度,我特意加入了批量归一化和学习率退化处理。程序如下:

import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm''' 1.数据集准备'''
# 取出数据集
filename_queue1 = tf.train.string_input_producer(["20190224_train.tfrecords"]) #读入流中
reader1 = tf.TFRecordReader()
_, serialized_example1 = reader1.read(filename_queue1)   #返回文件名和文件
features1 = tf.parse_single_example(serialized_example1,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})  #取出包含image和label的feature对象
image1 = tf.decode_raw(features1['img_raw'], tf.uint8)
image1 = tf.reshape(image1, [24, 24, 3])
label1 = tf.cast(features1['label'], tf.int32)# 取出训练集 一定要使用shuffle_batch打乱顺序,否则训练过程中会出现精度0,1之间交替的情况
image_batch, label_batch = tf.train.shuffle_batch([image1, label1],batch_size = 128,capacity=2000,min_after_dequeue=1000)filename_queue2 = tf.train.string_input_producer(["20190224_test.tfrecords"]) #读入流中
reader2 = tf.TFRecordReader()
_, serialized_example2 = reader2.read(filename_queue2)   #返回文件名和文件
features2 = tf.parse_single_example(serialized_example2,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})  #取出包含image和label的feature对象
image2 = tf.decode_raw(features2['img_raw'], tf.uint8)
image2 = tf.reshape(image2, [24, 24, 3])
label2 = tf.cast(features2['label'], tf.int32)# 取出测试集
images_test, labels_test = tf.train.shuffle_batch([image2, label2],batch_size = 512,capacity=2000,min_after_dequeue=1000)''' 2.网络搭建 '''
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')  def avg_pool_6x6(x):return tf.nn.avg_pool(x, ksize=[1, 6, 6, 1],strides=[1, 6, 6, 1], padding='SAME')def batch_norm_layer(value,train = None, name = 'batch_norm'): if train is not None:       return batch_norm(value, decay = 0.9,updates_collections=None, is_training = True)else:return batch_norm(value, decay = 0.9,updates_collections=None, is_training = False)# 定义占位符
x = tf.placeholder(tf.float32, [None, 24, 24, 3]) # 输入为128*128*3
y = tf.placeholder(tf.float32, [None, 2]) # 2类
train = tf.placeholder(tf.float32)# 定义网络结构
W_conv1 = weight_variable([5, 5, 3, 32])
b_conv1 = bias_variable([32])x_image = tf.reshape(x, [-1,24,24,3])h_conv1 = tf.nn.relu(batch_norm_layer((conv2d(x_image, W_conv1) + b_conv1),train))
h_pool1 = max_pool_2x2(h_conv1)W_conv2 = weight_variable([5, 5, 32, 32])
b_conv2 = bias_variable([32])h_conv2 = tf.nn.relu(batch_norm_layer((conv2d(h_pool1, W_conv2) + b_conv2),train))
h_pool2 = max_pool_2x2(h_conv2)W_conv3 = weight_variable([5, 5, 32, 2])
b_conv3 = bias_variable([2])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)nt_hpool3=avg_pool_6x6(h_conv3)#2
nt_hpool3_flat = tf.reshape(nt_hpool3, [-1, 2])y_conv = tf.contrib.layers.fully_connected(nt_hpool3_flat,2,activation_fn=tf.nn.softmax)# 定义交叉熵
cross_entropy = -tf.reduce_sum(y * tf.log(y_conv))
#cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_conv))#加入学习率退化
global_step = tf.Variable(0, trainable=False)
decaylearning_rate = tf.train.exponential_decay(0.04, global_step,1000, 0.9)#定义优化器
train_step = tf.train.AdamOptimizer(decaylearning_rate).minimize(cross_entropy,global_step=global_step)
#train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))''' 3.开始训练'''
sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
for i in range(15000):image_bth, label_b = sess.run([image_batch, label_batch])label_bth = np.eye(2, dtype=float)[label_b] #one hot#print(label_bth)train_step.run(feed_dict={x:image_bth, y: label_bth, train:1}, session=sess)if i % 200 == 0:train_accuracy = accuracy.eval(feed_dict={x:image_bth, y: label_bth}, session=sess)print( "step %d, training accuracy %g"%(i, train_accuracy))
image_bth, label_b = sess.run([images_test, labels_test])
label_bth = np.eye(2,dtype=float)[label_b]
print ("finished! test accuracy %g"%accuracy.eval(feed_dict={x:image_bth, y: label_bth},session=sess))

这里需要注意的是,在取出一批数据集的时候,最好要用tf.train.shuffle_batch函数进行打乱顺序,如果使用按顺序取出批次的方法,在训练过程中,你的精度会一直显示为0或者1。另外,测试集必须要打乱顺序,否则在最后进行测试的时候你的测试精度不管迭代多少次都会是1。理想情况下我们想得到的模型精度确实是1,但现实情况下,具有很好的泛化能力的模型是不可能达到1的,一开始我以为是过拟合的原因,加入正则化项后,发现仍然没有改善,直到改变了迭代次数观察到测试精度始终为1时才发现了这个问题,这里要mark一下。

最后我训练得到的模型精度为:

精度为0.5449,确实有点低,因为二选一的概率都是50%,说明模型训练得极其不理想。于是我查看了一下爬虫所得到的图片,不得不说猴子和狒狒真的长得太像了。。。有些图片我自己还区分不了狒狒还是猴子,狒狒小时候和猴子好像是一毛一样的。。。另外图片里还要很多是卡通漫画,这应该也是影响训练精度的原因。

当然还可以略微提高下精度的:

1.可以加入多通道卷积技术,设置多个不同大小的卷积核进行卷积,这样提取的特征也会多一些,对精度提高有一定帮助。

2.输出的feature map个数我设置的是32,为了提高精度可以设置成64甚至更大,当然这也会加大训练的时间。我源程序中迭代了15000次,输入仅为24*24*3的图片,训练时间用了797s。

3.在制作数据集TFRecord数据集时可以把图片大小放大一些,保留更多的细节特征,对精度提升也有帮助。

4.加大数据集的样本数量

第二天又仔细查看了下结果,发现输出的训练精度有些奇怪,波动幅度有些大,不知道是不是有什么问题,欢迎大佬批评指正!

猴子?狒狒?傻傻分不清楚——制作tfrecord数据集并利用卷积神经网络训练实例相关推荐

  1. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  2. PaddleDetection——使用(jpg + xml)制作VOC数据集并建立PD包

    因为模型需要VOC训练集,而数据集只有图片和已制作好的xml文件,那么只能自己进行VOC数据集的再加工,好,开工! 文章目录 voc数据集格式 生成VOC数据集的txt文件 生成Main文件下的txt ...

  3. 计算机考试excel统计图怎么做,excel表格取数据做统计图-Excel如何制作统计数据...

    如何用一个excel表格上的数据做曲线图. 方法一个连续区域,通过"框"输入地选中单元格,如选中A列的A3:A8区域,在名称框中输入"A3:A8"后回车,即可选 ...

  4. 数字签名?电子签名?傻傻分不清楚

    数字签名?电子签名?傻傻分不清楚! 数字签名与电子签名是紧密地联系在一起的,2000年美国的<全球和国家商业电子签名(ESIGN)法案>.2005年我国的<中华人民共和国电子签名法& ...

  5. UX、UI、 IA和IxD傻傻分不清

    UX.UI. IA和IxD傻傻分不清 UX.UI. IA和IxD傻傻分不清 在以前,一般所说的设计多半是指平面设计.随着数字时代的快速发展,涌现了一批新的职位,因此一些外行人士或者刚入行的人对设计相关 ...

  6. UE、UI、 IA和IxD傻傻分不清

    一直对这些分类搞不清,整理下,来源:https://www.zhihu.com/question/19742332/answer/132833171 &lt;img src="htt ...

  7. 傻傻分不清楚的研究设计类型

    傻傻分不清楚的研究设计类型 1 故事概要 最近,经常被咨询研究设计类型,而画风往往是像脱轨的火车,和过去的学习有一定出入,时常让我陷入深深的自我怀疑.所以,特地整理了基本的研究设计类型和常见的问题. ...

  8. JS魔法堂:属性、特性,傻傻分不清楚

    一.前言 或许你和我一样都曾经被下面的代码所困扰 var el = document.getElementById('dummy'); el.hello = "test"; con ...

  9. ASP.NET MVC涉及到的5个同步与异步,你是否傻傻分不清楚?[下篇]

    关于ASP.NET MVC对请求的处理方式(同步或者异步)涉及到的五个组件,在<上篇>中我们谈了三个(MvcHandler.Controller和ActionInvoker),现在我们来谈 ...

最新文章

  1. 在DataGrid中显示图片
  2. C++ map的使用
  3. 使用OpenVAS 9进行漏洞扫描
  4. stl文件 python_STL文件,一种前处理网格划分技术??
  5. Nginx 0.8.5版本access.log日志分析shell命令
  6. CF1251F Red-White Fence(多项式/背包问题/组合数学)
  7. 面试题 17.16. 按摩师
  8. 程序员如何在 HTTPS 中高效配置通配符证书?| 技术头条
  9. OPPO以技术推动产品 获专利数首次挺近前十
  10. 什么是MySQL视图
  11. Shawn,别让我们失望
  12. php怎么把中文转,php如何把汉字转换成拼音
  13. 【BZOJ3572】【Hnoi2014】世界树 虚树
  14. 计算机基础知识教学反思,计算机基础课教学反思.doc
  15. anchor base和anchor free, 小物体检测, YOLO V1-3 9000 V4 V5 的区别,yolov5-8, yolox创新点
  16. java后台发送post请求 MultipartFile、json
  17. Kruskal(克鲁斯卡尔)算法(图+代码+例题)
  18. 【Echarts图例点击事件】自定义Echarts图例legend点击事件(已解决)
  19. “只有偏执狂才能生存”在中国是怎样变成病毒的
  20. MySQL登录时出现 Access denied for user 'root'@'xxx.xxx.xxx.xxx' (using password: YES) 的原因及解决办法

热门文章

  1. BeanDefinition用法
  2. html5svg地图api,提取 ECharts 中的svg地图信息
  3. 分布式切换历史库总结
  4. 冯小刚说公众人物只能骂不还口打不还手,这是当明星要承担的代价
  5. Leetcode学习笔记(974. 和可被 K 整除的子数组)
  6. 时间序列数据分析与预测之Python工具汇总
  7. Visual Studio SVN创建分支 合并分支 切换分支 vs 插件 visualsvn
  8. 千兆交换机网线制作方法
  9. linux centos7 系统内核参数调优
  10. TIOBE2017年5月编程语言排名