用 tf重写BP,并增加SGD:

# coding=utf-8
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error import numpy as np
import tensorflow as tf
import random
#import matplotlib.pyplot as pltlogs_path=r'c:/temp/log_mnist_softmax'
learning_rate=5.0 #当>0.05时误差很大
training_epochs=100
batch_size=3trainData_in=np.array([[1.0,1.0,0.0,0.0],\[1.0,0.0,1.0,0.0],\[1.0,0.0,0.0,1.0],\[1.0,0.0,0.0,0.0],\[0.0,1.0,0.0,0.0],\[0.0,0.0,0.0,1.0]])
trainData_out=np.array([[0.0,1.0],\[0.0,1.0],\[0.0,1.0],\[1.0,0.0],\[1.0,0.0],\[1.0,0.0]])testData_in=np.array([[0.0,0.0,0.6,0.8],\[0.0,0.0,0.0,0.0],[0.0,0.0,0.0,1.0],\[0.0,1.0,1.0,0.0],\[0.0,1.0,0.0,1.0],\[0.0,0.0,1.0,1.0]])
testData_out=np.array([[0.0,1.0],\[1.0,0.0],\[1.0,0.0],\[0.0,1.0],\[0.0,1.0],\[0.0,1.0]])print(np.shape(trainData_in))
print(np.shape(trainData_out))x_input=tf.placeholder(tf.float32, [None,4], name='x_input')
y_desired=tf.placeholder(tf.float32,[None,2],name='y_desired')
w1=tf.Variable(tf.truncated_normal([4,3],stddev=0.1),name='w1')
b1=tf.Variable(tf.zeros([3]),name='b1')
z1=tf.matmul(x_input,w1)+b1
y1=tf.nn.sigmoid(z1)w=tf.Variable(tf.truncated_normal([3,2],stddev=0.1),name='w')
b=tf.Variable(tf.zeros([2]),name='b')
z=tf.matmul(y1,w)+b
y_output=tf.nn.softmax(z,name='y_output')
lossFun_crossEntropy=-tf.reduce_mean(y_desired*tf.log(y_output)) #交叉熵均值#BP:
delta=tf.add(y_output,-y_desired)  #BP1
nabla_b=tf.reduce_sum(delta,axis=0,name='nabla_b')#在列方向上求和delta #BP3
nabla_w=tf.matmul(y1,delta,transpose_a=True,name='nabla_w') #BP4
dSigmod_z1=tf.nn.sigmoid(z1)*(1-tf.nn.sigmoid(z1))
delta=tf.matmul(delta,w,transpose_b=True)*dSigmod_z1 #BP2!!!
nabla_b1=tf.reduce_sum(delta,axis=0,name='nabla_b1')#在列方向上求和delta #BP3
nabla_w1=tf.matmul(x_input,delta,transpose_a=True,name='nabla_w1')  #BP4feed_dict_trainData={x_input:trainData_in,y_desired:trainData_out}
feed_dict_testData={x_input:testData_in,y_desired:testData_out}correct_prediction=tf.equal(tf.argmax(y_output,1),\tf.argmax(y_desired,1)) #1:按行索引,每行得一索引值
accuracy=tf.reduce_mean(tf.cast(correct_prediction,\tf.float32))#将逻辑型变成数字型,再求均值
###
#train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(lossFun_crossEntropy)
###
tf.summary.scalar('cost',lossFun_crossEntropy)
tf.summary.scalar('accuracy',accuracy)
summary_op=tf.summary.merge_all()with tf.Session() as sess:sess.run(tf.global_variables_initializer())logs_writer=tf.summary.FileWriter(logs_path,graph=tf.get_default_graph())for epoch in range(training_epochs):
#        _,summary=sess.run([train_step,summary_op],feed_dict=feed_dict_trainData)#######SGD:trainData=list(zip(trainData_in,trainData_out))random.shuffle(trainData)trainData_in,trainData_out=zip(*trainData)batch_count=int(len(trainData_in)/batch_size)for i in range(batch_count):batch_x=trainData_in[batch_size*i:batch_size*(i+1)]batch_y=trainData_out[batch_size*i:batch_size*(i+1)]feed_dict_batch={x_input:batch_x,y_desired:batch_y}#update:w1_temp,b1_temp,w_temp,b_temp,\nabla_w1_temp,nabla_b1_temp,nabla_w_temp,nabla_b_temp=\sess.run([w1,b1,w,b,nabla_w1,nabla_b1,nabla_w,nabla_b],\feed_dict=feed_dict_batch)m,n=np.shape(batch_y)update_w1=tf.assign(w1,w1_temp-learning_rate/m/n*nabla_w1_temp)update_b1=tf.assign(b1,b1_temp-learning_rate/m/n*nabla_b1_temp)update_w=tf.assign(w,w_temp-learning_rate/m/n*nabla_w_temp)update_b=tf.assign(b,b_temp-learning_rate/m/n*nabla_b_temp)sess.run([update_w1,update_b1,update_w,update_b])summary=sess.run(summary_op,feed_dict=feed_dict_trainData)logs_writer.add_summary(summary,epoch)print('Epoch',epoch)print('Accuracy_trainData:',accuracy.eval\(feed_dict=feed_dict_trainData))print('Accuracy_testData:',accuracy.eval\(feed_dict=feed_dict_testData))print('Done')try_input=testData_in[0] try_desired=testData_out[0]  print(try_desired)print(y_output.eval(feed_dict={x_input:[try_input]}))

