[Paper Review & Implementation] Deep Residual Learning for Image Recognition (ResNet, 2015) + Interpretation of ResNet as Ensembles of Networks (2016)

[Paper Review & Implementation] Deep Residual Learning for Image Recognition (ResNet, 2015) + Interpretation of ResNet as Ensembles of Networks (2016)

2023, Jun 01    

Outlines


References


Issue of Interest - Degradation


  • As the depth of network increases, a degradation problem where deeper network produces lower training accuracy compared to its shallower couterpart become more pronounced.

  • Previously, main concern regarding to stacking more layers to network is gradient vanishing and exploding problem, which has been addressed with batch normalization and normalized initialization (e.g. kaiming initialization).

  • Even without those vanishing gradient problems, accuracy of neural networt starts to gets saturated as layer goes deeper and shows higher training error than the shallower networks, which represents adding layers can lead to under-fitting of the network.

      


  • The result above in the Figure 1. where 56-layer model recorded lower training error than 20-layer doesn’t make sense because there’s always a solution where 56-layer do the same thing for 20 layers and added layers are just identity mappings.

  • Occurence of the degradation problem even with the existence of simple solution indicates that current deeper networks might have difficulties in approximating identity mappings by deeply stacked multiple layers.

  • Hence, the paper introduces “deep residual learning” where network learns to fit only residual part (difference between desired underlying mapping and input), instead of directly optimizing the underlying mappings


      image


  • $\large H(x) \approx x \,= F(x) + x$   $\large H(x)$ : desired underlying mappings,  $\large F(x)$ : residual mappings,  $\large x$ : input

  • This approach seems quite reasonable in that as the depth of network increases, the optimal role of single layer tends to be limited to make very small and accurate perturbations from the previous input with a resonable preconditioning x. (This assumption is quite validated in Figure 7. that shows standard deviations (std) of layer responses of ResNet model is smaller than plain model in genenral)


      image


  • Under this network architecture, optimal function for identity mapping is simply zero instead of optimizing completely new funciton that approximates the input.


Skip Connection and Projection Connection


  • To implement residual learning, the researcher added direct shortcut from input to output of residual mapping function

     $\large y = F(x, {W_i}) + x$

  • The function $F(x, {W_i})$ represents the residual mapping to be learned that consists of two or more layers

  • Operations required to perform the shorcut connection is no more than a simple element-wise addition, which introduces neither extra parameters nor computational complexity

  • In case where the dimensions of F(x) changes from x due to the computations during residual mapping (e.g. convolution with stride 2 or convolution with larger feature maps), projection shortcut can be used to match the dimension. (typically done by 1x1 convolution with stride or matched featured maps)

     $\large y = F(x, {W_i}) + W_sx$


ResNet Architecture


1. Comparison with VGG-19

  


  • solind lines represent skip connection and dashed lines represent projection lines

2. Bottleneck Building Block

  • to reduce the number of parameters, first implement 1x1 convolution with reduced feature maps and then perform 3x3 convolution followed by again 1x1 convolution with increased feature dimension.
    • replace 2-layer block (3x3 -> 3x3) into 3-layer bottleneck block (1x1 -> 3x3 -> 1x1)
  • By using bottleneck blocks, 152-layer (11.3 bilion FLOPs) ResNet still has lower complexity compared to VGG-16/19 (15.3 bilion FLOPs)


      image


3. Downsampling

  • performs down-sampling with 1x1 convolution with stride 2 at the first conv layer in each block
    • ResNet v1.5 uses 3x3 conv with stride 2 instead of 1x1 to mitigate the loss of spatial information


Architectures of ResNet

image


