复现TSM,碰到以下语句,之前在学习中是浅尝辄止,现在为了更好学透深度学习,遂决定直接搞懂这个代码

if self.new_fc is None:normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)

这里主要有normal_方法和constant_方法,下面逐一解释

提取网络层的信息

这里先解释一下提取网络层的信息

每个层的键的名称是由该层的类型决定的,和该层在网络中的位置无关。所以,每个层的键应该是:

  • nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1): 0

  • nn.ReLU(inplace=True): 1

  • nn.MaxPool2d(kernel_size=2, stride=2): 2

  • nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1): 3

  • nn.ReLU(inplace=True): 4

  • nn.MaxPool2d(kernel_size=2, stride=2): 5

  • nn.AdaptiveAvgPool2d((1, 1)): 6

  • nn.Flatten(): 7

  • nn.Linear(128, 64): 8

  • nn.ReLU(inplace=True): 9

  • nn.Linear(64, 10): 10

请注意,这里的键只是一个数字标识符,用于在模型中唯一地标识每个层的权重和偏差。这些键的名称没有具体含义,只要它们在整个模型中是唯一的,就可以任意选择键的值。

例如:

import torch.nn as nnbase_model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), #0nn.ReLU(inplace=True),#1nn.MaxPool2d(kernel_size=2, stride=2),#2nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),#3nn.ReLU(inplace=True),#4nn.MaxPool2d(kernel_size=2, stride=2),#5nn.AdaptiveAvgPool2d((1, 1)),#6nn.Flatten(),#7nn.Linear(128, 64),#8nn.ReLU(inplace=True),#9nn.Linear(64, 10)#10
)last_layer = getattr(base_model, '8')
print(last_layer)

结果

Linear(in_features=128, out_features=64, bias=True)

normal_

具体来说,该行代码中的 normal_() 方法会对 self.base_model 模型中的 self.base_model.last_layer_name 层的权重进行高斯分布初始化,其中第一个参数是权重张量,第二个参数 0 表示均值为 0,第三个参数 std 表示标准差。这里使用下划线后缀的 normal_ 方法表示在原地修改权重张量,而不是返回一个新的张量。

例如,假设我们有一个名为 base_model 的模型,其中包含一个名为 fc 的全连接层,我们可以使用以下代码对该层的权重进行高斯分布初始化:

import torch.nn as nnbase_model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(128, 64),nn.ReLU(inplace=True),nn.Linear(64, 10)
)last_layer_name = getattr(base_model, '10')  # 最后一层
# last_layer = getattr(base_model, 'fc')
print(last_layer_name)
std = 0.01
# nn.init.normal_(getattr(base_model, last_layer_name).weight, 0, std)
print(nn.init.normal_(getattr(base_model, '10').weight, 0, std))  # 对最后一层进行初始化

结果:

