由于MAML作者提供的源码比较混乱,而且是由tensorflow写成。所以我写了一篇用Pytorch复现MAML的博客:MAML模型无关的元学习代码完整复现(Pytorch版)。那篇博客中的复现细节已经很详尽了,但是在omniglot数据集上的准确率只有0.92,考虑到omniglot算是比较简单的数据集了,因此0.92的准确率实在是太低了。

因此,我后来又对模型和数据的读取方法进行了一些调整,最近的实验表明在5-way-1-shot任务上,我的复现准确率已经达到了0.972,算是基本匹配上了作者在论文中给出的准确率区间。

在这篇文章中,我将总结一下我复现MAML时的一些经验和教训以及对原来代码的更改。

1 数据读取方式

我之前的数据读取方式是将omniglot中images_backgroud和images_evaluation这两个文件夹中的数据一次性读取出来,然后再对数据集进行划分。

img_list = np.load(os.path.join(root_dir, 'omniglot.npy')) # (1623, 20, 1, 28, 28)
x_train = img_list[:1200]
x_test = img_list[1200:]

这一次我使用通用的数据划分方法,即:images_backgroud中的数据作为训练数据,images_evaluation中的数据作为测试数据。

img_list_train = np.load(os.path.join(root_dir, 'omniglot_train.npy')) # (964, 20, 1, 28, 28)
img_list_test = np.load(os.path.join(root_dir, 'omniglot_test.npy')) # (659, 20, 1, 28, 28)x_train = img_list_train
x_test = img_list_test

具体代码见我的github。

2 模型构造

原来的模型卷积层的padding为2,stride也为2;我将它们修改为1之后,实验结果直接从0.92提升到了0.975。由此可见模型架构的微小调整也会严重影响模型的性能。大家平时在做实验时应该注意一下。

原来的模型架构为:

#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), #             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), #             FlattenLayer(),
#             nn.Linear(64,5)
#         )

修改后的模型架构为:

#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2),#             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), #             nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 1, stride = 1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(),
#             nn.MaxPool2d(2), #             FlattenLayer(),
#             nn.Linear(64,5)
#         )

3 降低对计算资源的要求

在进行20-way-1-shot的实验时,发现用原来的代码将会消耗大量的资源。我修改了一下原来的代码,在不需要记录梯度的位置加上"with torch.no_grad()",从而将计算资源的需求降到了原来的1/5.

原来的代码为:

            for k in range(1, self.update_step):y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)loss = F.cross_entropy(y_hat, y_spt[i])grad = torch.autograd.grad(loss, fast_weights)tuples = zip(grad, fast_weights) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)loss_qry = F.cross_entropy(y_hat, y_qry[i])loss_list_qry[k+1] += loss_qrywith torch.no_grad():pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)correct = torch.eq(pred_qry, y_qry[i]).sum().item()correct_list[k+1] += correct

修改后的代码为:

            for k in range(1, self.update_step):y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)loss = F.cross_entropy(y_hat, y_spt[i])grad = torch.autograd.grad(loss, fast_weights)tuples = zip(grad, fast_weights) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))if k < self.update_step - 1:with torch.no_grad():        y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)loss_qry = F.cross_entropy(y_hat, y_qry[i])loss_list_qry[k+1] += loss_qryelse:y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)loss_qry = F.cross_entropy(y_hat, y_qry[i])loss_list_qry[k+1] += loss_qry                    with torch.no_grad():pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)correct = torch.eq(pred_qry, y_qry[i]).sum().item()correct_list[k+1] += correct

4 关于20-way-1-shot实验

2020/5/10更新:

Reptile这篇论文中说,MAML的实验使用到了transductive Learning的实验设定。关于transductive Learning你可以理解成MAML作者汇报的是训练中query集的结果,而不是我们通常意义的测试集中query集的结果。
这个图表来自Reptile的那篇论文。

以下是原文:


我在复现这个实验的过程中,在测试集的query集中的最好结果也只有0.843。但是作者宣称她取得了0.95的实验结果,但是作者的源码中并没有给出20-way-1-shot的实验结果或者logs。

我找到了另一个网友(github账号名:katerkelly)的复现代码,这个人宣称他复现出来的结果是0.92。

20-way 1-shot training, best performance 92%

但是我实际运行以及查看了他的代码后发现,他报告的其实是训练集中query集的结果,而不是测试集中query集的结果。我们都知道在元学习中有support集和query集两者集合,其中:

  • 训练集:分为support集和query集,其中support集用于训练,query集用于更新参数。
  • 测试集:分为support集和query集,其中support集用于fine-tune,query集用于评估元学习模型的效果。

而那位网友报告的是训练集中support集的结果,真正的实验结果应该是测试集中support集的实验结果,也就是0.83。

你可以查看那位网友给出的实验结果展示图(下图)。中间那条橙黄色的线是0.92左右,那位网友报告的也是橙黄色这条线的结果,但是实际的实验结果应该是下面这条红色的线。也就是0.83左右,跟我得出的实验结果比较吻合。

有意思的是,MAML作者声称她的实验结果实0.95,而我自己复现的结果中,在测试集的support集上的结果也是0.95-0.96。为了跑出0.9以上的实验结果,我已经做了好几天的实验了,模型架构和超参数改动了几十次,最好的结果还是只有0.843。如果哪位网友能够复现出0.9以上的实验结果,麻烦告诉我一下。

5 实验数据

