线性回归从0实现

代码的实现需要这么几个过程。

  • 数据
  • 读入随机打乱的数据,然后要分epoch。
  • 定义我们的模型,损失函数,优化算法
  • 定义好超参数
  • 开始 for epoch …这个过程需要根据超参数,predict pre_label,然后计算出损失的反向传播,根据优化算法去更新参数。
  • 最后记得打印每次的loss,acc,auc等参数。

我们读取 ⼀小批量训练样本,并通过我们的模型来获得⼀组预测。 计算完损失后,我们开始反向传播,存储每个参数 的梯度。最后, 我们调⽤优化算法sgd来更新模型参数

构造数据

#随机生成一批数据 。y = Xw + b + ϵ. ϵ代表噪音值。
true_b = 4.2
true_w = torch.tensor([2,-3.4])
num_examples = 1000
def synthetic_data(w,b,num_examples):X = torch.normal(0,1,(num_examples,len(w)))y = torch.matmul(X,w)+by+= torch.normal(0,0.01,y.shape)return X,y
features, labels = synthetic_data(true_w,true_b,num_examples)
labels.shape
#torch.Size([1000])

看一下构造的数据。

# 保证图片在浏览器内正常显示
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
plt.scatter(features[:,0],labels,marker='o',c='b')
plt.scatter(features[:,1],labels,marker='^',c='r')
plt.show()

定义数据随机读取函数:

特别注意yield字段。:yield 的函数在 Python 中被称之为 generator(生成器)

def data_iter(batch_size, features, labels):num_examples = len(features) indices = list(range(num_examples)) # 这些样本是随机读取的,没有特定的顺序 random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])#print(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]
import random
batch_size = 10for X, y in data_iter(batch_size, features, labels):print(X, '\n', y) break

初始化超参数

batch_size = 10
lr = 0.03
num_epochs = 3
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

定义模型

def linreg(X,w,b):"""线性回归模型"""return torch.matmul(X,w) + b

定义损失函数

def squared_loss(y_hat, y):#@save"""均⽅损失。"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

定义优化算法

def sgd(params,lr,batch_size):#不track 梯度with torch.no_grad():for param in params:param -= lr*param.grad / batch_sizeparam.grad.zero_()

train

net = linreg
loss = squared_loss
for epoch in range(num_epochs):for X, y in data_iter(batch_size,features,labels):l = loss(net(X,w,b),y)l.sum().backward()sgd([w,b],lr,batch_size)# 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features,w,b),labels)print(f'epoch {epoch + 1},loss {float(train_l.mean()):f}')

test

总结

pytorch backward
若在 torch 中 对定义的变量 requires_grad 的属性赋为 True ,那么此变量即可进行梯度以及导数的求解。
当c.backward() 语句执行后,会自动对 c 表达式中 的可求导变量进行方向导数的求解,并将对每个变量的导数表达存储到 变量名.grad 中

loss.sum().backward()中对于sum()的理解
一个向量是不进行backward操作的,而sum()后,由于梯度为1,所以对结果不产生影响。反向传播算法一定要是一个标量才能进行计算。
torch.no_grad()
是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度

参考

线性回归的从零开始实现
pytorch backward() 的一点简单的理解
loss.sum().backward()中对于sum()的理解
with torch.no_grad() 详解
Python yield 使用浅析

