[Paper Review & Implementation] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (ViT, 2021)

[Paper Review & Implementation] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (ViT, 2021)

2023, Jun 18    

Outlines


Implementation with PyTorch



Reference


Transformer for Computer Vision


  • This paper developed an image classfication model only with a pure transformer applied directly to the sequence of patched images while removing the convolutional networks completely out of the architecture, which has been typically used in vision task.

  • There have been multiple other trials to incorporate transformer networks into computer vision before. However, those works couldn’t achieve practical effectiveness to be applied to hardware accelerators as they require specialized attention patterns.

    • Parmar et al. (2018) : Applied the self-attention only in local neighborhoods for each query pixel (not patchified) instead of globally.

    • Child et al. (2019) : Sparse Transformer - sparsely arrange the area subjected to attention


       

  • Vision Transformer (ViT) suggested in this paper, on the other hand, successfully utilized the transformer structure without any complex modifications.

  • The model tried to compensate the relative simpleness of the networks architecture by undergoing severe pre-training on lareger datasets (greater than ImageNet) and finally could outperform the CNN based ResNet-like networks with much fewer computational complexities.


Transformer vs CNN : Lack of Inductive Bias


Convolutional Layer


  • CNN is designed for specific application to image data, thus has strong inherent inductive bias about the structure of image data such as locality and translational invariance.

    • Locality : Apply fixed sized 2-Dimensional filters to capture neighborhood structure with an assumption that elements composing important patterns to recognize the image are adjacently localized in small area, rather than broadly spread over the image.

      • Filters to capture a specific feature share identical parameters regardless of that feature’s absolute position on the image.


       


    • Translational Invariance : Using locally focused filter, hardly affected by global translation of the image along with axes.

      


  • Transformer, however, lacks some of these inductive biases associated with the characteristics of image data.


Adding Image-Specific Inductive Bias : Positional Embedding (PE)


  • In addition to pre-training on large datasets, ViT does have its own way to overcome its insufficient inductive bias.


Learnable Position Embedding


  • PE matrix contains weights with shape of sequence length x embedding_dimension.

  • Different from postional encoding, all parameters are not fixed and optimized during training just like other learnable paramters in the networks.

  • Through this process, one can embed the sense of relative position (order) of each patch across the image inside the model.

  • The effect of position embeddings in spatial representation will be further addressed later in this post.


Architecture of ViT


  Figure 1: Model overview

  


  


Embeddings


  


  • 1. Image Input : n_batch (B) x channel (C) x Height (H) x Width (W)

  • 2. Patch + Position Embedding :

    • Concatenate the output of both embeddings to get a final input for the encoder of ViT.

    • Patch Embedding :

      • Patchify : B x C x H x W → B x N x (C*P*P) where P is a patch size and N = (H*W) / (C*P*P)

      • Embedding : Input_dim (C*P*P) → embed_dim (D)

        class PatchEmbedding(nn.Module):
                  
            def __init__(self, p, input_dim, embed_dim):
                super(PatchEmbedding, self).__init__()
                """
                Embedding image input (n_batch, channel, H, W) into (n_batch, N, (P*P*C)) where N = H*W/P*P*C
                Args :
                    - p : patch size
                """
                self.p = p
                self.patch_embedding = nn.Sequential(Rearrange('b c (h1 p) (w1 p) → b (h1 w1) (c p p)', p = self.p),
                                                    nn.Linear(input_dim, embed_dim))
      
            def forward(self, x):
                x = self.patch_embedding(x)
                return x 
      


    • Positional Embedding :

      • Shape : N x D
        class PositionalEmbedding(nn.Module):
      
            def __init__(self, embed_dim, max_len=5000):
                """
                Construct the PositionalEncoding layer.
                Args:
                - embed_dim: the size of the embed dimension
                - dropout: the dropout value
                - max_len: the maximum possible length of the incoming sequence
                """
                super(PositionalEncoding, self).__init__()
                      
                self.pos_emb = nn.Parameter(torch.randn(max_len, embed_dim))      
      
            def forward(self, x):
                n_batch, N, embed_dim = x.shape
                pe_output = x + self.pos_emb[:N, :]
                return pe_output
      


  • 3. Classification Token Embedding :

    • Add a specific a token to seqeunce dimension (N → N + 1) that holds classfication information

    • In the final step, this class token sequence will be linearly transformed to possess probability score for each class (D → num_classes)

        class ClassTokenEmbedding(nn.Module):
      
            def __init__(self, n_batch, embed_dim):
                super(PatchEmbedding, self).__init__()
                """
                Add classfication token to the sequence of embedded patches. (n_batch, N, embed_dim) -> (n_batch, N+1, embed_dim)
                Args :
                    - n_batch : batch size
                    - embed_dim : patch embedded dimension 
                """
                self.classtoken = nn.Parameter(torch.randn(n_batch, 1, embed_dim))
      
            def forward(self, x):
                      
                return torch.cat([x, self.classtoken], dim=1)
      