Linear(in_features=64, out_features=10, bias=True)
Parameter containing:
tensor([[-1.3272e-02, -2.2359e-03,  1.0343e-02, -1.0792e-02, -1.2242e-02,-1.7024e-02, -5.8491e-03,  2.3467e-03,  4.8827e-03, -8.6968e-03,1.2297e-03, -1.2535e-03,  1.2796e-02, -2.5463e-03, -3.9475e-03,-7.0661e-03, -1.1183e-03, -2.8131e-03,  1.6252e-02, -1.2049e-03,8.7546e-03, -6.3224e-03, -4.5075e-03, -3.3601e-03, -5.9193e-03,5.8264e-03,  4.2756e-03,  6.2804e-03,  1.7132e-02,  9.2935e-03,-1.7752e-02, -1.3216e-03, -3.6371e-04,  5.7609e-04, -6.0026e-03,-6.1346e-03, -2.7561e-03, -1.4461e-02, -1.1557e-02,  3.0463e-03,-1.0108e-02, -4.7012e-03,  9.2200e-04, -2.0413e-04, -6.1705e-03,-3.6149e-03, -1.0387e-02, -2.1668e-03,  2.4830e-02,  1.4840e-02,6.0030e-03,  4.5739e-03, -8.8294e-03, -9.5166e-05, -1.7639e-03,2.0190e-02, -3.2003e-03, -1.2733e-02, -7.3871e-04, -1.3978e-03,-3.9033e-03,  6.4459e-03, -1.1231e-02,  8.8730e-03],[-2.3404e-02,  5.0958e-03, -2.8636e-03,  7.4566e-03, -1.3217e-03,9.6338e-03,  4.0845e-03, -9.4511e-03, -1.5848e-02,  9.6535e-03,-1.9223e-02,  5.6821e-03, -9.5851e-03,  6.4113e-03,  1.2916e-02,4.4283e-03, -1.4849e-02,  1.7074e-04, -6.9544e-03, -2.0033e-02,1.1659e-03,  8.3090e-03,  1.0376e-02, -1.3552e-03,  5.1678e-04,-1.1280e-02, -1.5385e-03,  2.3204e-03,  2.1148e-02, -8.2046e-03,-3.1950e-03, -2.6627e-03, -8.0619e-03,  6.8565e-03, -1.5831e-03,-5.3802e-03,  4.9502e-03, -1.0993e-03,  1.5999e-03,  1.6042e-02,1.7201e-02,  1.0180e-02,  1.6592e-03, -5.1286e-03, -9.9063e-03,-1.0357e-02, -1.0337e-02, -6.0520e-03,  2.3609e-02, -5.8884e-03,9.9698e-03,  1.3698e-02, -1.2772e-02,  8.2779e-03, -7.6448e-03,1.1789e-02,  1.1248e-02, -7.1244e-03,  4.6928e-03, -1.1212e-03,2.8188e-03, -1.2279e-02,  6.7612e-03, -3.8642e-03],[ 3.1708e-03,  4.8198e-03, -1.3153e-03,  5.7256e-03,  2.9513e-03,6.0120e-03, -1.2716e-02, -2.9583e-02, -1.3539e-02,  9.6046e-03,-2.3146e-03, -5.9442e-03, -3.5330e-03,  6.3374e-03, -2.2096e-03,-3.5567e-03, -1.0496e-02, -9.1474e-03,  1.6573e-02, -2.7625e-03,7.2689e-03,  5.5843e-03,  1.9446e-03,  3.9445e-03, -1.0196e-02,6.3983e-03, -1.1957e-03,  1.9038e-03, -3.2439e-03, -9.9891e-03,1.7751e-03, -1.2842e-02, -1.0921e-02,  7.6490e-03, -6.7258e-03,2.7367e-03, -5.8537e-03, -4.4515e-03, -1.8622e-03,  1.8290e-03,-2.1976e-02, -1.0761e-02,  8.5432e-03, -7.4048e-04,  9.6255e-03,4.9710e-03,  2.6487e-03,  1.1278e-02, -2.2165e-02, -5.3400e-03,-1.2628e-02, -9.0693e-03,  1.1717e-04, -3.9173e-03,  7.0556e-03,-3.2840e-03, -2.0703e-02,  1.2574e-02, -1.9498e-03,  6.0815e-03,2.0596e-03, -3.7182e-03, -1.2596e-02, -1.2627e-02],[-9.8718e-03,  5.4179e-03,  1.6379e-03,  5.9276e-03,  3.1523e-04,-8.7425e-03, -7.6249e-03, -4.3939e-03, -6.2318e-03,  1.4632e-02,2.2710e-03,  9.9909e-03, -1.1965e-02,  1.3041e-02, -9.0492e-03,-1.1099e-03,  1.7039e-03,  1.1821e-02,  3.0194e-03, -6.0026e-03,1.6889e-02, -5.0959e-03, -1.9247e-03,  1.8685e-03, -1.7513e-02,1.6138e-02, -7.7140e-03,  1.0976e-02,  1.6791e-02, -1.3144e-02,1.0865e-03,  7.4303e-03, -3.3156e-03,  2.6016e-03, -2.5576e-03,6.6795e-03,  3.3842e-03,  8.1809e-03, -1.4860e-02, -9.7519e-03,-1.7827e-03,  3.6443e-03, -1.1853e-02, -1.8691e-02,  1.5285e-02,-1.9333e-02,  1.2217e-02, -1.9054e-02,  9.4615e-03, -2.2510e-03,4.9319e-03,  3.0562e-03,  1.0821e-02,  1.0460e-02,  2.2750e-03,1.5753e-02,  3.2901e-03, -2.3684e-02, -2.2839e-03, -2.0044e-03,1.7136e-02,  1.1856e-02, -6.5554e-04, -5.5353e-03],[-1.7280e-02, -6.1747e-04, -2.2409e-02, -4.7504e-03,  5.8930e-05,2.0783e-02, -2.3622e-03, -5.2687e-03, -5.8447e-03, -1.0608e-02,-1.0498e-02, -9.5139e-03,  1.3671e-02,  5.2071e-03,  6.2913e-03,1.5897e-02,  1.9130e-03, -1.7698e-02, -1.1381e-02, -2.4493e-03,-9.0661e-03,  3.0104e-04, -3.1901e-03, -1.4501e-02,  6.6054e-04,5.9062e-03,  9.6727e-03, -1.7085e-02, -7.1479e-03,  1.1499e-02,2.7148e-02, -1.6385e-02, -2.7822e-03,  1.5409e-02, -1.6170e-03,2.0706e-03, -1.0137e-03, -6.9971e-03,  4.5867e-03, -1.7533e-03,-1.5169e-02,  7.1676e-03,  7.3481e-04,  3.9260e-03,  2.2786e-02,-4.4119e-03, -5.0950e-04,  1.8821e-02, -9.7902e-03, -1.6165e-02,1.0704e-02,  1.1139e-03,  1.0848e-02,  8.7063e-03, -1.5427e-02,6.0734e-03, -6.5893e-03, -6.0677e-03,  3.6704e-03, -6.9275e-03,-6.3169e-03,  5.0168e-03,  5.0394e-03,  9.5354e-03],[-2.5179e-04,  3.3235e-03, -2.6350e-03, -2.8805e-03,  1.7708e-02,7.3398e-03,  1.8750e-03, -1.1013e-02, -1.4283e-03, -2.4949e-03,1.8975e-02, -6.0564e-03, -9.4183e-03,  7.4389e-03,  1.9309e-03,-7.5402e-03,  1.4321e-02, -1.4126e-02,  8.1422e-03,  5.2280e-03,5.0304e-03, -5.1896e-03,  5.6178e-03,  2.7014e-02,  8.8836e-04,-7.5104e-03, -2.3235e-03,  1.8136e-04, -2.9145e-02,  1.4800e-02,-1.2273e-04,  6.1381e-03, -1.5862e-02, -6.1995e-04, -2.7211e-03,-1.9053e-02,  5.0899e-03, -5.4222e-03,  1.0337e-02,  7.6167e-03,-3.2999e-03,  1.0699e-03,  1.3340e-02, -3.8160e-03,  3.0833e-03,-5.2089e-03,  9.0612e-04,  1.4491e-02, -4.5410e-03,  5.4649e-03,-3.3898e-03,  2.8065e-03,  1.5491e-02,  7.5988e-03,  1.5773e-02,1.5484e-02,  5.4282e-03,  2.5454e-03, -2.1613e-02, -1.5429e-02,-1.2897e-02, -5.3088e-03,  1.1335e-02, -9.0223e-03],[ 1.0117e-02, -3.3639e-03, -5.0150e-03,  3.2073e-03,  1.0271e-02,-6.6959e-03, -2.1131e-03, -1.2018e-02,  4.2930e-03, -1.6980e-04,-7.7216e-03,  5.6076e-03,  1.4555e-02,  1.7849e-02,  8.0165e-03,4.1849e-03,  5.5320e-03,  1.5881e-02, -6.8613e-03,  1.2461e-03,-9.2352e-03,  1.1187e-02, -7.9894e-03, -1.6583e-02,  1.4254e-02,7.2171e-04,  9.4763e-03, -8.1024e-03, -1.9460e-02,  7.6837e-03,2.7487e-04,  1.1689e-02,  4.6567e-03, -1.1756e-02,  2.4855e-03,-4.1040e-03, -9.1597e-03,  1.5789e-02, -1.8456e-03,  2.1223e-02,9.2912e-03, -1.5335e-02, -4.1271e-03, -6.4253e-04,  9.4843e-03,-3.3001e-03, -4.1901e-03, -1.0254e-02,  8.9056e-03,  9.9998e-04,-1.8859e-03,  3.7849e-03, -2.8724e-03,  7.0505e-04, -7.0398e-03,-2.5400e-03, -5.2459e-03,  7.9450e-03,  2.2319e-02, -1.0812e-02,1.5204e-02, -4.8161e-03, -7.8047e-04, -8.8051e-03],[-5.7175e-03,  2.0120e-02, -1.1453e-02, -1.2576e-02,  2.1838e-02,-1.5232e-02,  3.1967e-03,  1.2660e-02, -2.0898e-02,  6.2341e-03,4.6589e-04, -1.0194e-02,  7.9064e-03,  8.0972e-03,  5.9819e-03,1.2644e-02,  1.6601e-02, -9.0595e-04, -2.5263e-02, -7.0945e-03,-9.3729e-03,  9.0826e-03, -5.7136e-03,  4.9057e-03,  1.1597e-02,1.1955e-02, -7.8716e-04,  9.9030e-04,  5.8346e-03, -3.2973e-04,-7.7189e-03, -3.7571e-03, -1.3204e-02, -1.0440e-02,  2.9435e-03,-2.8907e-04,  1.2057e-03, -1.0351e-03, -5.7206e-03,  1.6962e-03,-3.2682e-03,  1.6592e-03, -8.5040e-03,  2.2232e-02, -5.1094e-04,1.2425e-02, -3.0755e-03,  1.1618e-02,  3.6595e-03, -1.5270e-02,4.9968e-03,  4.4446e-03,  1.1779e-02,  1.5565e-02, -1.2305e-02,2.8609e-03, -3.9866e-03,  1.3608e-02, -2.0619e-02, -4.7859e-03,-1.6829e-02, -7.4733e-03, -1.9138e-02,  1.7258e-03],[-1.7574e-02,  1.2523e-02,  2.0962e-03,  4.2426e-03,  8.0573e-03,2.2357e-02,  1.2657e-02,  3.5991e-03,  4.3030e-03, -7.8645e-03,-2.4566e-03,  5.2175e-03, -8.6353e-03, -8.3184e-03, -1.1575e-02,4.3127e-03, -5.1229e-03,  5.4804e-03, -8.9780e-03,  1.5546e-04,-1.4716e-02, -2.0695e-02, -7.6624e-03, -1.1113e-02, -5.7346e-03,6.1758e-03, -5.6148e-03,  2.0378e-05,  1.0982e-02,  8.3183e-04,-9.1416e-04, -3.9552e-03, -9.1955e-03,  1.1313e-03, -2.9609e-04,-1.1788e-02,  3.7255e-04,  1.0457e-02, -1.8796e-02,  2.8319e-03,-5.6307e-03, -6.3487e-04, -1.1184e-02, -1.4268e-02,  1.1114e-02,-1.2992e-02, -1.8135e-03, -1.0604e-02,  9.8879e-03, -2.3624e-03,1.3414e-02,  8.9875e-03,  2.2747e-02,  3.8558e-03, -3.3536e-03,9.5849e-03, -2.1084e-03,  6.9714e-04, -6.0838e-03, -8.5648e-03,-1.2717e-02, -5.1056e-04,  8.7422e-03, -4.1311e-03],[ 1.5267e-02, -1.0984e-02, -2.0117e-02,  1.1511e-02,  7.1848e-03,-4.4533e-03,  4.0069e-03,  2.2058e-02,  2.6312e-02, -1.1855e-02,-5.7762e-03, -2.0696e-03,  1.5378e-02,  1.8970e-02, -2.1807e-02,9.5076e-03,  5.6942e-03,  9.4878e-03,  6.0089e-05, -2.0919e-02,-2.0222e-03,  2.0285e-02,  2.7224e-03, -1.0630e-02,  5.5314e-03,-5.6111e-03,  1.1118e-02,  3.5967e-03,  8.3760e-03, -3.1836e-03,-9.6490e-03,  1.3429e-02,  4.2757e-03,  1.0438e-02, -2.9009e-03,2.3239e-03,  5.5037e-03,  1.1511e-02,  1.2017e-02,  1.3676e-03,1.5089e-02, -1.3869e-02, -8.1693e-04,  1.3668e-03,  4.1932e-03,-2.5586e-04,  9.7049e-05, -1.4709e-03,  8.5214e-03,  1.2624e-02,1.6878e-03, -2.0740e-03,  8.4441e-03, -2.2894e-02,  1.7953e-03,1.9498e-02,  2.4912e-02, -2.5735e-03, -5.9107e-03, -1.7209e-03,1.8648e-02,  5.6585e-03, -1.8484e-03, -9.2792e-03]],requires_grad=True)

