PyTorch的nn.Linear()详解
1. nn.Linear()
nn.Linear():用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
一般形状为[batch_size, size],不同于卷积层要求输入输出是四维张量。其用法与形参说明如下:
in_features
指的是输入的二维张量的大小,即输入的[batch_size, size]
中的size。out_features
指的是输出的二维张量的大小,即输出的二维张量的形状为[batch_size,output_size]
,当然,它也代表了该全连接层的神经元个数。从输入输出的张量的shape角度来理解,相当于一个输入为
[batch_size, in_features]
的张量变换成了[batch_size, out_features]
的输出张量。
用法示例:
import torch as t
from torch import nn
from torch.nn import functional as F# 假定输入的图像形状为[3,64,64]
x = t.randn(10, 3, 64, 64) # 10张 3个channel 大小为64x64的图片x = nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0)(x)
print(x.shape)# 之前的特征图尺寸为多少,只要设置为(1,1),那么最终特征图大小都为(1,1)
# x = F.adaptive_avg_pool2d(x, [1,1]) # [b, 64, h, w] => [b, 64, 1, 1]
# print(x.shape)# 将四维张量转换为二维张量之后,才能作为全连接层的输入
x = x.view(x.size(0), -1)
print(x.shape)# in_features由输入张量的形状决定,out_features则决定了输出张量的形状
connected_layer = nn.Linear(in_features = 64*21*21, out_features = 10)# 调用全连接层
output = connected_layer(x)
print(output.shape)
torch.Size([10, 64, 21, 21])
torch.Size([10, 28224])
torch.Size([10, 10])
PyTorch的nn.Linear()详解相关推荐
- torch.nn.Linear详解
在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解. 参考:https://pytorch.org/docs/stable/_module ...
- PyTorch中的torch.nn.Parameter() 详解
PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...
- 【小白学PyTorch】12.SENet详解及PyTorch实现
<<小白学PyTorch>> 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小白学PyTorch | 10 pytorch常见运算详解 小白学Py ...
- 【小白学PyTorch】11.MobileNet详解及PyTorch实现
<<小白学PyTorch>> 小白学PyTorch | 10 pytorch常见运算详解 小白学PyTorch | 9 tensor数据结构与存储结构 小白学PyTorch | ...
- 【nn.LSTM详解】
参数详解 nn.LSTM是pytorch中的模块函数,调用如下: torch.nn.lstm(input_size,hidden_size,num_layers,bias,batch_first,dr ...
- [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)
文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...
- 【Gans入门】Pytorch实现Gans代码详解【70+代码】
简述 由于科技论文老师要求阅读Gans论文并在网上找到类似的代码来学习. 文章目录 简述 代码来源 代码含义概览 代码分段解释 导入包: 设置参数: 给出标准数据: 构建模型: 构建优化器 迭代细节 ...
- 【小白学PyTorch】13.EfficientNet详解及PyTorch实现
<<小白学PyTorch>> 小白学PyTorch | 12 SENet详解及PyTorch实现 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小 ...
- pytorch中resnet_ResNet代码详解
代码学习第一天! fighting! import torch.nn as nn import math import torch.utils.model_zoo as model_zoo# 这个文件 ...
最新文章
- 优化案例(part5)--sparse subspace clustering via Low-Rank structure propagation
- jooq 入门_jOOQ,H2和Maven入门
- Exploring Pyramids【动态规划——区间DP】
- 13个美国大学生最常用的社交网络
- latext配置 vscode_新手关于在VScode上配置latex的事情
- .htaccess 基础教程(四)Apache RewriteCond 规则参数
- JDBC连接效率问题
- 第四章 可靠的请求-应答模式
- 国内顶尖团队的开源地址
- mysql定义过程_mysql定义和调用存储过程
- Gym - 102394I Interesting Permutation(思维)
- creo2.0+VS2010采用protoolkit二次开发环境配置(64位win7)
- python应用开发实战第一章 兽人之袭0.0.1
- errno!=EINTR是什么意思
- 在线tcp测试,tcp测试
- electron学习
- 物联网硬件安全分析基础-硬件分析初探
- 快速得到 Word2007 的 Docx 或 Docm 文档中的图片
- ImageMagic for win
- oracle网络认证,Oracle网络应用开发人员认证简介
热门文章
- Android ImageView的scaleType(图片比例类型)属性与adjustViewBounds(调整视图边界)属性
- 启动FastDFS服务,使用python客户端对接fastdfs完成上传测试
- android 手动回收对象,Android Studio Studio回收列表中的JSON对象
- 计算机应用基础精品课程申报表,《计算机应用基础》精品课程申报书(修改意见)...
- 九江机器人餐厅_机器人精通200道佳肴 九江学院来了多位机器厨神
- python数组的乘法_在Python中乘法非常大的2D数组
- python进程通信方式有几种_python全栈开发基础【第二十一篇】互斥锁以及进程之间的三种通信方式(IPC)以及生产者个消费者模型...
- JavaScript中的元素获取与操作
- Pycharm新建文件时自动添加基础信息
- python中函数的参数传递(传值还是传引用)