When we retrain the network, we may want to keep some network parameters unchanged and adjust only some of them; Or we can train some branch networks to prevent their gradients from affecting the gradients of the main network. At this time, we need to use the detach() function to cut off the gradients of some branches Back propagation

## 1, Tensor detach()

Returns a new tensor, which is separated from the current calculation diagram, but still points to the storage location of the original variable. The difference is only requires_grad is false. The resulting tensor never needs to calculate its gradient and does not have grad.

Even if it's required again later_ If grad is set to true, it will not have gradient grad

In this way, we will continue to use the new tensor for calculation. Later, when we conduct back propagation, the tensor calling detach() will stop and can no longer continue to propagate forward

be careful:

The tensor returned by detach and the original tensor share the same memory, that is, when one is modified, the other will change.

For example, a normal example is:

1 import torch 2 3 a = torch.tensor([1, 2, 3.], requires_grad=True) 4 print(a.grad) 5 out = a.sigmoid() 6 7 out.sum().backward() 8 print(a.grad) 9 '''return: 10 None 11 tensor([0.1966, 0.1050, 0.0452]) 12 '''

1.1. When detach() is used to separate the tensor, but this tensor is not changed, backward() will not be affected:

1 import torch 2 3 a = torch.tensor([1, 2, 3.], requires_grad=True) 4 print(a.grad) 5 out = a.sigmoid() 6 print(out) 7 8 #add to detach(),c of requires_grad by False 9 c = out.detach() 10 print(c) 11 12 #Not right at this time c Make changes, so it will not affect backward() 13 out.sum().backward() 14 print(a.grad) 15 16 '''return: 17 None 18 tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) 19 tensor([0.7311, 0.8808, 0.9526]) 20 tensor([0.1966, 0.1050, 0.0452]) 21 '''

It can be seen from the above that tensor c is separated from out, but I have not changed this c. There will be no error in deriving the original out at this time, that is

c. The difference between out and C is that C has no gradient and out has gradient. However, it should be noted that the following two cases are reported incorrectly,

1.2 when detach() is used to separate the tensor, and then the separated tensor is used to calculate the derivative, it will affect backward(), and errors will occur

1 import torch 2 3 a = torch.tensor([1, 2, 3.], requires_grad=True) 4 print(a.grad) 5 out = a.sigmoid() 6 print(out) 7 8 #add to detach(),c of requires_grad by False 9 c = out.detach() 10 print(c) 11 12 #Use newly generated Variable Back propagation 13 c.sum().backward() 14 print(a.grad) 15 16 '''return: 17 None 18 tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) 19 tensor([0.7311, 0.8808, 0.9526]) 20 Traceback (most recent call last): 21 File "test.py", line 13, in <module> 22 c.sum().backward() 23 File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward 24 torch.autograd.backward(self, gradient, retain_graph, create_graph) 25 File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward 26 allow_unreachable=True) # allow_unreachable flag 27 RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 28 '''

1.3 when detach() is used to separate the tensor and this tensor is changed, even if the derivative of the original out is calculated again, backward() will be affected and errors will occur

If you make a change to c at this time, the change will be tracked by autograd in out An error will also be reported when sum() performs backward() because the gradient obtained by backward() is wrong:

1 import torch 2 3 a = torch.tensor([1, 2, 3.], requires_grad=True) 4 print(a.grad) 5 out = a.sigmoid() 6 print(out) 7 8 #add to detach(),c of requires_grad by False 9 c = out.detach() 10 print(c) 11 c.zero_() #use in place Function to modify it 12 13 #Will find c The modification of will also affect out Value of 14 print(c) 15 print(out) 16 17 #Right now c Make changes, so it will affect backward()，It can't be done at this time backward()，Will report an error 18 out.sum().backward() 19 print(a.grad) 20 21 '''return: 22 None 23 tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) 24 tensor([0.7311, 0.8808, 0.9526]) 25 tensor([0., 0., 0.]) 26 tensor([0., 0., 0.], grad_fn=<SigmoidBackward>) 27 Traceback (most recent call last): 28 File "test.py", line 16, in <module> 29 out.sum().backward() 30 File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward 31 torch.autograd.backward(self, gradient, retain_graph, create_graph) 32 File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward 33 allow_unreachable=True) # allow_unreachable flag 34 RuntimeError: one of the variables needed for gradient computation has been modified 35 by an inplace operation 36 '''

2, Tensor detach_ ()

Separate a tensor from the graph that created it and set it as a leaf tensor

In fact, it is equivalent to that the relationship between variables is X - > m - > y, and the leaf tensor here is x, but m.detach is performed on m at this time_ () operation is actually two operations:

- Grad of M_ The value of FN is set to None, so that m will no longer be associated with the previous node X. The relationship here will become x, M - > y, and m will become a leaf node
- Then m's requirements will be_ Grad is set to False, so that the gradient of m will not be calculated when backward() is performed on y

Summary: actually detach() and detach_ () very similar. The difference between the two is detach_ () is a change to itself, and detach() generates a new tensor

For example, in X - > m - > y, if you detach m (), you can still operate the original calculation diagram if you want to go back later

But if detach is performed_ (), then the original calculation chart has also changed, so we can't go back

Original link: https://blog.csdn.net/qq_27825451/article/details/95498211