在上述代码中,我们使用 nn.init.normal_() 方法对 base_model 模型中的 最后一层也就是nn.Linear(64, 10) 层的权重进行高斯分布初始化,其中第一个参数 getattr(base_model, last_layer_name).weight 表示获取 base_model 模型中名为 last_layer_name 的层的权重张量;第二个参数 0 表示均值为 0;第三个参数 std 表示标准差为 0.01。最终得到的结果是对 fc 层的权重进行了高斯分布初始化。

constant_

这行代码使用 PyTorch 提供的 constant_() 方法将指定层的偏置设置为常量值。

具体来说,该行代码中的 constant_() 方法会将 self.base_model 模型中的 self.base_model.last_layer_name 层的偏置设置为常量值 0,其中第一个参数是偏置张量,第二个参数 0 表示常量值。这里使用下划线后缀的 constant_ 方法表示在原地修改偏置张量,而不是返回一个新的张量。

例如,假设我们有一个名为 base_model 的模型,其中包含一个名为 fc 的全连接层,我们可以使用以下代码将该层的偏置设置为常量值 0:

import torch.nn as nnbase_model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(),nn.Linear(128, 64),nn.ReLU(inplace=True),nn.Linear(64, 10)
)last_layer_name = getattr(base_model, '10')  # 最后一层
print(last_layer_name)# nn.init.constant_(getattr(base_model, '10').bias, 0)
print(nn.init.constant_(getattr(base_model, '10').bias, 0))

