[Paper Review & Partial Implementation] YOLOv4: Optimal Speed and Accuracy of Object Detection (YOLOv4, 2020)

[Paper Review & Partial Implementation] YOLOv4: Optimal Speed and Accuracy of Object Detection (YOLOv4, 2020)

2023, Jul 07    

Outlines


Reference



From YOLOv1 to YOLOv3


YOLO v1:

  • Review & Implementation
  • YOLO v1 introduced the concept of the YOLO architecture. It divided the input image into a grid and assigned each grid cell the responsibility of predicting bounding boxes and class probabilities.
  • Predicted fixed number of bounding boxes per grid cell, leading to potential localization errors.
  • Used a single scale feature map for detection, limiting its ability to detect multi-scale objects.


YOLO v2 (YOLO9000):


  • YOLO v2 made significant improvements over YOLO v1, addressing its limitations.
  • Introduces a new architecture with anchor boxes, which allowed the network to predict bounding box offsets relative to these anchor boxes. This improved the localization accuracy and enabled better handling of objects of different scales and aspect ratios.
  • Uses K-Means Clustering to determine the best number of anchor boxes that can optimize average IoU.
  • Limits the range of the coordinates of predicted bounding box within 0 ~ 1 by taking logistic regression (sigmoid) to the regression output, acclerating convergence.
  • Implements a multi-scale approach, where the networks combine features maps of different scales (26x26 and 13x13) using skip-connection. This facilitates the information flow from low-level (larger scale) to higher levels (smaller scale), enhancing the detection of objects of various sizes.
  • Trains the networks with multi-scale image inputs from 320 x 320 to 608 x 608.


YOLO v3:


  • YOLO v3 further improves upon the previous versions, focusing on better detection accuracy and handling a larger number of object categories.
  • Utilizes the concept of feature pyramid networks (FPN) to handle feature maps at different scales more efficiently.
  • Extracts multi-scaled feature maps (52, 26, 13) from different levels of feature pyramid and make separate predictions from each level.
  • Employs the use of the Darknet-53 backbone, a deep CNN architecture that enhanced the network’s feature extraction capabilities.


YOLO v4 : Designing the Optimal Model


Figure 1: Comparison of the proposed YOLOv4 and other state-of-the-art object detectors.


  • YOLOv4 performed extensive experiments to find out the optmial combinations of existing deep learning techniques for constructing each component (backbone, head, neck) of the architecture of YOLO v4.

  • Additionally, YOLOv4 also focuses on two types of methods, Bag of Freebies (BoF) and Bag of Specials (BoS), to further improve the object detection performance in terms of accuracy and speed.

  • As a result, YOLOv4 runs twice faster than EfficientDet with comparable performance and shows significant improvement in terms of performance compared to previous YOLO v3.

  • While the paper introduced several other recent deep learning techniques, I will specifically focus on the techniques that are actually adopted in YOLOv4.


Object Detection Model



BackBone : CSPDarkNet 53


  • Single architecture can show varying performance depeding on the selection of multiple sets of options, including the choice of dataset.

  • Paper thoroughly compared several backbone architectures to find the optimal balance among the input resolution, size of receptive field, depth convolutional layer, parameter number, number of outputs (channels) and the computational load. ( FPS : frames per second to measure the speed and efficiency of a model to process image data)


  • Paper explained that the CSPResNext50 is considerably better compared to CSPDarknet53 in terms of object classification on the ILSVRC2012 (ImageNet) dataset. However, conversely, the CSPDarknet53 is better compared to CSPResNext50 in terms of detecting objects on the MS COCO dataset.

  • Considiering all these options, CSPDarknet53 is selected as a backbone architecture of YOLOv4.


Cross Stage Partial DenseNet (CSPDenseNet)



  • Basic structure of CSPDarknet is similar with the figure above.

  • Cross stage partial network (CSPNet) consists of base layer, two separate paths, and final transition layer that merges two paths together.


  


  • Part 1 path (left) is a simple convolution layer - batch normalization - activation structure.

  • Part 2 path (right) is referred to as Partial Dense Block, which is a typical dense block that is composed of repeated dense layer and transition layer with a certain growth rate. Note the size of the output is same in both paths.

  • Transition 1 : transition layer only applied to partial dense block.

  • Transition 2 : receives combined outputs (channel-wise concatenation) from two paths as an input and performs transition.


