整体学习目标

  • 建立属于你自己的深度学习框架
  • Python创建线性回归模型,L1损失函数,L2损失函数
  • 参数初始化
  • 掌握梯度下降算法,创建优化器函数
  • 学会设置学习率以避免梯度爆炸
  • 掌握多个常用激活函数,Sigmoid, Relu,Tanh,Leaky_Relu,避免梯度消失
  • 掌握链式法则,计算图,拓扑,前馈/反向网络

创建线性回归模型

概念:首先该模型主要解决的情况是:你有一堆线性数据,你需要根据已知的样本数据,去拟合出一个模型或者说一条线,这样当你有新的数据点的时候,你就可以根据之前拟合出的模型也就是线,来进行预测,比如根据房屋面积来预测房价。

解析:

  1. 你要寻找到一条线,你就需要知道这条线的斜率k和截距b,这样你才能画出这条线
  2. 如何找到或者说怎么算好的斜率k和截距b(也就是我们常说的参数),你需要损失函数,对应的损失函数越小,证明参数越好
  3. 损失函数可以选则,L2-loss和L1-loss,一会儿后面会解释什么是l1和l2
  4. 然后你还需要梯度下降法来更新参数即可
  5. 如下图,左图为l1-loss,右图为l2-loss

      

整体流程:

随机初始化参数,斜率k和截距b,然后通过y=kx+b这一公式,将样本点(x,y)中的x带入,然后会得到一个预测值y_predict,通过损失函数获得损失值,然后再反向梯度求导,获得参数的更新值,即可完成参数一次更新。


代码:

#Linear-regression 线性回归代码#加载数据,波士顿-房价预测数据
from sklearn.datasets import load_boston
data = load_boston()
#这是对应的训练数据,X和y
X, y = data['data'], data['target']#将数据以散点图呈现,这里X数据有13个特征,我们只用了第5个特征,room_size,也就是房间面积
%matplotlib inline
import matplotlib.pyplot as plt
plt.scatter(X[:, 5], y)

画出的散点图如下:

#定义损失函数,来衡量参数是否好,这里是l2-loss
def loss(y, y_hat):sum_ = sum([(y_i - y_hat_i) ** 2 for y_i, y_hat_i in zip(y, y_hat)])return sum_ / len(y)#下面是分别对k和b求梯度,这样保证每次,k和b都能朝着loss减小的方向更新
def partial_k(x, y, y_hat):gradient = 0 for x_i, y_i, y_hat_i in zip(list(x), list(y), list(y_hat)):gradient += (y_i - y_hat_i) * x_ireturn -2 / len(y) * gradientdef partial_b(y, y_hat):gradient = 0for y_i, y_hat_i in zip(list(y), list(y_hat)):gradient += (y_i - y_hat_i)return -2 / len(y) * gradient

对应的梯度公式推导如下图:

#最后训练即可#该函数计算房价
def price(x, k, b): # Operation : CNN, RNN, LSTM, Attention 比KX+B更复杂的对应关系return k*x + b#训练次数
trying_times = 50000
#初始loss值
min_cost = float('inf')losses = []scala = 0.3# 参数初始化问题! Weight Initizalition 问题!这也是一个大问题,这里我们先这样
k, b = random.random() * 100 - 200, random.random() * 100 - 200best_k, best_b = None, None#学习率
learning_rate = 1e-3  # Optimizer Ratefor i in range(trying_times):price_by_random_k_and_b = [price(r, k, b) for r in X_rm]cost = loss(list(y), price_by_random_k_and_b)if cost < min_cost: # print('在第{}, k和b更新了'.format(i))min_cost = cost#获得最好的k和bbest_k, best_b = k, blosses.append((i, min_cost))#获得参数需要更新的梯度k_gradient = partial_k(X_rm, y, price_by_random_k_and_b) # 变化的方向b_gradient = partial_b(y, price_by_random_k_and_b)#更新参数k = k + (-1 * k_gradient) * learning_rate## 优化器: Optimizer 这块也是一个研究方向## Adam 动量 momentumb = b + (-1 * b_gradient) * learning_rate

总结

至此,我们就完成了线性回归模型,大家感兴趣可以去尝试将loss函数修改成l1-loss并进行调试,降低loss,欢迎大家随时交流后,后面会陆续把这个部分更完,此次内容来自开课吧的训练营。


