1. 写在前面

之前一直不太搞明白浅拷贝和赋值、深拷贝到底有什么区别,直到被pytorch的model.state_dict()给坑了

今天在和实验室同学讨论联邦学习框架代码的时候,终于明白了他们之间的区别,这里做个记录。

2. 先说结论

(1)直接赋值:给变量取个别名,原来叫张三,现在我给他取个小名,叫小张

  • b = a (b是a的别名)

(2)浅拷贝(shadow copy):拷贝最外层的数值和指针,不拷贝更深层次的对象,即只拷贝了父对象

  • copy.copy(xxx)
  • model.state_dict()也是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。具体可以看下文的示例

题外话:浅拷贝为什么叫“浅”,因为他只拷贝最外层的东西,不会去拷贝最外层“指针”所指向的内层的东西,所以浅。而深拷贝则会拷贝全部层的东西,所以深

(3)深拷贝(deepcopy):拷贝数值、指针和指针指向的深层次内存空间,拷贝了父对象及其子对象。

  • copy.deepcopy(xxx)
  • model.load_state_dict(xxx) 是深拷贝

3. 一图胜前言

这一小节主要来自:一个工作三年的同事,居然还搞不清深拷贝、浅拷贝…

2021年10月24日 更新:下面这个图其实是以Java语言而言的,我一开始以为Python字符串和int数值应该也是直接赋值的,后来经过验证,发现python中的字符串其实是引用(地址),所以若a=“hello”,则b=a是把"hello"的地址赋值给b。另外-5到256这个范围内的整数是公用一块内存空间的,具体请看我的博客:Python中容易被忽视的知识点:字符串是传引用以及整数-5到256共享内存空间

浅拷贝

深拷贝

深拷贝相较于上面所示的浅拷贝,除了值类型字段会复制一份,引用类型字段所指向的对象,会在内存中也创建一个副本,就像这个样子:

4. Pytorch的model load_state_dict()和state_dict()有坑点

pytorch在获取模型参数和加载模型参数时是有坑点的,而且这个bug一般不太容易发现,因为他不会报错,有时你很难通过实验结果注意到这个问题,我自己写框架时也是被坑过。

  • model.state_dict()实际上是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。
  • model.load_state_dict(xxx) 是深拷贝

用代码验证以上观点,可以结合上文的两张示意图来理解下面代码

import torch
import copym1 = torch.nn.Linear(in_features=5, out_features=1, bias=True)
m2 = torch.nn.Linear(in_features=5, out_features=1, bias=True)# m1是引用指向某块内存空间
# 浅拷贝相当于拷贝一个引用,所以他们“引用”变量的id是不一样的,指向的内存空间是一样的
ck = copy.copy(m1)
print(id(m1) == id(ck)) # Falseprint(m1.weight)
# Parameter containing:
# tensor([[ 0.0171,  0.4382, -0.4297,  0.4098, -0.3954]], requires_grad=True)# state_dict is shadow copy
p = m1.state_dict()
print(id(m1.state_dict()) == id(p)) # False# 通过引用p去修改内存空间
p['weight'][0][0] = 8.8888
# 可以看到m1指向的内存空间也被修改了
print(m1.state_dict())
# OrderedDict([('weight', tensor([[ 8.8888,  0.4382, -0.4297,  0.4098, -0.3954]])), ('bias', tensor([0.3964]))])# deepcopy
m2.load_state_dict(p)
m2.weight[0][0] = 2.0
print(p)
# OrderedDict([('weight', tensor([[ 8.8888,  0.4382, -0.4297,  0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
print(m2.state_dict())
# OrderedDict([('weight', tensor([[ 2.0000,  0.4382, -0.4297,  0.4098, -0.3954]])), ('bias', tensor([0.3964]))])

在我的联邦学习框架中本地模型参数确实是浅拷贝,但是我们没有去修改这个local_params,我们只是把不同客户端的local_params加权平均去更新global_params而已,所以不用deepcopy也没事

但如果想保存最优模型的参数,则必须要用deepcopy

best_state changes with the model during training in pytorch 这位提问者想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数,下面是他的错误代码:

def train():  #training steps …  if acc > best_acc:  best_state = model.state_dict()  best_acc = accreturn best_state

5. 实战演练

来源:Python 直接赋值、浅拷贝和深度拷贝解析

