About register in PyTorch_ forward_ The hook() function failed to execute the problem in which the hook function
Hook is a very useful feature in PyTorch. Using it, we can easily obtain and change the value and gradient of variables in the middle layer of the network without changing the structure of network input and output. This function is widely used to visualize the feature and gradient of the middle layer of neural network, so as to diagnose the possible problems in neural network and analyze the effectiveness of network.
Hook function mechanism: it does not change the main body and realizes additional functions, just like a pendant;
Hook function itself is not the focus of this paper. There are many articles introduced on the Internet. This paper mainly records some problems and solutions encountered by writers when using hook function.
register_forward_hook
First, let's take a look at one of the simplest ways to use register_ forward_ Example of hook:
import torch import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): out = F.relu(self.conv1(x)) #1 out = F.max_pool2d(out, 2) #2 out = F.relu(self.conv2(out)) #3 out = F.max_pool2d(out, 2) out = out.view(out.size(0), -1) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) return out features = [] def hook(module, input, output): # module: model.conv2 # input :in forward function [#2] # output:is [#3 self.conv2(out)] print('*'*100) features.append(output.clone().detach()) # output is saved in a list net = LeNet() ## Model instantiation x = torch.randn(2, 3, 32, 32) ## input handle = net.conv2.register_forward_hook(hook) ## Obtain the intermediate results of the entire Lenet model conv2 y = net(x) ## The conv2 result about input x is obtained print(features[0].size()) # Namely [#3 self.conv2(out)] handle.remove() ## Hook deletion prevents the space occupied by saving hook content multiple times
output
**************************************************************************************************** torch.Size([2, 16, 10, 10])
The shape is the result we want. Printing a string of * is to visually verify that the hook function is called.
Where conv2 is the name, we can print the state of the model_ Dict() to see which module you want
for k in model.state_dict(): print(k)
Output:
conv1.weight conv1.bias conv2.weight conv2.bias fc1.weight fc1.bias fc2.weight fc2.bias fc3.weight fc3.bias
Let's take conv2 as an example.
Problems arising
In practical use, I want to print the latest transformer model alt_ gvt_ Look at the location code of large, but there is a problem.
I checked the module in the model and found what I wanted
import torch import timm import numpy as np import cv2 import seaborn as sns import gvt from PIL import Image from torchvision import transforms fmap_block = [] def forward_hook(module, data_input, data_output): print('*'*100) fmap_block.append(data_output.clone().detach()) model = timm.create_model( 'alt_gvt_large', pretrained=False, num_classes=1000, drop_rate=0.1, drop_path_rate=0.1, drop_block_rate=None, ) pipeline = transforms.Compose([ transforms.RandomCrop(224), transforms.ToTensor(), ]) for k in model.state_dict(): print(k)
Output:
# ... patch_embeds.3.norm.weight patch_embeds.3.norm.bias norm.weight norm.bias head.weight head.bias pos_block.0.proj.0.weight pos_block.0.proj.0.bias pos_block.1.proj.0.weight pos_block.1.proj.0.bias pos_block.2.proj.0.weight pos_block.2.proj.0.bias pos_block.3.proj.0.weight pos_block.3.proj.0.bias blocks.0.0.norm1.weight blocks.0.0.norm1.bias # ...
That must be pos_block.
Start hook:
image = Image.open('125.jpg') image = pipeline(image).unsqueeze(dim=0) handle = model.pos_block.register_forward_hook(forward_hook) pred = model(image) print(fmap_block[0].shape) handle.remove()
There is a big problem. There is no output at all. Even the * we set to verify the operation of the hook function does not appear. The hook function must not be executed. What's the matter?
Solving process
After careful comparison of the above two successful and failed hook experiences:
conv2.bias conv2.weight -------- pos_block.3.proj.0.weight pos_block.3.proj.0.bias
It is not difficult to guess from a simple analysis that only the following module s that can directly point (.) to weight and bias can be directly hook.
However, if you paste the output results directly, the following will appear:
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
If you report syntax errors directly, you can't point numbers directly.
handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook) ^ SyntaxError: invalid syntax
So the author checked it layer by layer:
for k in model.pos_block: print(k) for _k in k.proj.state_dict(): print(_k) break break print(type(model.pos_block))
It is found that the type of place where the number appears on the above is actually: < class' torch nn. modules. container. Modulelist '>, that is, a list, whether it can be directly indexed with [].
So we can change it to:
handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)
Output:
**************************************************************************************************** torch.Size([1, 256, 28, 28])
Finally succeeded.
summary
Or for the Model, module and childeren in PyTorch_ Module is not well understood. They only use the most basic methods. A little more advanced operation will encounter resistance. I will have time to sort it out in the future. PyTorch is now recognized as an easy-to-use open source framework, but if you want to implement your ideas at will, you still need to take some time to understand each component and the relationship between them.