介绍

AutoML这个topic在机器学习领域越来越火,新的研究成果也是层出不穷。在网络架构(NAS),模型压缩(AMC),数据增强(AutoAugment),优化器设计(Neural Optimizer Search),平台相关优化(AutoTVM)等领域,我们都可以看到相应的研究成果表明机器学习可以达到比人肉调参更优的结果。自动化方法正在逐步替代调参工。相信不久的将来,我们面对一个场景,只要喂数据,其他的由机器一条龙完成,而且还能比人肉优化出来人牛X。它们本质上都是大规模空间中的最优参数搜索问题,因此很多方法基本都是相通的。其中网络架构的搜索(Neural Architecture Search, NAS)一直是业界主要研究对象,各种云服务也开始推出相关的服务。一开始的时候,这项技术需要依赖巨大的计算成本。随着各种改进,它的算力需求大大减少,使得它开始有可能飞入寻常百姓家。之前写过一篇简单讨论了相关的技术《神经网络架构搜索(Neural Architecture Search)杂谈》,这里就不重复了。本文主要学习一下Auto-Keras这个开源项目中的实现,使用它我们可以在本地进行网络架构搜索。官方有一篇相关的paper介绍-Auto-Keras: An Efficient Neural Architecture Search System。

安装过程很简单,详见官方介绍。图方便的话直接用pip安装即可:

pip install autokeras

比较核心的搜索过程如下示意图。首先生成几个初始网络架构作为种子并训练得到准确率指标,然后不断通过network morphism产生新的网络架构,这样就形成一个树形的搜索结构。搜索过程中会试图找最有『潜力』(用acquisition function衡量)的网络架构,这个架构会通过训练得到其准确率。这些挑选出来的架构及其准确率会用来拟合高斯过程模型,这个模型又会帮助下一轮树型结构中搜索时预测生成网络架构的准确率指标。

流程

下面沿着官方自带的例子mnist.py来看一下大体实现过程。

    clf = ImageClassifier(verbose=True)clf.fit(x_train, y_train, time_limit=12 * 60 * 60)clf.final_fit(x_train, y_train, x_test, y_test, retrain=True)y = clf.evaluate(x_test, y_test)

因为是图片分类任务,这里先创建ImageClassifier类,接着fit()函数进行网络架构搜索,final_fit()函数对搜索出来的最优模型训练,最后evaluate()函数进行模型评估。其中涉及到的几个主要类的关系:

ImageClassifier继承自ImageSupervised类(实现位于image_supervised.py文件)。它用来在图片分类任务中针对数据集搜索最优的CNN架构。ImageSupervised本身的构造函数很简单,它会调用其父类DeepTaskSupervised(实现位于supervised.py)的构造函数。父类构造函数中会检测是否要继续上次的任务。如果是,则从文件中把分类器和CNN模块恢复回来,否则就新创建CnnModule对象(它的实现在net_module.py文件)。CnnModule构造时会往generators数组里放入三个生成器CnnGenerator, ResNetGeneratorDenseNetGenerator,分别用于生成普通CNN,ResNet和DenseNet。

接下来大头是fit()函数。该函数先计算数据集中size的中值,再把数据都统一resize成这个尺寸。然后就是调用父类DeepTaskSupervisedfit()函数(位于supervised.py文件)。这里边先是准备数据集(切分验证集,数据转换等),最后调用CnnModulefit()函数。它本质上是调用其基类NetworkModulefit()函数。这个函数首次会根据参数创建相应的Searcher对象(默认是BayesianSearcher)。然后进入循环,每次迭代中调用BayesianSearchersearch()函数。在BayesianSearchersearch()函数中:

  1. 首次会调用init_search()函数初始化。这里会用到之前的三个网络生成器,分别对于数据集的类别数量和输入shape实例化,并通过它们的generate()函数生成相应的网络。网络计算图用Graph类表示。生成的网络加入到training_queue中等待训练和评估。这个队列中的元素为三元组(graph, other_info, model_id)。对于BayesianSearcher来说other_info就是其父节点id。
  2. training_queue中弹出一个元素。这个元素代表一个网络结构。调用mp_search()函数(非Colab环境)进行训练和网络的搜索。

