%matplotlib inline
import random
import torch
from d2l import torch as d2l
```python根据带有噪声的线性模型构造一个人造数据集。 我们使用线性模型参数w=[2,−3.4]⊤、b=4.2和噪声项ϵ生成数据集及其标签:\
def synthetic_data(w, b, num_examples):  """生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w)))
# 返回一个张量,包含从给定参数means,std的离散正态分布中抽取随机数
#                           行数是样本数量,列是w的长度y = torch.matmul(X, w) + b
#    两个张量矩阵相乘y += torch.normal(0, 0.01, y.shape)
#           随机噪声,均值为0,方差为0.01return X, y.reshape((-1, 1))
# 返回构造的X和y
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
# 构造了1000个训练数据
print('features:', features[0], '\nlabel:', labels[0])

features: tensor([ 1.2445, -0.1673])
label: tensor([7.2539])

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(),labels.detach().numpy(), 1);
d2l.set_figsize()
d2l.plt.scatter(features[:, 0].detach().numpy(),labels.detach().numpy(), 1);
```python
# 定义一个data_iter 函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量
def data_iter(batch_size, features, labels):num_examples = len(features)
# 训练数据总量。其中len函数返回矩阵的行数indices = list(range(num_examples))
# 对所有的数据进行标号random.shuffle(indices)
#     shuffle() 方法将序列的所有元素随机排序。for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i+batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]
batch_size = 10
for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break

tensor([[-0.6489, -1.2992],
[-1.4567, 0.6944],
[ 0.6492, 0.5349],
[ 0.6975, 0.7926],
[-0.0284, -0.3099],
[ 0.0941, -0.3360],
[ 0.0608, 0.4613],
[-0.8575, -0.3315],
[ 0.6550, -0.9486],
[ 0.1608, -0.1536]])
tensor([[ 7.3299],
[-1.0834],
[ 3.6786],
[ 2.8856],
[ 5.2030],
[ 5.5153],
[ 2.7383],
[ 3.6112],
[ 8.7448],
[ 5.0419]])

