线性回归的从零开始实现

我们将从零开始实现整个方法,包括数据流水线、模型、损失函数和小批量随机梯度下降优化器

%matplotlib inline
import random
import torch
from d2l import torch as d2l

数据集产生

根据带有噪声的线性模型构造一个人造数据集。
我们使用线性模型参数w=[2,−3.4]⊤\mathbf{w} = [2, -3.4]^\topw=[2,−3.4]⊤、b=4.2b = 4.2b=4.2和噪声项ϵ\epsilonϵ生成数据集及其标签:

y=Xw+b+ϵ\mathbf{y}= \mathbf{X} \mathbf{w} + b + \mathbf\epsilony=Xw+b+ϵ

def synthetic_data(w, b, num_examples):  """生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

features 中的每一行都包含一个二维数据样本,labels 中的每一行都包含一维标签值(一个标量)

print('features:', features[0], '\nlabel:', labels[0])
features: tensor([-0.5956, -0.5598])
label: tensor([4.9206])
d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(),labels.detach().numpy(), 1);

定义一个data_iter 函数,
该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量

batch_size设置

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i +batch_size, num_examples)])print(batch_indices)yield features[batch_indices], labels[batch_indices]batch_size = 10for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break
tensor([896, 973, 201, 842, 366, 358, 284, 388,  52, 698])
tensor([[ 1.3842,  0.6962],[ 1.5565, -0.0568],[-1.1065, -1.3475],[-1.2078, -1.7032],[-0.7833, -1.7924],[ 0.2734, -0.1906],[-0.8841, -0.9256],[-1.4493,  1.1998],[-0.7115, -1.2369],[ 0.6746,  1.0624]]) tensor([[ 4.5883],[ 7.4929],[ 6.5568],[ 7.5712],[ 8.7204],[ 5.3840],[ 5.5644],[-2.7621],[ 6.9495],[ 1.9339]])

定义
初始化模型参数

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

定义模型

def linreg(X, w, b):  """线性回归模型。"""return torch.matmul(X, w) + b

定义损失函数

def squared_loss(y_hat, y):  """均方损失。"""return (y_hat - y.reshape(y_hat.shape))**2 / 2

定义优化算法

def sgd(params, lr, batch_size):  """小批量随机梯度下降。"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()#梯度置为0

训练过程

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w, b], lr, batch_size)with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
epoch 1, loss 0.033103
epoch 2, loss 0.000124
epoch 3, loss 0.000054

比较真实参数和通过训练学到的参数来评估训练的成功程度

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
w的估计误差: tensor([ 0.0006, -0.0004], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0003], grad_fn=<RsubBackward1>)

