点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:来源丨AI科技评论

编辑丨极市平台

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

今天为大家介绍一个GitHub上最新开源的一个基于强化学习的自动化剪枝模型,本模型在图像识别的实验证明了能够有效减少计算量,同时还能提高模型的精度。

项目地址:

https://github.com/freefuiiismyname/cv-automatic-pruning-transformer

1

介绍

目前的强化学习工作很多集中在利用外部环境的反馈训练agent,忽略了模型本身就是一种能够获得反馈的环境。本项目的核心思想是:将模型视为环境,构建附生于模型的 agent ,以辅助模型进一步拟合真实样本。

大多数领域的模型都可以采用这种方式来优化,如cv/多模态等。它至少能够以三种方式工作:

1.过滤噪音信息,如删减语音或图像特征;

2.进一步丰富表征信息,如高效引用外部信息;

3.实现记忆、联想、推理等复杂工作,如构建重要信息的记忆池。

这里推出一款早期完成的裁剪机制transformer版本(后面称为APT),实现了一种更高效的训练模式,能够优化模型指标;此外,可以使用动态图丢弃大量的不必要单元,在指标基本不变的情况下,大幅降低计算量。

该项目希望为大家抛砖引玉。

2

为什么要做自动剪枝

在具体任务中,往往存在大量毫无价值的信息和过渡性信息,有时不但对任务无益,还会成为噪声。比如:表述会存在冗余/无关片段以及过渡性信息;动物图像识别中,有时候背景无益于辨别动物主体,即使是动物部分图像,也仅有小部分是关键的特征。

以transformer为例,在进行self-attention计算时其复杂度与序列长度平方成正比。长度为10,复杂度为100;长度为9,复杂度为81。

利用强化学习构建agent,能够精准且自动化地动态裁剪已丧失意义部分,甚至能将长序列信息压缩到50-100之内(实验中有从500+的序列长度压缩到个位数的示例),以大幅减少计算量。

实验中,发现与裁剪agent联合训练的模型比普通方法训练的模型效果要更好。

3

模型介绍及实验

模型主体

基于transformer的视觉预训练模型ViT是本项目的模型主体,具体细节可以查看论文:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》

自动化裁剪的智能体

对于强化学习agent来说,最关键的问题之一是如何衡量动作带来的反馈。为了评估单次动作所带来的影响,使用了以下三步骤:

1、使用一个普通模型(无裁剪模块)进行预测;

2、使用一个带裁剪器的模型(执行一次裁剪动作)进行预测;

3、对比两次预测的结果,若裁剪后损失相对更小,则说明该裁剪动作帮助了模型进一步拟合真实状况,应该得到奖励;反之,应该受到惩罚。

但是在实际预测过程中,模型是同时裁剪多个单元的,这或将因为多个裁剪的连锁反应而导致模型失效。训练过程中需要构建一个带裁剪器的模型(可执行多次裁剪动作),以减小该问题所带来的影响。

综上,本模型使用的是三通道模式进行训练。

关于裁剪器的模型结构设计,本模型中认为如何衡量一个信息单元是否对模型有意义,建立于其自身的信息及它与任务的相关性上。

因此以信息单元本身及它与CLS单元的交互作为agent的输入信息。

实验

数据集

ViT

APT(pruning)

APT(no pruning)

CIFAR-100

92.3

92.6

93.03

CIFAR-10

99.08

98.93

98.92

以上加载的均为ViT-B_16,resolution为224*224。

4

使用说明

环境

下载经过预先训练的模型(来自Google官方)

本项目使用的型号:ViT-B_16(您也可以选择其它型号进行测试)

训练与推理

下载好预训练模型就可以跑了。

CIFAR-10和CIFAR-100会自动下载和培训。如果使用其他数据集,您需要自定义data_utils.py。

在裁剪模式的推理过程中,预期您将看到如下格式的输出。

默认的batch size为72、gradient_accumulation_steps为3。当GPU内存不足时,您可以通过它们来进行训练。

注:相较于原始的ViT,APT(Automatic pruning transformer)的训练步数、训练耗时都会上升。原因是使用pruning agent的模型由于总会丢失部分信息,使得收敛速度变慢,同时为了训练pruning agent,也需要多次的观测、行动、反馈。

致谢

