摘要:

在tensorflow里面,已经封装了一些常用的网络结构模型,比如CNN,DNN这些,有时候我们需要自己搭一个网络结构或者需要了解每个网络是如何搭建起来的,相信这篇文章会对你有一点点所帮助。
已有模型

正文:

model_fn:

我们首先需要写一个函数描述我们的网络,这个也是自己搭建模型最为核心的一步,通常需要做以下的事情

1.基本的网络结构(卷积、池化、全连接等等)

卷积层:
tf.layers.conv2d(inputs=input_layer,filters=32,kernel_size=[5, 5],padding="same",activation=tf.nn.relu)
inputs:输入层
filters:卷积核的个数
kernel_size:卷积核的尺寸
padding这个参数决定了卷积后的tensor是否与输入tensor有相同的width和height。
当padding='same'时表示卷积后的tensor与输入tensor尺寸不变。
当padding='valid'(默认),卷积后的W和H按照公式计算公式—第6点
strides:步长
activation,选择激活函数。
tensorflow常用激活函数

池化层:

tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
inputs:输入层
pool_size:池化窗口大小
strides:步长
在接全连接层时一般有一个“拍平”tensor的过程,具体如下:
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

全连接:

tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)

inputs:输入层

units:神经元个数
activation:激活函数

drop-out层:
tf.layers.dropout(inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

rate:drop-out的比例。

traininig:判断当前是否为训练模式(因为只有在training的时候才会进行drop-out,eval和predict都不会drop-out)

2.计算Loss层。

每个训练任务都会有一个损失函数,下面是关于loss function的定义:

这里需要注意的是,对于回归任务loss层应该如下:

output_layer = tf.layers.dense(second_hidden_layer, 1)
predictions = tf.reshape(output_layer, [-1])
loss = tf.losses.mean_squared_error(labels, predictions)

可以看到,最后一层是一个没有激活函数单个神经元的全连接的层,接着改变tensor的维度为1。之后才计算Loss。

对于分类任务层应该如下:

logits = tf.layers.dense(inputs=dropout, units=10) #(也是没有神经元的)
onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)

可以看到,对于分类任务,需要先对标签做一次One-hot(二分类任务也需要)。之后再送入Loss层计算loss。

3.定义training op,具体形式如下:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"])
train_op=optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

4.return

在model_fn最后一部分为return一个 tf.estimator.EstimatorSpec(),根据训练、评价、预测可以分为3中情况进行return。
对于PREDICT模式,返回mode(当前模式)和预测结果(prediction),预测结果以字典形式。

    return tf.estimator.EstimatorSpec(mode=mode,predictions={"ages": predictions})

(predictions其实就是最后一层的输出)

对于TRAIN模式来说,需要返回mode、loss、train_op。
其中train_op其实就是训练时使用的优化器。例如,

optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"])
train_op=optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

loss就是损失函数:

loss = tf.losses.mean_squared_error(labels, predictions)
具体写法:
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

对于EVAL模式来说,需要返回mode、loss。还有一个参数可以选择eval_metric_ops。这个可用于保存某个指标的预测结果。比如说,希望得到准确率的值可以:
eval_metric_ops = { "accuracy": tf.metrics.accuracy(labels, predictions) }
如果不指定eval_metric_ops那么返回的指标里面只有损失loss。

具体写法:
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

可以看到,其实TRAIN和EVAL可以放在一起返回。
具体写法:
  return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,eval_metric_ops=eval_metric_ops)

具体参考下面官方原版说明:

model_fn基本上需要由上面提到的东西构成,最后说一个关于数据格式的问题。

5.tensor维度问题

如果网络对tensor的维度有要求,比如需要做卷积,那么在model_fn中输入层可以如下写
input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])
其中-1的意思是说这一维的数值是动态计算的。其会根据feature["x"]来进行计算。其实这个就是我们所知道的batch_size。

最后input_layer会转成一个4维的tensor。

注意,默认feature["x"]是一个2维的tensor。
在上面,模型搭建完毕之后就可以开始训练-评价-预测。

6.训练

训练前需要先实例化一个学习器,具体如下:
classifier = tf.estimator.Estimator(model_fn=model_fn, model_dir="/tmp/mnist_convnet_model",params=param)

model_fn 模型函数,这个就是我们上面所写的model_fn
model_dir 保存模型的地方
params 模型可能用到的参数。