如何从0-1构建自己的”pytorch“(自己专属的深度学习框架)——part01相关推荐

  1. 如何从0-1构建自己的”pytorch“(自己专属的深度学习框架)——part02

    今日份学习目标 掌握激活函数 学习激活函数的意义 激活函数和线性变化之间能产生的作用 链式求导 反向传播 传播的顺序由拓扑排序来决定 拓扑排序的原理和实现过程 激活函数 世界中的很多真实关系都不是简单 ...

  2. Keras vs PyTorch:谁是第一深度学习框架?

    「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题.本文中,来自 deepsense.ai 的研究员给出了他们在高级框架上的答案.在 Keras 与 PyTorch 的对比中,作者还给出了 ...

  3. 从TensorFlow到PyTorch:九大深度学习框架哪款最适合你?

    人工智能AI与大数据技术实战  公众号: weic2c 开源的深度学习神经网络正步入成熟,而现在有许多框架具备为个性化方案提供先进的机器学习和人工智能的能力.那么如何决定哪个开源框架最适合你呢?本文试 ...

  4. Keras与PyTorch全方位比较 哪一个深度学习框架更适合初学者?

    Keras或PyTorch作为您的第一个深度学习框架 你想学习深度学习吗?无论您是想开始将其应用于您的业务,建立您的下一个项目,还是仅仅获得当下热门的技能 – 选择合适的深度学习框架来学习是实现目标的 ...

  5. 除了TensorFlow、PyTorch,还有哪些深度学习框架值得期待?

    分布式技术是深度学习技术的加速器. 同时利用多个工作节点,分布式地.高效地训练出性能优良的神经网络模型,能够显著提高深度学习的训练效率.进一步增大其应用范围. <首席AI架构师--分布式高性能深 ...

  6. 除了 Tensorflow、PyTorch ,还有哪些深度学习框架值得期待?

    分布式技术是深度学习技术的加速器. 同时利用多个工作节点,分布式地.高效地训练出性能优良的神经网络模型,能够显著提高深度学习的训练效率.进一步增大其应用范围. <首席AI架构师--分布式高性能深 ...

  7. 终极之战!TensorFlow与PyTorch谁最适合深度学习

    选自builtin 本文经机器之心授权转载,禁止二次转载 (微信公众号:almosthuman2014) 参与:吴攀.杜伟 谷歌的 Tensorflow 与 Facebook 的 PyTorch 一直 ...

  8. DL:深度学习框架Pytorch、 Tensorflow各种角度对比

    DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...

  9. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

最新文章

  1. Oracle不能在本地计算机启动,Windows 不能在本地计算机启动 OracleDBConsoleorcl的问题解决方法...
  2. The XOR Largest Pair(01trie模板题)
  3. LeetCode 832 Flipping an Image
  4. python画图哆啦a梦-Python—turtle画图(哆啦A梦)
  5. 并发编程中的GIL锁(全局解释器锁)自己理解的他为啥存在
  6. collection的iterator()方法
  7. Java多线程之实现多线程的三种方法
  8. Nginx负载均衡策略之轮询与加权轮询
  9. ASP.NET Core 中的静态文件
  10. jQuery dataTables四种数据来源[转]-原文地址:http://xqqing79.iteye.com/blog/1219425
  11. 100多个很有用的JavaScript函数以及基础写法大集合
  12. 苹果修复已遭在野利用的 iOS 和 macOS 0day
  13. 计算机怎么看网络密码,怎么查看电脑网络连接密码 - 卡饭网
  14. 【论文阅读】2018-基于深度学习的网络流量分类及异常检测方法研究_王伟
  15. 想留长发没那么难,30个让头发快速生长的秘诀~
  16. 程序员的量化交易之路(22)--Cointrader值货币集合Currencies(10)
  17. 中英离线翻译mac_Instant Translate for Mac-即时翻译Mac版下载 V1.3.0-PC6苹果网
  18. 除尘机器人毕业_【干货】焊接机器人除尘方式
  19. pandas 库简介
  20. Linux简介及常用命令

热门文章

  1. 吴恩达“官宣”荣升准爸爸~
  2. AI聚变:寻找2018最佳人工智能应用案例
  3. 今晚8点开播 | 深度解析知识图谱发展关键阶段技术脉络
  4. 网易有道周枫:AI正带来革命性变化,但在线教育的核心是内容
  5. 听完李厂长和雷布斯在乌镇讲AI段子,我突然理解为什么这两个男人选择在一起了
  6. 系统、应用监控的缜密思路,性能瓶颈的克星
  7. 用不惯VMware?试试这款更轻量级的虚拟机!
  8. MyBatis 框架下 SQL 注入攻击的 3 种方式,真是防不胜防!
  9. 8种方案解决重复提交问题
  10. 最常用的决策树算法!Random Forest、Adaboost、GBDT 算法