About register in PyTorch_ forward_ The hook() function failed to execute the problem in which the hook function

About register in PyTorch_ forward_ The hook() function failed to execute the problem in which the hook function

Hook is a very useful feature in PyTorch. Using it, we can easily obtain and change the value and gradient of variables in the middle layer of the network without changing the structure of network input and output. This function is widely used to visualize the feature and gradient of the middle layer of neural network, so as to diagnose the possible problems in neural network and analyze the effectiveness of network.

Hook function mechanism: it does not change the main body and realizes additional functions, just like a pendant;

Hook function itself is not the focus of this paper. There are many articles introduced on the Internet. This paper mainly records some problems and solutions encountered by writers when using hook function.

register_forward_hook

First, let's take a look at one of the simplest ways to use register_ forward_ Example of hook:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = F.relu(self.conv1(x))     #1 
        out = F.max_pool2d(out, 2)      #2
        out = F.relu(self.conv2(out))   #3
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
 
features = []
def hook(module, input, output): 
    # module: model.conv2 
    # input :in forward function  [#2]
    # output:is  [#3 self.conv2(out)]
    print('*'*100)
    features.append(output.clone().detach())
    # output is saved  in a list 
 
 
net = LeNet() ## Model instantiation 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## Obtain the intermediate results of the entire Lenet model conv2
y = net(x)  ## The conv2 result about input x is obtained 
 

print(features[0].size()) # Namely [#3 self.conv2(out)]
handle.remove() ## Hook deletion prevents the space occupied by saving hook content multiple times

output

****************************************************************************************************
torch.Size([2, 16, 10, 10])

The shape is the result we want. Printing a string of * is to visually verify that the hook function is called.

Where conv2 is the name, we can print the state of the model_ Dict() to see which module you want

for k in model.state_dict():
    print(k)

Output:

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias

Let's take conv2 as an example.

Problems arising

In practical use, I want to print the latest transformer model alt_ gvt_ Look at the location code of large, but there is a problem.

I checked the module in the model and found what I wanted

import torch
import timm
import numpy as np
import cv2
import seaborn as sns
import gvt
from PIL import Image
from torchvision import transforms

fmap_block = []
def forward_hook(module, data_input, data_output):
    print('*'*100)
    fmap_block.append(data_output.clone().detach())

model = timm.create_model(
        'alt_gvt_large',
        pretrained=False,
        num_classes=1000,
        drop_rate=0.1,
        drop_path_rate=0.1,
        drop_block_rate=None,
    )
pipeline = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    ])

for k in model.state_dict():
    print(k)

Output:

# ...
patch_embeds.3.norm.weight
patch_embeds.3.norm.bias
norm.weight
norm.bias
head.weight
head.bias
pos_block.0.proj.0.weight
pos_block.0.proj.0.bias
pos_block.1.proj.0.weight
pos_block.1.proj.0.bias
pos_block.2.proj.0.weight
pos_block.2.proj.0.bias
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias
blocks.0.0.norm1.weight
blocks.0.0.norm1.bias
# ...

That must be pos_block.

Start hook:

image = Image.open('125.jpg')
image = pipeline(image).unsqueeze(dim=0)

handle = model.pos_block.register_forward_hook(forward_hook)
  
pred = model(image)
print(fmap_block[0].shape)
handle.remove()

There is a big problem. There is no output at all. Even the * we set to verify the operation of the hook function does not appear. The hook function must not be executed. What's the matter?

Solving process

After careful comparison of the above two successful and failed hook experiences:

conv2.bias
conv2.weight
--------
pos_block.3.proj.0.weight
pos_block.3.proj.0.bias

It is not difficult to guess from a simple analysis that only the following module s that can directly point (.) to weight and bias can be directly hook.

However, if you paste the output results directly, the following will appear:

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)

If you report syntax errors directly, you can't point numbers directly.

handle = model.pos_block.3.proj.0.register_forward_hook(forward_hook)
                            ^
SyntaxError: invalid syntax

So the author checked it layer by layer:

for k in model.pos_block:
    print(k)
    for _k in k.proj.state_dict():
        print(_k)
        break
    break 
print(type(model.pos_block))

It is found that the type of place where the number appears on the above is actually: < class' torch nn. modules. container. Modulelist '>, that is, a list, whether it can be directly indexed with [].

So we can change it to:

handle = model.pos_block[3].proj[0].register_forward_hook(forward_hook)

Output:

****************************************************************************************************
torch.Size([1, 256, 28, 28])

Finally succeeded.

summary

Or for the Model, module and childeren in PyTorch_ Module is not well understood. They only use the most basic methods. A little more advanced operation will encounter resistance. I will have time to sort it out in the future. PyTorch is now recognized as an easy-to-use open source framework, but if you want to implement your ideas at will, you still need to take some time to understand each component and the relationship between them.

Keywords: Python Machine Learning AI Pytorch Deep Learning

Added by Pobega on Mon, 03 Jan 2022 21:36:54 +0200