点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

来源丨AI科技评论

编辑丨极市平台

导读

GitHub上最新开源的一个基于强化学习的自动化剪枝模型,本模型在图像识别的实验证明了能够有效减少计算量,同时还能提高模型的精度。

今天为大家介绍一个GitHub上最新开源的一个基于强化学习的自动化剪枝模型,本模型在图像识别的实验证明了能够有效减少计算量,同时还能提高模型的精度。

项目地址:

https://github.com/freefuiiismyname/cv-automatic-pruning-transformer

1

介绍

目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境。本项目的核心思想是:将模型视为环境,构建附生于模型的 agent ,以辅助模型进一步拟合真实样本。

大多数领域的模型都可以采用这种方式来优化,如cv/多模态等。它至少能够以三种方式工作:

1.过滤噪音信息,如删减语音或图像特征;

2.进一步丰富表征信息,如高效引用外部信息;

3.实现记忆、联想、推理等复杂工作,如构建重要信息的记忆池。

这里推出一款早期完成的裁剪机制transformer版本(后面称为APT),实现了一种更高效的训练模式,能够优化模型指标;此外,可以使用动态图丢弃大量的不必要单元,在指标基本不变的情况下,大幅降低计算量。

该项目希望为大家抛砖引玉。

2

为什么要做自动剪枝

在具体任务中,往往存在大量毫无价值的信息和过渡性信息,有时不但对任务无益,还会成为噪声。比如:表述会存在冗余/无关片段以及过渡性信息;动物图像识别中,有时候背景无益于辨别动物主体,即使是动物部分图像,也仅有小部分是关键的特征。

以transformer为例,在进行self-attention计算时其复杂度与序列长度平方成正比。长度为10,复杂度为100;长度为9,复杂度为81。

利用强化学习构建agent,能够精准且自动化地动态裁剪已丧失意义部分,甚至能将长序列信息压缩到50-100之内(实验中有从500+的序列长度压缩到个位数的示例),以大幅减少计算量。

实验中,发现与裁剪agent联合训练的模型比普通方法训练的模型效果要更好。

3

模型介绍及实验

模型主体

基于transformer的视觉预训练模型ViT是本项目的模型主体,具体细节可以查看论文:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》

自动化裁剪的智能体

对于强化学习agent来说,最关键的问题之一是如何衡量动作带来的反馈。为了评估单次动作所带来的影响,使用了以下三步骤:

1、使用一个普通模型(无裁剪模块)进行预测;

2、使用一个带裁剪器的模型(执行一次裁剪动作)进行预测;

3、对比两次预测的结果,若裁剪后损失相对更小,则说明该裁剪动作帮助了模型进一步拟合真实状况,应该得到奖励;反之,应该受到惩罚。

但是在实际预测过程中,模型是同时裁剪多个单元的,这或将因为多个裁剪的连锁反应而导致模型失效。训练过程中需要构建一个带裁剪器的模型(可执行多次裁剪动作),以减小该问题所带来的影响。

综上,本模型使用的是三通道模式进行训练。

关于裁剪器的模型结构设计,本模型中认为如何衡量一个信息单元是否对模型有意义,建立于其自身的信息及它与任务的相关性上。

因此以信息单元本身及它与CLS单元的交互作为agent的输入信息。

实验

数据集

ViT

APT(pruning)

APT(no pruning)

CIFAR-100

92.3

92.6

93.03

CIFAR-10

99.08

98.93

98.92

以上加载的均为ViT-B_16,resolution为224*224。

4

使用说明

环境

下载经过预先训练的模型(来自Google官方)

本项目使用的型号:ViT-B_16(您也可以选择其它型号进行测试)

训练与推理

下载好预训练模型就可以跑了。

CIFAR-10和CIFAR-100会自动下载和培训。如果使用其他数据集,您需要自定义data_utils.py。

在裁剪模式的推理过程中,预期您将看到如下格式的输出。

默认的batch size为72、gradient_accumulation_steps为3。当GPU内存不足时,您可以通过它们来进行训练。

注:相较于原始的ViT,APT(Automatic pruning transformer)的训练步数、训练耗时都会上升。原因是使用pruning agent的模型由于总会丢失部分信息,使得收敛速度变慢,同时为了训练pruning agent,也需要多次的观测、行动、反馈。

致谢

