191214 说明:

很抱歉,突然发现图中第三行多画了一列叉,事实上,生成 output(0,0) 数据只用到了input[:,0] 以及 weights[0,:]。比较懒,就不再画了,图中第三行的第一个矩阵应该和第二行的第一个矩阵相同。

此外至于评论区中有人提到得到的结果一样。为此我做了一个小实验,验证经过一步简单优化后,模型参数之间的差异。使用的代码如下:

import torch
import torch.nn as nn
from copy import deepcopySMALL = 1e-7class Model(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv1d(3, 2, 1)self.linear = nn.Linear(3, 2)params = deepcopy(self.conv.state_dict())params['weight'] = params['weight'].reshape(2, 3)self.linear.load_state_dict(params)self.optim0 = torch.optim.Adam(lr=0.5, params=self.conv.parameters())self.optim1 = torch.optim.Adam(lr=0.5, params=self.linear.parameters())def test(self, tensor):assert (self.conv.weight.flatten().eq(self.linear.weight.flatten())).all(), 'Initial weights are different'assert (self.conv.bias.flatten().eq(self.linear.bias.flatten())).all(), 'Initial biases are different'# get out tensorc_out = self.conv(tensor.transpose(1, 2)).transpose(1, 2)l_out = self.linear(tensor)# set optim targetc_target = c_out.sum(2).mean()l_target = l_out.sum(2).mean()# apply one step optimizationself.optim0.zero_grad()c_target.backward()self.optim0.step()self.optim1.zero_grad()l_target.backward()self.optim1.step()# if conv1d(kernel=1) behaves the same as linear, # their parameters should be the same after applying # one-step optimizationa = self.conv.weight.data.flatten()b = self.linear.weight.data.flatten()c = self.conv.bias.data.flatten()d = self.linear.bias.data.flatten()assert torch.add(a, -b).abs().lt(SMALL).all(), 'After a step, weights are different'assert torch.add(c, -d).abs().lt(SMALL).all(), 'After a step, biases are different'if __name__ == "__main__":gen = torch.Generator()for seed in range(1000):gen.manual_seed(seed)tensor = torch.rand((10, 3, 3), generator=gen)model = Model()model.test(tensor)

代码做的事情主要是:1)给定随机种子,生成随机的 tensor,2)建模并使得两个矩阵的初始参数相等。3)用 Adam 进行一次简单的优化,并比较优化后的参数。

这里做最后的参数比较的时候,用 torch.eq 是肯定无法通过的,经过观察,发现最后的weights 确实很接近。经过几次测试发现,两者在一步优化后,在 1e-6 误差内可以通过测试,但是无法通过 1e-7 误差的测试。

因此,这里有两种可能,第一,这个weights的误差仅仅是因为计算误差导致的,pytorch 在计算两者的时候,本质上是一样。第二,pytorch 在计算 conv1d 的时候确实如文档说的使用了cross-relation operation,但是这个operation在简单的case中,带来的gradients和linear确实存在微小的区别,从而使得其行为不一样。但是得说明的是,经过一个大 N 级别的优化过程,conv1d 和 linear 带来的区别会是显着的。因此,最起码的,在使用 pytorch 进行计算时,不可将两者视为等同。

通过本次实验以及基于自己在复现 VRP-RL 的经验,我偏向认为 conv1d 的行为和 linear 是不同的。欢迎进一步讨论。

-------

最近在复现VRP下的DRL算法,当考虑C个顾客的问题,以及batch的大小为N,相应的地图数据的shape是(N, C, 2),其中第三维分别存储物理坐标(x,y)信息。

原文使用Conv1d with kernel_size=1来作为encoder,将原始数据映射到embedding_size=M的维度上去,得到数据形状为(N, C, M)。

作为一个调包侠,从来都只在乎输入和输出的形状,怎么方便怎么来。因为pytorch的Conv1d的API的输入数据需要将1维和2维调换,即(N, 2, C),觉得麻烦,而且误以为kernel=1的时候的Conv1d和Linear是完全一样的,然后就顺手用了一个Linear Layer 去做为embedding。唯一的区别仅仅在于这个encoder的选择,结果就是和benchmark对比,花费时间更长且效果更差。

然后去Stack Overflow上面去找找看答案,发现遇到这种问题的不仅我一个(见帖子),这里就根据pytorch的API一起探索一下conv1d(kernel=1)和linear分别究竟做了什么,以及产生区别的原因。

首先我们看最简单的Linear layer如下,

然后我们看Conv1d的API如下

有兴趣的小伙伴可以推导一下公式,不难发现,假设考虑都是叉乘操作,结合轴变化,当kernel=1的时候,Conv1d和Linear的output中各元素是共源的,也就是说,对于entry(i,j),生成他们数据的原始数据来源是一样的,结果应该没什么区别。神级画手只能帮到这里了(1为Linear,2为假设的Conv1d,3为实际的Conv1d):

如此这般,那问题很可能就出在集结方式了!Linear这边,确实就是普通的叉乘操作。这时候,Conv1d中红色的这个cross-correlation很可能就是问题的关键了。wiki和百度链接如下,具体计算公式可以在链接里面找到:

wiki​www.wikiwand.com

Insignal processing,cross-correlationis ameasure of similarityof two series as a function of the displacement of one relative to the other. This is also known as aslidingdot productorsliding inner-product.

互相关函数_百度百科​baike.baidu.com

互相关函数是信号分析里的概念,表示的是两个时间序列之间的相关程度,即描述信号 x (t),y (t) 在任意两个不同时刻 t1,t2 的取值之间的相关程度。描述两个不同的信号之间的相关性时,这两个信号可以是随机信号,也可以是确知信号。