PyTorch Implementation of Basic CSPBlock


class CSPBlock(nn.Module):
    def __init__(self, in_channel, is_first=False, num_blocks=1):
        super().__init__()
        self.part1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, 1, stride=1, padding=0, bias=False),
                                        nn.BatchNorm2d(in_channel//2),
                                        Mish())
        self.part2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, 1, stride=1, padding=0, bias=False),
                                        nn.BatchNorm2d(in_channel//2),
                                        Mish())
        self.features = nn.Sequential(*[ResidualBlock(in_channel=in_channel//2) for _ in range(num_blocks)])
        self.transition1_conv = nn.Sequential(nn.Conv2d(in_channel//2, in_channel//2, 1, stride=1, padding=0, bias=False),
                                              nn.BatchNorm2d(in_channel//2),
                                              Mish())
        self.transition2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                              nn.BatchNorm2d(in_channel),
                                              Mish())
        if is_first:
            self.part1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                            nn.BatchNorm2d(in_channel),
                                            Mish())
            self.part2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                            nn.BatchNorm2d(in_channel),
                                            Mish())
            self.features = nn.Sequential(*[ResidualBlock(in_channel=in_channel,
                                                          hidden_channel=in_channel//2) for _ in range(num_blocks)])
            self.transition1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                                  nn.BatchNorm2d(in_channel),
                                                  Mish())
            self.transition2_conv = nn.Sequential(nn.Conv2d(2 * in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                                  nn.BatchNorm2d(in_channel),
                                                  Mish())


Purposes of Designing CSPNet


  • There are several purposes of designing this kind of partial structure.
  1. Increase gradient path

    • By separting the dense block and transition layer, one can double the gradient path, preventing reuse of gradient in other path.

    • From the above figure, standard denseblock sequentially receives previous output as an input of next dense layer, which leads to excessive reuse of gradients across all depths.

    • However, partial dense block has two separate paths that don’t share gradient generated from each side, which in turn doubles the gradient flow in the network.

    • Separating transition layer further maximizes the difference of gradient combination.

      • Fusing transition layer after cocatenation of two paths (fusion first) significantly drops the performance (-1.5%) compared to CSPPeleeNet, whereas the case of only applying transition to Part 2 (fusion last) is not much affected (-0.1%). Computational cost is decreased for both cases, obviously.

      • This results demonstrate that enriched gradient flow is the key part that enhances the performance of CSPDenseNet.

  2. Prevent computational bottleneck

    • As the amount of feature maps (channels) subjected to dense block becomes half of original dense block, the computational bottleneck issue due to large gap between the number of channels and growth rate can also be alleviated.
  3. Reduce memory traffic

    • Computations required for dense block is $\large (c \times n) + {n \times (n + 1) \times k}$, where c is the number of channels, n is the number of dense layer, and k is the growth rate.

    • As the number of channel (c) is reduced to half, which is usually far greater than n, k, memory traffic can be saved by nearly half as well.


Neck : SPP, SAM, PAN


  • There are some layers inserted between backbone and head (make predictions for classes and boxes).

  • These layers are typically used to integrate and re-organize the feature maps extracted from backbone to make more comprehensive and semantically strong features that are robust to scale changes of objects.


Spatial Pyramid Pooling (SPP)


  • YOLOv4 adds the SPP block over the backbone to process the topmost feature maps (512 x 13 x 13).

  • SPPNet applys multiple MaxPooling layers with different kernel sizes (5, 9, 13) to the given input in parallel and concatenate each output (naive input with no pooling as well) to enlarge the receptive field. (final output : 512*4 x 13 x 13)


class SPPNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                        nn.BatchNorm2d(512),
                        Mish(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(2048, 2048, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(2048),
            Mish(),
        )

        self.maxpool5 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5//2)
        self.maxpool9 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9//2)
        self.maxpool13 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13//2)

    def forward(self, x):
        x = self.conv1(x)   # torch.Size([1, 512, 16, 16])
        maxpool5 = self.maxpool5(x)
        maxpool9 = self.maxpool9(x)
        maxpool13 = self.maxpool13(x)
        x = torch.cat([x, maxpool5, maxpool9, maxpool13], dim=1)
        x = self.conv2(x)
        return x


Self-Attention Module (SAM)




  • Instead of taking max and average pooling to get vectorized channel-wise attention, YOLOv4 implements 3 x 3 convolution to get pixel-wise attention that has same shape as the input feature map.

  • Take sigmoid activation to computed attention to get probabilistic attention scores (0 ~ 1) and multiply them to target feature maps.

  • There are other types of methods for assigning attention to backbone features such as Squeeze-Excitation module (SE) but this approach increases the inference time by aboout 10%, while SAM only needs to pay 0.1% extra calculation with slight improvement (0.5%) to SE based ResNet50 model.


Path Aggregation Networks (PAN)



  • YOLOv4 adopted PANet to build stronger feature maps that combine features from multiple levels of the pyramid, which are far more helpful for following predictions compared to features made from a single level.

  • Detailed explanation about the architecture of PANet is HERE.

  • A modification from original PANet : Uses concatenation instead of addition for bottom-up pathway.


class PANet(nn.Module):
    def __init__(self):
        super(PANet, self).__init__()

        self.p52d5 = nn.Sequential(nn.Conv2d(2048, 512, 1, stride=1, padding=0, bias=False),
                                   nn.BatchNorm2d(512),
                                   Mish(),
                                   nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                   nn.BatchNorm2d(1024),
                                   Mish(),
                                   nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                   nn.BatchNorm2d(512),
                                   Mish(),
                                   )

        self.p42p4_ = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    )

        self.p32p3_ = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(128),
                                    Mish(),
                                    )

        self.d5_p4_2d4 = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(512),
                                       Mish(),
                                       nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(512),
                                       Mish(),
                                       nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       )

        self.d4_p3_2d3 = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       )

        self.d52d5_ = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    nn.Upsample(scale_factor=2)
                                    )

        self.d42d4_ = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(128),
                                    Mish(),
                                    nn.Upsample(scale_factor=2)
                                    )

        self.u32u3_ = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish())

        self.u42u4_ = nn.Sequential(nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
                                    nn.BatchNorm2d(512),
                                    Mish())

        self.d4u3_2u4 = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),

                                      nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),

                                      nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),
                                      )

        self.d5u4_2u5 = nn.Sequential(nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(1024),
                                      Mish(),

                                      nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(1024),
                                      Mish(),

                                      nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),
                                      )

    def forward(self, P5, P4, P3):
        D5 = self.p52d5(P5)    # [B, 512, 13, 13]
        D5_ = self.d52d5_(D5)  # [B, 256, 26, 26]
        P4_ = self.p42p4_(P4)  # [B, 256, 26, 26]
        D4 = self.d5_p4_2d4(torch.cat([D5_, P4_], dim=1))   # [B, 256, 26, 26]
        D4_ = self.d42d4_(D4)                               # [B, 128, 52, 52]
        P3_ = self.p32p3_(P3)                               # [B, 128, 52, 52]
        D3 = self.d4_p3_2d3(torch.cat([D4_, P3_], dim=1))   # [B, 128, 52, 52]

        U3 = D3                                             # [B, 128, 52, 52]   V
        U3_ = self.u32u3_(U3)
        U4 = self.d4u3_2u4(torch.cat([D4, U3_], dim=1))     # [B, 256, 26, 26]   V
        U4_ = self.u42u4_(U4)                               # [B, 512, 13, 13]
        U5 = self.d5u4_2u5(torch.cat([D5, U4_], dim=1))     # [B, 512, 13, 13]   V

        return [U5, U4, U3]