mp_search()函数主要有这么几步:

  1. 这个函数首先会新建一个进程来作模型的训练(通过train()函数)和评估。这主要是由ModelTrainertrain_model()函数(实现在model_trainer.py文件)完成。由于这里只是为了评估网络架构,没必要训练得很充分,所以会采用early stop节约资源,默认如果5次epoch loss不再减少则提前退出训练。
  2. 同时,在当前进程调用_search_common()函数进行网络结构的搜索。这个函数中当发现当前已没有待训练模型时,会首先调用generate()函数生成下一个网络结构。这里就用到BayesianOptimizer了,它会涉及到三个关键数据结构:1) IncrementalGaussianProcess是高斯过程模型用来对网络的指标进行建模和预测;2) SearchTree为搜索树,用来组织搜索过的历史网络结构,A通过network morphism生成B,则A为B的父结点; 3) PriorityQueue将网络结构按其指标进行排序,它用于指导优先从哪些网络结构开始拓展。回到generate()函数。它首先将历史中的网络模型及指标拿出来进行排序,并放入PriorityQueue。。接着进入循环。每次迭代中,从优先级队列中取元素(即指标最好),并通过模拟退火算法确定是否要对该元素对应的网络进行变形。模拟退火是为了考虑exploration。模拟退火中的acceptance function中的energe function这里也是acquisition function,它某种程度上指示搜索过程中对结点的偏好。因为我们对网络指标的估计不是准确的,还得考虑估计的不确定性,因此这里采用的是UCB。它的计算不仅需要指标预测的均值,还需要方差,这些信息是由高斯过程模型给出的。如果某结点被选定为“潜力股”,那会调用transform()函数(实现在net_transformer.py文件)通过network morphism对其进行扩展。通过变形形成的新网络如果不是重复的,则也会被加入PriorityQueue进行下一轮拓展。这个循环结束后,变量opt_acqtarget_graph、和father_id分别记录了最优的acquisition function值,对应的图结构和其父结点ID。这个生成的图在_search_common()中会加入到training_queue中待训练。注意这个生成的图中的权重是可以和父结点重用的(通过load_modoel_by_id加载),这也是network morphism的最大好处。
  3. 从训练的进程拿评估后的指标。如果指标有效,先调用add_model()函数记录相关信息,如将网络图dump到文件,记录到历史,记录最好模型ID等。然后调用update()函数。该函数调用BayesianOptimizerfit()add_child()函数。前者主要用于更新高斯过程模型参数;后者主要将该网络结构加入到搜索树当中,以便以后评估是否要对其进行扩展。

以上就是网络搜索的主要过程,大致流程图如下。

例子中,后面分别调用了final_fit()evaluate()函数进行最优模型的训练和评估。final_fit()函数对前面搜索到的最优网络架构进行权重的训练。在搜索过程中的训练只是为了在候选网络间对比,因此训练得并不充分。这里会将max_no_improvement_num设为30,即30次epoch loss没有减少才退出。最后evaluate()函数将前一步训练好的模型拿来在测试集作预测并得出最终准确率。

实验

MNIST因为数据集太简单,初始时就很高了,貌似不是很能体现出网络架构搜索的价值。我们拿稍微复杂一些的CIFAR10来做下实验。经过在我的低端GPU上搜索约20个小时,可以看到,验证集上准确率从一开始的65%左右到90%以上。最优模型出现在第38轮迭代,验证集准确率为93.16%。基于该模型经过充分训练后测试集上准确率为95.56%。

这里搜索到的最优模型的网络架构如下图。自动搜索出来的网络架构看起来是比较诡异一些。。。

