1、模型机制

tensor  代表数据,可以理解为多维数组

variable  代表变量,模型中表示定义的参数,是通过不断训练得到的值

placeholder  代表占位符,也可以理解为定义函数的参数

2、session 的两种使用方法(还有一种启动session的方式是sess = tf.InteractiveSession())

3、注入机制

4、指定gpu运算

5、保存模型与载入模型

示例完整代码如下可直接运行:

  1. import tensorflow as tf
  2. import numpy as np
  3. plotdata = { "batchsize":[], "loss":[] }
  4. #生成模拟数据
  5. train_X = np.linspace(-1, 1, 100)
  6. train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
  7. tf.reset_default_graph()  #注意需要添加一个重置图
  8. # 创建模型
  9. # 占位符
  10. X = tf.placeholder("float")
  11. Y = tf.placeholder("float")
  12. # 模型参数
  13. W = tf.Variable(tf.random_normal([1]), name="weight")
  14. b = tf.Variable(tf.zeros([1]), name="bias")
  15. # 前向结构
  16. z = tf.multiply(X, W)+ b
  17. #反向优化
  18. cost =tf.reduce_mean( tf.square(Y - z))
  19. learning_rate = 0.01
  20. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
  21. # 初始化变量
  22. init = tf.global_variables_initializer()
  23. # 训练参数
  24. training_epochs = 20
  25. display_step = 2
  26. saver = tf.train.Saver()
  27. savedir = './'
  28. # 启动session
  29. with tf.Session() as sess:
  30. sess.run(init)
  31. # Fit all training data
  32. for epoch in range(training_epochs):
  33. for (x, y) in zip(train_X, train_Y):
  34. sess.run(optimizer, feed_dict={X: x, Y: y})
  35. #显示训练中的详细信息
  36. if epoch % display_step == 0:
  37. loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
  38. print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
  39. if not (loss == "NA" ):
  40. plotdata["batchsize"].append(epoch)
  41. plotdata["loss"].append(loss)
  42. print (" Finished!")
  43. saver.save(sess,savedir+'linemodel.cpkt') #模型保存
  44. print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
  45. #模型载入
  46. with tf.Session() as sess2:
  47. sess2.run(tf.global_variables_initializer())
  48. saver.restore(sess2,savedir+'linemodel.cpkt')
  49. print('x=0.1,z=',sess2.run(z,feed_dict={X:0.1}))

6、检查点,训练模型有时候会出现中断情况,可以将检查点保存起来

saver一个参数max_to_keep=1表明最多只保存一个检查点文件

载入时指定迭代次数load_epoch

完整代码如下:

  1. import tensorflow as tf
  2. import numpy as np
  3. plotdata = { "batchsize":[], "loss":[] }
  4. #生成模拟数据
  5. train_X = np.linspace(-1, 1, 100)
  6. train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
  7. tf.reset_default_graph()  #注意需要添加一个重置图
  8. # 创建模型
  9. # 占位符
  10. X = tf.placeholder("float")
  11. Y = tf.placeholder("float")
  12. # 模型参数
  13. W = tf.Variable(tf.random_normal([1]), name="weight")
  14. b = tf.Variable(tf.zeros([1]), name="bias")
  15. # 前向结构
  16. z = tf.multiply(X, W)+ b
  17. #反向优化
  18. cost =tf.reduce_mean( tf.square(Y - z))
  19. learning_rate = 0.01
  20. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
  21. # 初始化变量
  22. init = tf.global_variables_initializer()
  23. # 训练参数
  24. training_epochs = 20
  25. display_step = 2
  26. saver = tf.train.Saver(max_to_keep=1)       #表明最多只保存一个检查点文件
  27. savedir = './'
  28. # 启动session
  29. with tf.Session() as sess:
  30. sess.run(init)
  31. # Fit all training data
  32. for epoch in range(training_epochs):
  33. for (x, y) in zip(train_X, train_Y):
  34. sess.run(optimizer, feed_dict={X: x, Y: y})
  35. #显示训练中的详细信息
  36. if epoch % display_step == 0:
  37. loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
  38. print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
  39. if not (loss == "NA" ):
  40. plotdata["batchsize"].append(epoch)
  41. plotdata["loss"].append(loss)
  42. saver.save(sess,savedir+'linemodel.cpkt',global_step=epoch)
  43. print (" Finished!")
  44. # saver.save(sess,savedir+'linemodel.cpkt') #模型保存
  45. print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))
  46. #检查点载入
  47. with tf.Session() as sess2:
  48. load_epoch =  18
  49. sess2.run(tf.global_variables_initializer())
  50. saver.restore(sess2,savedir+'linemodel.cpkt-'+str(load_epoch))
  51. print('x=0.1,z=',sess2.run(z,feed_dict={X:0.1}))

