1. 前言

一般我们在构建CNN的时候都是以32位浮点数为主,这样在网络规模很大的情况下就会占用非常大的内存资源。然后我们这里来理解一下浮点数的构成,一个float32类型的浮点数由一个符号位,8个指数位以及23个尾数为构成,即:

符号位[ ] + 指数位[ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] + 尾数[ ]*23

我们可以看到,每个float32浮点数里面一共有223=838860872^{23}=83886087223=83886087个二进制对应表示2232^{23}223个数,又106<223<10810^6<2^{23}<10^8106<223<108,所以我们一般可以精确地表示6位有效数字,但是无法表示888位有效数字。浮点数有正负所以需要一个符号位来表示,还有888个指数位来表示指数(指数也是要存储的)由于有正负也就是[−127,128][-127,128][−127,128]。

然后完成一个浮点数的加减运算至少有如下过程:

  • 检查操作数,即如果有一个参与运算的数为000,那么直接得出结果。
  • 比较阶码大小完成对阶。
  • 对尾数进行加或减运算。
  • 将结果规格化并进行舍入处理。

关于这一节,可以看这个表述得非常清楚的博客:http://www.cppblog.com/jianjianxiaole/articles/float.html

从上面对浮点数的介绍来看,如果我们使用全浮点数的CNN,那么不仅仅会占用大量的内存,还需要大量的计算资源。在这种背景下,低比特量化的优势就体验出来了。接下来,我们就先看一下2016年NIPS提出的《Binarized Neural Networks: Training Neural Networks with Weights andActivations Constrained to +1 or −1》这篇论文,简称BNN,然后再对BNN的Pytorch代码做一个解析。

2. BNN的原理

2.1 二值化方案

在训练BNN时,我们要把网络层的权重和输出值设为1或者-1,下面是论文提出的222种二值化方法。

第一种是直接将大于等于0的参数设置为1,小于0的设置为-1,即:

第二种是将绝对值大于1的参数设为1,将绝对值小于1的参数根据距离+/−1+/-1+/−1的远近按概率随机置为+/−1+/-1+/−1:

其中σ(x)\sigma(x)σ(x)是一个clip函数,公式如下:

论文中提到,第二种方法似乎更加合理,但它也引入了按概率分布的随机比特数,因此硬件实现会消耗很多时间,所以通常会选定第一种方法来对权重和输出值进行量化。

2.2 如何反向传播?

将CNN的权重和输出值二值化以后,梯度信息应当怎么办呢?论文指出梯度仍然不得不用较高精度的实数来存储,因为梯度很小,所以无法使用低精度来正确表达梯度,同时梯度是有高斯白噪声的,累加梯度才能抵消噪声。另外,二值化相当于给权重和输出值添加了噪声,这种噪声具有正则化的作用,可以防止模型过拟合,即它可以让权重更加稀疏。

由于signsignsign函数的导数在非0处都是0,所以在梯度回传的时候使用tanh来替代sign进行求导,这里假设损失函数是CCC,输入是rrr,对rrr做二值化可得:

q=sign(r)q=sign(r)q=sign(r)

CCC对qqq的导数可以用gqg_qgq​表示,那么CCC对rrr的导数为:

gr=gq1∣r∣<=1g_r=g_q1_{|r|<=1}gr​=gq​1∣r∣<=1​

其中1∣r∣<=11_{|r|<=1}1∣r∣<=1​是tanh的梯度,这样就可以进行梯度回传,然后就可以根据梯度不断优化并训练参数了。

这里需要注意的是我们需要使用BatchNorm层,BN层最大的作用就是可以加速学习并减少权重尺度的影响,带来一定量的正则化并提高CNN的性能,但是BN设计了很多的矩阵运算会降低运算速度。因此,论文提出了一种Shift-based Batch Normalization(SBN)层。SBN的优点是几乎不需要矩阵运算,并且不会带来性能损失。SBN的操作过程如下:

这个函数实现了在不使用乘法的情况下近似计算BN,可以提高计算效率。

同样也是为了加速二值网络的训练,改进了AdaMax优化器。具体算法如下图所示。

2.3 第一层怎么办?

由于网络除了输入以外,全部都是二值化的,所以需要对第一层进行处理,将其二值化,整个二值化网络的处理流程如下:

整个过程可以表示为:初始化第一层->计算前一层点积的Xnor->计算BatchNorm的符号->执行网络到倒数第二层->计算输出…

以上是假设输入的每个数字只有8位的情况,如果我们希望使用任意nnn位的整数,那么我们可以对公式进行推广,即:

LinearQuant(x,bitwidth)=clip(round(xbitwidth)×bitwidth,minV,maxV)LinearQuant(x,bitwidth)=clip(round(\frac{x}{bitwidth})\times bitwidth, minV, maxV)LinearQuant(x,bitwidth)=clip(round(bitwidthx​)×bitwidth,minV,maxV)

或者

LogQuant(x,bitwidth)=clip(AP2(x),minV,maxV)LogQuant(x,bitwidth)=clip(AP2(x), minV, maxV)LogQuant(x,bitwidth)=clip(AP2(x),minV,maxV)

3. 代码实现

接下来我们来解析一下Pytorch实现一个BNN,需要注意的是代码实现和上面介绍的原理有很多不同,首先第一个卷积层没有做二值化,也就是说第一个卷积层是普通的卷积层。对于输入也没有做定点化,即输入仍然为Float。另外,对于BN层和优化器也没有按照论文中的方法来做优化,代码地址如下:https://github.com/666DZY666/model-compression/blob/master/quantization/WbWtAb/models/nin.py

3.1 定义网络结构

下面的代码定义了支持权重和输出值分别可选二值或者三值量化,可以看到核心函数即为Conv2d_Q

import torch.nn as nn
from .util_wt_bab import Conv2d_Q# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):# 参数:last_relu-尾层卷积输入激活def __init__(self, input_channels, output_channels,kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):super(Tnn_Bin_Conv2d, self).__init__()self.A = Aself.W = Wself.last_relu = last_relu# ********************* 量化(三/二值)卷积 *********************self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)self.bn = nn.BatchNorm2d(output_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.tnn_bin_conv(x)x = self.bn(x)if self.last_relu:x = self.relu(x)return xclass Net(nn.Module):def __init__(self, cfg = None, A=2, W=2):super(Net, self).__init__()# 模型结构与搭建if cfg is None:cfg = [192, 160, 96, 192, 192, 192, 192, 192]self.tnn_bin = nn.Sequential(nn.Conv2d(3, cfg[0], kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(cfg[0]),Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, A=A, W=W),Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, A=A, W=W),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, A=A, W=W),Tnn_Bin_Conv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, A=A, W=W),Tnn_Bin_Conv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, A=A, W=W),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),Tnn_Bin_Conv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, A=A, W=W),Tnn_Bin_Conv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, last_relu=1, A=A, W=W),nn.Conv2d(cfg[7],  10, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(10),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=8, stride=1, padding=0),)def forward(self, x):x = self.tnn_bin(x)x = x.view(x.size(0), -1)return x

3.2 具体实现