import copya = [1, 2, 3, 4, ['a', 'b']]  # 原始对象b = a  # 赋值,传对象的引用
c = copy.copy(a)  # 对象拷贝,浅拷贝
d = copy.deepcopy(a)  # 对象拷贝,深拷贝a.append(5)  # 修改对象a
a[4].append('c')  # a[4]是指针,修改对象a中的['a', 'b']数组对象print('a = ', a)
print('b = ', b)
print('c = ', c) # 浅拷贝,只会拷贝最外层的数值或指针
print('d = ', d)
a =  [1, 2, 3, 4, ['a', 'b', 'c'], 5]
b =  [1, 2, 3, 4, ['a', 'b', 'c'], 5]
c =  [1, 2, 3, 4, ['a', 'b', 'c']]
d =  [1, 2, 3, 4, ['a', 'b']]

现在你看下面这段代码的输出结果应该就不奇怪了吧

import copyA = [1, 2, 3]
print(A)  # [1, 2, 3]B = copy.copy(A) # 浅拷贝(最外层"值"会拷贝,"引用"会拷贝)
B.append(5)
print(A)  # [1, 2, 3]
print(B)  # [1, 2, 3, 5]

6. Deep copy VS Shadow copy

深拷贝示例:

# Python code to demonstrate copy operations# importing "copy" for copy operations
import copy# initializing list 1
li1 = [1, 2, [3, 5], 4]# using deepcopy to deep copy
li2 = copy.deepcopy(li1)# original elements of list
print("The original elements before deep copying")
for i in range(0, len(li1)):print(li1[i], end=" ")print("\r")# adding and element to new list
li2[2][0] = 7# Change is reflected in l2
print("The new list of elements after deep copying ")
for i in range(0, len(li1)):print(li2[i], end=" ")print("\r")
The original elements before deep copying
1 2 [3, 5] 4
The new list of elements after deep copying
1 2 [7, 5] 4
The original elements after deep copying
1 2 [3, 5] 4

浅拷贝示例:

# Python code to demonstrate copy operations# importing "copy" for copy operations
import copy# initializing list 1
li1 = [1, 2, [3,5], 4]# using copy to shallow copy
li2 = copy.copy(li1)# original elements of list
print ("The original elements before shallow copying")
for i in range(0,len(li1)):print (li1[i],end=" ")print("\r")# adding and element to new list
li2[2][0] = 7# checking if change is reflected
print ("The original elements after shallow copying")
for i in range(0,len( li1)):print (li1[i],end=" ")
The original elements before shallow copying
1 2 [3, 5] 4
The original elements after shallow copying
1 2 [7, 5] 4

注意:上面用了li2[2][0] = 7,相当于是在修改引用的内存空间;如果是li2[1] = 7,那么l1[1]不会改变

7. 参考资料

i. Numpy中的浅拷贝和深拷贝问题

ii. copy in Python (Deep Copy and Shallow Copy) (geeksforgeeks的文章还是挺清楚的)

iii. Python 直接赋值、浅拷贝和深度拷贝解析

iv. pytorch的state_dict()拷贝问题

v. 一个工作三年的同事,居然还搞不清深拷贝、浅拷贝… (图解挺不错的)

vi. best_state changes with the model during training in pytorch (这位老哥想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数)

vii. Python中的赋值(复制)、浅拷贝与深拷贝 (这篇文章关于可变对象和不可对象的拷贝的id是否会改变进行了讨论)

写在最后

