pytorch 模型model 的一些常用属性和函数说明
首先创建一个简单的网络,用来举例说明后来的例子。
class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(6)self.conv2 = nn.Conv2d(6,8,kernel_size=3,padding=1)self.relu = nn.ReLU(inplace=True)def forward(self,x):x = self.conv1(1)x = self.bn1(x)x = self.conv2(x)x = self.relu(x)return xnet = Net()
net.parameters(),可以得到net这个具体的模型中的参数:
for para in net.parameters():print(para)print(para.shape)print()'''
输出为:Parameter containing:
tensor([[[[ 0.0794, 0.1070, 0.0415],[ 0.0037, -0.0850, 0.0919],[ 0.0039, 0.0899, -0.1446]],[[-0.0642, 0.0251, -0.1055],[ 0.1085, 0.0627, 0.0388],[-0.0878, -0.1305, 0.1335]],[[-0.0907, 0.0113, 0.1400],[ 0.0051, -0.0605, -0.1085],[ 0.0544, -0.0649, 0.0847]]],[[[-0.1319, 0.0152, -0.0736],[ 0.1796, 0.0857, 0.1668],[ 0.0586, -0.1508, -0.1571]],[[-0.1053, 0.0372, 0.1596],[ 0.1509, 0.1125, -0.1773],[ 0.0960, -0.0507, 0.0569]],[[-0.0640, 0.0070, -0.1253],[-0.1739, 0.0552, 0.1892],[ 0.1232, -0.0811, 0.1263]]],[[[ 0.0483, -0.1212, -0.0870],[-0.0915, 0.0072, 0.1581],[ 0.1184, -0.0907, 0.1109]],[[-0.0024, 0.0980, -0.1080],[-0.0311, -0.1013, -0.0581],[ 0.1855, -0.0202, 0.0950]],[[-0.1640, -0.0848, 0.0254],[ 0.0318, 0.0538, 0.0277],[ 0.0641, 0.0298, 0.0352]]],[[[-0.0955, 0.0569, -0.0565],[-0.1186, -0.0177, 0.0604],[ 0.0305, -0.0398, -0.1165]],[[-0.1532, 0.0179, 0.0317],[ 0.0910, 0.1470, -0.1013],[-0.0165, 0.0095, -0.0887]],[[-0.0314, 0.1790, -0.1142],[ 0.1710, -0.1628, 0.1342],[-0.0781, 0.0194, -0.0568]]],[[[-0.1903, -0.1659, -0.1797],[ 0.1109, 0.0686, 0.1767],[-0.0777, -0.0341, -0.1549]],[[ 0.0615, -0.1309, -0.1492],[ 0.1291, -0.1705, 0.1749],[ 0.0173, -0.1587, 0.0072]],[[-0.1669, -0.0803, 0.0378],[ 0.1880, -0.0338, 0.1056],[-0.0171, 0.0892, -0.0090]]],[[[-0.1615, 0.1901, -0.1313],[-0.0775, -0.0043, -0.0902],[-0.0786, 0.0501, 0.0921]],[[ 0.1332, 0.1698, 0.1657],[ 0.0244, 0.0792, -0.1830],[ 0.0519, -0.1610, 0.0821]],[[-0.1437, 0.0229, -0.0810],[-0.1200, 0.1311, 0.0776],[ 0.0772, -0.0238, -0.0981]]]], requires_grad=True)
torch.Size([6, 3, 3, 3])Parameter containing:
tensor([ 0.0599, -0.1511, -0.0591, 0.1000, 0.1050, 0.0743],requires_grad=True)
torch.Size([6])Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
torch.Size([6])Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
torch.Size([6])Parameter containing:
tensor([[[[-1.0620e-01, 5.6997e-02, -7.9542e-03],[-6.6638e-02, -1.0529e-02, 1.3376e-01],[ 7.1680e-02, 1.3388e-01, 1.2293e-01]],[[ 9.2092e-02, 2.4215e-02, -1.2708e-01],[ 1.9943e-03, -8.7654e-02, 1.0564e-01],[-1.2967e-01, -1.2077e-01, -4.4365e-02]],[[ 9.9798e-04, -7.9709e-02, 2.7571e-02],[-1.4309e-02, 1.1243e-01, -1.1661e-01],[ 7.5213e-02, 7.6132e-02, 1.4844e-02]],[[ 1.2713e-01, -7.3697e-02, 9.4301e-02],[ 7.7325e-02, 9.6845e-02, -1.0990e-01],[ 6.2486e-02, 1.0107e-01, 3.0378e-02]],[[-1.0599e-01, 2.7444e-02, -8.8193e-02],[-1.0384e-01, 1.2580e-01, 4.1619e-02],[ 1.3596e-01, -1.2098e-01, 8.2317e-02]],[[-1.0979e-01, 9.2484e-02, -5.2828e-03],[ 7.7915e-02, 6.0981e-02, 9.0634e-02],[ 8.3001e-02, 7.1535e-02, -1.6206e-02]]],[[[ 1.1561e-01, -2.1935e-02, -8.5694e-03],[-4.9740e-03, -2.1594e-02, 9.7255e-02],[ 1.2904e-01, 7.2028e-02, 9.6564e-02]],[[-7.6498e-02, -1.2666e-01, -3.2563e-02],[ 9.0076e-02, -8.3288e-02, 1.1785e-01],[-4.3596e-02, 3.6950e-03, -5.0087e-02]],[[-2.9787e-02, -5.2824e-02, -9.9231e-02],[ 9.1963e-02, 7.7965e-02, 1.1397e-01],[ 1.3667e-02, 1.1007e-01, -4.1288e-02]],[[ 9.4790e-02, -6.8296e-02, -4.3310e-02],[-6.3128e-02, 2.3350e-02, -6.3908e-02],[-1.2005e-01, -6.2899e-02, -7.2392e-02]],[[-1.1934e-01, -4.5716e-02, -5.7582e-02],[ 8.1211e-06, 9.6752e-02, -4.1839e-02],[ 9.9383e-02, -4.9952e-02, -4.1875e-02]],[[ 1.0271e-01, -9.7970e-02, -2.5481e-02],[ 1.2039e-01, 1.7195e-02, -2.2504e-02],[ 6.3394e-02, -1.0446e-02, 9.7013e-02]]],[[[-6.2230e-02, -8.0188e-02, -4.3593e-02],[ 9.6622e-02, 7.5777e-02, 1.9751e-02],[ 4.6756e-02, 8.1505e-02, 2.1734e-02]],[[-4.0420e-02, -4.7027e-02, 2.7860e-02],[-4.5530e-04, 1.0848e-01, -9.7263e-02],[ 4.0441e-02, -2.3740e-03, -1.1751e-01]],[[-1.0342e-01, 1.4509e-02, 3.5800e-02],[-7.3109e-02, -4.4676e-02, 1.1477e-01],[ 1.0436e-01, -1.1468e-01, 1.1279e-01]],[[ 1.2757e-01, -5.4175e-02, 3.9229e-02],[ 1.2238e-01, -4.1751e-02, 1.0329e-02],[ 1.1175e-01, -1.3469e-01, 9.0738e-02]],[[-1.2890e-01, 1.0985e-01, -3.5065e-02],[-1.0353e-02, -1.1117e-01, -1.0932e-01],[ 2.3825e-02, -5.1328e-02, 1.0952e-01]],[[-1.2119e-01, -1.1721e-01, 3.9911e-02],[-9.3294e-02, 3.6181e-02, -9.2453e-02],[-1.0519e-01, 5.3727e-02, 4.4648e-03]]],...,[[[ 6.6163e-02, -1.0531e-01, -1.0589e-01],[ 7.9671e-02, -3.3005e-02, -1.0760e-01],[ 1.4868e-02, 1.4420e-02, -9.6573e-02]],[[ 2.2414e-02, -1.5715e-02, 2.4232e-02],[ 2.3479e-02, -8.7212e-02, -1.8911e-02],[ 9.3712e-02, 1.0342e-01, 5.4269e-02]],[[-9.8044e-02, 7.1834e-02, -1.0760e-01],[-9.7597e-02, 9.9367e-02, -9.9010e-02],[ 2.6155e-02, -1.3208e-01, 1.0316e-02]],[[ 7.7097e-02, 1.0838e-01, 2.7527e-02],[-4.3391e-02, 1.3416e-01, -1.1440e-01],[-3.8224e-02, -2.7650e-03, -5.9436e-03]],[[ 6.5886e-02, 1.1016e-02, -1.0989e-01],[ 4.2206e-02, -9.2878e-02, 7.4586e-02],[ 1.1299e-01, -1.1260e-01, -7.2581e-02]],[[ 8.6093e-03, 3.0288e-02, 7.8243e-02],[-6.7512e-03, -8.5671e-02, 8.3012e-02],[-2.4528e-02, 1.7389e-02, 2.0112e-02]]],[[[ 3.9985e-02, 6.4231e-03, 1.3579e-01],[ 8.8007e-02, -1.8449e-02, 2.9483e-02],[-5.8890e-02, 3.1275e-02, 1.1129e-01]],[[ 9.9826e-02, -1.0343e-01, 1.7781e-02],[-1.5528e-02, -1.2074e-01, -5.4819e-02],[-8.1487e-02, 3.7535e-02, -6.7128e-02]],[[-2.2612e-02, -4.7612e-02, -1.3335e-01],[ 3.7972e-02, -1.2762e-01, 5.4009e-02],[ 9.0579e-02, 5.4727e-02, -9.1461e-02]],[[ 8.0858e-02, 1.4411e-03, -1.2739e-01],[ 1.0097e-01, 8.3857e-02, -8.0914e-02],[-1.9743e-02, 1.1509e-01, 8.2933e-02]],[[-3.0184e-02, 1.0409e-01, 2.2486e-02],[-7.8506e-02, -7.7744e-02, -2.8042e-02],[-3.3265e-02, 9.1861e-02, 4.7874e-02]],[[ 3.1688e-02, 1.2607e-01, 8.8575e-02],[ 1.0217e-01, 2.8618e-02, 8.4546e-02],[ 2.8103e-02, 1.2679e-01, 2.4444e-02]]],[[[ 7.9484e-02, -1.1017e-02, -2.9063e-02],[ 5.4235e-02, 1.1226e-01, -1.0663e-01],[ 9.8365e-02, -2.1643e-02, 6.3686e-02]],[[ 3.0368e-03, 1.2335e-03, 1.3460e-02],[-5.6941e-02, -9.9266e-02, 3.3269e-02],[ 8.6997e-02, 1.1879e-01, -1.2027e-02]],[[ 3.4441e-02, 1.3346e-01, 1.4495e-03],[ 6.1219e-02, 8.4678e-02, -4.3233e-02],[ 1.3061e-01, -1.1880e-01, -1.2782e-01]],[[ 3.4226e-02, 7.5535e-02, -7.4717e-02],[ 8.2468e-03, -9.3862e-02, -5.3166e-02],[ 1.3202e-01, 7.6724e-02, 6.3903e-02]],[[-5.8022e-02, -7.8344e-02, -4.7197e-02],[ 3.7977e-02, 8.6118e-02, 1.1670e-02],[-1.3180e-01, -3.9207e-02, 1.3028e-01]],[[-5.4157e-03, -7.3742e-02, 4.5027e-02],[-2.8969e-02, -2.3086e-02, -3.3792e-02],[ 7.5957e-02, 3.4847e-02, 1.3248e-01]]]], requires_grad=True)
torch.Size([32, 6, 3, 3])Parameter containing:
tensor([-0.0801, 0.0075, 0.0469, -0.0886, 0.0583, -0.0399, -0.0551, 0.0094,-0.0457, 0.1121, 0.0496, -0.0684, 0.1093, 0.0834, -0.0910, -0.1112,-0.0711, -0.0641, -0.0981, 0.0356, 0.1234, -0.0284, 0.0813, 0.0188,-0.0063, -0.0851, 0.1308, -0.0041, -0.0926, -0.0906, 0.1180, 0.0142],requires_grad=True)
torch.Size([32])'''
net.named_parameters()会返回两部分内容,分别是模型中的属性名称和对应的参数值:
for name, para in net.named_parameters():print(name)print(para)print(para.shape)print()'''
输出例子:
conv1.weight
torch.Size([6, 3, 3, 3])
Parameter containing:
tensor([[[[ 0.0215, 0.1517, 0.1218],[ 0.1887, -0.0702, 0.1366],[-0.0947, 0.0794, -0.1096]],[[-0.0045, 0.0683, -0.0814],[ 0.0367, -0.0305, -0.1630],[ 0.0413, 0.0197, 0.1726]],[[ 0.0212, 0.1100, 0.0536],[ 0.1513, 0.0163, 0.1070],[-0.1378, -0.1698, 0.1431]]],'''
输出的conv1.weight,正好对应着最上面定义的self.conv1 = nn.Conv2d(3,6,kernal_size=1,padding=1)中的conv1中的weight参数。
指定参数的更新方式:
net = Net()
ignored_params = list( map(id, net.conv1.parameters()) )
base_params = filter ( lambda p: id(p) not in ignored_params, net.parameters() )
optimizer = torch.optim.SGD( [{'params':base_params},{'params':net.conv1.parameters(), 'lr':1e-3}], lr = 1e-2, momentum=0.9 )
在{}中可以对某些参数指定更新方式,没有设置的更新细节,可以在()后面统一规定
pytorch 模型model 的一些常用属性和函数说明相关推荐
- 【pandas-汇总3】DataFrame常用属性、函数以及索引方式
1.DataFrame常用属性.函数以及索引方式 1.1DataFrame简介 DataFrame是一个表格型的数据结构,它含有一组有序的列,每列可以是不同的值类型(数值.字符串.布尔值等).Data ...
- python series函数,【Python】【pandas-汇总2】series常用属性和函数
1.Series常用属性 属性说明 values获取数组 index获取索引 namevalues的name index.name索引的name 2.Series常用函数 Series可使用ndarr ...
- jquery 常用属性和函数(part I)
Attribute: $("p").addClass(css中定义的样式类型); 给某个元素添加样式 $("img").attr({src:"test ...
- 第四次网页前端培训(CSS常用属性与盒子模型)
CSS常用属性 背景 <head><meta charset="utf-8"><title>常用属性设置</title><st ...
- Pytorch模型中的GPU运算详解与实践
前言 什么是GPU? GPU(Graphic Process Units,图形处理器).是一种单芯片处理器,主要用于管理和提高视频和图形的性能.GPU 加速计算是指同时利用图形处理器 (GPU) 和 ...
- 基于C++的PyTorch模型部署
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...
- 在C++平台上部署PyTorch模型流程+踩坑实录
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文主要讲解如何将pytorch的模型部署到c++平台上的模 ...
- 保存和加载pytorch模型
当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...
- Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型
Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型 原文:https://blog.csdn.net/u011808673/article/details/8079 ...
最新文章
- 路由守卫 AJAX,vue路由导航守卫 和 请求拦截以及基于node的token认证
- 新基建之数据中心2020
- QT绘制饼状图,自定义切片。
- 使用squid代理时出现“The requested URL could not be retrieved”
- 用SpringGraph制作拓扑图和关系图
- jQuery清空div内容
- 开课吧:哪些人适合转行做Web前端?
- 2. Javascript 数据类型
- Java全套学习资料
- 爬虫自动定时获取查重结果并将结果发送至指定邮箱
- 25岁,上帝找你谈一次灵魂。——送给女孩,也送给男孩
- 韩咏梅:幸福只需要七分饱(转自新加坡联合早报)
- 原力计划S5上榜博主名单公布(第四期已更新)
- 人工智能技术知识图谱
- zk和quartz实现分布式定时调度
- 使用dns服务器信息的方法,保护DNS服务器十大最有效方法
- Unity Editor修改分辨率
- 经典 90 坦克大战 Python 版实现(支持单双人模式)
- stm32ad测量范围_用STM32的AD测电压,范围是0~3.3V,但是输入电压可能高于3.3,怎么保护STM32?...
- 锐捷网络:引领地铁移动互联网快捷交付2.0时代到来