Skip to content

For a full pytorch tutorial, see notebook#


  • t.detach() return a tensor which is detached from the computation graph. However, this tensor is a reference to the original tensor t.
  • just calling detach() won't destroy the computational graph.
x = torch.tensor([1.,2.],requires_grad=True)
xfang = x * x
xlifang = x * xfang
xfang_detached = xfang.detach()
loss = xlifang.sum()
print(x.grad) # Not None


  • If you want to mutate t after detaching it from the graph, you should use t.detach().clone(), so that the mutation won't affect t in the graph.


  • Can backward twice for one leaf tensor x, but can't backward for one non-leaf tensor y twice. For example, this is possible
    x = torch.tensor([1.,2.],requires_grad=True)
    y = (x * x).sum()
    z = (x * x).sum()

Last update: 2024年10月29日 20:50:57
Created: 2024年10月29日 20:50:57