使用Tensoflow实现梯度下降算法的一次线性拟合
# Author Qian Chenglong import tensorflow as tf import numpy as np#生成100个随机数据点 x_date=np.random.rand(100) y_date=x_date*0.1+0.2#构造一个线性模型 k=tf.Variable(0.) b=tf.Variable(0.) y=k*x_date+b# 二次代价函数 loss=tf.reduce_mean(tf.square(y-y_date))#最小二乘 my_optimizer=tf.train.GradientDescentOptimizer(0.2)#定义一个使用梯度下降算法的训练器 train=my_optimizer.minimize(loss)#训练目标loss最小 init=tf.global_variables_initializer()#初始化变量 with tf.Session() as sess:sess.run(init)for step in range(201):sess.run(train)if step%20==0:print(step, '[k,b]:', sess.run([k, b]))
API说明:
np.random.rand(100)生成100个0~1之间的随机数 tf.square():计算元素的平方 tf.reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)
计算张量的各个维度上的元素的平均值。
axis是tf.reduce_mean函数中的参数,按照函数中axis给定的维度减少input_tensor。除非keep_dims是true,否则张量的秩将在axis的每个条目中减少1。如果keep_dims为true,则缩小的维度将保留为1。 如果axis没有条目,则减少所有维度,并返回具有单个元素的张量。
参数:
- input_tensor:要减少的张量。应该有数字类型。
- axis:要减小的尺寸。如果为None(默认),则减少所有维度。必须在[-rank(input_tensor), rank(input_tensor))范围内。
- keep_dims:如果为true,则保留长度为1的缩小尺寸。
- name:操作的名称(可选)。
- reduction_indices:axis的不支持使用的名称。
tf.Variable(initializer, name):initializer是初始化参数,可以有tf.random_normal,tf.constant,tf.constant等,name就是变量的名字,用法如下:
a1 = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='a1') a2 = tf.Variable(tf.constant(1), name='a2') a3 = tf.Variable(tf.ones(shape=[2,3]), name='a3') !
运行session.run()可以:
- 获得你要得到的运算结果;
- 你所要运算的部分;
#qiancl 666 import tensorflow as tf import numpy as np #学习率 learning_rate=0.01 #最大训练步数 max_train_step=1000 #np.array()矩阵 train_X_date=np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[5.654],[9.27],[3.1]],dtype=np.float32) train_Y_date=np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[2.42],[2.94],[1.3]],dtype=np.float32) #样本个数 tolal_samples=train_X_date.shape[0] #输入数据占位 x=tf.placeholder(tf.float32,[None,1]) y_=tf.placeholder(tf.float32,[None,1]) #tf.random_normal([1,1])生成【1,1】的符合正态分布的随机数 w=tf.Variable(tf.random_normal([1,1]),name="weight") b=tf.Variable(tf.zeros([1]),name="bias") y=tf.matmul(x,w)+b loss=tf.reduce_sum(tf.pow(y-y_,2))/tolal_samples#创建优化器 optimizer=tf.train.GradientDescentOptimizer(learning_rate)#训练目标 train_op=optimizer.minimize(loss)#训练 with tf.Session() as sess:sess.run(tf.global_variables_initializer())print("开始训练")for step in range(max_train_step):sess.run(train_op, feed_dict={x: train_X_date, y_: train_Y_date})if step % 100 == 0:c = sess.run(loss, feed_dict={x: train_Y_date, y_: train_Y_date})print("Step:%d, loss==%0.4f, w==%0.4f, b==%0.4f" % (step, c, sess.run(w), sess.run(b)))
转载于:https://www.cnblogs.com/long5683/p/10045957.html
使用Tensoflow实现梯度下降算法的一次线性拟合相关推荐
- 深度学习:梯度下降算法改进
学习目标 目标 了解深度学习遇到的一些问题 知道批梯度下降与MiniBatch梯度下降的区别 知道指数加权平均的意义 知道动量梯度.RMSProp.Adam算法的公式意义 知道学习率衰减方式 知道参数 ...
- Udacity机器人软件工程师课程笔记(二十四) - 控制(其二) - PID优化,梯度下降算法,带噪声的PID控制
7.非理想情况 (1)积分饱和 到目前为止,我们一直使用的"理想"形式的PID控制器很少用于工业中."时间常数"形式更为常见. 当前说明了理想形式的一些重大缺陷 ...
- 梯度下降算法_神经网络梯度下降算法
神经网络梯度下降算法 2018, SEPT 13 梯度下降(Gradient Descent) 是神经网络比较重要的部分,因为我们通常利用梯度来利用Cost function(成本函数) 进行back ...
- 机器学习中,梯度下降算法的问题引入
来源 | 动画讲编程 今天讲解的内容是梯度下降算法. 梯度下降算法在机器学习中的应用十分广泛,该算法的最主要目的是通过迭代的方法找到目标函数的最小值,经常用来解决线性回归和逻辑回归等相关问题.本节课主 ...
- 一文清晰讲解机器学习中梯度下降算法(包括其变式算法)
本篇文章向大家介绍梯度下降(Gradient Descent)这一特殊的优化技术,我们在机器学习中会频繁用到. 前言 无论是要解决现实生活中的难题,还是要创建一款新的软件产品,我们最终的目标都是使其达 ...
- 把梯度下降算法变成酷炫游戏,这有一份深度学习通俗讲义
公众号关注 "视学算法" 设为"星标",第一时间知晓最新干货~ 晓查 发自 凹非寺 转载自量子位 | 公众号 QbitAI 让小球滚下山坡,找到它们分别落在哪个 ...
- 简单的梯度下降算法,你真的懂了吗?
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 梯度下降算法的公式非常简单,"沿着梯度的反方向(坡度最陡 ...
- 梯度下降算法的正确步骤是什么?
梯度下降算法的正确步骤是什么? a.用随机值初始化权重和偏差 b.把输入传入网络,得到输出值 c.计算预测值和真实值之间的误差 d.对每一个产生误差的神经元,调整相应的(权重)值以减小误差 e.重复迭 ...
- python底层代码里面的参数_梯度下降算法讲解及python底层实现
梯度下降法思路就是,开始随机选择参数组合,计算代价函数,寻找到下一个能让代价函数下降最快的参数组合(对某一参数的偏导方向),然后不断重复这一过程,直到找到一个局部最小值.因为并没有计算过所有的参数组合 ...
最新文章
- 网页瀑布流效果实现的几种方式
- Tensorflow2.6更新cuda11.2
- 第七章 PX4-Pixhawk-Mavlink解析
- mysql数据库replace写入_MySQL数据库replace into 用法(insert into 的增强版)
- uni-app(从零开始)
- 【转】C# 网络连接中异常断线的处理:ReceiveTimeout, SendTimeout 及 KeepAliveValues(设置心跳)
- Google对Gmail的所有通信进行SSL加密
- 【EF】Entity Framework Core 2.0 特性介绍和使用指南
- 22. 栈的压入、弹出队列(C++版本)
- shell 小米system锁adb_小米/红米系列手机解system分区锁方法详解
- Python3爬虫之咪咕音乐
- Android P WMS addwindow流程
- 用excel将有规律的数据随机打乱
- tony的js学习笔记--基础知识(随时更新)
- 4.图灵学院-----阿里/京东/滴滴/美团整理----高频MQ消息队列篇
- 蓝桥杯 2014-2 切面条
- java过滤汉字和英文,java判断及过滤汉字
- 实现一个简单的压测工具
- 团队管理9--新经理角色认知和角色转换
- 利用人性弱点的互联网产品(三)虚荣