模型操作常用函数

tf.train.Saver()  #创建存储器Saver

tf.train.Saver.save(sess,save_path) #保存

tf.train.Saver.restore(sess,save_path) #恢复

7、可视化tensorboard

在代码中加入模型相关操作tf.summary.., 代码后面有注释,这个不理解可以当作模版,这几句代码,放在不同代码相应位置即可

代码如下:

  1. import tensorflow as tf
  2. import numpy as np
  3. plotdata = { "batchsize":[], "loss":[] }
  4. #生成模拟数据
  5. train_X = np.linspace(-1, 1, 100)
  6. train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3 # y=2x,但是加入了噪声
  7. tf.reset_default_graph()  #注意需要添加一个重置图
  8. # 创建模型
  9. # 占位符
  10. X = tf.placeholder("float")
  11. Y = tf.placeholder("float")
  12. # 模型参数
  13. W = tf.Variable(tf.random_normal([1]), name="weight")
  14. b = tf.Variable(tf.zeros([1]), name="bias")
  15. # 前向结构
  16. z = tf.multiply(X, W)+ b
  17. tf.summary.histogram('z',z)#将预测值以直方图显示
  18. #反向优化
  19. cost =tf.reduce_mean( tf.square(Y - z))
  20. tf.summary.scalar('loss_function', cost)#将损失以标量显示
  21. learning_rate = 0.01
  22. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent
  23. # 初始化变量
  24. init = tf.global_variables_initializer()
  25. # 训练参数
  26. training_epochs = 20
  27. display_step = 2
  28. # 启动session
  29. with tf.Session() as sess:
  30. sess.run(init)
  31. merged_summary_op = tf.summary.merge_all()  # 合并所有summary
  32. # 创建summary_writer,用于写文件
  33. summary_writer = tf.summary.FileWriter('log/summaries', sess.graph)
  34. # Fit all training data
  35. for epoch in range(training_epochs):
  36. for (x, y) in zip(train_X, train_Y):
  37. sess.run(optimizer, feed_dict={X: x, Y: y})
  38. summary_str = sess.run(merged_summary_op, feed_dict={X: x, Y: y});
  39. summary_writer.add_summary(summary_str, epoch);  # 将summary 写入文件
  40. #显示训练中的详细信息
  41. if epoch % display_step == 0:
  42. loss = sess.run(cost, feed_dict={X: train_X, Y:train_Y})
  43. print ("Epoch:", epoch+1, "cost=", loss,"W=", sess.run(W), "b=", sess.run(b))
  44. if not (loss == "NA" ):
  45. plotdata["batchsize"].append(epoch)
  46. plotdata["loss"].append(loss)
  47. print (" Finished!")
  48. print ("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b))

之后查看tensorboard,进入summary 日志的上级路径中,输入相关命令如下图所示:

看见端口号为6006,在浏览器中输入http://127.0.0.1:6006,就会看到下面界面

window系统下相关操作一样,进入日志文件目录,然后输入tensorboard相应的命令,在打开浏览器即可看到上图(tensorboard)

tensorflow 就该这么学--2相关推荐

  1. tensorflow 就该这么学--1

    深度学习主要有下面几个步骤: 1.获取数据 2.搭建模型 3.模型训练 4.使用模型解决实际问题 tensorflow是现在最火的深度学习框架,值得学习 简单的用tensorflow拟合二维数据 1. ...

  2. tensorflow和python先学哪个-前辈说先学会了这些Python知识点,再谈学习人工智能!...

    原标题:前辈说先学会了这些Python知识点,再谈学习人工智能! 首先我们看一看Python的优势: 开源,跨平台. 社区.不要小看这一点.社区意味着有很多教程.书籍,出了问题很容易google到,乃 ...

  3. tensorflow就该这么学--6(多层神经网络)

    一.线性问题和非线性问题 1.线性问题 某医院想用神经网络对已经有的病例进行分类,数据样本特征x包括病人的年龄x1和肿瘤的大小x2,(x[x1,x2]),对应的标签为良性或恶性(0.1) 二分类: ( ...

  4. tensorflow就该这么学--5( 神经网络基础)

    一.单个神经元 单个神经元输出时 y=w*x+b 1 .正向传播:输入数据,通过初始给定的参数w,b  计算出对应的函数值 2.反向传播:计算正向传播得到的函数值与真实标签之间的误差值,之后调整w,b ...

  5. tensorflow就该这么学--4(识别手写数字)

  6. tensorflow就该这么学--3

    一.张量及操作 1.张量介绍 (1)tensor类型 DT_FLOAT.DT_DOUBLE.DT_INT64.DT_INT32.DT_INT16.DT_INT8.DT_STRING.DT_BOOL ( ...

  7. TensorFlow王位不保?ICLR投稿论文PyTorch出镜率快要反超了

    自PyTorch出道以来,不断有人表示,发现了这样的趋势: "学术圈正在慢慢地抛弃TensorFlow,转投PyTorch." 如今,PyTorch 1.0发布,ICLR 2019 ...

  8. 用 Go 语言理解 Tensorflow

    原文:https://pgaleone.eu/tensorflow/go/2017/05/29/understanding-tensorflow-using-go/ Tensorflow 并不是一个严 ...

  9. Tensorflow神经网络框架 小例子 三层神经网络 卷积神经网络 循环神经网络 神经网络可视化

    Tensorflow神经网络框架 以前我们讲了神经网络基础,但是如果从头开始实现,那将是一个庞大且费时的工作,所以我们选择一条捷径---神经网络框架.我理解的神经网络框架就相当于一个工具包.就比如我们 ...

最新文章

  1. PTA数据结构与算法题目集(中文)7-25
  2. Andrew Ng 深度学习课后测试记录-01-week2-答案
  3. jquery 加法 乘法运算 精确计算函数
  4. linux上配置spark集群
  5. an导入html5,H5-FLASH:AN HTML5-BASED FLASH RUNTIME
  6. P3250 [HNOI2016]网络(整体二分)
  7. P4983-忘情【wqs二分,斜率优化】
  8. DecExpress 帮助网站
  9. 『TensorFlow』模型保存和载入方法汇总
  10. mac安装和使用boost库
  11. python编程(mysql操作)
  12. 单片机can通信可以接多少个设备_总结BMS上CAN收发器电路的几个要点
  13. C++动态内存管理好难怎么办?零基础图文讲解,小白轻松理解原理
  14. 笛卡尔树(知识总结+板子整理)
  15. VBA用CDO批量发送邮件
  16. 【问】前台销售时卡顿
  17. 软考报名资格审核要多久?证明材料要哪些?
  18. 员工激励:什么样的方法最合适?
  19. 初始化string对象时,申请空间的秘密
  20. Outlook 转发/回复邮件时如何不显示邮件地址而只显示联系人名字?

热门文章

  1. 聊天秒回的人都是生命之光 诉言网
  2. nginx的反向代理以及负载均衡模块的使用
  3. keynotes egestas,PPT 渐变背景下载-imsoft.cnblogs
  4. yum 自动使用光盘和网络源
  5. 高性能网站架构设计之缓存篇(5)- Redis 集群(上)
  6. JAVA 对象引用,以及对象赋值
  7. [转载] 七龙珠第一部——第094话 太阳拳
  8. WordPress自动升级插件时需要填写FTP信息的解决
  9. 原来流行也可以变成怀旧!
  10. 802.11协议精读9:初探节能模式(PS mode)与缓存机制