pytorch线性回归(笔记一)
代码部分:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.autograd import Variablex_data=np.random.rand(100)#数组中加入在[0,1]之间均匀分布的随机样本。构建x的值
noise=np.random.normal(0,0.01,x_data.shape)#生成高斯分布的概率密度随机数
y_data=x_data*0.01+0.2+noise #构建y的值
plt.scatter(x_data,y_data)#装载数据,看看坐标系下的数据分布
plt.show() #看图#把数据变成俩维度的
x_data=np.reshape(x_data,(-1,1))
y_data=np.reshape(y_data,(-1,1))#把俩维度的变成一个tensor
x_data=torch.FloatTensor(x_data)
y_data=torch.FloatTensor(y_data)#tensor变成变量
input=Variable(x_data)#表示输入
target=Variable(y_data)#表示输出#构建神经网络的模型
class LinearRegression(nn.Module):def __init__(self):#定义网络结构super(LinearRegression, self).__init__()#父类的初始化self.fc=nn.Linear(1,1)#线性回归,表示输入一个神经元,输出一个神经元,全连接层def forward(self,x):#定义网络计算,表示前向传递,pytorch默认做了后向传递out=self.fc(x) #表示把x值传给全连接层,返回一个y值,然后returnreturn out#定义模型(实例化模型)
model=LinearRegression()
#定义代价函数(均方差二次函数)
mes_loss =nn.MSELoss()
#定义优化器(随即梯度下降法)
optimizer=optim.SGD(model.parameters(),lr=0.1)#lr表示学习率,model.parameters()表示把参数传递进优化器#如何查看参数for name,parameters in model.named_parameters():print("name:{},param:{}".format(name,parameters))for i in range(1001):#训练1001次out=model(input) #把输入传递进去#计算loss损失函数loss=mes_loss(out,target) #根据输入和输出计算损失#梯度清零optimizer.zero_grad() #先把梯度清零,防止缓存#计算梯度loss.backward() #计算梯度#修改权值optimizer.step() #根据梯度,修改w和b的值if i%200 == 0: #每训练200次就打印一次损失函数print(i,loss.item())y_pred=model(input) #输入模型,得到预测值
plt.scatter(x_data,y_data) #装载训练数据
plt.plot(x_data,y_pred.data.numpy(),'-r',lw=3)#绘制预测数据图形,红色
plt.show()
F:\开发工具\pythonProject\tools\venv\Scripts\python.exe F:/开发工具/pythonProject/tools/bys/pychartools.py
name:fc.weight,param:Parameter containing:
tensor([[0.9746]], requires_grad=True)
name:fc.bias,param:Parameter containing:
tensor([0.2897], requires_grad=True)Process finished with exit code 0
fc.bias表示截取和偏置,fc.weight表示斜率
结果:
0 0.8948118686676025
200 0.0001185736691695638
400 0.00010005112562794238
600 0.00010000772454077378
800 0.00010000762267736718
1000 0.00010000763722928241
初始值是0.8948118686676025,经过一次次的训练,斜率约来越小,最后变成0.00010000763722928241
最终结果:
pytorch线性回归(笔记一)相关推荐
- 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】
深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...
- 深度学习入门之PyTorch学习笔记:卷积神经网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...
- 深度学习入门之PyTorch学习笔记:多层全连接网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...
- 深度学习入门之PyTorch学习笔记:深度学习框架
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...
- 深度学习入门之PyTorch学习笔记:深度学习介绍
深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...
- 深度学习入门之PyTorch学习笔记
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 5 循环神经网络 6 生成对抗网络 7 深度学习实战 参考资料 绪论 深度学习如今 ...
- PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard
文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...
- PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call
您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...
- pytorch学习笔记(二):gradien
pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030
- PyTorch学习笔记(二)——回归
PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...
最新文章
- 机器学习算法与Python实践之(六)二分k均值聚类
- Java常用API(六)Date 日期类介绍及使用
- Sherman-Morrison公式及其应用
- wps多人协作后怎么保存_蜂蜜开封后能放多久?蜂蜜开封后怎么保存?蜂蜜存放要注意事项...
- hapi 获取 请求地址 url
- 阿里Java编程规约(集合)
- YEAH!!距离拿回touch4倒计时:7days
- BGP ——路由过滤+路由聚合(讲解+配置)
- 卡尔曼滤波原理(二):扩展卡尔曼
- SAP License:委外业务产生的ML结算问题思考
- 无线通信定位一体化进展及其在煤矿井下应用分析
- java 如何执行dig 命令_如何在cmd下直接执行Dig命令
- maya mentray_新手快速掌握Maya Mental ray
- 手机号码变成空号导致亚马逊账号登陆两步验证失败的恢复网址及方法
- led的伏安特性曲线 matlab实现_小灯泡伏安特性曲线实验报告
- excel导出java不完整_有关Java POI导出excel表格中,单元格合并之后显示不全的解决方法。...
- 计算机网络 - 应用层
- Excel通过身份证号提取出生年月日(生日)/计算截至当前年龄
- 和菜鸟一起学android4.0.3源码之USB wifi移植心得
- 卸载事件off()方法
热门文章
- 数据结构(二十)二叉树的递归遍历算法
- Chapter 3 Phenomenon——6
- dynamic 找不到编译动态表达式所需的一种或多种类型。是否缺少引用?
- 牛根生--蒙牛创业故事
- 安装VS2008错误解决
- java struts2下载zip_Struts2多文件下载
- 启动白屏处理_App启动优化一顿操作猛如虎
- php 5.2.17 中文乱码,php5.2 Json中文乱码解决方法
- 如何查询以太信道接口_浅谈百兆千兆以太网物理层
- 疫情海报模板|光效显微传播大数据必备psd素材