深度学习导论(2)深度学习案例:回归问题
深度学习导论(2)深度学习案例:回归问题
- 问题分析
- 优化方法
- 代码
- 采样数据
- 计算误差
- 计算梯度
- 梯度更新
- main函数
- 结果输出
这篇文章将介绍深度学习的小案例:回归问题的问题分析、优化以及实现代码。
问题分析
如果只采样两个点则会存在较大偏差(如蓝色线),为减小估计偏差,可通过采样多组样本点:
{(x(1),y(1)),(x(2),y(2)),...,(x(n),y(n))}\{(x^{(1)},y^{(1)}),(x^{(2)},y^{(2)}),...,(x^{(n)},y^{(n)})\}{(x(1),y(1)),(x(2),y(2)),...,(x(n),y(n))}
然后找出一条“最好”的直线,使得它尽可能地让所有采样点到该直线的误差(Error,或损失Loss)之和最小。
如:求出当前模型的所有采样点上的预测值wx(i)+bwx^{(i)}+bwx(i)+b与真实值y(i)y^{(i)}y(i)之间差的平方和作为总误差L:
L=1n∑i=1n(wx(i)+b−y(i))2L=\frac{1}{n}\sum_{i=1}^{n}(wx^{(i)}+b-y^{(i)})^{2}L=n1i=1∑n(wx(i)+b−y(i))2
即:均方差误差(Mean Squared Error,MSE)
优化方法
- 最简单的方法:暴力搜索,随机试验
- 常用方法:梯度下降方法(Gradient Descent)
(a)函数导数为0的点即为f(x)的驻点(函数取得极大、极小值时对应的自变量点)
函数$f(x,y)=-(cos^{2}x+cos^{2}y)^2$及其梯度
(b)函数的梯度(Gradient)定义为对各个自变量的偏导数(Partial Derivative)组成的向量
图中xy平面的红色箭头的长度表示梯度向量的模,箭头的方向表示梯度向量的方向。可以看到,箭头的方向总是指向当前位置函数值增速最大的方向,函数曲面越陡峭,箭头的长度也就越长,梯度的模也越大。
函数在各处的梯度方向▽f总是指向函数值增大的方向,那么梯度的反方向-▽f应指向函数值减少的方向。
按照:
x′=x−η⋅∇fx'=x-\eta\cdot\nabla{f}x′=x−η⋅∇f
迭代更新x,就能获得越来越小的函数值。
代码
采样数据
import numpy as np# 采样数据data = [] # 保存样本集的列表for i in range(100): # 循环采样100个点x = np.random.uniform(-10., 10.) # 随机采样输入xeps = np.random.normal(0., 0.1) # 采样高斯噪声y = 1.477 * x + 0.089 +eps # 得到模型的输出data.append([x, y]) # 保存样本点data = np.array(data) # 转换为2D Numpy数组
计算误差
# 计算误差def mse(b, w, points):# 根据当前的 w,b参数计算均方差损失totalError = 0for i in range(0, len(points)): # 循环迭代所有点x = points[i, 0] # 获得 i号点的输入 xy = points[i, 1] # 获得 i号点的输出 y# 计算差的平方,并累加totalError += (y - (w * x + b)) ** 2# 将累加的误差求平均,得到均方差return totalError / float(len(points))
计算梯度
# 计算梯度
# b_current:当前 b的值
# w_current:当前 w的值
# points:样本点集合
# lr:学习率def step_gradient(b_current, w_current, points, lr):# 计算误差函数在所有点上的导数,并更新 w,bb_gradient = 0w_gradient = 0M = float(len(points)) # 总样本数for i in range(0, len(points)):x = points[i, 0]y = points[i, 1]# 误差函数对 b的导数:grand_b = 2(wx+b-y)b_gradient += (2/M) * ((w_current * x +b_current) - y)# 误差函数对 w的导数:grand_w = 2(wx+b-y)*xw_gradient += (2/M) * x * ((w_current * x +b_current) - y)# 根据梯度下降算法更新 w,b,其中lr为学习率new_b = b_current - (lr * b_gradient)new_w = w_current - (lr * w_gradient)return [new_b, new_w]
梯度更新
# 梯度更新def gradient_descent(points, starting_b, starting_w, lr, num_iterations):# 循环更新 w,b多次b = starting_b # b的初始值w = starting_w # w的初始值# 根据梯度下降算法更新多次for step in range(num_iterations):# 计算梯度并更新一次b, w = step_gradient(b, w, np.array(points), lr)loss = mse(b, w, points) # 计算当前的均方差,用于监控训练进度if step%50 == 0: # 打印误差和实时的 w,b值print(f"iteration:{step}, loss:{loss}, w:{w}, b:{b}")return [b, w] # 返回最后一次的 w,b
main函数
# main函数
if __name__ == '__main__':# 加载训练集数据,这些数据是通过真实模型添加观测误差采样得到的lr = 0.01 # 学习率initial_b = 0 # 初始化 b为 0initial_w = 0 # 初始化 w为 0num_iterations = 1000# 训练优化1000次,返回最优 w*,b*和训练 Loss的下降过程[b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)losses = gradient_descent(data, initial_b, initial_w, lr, num_iterations)loss = mse(b, w, data) # 计算最优数值解 w,b上的均方差print(f'Final loss:{loss}, w:{w}, b:{b}')
结果输出
结果输出如下图所示:
在上述迭代过程中,迭代次数与均方差误差MSE之间的关系如下图:
由上图可以看出,虽然我们迭代了1000次,但是在100轮左右就已经收敛了,设在第n轮收敛,则我们只需要记录第n轮时的w和b的值就行,就是我们要求的w和b的值。有了w和b的值后,模型 y=wx+by=wx+by=wx+b就有了。
深度学习导论(2)深度学习案例:回归问题相关推荐
- 统计学习导论_统计学习导论 | 读书笔记15 | 广义可加模型
ISLR 7.7 广义可加模型 要点: 0.广义可加模型介绍 1.用于回归问题的GAM -- 多元线性回归的推广 2.用于分类问题的GAM -- 逻辑回归的推广 3.GAM的优点与不足 0. Gene ...
- 统计学习导论_统计学习导论 | 读书笔记11 | 多项式回归和阶梯函数
ISLR(7)- 非线性回归分析 多项式回归和阶梯函数 Note Summary: 0.从理想的线性到现实的非线性 1.多项式回归 2.Step Function 3.参考 0. Moving Bey ...
- 《强化学习导论》经典课程10讲,DeepMind大神David Silver主讲
点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要5分钟 Follow小博主,每天更新前沿干货 这个经典的10部分课程,由强化学习(RL)的驱David Silver教授,虽然录制于2015年 ...
- 增强学习导论 中文版
增强学习导论 中文版 个人学习,进行翻译,放到GitHub与各位交流. https://github.com/holazzer/RL-book-zh-cn https://holazzer.githu ...
- 免费教材丨第56期:《深度学习导论及案例分析》、《谷歌黑板报-数学之美》
小编说 离春节更近了! 本期教材 本期为大家发放的教材为:<深度学习导论及案例分析>.<谷歌黑板报-数学之美>两本书,大家可以根据自己的需要阅读哦! < ...
- 《深度学习导论及案例分析》一2.11概率图模型的推理
本节书摘来自华章出版社<深度学习导论及案例分析>一书中的第2章,第2.11节,作者李玉鑑 张婷,更多章节内容可以访问云栖社区"华章计算机"公众号查看. 2.11概率图模 ...
- 深度学习导论(6)误差计算
深度学习导论(6)误差计算 一. 再谈误差计算 二. 神经网络类型 三. 模型的容量和泛化能力 四. 过拟合与欠拟合 1. 过拟合(Overfitting) 2. 欠拟合(Underfitting) ...
- 深度学习导论(4)神经网络基础
深度学习导论(4)神经网络基础 一. 训练深度学习模型的步骤 二. 线性层(或叫全链接层)(Linear layer(dense or fully connected layers)) 1. 定义一个 ...
- 深度学习导论(3)PyTorch基础
深度学习导论(3)PyTorch基础 一. Tensor-Pytorch基础数据结构 二. Tensor索引及操作 1. Tensor索引类型 2. Tensor基础操作 3. 数值类型 4. 数值类 ...
最新文章
- 《统一沟通-微软-技巧》-20-Lync 2010如何在我的联系人列表中添加非联盟联系人...
- golang 获取昨天日期
- 皮一皮:这个单人可玩推理真是太好玩了...
- 直方图中最大的矩形(遍历与单调栈)
- Wireshark介绍 与 过滤器表达式语法
- Java使用独立数据库连接池(DBCP为例)
- Android 系统(226)---Android 阿拉伯语适配
- java socket 抓包_linux下用socket的抓包程序
- oracle中命令,oracle中常用命令汇总(一)
- 解决webpack4版本在打包时候出现Cannot read property ‘bindings‘ of null 或 Cannot find module ‘@babel/core‘问题
- 每天一算法(一)——用链表实现加减乘运算
- 报告发现最新版Java存在一个安全漏洞
- SVN下载安装及入门使用教程,详细到不能再详细了
- 亚马逊利润_大流行给亚马逊带来了创纪录的利润
- 问卷调查的数据分析怎么做
- 小米无线路由器服务器用户名和密码忘了,小米路由器无线密码(wifi密码)忘记了怎么办? | 192路由网...
- [ 2204阅读 ] 题型专项 - 句子简化题
- 布局福建市场,维也纳酒店欧暇·地中海酒店能否为投资人带来信心与底气?
- Android 监听wifi总结
- ESP32 HTTP 使用入门
热门文章
- linux内核驱动开发 培训,嵌入式Linux驱动开发培训 - 华清远见教育集团官网
- 槽函数会被执行多次的问题原因及解决方法
- thinkphp四种url访问方式详解
- php中var_dump()函数
- Mac下Nginx、PHP、MySQL 和 PHP-fpm安装配置
- tp5.0行为的用法,可以存入json数据,方便读取数据。
- Eclipse使用添加tomcat后,默认部署目录不是tomcat/webapps,修改方法如下
- poj 1995 Raising Modulo Numbers 二分快速幂
- JUnit简单使用教程
- [游泳] Sun Yang 1500 Swimming Stroke Analysis London 2012