4.4 初始化策略

在深度学习中参数的初始化十分重要,良好的初始化能让模型更快收敛,并达到更高水平,而糟糕的初始化则可能使得模型迅速瘫痪。PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,因此一般不用我们考虑,当然我们也可以用自定义初始化去代替系统的默认初始化。而当我们在使用Parameter时,自定义初始化则尤为重要,因t.Tensor()返回的是内存中的随机数,很可能会有极大值,这在实际训练网络中会造成溢出或者梯度消失。PyTorch中nn.init模块就是专门为初始化而设计,如果某种初始化策略nn.init不提供,用户也可以自己直接初始化。

In [55]:

# 利用nn.init初始化
from torch.nn import init
linear = nn.Linear(3, 4)t.manual_seed(1)
# 等价于 linear.weight.data.normal_(0, std)
init.xavier_normal_(linear.weight)

Out[55]:

Parameter containing:
tensor([[ 0.3535,  0.1427,  0.0330],[ 0.3321, -0.2416, -0.0888],[-0.8140,  0.2040, -0.5493],[-0.3010, -0.4769, -0.0311]])

In [56]:

# 直接初始化
import math
t.manual_seed(1)# xavier初始化的计算公式
std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0,std)

Out[56]:

tensor([[ 0.3535,  0.1427,  0.0330],[ 0.3321, -0.2416, -0.0888],[-0.8140,  0.2040, -0.5493],[-0.3010, -0.4769, -0.0311]])

In [57]:

# 对模型的所有参数进行初始化
for name, params in net.named_parameters():if name.find('linear') != -1:# init linearparams[0] # weightparams[1] # biaselif name.find('conv') != -1:passelif name.find('norm') != -1:pass

初始化的具体例子

# -*- coding: utf-8 -*-
"""
Created on 2019@author: fancp
"""import torch
import torch.nn as nnw = torch.empty(3,5)#1.均匀分布 - u(a,b)
#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
print(nn.init.uniform_(w))
# =============================================================================
# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],
#         [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],
#         [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])
# =============================================================================#2.正态分布 - N(mean, std)
#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
print(nn.init.normal_(w))
# =============================================================================
# tensor([[ 0.4388,  0.3083, -0.6803, -1.1476, -0.6084],
#         [ 0.5148, -0.2876, -1.2222,  0.6990, -0.1595],
#         [-2.0834, -1.6288,  0.5057, -0.5754,  0.3052]])
# =============================================================================#3.常数 - 固定值 val
#torch.nn.init.constant_(tensor, val)
print(nn.init.constant_(w, 0.3))
# =============================================================================
# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#         [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#         [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
# =============================================================================#4.全1分布
#torch.nn.init.ones_(tensor)
print(nn.init.ones_(w))
# =============================================================================
# tensor([[1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1.]])
# =============================================================================#5.全0分布
#torch.nn.init.zeros_(tensor)
print(nn.init.zeros_(w))
# =============================================================================
# tensor([[0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])
# =============================================================================#6.对角线为 1,其它为 0
#torch.nn.init.eye_(tensor)
print(nn.init.eye_(w))
# =============================================================================
# tensor([[1., 0., 0., 0., 0.],
#         [0., 1., 0., 0., 0.],
#         [0., 0., 1., 0., 0.]])
# =============================================================================#7.xavier_uniform 初始化
#torch.nn.init.xavier_uniform_(tensor, gain=1.0)
#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010
print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))
# =============================================================================
# tensor([[-0.1270,  0.3963,  0.9531, -0.2949,  0.8294],
#         [-0.9759, -0.6335,  0.9299, -1.0988, -0.1496],
#         [-0.7224,  0.2181, -1.1219,  0.8629, -0.8825]])
# =============================================================================#8.xavier_normal 初始化
#torch.nn.init.xavier_normal_(tensor, gain=1.0)
print(nn.init.xavier_normal_(w))
# =============================================================================
# tensor([[ 1.0463,  0.1275, -0.3752,  0.1858,  1.1008],
#         [-0.5560,  0.2837,  0.1000, -0.5835,  0.7886],
#         [-0.2417,  0.1763, -0.7495,  0.4677, -0.1185]])
# =============================================================================#9.kaiming_uniform 初始化
#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015
print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.7712,  0.9344,  0.8304,  0.2367,  0.0478],
#         [-0.6139, -0.3916, -0.0835,  0.5975,  0.1717],
#         [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])
# =============================================================================#10.kaiming_normal 初始化
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.0210,  0.5532, -0.8647,  0.9813,  0.0466],
#         [ 0.7713, -1.0418,  0.7264,  0.5547,  0.7403],
#         [-0.8471, -1.7371,  1.3333,  0.0395,  1.0787]])
# =============================================================================#11.正交矩阵 - (semi)orthogonal matrix
#torch.nn.init.orthogonal_(tensor, gain=1)
#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013
print(nn.init.orthogonal_(w))
# =============================================================================
# tensor([[-0.0346, -0.7607, -0.0428,  0.4771,  0.4366],
#         [-0.0412, -0.0836,  0.9847,  0.0703, -0.1293],
#         [-0.6639,  0.4551,  0.0731,  0.1674,  0.5646]])
# =============================================================================#12.稀疏矩阵 - sparse matrix
#torch.nn.init.sparse_(tensor, sparsity, std=0.01)
#From - Deep learning via Hessian-free optimization - Martens 2010
print(nn.init.sparse_(w, sparsity=0.1))
# =============================================================================
# tensor([[ 0.0000,  0.0000, -0.0077,  0.0000, -0.0046],
#         [ 0.0152,  0.0030,  0.0000, -0.0029,  0.0005],
#         [ 0.0199,  0.0132, -0.0088,  0.0060,  0.0000]])
# =============================================================================

