pytorch 创建神经网络
标准的训练神经网络流程是:
1.定义包含权重的神经网络
2.遍历所有输入数据
3.处理数据
4.计算损失值
5.向前传播
6.更新权重
定义神经网络
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 3) #1 inputChannel,6 outputChannel,3*3 kernelself.conv2 = nn.Conv2d(6, 16, 3) #6 inputChannel,16 outputChannel,3*3 kernelself.fc1 = nn.Linear(16 * 6 * 6, 120) #16*6*6 inputFeatures(16 channel 6*6 image dimension), 120 outputFeaturesself.fc2 = nn.Linear(120, 84) #120 inputFeatures 84outputFeaturesself.fc3 = nn.Linear(84, 10) #84 inputFeatures 10outputFeaturesdef forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 卷积池化 (b,1,w,h) ->(b,6,(w-3+1)/2,(b-3+1/2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)# 卷积池化(b,6,w,h) ->(b,16,(w-3+1)/2,(b-3+1/2))x = x.view(-1, x.size()[1:].numel()) # (b,16,w,h) -> (b,16*w*h)x = F.relu(self.fc1(x)) #(b,16*6*6) ->(b,120)x = F.relu(self.fc2(x))#(b,120) ->(b,84)x = self.fc3(x)#(b,84) ->(b,10)return x
net = Net()
print(net)# https://stackoverflow.com/questions/53784998/how-are-the-pytorch-dimensions-for-linear-layers-calculated/53787076#53787076
----------------------------------------------------------------
Net((conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))(fc1): Linear(in_features=576, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
我们只需要定义forward
方法,backward
方法会自动定义(因为autograd
的存在。
可以使用net.parameters()
来查看权重系数。
params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weight
---------------------
10
torch.Size([6, 1, 3, 3])
让我们使用32*32的输入。
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
----------------------------------------------
tensor([[ 0.0158, 0.0992, -0.1584, -0.0231, 0.0408, -0.0601, -0.0561, 0.0461,0.0854, 0.0818]], grad_fn=<AddmmBackward>)
Zero the gradient buffers of all parameters and backprops with random gradients:
net.zero_grad()
out.backward(torch.randn(1, 10))
torch.nn
只能支持mini-batches
,而不支持single sample
。比如nn.Conv2d
会将4D tensor作为输入,(nSamples * nChannels * Height * Width)
。如果只有 a single sample 需要使用input.unsqueeze(0)
来假装添加了batch dimension
。
2.损失函数
将 (output, target)作为参数,并计算 output和target的距离。
有几种不同的loss Functions,最简单的是nn.MSELoss
计算mean-squared error
output = net(input)
target = torch.randn(10) # a dummy target, for example
target = target.view(1, -1) # make it the same shape as output
criterion = nn.MSELoss()loss = criterion(output, target)
print(loss)
-------------------------------------------------------------
tensor(0.6443, grad_fn=<MseLossBackward>)
如果查看loss
的反向传播,使用.grad_fn
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
当我们使用loss.backward()
,整个graph
会计算loss,所有graph
中带有requires_grad=True
的tensor,将累加他们的gradient。
print(loss.grad_fn) # MSELoss
print(loss.grad_fn.next_functions[0][0]) # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU
----------------------------------------------------------
<MseLossBackward object at 0x7f16dfd30ba8>
<AddmmBackward object at 0x7f16dfd550b8>
<AccumulateGrad object at 0x7f16dfd30ba8>
Backprop
我们需要清除已经存在的gradients。
net.zero_grad() # zeroes the gradient buffers of all parametersprint('conv1.bias.grad before backward')
print(net.conv1.bias.grad)loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)
------------------------------------------------
conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0083, 0.0066, 0.0212, -0.0175, -0.0130, 0.0090])
更新权重
最简单的更新规则是Stochastic Gradient Descent (SGD):
weight = weight - learning_rate * gradient
手动实现
learning_rate = 0.01
for f in net.parameters(): # 遍历图中每个节点的参数f.data.sub_(f.grad.data * learning_rate) # 将节点的参数-(学习速率*梯度),单下划线表示替换
pytorch中已经实现了SGD等一系列的更新方法
import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step() # Does the update
参考:
https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py
https://stackoverflow.com/questions/53784998/how-are-the-pytorch-dimensions-for-linear-layers-calculated/53787076#53787076
pytorch 创建神经网络相关推荐
- 使用PyTorch创建神经网络
2019年年初,ApacheCN组织志愿者翻译了PyTorch1.0版本中文文档(github地址),同时也获得了PyTorch官方授权,我相信已经有许多人在中文文档官网上看到了.不过目前校对还缺人手 ...
- 使用pytorch创建神经网络并解决线性拟合和分类问题
#线性拟合 import torch import torch.nn.functional as F import matplotlib.pyplot as plt torch.manual_seed ...
- pytorch神经网络因素预测_实战:使用PyTorch构建神经网络进行房价预测
微信公号:ilulaoshi / 个人网站:lulaoshi.info 本文将学习一下如何使用PyTorch创建一个前馈神经网络(或者叫做多层感知机,Multiple-Layer Perceptron ...
- 基于 PyTorch 和神经网络给 GirlFriend 制作漫画风头像
摘要:本文中我们介绍的 AnimeGAN 就是 GitHub 上一款爆火的二次元漫画风格迁移工具,可以实现快速的动画风格迁移. 本文分享自华为云社区<AnimeGANv2 照片动漫化:如何基于 ...
- 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类
文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...
- PyTorch搭建神经网络求解二分类问题
PyTorch搭建全连接神经网络求解二分类问题 在求解线性回归问题的时候,我们已经学习了如何使用梯度下降算法来不断更新权重矩阵,使误差函数不断减小,这一节我们将使用PyTorch搭建一个简单的神经网络 ...
- 【深度学习】基于Torch的Python开源机器学习库PyTorch卷积神经网络
[深度学习]基于Torch的Python开源机器学习库PyTorch卷积神经网络 文章目录 1 CNN概述 2 PyTorch实现步骤2.1 加载数据2.2 CNN模型2.3 训练2.4 可视化训练 ...
- matlab神经网络工具箱创建神经网络,matlab神经网络工具箱创建神经网络
matlab神经网络工具箱创建神经网络 为了看懂师兄的文章中使用的方法,研究了一下神经网络 昨天花了一天的时间查怎么写程序,但是费了半天劲,不能运行,百度知道里倒是有一个,可以运行的,先贴着做标本 % ...
- eclipse创建神经网络_使用Eclipse Deeplearning4j构建简单的神经网络
eclipse创建神经网络 神经网络导论 深度学习包含深度神经网络和深度强化学习,它们是机器学习的子集,而机器学习本身就是人工智能的子集. 广义地说,深度神经网络执行机器感知,该机器感知从原始数据中提 ...
最新文章
- HDU1162(Prim算法)
- SAP MM 自定义条件类型出现在采购信息记录的'条件'界面里 ?
- Apache2.2+tomcat7 负载均衡配置
- Dependency injection in ASP.NET Core
- BZOJ4475: [Jsoi2015]子集选取【找规律】【数学】
- 【做题记录】人类智慧
- SpringMVC错误:No mapping found for HTTP request with URI [xxxx] in DispatcherServlet
- (07)FPGA基本组成单元
- Ubuntu 1604 升级 1804 记录
- CAM350简单编辑gerber文件(【增加一条线】 【复制元素】 【删除元素】)
- Linux的操作系统原理详解
- 基于ObjectArx进行cad二次开发总结
- 计算机考研 机械设计,2019考研专业:机械设计制造及其自动化
- 使用Python进行文件快速搜索(建立文件搜索索引)
- 笔记本光驱位固态硬盘安装及系统迁移
- codeforces 863B Kayaking
- 一文理解UDS安全访问服务(0x27)
- IntelliJ IDEA 创建Spring+SpringMVC+hibernate+maven项目
- 据说优秀的程序员都是这样送新年祝福的?
- 想备战 2022 ‘金三银四’ 必备超多软件测试面试题全在这里
热门文章
- Python文件操作与matplotlib数据可视化案例一则
- 使用Python+turtle绘制动画重现龟兔赛跑现场
- Python把汉字转换成拼音
- c++ vector常用用法总结
- 中职计算机英语教师教学总结,中职计算机教师教学工作总结 (3000字).doc
- linux终端打开浏览器_终端可以放电影,一行代码就能实现
- 力扣35,搜索插入位置(JavaScript)
- 华住数据库_华住内控人系列故事(四)技术领先篇——搭建大数据风险数据仓,实现自助取数...
- hive3新增資料_Hive表新增字段后,新字段无法写入值问题总结
- mysql查询只能是等式连接_mysql连接查询