这不就是明摆着这个操作会体现两个序列间的相关性吗?在VRP问题中,相当于提前集结了x和y的信息,并通过学习weights与历史信息中其他点作对比,大致估计当前点所在的位置,对于VRP问题,各个物理点x以及y之间的相关程度对于问题的求解是一个很重要的信息,当两个位置的坐标更相近的时候,更有可能考虑这两点相连,因此,使用Conv1d layer对问题的求解有更好的帮助。

小结,当同组数据中,各数据在每一维度的相关性比较重要的时候,Conv1d能提取这些数据并反映出来,这是普通的Linear layer做不到的。同时,做一个naive的调包侠是不可取的,仔细研究每一个API的内部运行机制才能避免写低效甚至错误的代码。

union和union all有什么区别_Pytorch中Linear与Conv1d(kernel=1)的区别相关推荐

  1. html div p 区别,html中div br p三者有什么区别?

    本篇文章给大家带来的内容是关于html中div br p三者有什么区别,有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 一.语法不同 p和p是成对组合闭合标签: 是单一的闭合标签.以 ...

  2. shell和php区别,PHP中exec函数和shell_exec函数的区别

    这篇文章主要介绍了PHP中exec函数和shell_exec函数的区别,这两个函数是非常危险的函数,一般情况都是被禁用的,当然特殊情况下也会使用,需要的朋友可以参考下 这两个函数都是执行Linux命令 ...

  3. python新式类和经典类区别_Python中新式类和经典类的区别,钻石继承

    1)首先,写法不一样: class A: pass class B(object): 2)在多继承中,新式类采用广度优先搜索,而旧式类是采用深度优先搜索. 3)新式类更符合OOP编程思想,统一了pyt ...

  4. c#与html的区别,C#中Html.RenderPartial与Html.RenderAction的区别分析

    本文较为详细的讲解了C#中Html.RenderPartial与Html.RenderAction的区别,具体分析如下: Html.RenderPartial与Html.RenderAction这两个 ...

  5. mysql数据库blob区别_MySQL中TEXT与BLOB字段类型的区别

    在MySQL中有两个字段类型容易让人感觉混淆,那就是TEXT与BLOB,特别是自己写博客程序的博主不知道改为自己的博客正文字段选择TEXT还是BLOB类型. 下面给出几点区别: 一.主要差别 TEXT ...

  6. Java传统的io和nio区别_Java中IO和NIO的本质和区别

    简介 终于要写到java中最最让人激动的部分了IO和NIO.IO的全称是input output,是java程序跟外部世界交流的桥梁,IO指的是java.io包中的所有类,他们是从java1.0开始就 ...

  7. HTML中href src区别,html中 href 和 src 的定义与区别

    看了几天的html,一直觉得没有把href和src的区别和用的地方搞清楚,今晚就来捋一捋. 才开始觉得href和src是没有有区别的,只是不能用在同一个标签中. 后来发现href和src是有区别的. ...

  8. java中add和addall区别,java中list的add与addall方法区别

    在做项目时我遇到过这样的问题,java.lang.ClassCastException: java.util.ArrayList cannot be cast to com.alibaba.gette ...

  9. dma和通道的区别_Java中IO和NIO的本质和区别

    简介 终于要写到java中最最让人激动的部分了IO和NIO.IO的全称是input output,是java程序跟外部世界交流的桥梁,IO指的是http://java.io包中的所有类,他们是从jav ...

最新文章

  1. swift 加载gif 框架图片
  2. visual studio 怎么生成coredump文件_玩游戏丢失dll文件别着急 认识这些就妥了
  3. python【Matlibplot绘图库】多图合并显示(真の能看懂~!)
  4. python大数据分析实例-python大数据分析代码案例
  5. 化工原理 蒸馏(下)
  6. 2021牛客多校9 - Cells(推公式+NTT)
  7. Linux各发行版本 优缺点 简介
  8. JDK 5.0 中的泛型类型学习
  9. case 日期when 范围_亚马逊运营干货:开case最全路径和各种实用链接,赶紧收藏...
  10. 利用Resource Hacker简单去除WinRar广告
  11. 假期无聊 就来试试用Python做一个智能识别 包教会哦 多图预警:配置Pyqt5超详细解说(designer.exe和pyuic.exe)以及项目:Python实现百度智能识别,识别各种实物
  12. Log4j2日志记录框架的使用教程与简单实例
  13. 中间继电器DZY-204/DC110V
  14. TM4C123GXL驱动安装
  15. 考研复试-传输层-计算机网络面试题
  16. c android显示gif动画,MFC显示GIF动画图片
  17. * web H5 网页 浏览器 蓝牙 Bluetooth
  18. 蓝魔法师——树形DP
  19. java基于微信小程序的游戏外包管理信息系统 uniapp 小程序
  20. 轻松搞定java高薪

热门文章

  1. 20应用统计考研复试要点(part41)--概率论与数理统计
  2. matlab图像去毛刺_信号去毛刺,去零漂
  3. oa 中会议推送 实现_揭秘“OA与ERP高端融合方案”三大亮点
  4. 一行SQL代码能做什么?
  5. 读书笔记 —《钱从哪里来》
  6. SAP Spartacus如何为不同的environment设置不同的baseUrl
  7. 如何使用schematics快速创建全新的SAP Spartacus Storefront并启用SSR
  8. Angular开发模式下的setNgReflectProperties函数
  9. Angular里的structural directive的一个例子
  10. Kyma registration of webservices and event endpoints