Pytorch:初始化
4.4 初始化策略
在深度学习中参数的初始化十分重要,良好的初始化能让模型更快收敛,并达到更高水平,而糟糕的初始化则可能使得模型迅速瘫痪。PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,因此一般不用我们考虑,当然我们也可以用自定义初始化去代替系统的默认初始化。而当我们在使用Parameter时,自定义初始化则尤为重要,因t.Tensor()返回的是内存中的随机数,很可能会有极大值,这在实际训练网络中会造成溢出或者梯度消失。PyTorch中nn.init
模块就是专门为初始化而设计,如果某种初始化策略nn.init
不提供,用户也可以自己直接初始化。
In [55]:
# 利用nn.init初始化
from torch.nn import init
linear = nn.Linear(3, 4)t.manual_seed(1)
# 等价于 linear.weight.data.normal_(0, std)
init.xavier_normal_(linear.weight)
Out[55]:
Parameter containing:
tensor([[ 0.3535, 0.1427, 0.0330],[ 0.3321, -0.2416, -0.0888],[-0.8140, 0.2040, -0.5493],[-0.3010, -0.4769, -0.0311]])
In [56]:
# 直接初始化
import math
t.manual_seed(1)# xavier初始化的计算公式
std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0,std)
Out[56]:
tensor([[ 0.3535, 0.1427, 0.0330],[ 0.3321, -0.2416, -0.0888],[-0.8140, 0.2040, -0.5493],[-0.3010, -0.4769, -0.0311]])
In [57]:
# 对模型的所有参数进行初始化
for name, params in net.named_parameters():if name.find('linear') != -1:# init linearparams[0] # weightparams[1] # biaselif name.find('conv') != -1:passelif name.find('norm') != -1:pass
初始化的具体例子
# -*- coding: utf-8 -*-
"""
Created on 2019@author: fancp
"""import torch
import torch.nn as nnw = torch.empty(3,5)#1.均匀分布 - u(a,b)
#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
print(nn.init.uniform_(w))
# =============================================================================
# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],
# [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],
# [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])
# =============================================================================#2.正态分布 - N(mean, std)
#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
print(nn.init.normal_(w))
# =============================================================================
# tensor([[ 0.4388, 0.3083, -0.6803, -1.1476, -0.6084],
# [ 0.5148, -0.2876, -1.2222, 0.6990, -0.1595],
# [-2.0834, -1.6288, 0.5057, -0.5754, 0.3052]])
# =============================================================================#3.常数 - 固定值 val
#torch.nn.init.constant_(tensor, val)
print(nn.init.constant_(w, 0.3))
# =============================================================================
# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
# [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
# [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
# =============================================================================#4.全1分布
#torch.nn.init.ones_(tensor)
print(nn.init.ones_(w))
# =============================================================================
# tensor([[1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1.]])
# =============================================================================#5.全0分布
#torch.nn.init.zeros_(tensor)
print(nn.init.zeros_(w))
# =============================================================================
# tensor([[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]])
# =============================================================================#6.对角线为 1,其它为 0
#torch.nn.init.eye_(tensor)
print(nn.init.eye_(w))
# =============================================================================
# tensor([[1., 0., 0., 0., 0.],
# [0., 1., 0., 0., 0.],
# [0., 0., 1., 0., 0.]])
# =============================================================================#7.xavier_uniform 初始化
#torch.nn.init.xavier_uniform_(tensor, gain=1.0)
#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010
print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))
# =============================================================================
# tensor([[-0.1270, 0.3963, 0.9531, -0.2949, 0.8294],
# [-0.9759, -0.6335, 0.9299, -1.0988, -0.1496],
# [-0.7224, 0.2181, -1.1219, 0.8629, -0.8825]])
# =============================================================================#8.xavier_normal 初始化
#torch.nn.init.xavier_normal_(tensor, gain=1.0)
print(nn.init.xavier_normal_(w))
# =============================================================================
# tensor([[ 1.0463, 0.1275, -0.3752, 0.1858, 1.1008],
# [-0.5560, 0.2837, 0.1000, -0.5835, 0.7886],
# [-0.2417, 0.1763, -0.7495, 0.4677, -0.1185]])
# =============================================================================#9.kaiming_uniform 初始化
#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015
print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.7712, 0.9344, 0.8304, 0.2367, 0.0478],
# [-0.6139, -0.3916, -0.0835, 0.5975, 0.1717],
# [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])
# =============================================================================#10.kaiming_normal 初始化
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.0210, 0.5532, -0.8647, 0.9813, 0.0466],
# [ 0.7713, -1.0418, 0.7264, 0.5547, 0.7403],
# [-0.8471, -1.7371, 1.3333, 0.0395, 1.0787]])
# =============================================================================#11.正交矩阵 - (semi)orthogonal matrix
#torch.nn.init.orthogonal_(tensor, gain=1)
#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013
print(nn.init.orthogonal_(w))
# =============================================================================
# tensor([[-0.0346, -0.7607, -0.0428, 0.4771, 0.4366],
# [-0.0412, -0.0836, 0.9847, 0.0703, -0.1293],
# [-0.6639, 0.4551, 0.0731, 0.1674, 0.5646]])
# =============================================================================#12.稀疏矩阵 - sparse matrix
#torch.nn.init.sparse_(tensor, sparsity, std=0.01)
#From - Deep learning via Hessian-free optimization - Martens 2010
print(nn.init.sparse_(w, sparsity=0.1))
# =============================================================================
# tensor([[ 0.0000, 0.0000, -0.0077, 0.0000, -0.0046],
# [ 0.0152, 0.0030, 0.0000, -0.0029, 0.0005],
# [ 0.0199, 0.0132, -0.0088, 0.0060, 0.0000]])
# =============================================================================
Pytorch:初始化相关推荐
- pytorch初始化
转载自 https://blog.csdn.net/dss_dssssd/article/details/83959474 本文内容: 1. Xavier 初始化 2. nn.init 中各种初始化函 ...
- pytorch 初始化权重
一般的网络初始化方法: from torch.nn import functional as F, initdef init_params(self):for m in self.modules(): ...
- pytorch默认初始化_PyTorch的初始化
背景 在使用PyTorch深度学习框架的时候,不管是训练还是测试,代码中引入PyTorch的第一句总是: import torch 在Gemfield前述专栏文章里,我们已经得知,torch/csrc ...
- 树莓派4B (aarch64) 安装PyTorch 1.8 的可行方案
树莓派4B (aarch64) 安装PyTorch 1.8 的可行方案 最终可行方案 试了一堆方案(源码编译.Fast.ai的安装文件等)之后,终于找到一个可行的方案.是在 PyTorch 官方讨论社 ...
- pytorch中的激励函数(详细版)
初学神经网络和pytorch,这里参考大佬资料来总结一下有哪些激活函数和损失函数(pytorch表示) 首先pytorch初始化: import torch import torch.nn.fun ...
- PyTorch中文文档阅读笔记-day1
写在开头(重复的) 1.课程来源:torch中文教程1.7版. torch中文文档. 2.笔记目的:个人学习+增强记忆+方便回顾 3.时间:2021年4月29日 4.仅作为个人笔记,如有需要请务必按照 ...
- (五)使用生成对抗网络 (GAN)生成新的时装设计
目录 介绍 预测新时尚形象的力量 构建GAN 初始化GAN参数和加载数据 从头开始构建生成器 从头开始构建鉴别器 初始化GAN的损失和优化器 下一步 下载源 - 120.7 MB 介绍 DeepFas ...
- 【KG】TransE 及其实现
原文:https://yubincloud.github.io/notebook/pages/paper/kg/TransE/ TransE 及其实现 1. What is TransE? Trans ...
- 中科大+快手出品 CIRS: Bursting Filter Bubbles by Counterfactual Interactive Recommender System 代码解析
文章目录 前言 论文介绍: 代码介绍: 代码: 一. CIRS-UserModel-kuaishou.py 0. get_args() 解析参数 1. create_dir() 2. Prepare ...
最新文章
- “计算机艺术之父”、现代计算机技术先驱查理斯·苏黎去世,享年99岁
- Flutter入门:application、module、package、plugin
- selenium===使用docker搭建selenium分布式测试环境
- 手机技巧:手机关掉这个开关,一下能省2G内存,再也不怕卡顿死机
- C++生成简单WAV文件(一)
- python真是最烂的语言_在大型项目上,Python 是个烂语言吗?
- Nginx设置日志打印post请求参数
- selenium_java
- Ubuntu 18.04安装: failed to load ldlinux.c32
- PostgreSQL常用的客户端工具
- html5网页制作代码 大学生网页制作作业代码 (旅游网站官网滚动模板)
- 计算机视觉大型攻略 —— 立体视觉(4)立体匹配算法简介与SGM
- 记录MySQL中JSON_EXTRACT JSON_UNQUOTE函数的使用方式
- mysql 表的详细_MySQL的库表详细操作
- html5代码验证电话号码,这个我觉得挺重要的!
- CR-Fill: Generative Image Inpainting with Auxiliary Contexutal Reconstruction
- OGG REPA进程 Error ORA-01031报错处理
- 本地项目与Git项目关联
- FFplay退出分析
- biopython:2:序列组成