TensorFlow学习笔记(九)tf搭建神经网络基本流程
1. 搭建神经网络基本流程
定义添加神经层的函数
1.训练的数据
2.定义节点准备接收数据
3.定义神经层:隐藏层和预测层
4.定义 loss 表达式
5.选择 optimizer 使 loss 达到最小
然后对所有变量进行初始化,通过 sess.run optimizer,迭代 1000 次进行学习:
import tensorflow as tf
import numpy as np# 添加层
def add_layer(inputs, in_size, out_size, activation_function=None):# add one more layer and return the output of this layerWeights = tf.Variable(tf.random_normal([in_size, out_size]))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)Wx_plus_b = tf.matmul(inputs, Weights) + biasesif activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b)return outputs# 1.训练的数据
# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise# 2.定义节点准备接收数据
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])# 3.定义神经层:隐藏层和预测层
# add hidden layer 输入值是 xs,在隐藏层有 10 个神经元
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer 输入值是隐藏层 l1,在预测层输出 1 个结果
prediction = add_layer(l1, 10, 1, activation_function=None)# 4.定义 loss 表达式
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))# 5.选择 optimizer 使 loss 达到最小
# 这一行定义了用什么方式去减少 loss,学习率是 0.1
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)# important step 对所有变量进行初始化
#init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess = tf.Session()
# 上面定义的都没有运算,直到 sess.run 才会开始运算
sess.run(init)# 迭代 1000 次学习,sess.run optimizer
for i in range(1000):# training train_step 和 loss 都是由 placeholder 定义的运算,所以这里要用 feed 传入参数sess.run(train_step, feed_dict={xs: x_data, ys: y_data})if i % 50 == 0:# to see the step improvementprint(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
2. 主要步骤的解释:
import tensorflow as tf
import numpy as np
- 导入或者随机定义训练的数据 x 和 y:
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data*0.1 + 0.3
- 先定义出参数 Weights,biases,拟合公式 y,误差公式 loss:
Weights = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
biases = tf.Variable(tf.zeros([1]))
y = Weights*x_data + biases
loss = tf.reduce_mean(tf.square(y-y_data))
- 选择 Gradient Descent 这个最基本的 Optimizer:
optimizer = tf.train.GradientDescentOptimizer(0.5)
- 神经网络的 key idea,就是让 loss 达到最小:
train = optimizer.minimize(loss)
- 前面是定义,在运行模型前先要初始化所有变量:
init = tf.initialize_all_variables()
- 接下来把结构激活,sesseion像一个指针指向要处理的地方:
sess = tf.Session()
- init 就被激活了,不要忘记激活:
sess.run(init)
- 训练201步:
for step in range(201):
- 要训练 train,也就是 optimizer:
sess.run(train)
- 每 20 步打印一下结果,sess.run 指向 Weights,biases 并被输出:
if step % 20 == 0:
print(step, sess.run(Weights), sess.run(biases))
所以关键的就是 y,loss,optimizer 是如何定义的。
TensorFlow学习笔记(九)tf搭建神经网络基本流程相关推荐
- 【拔刀吧 TensorFlow】TensorFlow学习笔记八——何为卷积神经网络
TensorFlow直接以官方手册作为切入点,在趣味性和快速性上优势很大,但是对于学习深入理论的理解产生了巨大的阻碍. 在"深入MNIST"这一节中,遇到了卷积神经网络的构建,涉及 ...
- tensorflow学习笔记九:将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程
2017/02/23 更新 贴一个TensorFlow 2017开发者大会的Mobile专题演讲 移动和嵌入式TensorFlow 这里面有重点讲到本文介绍的三个例子,以及其他的移动和嵌入式方面的TF ...
- TensorFlow学习笔记之五(卷积神经网络)
文章目录 1. 图片识别问题简介以及经典数据集 1.1 图片识别问题简介 1.2 经典数据集 1. 单通道图片求卷积 1.1 基本的图片求卷积 1.2 填充的图片求卷积 1.3 TensorFlow计 ...
- tensorflow学习笔记:tf.control_dependencies,tf.GraphKeys.UPDATE_OPS,tf.get_collection
tf.control_dependencies(control_inputs): control_dependencies(control_inputs) ARGS: control_inputs:在 ...
- TensorFlow学习笔记(一): tf.Variable() 和tf.get_variable()详解
对于tf.Variable和tf.get_variable,这两个都是在我们训练模型的时候常遇到的函数,我们首先要知道懂得它的语法格式.常用的语法格式的作用以及在实际代码中是如何调用.如何运行的,运行 ...
- TensorFlow学习笔记之六(循环神经网络RNN)
文章目录 1. 循环神经网络简介 1. 循环神经网络简介 循环神经网络源自于1982年由 Saratha Sathasivam提出的霍普菲尔德网络. 循环神经网络的主要用途是处理和预测序列数据,循环神 ...
- tensorflow学习笔记:tf.data.Dataset,from_tensor_slices(),shuffle(),batch()的用法
tf.data.Dataset.from_tensor_slices: 它的作用是切分传入Tensor的第一个维度,生成相应的dataset. 例1: dataset = tf.data.Datase ...
- python自训练神经网络_tensorflow学习笔记之简单的神经网络训练和测试
本文实例为大家分享了用简单的神经网络来训练和测试的具体代码,供大家参考,具体内容如下 刚开始学习tf时,我们从简单的地方开始.卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第 ...
- TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]
我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自动根据loss计算对应variable的导数.示例如下: loss = . ...
最新文章
- python 对象拷贝
- 一文讲清,MySQL主从架构
- SpringCloud之分布式配置中心(六)
- SWT让耗时的操作后台运行
- [转] 数学符号英文拼写及发音
- mysql死锁语句_记一次神奇的Mysql死锁排查
- 10 步让你成为更优秀的Coder
- (二十二)美萍酒店管理系统:系统维护_系统设置_房间设置_其他测试
- 八数码难题(启发式搜索)
- 从Activiti切换到Camunda的5个理由
- android msf 漏洞,MSF之ms各种漏洞
- 大比分领先!ACCV 2022 国际细粒度图像分析挑战赛冠军方案
- 深度学习需要多强的数学基础?
- Kubernetes API Aggregation在 Master 的 API Server 中启用 API 聚合功能注册自定义 APIService 资源实现和部署自定义的 API Serv
- mysql数据表损坏的常见原因是_MYSQL数据表损坏的分析
- 树莓派centos踩坑之旅,解决每次重启都需要route add才能有网络
- HighlightingSystem(边缘发光插件)的简单使用(一)
- WebLogicServer BEA-000386 Weblogic启动报错
- 【CV】SiamFC:用于目标跟踪的全卷积孪生网络
- html把毫秒转换成年月日,如何使用JavaScript将毫秒转换为日期格式?
热门文章
- 算法竞赛入门经典(第二版) | 程序3-10 生成元 (UVa1584,Circular Sequence)
- Python绘制三维散点图
- 线程中这么调用类_这些线程知识总结是真的到位!java开发两年的我看的目瞪口呆
- python中乘法和除法_python – NumPy的性能:uint8对比浮动和乘法与除法?
- 计算机不同用户信息互通吗,迷你世界电脑版和手机版通用吗 二者账号数据互通吗...
- rust石头墙几个c4_哪个房间需要清扫 石头扫地机器人T6可能比你还清楚
- paradox 修改字段长度_400字的作文就只能写400字?刘强东:这不是笑话
- limux php启动_linux下nginx与php设置开机启动代码
- notepad php格式,notepad怎么格式xml
- java pid 获取句柄_获取进程pid、根据进程pid获取线程pid、获取线程进程句柄