参考:

https://zhuanlan.zhihu.com/p/100672008

https://www.jianshu.com/p/2b94da24af3b

https://github.com/ptrblck/pytorch_misc

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test2.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 5月 19, 2021
# ---
import torch
import torch.nn as nn
import numpy as np
np.random.seed(10)
torch.manual_seed(10)data = np.array([[1, 2, 7],[1, 3, 9],[1, 4, 6]]).astype(np.float32)
bn_torch = nn.BatchNorm1d(num_features=3)
data_torch = torch.from_numpy(data)
bn_output_torch = bn_torch(data_torch)
print("bn_output_torch:", bn_output_torch)def fowardbn(x, gam, beta, ):'''x:(N,D)维数据'''momentum = 0.1eps = 1e-05running_mean = 0running_var = 1running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)mean = x.mean(dim=0)var = x.var(dim=0,unbiased=False)# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).datax_hat = (x - mean) / torch.sqrt(var + eps)out = gam * x_hat + betaprint("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)cache = (x, gam, beta, x_hat, mean, var, eps)return out, cacheclass MyBN:def __init__(self, momentum, eps, num_features):"""初始化参数值:param momentum: 追踪样本整体均值和方差的动量:param eps: 防止数值计算错误:param num_features: 特征数量"""# 对每个batch的mean和var进行追踪统计self._running_mean = 0self._running_var = 1# 更新self._running_xxx时的动量self._momentum = momentum# 防止分母计算为0self._eps = eps# 对应论文中需要更新的beta和gamma,采用pytorch文档中的初始化值self._beta = np.zeros(shape=(num_features, ))self._gamma = np.ones(shape=(num_features, ))def batch_norm(self, x):"""BN向传播:param x: 数据:return: BN输出"""x_mean = x.mean(axis=0)x_var = x.var(axis=0)# 对应running_mean的更新公式self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_meanself._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var# 对应论文中计算BN的公式x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)y = self._gamma*x_hat + self._betaprint("x_mean:", x_mean, "x_var:", x_var, "self._gamma:", self._gamma, "self._beta:", self._beta)return ymy_bn = MyBN(momentum=0.1, eps=1e-05, num_features=3)
my_bn._beta = bn_torch.bias.detach().numpy()
my_bn._gamma = bn_torch.weight.detach().numpy()
bn_output = my_bn.batch_norm(data, )
print("MyBN bn_output:", bn_output)out, cache = fowardbn(data_torch.detach(), bn_torch.weight.detach(), bn_torch.bias.detach())
print("fowardbn out2: ", out)
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 5月 19, 2021
# ---import numpy as np
np.set_printoptions(suppress = True)
import torch
import torch.nn as nn
np.random.seed(10)
torch.manual_seed(10)# import pprint
# np.random.seed(10)
# norm = np.random.normal(size=(5, 5))
# pprint.pprint(norm)data = [[0.1, 0.3, 0.4],[0.5, 0.3, 0.2],[0.4, 0.6, 0.1],[0.5, 0.3, 0.2],
]
data_np = np.array(data, dtype=np.float32)*10; print("data_np.shape:", data_np.shape);
data_np = data_np.reshape((3,-1)); print("data_np.shape:", data_np.shape);
t_data = torch.from_numpy(data_np); t_data = torch.unsqueeze(t_data, dim=0)
print("t_data.shape:", t_data.shape); print(t_data)class PointNet(nn.Module):def __init__(self):super(PointNet, self).__init__()#pytorch之nn.Conv1d详解 https://blog.csdn.net/sunny_xsc1994/article/details/82969867self.conv1 = torch.nn.Conv1d(3, 5, 1)self.bn1 = nn.BatchNorm1d(5)#Pytorch权值初始化及参数分组 https://blog.csdn.net/Bear_Kai/article/details/99302341#Pytorch 实现权重初始化 https://www.jb51.net/article/177617.htmfor m in self.modules():if isinstance(m, nn.Conv1d):m.weight.data.normal_(0, 1)if m.bias is not None:m.bias.data.zero_()self.weight = np.asarray(m.weight.data)#print("nn.Conv1d:", m.weight.data)elif isinstance(m, nn.BatchNorm1d):m.weight.data.fill_(5) #1m.bias.data.zero_()def forward(self, x):result1 = self.conv1(x)result2 = self.bn1(result1)return result1, result2, self.weightpn = PointNet()
result1, result2, weight = pn(t_data); weight = torch.from_numpy(weight)
print("weight.shape:", weight.shape); print("weight:", weight)
print("result1.shape:", result1.shape); print(result1)
print("result2.shape:", result2.shape); print(result2)
#print("result1_end:", pn.bn1(result1))
#PointNet论文复现及代码详解 https://zhuanlan.zhihu.com/p/86331508
for n in range(t_data.shape[2]):sum = []for m in range(weight.shape[0]):#Pytorch总结之乘法 https://zhuanlan.zhihu.com/p/212461087#sum += (torch.mul(t_data[0,:,0], weight[m,:,0]))#对应位相乘sum.append(torch.dot(t_data[0, :, n], weight[m, :, 0]))#点乘print("sum:", sum)
#pytorch nn.BatchNorm1d 与手动python实现不一样--解决办法 https://www.jianshu.com/p/2b94da24af3b
#https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
def fowardbn(x, gam, beta, dim=0):'''x:(N,D)维数据'''momentum = 0.1eps = 1e-05running_mean = 0running_var = 5 #1running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)mean = x.mean(dim=dim)var = x.var(dim=dim, unbiased=False)# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).datax_hat = (x - mean) / torch.sqrt(var + eps)out = gam * x_hat + betaprint("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)cache = (x, gam, beta, x_hat, mean, var, eps)return out, cache#如果是B*C*(H*W)
#1, 3_Iup, 4
#3_Iup, 5_Out, 1 卷积核
#1, 5_Out(channel), 4
bn_re = result1.permute(0, 2, 1)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=1)
out = out.permute(0, 2, 1)
print("out1", out)bn_re = result1.squeeze()
bn_re = bn_re.permute(1, 0)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=0)
out = out.permute(1, 0)
print("out2", out)x = np.array([[-1.2089,  6.8342, -0.3317, -5.2298],[ 2.5075,  9.6109,  8.8057,  9.0995],[ 4.2763,  1.2605,  6.7774, 11.4138],[ 1.0103,  1.0549,  0.3408,  0.0656],[-2.2381,  1.9428, -3.6522, -7.8491]])
x = x.mean(axis=1)
y = np.array([-2.2381,  1.9428, -3.6522, -7.8491])
y = y.mean(axis=0)
print(x)
print(y)

