链接

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Function# ********************* 二值(+-1) ***********************
# A
class Binary_a(Function):@staticmethoddef forward(self, input):self.save_for_backward(input)output = torch.sign(input)return output@staticmethoddef backward(self, grad_output):input, = self.saved_tensors#*******************ste*********************grad_input = grad_output.clone()#****************saturate_ste***************grad_input[input.ge(1)] = 0grad_input[input.le(-1)] = 0return grad_input
# W
class Binary_w(Function):@staticmethoddef forward(self, input):output = torch.sign(input)return output@staticmethoddef backward(self, grad_output):#*******************ste*********************grad_input = grad_output.clone()return grad_input
# ********************* 三值(+-1、0) ***********************
class Ternary(Function):@staticmethoddef forward(self, input):# **************** channel级 - E(|W|) ****************E = torch.mean(torch.abs(input), (3, 2, 1), keepdim=True)# **************** 阈值 ****************threshold = E * 0.7# ************** W —— +-1、0 **************output = torch.sign(torch.add(torch.sign(torch.add(input, threshold)),torch.sign(torch.add(input, -threshold))))return output, threshold@staticmethoddef backward(self, grad_output, grad_threshold):#*******************ste*********************grad_input = grad_output.clone()return grad_input# ********************* A(特征)量化(二值) ***********************
class activation_bin(nn.Module):def __init__(self, A):super().__init__()self.A = Aself.relu = nn.ReLU(inplace=True)def binary(self, input):output = Binary_a.apply(input)return outputdef forward(self, input):if self.A == 2:output = self.binary(input)# ******************** A —— 1、0 *********************#a = torch.clamp(a, min=0)else:output = self.relu(input)return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):mean = w.data.mean(1, keepdim=True)w.data.sub(mean) # W中心化(C方向)w.data.clamp(-1.0, 1.0) # W截断return w
class weight_tnn_bin(nn.Module):def __init__(self, W):super().__init__()self.W = Wdef binary(self, input):output = Binary_w.apply(input)return outputdef ternary(self, input):output = Ternary.apply(input)return outputdef forward(self, input):if self.W == 2 or self.W == 3:# **************************************** W二值 *****************************************if self.W == 2:output = meancenter_clampConvParams(input) # W中心化+截断# **************** channel级 - E(|W|) ****************E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)# **************** α(缩放因子) ****************alpha = E# ************** W —— +-1 **************output = self.binary(output)# ************** W * α **************output = output * alpha # 若不需要α(缩放因子),注释掉即可# **************************************** W三值 *****************************************elif self.W == 3:output_fp = input.clone()# ************** W —— +-1、0 **************output, threshold = self.ternary(input)# **************** α(缩放因子) ****************output_abs = torch.abs(output_fp)mask_le = output_abs.le(threshold)mask_gt = output_abs.gt(threshold)output_abs[mask_le] = 0output_abs_th = output_abs.clone()output_abs_th_sum = torch.sum(output_abs_th, (3, 2, 1), keepdim=True)mask_gt_sum = torch.sum(mask_gt, (3, 2, 1), keepdim=True).float()alpha = output_abs_th_sum / mask_gt_sum # α(缩放因子)# *************** W * α ****************output = output * alpha # 若不需要α(缩放因子),注释掉即可else:output = inputreturn output# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True,A=2,W=2):super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias)# 实例化调用A和W量化器self.activation_quantizer = activation_bin(A=A)self.weight_quantizer = weight_tnn_bin(W=W)def forward(self, input):# 量化A和Wbin_input = self.activation_quantizer(input)tnn_bin_weight = self.weight_quantizer(self.weight)    #print(bin_input)#print(tnn_bin_weight)# 用量化后的A和W做卷积output = F.conv2d(input=bin_input, weight=tnn_bin_weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)return output# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):# 参数:last_relu-尾层卷积输入激活def __init__(self, input_channels, output_channels,kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):super(Tnn_Bin_Conv2d, self).__init__()self.A = Aself.W = Wself.last_relu = last_relu# ********************* 量化(三/二值)卷积 *********************self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)self.bn = nn.BatchNorm2d(output_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.tnn_bin_conv(x)x = self.bn(x)if self.last_relu:x = self.relu(x)return xclass Net(nn.Module):def __init__(self, cfg = None, A=2, W=2):super(Net, self).__init__()# 模型结构与搭建if cfg is None:cfg = [192, 160, 96, 192, 192, 192, 192, 192]self.tnn_bin = nn.Sequential(nn.Conv2d(3, cfg[0], kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(cfg[0]),Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, A=A, W=W),Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, A=A, W=W),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, A=A, W=W),Tnn_Bin_Conv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, A=A, W=W),Tnn_Bin_Conv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, A=A, W=W),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),Tnn_Bin_Conv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, A=A, W=W),Tnn_Bin_Conv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, last_relu=1, A=A, W=W),nn.Conv2d(cfg[7],  10, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(10),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=8, stride=1, padding=0),)def forward(self, x):x = self.tnn_bin(x)x = x.view(x.size(0), -1)return x
import sys
import math
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import osdevice = torch.device('cuda:0')# 随机种子——训练结果可复现
def setup_seed(seed):torch.manual_seed(seed)                                 torch.cuda.manual_seed_all(seed)           np.random.seed(seed)                       torch.backends.cudnn.deterministic = True# 训练lr调整
def adjust_learning_rate(optimizer, epoch):update_list = [10,20,30,40,50]if epoch in update_list:for param_group in optimizer.param_groups:param_group['lr'] = param_group['lr'] * 0.5return# 模型训练
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(trainloader):# 前向传播data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)loss = criterion(output, target)# 反向传播optimizer.zero_grad()loss.backward() # 求梯度optimizer.step() # 参数更新# 显示训练集loss(/100个batch)if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.format(epoch, batch_idx * len(data), len(trainloader.dataset),100. * batch_idx / len(trainloader), loss.data.item(),optimizer.param_groups[0]['lr']))return# 模型测试
def test():global best_accmodel.eval()test_loss = 0average_test_loss = 0correct = 0for data, target in testloader:data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)# 前向传播output = model(data)test_loss += criterion(output, target).data.item()pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()# 测试准确率acc = 100. * float(correct) / len(testloader.dataset)print(acc)if __name__=='__main__':setup_seed(1)#随机种子——训练结果可复现# 训练集:随机裁剪 + 水平翻转 + 归一化transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 测试集:归一化transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])# 数据加载trainset = torchvision.datasets.CIFAR10(root='./data',train = True, download = True, transform = transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) # 训练集数据testset = torchvision.datasets.CIFAR10(root='./data',train = False, download = True, transform = transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2) # 测试集数据# cifar10类别classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')print('******Initializing model******')# ******************** 在model的量化卷积中同时量化A(特征)和W(模型参数) ************************model = Net(A=2, W=2)best_acc = 0for m in model.modules():if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight.data)m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()# cpu、gpumodel.to(device)# 打印模型结构print(model)# 超参数param_dict = dict(model.named_parameters())params = []for key, value in param_dict.items():params += [{'params':[value], 'lr': 0.01, 'weight_decay':0.0}]# 损失函数criterion = nn.CrossEntropyLoss()# 优化器optimizer = optim.Adam(params, lr=0.01, weight_decay=0.0)# 训练模型for epoch in range(1, 300):adjust_learning_rate(optimizer, epoch)train(epoch)test()