感谢基于pytorch的图像分类项目(https://github.com/jeonsworld/ViT-pytorch),本项目是在此基础上做的研发。

最后再附上一次项目地址,欢迎感兴趣的读者Star✨

https://github.com/freefuiiismyname/cv-automatic-pruning-transformer

如果觉得有用,就请分享到朋友圈吧!

点个在看 paper不断!

基于强化学习的自动化剪枝模型相关推荐

  1. 人工智能AI实战100讲(五)-基于强化学习的自动化剪枝模型

    1介绍 文中涉及代码请参见: 人工智能AI-图像处理cv-基于强化学习的自动化裁剪 目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境.本项目的核心 ...

  2. 应用实践 | 电商应用——一种基于强化学习的特定规则学习模型

    本文转载自公众号:浙大KG. 作者:汪寒,浙江大学硕士,主要研究方向为知识图谱和自然语言处理. 应用场景 在电商实际应用中,每个商品都会被挂载到若干个场景,以图结构中的节点形式存在.商品由结构化信息表 ...

  3. GitHub|基于强化学习自动化剪枝

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:来源丨AI科技评论 编辑丨极市平台 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打 ...

  4. [论文]基于强化学习的无模型水下机器人深度控制

    基于强化学习的无模型水下机器人深度控制 摘要 介绍 问题公式 A.水下机器人的坐标框架 B.深度控制问题 马尔科夫模型 A.马尔科夫决策 B.恒定深度控制MDP C.弯曲深度控制MDP D.海底追踪的 ...

  5. 智能城市dqn算法交通信号灯调度_博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型...

    原标题:博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型 国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共 ...

  6. 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型

    国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共收录 293 篇,录取率不足 20%.其中滴滴共有四篇论文入选 KDD 2018,涵 ...

  7. 基于强化学习的自我完善聊天机器人

    Elena Ricciardelli, Debmalya Biswas 埃琳娜·里恰德利(Elena Ricciardelli) Abstract. We present a Reinforcemen ...

  8. 基于强化学习的质量AI在淘系互动业务的实践之路

    导读:AI人工智能的概念由来已久,因为alphago在围棋领域击败李世石掀起了全世界范围内的AI热潮,最近又随着DeepMind破解蛋白质折叠难题这一诺奖级成果再次让我们发现AI已经进化到了如此强大的 ...

  9. 华为诺亚ICLR 2020满分论文:基于强化学习的因果发现算法

    2019-12-30 13:04:12 人工智能顶会 ICLR 2020 将于明年 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行,不久之前,大会官方公布论文接收结果:在最终提交的 2594 篇论文 ...

最新文章

  1. json vue 对象转数组_vue 基础入门(一)修改
  2. 用C语言编写贪吃蛇项目描述,刚学C语言,想写一个贪吃蛇的代码
  3. 【VC++技术杂谈005】如何与程控仪器通过GPIB接口进行通信
  4. ajax获取get请求,get请求
  5. 客户端级别的渲染分析工具 dynaTrace
  6. CSS 定位 (Positioning) 实例
  7. python opencv 界面按钮_如何使用Python构建简单的UI?
  8. 深信服桌面云取消聚合口后的影响
  9. 单片机c语言1小时视频教程,1小时学会C语言51单片机C语言入门教程.doc
  10. 【UCSC Genome Browser】- ClinGen剂量敏感性分析
  11. 3种方法设置和取消Excel文件的打开密码
  12. OSChina 周日乱弹 ——愿你在天堂也能写代码
  13. 计算机图像相关应用研究,计算机图像处理技术的应用探讨.pdf
  14. 支持5G WIFI的串口服务器
  15. 自用备份 Unity 获取 两个点的中心点
  16. Navicat 被投毒了 | 调查结果来了
  17. 微信小程序|考试系统|基于微信小程序和SpringBoot+VUE的智能在线考试系统毕业设计
  18. 认识字符集、ASCII、GBK、Unicode、UTF-8
  19. 共射极放大电路静态工作点自动调整分析
  20. 安装 VMware tools时报错:不在 sudoers 文件中。此事将被报告。

热门文章

  1. 40个出色的Wordpress cms插件
  2. 5.1软件升级的小阳春
  3. 【CTF】实验吧 传统知识+古典密码
  4. centos 默认mysql_centos改变mysql默认目录
  5. 内含福利|CSDN 携手字节跳动:云原生Meetup北京站报名热烈启动,1月8日见!
  6. 关于Transformer,那些的你不知道的事
  7. 阿里再次主办大数据世界杯, KDD Cup2020正式开赛
  8. 小团队如何玩转物联网开发?
  9. 200行代码解读TDEngine背后的定时器
  10. 李开复口中的“联邦学习” 到底是什么?| 技术头条