我们跟进一下Conv2d_Q函数,来看一下二值化的具体代码实现,注意我将代码里面和三值化有关的部分省略了。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function# ********************* 二值(+-1) ***********************
# A 对激活值进行二值化的具体实现,原理中的第一个公式
class Binary_a(Function):@staticmethoddef forward(self, input):self.save_for_backward(input)output = torch.sign(input)return output@staticmethoddef backward(self, grad_output):input, = self.saved_tensors#*******************ste*********************grad_input = grad_output.clone()#****************saturate_ste***************grad_input[input.ge(1)] = 0grad_input[input.le(-1)] = 0'''#******************soft_ste*****************size = input.size()zeros = torch.zeros(size).cuda()grad = torch.max(zeros, 1 - torch.abs(input))#print(grad)grad_input = grad_output * grad'''return grad_input
# W 对权重进行二值化的具体实现
class Binary_w(Function):@staticmethoddef forward(self, input):output = torch.sign(input)return output@staticmethoddef backward(self, grad_output):#*******************ste*********************grad_input = grad_output.clone()return grad_input# ********************* A(特征)量化(二值) ***********************
# 因为我们使用的网络结构不是完全的二值化,第一个卷积层是普通卷积接的ReLU激活函数,所以要判断一下
class activation_bin(nn.Module):def __init__(self, A):super().__init__()self.A = Aself.relu = nn.ReLU(inplace=True)def binary(self, input):output = Binary_a.apply(input)return outputdef forward(self, input):if self.A == 2:output = self.binary(input)# ******************** A —— 1、0 *********************#a = torch.clamp(a, min=0)else:output = self.relu(input)return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):mean = w.data.mean(1, keepdim=True)w.data.sub(mean) # W中心化(C方向)w.data.clamp(-1.0, 1.0) # W截断return w
# 对激活值进行二值化
class weight_tnn_bin(nn.Module):def __init__(self, W):super().__init__()self.W = Wdef binary(self, input):output = Binary_w.apply(input)return outputdef forward(self, input):# **************************************** W二值 *****************************************output = meancenter_clampConvParams(input) # W中心化+截断# **************** channel级 - E(|W|) ****************E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)# **************** α(缩放因子) ****************alpha = E# ************** W —— +-1 **************output = self.binary(output)# ************** W * α **************output = output * alpha # 若不需要α(缩放因子),注释掉即可# **************************************** W三值 *****************************************else:output = inputreturn output# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,A=2,W=2):super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)# 实例化调用A和W量化器self.activation_quantizer = activation_bin(A=A)self.weight_quantizer = weight_tnn_bin(W=W)def forward(self, input):# 量化A和Wbin_input = self.activation_quantizer(input)tnn_bin_weight = self.weight_quantizer(self.weight)    #print(bin_input)#print(tnn_bin_weight)# 用量化后的A和W做卷积output = F.conv2d(input=bin_input, weight=tnn_bin_weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)return output

上面的代码比较好理解,因为它将BNN论文中最难实现的SBN和改进后的AdaMax优化算法省略掉了,并且没有对输入进行定电化,所以编码难度小了很多,这个代码可以验证一下使用BNN之后精度损失。

4. 实验结果

这里贴一下使用上面的网络训练Cifar10图像分类数据集的准确率对比:

可以看到如果将除了第一层卷积之外的卷积层均换成二值化卷积之后,模型的压缩率高达92%并且准确率也只有1个点的下降,这说明在Cifar10数据集上这种方法确实是有效的。笔者跑了一下这个代码,测试结果和代码作者是类似的。

5. 思考

我们看一下论文给出的BNN在MNIST/CIFAR-10等数据集上的测试结果:

可以看到这些简单网络的分类误差还在可接受的范围之内,但这种二值化网络在ImageNet上的测试结果却是比较差的,出现了很大的误差。虽然还存在很多的优化技巧比如放开Tanh的边界,用2Bit的激活函数可以提升一些精度,但在复杂模型下效果仍然不太好。因此,二值化模型的最大缺点应该是不适合复杂模型。另外,新定义的算子在部署时也是一件麻烦的事,还需要专门的推理框架或者定制的硬件来支持。不然就只能像我们介绍的代码实现那样,使用矩阵乘法来模拟这个二值化计算过程,但加速是非常有限的。


欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧

有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:


为了方便读者获取资料以及我们公众号的作者发布一些Github工程的更新,我们成立了一个QQ群,二维码如下,感兴趣可以加入。

