记录一下怎样pytorch框架下怎样获得模型的梯度

文章目录

  • 引入所需要的库
  • 一个简单的函数
  • 模型梯度获取
    • 先定义一个model
    • 如下定义两个获取梯度的函数
    • 定义一些过程与调用上述函数的方法
    • 可视化一下梯度的histogram

引入所需要的库

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

一个简单的函数

z=3y2=3(x+2)2z = 3y^2 = 3(x+2)^2z=3y2=3(x+2)2

out=mean(z)\text{out} = \text{mean}(z)out=mean(z)

∂z∂x=6(x+2)2\frac{\partial z}{\partial x} = \frac{6(x+2)}{2}∂x∂z​=26(x+2)​
大家想想,这里的公式为什么要除2?

代码如下:

x = torch.tensor([[1., 2.]], requires_grad=True)
y = x + 2
z = 3 * y.pow(2)
out = z.mean()  # you can try sum() to see what is the result.
out.backward()print(f"x: {x}")
print(f"y->x: {y}")
print(f"z->y->x: {z}")
print(f"out: {out}")
print(f"out->z->y->x: {x.grad}")

输出如下

x: tensor([[1., 2.]], requires_grad=True)
y->x: tensor([[3., 4.]], grad_fn=<AddBackward0>)
z->y->x: tensor([[27., 48.]], grad_fn=<MulBackward0>)
out: 37.5
out->z->y->x: tensor([[ 9., 12.]])

这里解释一下,x 是定义一个tensor,并设置requires_grad=True,这个意思就是x需要计算梯度。其它的注释已经标注挺清楚的了

模型梯度获取

先定义一个model

class ToyModel(nn.Module):def __init__(self, in_channels, out_channels, num_classes=2):super().__init__()# tmp only for testing, not validself.tmp = nn.Conv2d(in_channels, in_channels * 2, (3, 3))self.dim = out_channelsself.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=in_channels * 2,kernel_size=(3, 3),stride=(1, 1),padding=0)self.conv2 = nn.Conv2d(in_channels=in_channels * 2,out_channels=out_channels,kernel_size=(3, 3),stride=(1, 1),padding=0)self.pool = nn.AdaptiveAvgPool2d(output_size=(1))self.fc = nn.Linear(out_channels, num_classes, bias=False)def forward(self, x):x = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = self.pool(x)x = self.fc(x.view(-1, self.dim))return x

如下定义两个获取梯度的函数

def get_model_histogram(model):"""Description:- get norm gradients from model, and store in a OrderDictArgs:- model: (torch.nn.Module), torch modelReturns:- grads in OrderDict"""grads = OrderedDict()for name, params in model.named_parameters():grad = params.gradif grad is not None:tmp = {}params_np = grad.numpy()histogram, bins = np.histogram(params_np.flatten())tmp['histogram'] = list(histogram)tmp['bins'] = list(bins)grads[name] = tmpreturn grads
def get_model_norm_gradient(model):"""Description:- get norm gradients from model, and store in a OrderDictArgs:- model: (torch.nn.Module), torch modelReturns:- grads in OrderDict"""grads = OrderedDict()for name, params in model.named_parameters():grad = params.gradif grad is not None:grads[name] = grad.norm().item()return grads

定义一些过程与调用上述函数的方法

torch.manual_seed(0)
num_data = 40
toy_model = ToyModel(3, 64, 2)
data = torch.randn(num_data, 3, 224, 224)
label = torch.randint(0, 2, (num_data,))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(toy_model.parameters(), lr=1e-3)
toy_model.train()
for i, data in enumerate(data):data = data.unsqueeze(0)out = toy_model(data)target = label[i].unsqueeze(0)loss = criterion(out, target)loss.backward()if (i + 1) % 10 == 0:print('=' * 80)print(get_model_norm_gradient(toy_model))optimizer.step()optimizer.zero_grad()

get model norm gradient的输出如下

