2020年10月4号,依然在家学习。
今天是我写的第一个 Pytorch程序,从今天起也算是入门了。
就从简单的线性回归开始吧。

话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的。

import torch# Step 1: ========创建模型========
# 定义一个类,继承自 torch.nn.Module,torch.nn.Module是callable的类
# 在整个类里面重新定义一个线性回归模型 y = wx+b
# 整个子类需要重写forward函数,
class LinearRegressionModel(torch.nn.Module):def __init__(self):# 调用父类的初始化函数,必须要的super(LinearRegressionModel, self).__init__()# 创建一个线性层,也是实例化一个torch.nn.Linear对象,输入数据是一维的,输出数据也是一维的,默认包含偏置参数# torch.nn.Linear也是callable的类self.linearLayer = torch.nn.Linear(1, 1)def forward(self, x):y_out = self.linearLayer(x)return y_out# 创建和实例化一个整个模型类的对象
LR_Model = LinearRegressionModel()
# 打印出整个模型
print(LR_Model)# Step 2: ========定义损失函数和优化器========
# 定义一个均方差误差损失函数 mean square error loss
LR_Criterion = torch.nn.MSELoss(size_average=True)
# 创建一个优化器,是用来做参数训练的,或者说是反向传播后更新参数,线性回归一般选择随机梯度下降,当然还有其他的梯度下降的方式。
# lr 就是learning rate,把模型的所有参数都交给优化器,反向传播中,优化器会递归地计算参数的偏导数以及做参数更新。
LR_Optimizer = torch.optim.SGD(LR_Model.parameters(), lr=0.1)# Step 3: ========得到数据========
# 为了方便演示和学习,这里我随意构造几个数据
# 大致w=2, b=1
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]])
y_data = torch.Tensor([[3.1], [5.0], [6.9], [9.1], [11.01], [13.1], [15.1], [16.9]])# Step 4: ========开始训练========
# 迭代进行训练
iteration = 30
for itr in range(iteration):# 计算前向传播,也就是计算输出y_output = LR_Model(x_data)# 得到损失值loss = LR_Criterion(y_output, y_data)print("in the ", itr, "th iteration, loss is", loss.item())# 反向传播,也就是对参数进行训练,需要注意的是,需要把优化其中上一次计算的梯度值清0LR_Optimizer.zero_grad()# 计算反向的各个参数的偏导数loss.backward()# 更新参数LR_Optimizer.step()# 打印出参数
print("w is: ", LR_Model.linearLayer.weight.item())
print("b is: ", LR_Model.linearLayer.bias.item())

测试结果如下:
跟预期的值一样。

之前我都是用matlab直接实现手写算法的,感觉用pytorch后,生了好多事儿。哈哈,今儿入门了。

《Pytorch - 线性回归模型》相关推荐

  1. ComeFuture英伽学院——2020年 全国大学生英语竞赛【C类初赛真题解析】(持续更新)

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  2. ComeFuture英伽学院——2019年 全国大学生英语竞赛【C类初赛真题解析】大小作文——详细解析

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  3. 信息学奥赛真题解析(玩具谜题)

    玩具谜题(2016年信息学奥赛提高组真题) 题目描述 小南有一套可爱的玩具小人, 它们各有不同的职业.有一天, 这些玩具小人把小南的眼镜藏了起来.小南发现玩具小人们围成了一个圈,它们有的面朝圈内,有的 ...

  4. 信息学奥赛之初赛 第1轮 讲解(01-08课)

    信息学奥赛之初赛讲解 01 计算机概述 系统基本结构 信息学奥赛之初赛讲解 01 计算机概述 系统基本结构_哔哩哔哩_bilibili 信息学奥赛之初赛讲解 02 软件系统 计算机语言 进制转换 信息 ...

  5. 信息学奥赛一本通习题答案(五)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  6. 信息学奥赛一本通习题答案(三)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  7. 信息学奥赛一本通 提高篇 第六部分 数学基础 相关的真题

    第1章   快速幂 1875:[13NOIP提高组]转圈游戏 信息学奥赛一本通(C++版)在线评测系统 第2 章  素数 第 3 章  约数 第 4 章  同余问题 第 5 章  矩阵乘法 第 6 章 ...

  8. 信息学奥赛一本通题目代码(非题库)

    为了完善自己学c++,很多人都去读相关文献,就比如<信息学奥赛一本通>,可又对题目无从下手,从今天开始,我将把书上的题目一 一的解析下来,可以做参考,如果有错,可以告诉我,将在下次解析里重 ...

  9. 信息学奥赛一本通(C++版) 刷题 记录

    总目录详见:https://blog.csdn.net/mrcrack/article/details/86501716 信息学奥赛一本通(C++版) 刷题 记录 http://ybt.ssoier. ...

  10. 最近公共祖先三种算法详解 + 模板题 建议新手收藏 例题: 信息学奥赛一本通 祖孙询问 距离

    首先什么是最近公共祖先?? 如图:红色节点的祖先为红色的1, 2, 3. 绿色节点的祖先为绿色的1, 2, 3, 4. 他们的最近公共祖先即他们最先相交的地方,如在上图中黄色的点就是他们的最近公共祖先 ...

最新文章

  1. 作为一名合格的前端开发工程师需要会哪些
  2. 用Swift实现一款天气预报APP(三)
  3. 我拷贝大文件的时候报“超过文件大小限制”错误,怎样突破这个限制?
  4. Nginx代理功能与负载均衡详解
  5. 原型设计20条军规(转)
  6. matlab八节点六面体程序,平面8节点等参元完整程序
  7. delphi 串口通信发送_关于串口通信232、485、422和常见问题,就没见过能讲这么清楚的...
  8. android viewpager 间隔,viewpager 系统兼容 clipChildren 页卡间距
  9. 使用Navicat为数据库表建立触发器
  10. cte公用表表达式_在SQL Server中使用CTE进行插入和更新(公用表表达式)
  11. ThinkPHP叫号系统
  12. 施工日志管理软件app_启用ERP装修管理软件的必要性
  13. Echarts数据可视化event图表事件的相关操作,开发全解+完美注释
  14. apache cgi python
  15. Android SharedPreferences
  16. 从转载阿里开源项目 Egg.js 技术文档引发的“版权纠纷”,看宽松的 MIT 许可该如何用?
  17. 计算机房精密空调术语,机房精密空调参数及含义
  18. 优化elelment ui 的 dialog 样式
  19. NekoHTML 和 XPath
  20. 从0开始建设saas - 优化篇(session访问的问题)

热门文章

  1. 2017 最值得关注的十大 APP、Web 界面设计趋势
  2. 物化视图常用维护操作
  3. iptables之NAT
  4. 数据状态更新时的差异 diff 及 patch 机制
  5. ES6知识点汇总(全)
  6. Nginx出现500 Internal Server Error 错误的解决方案
  7. 容器编排技术 -- Kubernetes kubectl create serviceaccount 命令详解
  8. HomeBrew 更换为国内源--提高brew命令操作速度
  9. Kubernetes如何赋能可再生能源产业提升10倍效率
  10. LayoutInflater.inflate()方法两个参数和三个参数