# Author Qian Chenglong
import tensorflow as tf
import  numpy as np#生成100个随机数据点
x_date=np.random.rand(100)
y_date=x_date*0.1+0.2#构造一个线性模型
k=tf.Variable(0.)
b=tf.Variable(0.)
y=k*x_date+b# 二次代价函数
loss=tf.reduce_mean(tf.square(y-y_date))#最小二乘

my_optimizer=tf.train.GradientDescentOptimizer(0.2)#定义一个使用梯度下降算法的训练器
train=my_optimizer.minimize(loss)#训练目标loss最小

init=tf.global_variables_initializer()#初始化变量

with tf.Session() as sess:sess.run(init)for step in range(201):sess.run(train)if step%20==0:print(step, '[k,b]:', sess.run([k, b]))

 API说明:

np.random.rand(100)生成100个0~1之间的随机数

tf.square():计算元素的平方

tf.reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)

计算张量的各个维度上的元素的平均值。

axis是tf.reduce_mean函数中的参数,按照函数中axis给定的维度减少input_tensor。除非keep_dims是true,否则张量的秩将在axis的每个条目中减少1。如果keep_dims为true,则缩小的维度将保留为1。 如果axis没有条目,则减少所有维度,并返回具有单个元素的张量。

参数:

  • input_tensor:要减少的张量。应该有数字类型。
  • axis:要减小的尺寸。如果为None(默认),则减少所有维度。必须在[-rank(input_tensor), rank(input_tensor))范围内。
  • keep_dims:如果为true,则保留长度为1的缩小尺寸。
  • name:操作的名称(可选)。
  • reduction_indices:axis的不支持使用的名称。
tf.Variable(initializer, name):initializer是初始化参数,可以有tf.random_normal,tf.constant,tf.constant等,name就是变量的名字,用法如下:
a1 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a1')
a2 = tf.Variable(tf.constant(1), name='a2')
a3 = tf.Variable(tf.ones(shape=[2,3]), name='a3')
!

运行session.run()可以:

  1. 获得你要得到的运算结果;
  2. 你所要运算的部分;
 
#qiancl 666
import tensorflow as tf
import numpy as np
#学习率
learning_rate=0.01
#最大训练步数
max_train_step=1000
#np.array()矩阵
train_X_date=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[5.654],[9.27],[3.1]],dtype=np.float32)
train_Y_date=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[2.42],[2.94],[1.3]],dtype=np.float32)
#样本个数
tolal_samples=train_X_date.shape[0]
#输入数据占位
x=tf.placeholder(tf.float32,[None,1])
y_=tf.placeholder(tf.float32,[None,1])
#tf.random_normal([1,1])生成【1,1】的符合正态分布的随机数
w=tf.Variable(tf.random_normal([1,1]),name="weight")
b=tf.Variable(tf.zeros([1]),name="bias")
y=tf.matmul(x,w)+b
loss=tf.reduce_sum(tf.pow(y-y_,2))/tolal_samples#创建优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate)#训练目标
train_op=optimizer.minimize(loss)#训练
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("开始训练")for step in range(max_train_step):sess.run(train_op, feed_dict={x: train_X_date, y_: train_Y_date})if step % 100 == 0:c = sess.run(loss, feed_dict={x: train_Y_date, y_: train_Y_date})print("Step:%d, loss==%0.4f, w==%0.4f, b==%0.4f" % (step, c, sess.run(w), sess.run(b)))

转载于:https://www.cnblogs.com/long5683/p/10045957.html

