XCiT: covariance image Quansi transformation network

This is the fifth issue of the oar paper reproduction challenge< XCiT: Cross-Covariance Image Transformers >The champion code of GitHub is https://github.com/BrilliantYuKaimin/XCiT-PaddlePaddle . The official PyTorch is implemented in https://github.com/facebookresearch/xcit.

Because the XCiT code has been integrated into the propeller PASSL Therefore, the code of this project will run based on PASSL. The networking code of XCiT can be found in PASL / pass / modeling / backbones / XCiT Py.

This project provides two propeller weight documents:

  • xcit_nano_12_p8_224_dist.pdparams is the weight retrained with the propeller. Its accuracy on the ImageNet1k test set is 77.28%, exceeding the official result of 76.3%.
  • regnety_160.pdparams is the teacher model weight used for knowledge distillation, which is converted from the corresponding PyTorch weight.

brief introduction

Following the success in the field of natural language processing, Transformer also shows great prospects in the field of computer vision. The self attention operation in the total thinking transformation network will produce global interaction between each item of the sequence (for example, it is a word in a sentence and a small picture block in a picture), and allow flexible modeling of image data in addition to local interaction such as convolution. However, the cost of this flexibility is the square level complexity in time and space, which hinders the application of Quansi transform network in long sequence and high-resolution images. The author proposes a transposed version of self attention operation, which operates between characteristic channels (rather than sequence terms) by means of the covariance matrix of key value and query value. Such covariance attention (XCA) has linear complexity in sequence length, allowing efficient processing of high-resolution images. Based on XCA, covariance image total thought transform network (XCiT) combines the accuracy of traditional total thought transform network with the scalability of convolution structure.

Covariance attention

For a shape N × d N\times d N × d input X \bm X 10. Among them N N N represents the length of the sequence. We can get three matrices through three different linear transformations
Q = X W q ,   K = X W k ,   V = X W v , \bm Q=\bm X\bm W_{\mathrm q},\ \bm K=\bm X\bm W_{\mathrm k},\ \bm V=\bm X\bm W_{\mathrm v}, Q=XWq​, K=XWk​, V=XWv​,
They are still N × d N\times d N × d. The original attention mechanism calculates
s o f t m a x   ( Q K T / d 1 / 2 ) V , \mathrm{softmax}\,\left(\bm Q\bm K^{\mathrm T}/d^{1/2}\right)\bm V, softmax(QKT/d1/2)V,
The covariance attention mechanism is calculated
V s o f t m a x   ( K T Q / τ ) , \bm V \mathrm{softmax}\,\left(\bm K^{\mathrm T}\bm Q/\tau\right), Vsoftmax(KTQ/τ),
among τ \tau τ It's a relationship d 1 / 2 d^{1/2} The position of d1/2 is similar to that of hyperparameters. A significant difference between the two is the calculation Q K T \bm Q\bm K^{\mathrm T} QKT needs N 2 d N^2d N2d times multiplication, and calculation K T Q \bm K^{\mathrm T}\bm Q KTQ needs N d 2 Nd^2 Nd2 times multiplication. It can be seen that the complexity of covariance attention mechanism is linear with respect to the length of the sequence.

on the other hand,
K T Q = W k T X T X W q , \bm K^{\mathrm T}\bm Q=\bm W_{\mathrm k}^{\mathrm T}\bm X^{\mathrm T}\bm X\bm W_{\mathrm q}, KTQ=WkT​XTXWq​,
And among them X T X \bm X^{\mathrm T}\bm X XTX is X \bm X The (non normalized) covariance matrix of the row vector of X. This is also the origin of the name of covariance attention mechanism.

Quick start

Training the following commands can evaluate the trained model.

!python PASSL/tools/train.py -c PASSL/configs/xcit/xcit_nano_12_p8_224.yaml \
                             --load xcit_nano_12_p8_224_dist.pdparams \
[03/01 19:47:20] passl INFO: Configs: {'epochs': 400, 'output_dir': 'outputs', 'seed': 0, 'device': 'gpu', 'model': {'name': 'SwinWrapper', 'architecture': {'name': 'XCiT', 'patch_size': 8, 'embed_dim': 128, 'depth': 12, 'num_heads': 4, 'eta': 1.0, 'tokens_norm': False}, 'head': {'name': 'SwinTransformerClsHead', 'in_channels': 128, 'num_classes': 1000}}, 'dataloader': {'train': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': True, 'drop_last': True}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/train/', 'return_label': True, 'transforms': [{'name': 'RandomResizedCrop', 'size': 224, 'scale': [0.08, 1.0], 'interpolation': 'bicubic'}, {'name': 'RandomHorizontalFlip'}, {'name': 'AutoAugment', 'config_str': 'rand-m9-mstd0.5-inc1', 'interpolation': 'bicubic', 'img_size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}, {'name': 'RandomErasing', 'prob': 0.25, 'mode': 'pixel', 'max_count': 1}], 'batch_transforms': [{'name': 'Mixup', 'mixup_alpha': 0.8, 'prob': 1.0, 'switch_prob': 0.5, 'mode': 'batch', 'cutmix_alpha': 1.0}]}}, 'val': {'loader': {'num_workers': 8, 'use_shared_memory': True}, 'sampler': {'batch_size': 128, 'shuffle': False, 'drop_last': False}, 'dataset': {'name': 'ImageNet', 'dataroot': 'data/ILSVRC2012/val', 'return_label': True, 'transforms': [{'name': 'Resize', 'size': 224, 'interpolation': 'bicubic'}, {'name': 'CenterCrop', 'size': 224}, {'name': 'Transpose'}, {'name': 'Normalize', 'mean': [123.675, 116.28, 103.53], 'std': [58.395, 57.12, 57.375]}]}}}, 'lr_scheduler': {'name': 'LinearWarmup', 'learning_rate': {'name': 'CosineAnnealingDecay', 'learning_rate': 0.0005, 'T_max': 400, 'eta_min': 1e-05}, 'warmup_steps': 5, 'start_lr': 1e-06, 'end_lr': 0.0005}, 'optimizer': {'name': 'AdamW', 'beta1': 0.9, 'beta2': 0.999, 'weight_decay': 0.05, 'exclude_from_weight_decay': ['temperature', 'pos_embed', 'cls_token', 'dist_token']}, 'log_config': {'name': 'LogHook', 'interval': 10}, 'checkpoint': {'name': 'CheckpointHook', 'by_epoch': True, 'interval': 1, 'max_keep_ckpts': 50}, 'custom_config': [{'name': 'EvaluateHook'}], 'is_train': False, 'timestamp': '-2022-03-01-19-47'}
[03/01 19:47:20] passl.engine.trainer INFO: train with paddle 2.2.2 on CUDAPlace(0) device
W0301 19:47:20.903514 23036 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0301 19:47:20.908339 23036 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[03/01 19:47:25] passl.engine.trainer INFO: Number of Parameters is 3.05M.
[03/01 19:47:30] passl.engine.trainer INFO: start evaluate on epoch 1 ..
[03/01 19:47:30] passl.engine.trainer INFO: Evaluate total samples 50000
100%|█████████████████████████████████████████| 391/391 [02:41<00:00,  2.43it/s]
[03/01 19:50:11] passl.engine.trainer INFO: Validate Epoch [1] acc1 (77.276), acc5 (93.248)