基于Pytorch构建一个可训练的BNN相关推荐

  1. 【神经网络】Pytorch构建自己的训练数据集

    [神经网络]Pytorch构建自己的训练数据集 ​ 最近参加了一个比赛,需要对给定的图像数据进行分类,之前使用Pytorch进行神经网络模型的构建与训练过程中,都是使用的Pytorch内置的数据集,直 ...

  2. 基于pytorch的模型稀疏训练与模型剪枝示例

    基于pytorch的模型稀疏训练与模型剪枝示例 稀疏训练+模型剪枝代码下载地址:下载地址 CIFAR10-VGG16BN Baseline Trained with Sparsity (1e-4) P ...

  3. 深度学总结:RNN训练需要注意地方:pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep

    pytorch每一个batch训练之前需要把hidden = hidden.data,否者反向传播的梯度会遍历以前的timestep tensorflow也有把new_state更新,但是没有明显de ...

  4. 基于ForkJoin构建一个简单易用的并发组件

    2019独角兽企业重金招聘Python工程师标准>>> 基于ForkJoin构建一个简单易用的并发组件 在实际的业务开发中,需要用到并发编程的知识,实际使用线程池来异步执行任务的场景 ...

  5. [carla入门教程]-6 小项目:基于carla-ros-bridge构建一个小型比赛赛道

    本专栏教程将记录从安装carla到调用carla的pythonAPI进行车辆操控并采集数据的全流程,带领大家从安装carla开始,到最终能够熟练使用carla仿真环境进行传感器数据采集和车辆控制. 第 ...

  6. 用Pytorch构建一个喵咪识别模型

    本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 目录 一.前言 二.问题阐述及理论流程 2.1问题阐述 2.2猫咪图片识别原 ...

  7. 使用Pytorch构建一个分类器(CIFAR10模型)

    分类器任务和数据介绍 ·构建一个将不同图像进行分类的神经网络分类器,对输入的的图片进行判别并完成分类. ·本案例采用CIFAR10数据集作为原始图片数据 ·CIFAR10数据集介绍:数据集中每张图片的 ...

  8. 使用pytorch构建一个神经网络、损失函数、反向传播、更新网络参数

    关于torch.nn: 使用Pytorch来构建神经网络, 主要的工具都在torch.nn包中. nn依赖于autograd来定义模型, 并对其自动求导. 构建神经网络的典型流程: 定义一个拥有可学习 ...

  9. 我的实践:通过蚂蚁、蜜蜂二分类问题了解如何基于Pytorch构建分类模型

    文章目录 1.数据集准备 2.pytorch Dataset 处理图片数据 3.网络模型设计 4.模型的训练与测试 1.数据集准备 本例采用了pytorch教程提供的蜜蜂.蚂蚁二分类数据集(点击可直接 ...

  10. 基于pytorch构建双向LSTM(Bi-LSTM)文本情感分类实例(使用glove词向量)

    学长给的代码,感觉结构清晰,还是蛮不错的,想以后就照着这样的结构走好了,记录一下. 首先配置环境 matplotlib==3.4.2 numpy==1.20.3 pandas==1.3.0 sklea ...

最新文章

  1. 人工智能技术给教育行业带来哪些主要影响?
  2. Youtube-dl调用外部Aria2多线程加速下载
  3. php swoole yii,yii2-swoole
  4. getseconds补0_Java Duration类| getSeconds()方法与示例
  5. python常用小技巧(四)——批量图片改名
  6. 菜鸟的学习之路(12) —HashSet类详解
  7. python文件拆分_python – 在几个文件中拆分views.py.
  8. My sql 常用函数
  9. 论文笔记_S2D.52_CMRNet++_运行记录
  10. linux bash错误,linux bash错误重定向输出
  11. 计算机密码突然不正确,win10开机密码明明正确,win10密码突然不对了
  12. opencv安装教程
  13. 3ds Max2021软件安装包+安装教程
  14. 三、Snapman多人协作电子表格之——软件的基本功能
  15. 自动删除QQ空间指定好友的留言
  16. Startbbs YouBBS等轻论坛程序折腾过程
  17. 汇正财经骗局?科创50大涨
  18. 装配一台计算机有哪些安全注意事项,挤出机注意事项
  19. Appium:配置华为手机鸿蒙HarmonyOS系统参数
  20. HTML那些可爱的“表情包”

热门文章

  1. c语言大作业实现程序功能描述,C语言程序设计大作业——员工管理系统(代码超详细内含实验报告)...
  2. 保姆级教程 树莓派4b ubuntu20.04 的 linux 之旅
  3. 2-2 李宏毅2021春季机器学习教程-类神经网络训练不起来怎么办(一)局部最小值与鞍点(Local Minima and Saddle Point)
  4. 2022.9.13 手机验证码登录功能
  5. Pytorch下基于lstm的股价预测
  6. 一般家用路由器买多大的合适_家用路由器多少兆合适
  7. plc的毕业设计冷门题目_PLC毕业设计----PLC毕业设计题目汇总
  8. 百度 LBS 开放平台,开发者众测计划正式启动
  9. 批处理bat schtasks 启动远程应用
  10. java offset什么意思_java – “offset或count可能接近-1 1”这是什么意思