✨原创不易,还希望各位大佬支持一下\textcolor{blue}{原创不易,还希望各位大佬支持一下}原创不易,还希望各位大佬支持一下

Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点相关推荐

  1. Interview:算法岗位面试—10.11下午—上海某公司算法岗位(偏机器学习,互联网数字行业)技术面试考点之XGBoost的特点、python的可变不可变的数据类型、赋值浅拷贝深拷贝区别

    ML岗位面试:10.11下午-上海某公司算法岗位(偏机器学习,互联网数字行业)技术面试考点之XGBoost的特点.python的可变不可变的数据类型.赋值浅拷贝深拷贝区别 Interview:算法岗位 ...

  2. 一文搞懂JS中的赋值·浅拷贝·深拷贝

    前言 为什么写拷贝这篇文章?同事有一天提到了拷贝,他说赋值就是一种浅拷贝方式,另一个同事说赋值和浅拷贝并不相同.我也有些疑惑,于是我去MDN搜一下拷贝相关内容,发现并没有关于拷贝的实质概念,没有办法只 ...

  3. python3 赋值 浅拷贝 深拷贝 简介

    目录 一.赋值 二.浅拷贝(shallow copy) 三.深拷贝(deep copy) 四.关于拷贝操作的警告 一.赋值 在python中,对象的赋值就是简单的对象引用,这点和C++不同.如下: a ...

  4. 拷贝,浅拷贝与深拷贝三者的区别

    拷贝,浅拷贝与深拷贝的区别如下: 如果拷贝的对象里的元素只有值,没有引用类型,那浅拷贝和深拷贝没有差别,新对象和原对象相互独立,不受影响: 如果拷贝的对象里的元素包含引用类型, 对于浅拷贝,它虽然将原 ...

  5. python赋值浅拷贝和深拷贝的区别_python赋值、浅拷贝、深拷贝区别

    在写Python过程中,经常会遇到对象的拷贝,如果不理解浅拷贝和深拷贝的概念,你的代码就可能出现一些问题.所以,在这里按个人的理解谈谈它们之间的区别. 一.赋值(assignment) 在<Py ...

  6. Map的putAll方法踩坑实记(对象深拷贝浅拷贝)

    文章目录 问题描述 编写测试代码模拟问题场景 场景1:Map中不包含对象 场景2:Map中包含对象 什么是对象的浅拷贝深拷贝 如何实现深拷贝 问题描述 在一个产品管理系统中,产品信息需要封装一份同步业 ...

  7. go nil json.marshal 完是null_字节跳动踩坑记#3:Go服务灵异panic

    这个坑比较新鲜,刚填完,还冒着冷气. - 1 - 在字节跳动,我们服务的所有 log 都通过统一的日志库采集到流式日志服务.落地 ES 集群,配上字节云超(sang)级(xin)强(bing)大(ku ...

  8. js 浅拷贝直接赋值_js 深拷贝 vs 浅拷贝

    本文主要讲一下 js 的基本数据类型以及一些堆和栈的知识和什么是深拷贝.什么是浅拷贝.深拷贝与浅拷贝的区别,以及怎么进行深拷贝和怎么进行浅拷贝. 本文思维导图如下:本文思维导图 堆和栈的区别 其实深拷 ...

  9. python从入门到实践django看不懂_Python编程:从入门到实践踩坑记 Django

    <>踩坑记 Django Django Python 19.1.1.5 模板new_topic 做完书上的步骤后,对主题添加页面经行测试,但是浏览器显示 服务器异常. 个人采用的开发环境是 ...

  10. 东八区转为0时区_踩坑记 | Flink 天级别窗口中存在的时区问题

    ❝ 本系列每篇文章都是从一些实际的 case 出发,分析一些生产环境中经常会遇到的问题,抛砖引玉,以帮助小伙伴们解决一些实际问题.本文介绍 Flink 时间以及时区问题,分析了在天级别的窗口时会遇到的 ...

最新文章

  1. 委托、事件、事件访问器
  2. spring中的BeanPostProcessor
  3. js获取上传文件内容
  4. 解决HDFS NameNode启动时Loading edits时间超长的问题(NameNode数据同步机制介绍)
  5. 《黑客大曝光:移动应用安全揭秘及防护措施》一2.2 攻击与对策
  6. 面试官 | 什么是递归算法?它有什么用?
  7. 互联网日报 | 6月4日 星期五 | 蚂蚁消费金融获批开业;腾讯云四个国际数据中心同步开服;滴滴App上线“老人打车”模式...
  8. cordova混合开发流程
  9. clear linux安装教程,Clear Linux OS特性介绍,附下载地址
  10. azure云数据库_如何使用SQL Data Sync同步Azure SQL数据库和本地数据库
  11. 【分布式WebSocket - 1】超详细!WebSocket协议详解
  12. gitlab使用教程
  13. 扩展卡尔曼滤波(EKF)
  14. springboot整合es实现聚合搜索(api搜索版)
  15. spyder配置说明_Spyder学习使用总结
  16. Keras:ModelCheckpoint和model.fit的verbose有什么差异?
  17. html中如何设计圆形图案,纯CSS绘制漂亮的圆形图案效果
  18. 74ls175四人抢答器电路图_四人智力竞赛抢答器电路原理及设计.doc
  19. 《UE4蓝图完全学习》笔记
  20. 学术英语理工(第二版)Unit1课文翻译

热门文章

  1. 读取nginx的conf文件_nginx.conf配置文件
  2. h5计时器(requestAnimationFrame)
  3. ssms mysql_SQL Server免费版的安装以及使用SQL Server Management Studio(SSMS)连接数据库的图文方法...
  4. 新手如何推广优化自己的网站
  5. 互联网金融革命已让银行家们彻夜难眠
  6. 90后迎来30岁,比升职更重要的是这8件事
  7. html 调用es2015模块,在浏览器中懒加载ES2015模块
  8. 台式计算机网卡驱动不能正常使用,系统提示“您的网卡驱动程序不正常!”怎么办 是什么原因...
  9. 2549. 删除他们! 解题报告
  10. 苹果6p计算机在哪里设置方法,苹果手机怎么设置铃声【图文教程,不用电脑,1分钟完成】...