#coding=utf-8
import tensorflow as tf
import numpy as np
import pandas as pd
import cv2 as cv
import os
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as pltdef draw_line(x_data, y_data, k, c):#这是离散的点图,同时含有高斯噪声plt.plot(x_data, y_data, "ro", label="points")#X和Y的范围在0-15,因为之前构造的X范围是从0-14plt.axis([0, 15, 0, 15])#准备一个列表curr_y = []#从我们通过训练模型得到的斜率和截距画出一个线性回归的曲线for i in range(len(x_data)):curr_y.append(x_data[i]*k+c)plt.plot(x_data, curr_y, label="fit-line")#设置显示labelplt.legend()#显示plt.show()def line_regression():# 定义线性方程y=wx+b的斜率和截距的变量,后面我们会通过训练得到这2个值w = tf.Variable(0.1, dtype=tf.float32)b = tf.Variable(0.1, dtype=tf.float32)# 产生30个随机数,数值从0-14,类型是浮点数X = tf.random_uniform([30], minval=0, maxval=14, dtype=tf.float32)"""1.这一步会生成一个含有30对坐标(x, y), 同时所有的坐标都满足线性方程Y = 1.2X + 0.82.我们的目的就是要将w逼近1.2,b逼近0.83.最后line_model是一个包含30个元素的一维矩阵"""line_model = tf.add(tf.multiply(X, 1.2), 0.8)# 生成30个满足高斯分布的随机噪声noise = tf.random_normal([30], mean=0, stddev=0.5, dtype=tf.float32)# 将30个随机噪声叠加在line_model(Y分量)上y_labels = line_model + noise"""1.以上所有操作的目的就是为了构造用于测试线性回归的输入数据参数2.现在我们得到了30组能够拟合到线性方程的输入数据参数"""# 通过输入参数和训练模型(训练w和b)得到输出y_ = tf.add(tf.multiply(w, X), b)# 计算平方diff = tf.square(y_ - y_labels)# 计算均值,目标是均值最小loss = tf.reduce_mean(diff)# 设置学习力optimizer = tf.train.GradientDescentOptimizer(0.01)# 设置优化器的优化方向step = optimizer.minimize(loss)# 初始化tensorflow的变量init = tf.global_variables_initializer()# 为tensorflow的会话设置别名with tf.Session() as sess:# run变量sess.run(init)# 训练次数设置for i in range(5000):sess.run(step)# 每隔多少次打印一次均值,这个值应当逐步收敛if (i + 1) % 100 == 0:curr_loss = sess.run(loss)print("current loss : ", curr_loss)# 训练结束后,得到训练后的变量w和bk, c, y = sess.run([w, b, y_])# 得到训练的输入参数,用于作图x_input, y_input = sess.run([X, y_labels])# 打印拟合后线性方程的斜率和截距print(k, c)draw_line(x_input, y_input, k, c)line_regression()

结果如下:

...
('current loss : ', 0.46891606)
('current loss : ', 0.27399457)
('current loss : ', 0.23788354)
('current loss : ', 0.28854463)
('current loss : ', 0.27047512)
('current loss : ', 0.19254531)
('current loss : ', 0.17524074)
('current loss : ', 0.3638595)
('current loss : ', 0.12371392)
('current loss : ', 0.42371178)
('current loss : ', 0.22886035)
('current loss : ', 0.26392117)
('current loss : ', 0.29713854)
('current loss : ', 0.29254925)
('current loss : ', 0.32334834)
('current loss : ', 0.17560533)
('current loss : ', 0.20040913)
('current loss : ', 0.40767363)
('current loss : ', 0.29662272)
(1.220254, 0.8045528)

图片展示:

很明显,k逼近到了1.2,b已经逼近到了0.8

