前言

大四的时候大致的看过一本基于tensorflow的实战Google深度学习框架的书,目前看论文源码也好,修改代码做改进也好,很多基本知识还是源于那个时候。这是远远不够的,为此,我在github上找了一个基于tensorflow的实例管理教程,来再细致的学习一下tensorflow,希望能够增强自己读代码,写代码的能力,对深度学习也有更好的理解。

一.数据准备

具体的学习过程,因为有之前的一些基础,为此直接从各种神经网络模型入手,来学习tensorflow框架,并且还可以对模型进一步的进行理解。在模型的搭建训练之前,首先就是训练测试数据的输入是如何实现的。下面结合代码,分块讲解。

数据集获取。建议先下载到本地,通过函数去下载相对不太稳定,耗时较长。

from __future__ import print_function
import gzip
import os
import urllib
import numpy
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):"""如果数据集不存在,从yann's的网站下载所需的数据集."""if not os.path.exists(work_directory):os.mkdir(work_directory)filepath = os.path.join(work_directory, filename)if not os.path.exists(filepath):filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)statinfo = os.stat(filepath)print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')return filepath

定义了所获取到的字节流的存储,采用大尾端的方式,返回值就是字节流中的前四个:magic,num_images,rows,cols。

def _read32(bytestream):dt = numpy.dtype(numpy.uint32).newbyteorder('>')return numpy.frombuffer(bytestream.read(4), dtype=dt)

对图片进行提取并将其转化为一个四维的numpy数组。

def extract_images(filename):"""将输入的图片转化为一个uint8类型四维的numpy数组 [index, y, x, depth]."""print('Extracting', filename)with gzip.open(filename) as bytestream:""" magic是指对应的幻数,相当于该部分数据集的专属标识一样,通过判定magic值是否与应该的值相等,来判断读取的数据集是否正确"""magic = _read32(bytestream)  if magic != 2051:raise ValueError('Invalid magic number %d in MNIST image file: %s' %(magic, filename))num_images = _read32(bytestream)rows = _read32(bytestream)cols = _read32(bytestream)buf = bytestream.read(rows * cols * num_images)data = numpy.frombuffer(buf, dtype=numpy.uint8) #以流的形式将需要的数据转化为uint8类型data = data.reshape(num_images, rows, cols, 1)return data

常用的one-hot矢量转换:

def dense_to_one_hot(labels_dense, num_classes=10):"""将类标签由标量转换为one-hot矢量."""num_labels = labels_dense.shape[0]index_offset = numpy.arange(num_labels) * num_classeslabels_one_hot = numpy.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1return labels_one_hot

为了帮助大家更好的理解每一步,把每一个变量的值进行一个测试输出来直观理解:

>>> labels_dense = np.array([0,1,2])
>>> labels_dense
array([0, 1, 2])
>>> num_classes = 3
>>> labels_dense
array([0, 1, 2])
>>> num_labels = labels_dense.shape[0]
>>> num_labels
3
>>> index_offset = np.arange(num_labels)*num_classes
>>> index_offset
array([0, 3, 6])
>>> labels_one_hot = np.zeros((num_labels, num_classes))
>>> labels_one_hot
array([[ 0.,  0.,  0.],[ 0.,  0.,  0.],[ 0.,  0.,  0.]])
>>> labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
>>> labels_one_hot
array([[ 1.,  0.,  0.],[ 0.,  1.,  0.],[ 0.,  0.,  1.]])

对标签的转换过程类似于图片:

def extract_labels(filename, one_hot=False):"""提取标签并将其转换为一个一维的uint8类型的numpy数组."""print('Extracting', filename)with gzip.open(filename) as bytestream:magic = _read32(bytestream)if magic != 2049:raise ValueError('Invalid magic number %d in MNIST label file: %s' %(magic, filename))num_items = _read32(bytestream)buf = bytestream.read(num_items)labels = numpy.frombuffer(buf, dtype=numpy.uint8)if one_hot:return dense_to_one_hot(labels)return labels

定义一个数据集:

