使用scikit learn的内建iris数据集。用数据点(x代表花瓣宽度,y代表花瓣长度)找到最优直线。

1.导入必要的编程库,创建计算图,加载数据集。

>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> import tensorflow as tf
>>> from sklearn import datasets
>>> from tensorflow.python.framework import ops
>>> ops.reset_default_graph()

>>> sess=tf.Session()

>>> iris=datasets.load_iris()

>>> x_vals=np.array([x[3] for x in iris.data])

>>> y_vals=np.array([y[0] for y in iris.data])

2.声明学习率、批量大小、占位符和模型变量

>>> x_vals=np.array([x[3] for x in iris.data])
>>> y_vals=np.array([y[0] for y in iris.data])
>>> learning_rate=0.05
>>> batch_size=25
>>> x_data=tf.placeholder(shape=[None,1],dtype=tf.float32)
>>> y_target=tf.placeholder(shape=[None,1],dtype=tf.float32)
>>> A=tf.Variable(tf.random_normal(shape=[1,1]))

>>> b=tf.Variable(tf.random_normal(shape=[1,1]))

3.增加线性模型,y=Ax+b

>>> model_output=tf.add(tf.matmul(x_data,A),b)

4.声明L2损失函数,其为批量损失的平均值。初始化变量,声明优化器。

>>> loss=tf.reduce_mean(tf.square(y_target-model_output))
>>> init=tf.global_variables_initializer()
>>> sess.run(init)
>>> my_opt=tf.train.GradientDescentOptimizer(learning_rate)

>>> train_step=my_opt.minimize(loss)

5.遍历迭代,并在随机选择的批量数据上进行模型训练。迭代100次,每25次迭代输出变量值和损失值。

注意:保存每次迭代的损失值,将其用于后续的可视化

>>> loss_vec=[]

>>> for i in range(100):
...   rand_index=np.random.choice(len(x_vals),size=batch_size)
...   rand_x=np.transpose([x_vals[rand_index]])
...   rand_y=np.transpose([y_vals[rand_index]])
...   sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})
...   temp_loss=sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})
...   loss_vec.append(temp_loss)
...   if (i+1)%25==0:
...     print('Step #'+str(i+1)+'A='+str(sess.run(A))+'b='+str(sess.run(b)))
...     print('Loss='+str(temp_loss))
...
Step #25A=[[2.1689417]]b=[[2.9067767]]
Loss=0.9575598
Step #50A=[[1.6556607]]b=[[3.6350703]]
Loss=0.5116888
Step #75A=[[1.3509697]]b=[[4.1133633]]
Loss=0.39227816
Step #100A=[[1.1751562]]b=[[4.356544]]

Loss=0.31387034

6.抽取系数,创建最佳拟合直线

>>> [slope]=sess.run(A)
>>> [y_intercept]=sess.run(b)
>>> best_fit=[]
>>> for i in x_vals:
...   best_fit.append(slope*i+y_intercept)

...

7.绘制你喝的直线和L2正则损失函数

>>> plt.plot(x_vals,y_vals,'o',label='Data Points')
[<matplotlib.lines.Line2D object at 0x000002216124B518>]
>>> plt.plot(x_vals,best_fit,'r-',label='Best fit line',linewidth=3)
[<matplotlib.lines.Line2D object at 0x0000022159E49128>]
>>> plt.legend(loc='upper left')
<matplotlib.legend.Legend object at 0x000002216124BD30>
>>> plt.title('Sepal Length vs Pedal Width')
Text(0.5,1,'Sepal Length vs Pedal Width')
>>> plt.xlabel('Pedal Width')
Text(0.5,0,'Pedal Width')
>>> plt.ylabel('Sepal Length')
Text(0,0.5,'Sepal Length')

>>> plt.show()

>>> plt.plot(loss_vec,'k-')
[<matplotlib.lines.Line2D object at 0x0000022160D0AEF0>]
>>> plt.title('L2 Loss per Generation')
Text(0.5,1,'L2 Loss per Generation')
>>> plt.xlabel('Generation')
Text(0.5,0,'Generation')
>>> plt.ylabel('L2 Loss')
Text(0,0.5,'L2 Loss')

>>> plt.show()