Pytorch:初始化相关推荐

  1. pytorch初始化

    转载自 https://blog.csdn.net/dss_dssssd/article/details/83959474 本文内容: 1. Xavier 初始化 2. nn.init 中各种初始化函 ...

  2. pytorch 初始化权重

    一般的网络初始化方法: from torch.nn import functional as F, initdef init_params(self):for m in self.modules(): ...

  3. pytorch默认初始化_PyTorch的初始化

    背景 在使用PyTorch深度学习框架的时候,不管是训练还是测试,代码中引入PyTorch的第一句总是: import torch 在Gemfield前述专栏文章里,我们已经得知,torch/csrc ...

  4. 树莓派4B (aarch64) 安装PyTorch 1.8 的可行方案

    树莓派4B (aarch64) 安装PyTorch 1.8 的可行方案 最终可行方案 试了一堆方案(源码编译.Fast.ai的安装文件等)之后,终于找到一个可行的方案.是在 PyTorch 官方讨论社 ...

  5. pytorch中的激励函数(详细版)

    初学神经网络和pytorch,这里参考大佬资料来总结一下有哪些激活函数和损失函数(pytorch表示) 首先pytorch初始化:   import torch import torch.nn.fun ...

  6. PyTorch中文文档阅读笔记-day1

    写在开头(重复的) 1.课程来源:torch中文教程1.7版. torch中文文档. 2.笔记目的:个人学习+增强记忆+方便回顾 3.时间:2021年4月29日 4.仅作为个人笔记,如有需要请务必按照 ...

  7. (五)使用生成对抗网络 (GAN)生成新的时装设计

    目录 介绍 预测新时尚形象的力量 构建GAN 初始化GAN参数和加载数据 从头开始构建生成器 从头开始构建鉴别器 初始化GAN的损失和优化器 下一步 下载源 - 120.7 MB 介绍 DeepFas ...

  8. 【KG】TransE 及其实现

    原文:https://yubincloud.github.io/notebook/pages/paper/kg/TransE/ TransE 及其实现 1. What is TransE? Trans ...

  9. 中科大+快手出品 CIRS: Bursting Filter Bubbles by Counterfactual Interactive Recommender System 代码解析

    文章目录 前言 论文介绍: 代码介绍: 代码: 一. CIRS-UserModel-kuaishou.py 0. get_args() 解析参数 1. create_dir() 2. Prepare ...

最新文章

  1. “计算机艺术之父”、现代计算机技术先驱查理斯·苏黎去世,享年99岁
  2. Flutter入门:application、module、package、plugin
  3. selenium===使用docker搭建selenium分布式测试环境
  4. 手机技巧:手机关掉这个开关,一下能省2G内存,再也不怕卡顿死机
  5. C++生成简单WAV文件(一)
  6. python真是最烂的语言_在大型项目上,Python 是个烂语言吗?
  7. Nginx设置日志打印post请求参数
  8. selenium_java
  9. Ubuntu 18.04安装: failed to load ldlinux.c32
  10. PostgreSQL常用的客户端工具
  11. html5网页制作代码 大学生网页制作作业代码 (旅游网站官网滚动模板)
  12. 计算机视觉大型攻略 —— 立体视觉(4)立体匹配算法简介与SGM
  13. 记录MySQL中JSON_EXTRACT JSON_UNQUOTE函数的使用方式
  14. mysql 表的详细_MySQL的库表详细操作
  15. html5代码验证电话号码,这个我觉得挺重要的!
  16. CR-Fill: Generative Image Inpainting with Auxiliary Contexutal Reconstruction
  17. OGG REPA进程 Error ORA-01031报错处理
  18. 本地项目与Git项目关联
  19. FFplay退出分析
  20. biopython:2:序列组成

热门文章

  1. 更新字典 (Updating a Dictionary,UVa12504)
  2. 笨方法学python - 04
  3. windows下手动安装composer并配置环境变量
  4. div的水平居中和垂直居中
  5. 软件测试--接口测试入门
  6. linux运维高频命令汇总
  7. mysql 21天_把整个Mysql拆分成21天,轻松掌握,搞定(中)
  8. 我是学渣,但是我零基础自学web前端成功了
  9. 云计算开源软件有哪些?
  10. 零基础学python看什么书好?