Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点
1. 写在前面
之前一直不太搞明白浅拷贝和赋值、深拷贝到底有什么区别,直到被pytorch的model.state_dict()给坑了
今天在和实验室同学讨论联邦学习框架代码的时候,终于明白了他们之间的区别,这里做个记录。
2. 先说结论
(1)直接赋值:给变量取个别名,原来叫张三,现在我给他取个小名,叫小张
- b = a (b是a的别名)
(2)浅拷贝(shadow copy):拷贝最外层的数值和指针,不拷贝更深层次的对象,即只拷贝了父对象
- copy.copy(xxx)
model.state_dict()
也是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。具体可以看下文的示例
题外话:浅拷贝为什么叫“浅”,因为他只拷贝最外层的东西,不会去拷贝最外层“指针”所指向的内层的东西,所以浅。而深拷贝则会拷贝全部层的东西,所以深
(3)深拷贝(deepcopy):拷贝数值、指针和指针指向的深层次内存空间,拷贝了父对象及其子对象。
- copy.deepcopy(xxx)
model.load_state_dict(xxx)
是深拷贝
3. 一图胜前言
这一小节主要来自:一个工作三年的同事,居然还搞不清深拷贝、浅拷贝…
2021年10月24日 更新:下面这个图其实是以Java语言而言的,我一开始以为Python字符串和int数值应该也是直接赋值的,后来经过验证,发现python中的字符串其实是引用(地址),所以若a=“hello”,则b=a是把"hello"的地址赋值给b。另外-5到256这个范围内的整数是公用一块内存空间的,具体请看我的博客:Python中容易被忽视的知识点:字符串是传引用以及整数-5到256共享内存空间
浅拷贝
深拷贝
深拷贝相较于上面所示的浅拷贝,除了值类型字段会复制一份,引用类型字段所指向的对象,会在内存中也创建一个副本,就像这个样子:
4. Pytorch的model load_state_dict()和state_dict()有坑点
pytorch在获取模型参数和加载模型参数时是有坑点的,而且这个bug一般不太容易发现,因为他不会报错,有时你很难通过实验结果注意到这个问题,我自己写框架时也是被坑过。
model.state_dict()
实际上是浅拷贝,如果令param=model.state_dict(),那么当你修改param,相应地也会修改model的参数。model这个对象实际上是指向各个参数矩阵的,而浅拷贝只会拷贝最外层的这些“指针”。model.load_state_dict(xxx)
是深拷贝
用代码验证以上观点,可以结合上文的两张示意图来理解下面代码
import torch
import copym1 = torch.nn.Linear(in_features=5, out_features=1, bias=True)
m2 = torch.nn.Linear(in_features=5, out_features=1, bias=True)# m1是引用指向某块内存空间
# 浅拷贝相当于拷贝一个引用,所以他们“引用”变量的id是不一样的,指向的内存空间是一样的
ck = copy.copy(m1)
print(id(m1) == id(ck)) # Falseprint(m1.weight)
# Parameter containing:
# tensor([[ 0.0171, 0.4382, -0.4297, 0.4098, -0.3954]], requires_grad=True)# state_dict is shadow copy
p = m1.state_dict()
print(id(m1.state_dict()) == id(p)) # False# 通过引用p去修改内存空间
p['weight'][0][0] = 8.8888
# 可以看到m1指向的内存空间也被修改了
print(m1.state_dict())
# OrderedDict([('weight', tensor([[ 8.8888, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])# deepcopy
m2.load_state_dict(p)
m2.weight[0][0] = 2.0
print(p)
# OrderedDict([('weight', tensor([[ 8.8888, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
print(m2.state_dict())
# OrderedDict([('weight', tensor([[ 2.0000, 0.4382, -0.4297, 0.4098, -0.3954]])), ('bias', tensor([0.3964]))])
在我的联邦学习框架中本地模型参数确实是浅拷贝,但是我们没有去修改这个local_params,我们只是把不同客户端的local_params加权平均去更新global_params而已,所以不用deepcopy也没事
但如果想保存最优模型的参数,则必须要用deepcopy
best_state changes with the model during training in pytorch 这位提问者想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数,下面是他的错误代码:
def train(): #training steps … if acc > best_acc: best_state = model.state_dict() best_acc = accreturn best_state
5. 实战演练
来源:Python 直接赋值、浅拷贝和深度拷贝解析
import copya = [1, 2, 3, 4, ['a', 'b']] # 原始对象b = a # 赋值,传对象的引用
c = copy.copy(a) # 对象拷贝,浅拷贝
d = copy.deepcopy(a) # 对象拷贝,深拷贝a.append(5) # 修改对象a
a[4].append('c') # a[4]是指针,修改对象a中的['a', 'b']数组对象print('a = ', a)
print('b = ', b)
print('c = ', c) # 浅拷贝,只会拷贝最外层的数值或指针
print('d = ', d)
a = [1, 2, 3, 4, ['a', 'b', 'c'], 5]
b = [1, 2, 3, 4, ['a', 'b', 'c'], 5]
c = [1, 2, 3, 4, ['a', 'b', 'c']]
d = [1, 2, 3, 4, ['a', 'b']]
现在你看下面这段代码的输出结果应该就不奇怪了吧
import copyA = [1, 2, 3]
print(A) # [1, 2, 3]B = copy.copy(A) # 浅拷贝(最外层"值"会拷贝,"引用"会拷贝)
B.append(5)
print(A) # [1, 2, 3]
print(B) # [1, 2, 3, 5]
6. Deep copy VS Shadow copy
深拷贝示例:
# Python code to demonstrate copy operations# importing "copy" for copy operations
import copy# initializing list 1
li1 = [1, 2, [3, 5], 4]# using deepcopy to deep copy
li2 = copy.deepcopy(li1)# original elements of list
print("The original elements before deep copying")
for i in range(0, len(li1)):print(li1[i], end=" ")print("\r")# adding and element to new list
li2[2][0] = 7# Change is reflected in l2
print("The new list of elements after deep copying ")
for i in range(0, len(li1)):print(li2[i], end=" ")print("\r")
The original elements before deep copying
1 2 [3, 5] 4
The new list of elements after deep copying
1 2 [7, 5] 4
The original elements after deep copying
1 2 [3, 5] 4
浅拷贝示例:
# Python code to demonstrate copy operations# importing "copy" for copy operations
import copy# initializing list 1
li1 = [1, 2, [3,5], 4]# using copy to shallow copy
li2 = copy.copy(li1)# original elements of list
print ("The original elements before shallow copying")
for i in range(0,len(li1)):print (li1[i],end=" ")print("\r")# adding and element to new list
li2[2][0] = 7# checking if change is reflected
print ("The original elements after shallow copying")
for i in range(0,len( li1)):print (li1[i],end=" ")
The original elements before shallow copying
1 2 [3, 5] 4
The original elements after shallow copying
1 2 [7, 5] 4
注意:上面用了li2[2][0] = 7
,相当于是在修改引用的内存空间;如果是li2[1] = 7
,那么l1[1]
不会改变
7. 参考资料
i. Numpy中的浅拷贝和深拷贝问题
ii. copy in Python (Deep Copy and Shallow Copy) (geeksforgeeks的文章还是挺清楚的)
iii. Python 直接赋值、浅拷贝和深度拷贝解析
iv. pytorch的state_dict()拷贝问题
v. 一个工作三年的同事,居然还搞不清深拷贝、浅拷贝… (图解挺不错的)
vi. best_state changes with the model during training in pytorch (这位老哥想保存最佳模型参数,结果因为浅拷贝,导致保存的都是最后一轮的模型参数)
vii. Python中的赋值(复制)、浅拷贝与深拷贝 (这篇文章关于可变对象和不可对象的拷贝的id是否会改变进行了讨论)
写在最后
✨原创不易,还希望各位大佬支持一下\textcolor{blue}{原创不易,还希望各位大佬支持一下}原创不易,还希望各位大佬支持一下
Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点相关推荐
- Interview:算法岗位面试—10.11下午—上海某公司算法岗位(偏机器学习,互联网数字行业)技术面试考点之XGBoost的特点、python的可变不可变的数据类型、赋值浅拷贝深拷贝区别
ML岗位面试:10.11下午-上海某公司算法岗位(偏机器学习,互联网数字行业)技术面试考点之XGBoost的特点.python的可变不可变的数据类型.赋值浅拷贝深拷贝区别 Interview:算法岗位 ...
- 一文搞懂JS中的赋值·浅拷贝·深拷贝
前言 为什么写拷贝这篇文章?同事有一天提到了拷贝,他说赋值就是一种浅拷贝方式,另一个同事说赋值和浅拷贝并不相同.我也有些疑惑,于是我去MDN搜一下拷贝相关内容,发现并没有关于拷贝的实质概念,没有办法只 ...
- python3 赋值 浅拷贝 深拷贝 简介
目录 一.赋值 二.浅拷贝(shallow copy) 三.深拷贝(deep copy) 四.关于拷贝操作的警告 一.赋值 在python中,对象的赋值就是简单的对象引用,这点和C++不同.如下: a ...
- 拷贝,浅拷贝与深拷贝三者的区别
拷贝,浅拷贝与深拷贝的区别如下: 如果拷贝的对象里的元素只有值,没有引用类型,那浅拷贝和深拷贝没有差别,新对象和原对象相互独立,不受影响: 如果拷贝的对象里的元素包含引用类型, 对于浅拷贝,它虽然将原 ...
- python赋值浅拷贝和深拷贝的区别_python赋值、浅拷贝、深拷贝区别
在写Python过程中,经常会遇到对象的拷贝,如果不理解浅拷贝和深拷贝的概念,你的代码就可能出现一些问题.所以,在这里按个人的理解谈谈它们之间的区别. 一.赋值(assignment) 在<Py ...
- Map的putAll方法踩坑实记(对象深拷贝浅拷贝)
文章目录 问题描述 编写测试代码模拟问题场景 场景1:Map中不包含对象 场景2:Map中包含对象 什么是对象的浅拷贝深拷贝 如何实现深拷贝 问题描述 在一个产品管理系统中,产品信息需要封装一份同步业 ...
- go nil json.marshal 完是null_字节跳动踩坑记#3:Go服务灵异panic
这个坑比较新鲜,刚填完,还冒着冷气. - 1 - 在字节跳动,我们服务的所有 log 都通过统一的日志库采集到流式日志服务.落地 ES 集群,配上字节云超(sang)级(xin)强(bing)大(ku ...
- js 浅拷贝直接赋值_js 深拷贝 vs 浅拷贝
本文主要讲一下 js 的基本数据类型以及一些堆和栈的知识和什么是深拷贝.什么是浅拷贝.深拷贝与浅拷贝的区别,以及怎么进行深拷贝和怎么进行浅拷贝. 本文思维导图如下:本文思维导图 堆和栈的区别 其实深拷 ...
- python从入门到实践django看不懂_Python编程:从入门到实践踩坑记 Django
<>踩坑记 Django Django Python 19.1.1.5 模板new_topic 做完书上的步骤后,对主题添加页面经行测试,但是浏览器显示 服务器异常. 个人采用的开发环境是 ...
- 东八区转为0时区_踩坑记 | Flink 天级别窗口中存在的时区问题
❝ 本系列每篇文章都是从一些实际的 case 出发,分析一些生产环境中经常会遇到的问题,抛砖引玉,以帮助小伙伴们解决一些实际问题.本文介绍 Flink 时间以及时区问题,分析了在天级别的窗口时会遇到的 ...
最新文章
- 委托、事件、事件访问器
- spring中的BeanPostProcessor
- js获取上传文件内容
- 解决HDFS NameNode启动时Loading edits时间超长的问题(NameNode数据同步机制介绍)
- 《黑客大曝光:移动应用安全揭秘及防护措施》一2.2 攻击与对策
- 面试官 | 什么是递归算法?它有什么用?
- 互联网日报 | 6月4日 星期五 | 蚂蚁消费金融获批开业;腾讯云四个国际数据中心同步开服;滴滴App上线“老人打车”模式...
- cordova混合开发流程
- clear linux安装教程,Clear Linux OS特性介绍,附下载地址
- azure云数据库_如何使用SQL Data Sync同步Azure SQL数据库和本地数据库
- 【分布式WebSocket - 1】超详细!WebSocket协议详解
- gitlab使用教程
- 扩展卡尔曼滤波(EKF)
- springboot整合es实现聚合搜索(api搜索版)
- spyder配置说明_Spyder学习使用总结
- Keras:ModelCheckpoint和model.fit的verbose有什么差异?
- html中如何设计圆形图案,纯CSS绘制漂亮的圆形图案效果
- 74ls175四人抢答器电路图_四人智力竞赛抢答器电路原理及设计.doc
- 《UE4蓝图完全学习》笔记
- 学术英语理工(第二版)Unit1课文翻译
热门文章
- 读取nginx的conf文件_nginx.conf配置文件
- h5计时器(requestAnimationFrame)
- ssms mysql_SQL Server免费版的安装以及使用SQL Server Management Studio(SSMS)连接数据库的图文方法...
- 新手如何推广优化自己的网站
- 互联网金融革命已让银行家们彻夜难眠
- 90后迎来30岁,比升职更重要的是这8件事
- html 调用es2015模块,在浏览器中懒加载ES2015模块
- 台式计算机网卡驱动不能正常使用,系统提示“您的网卡驱动程序不正常!”怎么办 是什么原因...
- 2549. 删除他们! 解题报告
- 苹果6p计算机在哪里设置方法,苹果手机怎么设置铃声【图文教程,不用电脑,1分钟完成】...