训练模型的时候需要先定义输入数据。一般可用tf.estimator.inputs.numpy_input_fn()或者tf.estimator.inputs.pandas_input_fn()两种方法来加载数据。
两个参数是一样的,下面介绍下tf.estimator.inputs.numpy_input_fn()的参数

numpy_input_fn(x,y=None,batch_size=128,num_epochs=1,shuffle=None,queue_capacity=1000,num_threads=1
)

x是训练数据,使用numpy_input_fn时就是numpy array的格式。使用Pandas的时候就是dataframe的格式。注意训练数据不包含标签值。
y是标签。
batch_size,每一次迭代的时候选择的样本个数(这个涉及到梯度的计算,有疑惑的可以读读这篇文章),默认选择是128个样本。
num_epochs参数解释:
epoch是一个很重要的参数,这里详细的解释一下。首先一个epoch是指把整个数据集训练一次。举个例子,比如训练集理由100张图片,1个ephoch就是指每张图片都被训练了一次。2个epoch就是指 先对整个数据集训练一次后,再来对整个数据集训练一次。3个epoch就依次类推。

除了一个epoch这个概念外,还有一个就是steps(迭代次数)。
下面具体结合epoch和steps,batch_size这三个参数的设置来看看对最后的训练次数有什么影响:
(以下均假设整个数据集大小为3328条)
epoch=None,steps=100,batch_size=128:
当epoch=None的时候就是说,训练的停止条件是达到迭代次数100。(这个时候其实可以算得到整个数据集被训练了100/(3328/128)=3.84次)
epoch=1,steps=100,batch_size=128:
整个数据集共3328条数据,batch_size为128,所以迭代26次(3328/128)时可以实现整个数据集被训练了一次,所以实际上迭代26次就停止训练了。
epoch=4,steps=100,batch_size=128:
和上面类似,只不过这里的epoch=4,故数据集总共需要被训练4次,故迭代次数总共需要4*(3328/128)=104次,但是104<100次,所以100次的时候训练也停止了。
epoch=100,steps=None,batch_size=128:
这个时候steps不指定意味着停止条件是达到epoch的次数。
所以当整个数据集被训练了100次的时候停止训练。此时的迭代次数其实是(3328/128)*100=2600次。
相信举了几个例子后大家对epoch、steps(迭代次数)、batch_size有了点认识。

shuffle打乱数据集。
其他参数这里就不一一解释了。给出官方的原版解释:
queue_capacity: Integer, size of queue to accumulate.
num_threads: Integer, number of threads used for reading and enqueueing. In order to have predicted and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, num_threads should be 
上面讲到输入数据指定,具体指定如下:

train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": train_data},y=train_labels,batch_size=100,num_epochs=None,shuffle=True)

指定完之后可以开始训练了,具体如下:

classifier.train(input_fn=train_input_fn,steps=20000)

指定训练数据以及迭代的次数。

7.评价

eval_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": eval_data},y=eval_labels,num_epochs=1,shuffle=False)

这里的num_epochs=1相信大家可以理解了,就是把整个测试集测试一次。

eval_results =classifier.evaluate(input_fn=eval_input_fn)

8.预测

predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": prediction_set.data},num_epochs=1,shuffle=False)
predictions = classifier.predict(input_fn=predict_input_fn)

这里需要注意的是,评价/预测返回的结果是字典格式的,这个部分我们在model_fn里面都设置了。
对于评价的返回结果是eval_metric_ops
对于预测的返回结果是predictions

---------------------------------------------------------------------------------------------------------------

至此,这个流程都介绍的差不多了。下面给出两个tensorflow的demo,一个是回归任务,一个是分类任务。
回归任务:https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/estimators/abalone.py
分类任务:https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/layers/cnn_mnist.py
相关文档:
https://www.tensorflow.org/extend/estimators
https://www.tensorflow.org/tutorials/layers#training_and_evaluating_the_cnn_mnist_classifier
https://www.tensorflow.org/api_docs/python/tf/estimator/inputs/numpy_input_fn
https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator

