[Paper Review & Implementation] Aggregated Residual Transformations for Deep Neural Networks (ResNeXt, 2017) & Depthwise Separable Convolution

[Paper Review & Implementation] Aggregated Residual Transformations for Deep Neural Networks (ResNeXt, 2017) & Depthwise Separable Convolution

2023, Jun 04    

Outlines


References



Split-Transform-Merge Strategy


  • Grouped convolution performs a set of transformations that consist of a low-dimensional embedding and aggregation of output.
  • All transformations are in same topology, which distinguishes ResNeXt from Inception-ResNet (Inception v4) that also involves branching and concatenating in residual function.

   Figure 1. : Original bottleneck block in ResNet vs ResNeXt with caridnality = 32 with roughly same complexity

   


Aggregated Transformation


    - $\large y = x + \sum_{i=1}^{C} T_i(x)$,   where $\large T_i(x)$ has same topology for all $\large i$


Equivalent Building Blocks for ResNeXt


   Figure 3. : Equivalent blcoks for ResNeXt (Figure 3-(c) is the final implementation of gouped convolution)

   


Basic Building Block (Bottleneck)


  • Example of basic bottleneck building block in Conv2 stage of ResNeXt architecture

   

   


  • grouped convolution layer is a wider and sparsely-connected version of original ResNeXt in Figure 1. (left)


Comparison with Other Neural Architectures


1. Cardinality vs Width


   Table 2. : Trade-Off between cardinality and width of block (d) with preserved complexity

   

  • Example of how to determine appropriate trade-off combination of cardinality and bottleneck width while maintaining the complexity

    • 1) Original ResNet bottleneck block has 256 x 64 + 3 x 3 x 64 x 64 + 64 x 256 ≈ 70k parameters
    • 2) With bottleneck width = d, parameters = C · (256 · d + 3 · 3 · d · d + d · 256)
    • find the combination of C and d that gives roughly identical complexity.


   Figure 5. Comparison of train and validation accuracy between ResNet 50, 101 (1 x 64d) and ResNeXt-50, 101 (32 x 4d)

   

   Table 3. : Results of increasing cardinality at the expense of bottleneck width (d) (Experiments on ImageNet-1K)

   


  • for both 50 and 101 layer networks, every ResNeXt with cardinality increasing from 2 to 32 outperforms its ResNet couterpart.

  • Because the improvement from trade-off becomes saturated as cardinality increases, the paper adopted the bottleneck width no smaller than 4d.


2. Cardinality vs Depth/Width with Increased Complexity


   Table 4. : Comparison of deeper/wider/increased cardinality networks with complexity doubled to ResNet-101’s

   


  • highlighted part indicates where the complexity is increased (baseline : 1x complexity reference)

  • ResNeXt-101 models with increased cardinality (2 x 64d, 64 x 4d) record lower top-1 and top-5 error rate compared to both deeper and wider ResNet.

  • What’s remarkable here is that ResNeXt with 1x complexity (32 x 4d) performs better than ResNet-200 and wider ResNet-101 even with 50% complexity.


   Figure 7. Effects of Increasing Complexity (Parameters) with Cardinality vs Width (on CIFAR-10)

   


  • It’s always better to increase cardinality instead of transformation width to reduce the test error rate


3. ResNet vs Inception Net vs ResNeXt


   Table 5. experimented on ImageNet-1K

   


4. With or Without Residual (Skip) Connections


   

  • Introducing grouped convolution method shows consistent improvement in the performance of networks with or without skip connections.

  • Residual connections are known to relieve non-convexities existing in the loss surface, resulting in easier and faster optimization.

  • Considering all these, residual connections are helpful for optimization process, whereas aggregated transformations achieve stronger representations where networks learn to capture more representationally important patterns or features of the input images.


ResNeXt Architecture


   


  • Performs downsampling with stride 2 3x3 grouped convolution for every first block at each conv stage.


Implementation with PyTorch


class Bottleneck(nn.Module):
    expansion = 2
    def __init__(self, in_channels, out_channels, cardinality=32, stride=1, projection=False):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                    nn.BatchNorm2d(out_channels),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1, groups=cardinality, bias=False),
                                    nn.BatchNorm2d(out_channels),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(out_channels, out_channels*self.expansion, 1, bias=False),
                                    nn.BatchNorm2d(out_channels*self.expansion))
        if projection:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*self.expansion, 1, stride=stride),
                                          nn.BatchNorm2d(out_channels*self.expansion))
        else:
            self.shortcut = nn.Identity()

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)

        return self.relu(residual + shortcut)


