一,nn.functional 和 nn.Module

前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。

利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模型层,损失函数)。

Pytorch和神经网络相关的功能组件大多都封装在 torch.nn模块下。

这些功能组件的绝大部分既有函数形式实现,也有类形式实现。

其中nn.functional(一般引入后改名为F)有各种功能组件的函数实现。例如:

(激活函数)

  • F.relu
  • F.sigmoid
  • F.tanh
  • F.softmax

(模型层)

  • F.linear
  • F.conv2d
  • F.max_pool2d
  • F.dropout2d
  • F.embedding

(损失函数)

  • F.binary_cross_entropy
  • F.mse_loss
  • F.cross_entropy

为了便于对参数进行管理,一般通过继承 nn.Module 转换成为类的实现形式,并直接封装在 nn 模块下。例如:

(激活函数)

  • nn.ReLU
  • nn.Sigmoid
  • nn.Tanh
  • nn.Softmax

(模型层)

  • nn.Linear
  • nn.Conv2d
  • nn.MaxPool2d
  • nn.Dropout2d
  • nn.Embedding

(损失函数)

  • nn.BCELoss
  • nn.MSELoss
  • nn.CrossEntropyLoss

实际上nn.Module除了可以管理其引用的各种参数,还可以管理其引用的子模块,功能十分强大。

二,使用nn.Module来管理参数

在Pytorch中,模型的参数是需要被优化器训练的,因此,通常要设置参数为 requires_grad = True 的张量。

同时,在一个模型中,往往有许多的参数,要手动管理这些参数并不是一件容易的事情。

Pytorch一般将参数用nn.Parameter来表示,并且用nn.Module来管理其结构下的所有参数。

import torch
from torch import nn
import torch.nn.functional  as F
from matplotlib import pyplot as plt# nn.Parameter 具有 requires_grad = True 属性
w = nn.Parameter(torch.randn(2,2))
print(w)
print(w.requires_grad)

# nn.ParameterList 可以将多个nn.Parameter组成一个列表
params_list = nn.ParameterList([nn.Parameter(torch.rand(8,i)) for i in range(1,3)])
print(params_list)
print(params_list[0].requires_grad)

# nn.ParameterDict 可以将多个nn.Parameter组成一个字典params_dict = nn.ParameterDict({"a":nn.Parameter(torch.rand(2,2)),"b":nn.Parameter(torch.zeros(2))})
print(params_dict)
print(params_dict["a"].requires_grad)
ParameterDict((a): Parameter containing: [torch.FloatTensor of size 2x2](b): Parameter containing: [torch.FloatTensor of size 2]
)
True
# 可以用Module将它们管理起来
# module.parameters()返回一个生成器,包括其结构下的所有parametersmodule = nn.Module()
module.w = w
module.params_list = params_list
module.params_dict = params_dictnum_param = 0
for param in module.parameters():print(param,"\n")num_param = num_param + 1
print("number of Parameters =",num_param)

#实践当中,一般通过继承nn.Module来构建模块类,并将所有含有需要学习的参数的部分放在构造函数中。#以下范例为Pytorch中nn.Linear的源码的简化版本
#可以看到它将需要学习的参数放在了__init__构造函数中,并在forward中调用F.linear函数来实现计算逻辑。class Linear(nn.Module):__constants__ = ['in_features', 'out_features']def __init__(self, in_features, out_features, bias=True):super(Linear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(out_features, in_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)def forward(self, input):return F.linear(input, self.weight, self.bias)

三,使用nn.Module来管理子模块

一般情况下,我们都很少直接使用 nn.Parameter来定义参数构建模型,而是通过一些拼装一些常用的模型层来构造模型。

这些模型层也是继承自nn.Module的对象,本身也包括参数,属于我们要定义的模块的子模块。

nn.Module提供了一些方法可以管理这些子模块。

  • children() 方法: 返回生成器,包括模块下的所有子模块。

  • named_children()方法:返回一个生成器,包括模块下的所有子模块,以及它们的名字。

  • modules()方法:返回一个生成器,包括模块下的所有各个层级的模块,包括模块本身。

  • named_modules()方法:返回一个生成器,包括模块下的所有各个层级的模块以及它们的名字,包括模块本身。

其中chidren()方法和named_children()方法较多使用。

modules()方法和named_modules()方法较少使用,其功能可以通过多个named_children()的嵌套使用实现。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.embedding = nn.Embedding(num_embeddings = 10000,embedding_dim = 3,padding_idx = 1)self.conv = nn.Sequential()self.conv.add_module("conv_1",nn.Conv1d(in_channels = 3,out_channels = 16,kernel_size = 5))self.conv.add_module("pool_1",nn.MaxPool1d(kernel_size = 2))self.conv.add_module("relu_1",nn.ReLU())self.conv.add_module("conv_2",nn.Conv1d(in_channels = 16,out_channels = 128,kernel_size = 2))self.conv.add_module("pool_2",nn.MaxPool1d(kernel_size = 2))self.conv.add_module("relu_2",nn.ReLU())self.dense = nn.Sequential()self.dense.add_module("flatten",nn.Flatten())self.dense.add_module("linear",nn.Linear(6144,1))self.dense.add_module("sigmoid",nn.Sigmoid())def forward(self,x):x = self.embedding(x).transpose(1,2)x = self.conv(x)y = self.dense(x)return ynet = Net()i = 0
for child in net.children():i+=1print(child,"\n")
print("child number",i)

