CVPR 2018 的一篇少样本学习论文

Learning to Compare: Relation Network for Few-Shot Learning

源码地址:https://github.com/floodsung/LearningToCompare_FSL

在自己的破笔记本上跑了下这个源码,windows 系统,pycharm + Anaconda3 + pytorch-cpu 1.0.1

报了一堆bug, 总结如下:

procs_images.py里 ‘cp’报错

用procs_images.py处理 miniImangenet 数据集的时候:

报错信息:
/LearningToCompare_FSL-master/datas/miniImagenet/proc_images.py
'cp' �����ڲ����ⲿ���Ҳ���ǿ����еij������������ļ���

具体位置是

/datas/miniImagenet/procs_images.py  Line 48:os.system('cp images/' + image_name + ' ' + cur_dir)

这个‘cp’是linux环境运行的。

用windows系统的话要改成:

os.rename('images/' + image_name, cur_dir + image_name)

除此之外,所有的 os.system('mkdir ' + filename)

也要改成 os.mkdir(filename),虽然不一定会报错。

cpu RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.

我的torch版本是是cpu, 所以把所有 .cuda(GPU)删了,另外

使用torch.load时添加 ,map_location ='cpu'

以miniImagenet_train_few_shots.py 为例
Line 150:
feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
改成
feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'
))
Line:153:
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
改成
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location = 'cpu'))

KeyError: '..\\datas\\omniglot_resized'

报错信息:File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in maintask = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'

关键的地方其实是在:

 task_generator.py, line 74:
  def get_class(self, sample):return os.path.join(*sample.split('/')[:-1])

print (os.path.join(*sample.split('/')[:-1])) 结果是

..\datas\omniglot_resized

而labels是

{'../datas/omniglot_resized/Malay_(Jawi_-_Arabic)\\character25': 0, '../datas/omniglot_resized/Japanese_(hiragana)\\character15': 1, '…}

而 print(os.path.join(*sample.split('\\')[:-1]))  结果正是

../datas/omniglot_resized/Malay_(Jawi_-_Arabic)\character25

解决方法:把'/'改成'\\'即可 def get_class(self, sample):return os.path.join(*sample.split('\\')[:-1]) 

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

报错信息:File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in maintorch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1))
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'

解决方法:在前面加一句

 batch_labels = batch_labels.long()

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

报错信息:  File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_test_few_shot.py", line 247, in <listcomp>rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'

解决方法:在前面加上

predict_labels = predict_labels.long()
test_labels = test_labels.long()

这两个好像是使用torch的数据格式问题

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in mainprint("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number按要求改成
print("episode:", episode + 1, "loss", loss.item())
就可以了

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__image = self.transform(image)File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__img = t(img)File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__return F.normalize(tensor, self.mean, self.std, self.inplace)File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalizetensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

这个是使用Omniglot数据集时的报错,主要原因在于

"\omniglot\task_generator.py", line 139:def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0):    normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize]))

使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28]

解决方法:

把normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成normalize = transforms.Normalize(mean=[0.92206], std=[0.08426]) 

UserWarning: nn.functional.sigmoid is deprecated.

类似的warning 还有

UserWarning : torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.

按要求改就行

torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
改成
torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)def forward里的
out = F.sigmoid(self.fc2(out))
改成
out = F.torch.sigmoid(self.fc2(out))

转载于:https://www.cnblogs.com/smartweed/p/10750065.html