单机玩转神经网络架构搜索(NAS) - Auto-Keras学习笔记相关推荐

  1. 【CVPR 2020】神经网络架构搜索(NAS)论文和代码汇总

    关注上方"深度学习技术前沿",选择"星标公众号", 技术干货,第一时间送达! [导读]今天给大家整理了CVPR2020录用的几篇神经网络架构搜索方面的论文,神经 ...

  2. 自动化机器学习(三)神经网络架构搜索综述(NAS)简述

    文章目录 技术介绍 简介 技术栈 实现 数据 数据读取 创建模型并训练 模型预测与评估 模型的导出 技术介绍 简介 自动化机器学习就是能够自动建立机器学习模型的方法,其主要包含三个方面:方面一,超参数 ...

  3. 神经网络架构搜索(NAS)综述 | 附AutoML资料推荐

    本文是一篇神经网络架构搜索综述文章,从 Search Space.Search Strategy.Performance Estimation Strategy 三个方面对架构搜索的工作进行了综述,几 ...

  4. ICCV NAS Workshop 最佳论文提名:通过层级掩码实现高效神经网络架构搜索

    点击我爱计算机视觉标星,更快获取CVML新技术 机器之心发布 作者:Faen Zhang.Mi Zhang等 本文介绍了由创新奇智公司联合密歇根州立大学合作开发的高效神经网络架构搜索算法 HM-NAS ...

  5. 神经网络架构搜索(NAS)综述

    在阅读近期的CVPR2019时,看到一篇比较亮眼的图像分割论文.来自斯坦福 Li Fei-Fei组(Auto-deeplab),关于利用NAS策略进行图像分割,达到了较优的水平,仅仅比deeplabv ...

  6. 神经网络架构搜索(NAS)基础

    网络架构搜索(NAS)已成为机器学习领域的热门课题.商业服务(如谷歌的AutoML)和开源库(如Auto-Keras[1])使NAS可用于更广泛的机器学习环境.在这篇博客文章中,我们主要探讨NAS的思 ...

  7. 论文解读 Search to Distill: Pearls are Everywhere but not the Eyes,神经网络架构搜索+知识蒸馏

    目录 Search to Distill: Pearls are Everywhere but not the Eyes Motivation Method Experiments 结论 Search ...

  8. 卷积神经网络原理_怎样设计最优的卷积神经网络架构?| NAS原理剖析

    虽然,深度学习在近几年发展迅速.但是,关于如何才能设计出最优的卷积神经网络架构这个问题仍在处于探索阶段. 其中一大部分原因是因为当前那些取得成功的神经网络的架构设计原理仍然是一个黑盒.虽然我们有着关于 ...

  9. 神经架构搜索(NAS)2020最新综述:挑战与解决方案

    终于把这篇NAS最新的综述整理的survey放了上来,文件比较大,内容比较多.这个NAS的survey是A Comprehensive Survey of Neural Architecture Se ...

最新文章

  1. EF 查看生成的SQL语句
  2. python websocket实现消息推送_Python Websocket消息推送---GoEasy
  3. Apache Hive on Apache Tez
  4. android怎么注释代码块,Android.mk 代码注释
  5. 菜鸟的学习之路(12) —HashSet类详解
  6. 南京师范大学2021年硕士研究生入学考试高等代数试卷及参考答案
  7. 水果销售管理系统课程设计报告
  8. 混合云存储阵列与云存储网关的协同解决方案
  9. SAP MM 发货到成本中心场景下的批次确定
  10. SmartToast
  11. P3426 [POI2005]SZA-Template
  12. win10下yolov3训练自己的数据集
  13. VUE项目报错Error Cannot find module ‘webpacklibRuleSet‘_解决
  14. 质量流量计在油品计量中的应用
  15. 360n4手机可以装linux,360手机N4 root教程_360手机N4获取root权限的方法
  16. Django大咖之路: 如何对付学习Django过程中所遇到的挫败感?
  17. 破解分布式数据库全局死锁难题 GBase 8c引领数据库领域变革
  18. 蓝天采集器winds系统页面渲染设置教程
  19. Android小程序之自动发送短信
  20. CSS3新选择器,盒子模型,过渡动画transition,2D转换transform

热门文章

  1. MATLAB 平面线形变换 及验证多个点是否在同一直线
  2. 集成支付宝支付出现{resultStatus=4000, result=, memo=系统繁忙,请稍后再试}
  3. POJ-1236(有向图强连通分量 + 缩点 + 加边使得整个图强连通)
  4. 连接mongodb提示目标计算机拒绝,MongoDB 由于目标计算机积极拒绝,无法连接 2014-07-25T11:00:48.634+0...
  5. Seata源码走读分析
  6. Spring Boot 入门实战教程
  7. 华为RS 5.IP编址之VLSM
  8. 读《任正非在2012实验室的讲话》总结
  9. 金蝶k3显示加层服务器失败,金蝶k3提示:连接中间加密服务失败,请确认中间层加密服务已启动...
  10. Educational Codeforces Round 81 (Rated for Div. 2)