================================================================================
OrderedDict([('conv1.weight', 0.1473149210214615), ('conv1.bias', 0.16713829338550568), ('conv2.weight', 0.9203198552131653), ('conv2.bias', 0.5442095994949341), ('fc.weight', 1.7258217334747314)])
================================================================================
OrderedDict([('conv1.weight', 0.0349930003285408), ('conv1.bias', 0.03801438584923744), ('conv2.weight', 0.20729205012321472), ('conv2.bias', 0.12616902589797974), ('fc.weight', 0.39913201332092285)])
================================================================================
OrderedDict([('conv1.weight', 0.07812522351741791), ('conv1.bias', 0.08833323419094086), ('conv2.weight', 0.49012720584869385), ('conv2.bias', 0.2875416576862335), ('fc.weight', 0.9168939590454102)])
================================================================================
OrderedDict([('conv1.weight', 0.14530049264431), ('conv1.bias', 0.16511967778205872), ('conv2.weight', 0.9190732836723328), ('conv2.bias', 0.5398930907249451), ('fc.weight', 1.7265493869781494)])
torch.manual_seed(0)
num_data = 40
toy_model = ToyModel(3, 64, 2)
data = torch.randn(num_data, 3, 224, 224)
label = torch.randint(0, 2, (num_data,))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(toy_model.parameters(), lr=1e-3)
toy_model.train()
for i, data in enumerate(data):data = data.unsqueeze(0)out = toy_model(data)target = label[i].unsqueeze(0)loss = criterion(out, target)loss.backward()if (i + 1) % 10 == 0:print('=' * 80)print(str(get_model_histogram(toy_model)))optimizer.step()optimizer.zero_grad()

get model histogram 输出如下,输出太多,只显示最后一条输入了

================================================================================
OrderedDict([('conv1.weight', {'histogram': [4, 2, 13, 27, 76, 22, 11, 5, 1, 1], 'bins': [-0.036256444, -0.028072663, -0.019888882, -0.0117051015, -0.0035213209, 0.0046624597, 0.012846241, 0.021030022, 0.029213801, 0.037397582, 0.045581363]}), ('conv1.bias', {'histogram': [1, 2, 0, 0, 1, 0, 1, 0, 0, 1], 'bins': [-0.028756114, -0.012518765, 0.0037185834, 0.019955931, 0.03619328, 0.05243063, 0.06866798, 0.08490533, 0.101142675, 0.11738002, 0.13361737]}), ('conv2.weight', {'histogram': [15, 10, 35, 245, 1828, 970, 230, 68, 40, 15], 'bins': [-0.07653718, -0.060686104, -0.044835035, -0.028983962, -0.013132891, 0.0027181804, 0.018569252, 0.034420323, 0.050271396, 0.066122465, 0.08197354]}), ('conv2.bias', {'histogram': [1, 0, 1, 8, 12, 28, 5, 6, 0, 3], 'bins': [-0.21087514, -0.16971013, -0.1285451, -0.0873801, -0.04621508, -0.005050063, 0.036114953, 0.07727997, 0.11844498, 0.15961, 0.20077501]}), ('fc.weight', {'histogram': [1, 7, 11, 12, 33, 33, 12, 11, 7, 1], 'bins': [-0.41966814, -0.33573452, -0.2518009, -0.16786726, -0.08393363, 0.0, 0.08393363, 0.16786726, 0.2518009, 0.33573452, 0.41966814]})])

可视化一下梯度的histogram

import matplotlib.pyplot as plt
import matplotlib%matplotlib inline
  • 可视化conv2.weight
data = histo['conv2.weight']
bins = data['bins']
histogram = data['histogram']
max_idx = np.argmax(histogram)
min_idx = np.argmin(histogram)
width = abs(bins[max_idx] - bins[min_idx])plt.figure(figsize=(9, 6))
plt.bar(bins[:-1], histogram, width=width)
plt.show()

  • 可视化conv2.bias
data = histo['conv2.bias']
bins = data['bins']
histogram = data['histogram']
max_idx = np.argmax(histogram)
min_idx = np.argmin(histogram)
width = abs(bins[max_idx] - bins[min_idx])plt.figure(figsize=(9, 6))
plt.bar(bins[:-1], histogram, width=width)
plt.show()