TensorFlow入门:线性回归相关推荐

  1. Tensorflow入门——训练结果的保存与加载

    2019独角兽企业重金招聘Python工程师标准>>> 训练完成以后我们就可以直接使用训练好的模板进行预测了 但是每次在预测之前都要进行训练,不是一个常规操作,毕竟有些复杂的模型需要 ...

  2. TensorFlow入门:第一个机器学习Demo

    TensorFlow入门:第一个机器学习Demo 2017年12月13日 20:10:23 阅读数:8604 本文主要通过一个简单的 Demo 介绍 TensorFlow 初级 API 的使用方法,因 ...

  3. 从 TensorFlow 入门机器学习

    写在前面:紧跟时代步伐,开始学习机器学习,抱着争取在毕业之前多看看各个方向是什么样子的心态,发现这是一个很有潜力也很有趣的领域(keng).// 然后就开始补数学了-- 0 TensorFlow 介绍 ...

  4. 一文带你看懂!TensorFlow入门

    个人博客导航页(点击右侧链接即可打开个人博客):大牛带你入门技术栈 TensorFlow入门 本文将初步向码农和程序媛们介绍如何使用TensorFlow进行编程.在阅读之前请先 安装TensorFlo ...

  5. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  6. tensorflow 入门

    基本使用 使用 TensorFlow, 你必须明白 TensorFlow: 使用图 (graph) 来表示计算任务. 在被称之为 会话 (Session) 的上下文 (context) 中执行图. 使 ...

  7. Tensorflow 入门教程

    Tensorflow 入门教程  http://tensornews.cn/ 深度学习发展史 特征工程 深度学习之激活函数 损失函数 反向传播算法 [上] 反向传播算法 [下] Tensorflow ...

  8. tensorflow入门_TensorFlow法律和统计入门

    tensorflow入门 by Daniel Deutsch 由Daniel Deutsch TensorFlow法律和统计入门 (Get started with TensorFlow on law ...

  9. 【深度学习】Tensorflow完成线性回归对比机器学习LinearRegression()

    首先构建一个线性的点状图 import warnings warnings.filterwarnings('ignore') import numpy as np import matplotlib. ...

  10. Tensorflow入门--图与会话

    目录 第1关:Hello,Tensorflow 第2关:计算图与会话 第3关:Tensorflow实现线性回归 第1关:Hello,Tensorflow 本关任务:编写使用python一个Tensor ...

最新文章

  1. python字符串和字节串有什么区别_对于Python中的字节串bytes和字符串以及转义字符的新的认识...
  2. 驾照考试:六百公里考试流程与注意事项
  3. 初学者学用Github
  4. 混合云:公共云和私有云之间取得平衡的方式?
  5. 【C++】带空格输入
  6. liquibase mysql_Liquibase MySQL:语法错误附近'????????????????'
  7. 【bzoj4653】[Noi2016]区间 双指针法+线段树
  8. 解决Visual Studio 2019未能从“https://www.nuget.org/api/v2/package..“下载包问题
  9. 《酒吧圣经》学习笔记1
  10. P1183 多边形的面积
  11. 应用统计学考研笔记1:数据整理与抽样
  12. js 区分中英文输入法(如中英文括号)
  13. 微信小程序跳转公众号图文内容
  14. 【C++】-- C++11基础常用知识点(下)
  15. css来回摆动,css3 animation(左右摆动) (放大缩小)
  16. 5G NR标准 第9章 传输信道处理
  17. 仿热血江湖帮战客方血帮战 开始对战记时器结束事件
  18. 单片机矩阵式键盘扫描程序
  19. 精通脚本黑客--电骡下载
  20. 服务器状态502 503 504,服务器错误500/502/503/504详解

热门文章

  1. C语言使用SQLite3数据库
  2. 为ashx文件启用session管理
  3. JAVA线程池shutdown和shutdownNow的区别
  4. 浅谈JS中的原型对象和原型链
  5. TOMCAT报错解决
  6. phpunit 测试指定目录下的测试类
  7. Intel Optane P4800X评测(序):不用缓存和电容保护的SSD?
  8. 李瑾博士:信誉的建立是否“不计成本”?
  9. MySQL · 特性分析 · 执行计划缓存设计与实现
  10. apt-get install