线性回归的从零开始实现-08-p3相关推荐

  1. 动手深度学习:08 线性回归(线性回归的从零开始实现)(二)

    1.线性回归的从零开始实现 我们将从零开始实现整个方法,包括数据流水线.模型.损失函数和小批量随机梯度下降优化器 d2l包可以直接在conda的prompt里面输入命令 pip install -U ...

  2. 3.23.3 线性回归的从零开始实现|Pytorch简洁实现

    学习链接:李沐老师的动手深度学习v2书.视频链接 代码部分的理解笔记. 1.生成数据 2.读取数据集 3.初始化模型参数 4.定义模型 5.定义损失函数 6.定义优化算法 7.训练 import ra ...

  3. mysql数据库随机生成数据库_MySQL 从零开始:08 番外:随机生成数据库数据

    学习数据库时,难免需要一些数据进行实验,对于小数据量的数据来说,我们自己想一些数据并插入到数据库即可,但是如果需要大量的数据时,手动输入将是一项繁琐的工作,我们也不一定能编那么多数据.基于以上,自动生 ...

  4. [iTyran翻译]OpenGL ES 从零开始系列08:交叉存取顶点数据

    Technote 2230提出了很多用OpenGL ES来提升iphone程序性能的建议.我们现在远远不能深刻理解OpenGL ES所以你需要学习以下内容.不信?是真的,试试看,我等着你的读后感. 好 ...

  5. OpenGL ES 从零开始系列08:交叉存取顶点数据

    Technote 2230提出了很多用OpenGL ES来提升iphone程序性能的建议.我们现在远远不能深刻理解OpenGL ES所以你需要学习以下内容.不信?是真的,试试看,我等着你的读后感. 好 ...

  6. 从零开始学Pytorch之线性回归

    线性回归 主要内容包括: 线性回归的基本要素 线性回归模型从零开始的实现 线性回归模型使用pytorch的简洁实现 线性回归的基本要素 模型 为了简单起见,这里我们假设价格只取决于房屋状况的两个因素, ...

  7. 华南理工深度学习与神经网络期末考试_深度学习基础:单层神经网络之线性回归...

    3.1 线性回归 线性回归输出是一个连续值,因此适用于回归问题.回归问题在实际中很常见,如预测房屋价格.气温.销售额等连续值的问题.与回归问题不同,分类问题中模型的最终输出是一个离散值.我们所说的图像 ...

  8. 机器学习:线性回归I 最小二乘法

    原文在此 线性回归是最基础和常见的算法,属于监督学习的一种,是讲述算法开始的地方.我们在中学.大学学过很多次,虽然我已完全不记得.线性回归作为基础,虽然simple但不意味着easy,对其掌握很重要的 ...

  9. 《动手学深度学习》 第二天 (线性回归)

    3.2 线性回归的从零开始实现 只利用NDArray和autograd来实现一个线性回归的训练. 首先,导入本节中实验所需的包或模块,其中的matplotlib包可用于作图,且设置成嵌入显示. %ma ...

最新文章

  1. 某知名大学学生毕业设计,Java学好了就是厉害
  2. 51nod 2006 飞行员配对(二分图最大匹配) 裸匈牙利算法 求二分图最大匹配题
  3. RocketMQ的Producer详解之分布式事务消息(代码实现以及过程分析)
  4. 为DataList和GridView内容项添加序号
  5. 15.IDA-查看XREF列表(Ctrl+x)
  6. 为了防止程序重排序,慎用volatile
  7. Netty工作笔记0023---NIO服务器客户端总结
  8. Linux下Ipython安装
  9. EverWeb for Mac(网页设计软件)v3.5.1中文版
  10. 实验记录一 初步接触cortex-M3
  11. android 公式编辑器,公式编辑器
  12. 如何实现基于 RADIUS 协议的双因子认证 MFA?
  13. 操作系统:信号量机制之生产者与消费者实验
  14. CISSP考点拾遗——关于道德
  15. 前端学习-案例:制作一个超简单的静态页面
  16. UVa 10105 - Polynomial Coefficients
  17. 就是上来吐槽一下树莓派上,编译个模块都过不去的郁闷。 欢迎使用CSDN-markdown编辑器
  18. java机顶盒_Java技术在数字电视机顶盒中的应用
  19. 2020年国考申论备考:评价类(观点)题和理解类题目的辨析
  20. python 创建画板_OpenCV +Python 制作画板

热门文章

  1. Call to localhost/127.0.0.1:9000 failed on connection exception:java.net.ConnectException的解决方案
  2. Codeblocks中文乱码解决方法。
  3. 线上 ELK 集群健康值 red 状态问题排查与解决
  4. 清除或重新创建Ruby on Rails数据库
  5. 如何为项目中的单个文件禁用ARC?
  6. 如何知道对象在Python中是否具有属性
  7. (配置消息转换器)解决后台返回json数据到前台时页面时中文显示乱码问题
  8. Java代码TkMyBatis通用Mapper中新增数据时同时获取自增主键ID,与适用uuid 做主键时获取 id
  9. java 时间格式化 星期_Java SimpleDateFormate时间格式化
  10. ❤️《Mybatis从基础到高级》(建议收藏)❤️