Tensorflow学习-自定义模型相关推荐

  1. Tensorflow学习教程------模型参数和网络结构保存且载入,输入一张手写数字图片判断是几...

    首先是模型参数和网络结构的保存 #coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist impor ...

  2. tensorflow学习5----GAN模型初探

    生成模型: 通过观测学习样本和标签的联合概率分布P(X,Y)进行训练,训练好的模型能够生成符合样本分布的新数据,在无监督学习方面,生成式模型能够捕获数据的高阶相关性,通过学习真实数据的本质特征,刻画样 ...

  3. 深度学习TensorFlow学习-自定义网络

    上一篇文章讲解如何使用tf.keras 快速搭建网络,这篇讲解自定义一个神经网络结构. 使用 Sequential 可以快速搭建网络结构,但是如果网络包含跳连等其他复杂网络结构, Sequential ...

  4. 【tensorflow速成】Tensorflow图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [tensorflow速成]Tensorflow图像分类从模型自定义到测试 这是给大家准备的tensorflow速成例子 上一篇介绍了 Caffe , ...

  5. TensorFlow 2.0 - 自定义模型、训练过程

    文章目录 1. 自定义模型 2. 学习流程 学习于:简单粗暴 TensorFlow 2 1. 自定义模型 重载 call() 方法,pytorch 是重载 forward() 方法 import te ...

  6. TensorFlow学习笔记——实现经典LeNet5模型

    TensorFlow实现LeNet-5模型 文章目录 TensorFlow实现LeNet-5模型 前言 一.什么是TensorFlow? 计算图 Session 二.什么是LeNet-5? INPUT ...

  7. 深度学习利器:TensorFlow与NLP模型

    深度学习利器:TensorFlow与NLP模型 享到:微博微信FacebookTwitter有道云笔记邮件分享 稍后阅读 我的阅读清单 前言 自然语言处理(简称NLP),是研究计算机处理人类语言的一门 ...

  8. AI开发者大会之AI学习与进阶实践:2020年7月3日《如何转型搞AI?》、《基于AI行业价值的AI学习与进阶路径》、《自动机器学习与前沿AI开源项目》、《使用TensorFlow实现经典模型》

    AI开发者大会之AI学习与进阶实践:2020年7月3日<如何转型搞AI?>+<无行业不智能:基于AI行业价值的AI学习与进阶路径>.<自动机器学习与前沿AI开源项目> ...

  9. 人脸口罩检测现开源PyTorch、TensorFlow、MXNet等全部五大主流深度学习框架模型和代码...

    号外!号外! 现在,AIZOO开源PyTorch.TensorFlow.MXNet.Keras和Caffe五大主流深度学习框架的人脸检测模型和代码啦! 先附上Github链接为敬. https://g ...

  10. Pytorch学习记录(七):自定义模型 Auto-Encoders 使用numpy实现BP神经网络

    文章目录 1. 自定义模型 1.1 自定义数据集加载 1.2 自定义数据集数据预处理 1.3 图像数据存储结构 1.4 模型构建 1.5 训练模型 2. Auto-Encoders 2.1 无监督学习 ...

最新文章

  1. MySQL 5.7中的更多改进,包括计算列
  2. 26.angularJS $routeProvider
  3. uva673 Parentheses Balance
  4. 蓝桥杯-K好数(java)
  5. java写一个音乐播放器源码_求一个JAVA音乐播放器的源代码
  6. 前世档案 (15 分)
  7. this关键字 和 private关键字
  8. 全国计算机一级试题重难点,全国计算机等级考试一级MS选择题(重难点)部分.doc...
  9. CHM乱码解决方案!
  10. Kava下一阶段Kava 5主网将于3月4日上线
  11. 诹图系列(2): 堆积条形图
  12. java 微信卡券开发 --创建微信卡券
  13. 【kafka】kafka windows Invalid UTF-8 middle byte 0xfe
  14. 公司老总直接面试 我该如何准备
  15. 浅谈无人值守改造技术在矿山供电系统的应用研究
  16. 取消Editplus的自动备份
  17. EasyPlayer流媒体播放器播放HLS视频,起播速度慢的技术优化
  18. 代码批量删除QQ日志和说说
  19. 2021年陕西省安全员C证考试内容及陕西省安全员C证考试资料
  20. 2023年,如何管理你的绩效目标?

热门文章

  1. 非负数 正则表达式
  2. 红警地图编辑器的使用方法
  3. html360全景图原理,HTML5中Canvas如何实现360度全景图
  4. 红米5plus开发者选项怎么打开?
  5. [ctf misc][2021祥云杯初赛]层层取证
  6. 黑桃怎么用html代码,index.html
  7. 魔域充值卡表cq_card里chk_sum参数的算法
  8. 怒怼|扎克伯格到底是个怎样的人
  9. iPhone连接Mac电脑总是断开
  10. Python之父退休,龟叔与Python的渊源