[Paper Review] Sharpness-Aware Minimization for Efficiently Improving Generalization (SAM, 2021)
Outlines
- References
- 1. Weak Generalization Power of Sharp Minima
- 2. Sharpness-Aware Minimization (SAM)
- 3. Empirical Evaluation
References
- Sharpness-Aware Minimization for Efficiently Improving Generalization, Foret et al, 2021
- Proving that the dual of the lp norm is the lq norm
- An Introduction to PAC-Bayes
1. Weak Generalization Power of Sharp Minima
-
Overparameterizing the model with an objective to minimize the training loss can results in suboptimal model that fails to generalize over the entire distribution of the data.
-
Visual representation of the loss landscape of the overfitted model shows sharp minima where the curvature of loss landscape becomes significantly sharp near the minima.
-
Overly peaked loss minima indicates that the model can fail to show stable performance over the deviations from the data to which it’s originally fitted during training.
-
Thus, leveraging the geometry of the landscape to have flatter minima can yield better generalization to the test set.
-
Figure above gives visual intuition that flatter minima tends to be more robust to the deviation between training and test function (goal) with smaller generalization gap compared to sharp minima.
-
In order to alleviate loss sharpness and achieve better generalization, this paper suggests to extend the minimization objective to the neighborhoods of the parameters, not just the parameters themselves.
2. Sharpness-Aware Minimization (SAM)
-
Notations
-
Training dataset :
-
Population Distribution :
-
Training loss :
-
Population loss :
-
-
Typical approach of searching for the parameters is by solving $\large \text{min}\, L_{S}(w)$ with respect to $\large w$, which can easily result in suboptimal performance at test time.
-
Instead, SAM seeks out the parameters whose bounded neighborhoods have uniformly low training loss value and thus forms wide and flat curvature of the loss landscape.
2.1. PAC Bayesian Generalization Bound
-
Probably Approximately Correct (PAC) Bayes Bound
- PAC Goal : with high probability (at least 1-$\delta$), the empirical loss is approximately correct (error from true risk is bounded by certain small value)
-
This inequality holds for any prior P over parameters p and posterior Q over parameters.
-
n : the size of the dataset S
-
Q : Posterior on hypotheses that depends on the training datasets S.
-
P : Priors on hypotheses that doen NOT depends on the training datasets S.
-
Assuming that each of prior and posterior follows distinct normal distribution , then the KL divergence between them is as follows
-
Can check the full derivations HERE
2.2. SAM Objective
-
SAM aims to minimize the training loss not only concerning the parameters themselves but also by considering their bounded neighborhoods.
-
Starting from the PAC-bayes bound derived over the parameter set,
-
SAM extends the generalization bound based on the sharpness.
-
More formally, the inequality is as follows
-
Proof required to derive the inequality above is provided in the paper. (just focus on the fact that $\large h$ here is still the increasing function of $\large ||w||_{2}^{2}$)
-
Intuition of why SAM leverages sharpness of the loss landscape
-
Rewriting the RHS
-
Term in square brackets captures the loss sharpness and thus minimizing the summed of $\large \text{max} \, L_{s}(w + \epsilon)$ results in flattening the curvature around the neighborhoods of the parameters.
-
-
-
Given that the specific function h (increasing f w.r.t power of w) is heavily influenced by the details of the proof, it’s substituted with $\large \lambda ||w||^{2}$ for a hyperparameter $\large \lambda$, yielding a standard L2 regularization term.
-
Then the SAM can be simplified to
-
To sum up, SAM objective is given by minimizing the superior of neighborhood loss summed over a batch set with a L2 regularization over the magnitude of the parameters and this improves the generalization of the predictor by limiting the upper bound of the true loss ($\large L_{\mathcal{D}(w)}$)
-
-
As the objective of SAM is given, parameters can be optimzed to minimize $\large L_{S}^{SAM}(w)$ by stochastically update the term using gradient descent.
g =
-
First, to express in a solvable form w.r.t $\large \epsilon$, approximate it via a first-order Tylor expansion (this strategy is valid as $\large \epsilon$ is set to be near 0)
-
As $\large L_{S}(w)$ is determined by data, the problem gets down to simply solving $\large \epsilon$ that maximizes the second term.
-
Solution to this problem $\large \hat{\epsilon}(w)$ is given by the solution to a classical dual p-norm problem
-
Proof for Solving Dual p-Norm Problem
-
-
Thus,
-
Then differentiating $\large L_{S}^{SAM}(w)$ gives
-
As the second term contains the Hessian of loss with respect $\large w$ ($\large \tfrac{d\hat{\epsilon}(w)}{dw}$), which is too expensive to compute, SAM drops it from the gradient formula.
-
Then the final approximation of the SAM gradient is
-
SGD with SAM Objective
3. Empirical Evaluation
-
m-Sharpness
-
Generalization power of the model grows with the batch size (m) for SGD update with SAM, demonstrated by the result of the experiment that shows the correlation between sharpness and generalization is higher for smaller m.
-
Smaller m tends to be more sensitive to changes in $\large \rho$, a bound of the magnitude of $\large \epsilon$, which means that the effect of SAM objective becomes more significant in smaller batch size.
-
-
Hessian Spectra
-
Spectrum of Hessian (ratio of $\large \lambda_{max}/\lambda_{5}$) is a widely used measure of loss sharpness. ($\large \lambda$ here stands for the eigenvalue of the Hessian matrix at convergence).
-
left : standard SGD / right : SGD with SAM objective
-
As the number of epoch increases, the distribution of eigenvalues becomes more left-shifted and Hessian spectrum decreases.
-
-
Performance Comparison
-
Error rates for fine-tuning various SOTA models (EffNet, BiT, ViT, and etc.)
-
SAM uniformly improves the performance relative to finetuning w/o SAM.
-
- SAM derived from PAC-Bayesian bound, successfully improves the generalization power of the predictors by leveraging the sharpness of loss geometry.