class ResNeXt(nn.Module):
    def __init__(self, in_channels, block, expansion, cardinality, block_repeats, num_classes, zero_init_residual=True):
        super().__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True))
        self.conv2_pool = nn.MaxPool2d(3, stride=2, padding=1)

        self.expansion = expansion
        out_channels, self.conv2_blocks = self.stack_blocks(block, 64, block_repeats[0], cardinality, 2)
        out_channels, self.conv3_blocks = self.stack_blocks(block, out_channels, block_repeats[1], cardinality, 2)
        out_channels, self.conv4_blocks = self.stack_blocks(block, out_channels, block_repeats[2], cardinality, 2)
        out_channels, self.conv5_blocks = self.stack_blocks(block, out_channels, block_repeats[3], cardinality, 2)

        self.gap = nn.AdaptiveAvgPool2d((1,1))   # 1x1x2048
        self.classifier = nn.Linear(out_channels, 1000)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            # zero-initialize the last BN in each residual branch -> set the start of residual branch as zero (identity mapping)
            # improves model by 0.2~0.3%p (https://arxiv.org/abs/1706.02677)
            elif isinstance(m, block):    
                nn.init.constant_(m.residual[-1].weight, 0)

    def stack_blocks(self, block, in_channel, block_repeat, cardinality, stride):
        stacked = []

        c, repeats = block_repeat
        for _ in range(repeats):
            if stride == 2 or in_channel != c*self.expansion:
                stacked += [block(in_channel, c, cardinality, stride, True)]
                in_channel = c*self.expansion
                stride = 1
            else:
                stacked += [block(in_channel, c, cardinality)]
                in_channel = c*self.expansion

        return c*self.expansion, nn.Sequential(*stacked)  

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_pool(x)
        x = self.conv2_blocks(x)
        x = self.conv3_blocks(x)
        x = self.conv4_blocks(x)
        x = self.conv5_blocks(x)
        x = self.gap(x)
        x = torch.flatten(x, start_dim=1)
        out = self.classifier(x)

        return out


Model Summary


block_repeats = [(128,3), (256,4), (512,6), (1024,3)]
expansion = 2
cardinality = 32

model = ResNeXt(3, Bottleneck, expansion, cardinality, block_repeats, 1000)
summary(model, input_size=(2, 3, 224, 224), device='cpu')


