1 函数的增益值

torch.nn.init.calculate_gain(nonlinearity, param=None)提供了对非线性函数增益值的计算。

增益值gain是一个比例值,来调控输入数量级和输出数量级之间的关系。

常见的非线性函数的增益值(gain)有:

2 fan_in和fan_out

以下是pytorch计算fan_in和fan_out的源码

def _calculate_fan_in_and_fan_out(tensor):dimensions = tensor.ndimension()if dimensions < 2:raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")#如果tensor的维度小于两维,那么报错if dimensions == 2:  # Linearfan_in = tensor.size(1)fan_out = tensor.size(0)else:num_input_fmaps = tensor.size(1)num_output_fmaps = tensor.size(0)receptive_field_size = 1if tensor.dim() > 2:receptive_field_size = tensor[0][0].numel()#tensor[0][0].numel():tensor[0][0]元素的个数fan_in = num_input_fmaps * receptive_field_sizefan_out = num_output_fmaps * receptive_field_sizereturn fan_in, fan_out
  • 对于全连接层,fan_in是输入维度,fan_out是输出维度;
  • 对于卷积层,设其维度为,其中H × W为kernel规模。则fan_in是,fan_out是

3 Xavier初始化

xavier初始化可以使得输入值x的方差和经过网络层后的输出值y的方差一致。

3.1 xavier均匀分布

torch.nn.init.xavier_uniform_(tensor,gain=1)

填充一个tensor,使得这个tensor满足

其中

import torch
w = torch.empty(3, 5)
torch.nn.init.xavier_uniform_(w, gain=torch.nn.init.calculate_gain('relu'))
w
'''
tensor([[-0.3435, -0.4432,  0.1063,  0.6324,  0.3240],[ 0.6966,  0.6453, -1.0706, -0.9017, -1.0325],[ 1.2083,  0.5733,  0.7945, -0.6761, -0.9595]])
'''

3.2 xavier正态分布

torch.nn.init.xavier_normal_(tensor, gain=1)

填充一个tensor,使得这个tensor满足
其中,std满足

import torch
w = torch.empty(3, 5)
torch.nn.init.xavier_normal_(w, gain=torch.nn.init.calculate_gain('relu'))
w
'''
tensor([[ 0.2522, -1.3404, -0.7371, -0.0280, -0.9147],[-0.1330, -1.4434, -0.2913, -0.1084, -0.9338],[ 0.8631,  0.1877,  0.8003, -0.0865,  0.9891]])
'''

4 Kaiming 分布

Xavier在tanh中表现的很好,但在Relu激活函数中表现的很差,所何凯明提出了针对于relu的初始化方法。

pytorch默认使用kaiming正态分布初始化卷积层参数

4.1 kaiming均匀分布

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

填充一个tensor,使得这个tensor满足U(−bound,bound)

其中,bound满足

a

激活函数的负斜率(对于leaky_relu来说)

如果激活函数是relu的话,a为0

mode

默认为fan_in模式,可以设置为fan_out模式

fan_in可以保持前向传播的权重方差的数量级,fan_out可以保持反向传播的权重方差的数量级

