现在深度学习中一般我们学习的参数都是连续的,因为这样在反向传播的时候才可以对梯度进行更新。但是有的时候我们也会遇到参数是离>散的情况,这样就没有办法进行反向传播了,比如二值神经网络。本文中讲解了如何用pytorch对二值化的参数进行梯度更新的straight-through estimator算法。
Question:
STE核心的思想就是我们的参数初始化的时候就是float这样的连续值,当我们forward的时候就将原来的连续的参数映射到{-1, 1}带入到网络进行计算,这样就可以计算网络的输出。然后backward的时候直接对原来float的参数进行更新,而不是对二值化的参数更新。这样可以完成对整个网络的更新了。
首先我们对上面问题进行一下数学的讲解。

Example:
首先我们验证一下使用torch.sign会是参数的梯度基本上都是0:

>>> input = torch.randn(4, requires_grad = True)
>>> output = torch.sign(input)
>>> loss = output.mean()
>>> loss.backward()
>>> input
tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True)
>>> input.grad
tensor([0., 0., 0., 0.])

我们需要重写sign这个函数,就好像写一个激活函数一样。

import torchclass LBSign(torch.autograd.Function):@staticmethoddef forward(ctx, input):return torch.sign(input)@staticmethoddef backward(ctx, grad_output):return grad_output.clamp_(-1, 1)
import torch
from LBSign import LBSignif __name__ == '__main__':sign = LBSign.applyparams = torch.randn(4, requires_grad = True)                                                                           output = sign(params)loss = output.mean()loss.backward()

测试梯度:

>>> params
tensor([-0.9143,  0.8993, -1.1235, -0.7928], requires_grad=True)
>>> params.grad
tensor([0.2500, 0.2500, 0.2500, 0.2500])

文章转载:https://segmentfault.com/a/1190000020993594?utm_source=tag-newest仅供参考学习,如有侵权则请联系博主。

参考文献:

  • https://segmentfault.com/a/1190000020993594?utm_source=tag-newest

pytorch实现straight-through estimator(STE)相关推荐

  1. 开源项目|基于darknet实现量化感知训练,已实现yolov3-tiny所有算子

    ◎本文为极市开发者「ArtyZe」原创投稿,转载请注明来源. ◎极市「项目推荐」专栏,帮助开发者们推广分享自己的最新工作,欢迎大家投稿.联系极市小编(fengcall19)即可投稿~ 量化简介 在实际 ...

  2. QAT(Quantization Aware Training)量化感知训练(二)【详解】

    文章目录 1.QAT(Quantization Aware Training)的建议 1.QAT(Quantization Aware Training)的建议 Quantization Aware ...

  3. 性能不打折,内存占用减少90%,Facebook提出极致模型压缩方法Quant-Noise

    对于动辄上百 M 大小的神经网络来说,模型压缩能够减少它们的内存占用.通信带宽和计算复杂度等,以便更好地进行应用部署.最近,来自 Facebook AI 的研究者提出了一种新的模型量化压缩技术 Qua ...

  4. java list 占用内存不释放_性能不打折,内存占用减少90%,Facebook提出极致模型压缩方法Quant-Noise...

    对于动辄上百 M 大小的神经网络来说,模型压缩能够减少它们的内存占用.通信带宽和计算复杂度等,以便更好地进行应用部署.最近,来自 Facebook AI 的研究者提出了一种新的模型量化压缩技术 Qua ...

  5. 闲话模型压缩之量化(Quantization)篇

    1. 前言 这些年来,深度学习在众多领域亮眼的表现使其成为了如今机器学习的主流方向,但其巨大的计算量仍为人诟病.尤其是近几年,随着端设备算力增强,业界涌现出越来越多基于深度神经网络的智能应用.为了弥补 ...

  6. 收藏 | 一文总结70篇论文,帮你透彻理解神经网络的剪枝算法

    来源:DeepHub IMBA本文约9500字,建议阅读10+分钟 本文为你详细介绍神经网络剪枝结构.剪枝标准和剪枝方法. 无论是在计算机视觉.自然语言处理还是图像生成方面,深度神经网络目前表现出来的 ...

  7. 深度学习量化总结(PTQ、QAT)

    背景  目前神经网络在许多前沿领域的应用取得了较大进展,但经常会带来很高的计算成本,对内存带宽和算力要求高.另外降低神经网络的功率和时延在现代网络集成到边缘设备时也极其关键,在这些场景中模型推理具有严 ...

  8. 我总结了70篇论文的方法,帮你透彻理解神经网络的剪枝算法

    无论是在计算机视觉.自然语言处理还是图像生成方面,深度神经网络目前表现出来的性能都是最先进的.然而,它们在计算能力.内存或能源消耗方面的成本可能令人望而却步,这使得大部份公司的因为有限的硬件资源而完全 ...

  9. 初入神经网络剪枝量化4(大白话)

    二. 量化 简单介绍目前比较SOTA的量化方法,也是最近看的. 2.1 DSQ   Differentiable Soft Quantization:Bridging Full-Precision a ...

  10. 高糊图片可以做什么?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者:David Berthelot.Peyman Milanfar ...

最新文章

  1. CSDN 的文化衫寄送到啦
  2. Python PIL.Image和OpenCV图像格式相互转换
  3. 【转载】OI生涯结束……在逸夫楼那些的日子里
  4. LeetCode 1403. 非递增顺序的最小子序列(排序)
  5. nagios监控配置错误汇总
  6. jQuery 鼠标事件
  7. matlab程序设计课件,《MATLAB程序设计》PPT课件.ppt
  8. ads1110程序实测好用
  9. 【每晚20点红包雨】2018天猫聚划算99大促欢聚盛典活动介绍
  10. java 汇率换算_[java] 汇率换算器实现(2)
  11. 弗雷德里克·特曼:硅谷之父、斯坦福大学前副校长——(转自新浪网)
  12. 线程池ExecutorService
  13. 大数据Spark(三十九):SparkStreaming实战案例四 窗口函数
  14. Frenetic HelloSDNWorld
  15. 163电子邮箱注册登录入口是?企业邮箱和163邮箱有什么区别?
  16. 全球及中国粉煤灰PFA行业行业发展动态与前景趋势预测报告2022-2028年
  17. 计算机网络:HTTP相关知识
  18. 关于CC2541蓝牙开发板的学习笔记-1
  19. 人参产业,炒作者与投资者之辩
  20. 微信小程序与H5的区别?

热门文章

  1. ios动态效果实现翻页_iOS实现翻页效果动画
  2. Activities介绍
  3. 做一个业务中台你到底会踩多少坑?
  4. 用python实现监听微信撤回消息
  5. linux常用命令-part3
  6. LSTM预测多支股票的收盘价
  7. UWB基本原理分析2
  8. manjaro配置输入法
  9. pod中mysql配置文件修改_Pod中的secret,configmap,downwardapi的使用记录
  10. 挣脱注意力经济:为什么应该练习数字极简主义?