pytorch学习:xavier分布和kaiming分布
1 函数的增益值
torch.nn.init.calculate_gain(nonlinearity, param=None)提供了对非线性函数增益值的计算。
增益值gain是一个比例值,来调控输入数量级和输出数量级之间的关系。
常见的非线性函数的增益值(gain)有:
2 fan_in和fan_out
以下是pytorch计算fan_in和fan_out的源码
def _calculate_fan_in_and_fan_out(tensor):dimensions = tensor.ndimension()if dimensions < 2:raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")#如果tensor的维度小于两维,那么报错if dimensions == 2: # Linearfan_in = tensor.size(1)fan_out = tensor.size(0)else:num_input_fmaps = tensor.size(1)num_output_fmaps = tensor.size(0)receptive_field_size = 1if tensor.dim() > 2:receptive_field_size = tensor[0][0].numel()#tensor[0][0].numel():tensor[0][0]元素的个数fan_in = num_input_fmaps * receptive_field_sizefan_out = num_output_fmaps * receptive_field_sizereturn fan_in, fan_out
- 对于全连接层,fan_in是输入维度,fan_out是输出维度;
- 对于卷积层,设其维度为,其中H × W为kernel规模。则fan_in是,fan_out是
3 Xavier初始化
xavier初始化可以使得输入值x的方差和经过网络层后的输出值y的方差一致。
3.1 xavier均匀分布
torch.nn.init.xavier_uniform_(tensor,gain=1)
填充一个tensor,使得这个tensor满足
其中
import torch
w = torch.empty(3, 5)
torch.nn.init.xavier_uniform_(w, gain=torch.nn.init.calculate_gain('relu'))
w
'''
tensor([[-0.3435, -0.4432, 0.1063, 0.6324, 0.3240],[ 0.6966, 0.6453, -1.0706, -0.9017, -1.0325],[ 1.2083, 0.5733, 0.7945, -0.6761, -0.9595]])
'''
3.2 xavier正态分布
torch.nn.init.xavier_normal_(tensor, gain=1)
填充一个tensor,使得这个tensor满足
其中,std满足
import torch
w = torch.empty(3, 5)
torch.nn.init.xavier_normal_(w, gain=torch.nn.init.calculate_gain('relu'))
w
'''
tensor([[ 0.2522, -1.3404, -0.7371, -0.0280, -0.9147],[-0.1330, -1.4434, -0.2913, -0.1084, -0.9338],[ 0.8631, 0.1877, 0.8003, -0.0865, 0.9891]])
'''
4 Kaiming 分布
Xavier在tanh中表现的很好,但在Relu激活函数中表现的很差,所何凯明提出了针对于relu的初始化方法。
pytorch默认使用kaiming正态分布初始化卷积层参数。
4.1 kaiming均匀分布
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
填充一个tensor,使得这个tensor满足U(−bound,bound)
其中,bound满足
a |
激活函数的负斜率(对于leaky_relu来说) 如果激活函数是relu的话,a为0 |
mode |
默认为fan_in模式,可以设置为fan_out模式 fan_in可以保持前向传播的权重方差的数量级,fan_out可以保持反向传播的权重方差的数量级 |
import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')'''
tensor([[ 0.8828, 0.0301, 0.9511, -0.0795, -0.9353],[ 1.0642, 0.8425, 0.1968, 0.9409, -0.7710],[ 0.3363, 0.9057, -0.1552, 0.5057, 1.0035]])
'''import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_out', nonlinearity='relu')
w
'''
tensor([[-0.0280, -0.5491, -0.4809, -0.3452, -1.1690],[-1.1383, 0.6948, -0.3656, 0.8951, -0.3688],[ 0.4570, -0.5588, -1.0084, -0.8209, 1.1934]])
'''
4.2 kaiming正态分布
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
参数的意义同4.1 kaiming均匀分布
填充一个tensor,使得这个tensor满足
其中,std满足
import torch
w = torch.empty(3, 5)
torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
w
'''
tensor([[ 0.9705, 1.6935, -0.4261, 1.1065, 1.0238],[-0.3599, -0.8007, 1.3687, 0.1199, 0.4077],[ 0.5240, -0.5721, -0.2794, 0.3618, -1.1206]])
'''
pytorch学习:xavier分布和kaiming分布相关推荐
- pytorch 学习: STGCN
1 main.ipynb 1.1 导入库 import random import torch import numpy as np import pandas as pd from sklearn. ...
- pytorch 学习笔记目录
1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...
- Pytorch 文本数据分析方法(标签数量分布、句子长度分布、词频统计、关键词词云)、文本特征处理(n-gram特征、文本长度规范)、文本数据增强(回译数据增强法)
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 文本数据分析 学习目标: 了解文本数据分析的作用. 掌握常用的 ...
- “强化学习说白了,是建立分布到分布之间的映射”?数学角度谈谈个人观点
简介:F学长是我数模竞赛.科研方法道路上的最重要的启蒙人之一. 去年他成功进入清华大学.巧的是,他的研究方向也是强化学习. 疫情期间,我们打过好几轮长长的电话,讨论强化学习,其中给我印象最为深刻的是, ...
- 深度学习:用于multinoulli输出分布的softmax单元
首先说明Bernoulli分布对应sigmoid单元,Multinoulli分布对应softmax单元.了解multinoulli分布请看:机器学习:Multinoulli分布与多项式分布. soft ...
- MNL——多项Logit模型学习笔记(三)二项Logit模型、Gumble分布以及Logistic分布
上一节最后一部分,介绍了Provit模型,从建模的角度来说,Probit模型假设随机项服从正态分布,这是具有一定的合理性的--也是其优点:但是Probit模型没有闭合解--每次算P(n)i 的值的时候 ...
- Lesson 13.5 Xavier方法与kaiming方法(HE初始化)
Lesson 13.5 Xavier方法与kaiming方法(HE初始化) 在进行了一系列的理论推导和代码准备工作之后,接下来,我们介绍参数初始化优化方法,也就是针对tanh和Sigmoid激活函 ...
- 【Pytorch学习笔记2】Pytorch的主要组成模块
个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...
- 深度学习入门之PyTorch学习笔记:多层全连接网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...
最新文章
- Meta祭出元宇宙「阿拉丁神灯」!LeCun称世界模型将带来像人一样的AI
- 如何在NLP领域应用卷积神经网络CNN
- vc mysql utf8_C/C++ 连接 MySQL (VC 版)
- Linux网络服务:Samba服务与实现
- hadoop 学习笔记:mapreduce框架详解
- 用ul li实现边框重合并附带鼠标经过效果
- linux常用网络命令
- PyQt5应用与实践
- 封装js千分位加逗号和删除逗号
- 陈一舟:我们花了大力气找合适团队接力人人网 任务完成
- spring 集成mybatis——多数据源切换(附带定时器的配置)
- linux 安装apache
- php发送sql,php学习笔记(二)php与mysql连接与用php发送SQL查询
- SqlServer2008 R2 自动备份和自动清除过期备份
- Pawn Storm网络间谍行动再度现身
- php留言板系统制作,php制作留言板讲解
- LaTex 英文期刊论文模板
- matlab 多项式排序,MATLAB多项式
- 电脑一会,电脑一会黑屏一会正常怎么回事
- OpenLayers3基础教程——OL3 介绍control