class DataSet(object):def __init__(self, images, labels, fake_data=False):if fake_data:self._num_examples = 10000else:assert images.shape[0] == labels.shape[0], ("images.shape: %s labels.shape: %s" % (images.shape,labels.shape))self._num_examples = images.shape[0]# Convert shape from [num examples, rows, columns, depth]# to [num examples, rows*columns] (assuming depth == 1)assert images.shape[3] == 1images = images.reshape(images.shape[0],images.shape[1] * images.shape[2])# Convert from [0, 255] -> [0.0, 1.0].images = images.astype(numpy.float32)images = numpy.multiply(images, 1.0 / 255.0)self._images = imagesself._labels = labelsself._epochs_completed = 0self._index_in_epoch = 0"""@property装饰器负责把一个方法当作属性来调用,这里只定义了getter方法,没有定义setter方法"""@propertydef images(self):return self._images@propertydef labels(self):return self._labels@propertydef num_examples(self):return self._num_examples@propertydef epochs_completed(self):return self._epochs_completed
  def next_batch(self, batch_size, fake_data=False):"""从这个数据集中获得下一个batch大小的示例样本."""if fake_data:fake_image = [1.0 for _ in xrange(784)]fake_label = 0return [fake_image for _ in xrange(batch_size)], [fake_label for _ in xrange(batch_size)]start = self._index_in_epochself._index_in_epoch += batch_sizeif self._index_in_epoch > self._num_examples:# 结束这个epochself._epochs_completed += 1# 对数据进行随机清洗perm = numpy.arange(self._num_examples)numpy.random.shuffle(perm)self._images = self._images[perm]self._labels = self._labels[perm]# 开始下一个epochstart = 0self._index_in_epoch = batch_sizeassert batch_size <= self._num_examplesend = self._index_in_epochreturn self._images[start:end], self._labels[start:end]

读取所需要用到的数据集,所用到的mnist数据集可以从http://yann.lecun.com/exdb/mnist/下载:

def read_data_sets(train_dir, fake_data=False, one_hot=False):class DataSets(object):passdata_sets = DataSets()if fake_data:data_sets.train = DataSet([], [], fake_data=True)data_sets.validation = DataSet([], [], fake_data=True)data_sets.test = DataSet([], [], fake_data=True)return data_setsTRAIN_IMAGES = 'train-images-idx3-ubyte.gz'TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'TEST_IMAGES = 't10k-images-idx3-ubyte.gz'TEST_LABELS = 't10k-labels-idx1-ubyte.gz'VALIDATION_SIZE = 5000local_file = maybe_download(TRAIN_IMAGES, train_dir)train_images = extract_images(local_file)local_file = maybe_download(TRAIN_LABELS, train_dir)train_labels = extract_labels(local_file, one_hot=one_hot)local_file = maybe_download(TEST_IMAGES, train_dir)test_images = extract_images(local_file)local_file = maybe_download(TEST_LABELS, train_dir)test_labels = extract_labels(local_file, one_hot=one_hot)validation_images = train_images[:VALIDATION_SIZE]validation_labels = train_labels[:VALIDATION_SIZE]train_images = train_images[VALIDATION_SIZE:]train_labels = train_labels[VALIDATION_SIZE:]data_sets.train = DataSet(train_images, train_labels)data_sets.validation = DataSet(validation_images, validation_labels)data_sets.test = DataSet(test_images, test_labels)return data_sets

这就是一个基本的使用tensorflow对数据进行处理准备的过程,接下来具体讲到用tensorflow框架如何实现不同的网络模型的搭建训练及测试。