Implementation with PyTorch


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1, projection=False):
        super().__init__()
        
        self.residual = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False),
                                    nn.BatchNorm2d(out_channels),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(out_channels, out_channels*self.expansion, 3, padding=1, bias=False),
                                    nn.BatchNorm2d(out_channels*self.expansion))    # enters into relu activation 
        if projection:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*self.expansion, 1, stride=stride),
                                          nn.BatchNorm2d(out_channels*self.expansion))
        else:
            self.shortcut = nn.Identity()

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)

        return self.relu(residual + shortcut)


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1, projection=False):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1, bias=False),
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(out_channels, out_channels*self.expansion, 1, bias=False),
                                   nn.BatchNorm2d(out_channels*self.expansion))
        if projection:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels*self.expansion, 1, stride=stride),
                                          nn.BatchNorm2d(out_channels*self.expansion))
        else:
            self.shortcut = nn.Identity()

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)

        return self.relu(residual + shortcut)


class ResNet(nn.Module):
    def __init__(self, in_channels, block, expansion, block_repeats, num_classes, zero_init_residual=True):
        super().__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True))
        self.conv2_pool = nn.MaxPool2d(3, stride=2, padding=1)

        self.expansion = expansion
        out_channels, self.conv2_blocks = self.stack_blocks(block, 64, block_repeats[0], 1)
        out_channels, self.conv3_blocks = self.stack_blocks(block, out_channels, block_repeats[1], 2)
        out_channels, self.conv4_blocks = self.stack_blocks(block, out_channels, block_repeats[2], 2)
        out_channels, self.conv5_blocks = self.stack_blocks(block, out_channels, block_repeats[3], 2)

        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(out_channels, 1000)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            # zero-initialize the last BN in each residual branch -> set the start of residual branch as zero (identity mapping)
            # improves model by 0.2~0.3%p (https://arxiv.org/abs/1706.02677)
            elif isinstance(m, block):    
                nn.init.constant_(m.residual[-1].weight, 0)


    def stack_blocks(self, block, in_channel, block_repeat, stride):
        stacked = []

        c, repeats = block_repeat
        for _ in range(repeats):
            if stride == 2 or in_channel != c*self.expansion:
                stacked += [block(in_channel, c, stride, True)]
                in_channel = c*self.expansion
                stride = 1
            else:
                stacked += [block(in_channel, c)]
                in_channel = c*self.expansion

        return c*self.expansion, nn.Sequential(*stacked)  

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_pool(x)
        x = self.conv2_blocks(x)
        x = self.conv3_blocks(x)
        x = self.conv4_blocks(x)
        x = self.conv5_blocks(x)
        x = self.gap(x)
        x = torch.flatten(x, start_dim=1)
        out = self.classifier(x)

        return out


Model Summary


block_repeats = {'resnet34' : [(64,3), (128,4), (256,6), (512,3)],
                'resnet101' : [(64,3), (128,4), (256,23), (512,3)],
                'resnet152' : [(64,3), (128,8), (256,36), (512,3)]}

model = ResNet(3, Bottleneck, 4, block_repeats['resnet152'], 1000)
summary(model, input_size=(2, 3, 224, 224), device='cpu')


