李弘毅机器学习笔记:回归演示
李弘毅机器学习笔记:回归演示
现在假设有10个x_data和y_data,x和y之间的关系是y_data=b+w*x_data。b,w都是参数,是需要学习出来的。现在我们来练习用梯度下降找到b和w。
x_data = [338., 333., 328., 207., 226., 25., 179., 60., 208., 606.]
y_data = [640., 633., 619., 393., 428., 27., 193., 66., 226., 1591.]
x_d = np.asarray(x_data)
y_d = np.asarray(y_data)
先给b和w一个初始值,计算出b和w的偏微分
# linear regression
b = -120
w = -4
lr = 0.0000001
iteration = 100000b_history = [b]
w_history = [w]import time
start = time.time()
for i in range(iteration):b_grad=0.0w_grad=0.0for n in range(len(x_data))b_grad=b_grad-2.0*(y_data[n]-n-w*x_data[n])*1.0w_grad= w_grad-2.0*(y_data[n]-n-w*x_data[n])*x_data[n]# update paramb -= lr * b_gradw -= lr * w_gradb_history.append(b)w_history.append(w)
# plot the figure
plt.subplot(1, 2, 1)
C = plt.contourf(x, y, Z, 50, alpha=0.5, cmap=plt.get_cmap('jet')) # 填充等高线
# plt.clabel(C, inline=True, fontsize=5)
plt.plot([-188.4], [2.67], 'x', ms=12, mew=3, color="orange")
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$')
plt.ylabel(r'$w$')
plt.title("线性回归")plt.subplot(1, 2, 2)
loss = np.asarray(loss_history[2:iteration])
plt.plot(np.arange(2, iteration), loss)
plt.title("损失")
plt.xlabel('step')
plt.ylabel('loss')
plt.show()
输出结果如图
横坐标是b,纵坐标是w,标记×位最优解,显然,在图中我们并没有运行得到最优解,最优解十分的遥远。那么我们就调大learning rate,lr = 0.000001(调大10倍),得到结果如下图。
我们再调大learning rate,lr = 0.00001(调大10倍),得到结果如下图。
结果发现learning rate太大了,结果很不好。
所以我们给b和w特制化两种learning rate
# linear regression
b = -120
w = -4
lr = 1
iteration = 100000b_history = [b]
w_history = [w]lr_b=0
lr_w=0
import time
start = time.time()
for i in range(iteration):b_grad=0.0w_grad=0.0for n in range(len(x_data))b_grad=b_grad-2.0*(y_data[n]-n-w*x_data[n])*1.0w_grad= w_grad-2.0*(y_data[n]-n-w*x_data[n])*x_data[n]lr_b=lr_b+b_grad**2lr_w=lr_w+w_grad**2# update paramb -= lr/np.sqrt(lr_b) * b_gradw -= lr np.sqrt(lr_w) * w_gradb_history.append(b)w_history.append(w)
# plot the figure
plt.subplot(1, 2, 1)
C = plt.contourf(x, y, Z, 50, alpha=0.5, cmap=plt.get_cmap('jet')) # 填充等高线
# plt.clabel(C, inline=True, fontsize=5)
plt.plot([-188.4], [2.67], 'x', ms=12, mew=3, color="orange")
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$')
plt.ylabel(r'$w$')
plt.title("线性回归")plt.subplot(1, 2, 2)
loss = np.asarray(loss_history[2:iteration])
plt.plot(np.arange(2, iteration), loss)
plt.title("损失")
plt.xlabel('step')
plt.ylabel('loss')
plt.show()
有了新的特制化两种learning rate就可以在10w次迭代之内到达最优点了。
李弘毅机器学习笔记:回归演示相关推荐
- 李弘毅机器学习笔记:第二章
李弘毅机器学习笔记:第二章 回归定义和应用例子 回归定义 应用举例 模型步骤 Step 1:模型假设 - 线性模型 一元线性模型(单个特征) 多元线性模型(多个特征) Step 2:模型评估 - 损失 ...
- 李弘毅机器学习笔记:第五章—分类
李弘毅机器学习笔记:第五章-分类 例子(神奇宝贝属性预测) 分类概念 神奇宝贝的属性(水.电.草)预测 回归模型 vs 概率模型 回归模型 其他模型(理想替代品) 概率模型实现原理 盒子抽球概率举例 ...
- 李弘毅机器学习笔记:第六章—Logistic Regression
李弘毅机器学习笔记:第六章-Logistic Regression logistic回归 Step1 逻辑回归的函数集 Step2 定义损失函数 Step3 寻找最好的函数 损失函数:为什么不学线性回 ...
- 李弘毅机器学习笔记:第七章—深度学习的发展趋势
李弘毅机器学习笔记:第七章-深度学习的发展趋势 回顾一下deep learning的历史: 1958: Perceptron (linear model) 1969: Perceptron has l ...
- 李弘毅机器学习笔记:第十二章—Recipe of Deep Learning
李弘毅机器学习笔记:第十二章-Recipe of Deep Learning 神经网络的表现 如何改进神经网络? 新的激活函数 梯度消失 怎么样去解决梯度消失? Adaptive Learning R ...
- [机器学习入门] 李弘毅机器学习笔记-7 (Brief Introduction of Deep Learning;深度学习简介)
[机器学习入门] 李弘毅机器学习笔记-7 (Brief Introduction of Deep Learning:深度学习简介) PDF VIDEO Ups and downs of Deep Le ...
- 李弘毅机器学习笔记:第十三章—CNN
李弘毅机器学习笔记:第十三章-CNN 为什么用CNN Small region Same Patterns Subsampling CNN架构 Convolution Propetry1 Propet ...
- 李弘毅机器学习笔记:第八章—Backprogation
李弘毅机器学习笔记:第八章-Backprogation 背景 梯度下降 链式法则 反向传播 取出一个Neuron进行分析 Forward Pass Backward Pass case 1 : Out ...
- 李弘毅机器学习笔记:第十六章—无监督学习
李弘毅机器学习笔记:第十六章-无监督学习 1-of-N Encoding 词嵌入 基于计数的词嵌入 基于预测的词嵌入 具体步骤 共享参数 训练 Various Architectures 多语言嵌入 ...
最新文章
- Java调用Oracle存储Package
- 解决Tomcat文件上传超时问题.
- 关于报表在移动端展现需你需要知道哪些?
- google js cdn_「效率工具」模拟CDN的浏览器扩展程序,改善在线隐私
- Flask笔记-通过Model访问数据库
- 【AI视野·今日Robot 机器人论文速览 第十二期】Tue, 22 Jun 2021
- iOS如何退出测试软件,如何继续测试iOS应用程序,使用UIAutomation仪器,甚至应用程序退出后?(How to c...
- 计算机云客户端技术指标,云服务器技术指标
- cshop是什么开发语言_C语言是用什么语言编写出来的?
- RocketMQ源码解析:Message存储
- HX710_24位电子秤AD采集
- Ph0thon字符串
- Unity渲染(四):Shader着色器基础入门之获取当前屏幕贴图
- 家用投影机预埋布线图_家庭影院装修如何布线(装修前必看·附图)
- japanhr日语小工具 日文汉字转平假名-japankana
- 《Java语言程序设计与数据结构》编程练习答案(第三章)(三)
- 用c语言做一个五子棋程序,C语言制作简单五子棋游戏
- 网盘搜索神器php源码,127网盘搜索源码|网盘资源搜索神器|thinkphp3.1.3框架开发的...
- opencv将Mat读入的图像的像素值打印在控制台上
- 《20岁无资本无未来》