# 定义 初始化模型参数
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 定义模型
def linreg(X, w, b):  """线性回归模型。"""return torch.matmul(X, w) + b# 定义损失函数
def squared_loss(y_hat, y):  """均方损失。"""return (y_hat - y.reshape(y_hat.shape))**2 / 2# 定义优化算法
def sgd(params, lr, batch_size):  """小批量随机梯度下降。"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()# 训练过程
lr = 1
num_epochs =10
net = linreg
loss = squared_lossfor epoch in range(num_epochs):
# 对训练集迭代多次for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)
#         计算一个批量大小的损失函数l.sum().backward()
#        有个批量大小的损失函数求和sgd([w, b], lr, batch_size)
#         更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)
# 一个epoch迭代完成后,算损失print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

epoch 1, loss 0.000062
epoch 2, loss 0.000078
epoch 3, loss 0.000049
epoch 4, loss 0.000055
epoch 5, loss 0.000059
epoch 6, loss 0.000116
epoch 7, loss 0.000087
epoch 8, loss 0.000059
epoch 9, loss 0.000054
epoch 10, loss 0.000078

```python
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

w的估计误差: tensor([-0.0005, -0.0034], grad_fn=)
b的估计误差: tensor([0.0078], grad_fn=)

沐神动手深度学习 06线性回归从0开始实现相关推荐

  1. 动手深度学习13——计算机视觉:数据增广、图片分类

    文章目录 一.数据增广 1.1 为何进行数据增广? 1.2 常见图片增广方式 1.2.1 翻转 1.2.2 切割(裁剪) 1.2.3 改变颜色 1.2.4 综合使用 1.3 使用图像增广进行训练 1. ...

  2. 沐神-动手学深度学习-环境的配置

    在本次学习中,我使用Anaconda3进行环境的配置,使用Jupyter Notebook进行编程. 软件的安装我是在网上搜了个教程装的,结果很好可以使用.我着重对环境配置以及库函数的安装过程进行记录 ...

  3. 动手深度学习13:计算机视觉——语义分割、风格迁移

    文章目录 一.语义分割 1.1 语义分割简介 1.2 Pascal VOC2012 语义分割数据集 1.2.1下载.读取数据集 1.2.2 构建字典(RGB颜色值和类名互相映射) 1.2.3 数据预处 ...

  4. 基于Pycharm运行李沐老师的深度学习课程代码

    最近在b站看李沐老师的深度学习课程,受益颇多.不过觉得光看视频实在是不过瘾,最好还是能实际的玩起来.鉴于我还是习惯使用pycharm,且不需要过多的中间过程展示,所以代码的编写基本都是在pycharm ...

  5. 【动手深度学习-笔记】注意力机制(一)注意力机制框架

    生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...

  6. 动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet)

    动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet) 7.4. 含并行连结的网络(GoogLeNet) 7.4.1. Inception块 7.4.2. GoogLeNet模型 7 ...

  7. 深度学习原理-----线性回归+梯度下降法

    系列文章目录 深度学习原理-----线性回归+梯度下降法 深度学习原理-----逻辑回归算法 深度学习原理-----全连接神经网络 深度学习原理-----卷积神经网络 深度学习原理-----循环神经网 ...

  8. 动手深度学习笔记(一)2.1数据操作

    动手深度学习笔记(一) 2. 预备知识 2.1. 数据操作 2.1.1. 入门 2.1.2. 运算符 2.1.3. 广播机制 2.1.4. 索引和切片 2.1.5. 节省内存 2.1.6. 转换为其他 ...

  9. 动手深度学习笔记(四十五)8.1. 序列模型

    动手深度学习笔记(四十五)8.1. 序列模型 8.1. 序列模型 8.1.1. 统计工具 8.1.1.1. 自回归模型 8.1.1.2. 马尔可夫模型 8.1.1.3. 因果关系 8.1.2. 训练 ...

最新文章

  1. 销售流程管理-leangoo
  2. 英特尔用ViT做密集预测效果超越卷积,性能提高28%,mIoU直达SOTA|在线可玩
  3. android触摸外部关闭键盘,如何隐藏Android上的软键盘,点击外部EditText?
  4. [SDOI2009]HH去散步(矩阵)
  5. SQL Server 2008 R2:快速清除日志文件的方法
  6. 9-3:C++多态之多态的实现原理之虚函数表,虚函数表指针静态绑定和动态绑定
  7. scp复制linux系统的文件文件到本机(windows)以及本机文件复制到远程的命令
  8. .net core 介绍好文章
  9. android点击监听,android基础之点击监听器的2种监听实现
  10. day078_鼠标动起来
  11. java nfc_如何使用java创建简单的NFC程序?
  12. 【离散数学】「离散数学引论」学习笔记
  13. opnet共享代码开发
  14. 科学计算机怎么计算电工学向量,电工学常用单位计算与换算公式大全
  15. *帅帅老师,编写函数,将999-9999整数放入一个数组当中
  16. 博客-需求说明答辩总结
  17. Hololens学习(三)打包编译安装HoloLens2应用
  18. CINTA作业一:加减乘除
  19. 分享几个在线生成头像的网站
  20. 计算机网络实验:无线组网

热门文章

  1. 儿童玩具出口欧盟CE认证测试标准
  2. idea import javafx.util.Pair 异常原因
  3. 《剑指Offer》刷题之最小的K个数
  4. i.MX6ULL驱动开发 | 12 - 基于 Linux I2C 驱动读取AP3216C传感器
  5. 芭芭拉冲鸭~(dfs树两点最大距离)
  6. 执行git stash pop时的冲突解决
  7. 在Android Studio中删除module的方法
  8. python错误解决TypeError: () must be callable
  9. 小羊的暑假博客计划教程索引
  10. 删除字符串中出现次数最少的字符,汽水瓶,简单密码