tensorflow学习笔记二——建立一个简单的神经网络

2016-09-23 16:04 2973人阅读 评论(2) 收藏 举报
 分类:
tensorflow(4) 

目录(?)[+]

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):# add one more layer and return the output of this layerWeights = tf.Variable(tf.random_normal([in_size, out_size]))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)Wx_plus_b = tf.matmul(inputs, Weights) + biasesif activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b)return outputs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。 
numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop]. 

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step =  tf.train.GradientDescentOptimizer(0.1).minimize(loss)
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pltdef add_layer(inputs, in_size, out_size, activation_function=None):# add one more layer and return the output of this layerWeights = tf.Variable(tf.random_normal([in_size, out_size]))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)Wx_plus_b = tf.matmul(inputs, Weights) + biasesif activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b)return outputs# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()for i in range(1000):# trainingsess.run(train_step, feed_dict={xs: x_data, ys: y_data})if i % 50 == 0:# to visualize the result and improvementtry:ax.lines.remove(lines[0])except Exception:passprediction_value = sess.run(prediction, feed_dict={xs: x_data})# plot the predictionlines = ax.plot(x_data, prediction_value, 'r-', lw=5)plt.pause(0.1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

运行结果: 

tensorflow学习笔记二——建立一个简单的神经网络拟合二次函数相关推荐

  1. 炼数成金Tensorflow学习笔记之2.4_Tensorflow简单示例

    炼数成金Tensorflow学习笔记之2.4_Tensorflow简单示例 代码及分析 代码及分析 # -*- coding: utf-8 -*- """ Created ...

  2. 神经网络和深度学习(二)——一个简单的手写数字分类网络

    本文转自:https://blog.csdn.net/qq_31192383/article/details/77198870 一个简单的手写数字分类网络 接上一篇文章,我们定义了神经网络,现在我们开 ...

  3. Django学习笔记2:一个简单的开发实例

    Technorati 标签: Python,Django 目标:通过开发一个简单的Todo管理应用,熟悉Django的基本概念.和使用. 运行环境 Windows Vista + Python 2.7 ...

  4. ROS2学习笔记13--编写一个简单的发布器和侦听器(C++)

    概要:这篇主要介绍编写发布器和侦听器的简单套路(C++) 环境:ubuntu20.04,ros2-foxy,vscode 最后如果没有陈述实操过程中碰到问题的话,则表示该章节都可被本人正常复现. 2. ...

  5. CAD二次开发学习笔记二(创建一个对话框)

    打开资源视图->右击->添加资源->Dialog 双击对话框,弹出MFC类向导,输入类名FirstClass, 确定,创建对话框类.FirstClass.h与FirstClass.c ...

  6. JSP/Servlet Web 学习笔记 DayFour —— 实现一个简单的JSP/Servlet交互

    小实例说明: a)实现一个由JSP负责前台显示,Servlet负责后台处理的交互小实例 b)JSP页面由表单获取一个开始数字,一个结束数字,交给Servlet打印响应的乘法表. 未解决的问题: a)跳 ...

  7. 《Kubernetes权威指南第2版》学习(二)一个简单的例子

    1: 安装VirtualBox, 并下载CentOS-7-x86_64-DVD-1708.iso, 安装centOS7,具体过程可以百度. 2:开启centOS的SSH, 步骤如下: (1) yum ...

  8. TensorFlow学习笔记(九)tf搭建神经网络基本流程

    1. 搭建神经网络基本流程 定义添加神经层的函数 1.训练的数据 2.定义节点准备接收数据 3.定义神经层:隐藏层和预测层 4.定义 loss 表达式 5.选择 optimizer 使 loss 达到 ...

  9. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

最新文章

  1. FB邮件服务器测试smtp,pop3
  2. 【Linux开发】V4L2应用程序框架
  3. 我的征程是未来!带你展望2015年最重要的网页设计趋势
  4. 今年最惨的交易:做空特斯拉
  5. 远控免杀专题(24)-CACTUSTORCH免杀
  6. 拥抱创新,持续探索——对话阿里云MVP胡逢法
  7. 使用CSS控制段落首行缩进
  8. iOS核心动画高级技术(九) 图层时间
  9. 专为Mac用户设计的创建图形模式软件:Patternodes 2.4.4
  10. Url...................哈哈哈哈哈哈哈哈哈
  11. 阶段1 语言基础+高级_1-3-Java语言高级_06-File类与IO流_09 序列化流_3_对象的反序列化流_ObjectInputStream...
  12. go的编程哲学和设计理念
  13. java ojdbc7_ojdbc7 / ojdbc8中的charset问题与ojdbc6
  14. 理工科专业精品书系列
  15. 如何刻录服务器安装系统光盘启动盘,如何刻录系统光盘
  16. 外包公司的运作模式和赚钱之道-聊聊IT外包公司
  17. 无线网卡驱动正常却搜索不到无线信号
  18. antv,图表和地图
  19. PR放入视频音频后没声音,及提示MME设备内部错误的解决办法
  20. 学习网络攻防,有什么渠道?

热门文章

  1. php截取剩余部分,PHP从字串中截取一部分,支持使用(*)模糊截取
  2. Django 无法添加新字段,django.db.utils.OperationalError: (1050, Table app already exists)
  3. java 计算移动平均线_基于Java语言开发的个性化股票分析技术:移动平均线(MA)...
  4. php7 thinkphp5,thinkphp5+phpstudy+php7.0连接SQL Server 2008 | 睿客网
  5. 中南大学计算机有网络安全,中南大学2019年大学生网络安全知识竞赛(复赛)成功举行...
  6. mysql中两种备份方法的优缺点_Mysql两种存储引擎的优缺点
  7. linux操作系统应急方案,服务器操作系统应急预案
  8. 秒懂 CountDownLatch 与 CyclicBarrier 使用场景
  9. 一起来看看Fastjson的三种漏洞利用链
  10. 5 年开发搞不定 MySQL !