==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [2, 1000]                 --
├─Sequential: 1-1                        [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 64, 112, 112]         9,472
│    └─BatchNorm2d: 2-2                  [2, 64, 112, 112]         128
│    └─ReLU: 2-3                         [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [2, 64, 56, 56]           --
├─Sequential: 1-3                        [2, 256, 56, 56]          --
│    └─Bottleneck: 2-4                   [2, 256, 56, 56]          --
│    │    └─Sequential: 3-1              [2, 256, 56, 56]          58,112
│    │    └─Sequential: 3-2              [2, 256, 56, 56]          17,152
│    │    └─ReLU: 3-3                    [2, 256, 56, 56]          --
│    └─Bottleneck: 2-5                   [2, 256, 56, 56]          --
│    │    └─Sequential: 3-4              [2, 256, 56, 56]          70,400
│    │    └─Identity: 3-5                [2, 256, 56, 56]          --
│    │    └─ReLU: 3-6                    [2, 256, 56, 56]          --
│    └─Bottleneck: 2-6                   [2, 256, 56, 56]          --
│    │    └─Sequential: 3-7              [2, 256, 56, 56]          70,400
│    │    └─Identity: 3-8                [2, 256, 56, 56]          --
│    │    └─ReLU: 3-9                    [2, 256, 56, 56]          --
├─Sequential: 1-4                        [2, 512, 28, 28]          --
│    └─Bottleneck: 2-7                   [2, 512, 28, 28]          --
│    │    └─Sequential: 3-10             [2, 512, 28, 28]          247,296
│    │    └─Sequential: 3-11             [2, 512, 28, 28]          132,608
│    │    └─ReLU: 3-12                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-8                   [2, 512, 28, 28]          --
│    │    └─Sequential: 3-13             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-14               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-15                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-9                   [2, 512, 28, 28]          --
│    │    └─Sequential: 3-16             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-17               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-18                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-10                  [2, 512, 28, 28]          --
│    │    └─Sequential: 3-19             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-20               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-21                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-11                  [2, 512, 28, 28]          --
│    │    └─Sequential: 3-22             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-23               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-24                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-12                  [2, 512, 28, 28]          --
│    │    └─Sequential: 3-25             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-26               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-27                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-13                  [2, 512, 28, 28]          --
│    │    └─Sequential: 3-28             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-29               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-30                   [2, 512, 28, 28]          --
│    └─Bottleneck: 2-14                  [2, 512, 28, 28]          --
│    │    └─Sequential: 3-31             [2, 512, 28, 28]          280,064
│    │    └─Identity: 3-32               [2, 512, 28, 28]          --
│    │    └─ReLU: 3-33                   [2, 512, 28, 28]          --
├─Sequential: 1-5                        [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-15                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-34             [2, 1024, 14, 14]         986,112
│    │    └─Sequential: 3-35             [2, 1024, 14, 14]         527,360
│    │    └─ReLU: 3-36                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-16                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-37             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-38               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-39                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-17                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-40             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-41               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-42                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-18                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-43             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-44               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-45                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-19                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-46             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-47               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-48                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-20                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-49             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-50               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-51                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-21                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-52             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-53               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-54                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-22                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-55             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-56               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-57                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-23                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-58             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-59               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-60                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-24                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-61             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-62               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-63                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-25                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-64             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-65               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-66                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-26                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-67             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-68               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-69                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-27                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-70             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-71               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-72                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-28                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-73             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-74               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-75                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-29                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-76             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-77               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-78                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-30                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-79             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-80               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-81                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-31                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-82             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-83               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-84                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-32                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-85             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-86               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-87                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-33                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-88             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-89               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-90                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-34                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-91             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-92               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-93                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-35                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-94             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-95               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-96                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-36                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-97             [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-98               [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-99                   [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-37                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-100            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-101              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-102                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-38                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-103            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-104              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-105                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-39                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-106            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-107              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-108                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-40                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-109            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-110              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-111                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-41                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-112            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-113              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-114                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-42                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-115            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-116              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-117                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-43                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-118            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-119              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-120                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-44                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-121            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-122              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-123                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-45                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-124            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-125              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-126                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-46                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-127            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-128              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-129                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-47                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-130            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-131              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-132                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-48                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-133            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-134              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-135                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-49                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-136            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-137              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-138                  [2, 1024, 14, 14]         --
│    └─Bottleneck: 2-50                  [2, 1024, 14, 14]         --
│    │    └─Sequential: 3-139            [2, 1024, 14, 14]         1,117,184
│    │    └─Identity: 3-140              [2, 1024, 14, 14]         --
│    │    └─ReLU: 3-141                  [2, 1024, 14, 14]         --
├─Sequential: 1-6                        [2, 2048, 7, 7]           --
│    └─Bottleneck: 2-51                  [2, 2048, 7, 7]           --
│    │    └─Sequential: 3-142            [2, 2048, 7, 7]           3,938,304
│    │    └─Sequential: 3-143            [2, 2048, 7, 7]           2,103,296
│    │    └─ReLU: 3-144                  [2, 2048, 7, 7]           --
│    └─Bottleneck: 2-52                  [2, 2048, 7, 7]           --
│    │    └─Sequential: 3-145            [2, 2048, 7, 7]           4,462,592
│    │    └─Identity: 3-146              [2, 2048, 7, 7]           --
│    │    └─ReLU: 3-147                  [2, 2048, 7, 7]           --
│    └─Bottleneck: 2-53                  [2, 2048, 7, 7]           --
│    │    └─Sequential: 3-148            [2, 2048, 7, 7]           4,462,592
│    │    └─Identity: 3-149              [2, 2048, 7, 7]           --
│    │    └─ReLU: 3-150                  [2, 2048, 7, 7]           --
├─AdaptiveAvgPool2d: 1-7                 [2, 2048, 1, 1]           --
├─Linear: 1-8                            [2, 1000]                 2,049,000
==========================================================================================
Total params: 60,196,712
Trainable params: 60,196,712
Non-trainable params: 0
Total mult-adds (G): 23.03
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 721.75
Params size (MB): 240.79
Estimated Total Size (MB): 963.74
==========================================================================================


Forward Pass


x = torch.randn(2,3,224,224)
out = model(x)
print(out.shape)
out


torch.Size([2, 1000])
tensor([[ 1.7212,  0.7408, -0.7578,  ..., -0.3376,  0.8207, -0.8635],
        [ 1.6384,  0.6766, -0.7177,  ..., -0.4407,  0.7628, -0.8020]],
       grad_fn=<AddmmBackward0>)


Interpretation of ResNet as Ensembles of Multiple Networks with Varying Depth


  • There’s a research that proposes an interpretation of the architecture of residual networks as a collection of many paths of differing length (depth) (https://arxiv.org/abs/1605.06431)

  • ResNet contains $\large O(2^{n})$ implicit paths connecting input and output between each block (n equals to the number of block used) and adding one block in networks doubles the number of paths.


   Figure 1.

   


  • ResNet shows ensemble-like behavior and paths in the networks do not necessarily depend on each other, which is true in that deleting multiple blocks from 54-blocks residual networks doesn’t result in severe drop of test classification error. The performance of VGG networks, on the other hand, is significanlty affected by deleting a single layer as every path is connected across the entire depth in the VGG architecture.


   Figure 2.

   


   Figure 3.

   


  • Tendency of paths in the residual networks to behave in an ensemble-like manner can also be shown with the following experiment where randomly deleting serveral modules smoothly (not abruptly) decreases the model performance. This result indicates that the performance of networks smoothly correlates with the number of valid paths, which is also the case in the ensemble model.


   Figure 5.

   


Importance of Shorter Paths


  • Define the length of a path by the number of modules it passes throughout the entire networks.

  • Experiment in the paper shows that shorter paths play more crucial role in training the networks compared to longer paths

  • Distribution of length of paths of 54-layer ResNet (Figure 6.)

    • Bernoulli distribution with mean path length equals to np = 54*(1/2) = 27
    • more than 95% of paths go through 19 to 35 modules.


     


  • Magnitude (Norm) of gradient matrix of path with length k

    • To empirically investigate the effect of vanishing gradients on residual networks, starting from 54 blocks, researchers sample individual paths of a certain length (k) and measure the norm of the gradient that arrives at the input. To sample a path of length k, they first feed a batch forward through the whole network. During the backward pass, k residual blocks are randomly sampled. For those k blocks, propagation is only through residual modules and the rest n-k blocks propagate through the skip-connection. Then, measure the norm of the gradient matrix of the single k length path and plot the result with varying k from 1 to 54.

    • To find total gradient magnitude, multiply frequency of each path length to the expected magnitude of gradient.


      


  • The results show that the gradient magnitude of a path decays roughly exponentially with the number of residual modules it went through in the backward pass.

  • To summarize, each path in residual networks shows ensemble-like behavior, making independent contribution to the model and suprisingly, shorter lengthed paths (fewer residual modules passed trough back-propagation) has greater importance in the overall performance of networks.