==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNeXt                                  [2, 1000]                 --
├─Sequential: 1-1                        [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 64, 112, 112]         9,472
│    └─BatchNorm2d: 2-2                  [2, 64, 112, 112]         128
│    └─ReLU: 2-3                         [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [2, 64, 56, 56]           --
├─Sequential: 1-3                        [2, 256, 28, 28]          --
│    └─Bottleneck: 2-4                   [2, 256, 28, 28]          --
│    │    └─Sequential: 3-1              [2, 256, 28, 28]          46,592
│    │    └─Sequential: 3-2              [2, 256, 28, 28]          17,152
│    │    └─ReLU: 3-3                    [2, 256, 28, 28]          --
│    └─Bottleneck: 2-5                   [2, 256, 28, 28]          --
│    │    └─Sequential: 3-4              [2, 256, 28, 28]          71,168
│    │    └─Identity: 3-5                [2, 256, 28, 28]          --
│    │    └─ReLU: 3-6                    [2, 256, 28, 28]          --
│    └─Bottleneck: 2-6                   [2, 256, 28, 28]          --
│    │    └─Sequential: 3-7              [2, 256, 28, 28]          71,168
│    │    └─Identity: 3-8                [2, 256, 28, 28]          --
│    │    └─ReLU: 3-9                    [2, 256, 28, 28]          --
├─Sequential: 1-4                        [2, 512, 14, 14]          --
│    └─Bottleneck: 2-7                   [2, 512, 14, 14]          --
│    │    └─Sequential: 3-10             [2, 512, 14, 14]          217,088
│    │    └─Sequential: 3-11             [2, 512, 14, 14]          132,608
│    │    └─ReLU: 3-12                   [2, 512, 14, 14]          --
│    └─Bottleneck: 2-8                   [2, 512, 14, 14]          --
│    │    └─Sequential: 3-13             [2, 512, 14, 14]          282,624
│    │    └─Identity: 3-14               [2, 512, 14, 14]          --
│    │    └─ReLU: 3-15                   [2, 512, 14, 14]          --
│    └─Bottleneck: 2-9                   [2, 512, 14, 14]          --
│    │    └─Sequential: 3-16             [2, 512, 14, 14]          282,624
│    │    └─Identity: 3-17               [2, 512, 14, 14]          --
│    │    └─ReLU: 3-18                   [2, 512, 14, 14]          --
│    └─Bottleneck: 2-10                  [2, 512, 14, 14]          --
│    │    └─Sequential: 3-19             [2, 512, 14, 14]          282,624
│    │    └─Identity: 3-20               [2, 512, 14, 14]          --
│    │    └─ReLU: 3-21                   [2, 512, 14, 14]          --
├─Sequential: 1-5                        [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-11                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-22             [2, 1024, 7, 7]           864,256
│    │    └─Sequential: 3-23             [2, 1024, 7, 7]           527,360
│    │    └─ReLU: 3-24                   [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-12                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-25             [2, 1024, 7, 7]           1,126,400
│    │    └─Identity: 3-26               [2, 1024, 7, 7]           --
│    │    └─ReLU: 3-27                   [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-13                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-28             [2, 1024, 7, 7]           1,126,400
│    │    └─Identity: 3-29               [2, 1024, 7, 7]           --
│    │    └─ReLU: 3-30                   [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-14                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-31             [2, 1024, 7, 7]           1,126,400
│    │    └─Identity: 3-32               [2, 1024, 7, 7]           --
│    │    └─ReLU: 3-33                   [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-15                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-34             [2, 1024, 7, 7]           1,126,400
│    │    └─Identity: 3-35               [2, 1024, 7, 7]           --
│    │    └─ReLU: 3-36                   [2, 1024, 7, 7]           --
│    └─Bottleneck: 2-16                  [2, 1024, 7, 7]           --
│    │    └─Sequential: 3-37             [2, 1024, 7, 7]           1,126,400
│    │    └─Identity: 3-38               [2, 1024, 7, 7]           --
│    │    └─ReLU: 3-39                   [2, 1024, 7, 7]           --
├─Sequential: 1-6                        [2, 2048, 4, 4]           --
│    └─Bottleneck: 2-17                  [2, 2048, 4, 4]           --
│    │    └─Sequential: 3-40             [2, 2048, 4, 4]           3,448,832
│    │    └─Sequential: 3-41             [2, 2048, 4, 4]           2,103,296
│    │    └─ReLU: 3-42                   [2, 2048, 4, 4]           --
│    └─Bottleneck: 2-18                  [2, 2048, 4, 4]           --
│    │    └─Sequential: 3-43             [2, 2048, 4, 4]           4,497,408
│    │    └─Identity: 3-44               [2, 2048, 4, 4]           --
│    │    └─ReLU: 3-45                   [2, 2048, 4, 4]           --
│    └─Bottleneck: 2-19                  [2, 2048, 4, 4]           --
│    │    └─Sequential: 3-46             [2, 2048, 4, 4]           4,497,408
│    │    └─Identity: 3-47               [2, 2048, 4, 4]           --
│    │    └─ReLU: 3-48                   [2, 2048, 4, 4]           --
├─AdaptiveAvgPool2d: 1-7                 [2, 2048, 1, 1]           --
├─Linear: 1-8                            [2, 1000]                 2,049,000
==========================================================================================
Total params: 25,032,808
Trainable params: 25,032,808
Non-trainable params: 0
Total mult-adds (G): 2.44
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 145.72
Params size (MB): 100.13
Estimated Total Size (MB): 247.05
==========================================================================================


x = torch.randn(2,3,224,224)
out = model(x)
print(out.shape)
out


torch.Size([2, 1000])

tensor([[ 1.3636,  1.0331, -0.1438,  ...,  0.6048,  0.3152,  0.3505],
        [ 1.3330,  0.8973, -0.1342,  ...,  0.5798,  0.0394,  0.3434]],
       grad_fn=<AddmmBackward0>)


Depthwise Separable Convolution


  • In computer vision, it is important to retain enough contextual information to make accurate classification for a given input image. This can be done by increasing the kernel size of receptive field or concatenating multiple convolution layers (5x5 = 3x3 + 3x3). These approaches, However, dramatically increase the computational complexities. To tackle this issue, researchers have developed a variety of other convolution methods that aim to reduce the computations while preserving the spatial information. Here, depthwise separable convolution is one of those.

  • Traditional convolution layers perform spatial and channel-wise convolution at the same time, which means a single output feature requries a separate set of filters with full channels.

  • Depthwise separable convolution, on the other hand, splits the convolution operation into two separate stages: depthwise convolution and pointwise convolution.


     


1. Depthwise Convolution :

  • Idential to the grouped convolution with cardinality equal to total channel size.
  • Only performs spatial convolution where each channel of the input requires a single filter and concatenates the resultant feature maps.
  • Each feature map contains spatial representation within one chanenl.
  • This reduces the network complexities, as each filter is only responsible for convolving with its corresponding input channel.

   


2. Pointwise Convolution :

  • Applying convolution with size 1, stride 1, merging the output of depthwise convolution into a desired number of feature maps.
  • Creates new representations by linear combination of all channels with learned weights.

   

  • One can improve model capacity by increasing parameters using the saved resources from depthwise convolution