最近在学习pytorch的内容,写几篇博客记录一下学习过程。

熟悉pytorch的朋友知道nn.Module是nn中最重要的类,可以把它看作一个网络的封装,包含网络各层定义及forward方法,调用forward(input)方法,可返回前向传播的结果。

从最早的卷积神经网络LeNet为例,看看如何用nn.Module实现。LeNet的网络结构如图所示。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import  Variableclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 1代表输入图片未单通道,6表示输出通道数,5表示卷积核为5*5self.conv1 = nn.Conv2d(1, 6, 5) # 这里只能写三个参数,但是要理解输入其实是4维的,也必须是4维的# 卷积层self.conv2 = nn.Conv2d(6, 16, 5)# 仿射层/全连接层,y=Wx+bself.fc1 = nn.Linear(16*5*5, 120)print(666)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)# self.fc4 = nn.Linear(10,1) #zhouzdef forward(self, x):# 卷积->激活->池化x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)# reshape '-1'表示自适应x = x.view(x.size()[0], -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))# x = F.relu(self.fc3(x)) #zhouz 为了与原书中代码保持一致,暂时改回x = self.fc3(x)return x# print(7777)net = Net().cuda()
# print(net)# '测试输入为tensor时,输出size为(1,10)'
# input = Variable(torch.arange(0,10))
input = Variable(torch.randn(1, 1, 32, 32)).cuda()
# out = net(input)
# print(out)
# print(out.size())# '第一次测试结束'# '第二次测试开始'# for name,parameters in net.named_parameters():
#     print(name,':',parameters.size())# '第二次测试结束'# '第三次测试开始' 输出网络的可学习参数params = list(net.parameters())
print(len(params))# '第三次测试结束'# '第四次测试开始'
# net.zero_grad()
# out.backward(Variable(torch.ones(1,10))) # 反向传播
# '第四次测试结束'output = net(input)
target = Variable(torch.arange(1, 11)).float().cuda()
print(output.type == target.type)
print('output', output)
# print(target)criterion = nn.MSELoss()
loss = criterion(output, target)
print(loss)net.zero_grad()
print('反向传播之前的conv1.bias的梯度:')
print(net.conv1.bias.grad)
loss.backward()print('反向传播之后conv1.bias的梯度:')
print(net.conv1.bias.grad)# 在cpu跑会出错,在gpu跑能运行

调试的时候遇到了arange和range的问题,参考以下链接:https://www.jianshu.com/p/438c09259220

【12.27补充】 我在cpu上测试会出问题,拿到ubuntu上跑代码能正常运行,要改的地方包括float()和cuda(),改了之后就能直接运行了。

【Pytorch】LeNet的pytorch写法相关推荐

  1. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  2. Pytorch LeNet 3:网络输出可视化

    修改网络实现方式 为了是的LeNet可视化,我们需要修改下Pytorch/LeNet/LeNet.py的实现,将前面的卷积池化层,和后面的全连接层分开,便于可以独立获取卷积核池化层的所有特征,实现后的 ...

  3. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  4. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  5. 【pytorch速成】Pytorch图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...

  6. python pytorch fft_看PyTorch源代码的心路历程

    1. 起因 曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整. ...

  7. pytorch forward_【Pytorch部署】TorchScript

    TorchScript是什么? TorchScript - PyTorch master documentation​pytorch.org TorchScript是一种从PyTorch代码创建可序列 ...

  8. PyTorch系列 (二): pytorch数据读取自制数据集并

    PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 ...

  9. PyTorch入门(一)--PyTorch基础

    PyTorch基础 1. PyTorch与TensorFlow的区别 2. PyTorch基本数学形式 3. 关于Tensor 1. PyTorch与TensorFlow的区别 PyTorch和Ten ...

  10. PyTorch学习记录——PyTorch生态

    Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...

最新文章

  1. [蓝桥杯][历届试题]网络寻路(DFS)
  2. http --- 公开密钥加密(非对称加密)的几个概念
  3. vue delete删除json数组_vue面试题总结(二)
  4. android启动其他app的服务器,Android中通过外部程序启动App的三种方法
  5. python独立图形_在networkx中查找图形对象中的独立图形
  6. 每天Leetcode 刷题 初级算法篇-位1的个数
  7. ImportError: No module named ‘numpy‘的解决办法
  8. OpenCV-Python教程(10、直方图均衡化)
  9. 炮灰模型:对女生选择追求者的数学模型的建立-转
  10. LVGL使用华为鸿蒙字体
  11. 在线答题助手c语言源码,开源的在线答题小程序
  12. 山寨智能机多采用盗版Windows Mobile系统
  13. 利用WinEdt修改图片格式为eps
  14. 如何批量提取多个 PDF 文档中的图片
  15. 免费微信登陆界面html模板,微信小程序:使用微信授权登录以及页面模板
  16. java调用存储过程 sql server,Sql Server的存储过程与Java代码相连接调用(二)
  17. 我的青春谁做主经典台词
  18. matlab调用海康威视摄像头_招聘|海康威视招聘一批算法、图像等AI工程师
  19. Mycat概述及基本使用
  20. 一篇文章让你彻底理解java中抽象类和接口

热门文章

  1. 第八届“图灵杯”NEUQ-ACM程序设计竞赛个人赛——D题 Seek the Joker I
  2. 洛谷 P1983 [NOIP2013 普及组] 车站分级
  3. 实验6.2 定义一个基类BaseClass,观察构造函数和析构函数的执行情况。
  4. 《南溪的目标检测学习笔记》——主干网络backbone设计的学习笔记
  5. R语言正则表达式[stringr package]
  6. 给你出道题---N个数字的静态决策区分问题
  7. LeetCode 340. 至多包含 K 个不同字符的最长子串
  8. Unicode和ASCII的区别
  9. java.lang.InstantiationException:
  10. Python初学者的资源总结