用tensorflow实现线性回归算法相关推荐

  1. TensorFlow——实现线性回归算法

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt#使用numpy生成200个随机点 x_data=n ...

  2. TF之LiR:基于tensorflow实现机器学习之线性回归算法

    TF之LiR:基于tensorflow实现机器学习之线性回归算法 目录 输出结果 代码设计 输出结果 代码设计 # -*- coding: utf-8 -*-#TF之LiR:基于tensorflow实 ...

  3. Pytorch和Tensorflow在10000*1000数据规模线性回归算法中的运算速度对比

    Pytorch和Tensorflow在10000*1000数据规模线性回归算法中的运算速度对比 因为在学习人工智能相关知识,于是将学习过程与程序放在这里,希望对大家有帮助,共同学习,共同进步(不喜勿喷 ...

  4. Tensorflow实现线性回归

    Tensorflow实现线性回归 线性回归理论以及公式: 目标公式: y=w1x1+w2x2+⋯+wnxn+by=w1x1+w2x2+⋯+wnxn+b y=w_1x_1+w_2x_2+\cdots+w ...

  5. 利用TensorFlow解决线性回归问题

    利用TensorFlow解决线性回归问题 1.导入必要的库 import tensorflow as tf 在之前的基础上,还需要导入TensorFlow的库. 2.创建一个训练函数 def trai ...

  6. 线性回归算法原理及实现

    我们之前介绍了几种机器学习算法,这些机器学习算法都是用来进行分类的.今天换换口味,我们来了解一下如何进行回归,回归是基于已有的数据对新的数据进行预测,比如预测产品销量. 我们来看看最简单的线性回归,基 ...

  7. TensorFlow反向传播算法实现

    TensorFlow反向传播算法实现 反向传播(BPN)算法是神经网络中研究最多.使用最多的算法之一,用于将输出层中的误差传播到隐藏层的神经元,然后用于更新权重. 学习 BPN 算法可以分成以下两个过 ...

  8. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  9. 【机器学习入门】(8) 线性回归算法:正则化、岭回归、实例应用(房价预测)附python完整代码和数据集

    各位同学好,今天我和大家分享一下python机器学习中线性回归算法的实例应用,并介绍正则化.岭回归方法.在上一篇文章中我介绍了线性回归算法的原理及推导过程:[机器学习](7) 线性回归算法:原理.公式 ...

  10. 【机器学习入门】(7) 线性回归算法:原理、公式推导、损失函数、似然函数、梯度下降

    各位同学好,今天我和大家分享一下python机器学习中的线性回归算法.内容有: (1) 线性回归方程.(2) 损失函数推导.(3) 似然函数.(4) 三种梯度下降方法 1. 概念简述 线性回归是通过一 ...

最新文章

  1. 探索Java日志的奥秘:底层日志系统-log4j2
  2. 修改mysql编码方式
  3. strace监视系统调用
  4. MTK 驱动(66)---Android recovery UI实现分析
  5. 作为程序员,错过这次和以太坊V神的面基,你可能会痛失1个亿!
  6. 1.5封装数组之改进为泛型数组
  7. 设计鲁棒性的方法:输入一个链表的头结点,逆序遍历打印该链表出来
  8. 黑群晖drive套件的使用教程
  9. android 支付宝接口开发,android 实现支付宝wap接口编程
  10. Power BI 与企业数据安全
  11. 彻底了解HTTP模块
  12. python实现KD树
  13. STM32 HAL库学习笔记3-HAL库外设驱动框架概述
  14. 优秀后端架构师必会知识:史上最全MySQL大表优化方案总结
  15. 冯·诺依曼,天才中的天才
  16. 股票交易接口申请方式有哪几种?
  17. 实战:借助ucloud镜像加速功能下载镜像(亲测成功)-2022.1.1
  18. HTML5期末大作业:个人网页设计——薛之谦6页(代码质量好) 学生DW网页设计作业源码 web课程设计网页规划与设计
  19. 下载json文件,解决浏览器对JSON文件链接直接打开问题
  20. Android Studio 连接不上华为手机

热门文章

  1. asp.net页面去调用通过SSL加密的webservice报错
  2. 浅谈Event Loop
  3. Paper Read: Robust Deep Multi-modal Learning Based on Gated Information Fusion Network
  4. mysql Load Data InFile 的用法
  5. 大连理工大学计算机组织与结构实验,大连理工大学计算机系统结构实验-实验四.doc...
  6. 7-4 是不是顺子 (10 分)
  7. 字符串替换(NYOJ)
  8. mysql join 主表唯一_mysql left join 右表数据不唯一的情况解决方法
  9. idea 新建的java项目没发run_IntelliJ IDEA 如何创建一个普通的java项目,及创建java文件并运行...
  10. python中元组和列表的区别_Python 序列:列表、元组