Transformer Encoder Block


  • Almost identical with the transformer encoder structure in the previous post except that layernorm layer is implemented before MHA and MLP sub-layers

  • Consists of 2 sub-layers, MHA and MLP with GELU non-linearity, with residual connection inserted between each sub-layer.

  • Repeated for L times


MLP Head for Classification Token Sequence


  • Linear transformation applied to class token to compute final classification scores with preceding normalization step.

      class ClassificationHead(nn.Module):
          """
          Final MLP to get classification head : eithr mean or first element 
          """
          def __init__(self, embed_dim, num_classes, pool):
              super(ClassificationHead, self).__init__()
    
              self.pool = pool
              self.layernorm = LayerNorm(embed_dim)
              self.mlp = nn.Sequential(self.layernorm,
                                      nn.Linear(embed_dim, num_classes))
    
          def forward(self,x):
              """
              Args
                  - x : output of encoder (n_batch, N, embed_dim)
              """
              classhead = x.mean(dim=1) if self.pool == 'mean' else x[:, 0, :]
    
              classhead = self.mlp(classhead)
              return classhead
    


Image Representation Encoded in Vision Transformer


  1. Figure 7 (left) : RGB values of first-layer embedding filters of ViT-L/32

  

  • 28 principal Embedding filters of initial linear embedding layer. (Extracted through PCA)

  • Each filter seems to represent certain spatial patterns such as line, edge, circle, and rectangles, which are similar to the low-level features captured from early stages of CNN.


  2. Figure 7 (center) : Cosine similarity between the position embeddings

  


  • Add learned position embedding to former linear projections.

  • Each box represents cosine similarity between position embeddings of the patch with indicated row and coloumn and position embeddings of all patches.

  • Closer patches tend to share similar positional embeddings, showing that position embedding can encode the relative distance between patches.


  2. Figure 7 (right) : Size of attended area by head and network depth

  

  • Attended area is analogous to the receptive field size in CNN, indicating the pixel distance across images a single attention can capture.

  • Each dot shows the average distance spanned by attention weights of one of the 16 heads at different layers.

  • Earlier layers tend to show high variations in the average distance, but still some of them attend to the entire image globally while others attend to the smaller localized area close to the query patch location.

  • As depth increases, attention distance increases to the most regions of the image for all heads.

  • This shows that each attention becomes able to capture higher-level features that reflect the general representation of the image as layer depth increases, which is quite similar to CNN.



  

  • ViT configurations on those used for BERT (Devlin et al., 2019). The base and large models are directly adopted from BERT and the larger huge model is added in this paer.

  Table 2.


  • Large and huge ViT models pre-trained on the JFT-300M dataset outperform ResNet-based baselines on all fine-tuning datasets, while taking substantially less computational resources to pre-train. ViT pre-trained on the smaller public ImageNet-21k dataset performs well too, but sligtly weaker.

  • The results show that vision transformers pre-trained with larger dataset can achieve better performance in spite of weak inductive biases compared to CNN.


Effect of Larger Pre-Training Datasets on the Performance of ViT


  Figure 3, 4.

  • When pre-trained on the smallest dataset, ImageNet, ViT-Large models underperform compared to ViT-Base models and BiT ResNets (gray shaded area). With ImageNet-21k pre-training, their performances are similar.

  • Only with JFT-300M, the largest dataset, do the large ViTs start to overtake smaller ViT and BiT.




  • To summarize, the vision transformer effectively utilized the transformer structure in computer vision with minimal modifications, surpassing the performance of previous state-of-the-art ResNet-based CNN models while significantly reducing computational cost.

  • The performance of ViT after fine-tuning depends largely on the size of pre-trained dataset.