线性拟合?叶子的长宽:

# Linear Regression: TensorFlow Way
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
#  y = Sepal Length
#  x = Petal Widthimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()# Create graph
sess = tf.Session()# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])# Declare batch size
batch_size = 25# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)# Declare loss function (L2 loss)
loss = tf.reduce_mean(tf.square(y_target - model_output))# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)# Training loop
loss_vec = []
for i in range(100):rand_index = np.random.choice(len(x_vals), size=batch_size)rand_x = np.transpose([x_vals[rand_index]])rand_y = np.transpose([y_vals[rand_index]])sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})loss_vec.append(temp_loss)if (i+1)%25==0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))print('Loss = ' + str(temp_loss))# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)# Get best fit line
best_fit = []
for i in x_vals:best_fit.append(slope*i+y_intercept)# Plot the result
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

转载于:https://www.cnblogs.com/bonelee/p/8996380.html

tensorflow 线性回归 iris相关推荐

  1. tensorflow 线性回归

    https://blog.csdn.net/zhangpengzp/article/details/81384154 学习tensorflow,希望与大家共同进步,下面讲述的是如何利用tensorfl ...

  2. tensorflow线性回归--拟合iris花瓣数据

    思路:用线性回归拟合鸢尾花花瓣长度和宽度之间的关系:y = Ax + b,其中 y 时花瓣长度,x是花瓣宽度. 建议有一点 tensorflow 基础再往下看. 下面是代码具体讲解. 先放结果吧 代码 ...

  3. 机器学习算法 09-02 TensorFlow核心概念 TensorFlow基础代码、TensorFlow线性回归解析解和BGD求法

    目录 1 核心概念 2 代码流程 3 基础代码: 3.1 tf的版本 定义常量 理解tensor 了解session 3. 2   指定设备.  Variable 初始化 .  with块创建sess ...

  4. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

  5. tensorflow实现iris分类

    内容来自MOOC<人工智能实践:Tensorflow笔记2> 八股搭建网络:我觉得是指按照固定模式搭建神经网络,八股只是呆板的意思,并不是按照八个步骤搭建.我们抨击的应该是八股取士制度,而 ...

  6. python 网页樱花动态图_python,tensorflow线性回归Django网页显示Gif动态图

    1.工程组成 2.urls.py """Django_machine_learning_linear_regression URL Configuration The ` ...

  7. python网页动图_python,tensorflow线性回归Django网页显示Gif动态图

    1.工程组成 2.urls.py """Django_machine_learning_linear_regression URL Configuration The ` ...

  8. tensorflow线性回归基础函数

    以下先使用tensorflow 的矩阵乘积,注意不是内积,然后使用基础函数求平方 .平方和.均值,这是使用tensorflow 使用线性回归分析的基础,基础好了,才能走出下一步,要不然怎么数据分析呢, ...

  9. tensorflow Elastic Net回归,拟合 iris 数据

    引言: 之前写过一篇 tensorflow线性回归–拟合iris花瓣数据.今天的Elastic Net Regression和它差不多,只不过是损失函数变了一下.我对Elastic Net 回归的理解 ...

最新文章

  1. 必读:ICLR 2020 的50篇推荐阅读论文
  2. 精通python要多久-小白到精通python要多久
  3. 详细分析 apache httpd 反向代理的用法
  4. 【thymeleaf】data-*
  5. float排版c语言,如何解决因float带来的排版问题?
  6. [置顶] mkdir函数-linux
  7. 动态创建 Web 服务器控件模板
  8. 混血网站诞生-公司相互嫁接成就新商业模式(转贴)
  9. Python机器学习(基础篇---监督学习(k近邻))
  10. 拓端tecdat|MATLAB用深度学习长短期记忆 (LSTM) 神经网络对智能手机传感器时间序列数据进行分类
  11. HDU 3695 / POJ 3987 Computer Virus on Planet Pandora
  12. x11 matlab仿真,基于MATLABSimulink的弹道仿真方法.pdf
  13. 从一个PHP数据生成 CSV 文件
  14. 微信小程序也可以实现定位打卡/签到打卡了(附源码)
  15. 解决Mac无法识别移动硬盘以及无法识别BootCamp Windows分区的问题
  16. Java代理模式概述及应用场景
  17. 【mysql】复制一张表的数据到另一张表
  18. iMovie 裁剪视频
  19. 2022-2027(新版)全球与中国鱼藤酮行业发展动态及前景展望报告
  20. 网页链接分享到微信里的海报制作

热门文章

  1. 和linux关系_Linux内核Page Cache和Buffer Cache关系及演化历史
  2. idea提交git差件_多人合作使用git,推送代码、和并分支
  3. python判断括号有效,在Python中检查括号是否平衡
  4. php mysql集群_PHP如何访问数据库集群
  5. maven 如何看jar是否被修改_如何在线修改jar文件
  6. linux下,每次git pull 或者git push都需要输入账号密码的问题以及git remote 的一些基本操作
  7. 十分钟就能回顾Spring常问的知识点,带你突击面试没问题!
  8. 【深度学习入门到精通系列】图片OCR讲解
  9. Android移动开发之【Android实战项目】Recyclerview添加花色分割线
  10. 并查集详解(从引入到代码)