Head : YOLOv4


  • Extract features with 3 scales (13, 26, 52) each for large, middle, small objects.


class YOLOv4(nn.Module):
    def __init__(self, backbone, num_classes=80):
        super(YOLOv4, self).__init__()
        self.num_classes = num_classes
        self.backbone = backbone
        self.SPP = SPPNet()
        self.PANet = PANet()

        self.pred_s = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    nn.Conv2d(256, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        self.pred_m = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(512),
                                    Mish(),
                                    nn.Conv2d(512, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        self.pred_l = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(1024),
                                    Mish(),
                                    nn.Conv2d(1024, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        print("num_params : ", self.count_parameters())

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x):

        P3 = x = self.backbone.features1(x)  # [B, 256, 52, 52]
        P4 = x = self.backbone.features2(x)  # [B, 512, 26, 26]
        P5 = x = self.backbone.features3(x)  # [B, 1024, 13, 13]

        P5 = self.SPP(P5)
        U5, U4, U3 = self.PANet(P5, P4, P3)

        p_l = self.pred_l(U5).permute(0, 2, 3, 1)  # B, 13, 13, 255
        p_m = self.pred_m(U4).permute(0, 2, 3, 1)  # B, 26, 26, 255
        p_s = self.pred_s(U3).permute(0, 2, 3, 1)  # B, 52, 52, 255

        return [p_l, p_m, p_s]


Selection of BoF and BoS


1. Bag of Freebies (BoF)


  • BoF refers to the methods that only change the training strategy to enhance the model peformance without increasing inference time.

  • Here are the several BoFs that can be utilized and the ones adopted in this paper are colored as red.


1.1. Data Augmentation


  • Pixel-wise adjustment : photometric distortion (brightness, contrast, hue, saturation, noise, etc.), geometric distortions (rotation, scailing, crop, reflecting, etc.)

  • Simulating object occlusion :

    • Random erase, CutOut, Grid Mask : Randomly select the single or multiple rectangle region(s) and replace them to random values or all zeros.


    • DropOut, DropConnect, DropBlock : Similar approaches can also be applied to feature maps instead of input images


    • MixUp, CutMix : Use multiple images together.


1.2. Semantic Distribution


  • In object detection, it is common to encounter a significant data imbalance between object/no-object and different classes. This imbalance can lead the model to be biased towards over-represented categories and negatively impact its ability to generalize well for the rare categories.

  • The issue can be handled differently depending on the type of used object detector.


1.2.1. Two-Stage Object Detector


  • In two-stage object detector where there’s a separate networks for generating region proposals, one can use hard negative example mining and online hard example mining (OHEM).

  • Hard Negative Example Mining

    • Runing through batches, add false positive examples predicted in previous batch into next batches.

    • By feeding the model with challenging examples that it failed to correctly classify, one can encourage the model to improve it’s ability to distinguish between positive and negative examples.

    • However, this method can be computationally inefficient as it repeats the processes of selecting false positive examples, adding these to dataset, and re-organizing the mini-batch.

  • OHEM

    • Forward pass for the entire extracted RoIs to compute loss, and perform backward pass and update for only the RoIs with high losses (top B in N examples) picked from hard RoI sampler.


  • Also, sampling heuristics to set fixed ratio (1:3) of positive (object) and negative (background) examples per a mini-batch can also be an option.


1.2.2. One-Stage Object Detector


  • Methods used in two-stage detector are not applicable to one-stage detector system that belongs to dense prediction architecture where every grid cell is automatically assgined as RoI and predicts a fixed number of bounding boxes.

  • One can’t limit the number of negative examples or add extra RoIs that are misclassified in previous batch.

  • Focal Loss

    • Focal loss is proposed in RetinaNet to deal with the problem of data imbalance.

    • It is an improvement of the ordinary cross entropy loss (CE) to give more focus on hard examples with low ground-truth class probability.

    image

    • Modulating factor $\large (1\,-\,p_{t})^{\gamma}$ can automatically adjust the relative loss by the ground truth class probabilites, reducing the contribution of easy classes (high probability) and putting greater focus on rare examples.

    • $\large \gamma$ is an focusing parameter that lies between 0 and 5. Experimenting with several values, 2 seems to work best.

  • Label Smoothing

    • Also modify CE, converting hard one-hot encoding of ground-truth labels to soft-encoding (adds extra term of uniform distribution).


1.3. BBox Regression


  • Traditional object detection uses MSE to coordinates of predicted boxes and ground-truths, which ignores the integrity of the object itself treating each point of bboxes independently.

  • Recently, IoU based loss that considers the converage of predicted bboxes and ground-truth boxes is proposed to handle the problems in traditional MSE based loss.

    • IoU loss = 1 - IoU

    image

  • While MSE (L2 loss) gives same value across all examples, IoU based losses are dynamically returning different values by the relative localizations of boxes.

  • There are multiple variants of standard IoU loss such as GIoU, DIoU, and CIoU, each one taking distinct formula to compute the final loss.

    • Generalized IoU (GIoU) : considers the minimum rectangular area that simultaneously covers two target bboxes (predicted and ground truth).

      • GIoU = IoU - (C \ (A ∪ B)) / C where (C \ (A ∪ B)) indicates the area of C minus the total area of A and B.


      image

    • Distance IoU (DIoU) : considers the distance between centers of two bboxes.

      • DIoU : IoU - $\large \frac{\rho^{2}(b, b^{gt})}{c^{2}}$


    • Complete IoU (CIoU) : adds an extra term for aspect ratio to DIoU.

      image


2. Bag of Specials (BoS)


  • BoS represents the post-processing methods that slightly increase the inference cost with significant improvement in the model’s accuracy.

  • It includes previously explained methods used in model architecture such as SPP, SAM, PAN.


2.1. Activation Funcitons


  • Choosing good activation function further enhances the gradient flow across the networks and increases the expressiveness of the entire model.

  • Originating from ReLU, lots of modified versions are made, including LReLU, PReLU, ReLU6, Scaled Exponential Linear Unit (SELU), Swish, hard-Swish, and Mish.

  • Among these, YOLOv4 adopts Mish that removes upper bound capping to prevent saturation and allows negative values within limited range to make better gradient flow.


2.2. DIoU Non-Maximal Suppression (NMS)


  • Use DIoU for NMS to filter bboxes that capture identical objects with formerly selected bbox.

  • Using DIoU instead of standard IoU, model can consider the information of center point distance as well as the coverage of areas.


Ablation Studies of BoF and BoS


  • Further, to investigate the practical contribution of each feature in BoF and BoS to the detector accuracy, authors performed ablation studies where they remove a certain component and see how it affects the model performance.

image

image

  • S : Eliminate grid sensitivity
    • YOLOv3 used the equation bx = σ(tx)+cx, by = σ(ty)+cy, where cx and cy are always whole numbers, for evaluating the object coordinates.
    • Extremely high tx absolute values are required for the bx value approaching the cx or cx + 1 values. (As the output of simgoid reaches to either 0 or 1.)
    • Solve this problem through multiplying the sigmoid by a factor exceeding 1.0 (like 2), so eliminating the effect of grid on which the object is undetectable.
  • M : Mosaic data augmentation, using the 4-image mosaic during training instead of single image.
  • IT : IoU threshold - using multiple anchors for a single ground truth that has IoU (truth, anchor) > IoU threshold.
  • GA : Genetic algorithms - using genetic algorithms for selecting the optimal hyperparameters during network training on the first 10% of time periods.
  • LS : Class label smoothing
  • CBN : CmBN - using Cross mini-Batch Normalization for collecting statistics inside the entire batch, instead of collecting statistics inside a single mini-batch.
  • CA : Cosine annealing scheduler - altering the learning rate during sinusoid training.
  • DM : Dynamic mini-batch size - automatic increase of mini-batch size during small resolution training by using Random training shapes.
  • OA : Optimized Anchors - using the optimized anchors for training with the 512x512 network resolution.
  • GIoU, CIoU, DIoU, MSE - using different loss algorithms for bounded box regression


Optimal Combination


  • To sum up, these are the final combination of deep learning techniques that are selected in each part of the YOLOv4 after a thorough experiments.


Model Architecture
  • Backbone : CSPDarknet53
  • Neck : SPP, PAN
  • Head : YOLOv3
Bag of Freebies(BoF) for backbone
  • CutMix and Mosaic data augmentation
  • Dropblock regularization
  • Class label smoothing
Bag of Specials(Bos) for backbone
  • Mish activation
  • Cross-stage partial connections(CSP)
  • Multi-input weighted residual connections(MiWRC)
Bag of Freebies(BoF) for detector
  • CIoU-loss
  • CmBN (Cross Mini-Batch Normalization)
  • Self-Adversarial Training (SAT)
    • Operates in 2 steps, forward and backward pass.
    • In first stage, alters the original image instead of network image, the process called as an adversarial attack, modifying original image to create a deception that there is no desired object on the image.
    • Second stage train the networks with the altered image from 1st stage.
  • Eliminate grid sensitivity
  • Using multiple anchors for a single ground truth
  • Cosine annealing scheduler
Bag of Specials(BoS) for detector
  • Mish activation
  • SPP-block
  • SAM-block
  • PAN
  • DIoU-NMS



image