使用Tensoflow实现梯度下降算法的一次线性拟合相关推荐

  1. 深度学习:梯度下降算法改进

    学习目标 目标 了解深度学习遇到的一些问题 知道批梯度下降与MiniBatch梯度下降的区别 知道指数加权平均的意义 知道动量梯度.RMSProp.Adam算法的公式意义 知道学习率衰减方式 知道参数 ...

  2. Udacity机器人软件工程师课程笔记(二十四) - 控制(其二) - PID优化,梯度下降算法,带噪声的PID控制

    7.非理想情况 (1)积分饱和 到目前为止,我们一直使用的"理想"形式的PID控制器很少用于工业中."时间常数"形式更为常见. 当前说明了理想形式的一些重大缺陷 ...

  3. 梯度下降算法_神经网络梯度下降算法

    神经网络梯度下降算法 2018, SEPT 13 梯度下降(Gradient Descent) 是神经网络比较重要的部分,因为我们通常利用梯度来利用Cost function(成本函数) 进行back ...

  4. 机器学习中,梯度下降算法的问题引入

    来源 | 动画讲编程 今天讲解的内容是梯度下降算法. 梯度下降算法在机器学习中的应用十分广泛,该算法的最主要目的是通过迭代的方法找到目标函数的最小值,经常用来解决线性回归和逻辑回归等相关问题.本节课主 ...

  5. 一文清晰讲解机器学习中梯度下降算法(包括其变式算法)

    本篇文章向大家介绍梯度下降(Gradient Descent)这一特殊的优化技术,我们在机器学习中会频繁用到. 前言 无论是要解决现实生活中的难题,还是要创建一款新的软件产品,我们最终的目标都是使其达 ...

  6. 把梯度下降算法变成酷炫游戏,这有一份深度学习通俗讲义

    公众号关注 "视学算法" 设为"星标",第一时间知晓最新干货~ 晓查 发自 凹非寺 转载自量子位 | 公众号 QbitAI 让小球滚下山坡,找到它们分别落在哪个 ...

  7. 简单的梯度下降算法,你真的懂了吗?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 梯度下降算法的公式非常简单,"沿着梯度的反方向(坡度最陡 ...

  8. 梯度下降算法的正确步骤是什么?

    梯度下降算法的正确步骤是什么? a.用随机值初始化权重和偏差 b.把输入传入网络,得到输出值 c.计算预测值和真实值之间的误差 d.对每一个产生误差的神经元,调整相应的(权重)值以减小误差 e.重复迭 ...

  9. python底层代码里面的参数_梯度下降算法讲解及python底层实现

    梯度下降法思路就是,开始随机选择参数组合,计算代价函数,寻找到下一个能让代价函数下降最快的参数组合(对某一参数的偏导方向),然后不断重复这一过程,直到找到一个局部最小值.因为并没有计算过所有的参数组合 ...

最新文章

  1. 网页瀑布流效果实现的几种方式
  2. Tensorflow2.6更新cuda11.2
  3. 第七章 PX4-Pixhawk-Mavlink解析
  4. mysql数据库replace写入_MySQL数据库replace into 用法(insert into 的增强版)
  5. uni-app(从零开始)
  6. 【转】C# 网络连接中异常断线的处理:ReceiveTimeout, SendTimeout 及 KeepAliveValues(设置心跳)
  7. Google对Gmail的所有通信进行SSL加密
  8. 【EF】Entity Framework Core 2.0 特性介绍和使用指南
  9. 22. 栈的压入、弹出队列(C++版本)
  10. shell 小米system锁adb_小米/红米系列手机解system分区锁方法详解
  11. Python3爬虫之咪咕音乐
  12. Android P WMS addwindow流程
  13. 用excel将有规律的数据随机打乱
  14. tony的js学习笔记--基础知识(随时更新)
  15. 4.图灵学院-----阿里/京东/滴滴/美团整理----高频MQ消息队列篇
  16. 蓝桥杯 2014-2 切面条
  17. java过滤汉字和英文,java判断及过滤汉字
  18. 实现一个简单的压测工具
  19. 团队管理9--新经理角色认知和角色转换
  20. 利用人性弱点的互联网产品(三)虚荣

热门文章

  1. sql server management studio 查询的临时文件路径
  2. Maven 使用代理下载依赖
  3. 为什么不需要对独立的jre进行环境变量配置
  4. 一个简单的synchronized多线程问题、梳理与思考
  5. 查看SQL Server Resource Database以及修改系统表
  6. uchome 模板引擎
  7. WebSphere Application Server中manageprofiles的使用
  8. 鼠标滑过GridView的数据行时修改行的背景颜色
  9. BPM实例分享:如何设置表单字体样式
  10. HTC VIVE 虚拟现实眼镜VR游戏体验