Pytorch 神经网络训练过程
文章目录
- 1. 定义模型
- 1.1 绘制模型
- 1.2 模型参数
- 2. 前向传播
- 3. 反向传播
- 4. 计算损失
- 5. 更新参数
- 6. 完整简洁代码
参考 http://pytorch123.com/
1. 定义模型
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net_model(nn.Module):def __init__(self):super(Net_model, self).__init__()self.conv1 = nn.Conv2d(1,6,5) # 卷积# in_channels, out_channels, kernel_size, stride=1,# padding=0, dilation=1, groups=1,# bias=True, padding_mode='zeros'self.conv2 = nn.Conv2d(6,16,5)self.fc1 = nn.Linear(16*5*5, 120) # FC层self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = F.max_pool2d(x, (2,2))x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = x.view(-1, self.num_flat_features(x)) # 展平x = self.fc1(x)x = F.relu(x)x = self.fc2(x)x = F.relu(x)x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:] # 除了batch 维度外的维度num_features = 1for s in size:num_features *= sreturn num_featuresmodel = Net_model()
print(model)
输出:
Net_model((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
1.1 绘制模型
from torchviz import make_dot
vis_graph = make_dot(model(input),params=dict(model.named_parameters()))
vis_graph.view()
1.2 模型参数
params = list(model.parameters())
print(len(params))
for i in range(len(params)):print(params[i].size())
输出:
10
torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])
2. 前向传播
input = torch.randn(1,1,32,32)
out = model(input)
print(out)
输出:
tensor([[-0.1100, 0.0273, 0.1260, 0.0713, -0.0744, -0.1442, -0.0068, -0.0965,-0.0601, -0.0463]], grad_fn=<AddmmBackward>)
3. 反向传播
# 清零梯度缓存器
model.zero_grad()
out.backward(torch.randn(1,10)) # 使用随机的梯度反向传播
4. 计算损失
output = model(input)
target = torch.randn(10) # 举例用
target = target.view(1,-1) # 形状匹配 output
criterion = nn.MSELoss() # 定义损失类型
loss = criterion(output, target)
print(loss)
# tensor(0.5048, grad_fn=<MseLossBackward>)
- 测试
.zero_grad()
清零梯度缓存作用
model.zero_grad()
print(model.conv1.bias.grad)
loss.backward()
print(model.conv1.bias.grad)
输出:
tensor([0., 0., 0., 0., 0., 0.])
tensor([-0.0067, 0.0114, 0.0033, -0.0013, 0.0076, 0.0010])
5. 更新参数
learning_rate = 0.01
for f in model.parameters():f.data.sub_(f.grad.data*learning_rate)
6. 完整简洁代码
criterion = nn.MSELoss() # 定义损失类型
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.1)# 优化目标,学习率# 循环执行以下内容 训练
optimizer.zero_grad() # 清空梯度缓存
output = model(input) # 输入,输出,前向传播loss = criterion(output, target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数
Pytorch 神经网络训练过程相关推荐
- TensorFlow游乐园介绍及其神经网络训练过程
TensorFlow游乐场是一个通过网页浏览器就可以训练简单神经网络.并实现了可视化训练过程的工具.游乐场地址为http://playground.tensorflow.org/ 一.TensorFl ...
- 练习推导一个最简单的BP神经网络训练过程【个人作业/数学推导】
写在前面: 各式资料中关于BP神经网络的讲解已经足够全面详尽,故不在此过多赘述.本文重点在于由一个"最简单"的神经网络练习推导其训练过程,和大家一起在练习中一起更好理解神经网络训练 ...
- 神经网络的三种训练方法,神经网络训练过程图解
如何训练神经网络 1.先别着急写代码训练神经网络前,别管代码,先从预处理数据集开始.我们先花几个小时的时间,了解数据的分布并找出其中的规律. Andrej有一次在整理数据时发现了重复的样本,还有一次发 ...
- 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...
(文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...
- 神经网络测试集loss不变_神经网络训练过程中不收敛或者训练失败的原因
在面对模型不收敛的时候,首先要保证训练的次数够多.在训练过程中,loss并不是一直在下降,准确率一直在提升的,会有一些震荡存在.只要总体趋势是在收敛就行.若训练次数够多(一般上千次,上万次,或者几十个 ...
- 神经网络训练的一般步骤,神经网络训练过程详解
1.想要学习人工神经网络,需要什么样的基础知识? 人工神经网络理论百度网盘下载: 链接:https://pan.baidu.com/s/1Jcg4s2ETCrag2Vo-OA57Og 提取码:rxlc ...
- 神经网络的三种训练方法,神经网络训练过程详解
如何训练神经网络 1.先别着急写代码训练神经网络前,别管代码,先从预处理数据集开始.我们先花几个小时的时间,了解数据的分布并找出其中的规律. Andrej有一次在整理数据时发现了重复的样本,还有一次发 ...
- 人工神经网络的训练步骤,神经网络训练过程图解
如何通过人工神经网络实现图像识别 . 人工神经网络(ArtificialNeuralNetworks)(简称ANN)系统从20世纪40年代末诞生至今仅短短半个多世纪,但由于他具有信息的分布存储.并行处 ...
- 神经网络常用的训练方式,神经网络训练过程详解
神经网络参数如何确定 神经网络各个网络参数设定原则:①.网络节点 网络输入层神经元节点数就是系统的特征因子(自变量)个数,输出层神经元节点数就是系统目标个数.隐层节点选按经验选取,一般设为输入层节点 ...
最新文章
- Linxu终端gcc与gcc -c的区别
- CentOS 7 添加系统开机服务
- python3.8.5怎么用-Python 3.8 新功能大揭秘【新手必学】
- 【JUnit 报错】 method initializationerror not found:JUnit4单元测试报错问题
- 模板方法(钩子函数)设计模式
- 前端学习(977):本地存储导读
- 软件工程实践2017结对第二次作业
- mipi 调试经验【转】
- display:block jquery.sort()
- [Devcpp]为Devc自定义编译器及Devcpp路径读取的Bug
- FLUKE OTDR光纤断点测试仪OFP2-100-Q特色功能及亮点分析
- mysql登录其他电脑_如何连接另一台电脑的mysql数据库
- 神经网络中前向传播和反向传播解析
- Springboot mysql访问异常:User does not have access to metadata required to determine stored procedure
- 大创小组讨论会议纪要
- linux命令行界面上滑,获得Linux命令行平滑体验的5条技巧
- 使用测试客户端「玩转」MQTT 5.0
- 阿里妈妈佣金转换API接口代码对接教程
- linux直流电机测试,电机与动力系统测试
- Java System.currentTimeMillis()