tensorflow框架精细讲解(一)相关推荐

  1. 掌握深度学习,为什么要用 PyTorch、TensorFlow 框架?

    全世界只有3.14 % 的人关注了 爆炸吧知识 自从2012年深度学习再一次声名鹊起以来,许多机器学习框架都争先恐后地要成为研究人员和行业从业者的新宠.面对如些众多的选择,人们很难判断最流行的框架到底 ...

  2. TensorFlow框架的这些操作你肯定不知道!

    谷歌在上周正式推出了深度学习框架TensorFlow 1.11.0 版本,那么TensorFlow框架到底是什么? TensorFlow™ 是一个采用数据流图(data flow graphs),用于 ...

  3. 借助TensorFlow框架,到底能做什么?

    谷歌在七月份正式推出了深度学习框架TensorFlow 1.9 版本,那么TensorFlow框架到底是什么? TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值 ...

  4. Vue全家桶入门精细讲解

    Vue入门精细讲解 感谢coderwhy老师的精心讲解,本笔记全部内容源于coderwhy老师的课堂笔记: 一. Hello Vuejs 1.1. 认识Vuejs 为什么学习Vuejs 可能你的公司正 ...

  5. Tensorflow框架是如何支持分布式训练的?

    参加 2019 Python开发者日,请扫码咨询 ↑↑↑ 作者 | 杨旭东 转载自知乎<算法工程师的自我修养>专栏 Methods that scale with computation ...

  6. AI工程师面试知识点:TensorFlow 框架

    AI工程师面试知识点:TensorFlow 框架

  7. Pycharm中tensorflow框架下tqdm的安装

    基本环境 win 10 tensorflow-cpu pycharm // tensorflow程序里错误结果显示from tqdm import tqdm ImportError: cannot i ...

  8. Uber开源TensorFlow框架工具箱Ludwig,无需编码即可进行深度学习开发

    日前,网约车服务商 Uber 开源并发布了它们开发的 Ludwig,这是一款基于 Google TensorFlow 框架上的开源工具箱.藉由 Ludwig,用户无需再编写任何代码即可进行深度学习的开 ...

  9. 【深度学习】Keras和Tensorflow框架使用区别辨析

    [深度学习]Keras和Tensorflow框架使用区别辨析 文章目录 1 概述 2 Keras简介 3 Tensorflow简介 4 使用tensorflow的几个小例子 5 Keras搭建CNN ...

最新文章

  1. 从Uber微服务看最佳实践如何炼成?
  2. ML之FE:数据处理—特征工程之特征三化(标准化【四大数据类型(数值型/类别型/字符串型/时间型)】、归一化、向量化)简介、代码实现、案例应用之详细攻略
  3. RabbitMq初探——安装
  4. python生成列表_python列表生成器与生成器
  5. 让sublime text显示空格,到底是点还是横杠TabError: inconsistent use of tabs and spaces in indentation
  6. 如何区分两列中不同数据_如何区分原装数据线和山寨数据线
  7. 04 Mysql之单表查询
  8. BlogEngine.Net架构与源代码分析系列part5:对象搜索——IPublishable与Search
  9. 116 Python GIL全局解释器锁
  10. 2011—2018年软考中级数据库系统工程师历年真题
  11. VFP开眼看世界的第一眼,就是学会真正的BS开发,走错一步费三年
  12. 复旦新生计算机考试及格率,复旦大学本科新生《计算机办公自动化》课程入学考试考核大.doc...
  13. 如何写出高性能SQL语句?-性能设计沉思录(6)
  14. 基于STM32的简易交通灯设计
  15. linux iptable 使用指南
  16. ntag213和215有什么区别_NTAG213、NTAG215和NTAG216NFC标签
  17. 复印身份证所引发的一系列问题与思考
  18. windows桌面动态主题_学习Windows 7:桌面主题和背景
  19. python编程游戏-9个Python编程小游戏,有趣又好玩,简直太棒了
  20. 2020总结 2021规划

热门文章

  1. C++ accumulate()的使用
  2. fstream,ifstream,ofstream 详解与用法
  3. 32位微型计算机中的32字的是,32位微型计算机中32指的是
  4. 苹果开始向全球扩展iAd平台
  5. OpenCV Mat与uchar*指针相互转换赋值
  6. mysql alter 改密码_MySql修改密码
  7. 名片管理系统java_java毕业设计_springboot框架的名片管理系统
  8. 字节跳动算法工程师总结:成功入职阿里月薪45K
  9. 查看并修改Linux主机名命令hostname
  10. Python 83道经典练习题,含答案!