pytorch 正向与反向传播的过程 获取模型的梯度(gradient),并绘制梯度的直方图相关推荐

  1. PyTorch学习笔记(11)——论nn.Conv2d中的反向传播实现过程

    0. 前言 众所周知,反向传播(back propagation)算法 (Rumelhart et al., 1986c),经常简称为backprop,它允许来自代价函数的信息通过网络向后流动,以便计 ...

  2. “反向传播算法”过程及公式推导

    一.定义 首先来一个反向传播算法的定义(转自维基百科):反向传播(英语:Backpropagation,缩写为BP)是"误差反向传播"的简称,是一种与最优化方法(如梯度下降法)结合 ...

  3. 仅使用numpy从头开始实现神经网络,包括反向传播公式推导过程

    仅使用numpy从头开始实现神经网络,包括反向传播公式推导过程: https://www.ctolib.com/yizt-numpy_neural_network.html

  4. pytorch 入门学习反向传播-4

    pytorch 入门学习反向传播 反向传播 import numpy as np import matplotlib.pyplot as plt import torchdef forward(x): ...

  5. 前馈神经网络--前向传播与反向传播计算过程

    目录 2.多层感知机(前馈神经网络) 2.1 定义 2.2 神经元 2.3 激活函数 2.3.1 sigmoid函数 2.3.2 tanh函数 2.3.3 relu函数 2.4 计算 2.4.1 前向 ...

  6. CNN反向传播算法过程

    主模块 规格数据输入(加载,调格式,归一化) 定义网络结构 设置训练参数 调用初始化模块 调用训练模块 调用测试模块 画图 初始化模块 设置初始化参数(输入通道,输入尺寸) 遍历层(计算尺寸,输入输出 ...

  7. 反向传播算法(过程及公式推导)_一文讲透神经网络的反向传播,要点介绍与公式推导...

    神经网络的反向传播是什么 神经网络的反向传播,实际上就是逐层计算梯度下降所需要的$w$向量的"变化量"(代价函数$J(w1,b1,w2,b2,w3,b3...wn,bn)$对于$w ...

  8. 损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()

    目录 回答一: 回答二: 回答三: 传统的训练函数,一个batch是这么训练的: 使用梯度累加是这么写的: 回答一: 一句话,用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而 ...

  9. 神经网络正向与反向传播

    一.神经网络的前向传播原理 在全连接神经网络中,每一层的每个神经元都会与前一层的所有神经元或者输入数据相连,例如图中的 f1(e)f _1 ( e )f1​(e)就与x1x_1x1​ 和 x2x_2x ...

最新文章

  1. css小技巧 -- 单标签实现单行文字居中,多行文字居左
  2. 【重磅】谷歌2021博士奖研金完整名单出炉,13个方向共75人获奖
  3. 如何用图表控件实现点击图例图标隐藏图表序列
  4. JSPatch近期新特性解析
  5. aspx mysql类_aspx中的mysql操作类sqldatasource使用示例分享
  6. 【Linux】crontab命令详解
  7. IOS流水布局UICollectionView使用FlowLayout进行自由灵活组合
  8. python logging模块的作用_【python】【logging】python日志模块logging常用功能
  9. 写在校招季,谈谈机器学习岗的Offer选择问题
  10. 区块链技术的发展趋势
  11. php对smarty的使用,[ php ] php smarty使用!
  12. 手册-网站-仙客传奇团队博客
  13. 2021年南宁二中高考成绩查询,2021年广西南宁二中高考物理冲刺试卷(一).docx...
  14. Oracle分区交换
  15. 尔雅 科学通史(吴国盛) 个人笔记及课后习题 2018 第一章 科学通史绪论
  16. 郝斌c语言96-99,《祁连山Photoshop CS3专家讲堂系列教程》[ISO]
  17. 英语思维导图大全 基础语法(二)
  18. mac上的微信小助手WeChatPlugin
  19. 常见apn类型说明及配置
  20. 这可是全网EVE安装最完整,最详细的图解,没有之一【安装图解】

热门文章

  1. 高并发的理解和使用场景-----特意区别和多线程的关系
  2. 系统调用软中断处理程序system_call分析
  3. Uncaught TypeError: Cannot read property 'length' of null错误怎么处理?
  4. etcd raft library设计原理和使用
  5. Java反射机制的使用方法
  6. 第十七章 我国农业科学技术
  7. DataTable的计算功能(转)
  8. 简单使用Git和Github来管理自己的代码和读书笔记
  9. 学习File API用于前端读取文件
  10. javascript --- 瀑布流的实现