结果

Linear(in_features=64, out_features=10, bias=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

在上述代码中,我们使用 nn.init.constant_() 方法将 base_model 模型中的 nn.Linear(64, 10) 层的偏置设置为常量值 0,其中第一个参数 getattr(base_model, last_layer_name).bias 表示获取 base_model 模型最后一层的偏置张量;第二个参数 0 表示常量值。最终得到的结果是将 最后一层的偏置设置为常量值 0。

TSM-normal_方法相关推荐

  1. 联想X86服务器重启管理控制器(XClarity Controller)或TSM的方法

    当设备运行较长时间时,服务器的管理控制器(或称服务处理器,Service Processor)可能由于内存或空间等问题响应缓慢,如果机器上运行ESXi,有可能会在Vcenter报出部件的"s ...

  2. temporal shift module(TSM)

    [官方]Paddle2.1实现视频理解经典模型 - TSM - 飞桨AI Studio本项目将带大家深入理解视频理解领域经典模型TSM.从模型理论讲解入手,深入到代码实践.实践部分基于TSM模型在UC ...

  3. ​MMIT冠军方案 | 用于行为识别的时间交错网络,商汤公开视频理解代码库

    作者 | 商汤 出品 | AI科技大本营(ID:rgznai100) 本文主要介绍三个部分: 一个高效的SOTA视频特征提取网络TIN,发表于AAAI2020 ICCV19 MMIT多标签视频理解竞赛 ...

  4. Lesson 13.4 Dead ReLU Problem与学习率优化

    Lesson 13.4 Dead ReLU Problem与学习率优化   和Sigmoid.tanh激活函数不同,ReLU激活函数的叠加并不会出现梯度消失或者梯度爆炸,但ReLU激活函数中使得部分数 ...

  5. AAAI 2020 时间交错网络 | ICCV19多标签视频理解冠军方案

    本文主要介绍三个部分: 一个高效的 SOTA 视频特征提取网络 TIN,发表于 AAAI 2020 ICCV19 MMIT 多标签视频理解竞赛冠军方案,基于 TIN 和 SlowFast 一个基于 P ...

  6. 【Pytorch神经网络实战案例】14 构建条件变分自编码神经网络模型生成可控Fashon-MNST模拟数据

    1 条件变分自编码神经网络生成模拟数据案例说明 在实际应用中,条件变分自编码神经网络的应用会更为广泛一些,因为它使得模型输出的模拟数据可控,即可以指定模型输出鞋子或者上衣. 1.1 案例描述 在变分自 ...

  7. 【Pytorch神经网络实战案例】13 构建变分自编码神经网络模型生成Fashon-MNST模拟数据

    1 变分自编码神经网络生成模拟数据案例说明 变分自编码里面真正的公式只有一个KL散度. 1.1 变分自编码神经网络模型介绍 主要由以下三个部分构成: 1.1.1 编码器 由两层全连接神经网络组成,第一 ...

  8. 【汇总】行为识别、时序行为检测、弱监督行为检测、时空行为定位论文代码(持续更新!!!)

    视频行为识别与轻量化网络的前沿论文.代码等 https://zhuanlan.zhihu.com/c_1207774575393865728 CVPR 2020 行为识别/视频理解论文汇总 https ...

  9. 神经网络-常见函数、定义

    梯度下降: 通过不断沿着反梯度方向更新参数求解 小批量随机梯度下降是深度学习默认的求解方法 两个重要的超参数 批量大小和学习率 improt random: random 包主要用来生成随机数 X = ...

最新文章

  1. 让你彻底明白什么叫游戏引擎(1)
  2. 导师发现我刷短视频,给我发了一条链接
  3. 5秒到1秒,记一次效果“非常”显著的性能优化
  4. 关于SAP云平台的Identity Authentication tenant
  5. rds 数据库营销报告_《营销自动化从入门到精通》第五章 集成营销自动化工具与CRM...
  6. json 和 数组的区别
  7. SQL Server 2012 中 SSAS 多维数据浏览器已经废除
  8. ubuntu 下WebStorm 无法输入中文
  9. QQ扫码登录实现与原理
  10. hibernate整合openGauss
  11. pycharm英语怎么读_pycharm快捷键翻译
  12. 小学听力测试英语软件,亲测:好用的小学英语软件有哪些?这6款通通安利给大家!...
  13. Python herhan学习 day2
  14. flac格式转mp3
  15. php下lua的运行,phpStudy中起用lua脚本
  16. #数据结构:家谱管理
  17. 【C/C++】scanf,printf 函数
  18. 数仓01-概念的理解和方法论
  19. NUCLEO STM32H743购买和使用说明
  20. [引爆流行]Meme Engine话题(一)

热门文章

  1. 如何在Mac上使用Kigo Netflix Video Downloader从Netflix 下载视频?
  2. mysql 事务排他锁_[数据库事务与锁]详解六: MySQL中的共享锁与排他锁
  3. Mysql中的视图是什么?有什么作用?
  4. led调光原理c语言,最牛的LED遥控控制器---调整无闪烁(C语言)
  5. 达人评测 苹果 M1 iPad Pro 2021 怎么样
  6. ICT 2017 | 航天信息於亮:数据智能助力政府转型,推动企业发展
  7. android 打开公众号页面_微信公众号页面适配
  8. 程序员的第一款表情包,你值得拥有
  9. Fluent截取局部面
  10. Android开发-WebView的缓存处理和性能优化 实现H5页面秒开【四】