下面我们通过named_children方法找到embedding层,并将其参数设置为不可训练(相当于冻结embedding层)

children_dict = {name:module for name,module in net.named_children()}print(children_dict)
embedding = children_dict["embedding"]
embedding.requires_grad_(False) #冻结其参数
#可以看到其第一层的参数已经不可以被训练了。
for param in embedding.parameters():print(param.requires_grad)print(param.numel())
False
30000
from torchkeras import summary
summary(net,input_shape = (200,),input_dtype = torch.LongTensor)
# 不可训练参数数量增加

速成pytorch学习——5天nn.functional 和 nn.Module相关推荐

  1. nn.functional 和 nn.Module入门讲解

    本文来自<20天吃透Pytorch> 一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的AP ...

  2. 速成pytorch学习——7天模型层layers

    深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.Flatten, nn ...

  3. 速成pytorch学习——8天损失函数

    一般来说,监督学习的目标函数由损失函数和正则化项组成.(Objective = Loss + Regularization) Pytorch中的损失函数一般在训练模型时候指定. 注意Pytorch中内 ...

  4. 速成pytorch学习——4天中阶API示范

    使用Pytorch的中阶API实现线性回归模型和和DNN二分类模型. Pytorch的中阶API主要包括各种模型层,损失函数,优化器,数据管道等等. 一,线性回归模型 1,准备数据 import nu ...

  5. 速成pytorch学习——1天

    一.Pytorch的建模流程 使用Pytorch实现神经网络模型的一般流程包括: 1,准备数据 2,定义模型 3,训练模型 4,评估模型 5,使用模型 6,保存模型. 对新手来说,其中最困难的部分实际 ...

  6. 速成pytorch学习——11天. 使用GPU训练模型

    深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天. 训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代. 当数据准 ...

  7. 速成pytorch学习——10天.训练模型的3种方法

    Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异. 有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环. 下面以minist数据集的分类模型的训练 ...

  8. 速成pytorch学习——6天Dataset和DataLoader

    Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...

  9. 速成pytorch学习——3天自动微分机制

    神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Pytorch一般通过反向传播 backward 方 ...

最新文章

  1. 正在通过iTunes Store 进行鉴定
  2. ASP.NET常见错误,原因及解决方法(2003版)_不断更新.....
  3. 2020-12-09 深度学习 卷积核/过滤器、特征图(featue map)、卷积层
  4. dbscan算法_DBSCAN聚类算法探索
  5. python n个人围成一圈,Python练习代码实例69-有n个人围成一圈,顺序排号。从第一个人开始报数(从1到3报数),凡报到3的人退出圈子,问最后留下的...
  6. kylin与superset集成实现数据可视化
  7. 07-OSPF区域类型--NSSA区域/完全NSSA区域
  8. 8类网线利弊_知识积累 | 千兆网线和百兆网线有何区别?
  9. 随机过程中的功率谱密度
  10. 810B - 牛人是如何工作的
  11. java对列_JAVA实现EXCEL行列号解析(一)——单格解析
  12. 前端随心记---------WebSocket
  13. 微信小程序简单爱心点赞动画
  14. 淘客菜鸟百度贴吧怎么发帖子操作淘宝客
  15. 计算机中c盘是什么分区,电脑C盘怎么分区
  16. reincarnation server
  17. HDWiki软件包结构
  18. Tp5 实现 think-queue 队列操作
  19. CentOS如何拓展swap分区
  20. ae中计算机打字预设,Typewriter Pro(AE电脑打字动画特效预设)

热门文章

  1. 谷歌趋势:“比特币”热度远不及2017年高点
  2. SAP License:第三只眼看财务-现金流量表编制
  3. 智慧发电厂+智能发电厂web端平台管理系统+Axure高保真智慧电厂系统+能耗管理+告警管理+生产监控+安防设备管理+运维设备管理+监控面板+系统管理+智慧电厂+电厂系统+axure源文件+rp原型
  4. MapInfo格式到ArcInfo格式的转换
  5. Bootstrap学习(一)
  6. 血眼龙王萧沙传-翠花篇
  7. poj 1005 I Think I Need a Houseboat
  8. managed code和unmanaged code混合debug
  9. ASP.NET深入浅出系列3- Page类
  10. 8.18 NOIP模拟测试25(B) 字符串+乌鸦喝水+所驼门王的宝藏