Skip to main page content
U.S. flag

An official website of the United States government

Dot gov

The .gov means it’s official.
Federal government websites often end in .gov or .mil. Before sharing sensitive information, make sure you’re on a federal government site.

Https

The site is secure.
The https:// ensures that you are connecting to the official website and that any information you provide is encrypted and transmitted securely.

Access keys NCBI Homepage MyNCBI Homepage Main Content Main Navigation
. 2021 Mar 2;118(9):e2015617118.
doi: 10.1073/pnas.2015617118.

The inverse variance-flatness relation in stochastic gradient descent is critical for finding flat minima

Affiliations

The inverse variance-flatness relation in stochastic gradient descent is critical for finding flat minima

Yu Feng et al. Proc Natl Acad Sci U S A. .

Abstract

Despite tremendous success of the stochastic gradient descent (SGD) algorithm in deep learning, little is known about how SGD finds generalizable solutions at flat minima of the loss function in high-dimensional weight space. Here, we investigate the connection between SGD learning dynamics and the loss function landscape. A principal component analysis (PCA) shows that SGD dynamics follow a low-dimensional drift-diffusion motion in the weight space. Around a solution found by SGD, the loss function landscape can be characterized by its flatness in each PCA direction. Remarkably, our study reveals a robust inverse relation between the weight variance and the landscape flatness in all PCA directions, which is the opposite to the fluctuation-response relation (aka Einstein relation) in equilibrium statistical physics. To understand the inverse variance-flatness relation, we develop a phenomenological theory of SGD based on statistical properties of the ensemble of minibatch loss functions. We find that both the anisotropic SGD noise strength (temperature) and its correlation time depend inversely on the landscape flatness in each PCA direction. Our results suggest that SGD serves as a landscape-dependent annealing algorithm. The effective temperature decreases with the landscape flatness so the system seeks out (prefers) flat minima over sharp ones. Based on these insights, an algorithm with landscape-dependent constraints is developed to mitigate catastrophic forgetting efficiently when learning multiple tasks sequentially. In general, our work provides a theoretical framework to understand learning dynamics, which may eventually lead to better algorithms for different learning tasks.

Keywords: generalization; loss landscape; machine learning; statistical physics; stochastic gradient descent.

PubMed Disclaimer

Conflict of interest statement

The authors declare no competing interest.

Figures

Fig. 1.
Fig. 1.
The PCA results and the drift–diffusion motion in SGD. (A) The rank-ordered variance σi2 in different principal component (PC) directions i. For i20, σi2 decreases with i as a power law iγ with γ23. (B) The normalized accumulative variance of the top (n1) PCs excluding i=1. It reaches 90% at n=35 much smaller than the total number of weights N=2,500 between the two hidden layers. (C) The SGD weight trajectory projected onto the (θ1,θ2) plane. The persistent drift motion in θ1 and the diffusive random motion in θ2 are clearly shown. (D) The diffusive motion in the (θi,θj) plane with j>i(1) randomly chosen (i=49 and j=50 shown here). Unless otherwise stated, hyperparameters used are B=50, α=0.1.
Fig. 2.
Fig. 2.
The loss function landscape and the inverse variance–flatness relation. (A) The loss function profile Li along the ith PCA direction. (B) The loss landscape (in log-scale). Li can be fitted better by an inverse Gaussian (the red dashed line) than a quadratic function (the green dashed line). The definition of the flatness Fi(θirθil) is also shown (see text for details). (C and D) The flatness Fi for different PCA directions i (C) and the inverse relation between the variance σi2 and the flatness Fi for different choices of minibatch size B and learning rate α (D).
Fig. 3.
Fig. 3.
Statistical properties of the MLF ensemble. (A) Profiles of the overall loss function ln(Li) (red line) and a set of randomly chosen MLFs ln(Liμ) (blue dashed lines) in a given PCA direction i. (B) The inverse dependence of Di and τi on the flatness Fi.
Fig. 4.
Fig. 4.
The landscape-dependent constraints for avoiding catastrophic forgetting. (A) The test errors for task 1 (ϵ1) and task 2 (ϵ2) versus training time for task 2 in the absence of the constraints (λ=0). (B) The weight displacements qi in different PCA directions p1i from task 1 in the absence of the constraints (λ=0). C and D are the same as A and B but in the presence of the constraints with λ=10 and Nc=200. The red dashed line in D shows the upper bound qi0.008Fi1 for the modes (iNc) that are under constraint. (E) The tradeoff between the saturated test errors (ϵ1 and ϵ2) when varying λ for LDC (blue circles) and EWC (red squares) algorithms. (F) The overall performance (ϵ1+ϵ2) versus the number of constrains Nc for LDC (blue circles) and EWC (red squares) algorithms. The two tasks are for classifying two separate digit pairs [(0,1) for task 1 and (2,3) for task 2] in MNIST.
Fig. 5.
Fig. 5.
Profiles and dynamics of the anisotropic active temperature. (A) The active temperature profile Ti(δθ,t) in the ith PCA direction at t=200. (B) The minimum active temperature Ti(0) in different PCA directions i. Inset shows the inverse dependence of Ti on the flatness Fi. (C) The active temperature profiles Ti(δθ,t) at different times for i=10. (D) The active temperature Ti for all directions decreases with time in sync with the loss function (red line) dynamics. The shaded region highlights the transition between the fast-learning phase and the exploration phase. Inset shows the correlation between Ti and L.
Fig. 6.
Fig. 6.
The flatness spectrum and the effective dimension (Ds) of solution. (A) The flatness spectra (rank-ordered flatness) for networks with different width (H). (B) The effective dimension of the solution Ds, which is defined as the number of directions whose flatness is below a threshold set to be roughly half of the L2 norm of the weights (the dashed line in A), increases weakly as the number of parameters (weights) Np(H2) increases. The error bars are obtained by using 10 different solutions obtained by 10 random initializations with the same norm for each network size.

Similar articles

Cited by

References

    1. LeCun Y., Bengio Y., Hinton G., Deep learning. Nature 521, 436–444 (2015). - PubMed
    1. Robbins H., Monro S., A stochastic approximation method. Ann. Math. Stat. 22, 400–407 (1951).
    1. Bottou L. “Large-scale machine learning with stochastic gradient descent” in Proceedings of COMPSTAT’2010, Lechevallier Y., Saporta G., Eds. (Physica-Verlag HD, Heidelberg, Germany, 2010), pp. 177–186.
    1. Hinton G. E., van Camp D., “Keeping the neural networks simple by minimizing the description length of the weights” in Proceedings of the Sixth Annual Conference on Computational Learning Theory, COLT ‘93, L. Pitt, Ed. (ACM, New York, NY, 1993), pp. 5–13.
    1. Hochreiter S., Schmidhuber J., Flat minima. Neural Comput. 9, 1–42 (1997). - PubMed

LinkOut - more resources