昨天实现了一个简单的CNN网络。用了MNIST数据集,虽然看来对这个数据集用的很多,但是真正这个数据集是怎么在训练的时候被调用的,以及怎么把它换成自己的数据集都是一脸懵。

直接附上链接:MNIST数据集解析

作者给的代码是python2.x版本的,我用的python3.5,改了一些错误。

import numpy as np
import struct
import os
import matplotlib.pyplot as pltimport pickle
import gzip_tag = '>' #使用大端读取
_twoBytes = 'II' #读取数据格式是两个整数
_fourBytes =  'IIII' #读取的数据格式是四个整数
_pictureBytes =  '784B' #读取的图片的数据格式是784个字节,28*28
_lableByte = '1B' #标签是1个字节
_msb_twoBytes = _tag + _twoBytes
_msb_fourBytes = _tag + _fourBytes
_msb_pictureBytes = _tag + _pictureBytes
_msb_lableByte = _tag + _lableBytedef getImage(filename = None):binfile = open(filename, 'rb') #以二进制读取的方式打开文件buf = binfile.read() #获取文件内容缓存区binfile.close()index = 0 #偏移量numMagic, numImgs, numRows, numCols = struct.unpack_from(_msb_fourBytes, buf, index)index += struct.calcsize(_fourBytes)images = []for i in range(numImgs):imgVal  = struct.unpack_from(_msb_pictureBytes, buf, index)index += struct.calcsize(_pictureBytes)imgVal = list(imgVal)#for j in range(len(imgVal)):#   if imgVal[j] > 1:#       imgVal[j] = 1images.append(imgVal)return np.array(images)def getlable(filename=None) :binfile = open(filename, 'rb')buf = binfile.read() #获取文件内容缓存区binfile.close()index = 0 #偏移量numMagic, numItems = struct.unpack_from(_msb_twoBytes,buf, index)index += struct.calcsize(_twoBytes)labels = []for i in range(numItems):value = struct.unpack_from(_msb_lableByte, buf, index)index += struct.calcsize(_lableByte)labels.append(value[0]) #获取值的内容return np.array(labels)def outImg(arrX, arrY, order):#根据指定的order来获取集合中对应的图片和标签test1 = np.array([1,2,3])print(test1.shape)image = np.array(arrX[order])print(image.shape)image = image.reshape(28,28)label = arrY[order]print(label)outfile = str(order) + '_'+str(label) + '.png'plt.figure()plt.imshow(image, cmap="gray_r") # 在MNIST官网中有说道 “Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).”plt.show()#plt.savefig("./" + outfile) #保存图片"""
The second method
"""
def  load_data(filename = None):f = gzip.open(filename, 'rb')training_data, validation_data, test_data = pickle.load(f,encoding='bytes')return (training_data, validation_data, test_data)def test_cPickle():filename = 'MNIST_data/mnist.pkl.gz'training_data, validation_data, test_data = load_data(filename)print(len(test_data))outImg(training_data[0],training_data[1], 1000)#print len(training_data[1])def test():trainfile_X = 'MNIST_data/train-images.idx3-ubyte'trainfile_y = 'MNIST_data/train-labels.idx1-ubyte'arrX = getImage(trainfile_X)arrY = getlable(trainfile_y)outImg(arrX, arrY, 1000)if __name__  == '__main__':#test_cPickle() #test the second methodtest() #test the first method

附上百度百科中魔数的概念(magic number):

很多类型的文件,其起始几个字节的内容是固定的。根据这几个字节的内容可以确定文件的类型,因此这几个字节的内容被称为魔数。此外在一些程序代码中,程序员常常将在代码中出现但没有解释的数字常量或字符串称为魔数 (magic number)或魔字符串。

训练分类自己的图片

找了好几个博客做参考,但都有很多的错误,没改好

以下代码参考博文:tensorflow(六)训练分类自己的图片(CNN超详细入门版)

还有江湖人称星爷的博客

我从CK+表情数据库里选了一些原图作为我的数据集(此处代码和结果已删)

接下来的任务是要把昨天与今天的东西结合起来。

明天待续……


-------18/7/20更新-------

整理一下大致的流程思路

  • 上网爬取所需的图片(我把这个也算作一块知识点)
  • 将不同大小的图片转换成相同的大小(这部分还是有一点小问题没有改好,可以直接看星爷的博客:利用Tensorflow构建自己的图片数据集TFrecords这一节)
  • 对上步得到的图片进行处理(分类制作、banch处理)得到网络的输入(附上代码如下)

把前两天的代码做了修改,查看预处理效果的时候,可以显示彩色图像并直接标注属于何种表情

import os
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as pltangry = []
label_angry = []
disgusted = []
label_disgusted = []
fearful = []
label_fearful = []
happy = []
label_happy = []
sadness = []
label_sadness = []
surprised = []
label_surprised = []def get_file(file_dir):# step1:获取'F:/Python/PycharmProjects/DeepLearning/CK+_part'下所有的图片路径名,存放到# 对应的列表中,同时贴上标签,存放到label列表中。for file in os.listdir(file_dir + '/angry'):angry.append(file_dir + '/angry' + '/' + file)label_angry.append(0)for file in os.listdir(file_dir + '/disgusted'):disgusted.append(file_dir + '/disgusted' + '/' + file)label_disgusted.append(1)for file in os.listdir(file_dir + '/fearful'):fearful.append(file_dir + '/fearful' + '/' + file)label_fearful.append(2)for file in os.listdir(file_dir + '/happy'):happy.append(file_dir + '/happy' + '/' + file)label_happy.append(3)for file in os.listdir(file_dir + '/sadness'):sadness.append(file_dir + '/sadness' + '/' + file)label_sadness.append(4)for file in os.listdir(file_dir + '/surprised'):surprised.append(file_dir + '/surprised' + '/' + file)label_surprised.append(5)# 打印出提取图片的情况,检测是否正确提取print("There are %d angry\nThere are %d disgusted\nThere are %d fearful\n" %(len(angry), len(disgusted), len(fearful)),end="")print("There are %d happy\nThere are %d sadness\nThere are %d surprised\n" %(len(happy),len(sadness),len(surprised)))# step2:对生成的图片路径和标签List做打乱处理把所有的合起来组成一个list(img和lab)# 合并数据numpy.hstack(tup)# tup可以是python中的元组(tuple)、列表(list),或者numpy中数组(array),函数作用是将tup在水平方向上(按列顺序)合并image_list = np.hstack((angry, disgusted, fearful, happy, sadness, surprised))label_list = np.hstack((label_angry, label_disgusted, label_fearful, label_happy, label_sadness, label_surprised))# 利用shuffle,转置、随机打乱temp = np.array([image_list, label_list])   # 转换成2维矩阵temp = temp.transpose()     # 转置# numpy.transpose(a, axes=None) 作用:将输入的array转置,并返回转置后的arraynp.random.shuffle(temp)     # 按行随机打乱顺序函数# 将所有的img和lab转换成listall_image_list = list(temp[:, 0])    # 取出第0列数据,即图片路径all_label_list = list(temp[:, 1])    # 取出第1列数据,即图片标签label_list = [int(i) for i in label_list]   # 转换成int数据类型'''# 将所得List分为两部分,一部分用来训练tra,一部分用来测试valratio = 0.8n_sample = len(all_label_list)n_val = int(math.ceil(n_sample * ratio))  # 测试样本数, ratio是测试集的比例n_train = n_sample - n_val  # 训练样本数tra_images = all_image_list[0:n_train]tra_labels = all_label_list[0:n_train]tra_labels = [int(float(i)) for i in tra_labels]   # 转换成int数据类型val_images = all_image_list[n_train:-1]val_labels = all_label_list[n_train:-1]val_labels = [int(float(i)) for i in val_labels]   # 转换成int数据类型return tra_images, tra_labels, val_images, val_labels'''return image_list, label_list# 为了方便网络的训练,输入数据进行batch处理
# image_W, image_H, :图像高度和宽度
# batch_size:每个batch要放多少张图片
# capacity:一个队列最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):# step1:将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue# tf.cast()用来做类型转换image = tf.cast(image, tf.string)   # 可变长度的字节数组.每一个张量元素都是一个字节数组label = tf.cast(label, tf.int32)# tf.train.slice_input_producer是一个tensor生成器# 作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。input_queue = tf.train.slice_input_producer([image, label])label = input_queue[1]image_contents = tf.read_file(input_queue[0])   # tf.read_file()从队列中读取图像# step2:将图像解码,使用相同类型的图像image = tf.image.decode_jpeg(image_contents, channels=3)# jpeg或者jpg格式都用decode_jpeg函数,其他格式可以去查看官方文档# step3:数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮。image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)# 对resize后的图片进行标准化处理# image = tf.image.per_image_standardization(image)# step4:生成batch# image_batch: 4D tensor [batch_size, width, height, 3], dtype = tf.float32# label_batch: 1D tensor [batch_size], dtype = tf.int32image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=16, capacity=capacity)# 重新排列label,行数为[batch_size]label_batch = tf.reshape(label_batch, [batch_size])image_batch = tf.cast(image_batch, tf.uint8)    # 显示彩色图像# image_batch = tf.cast(image_batch, tf.float32)    # 显示灰度图return image_batch, label_batch# 获取两个batch,两个batch即为传入神经网络的数据def PreWork():# 对预处理的数据进行可视化,查看预处理的效果IMG_W = 256IMG_H = 256BATCH_SIZE = 5CAPACITY = 64train_dir = 'F:/Python/PycharmProjects/DeepLearning/CK+_part'# image_list, label_list, val_images, val_labels = get_file(train_dir)image_list, label_list = get_file(train_dir)image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)lists = ('angry', 'disgusted', 'fearful', 'happy', 'sadness', 'surprised')with tf.Session() as sess:i = 0coord = tf.train.Coordinator()  # 创建一个线程协调器,用来管理之后在Session中启动的所有线程threads = tf.train.start_queue_runners(coord=coord)try:while not coord.should_stop() and i < 2:# 提取出两个batch的图片并可视化。img, label = sess.run([image_batch, label_batch])  # 在会话中取出img和label# img = tf.cast(img, tf.uint8)'''1、range()返回的是range object,而np.arange()返回的是numpy.ndarray()range(start, end, step),返回一个list对象,起始值为start,终止值为end,但不含终止值,步长为step。只能创建int型list。arange(start, end, step),与range()类似,但是返回一个array对象。需要引入import numpy as np,并且arange可以使用float型数据。2、range()不支持步长为小数,np.arange()支持步长为小数3、两者都可用于迭代range尽可用于迭代,而np.nrange作用远不止于此,它是一个序列,可被当做向量使用。'''for j in np.arange(BATCH_SIZE):# np.arange()函数返回一个有终点和起点的固定步长的排列print('label: %d' % label[j])plt.imshow(img[j, :, :, :])title = lists[int(label[j])]plt.title(title)plt.show()i += 1except tf.errors.OutOfRangeError:print('done!')finally:coord.request_stop()coord.join(threads)if __name__ == '__main__':PreWork()

随便挑着放了两张效果图

附言:我的代码一如既往有很多很多的注释,不谈标不标准,单纯是记录,方便自己理解

利用tensorflow训练自己的图片数据集——数据准备相关推荐

  1. 完整实现利用tensorflow训练自己的图片数据集

    经过差不多一个礼拜的时间的学习,终于把完整的一个利用自己爬取的图片做训练数据集的卷积神经网络的实现(基于tensorflow) 目录 总体思路 第三部分:预处理 第四部分:网络模型 第五部分:训练 2 ...

  2. 利用tensorflow训练自己的图片数据(1)——预处理

    一. 准备原始数据 首先,我们需要准备训练的原始数据,本次训练为图像分类识别,因而一开始,笔者从网上随机的下载了Dog的四种类别:husky,jiwawa,poodle,qiutian.每种类别30种 ...

  3. Tensorflow生成自己的图片数据集TFrecords(支持多标签label)

    Tensorflow生成自己的图片数据集TFrecords 尊重原创,转载请注明出处:https://blog.csdn.net/guyuealian/article/details/80857228 ...

  4. 利用Tensorflow构建RNN并对序列数据进行建模

    利用Tensorflow构建RNN并对序列数据进行建模 对文本处理处理任务的方法中,一般将TF-IDF向量作为特征输入.显然的缺陷是:这种方法丢失了输入的文本序列中每个单词的顺序. 对一般的前馈神经网 ...

  5. tensorflow训练自己的声音数据集进行声音分类

    ** tensorflow训练自己的声音数据集进行声音分类 ** 环境 win10 anaconda3.5 tensorflow 2.0 1.安装anaconda https://pan.baidu. ...

  6. CNN实现训练自己的图片数据集

    https://blog.csdn.net/wills798/article/details/80638151

  7. TensorFlow笔记(3)——利用TensorFlow和MNIST数据集训练一个最简单的手写数字识别模型...

    前言 当我们开始学习编程的时候,第一件事往往是学习打印"Hello World".就好比编程入门有Hello World,机器学习入门有MNIST. MNIST是一个入门级的计算机 ...

  8. 在服务器上利用mmdetection来训练自己的voc数据集

    在服务器上利用mmdetection来训练自己的voc数据集 服务器上配置mmdetection环境 在服务器上用anaconda配置自己的环境 进入自己的虚拟环境,开始配置mmdetection 跑 ...

  9. TensorFlow学习笔记——使用TFRecord进行数据保存和加载

    本篇文章主要介绍如何使用TensorFlow构建自己的图片数据集TFRecord的方法,并使用最新的数据处理Dataset API进行操作. TFRecord TFRecord数据文件是一种对任何数据 ...

最新文章

  1. 关于学习Python的一点学习总结(24->列表推导)
  2. 光纤通道速率查看_基于OM3/OM4的光纤通道连接方案
  3. linux 信号_Linux信号机制
  4. WPF 程序最大化,窗口,最小化
  5. 一步步编写操作系统 45 linux的elf可执行文件中的段和节
  6. 云+X案例展 | 民生类:中国电信天翼云携手国家天文台打造“大国重器”
  7. NLP 《马尔科夫链》
  8. 五步git操作搞定Github中fork的项目与原作者同步
  9. JDK自带观察者的使用
  10. Gym 101775 D (思维)
  11. 计算机网络第8版课后习题答案整理
  12. 谷歌升级街景采集车,用AI获取更佳图像
  13. 写给2018考研的你
  14. 使用ffmpeg批量合并flv文件
  15. Android 应用程序之间内容分享详解(一)
  16. python通过pypiwin32操作PPT
  17. 快速排序随机选取主元的重要性
  18. GreenPlum 大数据平台--segment 失效问题恢复
  19. Hitters数据集数据分析
  20. python import turtle as t_turtle库的学习笔记(python)

热门文章

  1. 解密Airbnb 自助BI神器:Superset
  2. 教师计算机西沃培训心得,学习使用希沃电子白板的心得体会
  3. windows虚机环境下,如何快速有效的删除大文件夹?
  4. jupyter和spider;Anaconda、Python、Jupyter、Pycharm、Spyder、conda、pip
  5. 闪存中的NorFlash、NandFlash及eMMC三者的区别
  6. 神兽传说JAVA下载_JAVA游戏神兽传说攻略
  7. 笔记系列---------sqlnet.ora维护
  8. IEEE754浮点数算数标准
  9. 外贸供应链ERP怎么选?全流程综合管理解析
  10. libssl-dev : 依赖: libssl1.0.0 (= 1.0.2g-1ubuntu4.13) 但是 1.0.2n-1ubuntu5.1 正要被安装