以下展示在60000轮epoch中,query集的测试集中出现的最好结果:

  1. 20 way 1 shot 4 batch meta_lr = 0.0002, base_lr = 0.1 : acc: 0.84

  2. 20 way 1 shot 8 batch meta_lr = 0.0001, base_lr = 0.1 : acc: 0.835

  3. 20 way 1 shot 8 batch meta_lr = 0.0001, base_lr = 0.1 : acc: 0.843

  4. 20 way 1 shot 8 batch meta_lr = 0.0005, base_lr = 0.3 : acc: 0.79

  5. 20 way 1 shot 8 batch meta_lr = 0.001, base_lr = 0.1 : acc: 0.82

  6. 20 way 1 shot 8 batch meta_lr = 0.001, base_lr = 0.2 : acc: 0.785

  7. 5 way 1 shot 4 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.96

  8. 5 way 1 shot 8 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.972

  9. 5 way 1 shot 16 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.969

  10. 5 way 1 shot 32 batch 10 range meta_lr = 0.001, base_lr = 0.1 : acc: 0.975

自己想要复现的朋友,可以参考一下我的实验结果,免得继续做无用功。

6 关于我自己的源码

你可以在我的github上找到我的全部代码(miguealanmath)。喜欢的朋友可以点下小星星。

MAML复现全部细节和经验教训(Pytorch)相关推荐

  1. 论文返修与校对的经验教训

    论文返修与校对的经验教训 科技论文(Decoupling the dynamics of bacterial taxonomy and antibiotic resistance function i ...

  2. 下拉多选择框 实现方式_物体检测之旅(三)|设计选择,经验教训和物体检测的趋势...

    作者:Jonathan Hui编译:ronghuaiyang 物体检测器,像基于区域的检测或者一阶段的检测器,从不同的起点起步,最后越来越相似,都是朝着更快更准的目的地在前进.事实上,有些性能的差距可 ...

  3. 程序编写经验教训_编写您永远都不会忘记的有效绩效评估的经验教训。

    程序编写经验教训 This article is intended for two audiences: people who need to write self-evaluations, and ...

  4. 像程序员一样思考_如何像程序员一样思考-解决问题的经验教训

    像程序员一样思考 by Richard Reis 理查德·里斯(Richard Reis) 如何像程序员一样思考-解决问题的经验教训 (How to think like a programmer - ...

  5. 2年工作经验进 初创公司_沟通是关键:通过两家初创公司获得的成长经验教训+找工作...

    2年工作经验进 初创公司 by Niki Agrawal 通过尼基·阿格劳瓦尔(Niki Agrawal) 沟通是关键:通过两家初创公司获得的成长经验教训+找工作 (Communication is ...

  6. 回顾:我们从2次主要API中断中汲取的经验教训

    by Cory Kennedy-Darby 通过科里·肯尼迪·达比 回顾:我们从2次主要API中断中汲取的经验教训 (Retrospective: lessons we learned from 2 ...

  7. 结对编程——经验教训总结

    结对编程之经验教训总结 "宝剑锋从磨砺出,梅花香自苦寒来." 整整做了一个星期,终于将结对编程项目做完了,多少心酸只有自己知道,多少成就感也只有自己知道.这是真正自己动手,从最初的 ...

  8. 项目经验教训总结(教育软件)

    今年主要做的项目已经验收通过,做了一些经验教训的总结,记录于此: 一.项目的描述     这个项目是一款院校的实验室软件,因牵涉到几个因素:政府拨款.新学院和新专业的设立.市场是的新产品,所以备受关注 ...

  9. 几点项目里的经验教训

      年后上班,正是项目新的一个开发阶段开始了.果然年后是人员流动的高峰期,有几个同事辞职了.为了工作的顺利衔接,又有新的人员开始补充进来.项目总是时间紧,任务重.第一批任务即将开始,已经没有退路了.回 ...

最新文章

  1. ICRA2021| 自动驾驶相关论文汇总
  2. Semaphore(信号量)
  3. graphviz画图
  4. 网页脚本基本java语法_JSP 基础语法
  5. 在 Mac 安装Docker
  6. pythonsqlite视图_SQLite VIEW/视图
  7. 重构(Refactoring)技巧读书笔记 之二
  8. mysql 数据库定时备份 增量/全备份
  9. 02-linux安装nodejs
  10. 连续一个月,每天只吃一个苹果,身体会怎么样?
  11. linux下tar gz bz2 tgz z等众多压缩文件的解压方法
  12. 使用threeJS根据点的坐标绘制曲线
  13. javascript面试题(一)
  14. python爬虫下载付费音乐包有什么用_听歌音乐还要我付费?看我用Python批量下载!...
  15. 用WPS Office下五子棋(转)
  16. 供应商关系管理系统SRM
  17. (很容易懂,你把代码复制粘贴即可解决问题)高等代数/线性代数-基于python实现矩阵法求解齐次方程组
  18. 罗马帝国 Ancient Rome 简易修改器
  19. 【2022感恩节活动营销理念】跨境电商卖家必知 !
  20. [题解] 洛谷 P3603 雪辉

热门文章

  1. Jedis实现抽奖功能
  2. 戴尔电脑遭香港高校联合抵制
  3. win10桌面图标和任务栏图标一直闪烁,就和刷新一样,怎么解决?
  4. 《Unity3D网络游戏实战》第7章
  5. matlab idft 二维,idft matlab
  6. Python爬虫实战 | (7) 爬取万方数据库文献摘要
  7. 云鲸扫拖一体机器人说明书_让做家务变的更简单:云鲸智能扫拖一体机器人测评...
  8. 穿越解密: Intel X86迎来小型机的春天
  9. SqlSugar 首次使用以及遇到的问题
  10. 在线客服系统源码开发实战总结:动态加载js文件实现粘贴一段js的sdk代码,直接引入插件效果...