import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')'''
tensor([[ 0.8828,  0.0301,  0.9511, -0.0795, -0.9353],[ 1.0642,  0.8425,  0.1968,  0.9409, -0.7710],[ 0.3363,  0.9057, -0.1552,  0.5057,  1.0035]])
'''import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_out', nonlinearity='relu')
w
'''
tensor([[-0.0280, -0.5491, -0.4809, -0.3452, -1.1690],[-1.1383,  0.6948, -0.3656,  0.8951, -0.3688],[ 0.4570, -0.5588, -1.0084, -0.8209,  1.1934]])
'''

4.2 kaiming正态分布

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数的意义同4.1 kaiming均匀分布

填充一个tensor,使得这个tensor满足
其中,std满足

import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
w
'''
tensor([[ 0.9705,  1.6935, -0.4261,  1.1065,  1.0238],[-0.3599, -0.8007,  1.3687,  0.1199,  0.4077],[ 0.5240, -0.5721, -0.2794,  0.3618, -1.1206]])
'''

pytorch学习:xavier分布和kaiming分布相关推荐

  1. pytorch 学习: STGCN

    1 main.ipynb 1.1 导入库 import random import torch import numpy as np import pandas as pd from sklearn. ...

  2. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  3. Pytorch 文本数据分析方法(标签数量分布、句子长度分布、词频统计、关键词词云)、文本特征处理(n-gram特征、文本长度规范)、文本数据增强(回译数据增强法)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 文本数据分析 学习目标: 了解文本数据分析的作用. 掌握常用的 ...

  4. “强化学习说白了,是建立分布到分布之间的映射”?数学角度谈谈个人观点

    简介:F学长是我数模竞赛.科研方法道路上的最重要的启蒙人之一. 去年他成功进入清华大学.巧的是,他的研究方向也是强化学习. 疫情期间,我们打过好几轮长长的电话,讨论强化学习,其中给我印象最为深刻的是, ...

  5. 深度学习:用于multinoulli输出分布的softmax单元

    首先说明Bernoulli分布对应sigmoid单元,Multinoulli分布对应softmax单元.了解multinoulli分布请看:机器学习:Multinoulli分布与多项式分布. soft ...

  6. MNL——多项Logit模型学习笔记(三)二项Logit模型、Gumble分布以及Logistic分布

    上一节最后一部分,介绍了Provit模型,从建模的角度来说,Probit模型假设随机项服从正态分布,这是具有一定的合理性的--也是其优点:但是Probit模型没有闭合解--每次算P(n)i 的值的时候 ...

  7. Lesson 13.5 Xavier方法与kaiming方法(HE初始化)

    Lesson 13.5 Xavier方法与kaiming方法(HE初始化)   在进行了一系列的理论推导和代码准备工作之后,接下来,我们介绍参数初始化优化方法,也就是针对tanh和Sigmoid激活函 ...

  8. 【Pytorch学习笔记2】Pytorch的主要组成模块

    个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...

  9. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

最新文章

  1. Meta祭出元宇宙「阿拉丁神灯」!LeCun称世界模型将带来像人一样的AI
  2. 如何在NLP领域应用卷积神经网络CNN
  3. vc mysql utf8_C/C++ 连接 MySQL (VC 版)
  4. Linux网络服务:Samba服务与实现
  5. hadoop 学习笔记:mapreduce框架详解
  6. 用ul li实现边框重合并附带鼠标经过效果
  7. linux常用网络命令
  8. PyQt5应用与实践
  9. 封装js千分位加逗号和删除逗号
  10. 陈一舟:我们花了大力气找合适团队接力人人网 任务完成
  11. spring 集成mybatis——多数据源切换(附带定时器的配置)
  12. linux 安装apache
  13. php发送sql,php学习笔记(二)php与mysql连接与用php发送SQL查询
  14. SqlServer2008 R2 自动备份和自动清除过期备份
  15. Pawn Storm网络间谍行动再度现身
  16. php留言板系统制作,php制作留言板讲解
  17. LaTex 英文期刊论文模板
  18. matlab 多项式排序,MATLAB多项式
  19. 电脑一会,电脑一会黑屏一会正常怎么回事
  20. OpenLayers3基础教程——OL3 介绍control

热门文章

  1. iOS原生地图与高德地图的使用
  2. Filter的详解与配置应用
  3. Ubuntu共享WiFi(AP)给Android方法
  4. 如何在Visual Studio 2010中使用CppUTest建立TDD的Code Kata的环境
  5. UNPIVOT的详细说明
  6. C语言SHELL排序算法
  7. 802.11 参考手册
  8. Python logging动态调整日志等级
  9. Socket连接外网的思考
  10. 服务器拒绝接收office文件,Ghost Win7系统下Outlook设置拒绝接收垃圾文件的方法