线性拟合?叶子的长宽:

# Linear Regression: TensorFlow Way
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
#  y = Sepal Length
#  x = Petal Widthimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()# Create graph
sess = tf.Session()# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])# Declare batch size
batch_size = 25# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)# Declare loss function (L2 loss)
loss = tf.reduce_mean(tf.square(y_target - model_output))# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)# Training loop
loss_vec = []
for i in range(100):rand_index = np.random.choice(len(x_vals), size=batch_size)rand_x = np.transpose([x_vals[rand_index]])rand_y = np.transpose([y_vals[rand_index]])sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})loss_vec.append(temp_loss)if (i+1)%25==0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))print('Loss = ' + str(temp_loss))# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)# Get best fit line
best_fit = []
for i in x_vals:best_fit.append(slope*i+y_intercept)# Plot the result
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

转载于:https://www.cnblogs.com/bonelee/p/8996380.html

tensorflow 线性回归 iris相关推荐

  1. tensorflow 线性回归

    https://blog.csdn.net/zhangpengzp/article/details/81384154 学习tensorflow,希望与大家共同进步,下面讲述的是如何利用tensorfl ...

  2. tensorflow线性回归--拟合iris花瓣数据

    思路:用线性回归拟合鸢尾花花瓣长度和宽度之间的关系:y = Ax + b,其中 y 时花瓣长度,x是花瓣宽度. 建议有一点 tensorflow 基础再往下看. 下面是代码具体讲解. 先放结果吧 代码 ...

  3. 机器学习算法 09-02 TensorFlow核心概念 TensorFlow基础代码、TensorFlow线性回归解析解和BGD求法

    目录 1 核心概念 2 代码流程 3 基础代码: 3.1 tf的版本 定义常量 理解tensor 了解session 3. 2   指定设备.  Variable 初始化 .  with块创建sess ...

  4. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

  5. tensorflow实现iris分类

    内容来自MOOC<人工智能实践:Tensorflow笔记2> 八股搭建网络:我觉得是指按照固定模式搭建神经网络,八股只是呆板的意思,并不是按照八个步骤搭建.我们抨击的应该是八股取士制度,而 ...

  6. python 网页樱花动态图_python,tensorflow线性回归Django网页显示Gif动态图

    1.工程组成 2.urls.py """Django_machine_learning_linear_regression URL Configuration The ` ...

  7. python网页动图_python,tensorflow线性回归Django网页显示Gif动态图

    1.工程组成 2.urls.py """Django_machine_learning_linear_regression URL Configuration The ` ...

  8. tensorflow线性回归基础函数

    以下先使用tensorflow 的矩阵乘积,注意不是内积,然后使用基础函数求平方 .平方和.均值,这是使用tensorflow 使用线性回归分析的基础,基础好了,才能走出下一步,要不然怎么数据分析呢, ...

  9. tensorflow Elastic Net回归,拟合 iris 数据

    引言: 之前写过一篇 tensorflow线性回归–拟合iris花瓣数据.今天的Elastic Net Regression和它差不多,只不过是损失函数变了一下.我对Elastic Net 回归的理解 ...

最新文章

  1. 【廖雪峰python进阶笔记】类的继承
  2. oracle找到引起账户锁定的ip,Oracle 找到引起账户锁定的IP
  3. Python打包PyPI上传实践
  4. gbdt 算法比随机森林容易_用Python实现随机森林算法
  5. OpenGL Tessellated Triangle镶嵌三角形的实例
  6. python下常用OpenCV代码
  7. gulp如何保存后自动刷新?看这里就够了
  8. hbase 学习(十二)非mapreduce生成Hfile,然后导入hbase当中
  9. windows存储空间清理,C盘空间清理教程,磁盘清理方法
  10. 【python】urlencode、quote、unquote
  11. WMS入库作业_核心业务流程
  12. GBase 8a 集群维护工具C3介绍
  13. Eclipse的乱码问题是如何解决的
  14. 【 CF1186D,E,F】Vus the Cossack and Numbers/Vus the Cossack and a Field/Vus the Cossack and a Graph
  15. kali 设置中文并安装输入法
  16. web课题(仿百度+个人所得税计算)
  17. 游览慕田峪长城、红螺寺
  18. C语言小项目Conway‘s_Game_of_Life
  19. 【BZOJ2246】[SDOI2011]迷宫探险(搜索,动态规划)
  20. 算法设计与分析-----贪心法

热门文章

  1. 信息与计算机科学好学吗,计算机科学与技术好学吗?
  2. 八、H.264中的熵编码基本方法、指数哥伦布编码
  3. 计算机机械应用,浅析计算机技术在机械自动化的应用(原稿)
  4. 二叉搜索树的第k个节点java_剑指Offer62:二叉搜索树的第k个结点(Java)
  5. linux hrtimer 绑定cpu,Linux hrtimer分析--未配置高精度模式
  6. 再见SpringMVC!linuxkafka安装单机集群
  7. TCP的三次握手、四次挥手,含泪整理面经
  8. python【数据结构与算法】B树概念解析和实现
  9. 【响应式Web前端设计】CSS3伪类与伪元素的区别
  10. 【Deep Learning笔记】感知机模型和学习策略