时间:2021-05-23
clone() 与 detach() 对比
Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝,
首先我们来打印出来clone()操作后的数据类型定义变化:
(1). 简单打印类型
import torcha = torch.tensor(1.0, requires_grad=True)b = a.clone()c = a.detach()a.data *= 3b += 1print(a) # tensor(3., requires_grad=True)print(b)print(c)'''输出结果:tensor(3., requires_grad=True)tensor(2., grad_fn=<AddBackward0>)tensor(3.) # detach()后的值随着a的变化出现变化'''grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数。
detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。
注意:在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况。
import torch as ta = t.tensor([1.0,2.0], requires_grad=True)b = a.detach()#c[:] = a.detach()print(id(a))print(id(b))#140568935450520140570337203616显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断。
(2). clone()的梯度回传
detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算。
而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None
import torcha = torch.tensor(1.0, requires_grad=True)a_ = a.clone()y = a**2z = a ** 2+a_ * 3y.backward()print(a.grad) # 2z.backward()print(a_.grad) # None. 中间variable,无gradprint(a.grad) '''输出:tensor(2.) Nonetensor(7.) # 2*2+3=7'''使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下。
通常如果原tensor的requires_grad=True,则:
另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导。
如下:
import torcha = torch.tensor(1.0)a_ = a.clone()a_.requires_grad_() #require_grad=Truey = a_ ** 2y.backward()print(a.grad) # Noneprint(a_.grad) '''输出:Nonetensor(2.)'''了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要。
比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变。
需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致。
x = torch.rand(2, 2)y = x.view(4)x += 1print(x)print(y) # 也加了1view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处
x = torch.rand(2, 2)x_cp = x.clone().view(4)x += 1print(id(x))print(id(x_cp))print(x)print(x_cp)'''140568935036464140568935035816tensor([[0.4963, 0.7682], [0.1320, 0.3074]])tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) '''另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。在上一篇中有总结。
总结:
引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用。
像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同。
到此这篇关于PyTorch中clone()、detach()及相关扩展详解的文章就介绍到这了,更多相关PyTorch中clone()、detach()及相关扩展内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
PyTorch0.4中,.data仍保留,但建议使用.detach(),区别在于.data返回和x的相同数据tensor,但不会加入到x的计算历史里,且requ
java深拷贝与浅拷贝机制详解概要:在Java中,拷贝分为深拷贝和浅拷贝两种。java在公共超类Object中实现了一种叫做clone的方法,这种方法clone
详解phpmyadmin相关配置与错误解决缺少mcrypt扩展sudoapt-getinstallphp5-mcryptsudophp5enmodmcrypt检
Android注解相关文章:AndroidAOP注解Annotation详解(一)AndroidAOP之注解处理解释器详解(二)AndroidAOP注解详解及简
torch.Tensor.detach()的使用detach()的官方说明如下:ReturnsanewTensor,detachedfromthecurrent