感谢基于pytorch的图像分类项目(https://github.com/jeonsworld/ViT-pytorch),本项目是在此基础上做的研发。

最后再附上一次项目地址,欢迎感兴趣的读者Star✨

https://github.com/freefuiiismyname/cv-automatic-pruning-transformer

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

GitHub|基于强化学习自动化剪枝相关推荐

  1. 基于强化学习的自动化剪枝模型

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源丨AI科技评论 编辑丨极市平台 导读 GitHub上最新开源的一 ...

  2. 【实践】基于强化学习的 Contextual Bandits 算法在推荐场景中的应用

    文章作者:杨梦月.张露露 内容来源:滴滴科技合作 出品平台:DataFunTalk 导读:本文是对滴滴 AI Labs 和中科院大学联合提出的 WWW 2020 Research Track 的 Or ...

  3. 基于强化学习的图像配准 - Image Registration: Reinforcement Learning Approaches

    配准定义 给定参考图像 I_f 和浮动图像 I_m ,所谓的配准就是寻找一个图像变换T,将浮动图像I_m变换到和 I_f 相同的坐标空间下,使得两个图像中对应的点处于同一坐标下,从而达到信息聚合的目的 ...

  4. [论文]基于强化学习的无模型水下机器人深度控制

    基于强化学习的无模型水下机器人深度控制 摘要 介绍 问题公式 A.水下机器人的坐标框架 B.深度控制问题 马尔科夫模型 A.马尔科夫决策 B.恒定深度控制MDP C.弯曲深度控制MDP D.海底追踪的 ...

  5. 基于强化学习的质量AI在淘系互动业务的实践之路

    导读:AI人工智能的概念由来已久,因为alphago在围棋领域击败李世石掀起了全世界范围内的AI热潮,最近又随着DeepMind破解蛋白质折叠难题这一诺奖级成果再次让我们发现AI已经进化到了如此强大的 ...

  6. 华为诺亚ICLR 2020满分论文:基于强化学习的因果发现算法

    2019-12-30 13:04:12 人工智能顶会 ICLR 2020 将于明年 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行,不久之前,大会官方公布论文接收结果:在最终提交的 2594 篇论文 ...

  7. 智能城市dqn算法交通信号灯调度_博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型...

    原标题:博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型 国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共 ...

  8. 今晚8点:基于强化学习的关系抽取和文本分类 | PhD Talk #18

    「PhD Talk」是 PaperWeekly 的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和交流可能会让知识的传播更加有意义 ...

  9. 直播预告:基于强化学习的关系抽取和文本分类 | PhD Talk #18

    「PhD Talk」是 PaperWeekly 的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和交流可能会让知识的传播更加有意义 ...

最新文章

  1. 你不可不知的9种Lisp语言思想
  2. 面试题整理 4 合并两个排序的数组
  3. 【推荐系统】基于知识图谱的推荐系统总结
  4. zipkin brave mysql_zipkin mysql表结构
  5. UE4笔记-UStructToJsonObjectString首字母自动转换为小写的问题及解决方法
  6. 详解nginx 代理多个服务器(多个server方式)
  7. java 分布式 定时任务_Java中实现分布式定时任务的方法
  8. PF_PACKET说开去
  9. YY一下淘宝商品模型
  10. tp-link与台式计算机连接教程,【详细图解】TP-Link TL-WDR6510路由器电脑设置教程...
  11. ffmpeg(六)视频缩放及像素格式转换
  12. oracle怎么建立物化视图,Oracle 建立物化视图步骤
  13. DVWA 不跳转_终于开通!小红书图文、直播可跳转淘宝链接!
  14. java interop,服务器程序的Xamarin-Java.Interop体验(一)
  15. 关于spoolsv.exe程序问题
  16. Unity Shader - ddx/ddy偏导函数测试,实现:锐化、高度图、Flat shading应用、高度生成法线
  17. 【C++】引用、内联函数、函数重载、函数默认参数(缺省参数)与占位参数、extern “C“ 浅析
  18. 如何混迹程序猿江湖,你得懂程序员黑话暗语!
  19. 踩坑-填坑之 : vue打包上线,页面无法显示
  20. LangChain vs Semantic Kernel

热门文章

  1. java 线程参数 用final,JAVA 关于final修饰变量参数
  2. 网站访客系统php,PHP实现网站访客来访显示访客IP浏览器操作系统
  3. python简单体育竞技模拟_python初体验 —— 模拟体育竞技
  4. Tomcat Manager服务启用
  5. CentOS单用户模式及进入后只读处理,开机修改为文字界面
  6. cassss服务未启动_Mysql无法启动情况下,如何恢复数据呢?
  7. 恢复出厂设置android手机号码,安卓手机怎么恢复出厂设置
  8. 大数据:技术与应用实践指南_大数据技术与应用社团 社会实践总结篇
  9. 华硕服务器安装完系统起不来,w10安装后启动不起来的具体处理办法【图文】
  10. Django-C003-视图