pytorch实现bnn相关推荐

  1. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  2. BNN领域开山之作——不得错过的训练二值化神经网络的方法

    作者| cocoon 编辑| 3D视觉开发者社区 文章目录 导读 概述 方法 确定二值化以及随机二值化 梯度计算以及累加 离散化梯度传播 乘法运算优化 基于位移(shift)的BN 基于位移的AdaM ...

  3. PyTorch 深度学习模型压缩开源库(含量化、剪枝、轻量化结构、BN融合)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文为52CV群友666dzy666投稿,介绍了他最近开源的PyTorch模型压缩库,该库开源不到20天已经收获 219 颗星,是最近值得关注的模型压缩 ...

  4. GNN PyTorch functions

    目录 PyTorch fundamental functions torch.nonzero()方法 torch.no_grad()函数 torch.state_dict()函数 torch.wher ...

  5. 深度学习pytorch常见编程技巧

    文章目录 一.画图.路径.csv.txt.导模块.类继承调用方法 写进日志log里面 pytorch ,可视化1,输出每一层的名字,输出shape,参数量 获取当前文件夹路径 python 获取当前目 ...

  6. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  7. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  8. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  9. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

  10. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

最新文章

  1. 分数阶累加的Python实现
  2. java 不能反序列化_不能将“Java.Lang.Studio”实例反序列化到StaskObl对象令牌中
  3. Linux 小知识翻译 - 「Linux的吉祥物企鹅叫什么名字?」
  4. android磁场传感器页面布局在哪,基于磁场检测的寻线小车传感器布局研究
  5. python获取天气数据_python获取天气数据
  6. elementUI 学习入门之 Button 按钮
  7. String写时拷贝实现
  8. postfix发送邮件报错:mail for xxxxx.com loops back to myself
  9. Java实现简单工厂模式
  10. Mockito + JUnit 单元测试实例
  11. 记一次带有FSG壳的熊猫烧香病毒分析过程
  12. 通信专业顶刊_通信类期刊排名_2016通信期刊排名_2016通信前沿新技术
  13. vb.net 教程 5-21 拓展 如何给IE浏览器截图
  14. maximo工作流底色更改
  15. iOS 下载器实现-ASDownload
  16. 13. 中国古代数学家张丘建在他的《算经》中提出了一个著名的“百钱百鸡问题”:一只公鸡值5钱,一只母鸡值3钱,三只小鸡值1钱,现在要用百钱买百鸡,请问公鸡、母鸡、小鸡各多少只?
  17. MySQL8.0 OCP最新版1Z0-908认证考试题库整理-004
  18. 怎么会是lucene?
  19. GPON OMCI简介
  20. GCN图卷积神经网络总结笔记

热门文章

  1. 嵌入式系统开发笔记88:认识51微控制器系统架构
  2. win7计算机属性资源管理器停止工作,Win7系统Windows资源管理器已停止工作怎么解决?...
  3. 一张专家推荐的最健康的作息时间表
  4. svg中marker元素的理解
  5. MySQL添加删除账户及授予权限
  6. C -- OC with RunTime
  7. 计算机语言编码常用英文,MQ4语言编程 EA常用英文词汇
  8. html语言制作折线图,html5绘制折线图
  9. 网站盈利模式分析分类
  10. BF算法(暴力算法)