TensorFlow第六步: 继续挖坑 用tf重写BP并增加SGD相关推荐

  1. STM32学习100步之第七十二-七十六步——U盘、TF卡与单片机的通信(利用SPI总线通信)

    SPI通信 由图中可以看出,SPI有四条主要的信号线,即MISO(主机输入从机输出).MOSI(主机输出从机输入),CS是对于从机而言的,当为0时,允许通信,由主机控制是否选通,另外可以使用单片机的I ...

  2. 实践操作:六步教你如何用开源框架Tensorflow对象检测API构建一个玩具检测器

    TensorFlow对象检测API是一个建立在TensorFlow之上的开源框架,可以轻松构建,训练和部署对象检测模型. 到目前为止,API的性能给我留下了深刻的印象.在这篇文章中,我将API的对象设 ...

  3. tfr 计算机硬件,实践操作:六步教你如何用开源框架Tensorflow对象检测API构建一个玩具检测器...

    TensorFlow对象检测API是一个建立在TensorFlow之上的开源框架,可以轻松构建,训练和部署对象检测模型. 到目前为止,API的性能给我留下了深刻的印象.在这篇文章中,我将API的对象设 ...

  4. 掌握这六步,搭建完美的机器学习项目

    全文共7778字,预计学习时长15分钟 上图白板展示了一系列机器学习项目启动 机器学习覆盖的范围十分广泛.这篇文章将整体描述机器学习适用的典型问题,提供实现机器学习项目雏形的框架. 首先厘清一些定义. ...

  5. 六步使用ICallbackEventHandler实现无刷新回调

    AJAX技术所提倡的无刷新回调,在原来的技术中需要写大量的JavaScript代码或使用一些AJAX框架,使得开发效率和可维护性大大降低.其实ASP.NET2.0中,已经提供了这样的接口,这就是ICa ...

  6. 重磅 | TensorFlow 2.0即将发布,所有tf.contrib将被弃用

    作者 | 阿司匹林 出品 | AI科技大本营(公众号ID:rgznai100) 上周,谷歌刚刚发布了 TensorFlow 1.10.0 版本(详见<TensorFlow 版本 1.10.0 发 ...

  7. 【新书】用Python3六步掌握机器学习第二版,469页pdf,Mastering Machine Learning

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 https://www.apress.com/gp/book/978148424946 ...

  8. bin文件怎么转换成文本文档_怎么把视频文件转换成MP3?这款工具六步帮你实现...

    在我们的生活当中,如果我们在看视频的时候,碰到了一则带有背景音乐的视频,此时我们不需要视频上的画面内容,只想要保留视频里的音乐的话,那么我们就需要通过将视频转换成MP3音频格式文件,才能够获得视频当中 ...

  9. SQL Server六步改善安全规划全攻略

    SQL Server六步改善安全规划全攻略 1.验证方法选择 本文对验证(authentication)和授权(authorization)这两个概念作不同的解释.验证是指检验用户的身份标识:授权是指 ...

最新文章

  1. 计算机专业英语chapter012,计算机专业英语 chapter_1.ppt
  2. java web运行的快慢_WebAssembly执行速度真的很强悍吗?对微软Edge很无语
  3. 时隔四年回归的澎湃芯片,能为雷军赌上一生荣耀的造车创业带来什么?
  4. web后端轻量级框架flask基础调用程序模板
  5. emacs 跳转到指定行
  6. 移动端input 无法获取焦点的问题
  7. 在NAS上基础构建云存储系统的两种解决方案
  8. Presto性能调优的五大技巧
  9. java语言设计论文_(C)论文(JAVA语言考试系统的设计与实现)
  10. 记事本如何运行python代码_利用Python开发实现简单的记事本
  11. linux文本编辑器
  12. paypal java sdk_PayPal-Java SDK /信用卡付款问题
  13. 计算机二级考试模拟表单答题,2015年计算机二级考试Visual FoxPro练习题
  14. 数独问题流程图_数独游戏的难度等级分析及求解算法研究
  15. 服务器系统开机黑屏只有个鼠标,Win7系统开机却显示黑屏并只有鼠标光标该怎么办...
  16. mysql的密码破解
  17. python读取tif文件与png文件
  18. 21.VIVO: Visual Vocabulary Pre-Training for Novel Object Captioning
  19. 岁月的剪影【十一月无需要太多】
  20. 什么是瀑布流布局?瀑布流布局的实现方法

热门文章

  1. Vue开启Gzip打包异常:webpack打包报错Cannot read property ‘emit‘ of undefined
  2. 容器编排技术 -- Kubernetes入门概述
  3. Dos批处理编程常用命令
  4. AOP Aspect Oriented Programming 面向切面编程 Spring
  5. 【C语言】输入一个字符串,统计其中的单词个数,将第一个单词的首字母改为大写,并输出改写后的字符串...
  6. redux 函数式组件_如何从函数式编程的角度学习Redux
  7. 在线编码工具_我希望在开始编码时就已经知道的工具
  8. 原型和原型链原型继承_原型还是不原型:这就是问题所在。
  9. 一年前端开发工程师简历_一年前,我开始学习编码,专注于前端开发。
  10. mysql 硬盘缓存_paip.mysql性能跟iops的以及硬盘缓存的关系_MySQL