Noise Stability Optimization for Flat Minima with Tight Rates
Abstract
We consider minimizing a perturbed function $F(W) = \mathbb{E}_{U}[f(W + U)]$, given a function $f: \mathbb{R}^d \rightarrow \mathbb{R}$ and a random sample $U$ from a distribution $\mathcal{P}$ with mean zero. When $\mathcal{P}$ is the isotropic Gaussian, $F(W)$ is roughly equal to $f(W)$ plus a penalty on the trace of $\nabla^2 f(W)$, scaled by the variance of $\mathcal{P}$. This penalty on the Hessian has the benefit of improving generalization, through PACBayes analysis. It is useful in lowsample regimes, for instance, when a (large) pretrained model is finetuned on a small data set. One way to minimize $F$ is by adding $U$ to $W$, and then run SGD. We observe, empirically, that this noise injection does not provide significant gains over SGD, in our experiments of conducting finetuning on three image classification data sets. We design a simple, practical algorithm that adds noise along both $U$ and $U$, with the option of adding several perturbations and taking their average. We analyze the convergence of this algorithm, showing tight rates on the norm of the output's gradient. We provide a comprehensive empirical analysis of our algorithm, by first showing that in an overparameterized matrix sensing problem, it can find solutions with lower test loss than naive noise injection. Then, we compare our algorithm with four sharpnessreducing training methods (such as the SharpnessAware Minimization (Foret et al., 2021)). We find that our algorithm can outperform them by up to 1.8% test accuracy, for finetuning ResNet on six image classification data sets. It leads to a 17.7% (and 12.8%) reduction in the trace (and largest eigenvalue) of the Hessian matrix of the loss surface. This form of regularization on the Hessian is compatible with $\ell_2$ weight decay (and data augmentation), in the sense that combining both can lead to improved empirical performance.
 Publication:

arXiv eprints
 Pub Date:
 June 2023
 DOI:
 10.48550/arXiv.2306.08553
 arXiv:
 arXiv:2306.08553
 Bibcode:
 2023arXiv230608553J
 Keywords:

 Computer Science  Machine Learning;
 Computer Science  Data Structures and Algorithms;
 Mathematics  Optimization and Control;
 Statistics  Machine Learning
 EPrint:
 36 pages, 3 tables