Learning to Compare: Relation Network 源码调试相关推荐

  1. 小样本学习 | Learning to Compare: Relation Network for Few-Shot Learning

    博主github:https://github.com/MichaelBeechan 博主CSDN:https://blog.csdn.net/u011344545 Learning to Compa ...

  2. 【关系网络】Learning to Compare: Relation Network for Few-Shot Learning

    one-shot 关系网络 (RN) 由两个模块组成:一个嵌入模块 fφ 和一个关系模块 gφ,如图 1 所示.查询集 Q 中的样本 xj 和样本集 S 中的样本 xi 送到嵌入模块 fφ ,该模块生 ...

  3. The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorf

    The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorfl ...

  4. iOS之深入解析WKWebView的WebKit源码调试与分析

    一.前言 移动互联网时代,网页依旧是内容展示的重要媒介,这离不开 WebKit 浏览内核技术的支持与发展.在 iOS 平台下开发者们需要通过 WKWebView 框架来与 WebKit 打交道. 虽然 ...

  5. Android FrameWork学习(二)Android系统源码调试

    点击打开链接 通过上一篇 Android FrameWork学习(一)Android 7.0系统源码下载\编译 我们了解了如何进行系统源码的下载和编译工作. 为了更进一步地学习跟研究 Android ...

  6. 在Eclipse中进行HotSpot的源码调试--转

    原文地址:http://www.linuxidc.com/Linux/2015-05/117250.htm 在阅读OpenJDK源码的过程中,经常需要运行.调试程序来帮助理解.我们现在已经可以编译出一 ...

  7. webuploader 怎么在react中_另辟蹊径搭建阅读React源码调试环境支持所有React版本细分文件断点调试...

    引言(为什么写这篇文章) 若要高效阅读和理解React源码,搭建调试环境是必不可少的一步.而常规方法:使用react.development.js和react-dom.development.js调试 ...

  8. 使用vs2005进行(wince)DLL源码调试

    调试Dll也需要进到源码里面,进行单步调试.下面是使用vs2005进行wince DLL源码调试的步骤(可能我的方法麻烦了). ------------------------------------ ...

  9. Mac下下载android4.2源码,进行源码调试

    星期天在家研究了一下如何在mac下下载android4.2的源码并通过eclipse进行源码级别调试来更清晰的研究一下android的运行原理,具体步骤如下: 最后下下来了,但是我进行编译却没有通过, ...

最新文章

  1. 【DND图形库】三、创建窗口和绘制精灵
  2. 水文-接口和抽象类有什么不同
  3. windows下anaconda环境激活报错CommandNotFoundError: Your shell has not been properly configured to use ‘con
  4. NPU 2015年陕西省程序设计竞赛网络预赛(正式赛)F题 和谐的比赛(递推 ||卡特兰数(转化成01字符串))...
  5. 大前端技术选型 Native原生iOS, Android, React-Native, Flutter, 微信小程序, HTML5
  6. SQL语句:联合查询
  7. 设置android应用闪屏图片_android 闪屏设计
  8. python未知长度数组,python – 从具有未知维数的numpy数组中提取超立方体块
  9. BGA集成电路脚位识别
  10. 小米手机fastboot模式出现Press any key to shutdown字样解决方法
  11. MQTT协议的智能家居之指纹锁
  12. 口袋理财:“来了就是深圳人?”全国均价最高的房租了解一下
  13. [附源码]计算机毕业设计JAVA社区生鲜电商平台
  14. JAVA中的getBytes()方法
  15. GetLastError返回值大全(英文最新版)(1000-4000)
  16. DSDS/DSDA/DR-DSDS/DR-DSDA场景介绍和关键Log分析
  17. wpf 骚搞 新浪微博
  18. java数独流程图_九宫格数独游戏C语言解法
  19. 提名推荐!15个2019年最佳CSS框架
  20. acrobat PDF删除部分_墙裂推荐!功能强大的PDF编辑器最新免安装版!

热门文章

  1. Landsat5数据下载中国地区1991年
  2. 使用VMware安装黑苹果
  3. Android R(11)文件读写适配
  4. Hadoop学习之路(五):Hadoop交互关系型数据库(MySQL)
  5. Friborg NG8280 粉红噪声发生器测试方法
  6. 太吾绘卷加载卡54_太吾绘卷存档卡99怎么办?卡存档解决办法介绍
  7. 计算机设备供配电,IDC机房供配电系统解决方案
  8. 【天光学术】企业管理论文:基于供应链的企业物流成本控制优化分析(节选)
  9. 私人网盘树洞外链源码
  10. electron开发windows驱动程序