沐神的 《动手学深度学习》 课程中的 3.2节 线性回归的从零实现相关推荐

  1. 沐神《动手学深度学习》使用笔记

    1.引言 沐神提到自己看的三本书: <算法导论> <模式识别和机器学习>即PRML <统计学习基础> 一个教训:计算机科学是动手的学科,没有足够的动手能力难以取得很 ...

  2. Colab运行沐神《动手学深度学习》:ImportError: cannot import name ‘_check_savefig_extra_args‘ from ‘matplotlib.back

    原语句: num_epochs = 10 train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater) 报错: I ...

  3. 沐神《动手学深度实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

    沐神<动手学深度学习>飞桨版课程公开啦! hello各位飞桨的开发者,大家好!李沐老师的<动手学深度学习>飞桨版课程已经公开啦.本课程由PPSIG和飞桨工程师共同建设,将原书中 ...

  4. 李沐d2l《动手学深度学习》第二版——风格迁移源码详解

    本文是对李沐Dive to DL<动手学深度学习>第二版13.12节风格迁移的源码详解,整体由Jupyter+VSCode完成,几乎所有重要代码均给出了注释,一看就懂.需要的同学可以在文末 ...

  5. 动手学深度学习课程笔记ch02

    ch_02 线性代数 线性代数李老师讲得比较少,需要自己下去多看看书,后期还是需要一些矩阵论的知识. 基本知识 标量:由只有一个元素的张量表示(一般为数据的标签). # 创建标量进行运算 import ...

  6. 李沐老师《动手学深度学习》课程总结1

    数据操作 1. 创建数组:形状.数据类型.数据值 2. 访问元素:[1, 2] 访问第二行第三列 [1, :] 访问第二行 [1:3, 1:] 子区域:第二行至第四行前一行到第二列之后所有 [::3, ...

  7. 【Dive into Deep Learning / 动手学深度学习】第二章 - 第六节:概率

    目录 前言 2.6. 概率 2.6.1. 模拟扔骰子 2.6.2. 处理多个随机变量 2.6.2.1. 联合概率 2.6.2.2. 条件概率 2.6.2.3. 贝叶斯定理 2.6.2.4. 边际化 2 ...

  8. 动手学深度学习(十四)——权重衰退

    文章目录 1. 如何缓解过拟合? 2. 如何衡量模型的复杂度? 3. 通过限制参数的选择范围来控制模型容量(复杂度) 4. 正则化如何让权重衰退? 5. 可视化地看看正则化是如何利用权重衰退来达到缓解 ...

  9. 沐神-动手学深度学习-引言

    2022年暑假,本科毕业,准研究生的我准备在漫长的假期中学习些自己感兴趣的知识,恰好看到了B站中沐神的动手学深度学习系列视频,之后便开始跟着教程开始学习.在之后的学习过程中,渐渐发现自己的学习效率开始 ...

最新文章

  1. 马云牛啊 从骑自行车到坐迈巴赫只用20年
  2. cmd pc如何开多个微信_Win10下个人微信与企业微信多开
  3. 普及一下equals和==的区别的误区
  4. 距离高考出成绩,一年了、、、
  5. ppt生成器_小米发布会ppt词云怎么做的
  6. 缓存穿透、缓存击穿和缓存雪崩实践附源码
  7. Webydo:一款在线自由创建网站的 Web 应用
  8. Hibernate中hbm.xml文件的inverse、cascade、fetch、outer-join、lazy
  9. kswapd进程与swap、swappiness之间的关系及原理
  10. Excel数据的快速填充
  11. git如何拉去开发的 最新代码_git拉取代码到本地
  12. php实现推箱子游戏,C语言实现推箱子游戏的代码示例
  13. 声速的测量数据处理代码
  14. 华为手机怎么用云歌_华为手机语音助手怎么使用 看完你就知道了
  15. 计算机中的位,字节,字,字长的概念
  16. 中南大学计算机学院研究生录取分数线,2021中南大学
  17. 我在CSDN的2022:突破零粉丝,4个月涨粉4000+,2023年目标5万+
  18. fedora 安装与系统升级
  19. VScode快捷键(win + mac)
  20. 计算机应用技术博士,全国新增所大学计算机应用技术博士点

热门文章

  1. 快排解Top-K问题
  2. apicloud图片缓存的使用和查看清除缓存
  3. 判断一个整数是否为回文数
  4. 41、【斯纳克图书馆管理系统】编目流程 [ 准备工作]
  5. 存储一万亿张图片,需要怎样的架构?
  6. 关于RedisTemplate的ERR EXEC without MULTI错误
  7. 【硬十宝典】——1.2【基础知识】开关电源各种拓扑结构的特点
  8. 黑马韩前成linux从入门到精通之UNIX发展史
  9. 音频转文字怎么操作?快来看看这几个方法吧
  10. 小扎、马斯克宣战ChatGPT!Meta和推特组建顶级AI团队,硅谷硝烟四起