标准的训练神经网络流程是:
1.定义包含权重的神经网络
2.遍历所有输入数据
3.处理数据
4.计算损失值
5.向前传播
6.更新权重

定义神经网络

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 3) #1 inputChannel,6 outputChannel,3*3 kernelself.conv2 = nn.Conv2d(6, 16, 3) #6 inputChannel,16 outputChannel,3*3 kernelself.fc1 = nn.Linear(16 * 6 * 6, 120) #16*6*6 inputFeatures(16 channel 6*6 image dimension), 120 outputFeaturesself.fc2 = nn.Linear(120, 84) #120 inputFeatures 84outputFeaturesself.fc3 = nn.Linear(84, 10) #84 inputFeatures 10outputFeaturesdef forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 卷积池化 (b,1,w,h) ->(b,6,(w-3+1)/2,(b-3+1/2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)# 卷积池化(b,6,w,h) ->(b,16,(w-3+1)/2,(b-3+1/2))x = x.view(-1, x.size()[1:].numel()) # (b,16,w,h) -> (b,16*w*h)x = F.relu(self.fc1(x)) #(b,16*6*6) ->(b,120)x = F.relu(self.fc2(x))#(b,120) ->(b,84)x = self.fc3(x)#(b,84) ->(b,10)return x
net = Net()
print(net)# https://stackoverflow.com/questions/53784998/how-are-the-pytorch-dimensions-for-linear-layers-calculated/53787076#53787076
----------------------------------------------------------------
Net((conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))(fc1): Linear(in_features=576, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

我们只需要定义forward方法,backward方法会自动定义(因为autograd的存在。
可以使用net.parameters()来查看权重系数。

params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight
---------------------
10
torch.Size([6, 1, 3, 3])

让我们使用32*32的输入。

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
----------------------------------------------
tensor([[ 0.0158,  0.0992, -0.1584, -0.0231,  0.0408, -0.0601, -0.0561,  0.0461,0.0854,  0.0818]], grad_fn=<AddmmBackward>)

Zero the gradient buffers of all parameters and backprops with random gradients:

net.zero_grad()
out.backward(torch.randn(1, 10))

torch.nn只能支持mini-batches,而不支持single sample。比如nn.Conv2d会将4D tensor作为输入,(nSamples * nChannels * Height * Width)。如果只有 a single sample 需要使用input.unsqueeze(0)来假装添加了batch dimension

2.损失函数

将 (output, target)作为参数,并计算 output和target的距离。
有几种不同的loss Functions,最简单的是nn.MSELoss计算mean-squared error

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()loss = criterion(output, target)
print(loss)
-------------------------------------------------------------
tensor(0.6443, grad_fn=<MseLossBackward>)

如果查看loss的反向传播,使用.grad_fn

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss

当我们使用loss.backward(),整个graph会计算loss,所有graph中带有requires_grad=True的tensor,将累加他们的gradient。

print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU
----------------------------------------------------------
<MseLossBackward object at 0x7f16dfd30ba8>
<AddmmBackward object at 0x7f16dfd550b8>
<AccumulateGrad object at 0x7f16dfd30ba8>

Backprop

我们需要清除已经存在的gradients。

net.zero_grad()     # zeroes the gradient buffers of all parametersprint('conv1.bias.grad before backward')
print(net.conv1.bias.grad)loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
------------------------------------------------
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0083,  0.0066,  0.0212, -0.0175, -0.0130,  0.0090])

更新权重

最简单的更新规则是Stochastic Gradient Descent (SGD):

weight = weight - learning_rate * gradient

手动实现

learning_rate = 0.01
for f in net.parameters(): # 遍历图中每个节点的参数f.data.sub_(f.grad.data * learning_rate) # 将节点的参数-(学习速率*梯度),单下划线表示替换

pytorch中已经实现了SGD等一系列的更新方法

import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()    # Does the update

参考:
https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py
https://stackoverflow.com/questions/53784998/how-are-the-pytorch-dimensions-for-linear-layers-calculated/53787076#53787076