BatchNorm1d相关推荐

  1. Pytorch归一化方法讲解与实战:BatchNormalization、LayerNormalization、nn.BatchNorm1d和LayerNorm()和F.normalize()

    文章目录 LayerNormalization BatchNormalization F.normalize 这些Normalization的作用都是让数据保持一个比较稳定的分布,从而加速收敛.Bat ...

  2. pytorch中的批量归一化BatchNorm1d和BatchNorm2d的用法、原理记录

    1.对2d或3d数据进行批标准化(Batch Normlization)操作: 原类定义: class torch.nn.BatchNorm1d(num_features, eps=1e-05, mo ...

  3. pytorch中批量归一化BatchNorm1d和BatchNorm2d函数

    class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True) [source] 对小批量(mini-ba ...

  4. pytorch 批量归一化BatchNorm1d和BatchNorm2d的用法、BN层参数 running_mean running_var变量计算 验证

    前提知识 BN层包括mean var gamma beta四个参数,.对于图像来说(4,3,2,2),一组特征图,一个通道的特征图对应一组参数,即四个参数均为维度为通道数的一维向量,图中gamma.b ...

  5. batchnorm2d参数 torch_pytorch中BatchNorm1d、BatchNorm2d、BatchNorm3d

    1.nn.BatchNorm1d(num_features) 1.对小批量(mini-batch)的2d或3d输入进行批标准化(Batch Normalization)操作 2.num_feature ...

  6. pytorch BatchNorm1d 输入二维和三维数据的区别

    在阅读KPConv-PyTorch源码时,发现其对torch.nn.BatchNorm1d进行了封装. class BatchNormBlock(nn.Module):def __init__(sel ...

  7. nn.BatchNorm1d

    本篇博客主要讲解BatchNorm函数的执行过程,需要读者有一定的批归一化的基础,本文例子通俗易懂,如果没有基础也可以阅读 在PyTorch中BatchNorm有三个函数,这里主要讲解前两个,后面的就 ...

  8. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  9. batch normalization

    20210702 深度学习中的五种归一化(BN.LN.IN.GN和SN)方法简介 https://blog.csdn.net/u013289254/article/details/99690730 h ...

最新文章

  1. java基础练习题目
  2. Javascript关闭当前窗口
  3. 用CSS3来添加项目编号
  4. 二分法变种小结(leetcode 34、leetcode33、leetcode 81、leetcode 153、leetcode 74)
  5. 编码方法论,赋能你我他
  6. css根据当前宽度设置css,JS和CSS实现自动根据分辨率设置页面宽度
  7. 【白皮书分享】2021超级品牌力白皮书:数字时代的品牌人群心智重塑.pdf(附下载链接)...
  8. 绘图神器 —— Graphviz 绘制数据结构相关图形
  9. 剑指offer-面试题37:序列化二叉树及二叉树的基本操作和测试
  10. 利用grep-console插件使Intellij idea显示多颜色调试日志
  11. linux系统连接实验室服务器步骤详解
  12. VB获取一个文件夹中指定的文件或文件夹名称到列表
  13. 51单片机流水灯和蜂鸣器同步
  14. 普京任命卡德罗夫为车臣总统
  15. BLE安全之SM剖析(1)
  16. Linux_多线程(进程与线程的联系_pthread库_线程创建_线程等待_线程正常终止_线程取消_线程分离_pthread_t与LWP)
  17. STM32F103调试笔记(1)——microusb接入电脑后显示未知USB设备(代码43)
  18. linux内核创建软链接过程,Linux内核2.4.18创建符号链接的系统调用sys_symlink分析
  19. 使用 Zadig 交付云原生微服务应用
  20. 往word表中写数据

热门文章

  1. web前端(css3)
  2. KISSY基础篇乄KISSY简介
  3. linux 查看目录挂载,linux查看磁盘挂载的三种方法
  4. ecw2c认真有效地在线查找帮助!
  5. python语言与蟒蛇_1、python语言是一种“大蟒蛇语言‘’,但是python语言却和蟒蛇没有任何关系_学小易找答案...
  6. Python入门经典笔记之安装numpy和matplotlib遇到的问题
  7. c语言汉诺塔问题用指针变量,谁会用C语言解决汉诺塔问题?请进,最好把每一步的解释写上有三个 爱问知识人...
  8. 新型能力的识别与确定
  9. Windows 无法连接打印机,请检查打印机名并重试。如果这是网络打印机,请确保打印机已打开,并且打印机地址正确。报错代码:0x00000709
  10. 微金所案例总结——Bootstrap应用模板引擎的使用