Back propagation method in RNN and LSTM models, problem at loss.backward(),
After updating the pytorch version, it is prone to problems.
Question 1. Use loss.backward() to report an error
Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
(torchenv) star@lab407-1:~/POIRec/STPRec/Flashback_code-master$ python train.py
Question 2. Use loss.backward(retain_graph=True)
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 10]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Solution:
About loss.backward() and its parameter retain_ Some pits in the graph
First, the function loss.backward() is very simple, which is to calculate the gradient of the current tensor related to the leaf node in the graph
Of course, it can be used directly as follows
optimizer.zero_grad() Clear the past gradient; loss.backward() Back propagation to calculate the current gradient; optimizer.step() Update network parameters according to gradient or This situation for i in range(num): loss+=Loss(input,target) optimizer.zero_grad() Clear the past gradient; loss.backward() Back propagation to calculate the current gradient; optimizer.step() Update network parameters according to gradient
However, some times, such errors occur: runtimeerror: trying to backward through the graph a second time, but the buffers have already been free
This error means that the mechanism of pytoch is that every time. Backward() is called, all buffers will be free. There may be multiple backward() in the model, and the gradient stored in the buffer in the previous backward() will be free because of the subsequent call to backward(). Therefore, retain is required here_ Graph = true this parameter
Using this parameter, you can save the gradient of the previous backward() in the buffer until the update is completed. Note that if you write this:
optimizer.zero_grad() Clear the past gradient; loss1.backward(retain_graph=True) Back propagation to calculate the current gradient; loss2.backward(retain_graph=True) Back propagation to calculate the current gradient; optimizer.step() Update network parameters according to gradient
Then you may have memory overflow, and each iteration will be slower than the previous one, and slower and slower later (because your gradients are saved and there is no free)
The solution is, of course:
optimizer.zero_grad() Clear the past gradient; loss1.backward(retain_graph=True) Back propagation to calculate the current gradient; loss2.backward() Back propagation to calculate the current gradient; optimizer.step() Update network parameters according to gradient
That is: do not add retain to the last backward()_ Graph parameter, so that the occupied memory will be released after each update, so that it will not become slower and slower.
Someone here will ask, I don't have so much loss, how can such a mistake happen? This may be because the model you use has problems, such as LSTM and GRU. The problem exists with hidden unit, which also participates in back propagation, resulting in multiple backward(),
In fact, I don't understand why there are multiple backward()? Is it true that my LSTM network is n to N, that is, input n and output n, then calculate loss with n label s, and then send it back? Here, you can think about BPTT, that is, if it is N to 1, then gradient update requires all inputs of the time series and hidden variables to calculate the gradient, and then pass it forward from the last one, so there is only one backward(), In both N to N and N to M, multiple losses need to be backwarded (). If they continue to propagate in two directions (one from output to input and the other along time), there will be overlapping parts. Therefore, the solution is very clear. Use the detach() function to cut off the overlapping backpropagation, (here is only my personal understanding. If there is any error, please comment and point it out and discuss it together.) there are three ways to cut off, as follows:
hidden.detach_() hidden = hidden.detach() hidden = Variable(hidden.data, requires_grad=True)
reference resources:
https://blog.csdn.net/a845717607/article/details/104598278/