pytorch 创建神经网络相关推荐

  1. 使用PyTorch创建神经网络

    2019年年初,ApacheCN组织志愿者翻译了PyTorch1.0版本中文文档(github地址),同时也获得了PyTorch官方授权,我相信已经有许多人在中文文档官网上看到了.不过目前校对还缺人手 ...

  2. 使用pytorch创建神经网络并解决线性拟合和分类问题

    #线性拟合 import torch import torch.nn.functional as F import matplotlib.pyplot as plt torch.manual_seed ...

  3. pytorch神经网络因素预测_实战:使用PyTorch构建神经网络进行房价预测

    微信公号:ilulaoshi / 个人网站:lulaoshi.info 本文将学习一下如何使用PyTorch创建一个前馈神经网络(或者叫做多层感知机,Multiple-Layer Perceptron ...

  4. 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像

    摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...

  5. 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类

    文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...

  6. PyTorch搭建神经网络求解二分类问题

    PyTorch搭建全连接神经网络求解二分类问题 在求解线性回归问题的时候,我们已经学习了如何使用梯度下降算法来不断更新权重矩阵,使误差函数不断减小,这一节我们将使用PyTorch搭建一个简单的神经网络 ...

  7. 【深度学习】基于Torch的Python开源机器学习库PyTorch卷积神经网络

    [深度学习]基于Torch的Python开源机器学习库PyTorch卷积神经网络 文章目录 1 CNN概述 2 PyTorch实现步骤2.1 加载数据2.2 CNN模型2.3 训练2.4 可视化训练 ...

  8. matlab神经网络工具箱创建神经网络,matlab神经网络工具箱创建神经网络

    matlab神经网络工具箱创建神经网络 为了看懂师兄的文章中使用的方法,研究了一下神经网络 昨天花了一天的时间查怎么写程序,但是费了半天劲,不能运行,百度知道里倒是有一个,可以运行的,先贴着做标本 % ...

  9. eclipse创建神经网络_使用Eclipse Deeplearning4j构建简单的神经网络

    eclipse创建神经网络 神经网络导论 深度学习包含深度神经网络和深度强化学习,它们是机器学习的子集,而机器学习本身就是人工智能的子集. 广义地说,深度神经网络执行机器感知,该机器感知从原始数据中提 ...

最新文章

  1. HDU1162(Prim算法)
  2. SAP MM 自定义条件类型出现在采购信息记录的'条件'界面里 ?
  3. Apache2.2+tomcat7 负载均衡配置
  4. Dependency injection in ASP.NET Core
  5. BZOJ4475: [Jsoi2015]子集选取【找规律】【数学】
  6. 【做题记录】人类智慧
  7. SpringMVC错误:No mapping found for HTTP request with URI [xxxx] in DispatcherServlet
  8. (07)FPGA基本组成单元
  9. Ubuntu 1604 升级 1804 记录
  10. CAM350简单编辑gerber文件(【增加一条线】 【复制元素】 【删除元素】)
  11. Linux的操作系统原理详解
  12. 基于ObjectArx进行cad二次开发总结
  13. 计算机考研 机械设计,2019考研专业:机械设计制造及其自动化
  14. 使用Python进行文件快速搜索(建立文件搜索索引)
  15. 笔记本光驱位固态硬盘安装及系统迁移
  16. codeforces 863B Kayaking
  17. 一文理解UDS安全访问服务(0x27)
  18. IntelliJ IDEA 创建Spring+SpringMVC+hibernate+maven项目
  19. 据说优秀的程序员都是这样送新年祝福的?
  20. 想备战 2022 ‘金三银四’ 必备超多软件测试面试题全在这里

热门文章

  1. Python文件操作与matplotlib数据可视化案例一则
  2. 使用Python+turtle绘制动画重现龟兔赛跑现场
  3. Python把汉字转换成拼音
  4. c++ vector常用用法总结
  5. 中职计算机英语教师教学总结,中职计算机教师教学工作总结 (3000字).doc
  6. linux终端打开浏览器_终端可以放电影,一行代码就能实现
  7. 力扣35,搜索插入位置(JavaScript)
  8. 华住数据库_华住内控人系列故事(四)技术领先篇——搭建大数据风险数据仓,实现自助取数...
  9. hive3新增資料_Hive表新增字段后,新字段无法写入值问题总结
  10. mysql查询只能是等式连接_mysql连接查询