TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve

目录

输出结果

代码设计


输出结果

代码设计

# 23 Batch Normalizationimport numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltACTIVATION = tf.nn.tanh
N_LAYERS = 7
N_HIDDEN_UNITS = 30   def fix_seed(seed=1): # reproduciblenp.random.seed(seed)tf.set_random_seed(seed)def plot_his(inputs, inputs_norm):# plot histogram for the inputs of every layerfor j, all_inputs in enumerate([inputs, inputs_norm]):for i, input in enumerate(all_inputs):plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))plt.cla()if i == 0:the_range = (-7, 10)else:the_range = (-1, 1)plt.hist(input.ravel(), bins=15, range=the_range, color='#0000FF')plt.yticks(())if j == 1:plt.xticks(the_range)else:plt.xticks(())ax = plt.gca()ax.spines['right'].set_color('none')ax.spines['top'].set_color('none')plt.title("%s normalizing" % ("Without" if j == 0 else "With"))plt.title('Matplotlib,BN,histogram--Jason Niu')plt.draw()plt.pause(0.001)def built_net(xs, ys, norm): def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):# weights and biases (bad initialization for this case)Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0., stddev=1.))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)# fully connected productWx_plus_b = tf.matmul(inputs, Weights) + biases# normalize fully connected productif norm:# Batch Normalizefc_mean, fc_var = tf.nn.moments( Wx_plus_b,axes=[0],  )scale = tf.Variable(tf.ones([out_size]))shift = tf.Variable(tf.zeros([out_size]))epsilon = 0.001# apply moving average for mean and var when train on batchema = tf.train.ExponentialMovingAverage(decay=0.5)def mean_var_with_update():ema_apply_op = ema.apply([fc_mean, fc_var])with tf.control_dependencies([ema_apply_op]):return tf.identity(fc_mean), tf.identity(fc_var)mean, var = mean_var_with_update()Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)  #进行BN一下# Wx_plus_b = Wx_plus_b * scale + shift# activationif activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b)return outputs  #输出激活结果fix_seed(1)if norm:# BN for the first inputfc_mean, fc_var = tf.nn.moments(xs,axes=[0],)scale = tf.Variable(tf.ones([1]))shift = tf.Variable(tf.zeros([1]))epsilon = 0.001# apply moving average for mean and var when train on batchema = tf.train.ExponentialMovingAverage(decay=0.5)def mean_var_with_update():ema_apply_op = ema.apply([fc_mean, fc_var])with tf.control_dependencies([ema_apply_op]):return tf.identity(fc_mean), tf.identity(fc_var)mean, var = mean_var_with_update()xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)# record inputs for every layerlayers_inputs = [xs] # build hidden layersfor l_n in range(N_LAYERS):layer_input = layers_inputs[l_n]in_size = layers_inputs[l_n].get_shape()[1].valueoutput = add_layer(layer_input,    # inputin_size,        # input sizeN_HIDDEN_UNITS, # output sizeACTIVATION,     # activation functionnorm,           # normalize before activation)layers_inputs.append(output)  # build output layerprediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None)cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)return [train_op, cost, layers_inputs] fix_seed(1)
x_data = np.linspace(-7, 10, 2500)[:, np.newaxis]  #水平轴-7~10
np.random.shuffle(x_data)
noise = np.random.normal(0, 8, x_data.shape)
y_data = np.square(x_data) - 5 + noisexs = tf.placeholder(tf.float32, [None, 1])  # [num_samples, num_features]
ys = tf.placeholder(tf.float32, [None, 1])#建立两个神经网络作对比
train_op, cost, layers_inputs = built_net(xs, ys, norm=False)
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True)sess = tf.Session()
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:init = tf.initialize_all_variables()
else:init = tf.global_variables_initializer()
sess.run(init)# record cost
cost_his = []
cost_his_norm = []
record_step = 5     plt.ion()
plt.figure(figsize=(7, 3))
for i in range(250):if i % 50 == 0: # plot histogramall_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys: y_data})plot_his(all_inputs, all_inputs_norm)# train on batch每一步都run一下sess.run([train_op, train_op_norm], feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]})if i % record_step == 0:# record costcost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))cost_his_norm.append(sess.run(cost_norm, feed_dict={xs: x_data, ys: y_data}))#以下是绘制误差值Cost误差曲线的方法
plt.ioff()
plt.figure()
plt.title('Matplotlib,BN,Error_curve--Jason Niu')
plt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his), label='no BN')     # no norm
plt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his_norm), label='BN')   # norm
plt.legend()
plt.show()

相关文章
TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve

TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve相关推荐

  1. 根据《关于“k-means算法在流式细胞仪中细胞分类的应用”的学习笔记总结》撰写的中期报告...

    XXXX大学2014届本科毕业设计(论文)中期报告 毕业设计(论文)题目:K-means算法在流式细胞仪中细胞分类的应用 专业(方向):生物医学工程 学生信息:XXXXXX.XX.生医XXX 指导教师 ...

  2. 卷积神经网络CNN(2)—— BN(Batch Normalization) 原理与使用过程详解

    前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...

  3. 关于《k-means算法在流式细胞仪中细胞分类的应用》的学习笔记总结

    k-means算法在流式细胞仪中细胞分类的应用之学习总结 关键字:流式细胞仪,T淋巴细胞,k-means聚类,数据挖掘应用 一.课题简介 随着信息技术和计算机技术的迅猛发展,人们面临着越来越多的文本. ...

  4. caffe中的batchNorm层(caffe 中为什么bn层要和scale层一起使用)

    caffe中的batchNorm层 链接: http://blog.csdn.net/wfei101/article/details/78449680 caffe 中为什么bn层要和scale层一起使 ...

  5. dropout层_深度学习两大基础Tricks:Dropout和BN详解

    深度学习 Author:louwill Machine Learning Lab Dropout dropout作为目前神经网络训练的一项必备技术,自从被Hinton提出以来,几乎是进行深度学习训练时 ...

  6. 我眼中的算法导论 | 第一章——算法在计算中的作用、第二章——算法基础

    一个小白的算法学习之路.读<算法导论>第一天.本文仅作为学习的心得记录. 算法(Algorithm) 对于一个程序员来说,无论资历深浅,对算法一词的含义一定会或多或少有自己的体会,在< ...

  7. bm25算法Java代码_BM25算法在Lucene中的应用

    Lucene是apache软件基金会jakarta项目组的一个子项目,是一个用Java写的全文检索引擎工具包,可以方便的集成到系统中提以提供高效的检索能力,Lucene核心功能分为建索和检索两部分.而 ...

  8. (各种均衡算法在MIMO中的应用对比试验)最小均方误差(MMSE)原理推导以及在MIMO系统中对性能的改善。

    文档和程序地址:下载地址 各种均衡算法在MIMO中的应用对比试验,内附原理推导,对比实验说明和结果等.包括MMSE,ZF,ZF-SIC等.代码附有原理推导小论文.仅供参考

  9. 计算机视觉:Bag of words算法实现过程中出现错误及解决方案

    Bag of words算法实现过程中出现错误及解决方案 出现的问题 IndexError: list index out of range OSError:x.sift not found sqli ...

最新文章

  1. 组原,汇编语言关于代码段的定义
  2. C++知识点62——模板实参推断与函数模板的特化
  3. JSON.parse(text[, reviver])
  4. 1、在Linux虚拟机上安装 docker
  5. QT实现在图表顶部绘制一个附加元素(标注)
  6. javabeans_膨胀的JavaBeans –不要在您的API中添加“ Getters”
  7. 第七期:Python 从入门到精通:一个月就够了!
  8. mysql 5.5免安装配置_mysql的参考文档mysql5.5.21免安装版的配置方法
  9. mysql ssh错误_通过SSH隧道连接时,MySQL访问被拒绝错误
  10. 第二章 XHTML简介
  11. 防止sql注入:替换危险字符
  12. 光环PMP 二模错题知识点
  13. 怎么学习PLC技术?
  14. HTML5七夕情人节表白网页制作【一款乾坤八卦风水罗盘旋转CSS3动画特效代码,给人一种玄机重重的感觉】HTML+CSS+JavaScript
  15. win10创建局域网服务器
  16. paperwhite3翻页_Kindle vs. Paperwhite vs. Voyage vs. Oasis:您应该购买哪种Kindle?
  17. MATLAB中关于复矩阵的操作,新手易错
  18. 虎克哈克环槽铆钉机 铆接回收机振动筛设备 钢结构集装箱铆接机
  19. Ymodem协议介绍
  20. java中ssh测试接口方法_SSH入门---框架搭建(eclipse环境下)

热门文章

  1. Linux 发行版与Linux内核
  2. [unity3d]导出安卓版设置
  3. 《慕课React入门》总结
  4. Java类集框架 —— HashMap源码分析
  5. 通讯传输--全双工和半双工
  6. 天啊,为什么我的 Redis 变慢了。。
  7. 谷歌和 Facebook 是如何给工程师定职级和薪水的?
  8. Python r‘‘, b‘‘, u‘‘, f‘‘ 的含义
  9. Android——通知栏提示 app 更新的进度,更新完可以访问授权进行安装。适配 8.0 版本
  10. Java “Resource leak: ‘scanner‘ is never closed“警告的解决办法