pytorch detach,clone的区别

发布于:2025-03-09 ⋅ 阅读:(100) ⋅ 点赞:(0)

1.detach

返回原张量的view,但不保留计算图

a = torch.tensor([3.,2.],requires_grad=True)
a_detach = a.detach()
print(a_detach.requires_grad)    # 输出为 False
a_detach[0] = 1
print(a)  # a变成 tensor([1., 2.], requires_grad=True)

适用于想对原始张量操作,但又不想影响计算图的情况

2.clone

首先,运行如下代码:

x = torch.tensor([2.0, 3.0], requires_grad=True)
x_clone = x.clone()
print(x_clone.is_leaf)  # 输出结果为 False

可以看到,叶子节点对应的clone张量并不是一个叶子节点,因此,接下来为了查看clone张量的梯度值,需要使用方法.retain_grad()
(如果不了解这个方法,可以移步
https://blog.csdn.net/qq_45812220/article/details/146113524):

a = torch.tensor(1.,requires_grad=True)  # 创建叶子节点a
a_clone = a.clone()   # 创建clone张量a_clone
a_clone.retain_grad()   # 便于后续得到a_clone对应的梯度

y = a*3    # 对a进行计算
z = a_clone*4   # 对a_clone进行计算

y.backward()   #  对a的计算结果进行反向传播
print(a.grad)   #  tensor(3.)
print(a_clone.grad)   # None

可以发现,仅仅对a的计算结果y进行反向传播,不会影响到a_clone的梯度。但是,如果修改一下代码:

a = torch.tensor(1.,requires_grad=True)  # 创建叶子节点a
a_clone = a.clone()   # 创建clone张量a_clone
a_clone.retain_grad()   # 便于后续得到a_clone对应的梯度

y = a*3    # 对a进行计算
z = a_clone*4   # 对a_clone进行计算

z.backward()   #  对a_clone的计算结果进行反向传播
print(a.grad)   # tensor(4.)
print(a_clone.grad)   #tensor(4.)

观察输出结果,可以发现,此时a的梯度值也得到了更新。

再执行下面的代码:

a = torch.tensor([1.,2.],requires_grad=True)  # 创建叶子节点a
a_clone = a.clone()   # 创建clone张量a_clone
a_clone.retain_grad()   # 便于后续得到a_clone对应的梯度

a_clone[0] = 10   # tensor([10.,   2.], grad_fn=<CopySlices>)
z = (a_clone**2).sum()
z.backward()

print(a.grad)  # tensor([0., 4.])
print(a_clone.grad)   # tensor([20.,  4.])

可以看到,a_clonea并不共享内存,我们可以单独修改a_clone的值。另外,clone张量对应的值被修改以后,该值对应的梯度并不会被传回到a,仅仅只有a_clone中未被重新赋值的地方的张量可以传回a

3.总结

detach生成原始张量的view而不保留计算图;clone生成的张量附带计算图,计算该张量的梯度时,梯度会回传到原始张量。但clone张量不是原张量的view

另外,可以对clone张量的值进行局部修改,但局部修改以后,梯度更新时,该处的梯度不会回传。


网站公告

今日签到

点亮在社区的每一天
去签到