Introduction

Identifying slow or relevant variables is an essential step in analyzing large-scale non-linear systems. In the context of deep neural networks (DNNs), these should be some combinations of the individual weights that are weakly fluctuating and obey a closed set of equations. One potential set of such variables is the DNNs’ outputs themselves. Indeed, in the limit of infinitely over-parameterized DNNs these provide an elegant picture of deep learning1,2,3 based on a mapping to Gaussian Processes (GPs). However, these GP limits miss out on several qualitative aspects, such as feature learning4,5 and the fact that real-world DNNs are not nearly as over-parameterized as required for the GP description to hold1,3,6,7. Obtaining a useful set of slow variables for describing deep learning at finite over-parameterization is thus an important open problem in the field.

Several works provide guidelines for this search. Noting that GP limits can have surprisingly good performance8 and that over-parameterization is natural to deep learning9,10 we are inclined to keep some elements of the GP picture. One such element is to work in function space and study pre-activation and outputs instead of weights whose posterior distribution becomes complicated even in the GP limit11,12. Another element is the layer-wise composition of hidden layer kernels13 which yields the output kernel of the GP14. Such a layer-wise picture is also harmonious with the idea that DNN layers should not correlate strongly, to prevent co-adaptation15. Recently, it was shown that in some limited settings, making the GP kernel “dynamical” or flexible, so that it adapts to the dataset, can account for differences between infinite and finite DNNs3,16,17,18,19. Still, the task of finding an explicit set of equations describing this flexibility in deep non-linear DNNs remains unsolved. Specifically, while in the GP limit we find tractable algebraic expressions, involving only basic matrix manipulations, for the DNN’s prediction in the feature learning regime similar expressions only exist for deep linear16,20 or non-linear networks with one trainable layer21,22,23. In a related manner, while the DNN’s outputs provide a complete set of slow variables in the GP limit, in the finite-width feature learning regime it is not clear which subset of variables governs the trained DNN’s behavior other than the entire set of weights.

In this work, we identify such slow variables and use these to derive an effective theory for deep learning capable of capturing various finite channel/width (C/N) effects (such as feature learning) in convolutional neural networks (CNNs) and fully connected neural networks (FCNs). We argue that:

  1. 1.

    For, C, N 1 the erratic behavior of specific channels/neurons averages out and hidden layers coupled to each other only through two “slow” variables per layer: the second cumulant of the pre-activations (pre-kernel), K(l), and the second cumulant of activations (post-kernel), Q(l), of the lth layer. Furthermore, for mean square error (MSE) loss, FCNs in the so-called mean-field (MF) scaling (where the last layer weights are scaled down) or CNNs with a large read-out layer fan-in behave effectively as a GP with a data-aware kernel determined by the second cumulant of pre-activations in the penultimate layer21.For comparison between the classical GP limit and the GP process we find, see Fig. 1.

    Fig. 1: Feature learning regime versus Gaussian process infinite limit.
    figure 1

    Learning as described by our effective theory for antisymetric activation functions. Left: for infinite width the pre-activations (h(l)), and the output (f) fluctuate according to a Gaussian distribution with fixed post-kernels \({Q}_{\infty }^{(l)}\), Qf,, respectively. The complete set of slow variables are the outputs with fixed kernels. Right: for large but finite width and number of samples, we obtain an approximately Gaussian distribution for the pre-activations and outputs with learned pre-kernels K(l) and Kf = Qf + σ2In, respectively. The outputs follow a Gaussian Process Regression (GPR) with kernel Qf. The complete set of slow variables here are both the outputs and the pre and post-kernels. For general activation, an additional slow variable, corresponding to the mean of the preactivation, needs to be tracked.

  2. 2.

    In settings where the kernels have a large density of dominant eigenvalues, the posterior (or trained) pre-activations fluctuate in a nearly Gaussian manner. Following this, we use a multivariate Gaussian variational approximation for the posterior pre-activations and derive explicit matrix equations (Equations of State) for the covariance matrices governing these pre-activations and the DNNs predictions.

  3. 3.

    We identify an emergent feature learning scale (FLS) denoted by χ, proportional to the train MSE times n2 over C (or N). This scale controls the difference between the finite C, N output kernel (Qf) and its C, N →  limit and in this sense reflects feature learning. Due to the n2 factor, χ can be O(1) or larger even for C 1, e.g., for CNN architectures (see Fig. 2 panel c). The same holds, with C replaced by N, for FCNs in the MF scaling21. Unlike perturbation theory3,6,23,24, our theory tracks all orders of χ and treats only 1/C, 1/N perturbatively. The separation of scales between χ and 1/C, 1/N is thus central to our analysis. Its manifestation is the fact that feature learning shifts and stretches the dynamical variables in the theory (the pre-activations) in a considerable manner yet barely spoils their Gaussianity.

    Fig. 2: Theory versus experiment.
    figure 2

    a, b Compare fully trained 3-layer FCNs in the MF scaling with our theoretical predictions. a Studies the most strongly fluctuating weight mode in the first layer. A QQ-plot is shown, of its empirical distribution against a normal one, along with histograms together with Gaussian fits (insets). As the scale of the problem increases, these fluctuations become more and more Gaussian. b The experimental values and EoS predictions, each normalized by their values at Nl →  with fixed σl (i.e. the GP value). Here we used n = 1024 and d = 128. c Plots the N (or channel number C) at which the feature learning scale (χ) is 1. This dictates the crossover between weak and strong feature learning in our theory. The color gradient provides a qualitative separation between large and small, N, C and thus correlates with the adequateness of our mean-field decoupling for the hidden layers. Evidently, FCNs with standard scaling (\({[{Q}_{f}]}_{\mu \mu }=O(1)\) at n = 0) require Nl = O(1) for appreciable feature learning. However, FCN with an MF scaling, as well as CNNs in standard scaling, can exhibit strong feature learning well within the regime of our mean-field decoupling.

  4. 4.

    We provide what is, to the best of our knowledge, the first analytic predictions for the test performance of a non-linear network with two trainable layers in the feature learning regime. We do so both for an FCN and a CNN.

The predictions of our approach are tested on several toy and real-world examples using direct analytical approaches and numerical solutions to the equations of state. Our analysis takes a physics viewpoint on this complex non-linear problem. Rigorous mathematical proofs are left as an open problem for future research.

We note that there are several works showing evidence that the spectrum of the empirical weight correlation matrix show various tail effects and spikes25. While in deeper layers we focus on pre-activations, the spectrum of input layer weights we obtained, is Gaussian but not independent as in ref. 12. Hence, it can produce a variety of spectral distributions for the covariance matrix, similar to the aforementioned ones. We note a recent interesting work26 arguing that the test-loss depends only on the mean and variance of hidden activations. There, however, the setting is of a fixed trained DNN and the statistics are over the input measure rather than over the DNN parameters as in our case. While quantitatively different, our approach is similar in spirit to the layer-wise Gaussian Processes algorithm19 inspired by DNN experiments. However, our approach provides a more accurate first principles description of trained DNNs. Additional approaches for finite-width include perturbative correction around the infinite width limit to leading3,24 or higher orders27,28,29. There is however mounting evidence from bounds on GP limits1,28, numerical experiments3,5,6,23, as well as the current work, that such perturbative expansions have slow convergence in practical regimes. In contrast, our EoS are useful both numerically (see section “Methods” and “Numerical demonstration: 3 layer FCN”) and analytically (see section “Analytical solution of the EoS - two layer CNN”) and in addition allow us to model pre-activation distributions in the wild via our pre-kernels (see section “Extensions to deeper CNNs and subsets of real-world data-sets”).

Results

Problem statement

Our general setup consists of DNNs trained on a labeled training set of size \(n,{{{{{{{{\mathscr{D}}}}}}}}}_{n}={\{({{{{{{{\mathbf{{x}}}}}}}}}_{\mu },{y}_{\mu })\}}_{\mu=1}^{n}=\{{X}_{n},{{{{{{\mathbf{{y}}}}}}}}\}\) with MSE loss. The input vector is \({{{{{{{\mathbf{{x}}}}}}}}}_{\mu }\in {{\mathbb{R}}}^{d}\), and the target yμ is a scalar output. We denote vectors and tensors by boldface and use μ, ν to represent data point indices.

Our theory can be applied to any finite number of convolutional, dense, or pooling layers. To illustrate its main aspects, let us focus on an L-layer fully connected model with width Nl (l [1. . L − 1]),

$${{{\rm{f}}}}_{{{{{{{\boldsymbol{\theta }}}}}}}}({{{{{{\mathbf{x}}}}}}}) = \mathop{\sum }\limits_{j=1}^{{{{{{\mathrm{N}}}}}}_{{\mathrm {L}}}}{{{\mathrm{w}}}}_{j}^{({{{\mathrm{L}}}})}\phi \left({{{\mathrm{h}}}}_{j}^{({{{\mathrm{L}}}}-1)}({{{{{{\mathbf{{x}}}}}}}})\right)\,; {{{\mathrm{h}}}}_{j}^{(l)}({{{{{{\mathbf{{x}}}}}}}})=\mathop{\sum }\limits_{i=1}^{{{{{\mathrm{N}}}}}_{{{{\mathrm{l}}}}-1}}{{{\mathrm{W}}}}_{{{\mathrm{ji}}}}^{({{{\mathrm{l}}}})}\phi \left({{{\mathrm{h}}}}_{i}^{({{{\mathrm{l}}}}-1)}({{{{{{\mathbf{{x}}}}}}}})\right); \quad {{{{{{{\mathbf{{h}}}}}}}}}^{(1)}({{{{{{\mathbf{x}}}}}}}) = {W}^{(1)}{{{{{{\mathbf{{x}}}}}}}}$$
(1)

where θ vec{W(1), . . , W(L−1), w(L)} are the trainable parameters of the network arranged in a vector such that (i [1, N1]), \({{{{{{{\mathbf{{w}}}}}}}}}^{(L)}\in {{\mathbb{R}}}^{{N}_{L-1}}\), \({W}^{(l)}\in {{\mathbb{R}}}^{{N}_{l}\times {N}_{l-1}}\) are the weights of the network, and the input vector \({{{{{{\mathbf{{x}}}}}}}}\in {{\mathbb{R}}}^{d}\) such that d = N0. The activation function, \(\phi :{\mathbb{R}}\to {\mathbb{R}}\), is applied element-wise.

We take the Bayesian Neural network perspective, i.e. the above network is a random object drawn from a set of Neural networks. The main object we analyze is the Bayesian posterior distribution of the DNN outputs \(p({{{{{{{\bf{f}}}}}}}}|{{{{{{{{\mathscr{D}}}}}}}}}_{n})\). The target yμ is the output of the network fθ(xμ) with additive i.i.d centered Gaussian noise with variance σ2 and an i.i.d Gaussian prior on weights \({W}_{ij}^{(l)} \sim {{{{{{{\mathscr{N}}}}}}}}(0,{\sigma }_{l}^{2}/{N}_{l})\). The posterior can be rewritten as \(p({{{{{{{\bf{f}}}}}}}}|{{{{{{{{\mathscr{D}}}}}}}}}_{n})\propto p({{{{{{{\rm{f}}}}}}}}|{X}_{n})\exp \left(-{\sum }_{\mu }{({y}_{\mu }-{f}_{{{{{{{\boldsymbol{\theta }}}}}}}}({{{{{{{\mathbf{{x}}}}}}}}}_{\mu }))}^{2}/2{\sigma }^{2}\right)\) where the prior distribution is

$$p({{{{{{{\rm{f}}}}}}}}|{X}_{n})={\left\langle \mathop{\prod}\limits_{\mu }\delta \left[{f}({{{{{{{\mathbf{{x}}}}}}}}}_{\mu })-{f}_{{{{{{{\boldsymbol{\theta }}}}}}}}({{{{{{{\mathbf{{x}}}}}}}}}_{\mu })\right]\right\rangle }_{{{{{{{\boldsymbol{\theta }}}}}}}},$$
(2)

where 〈. . . 〉θ denotes an average over the prior distribution of the weights. We note that there is a correspondence between certain gradient-based methods and Bayesian Neural Networks (BNNs)30,31,32. Recent work33 shows that BNNs can perform on par or even better than DNNs trained with SGD. In particular, the equilibrium distribution of the Langevin-type training algorithm, i.e., full-batch gradient descent with a small learning rate, along with weight decay, and additive white noise with variance σ23 is equivalent to the posterior distribution of a BNNs, \(p({{{{{{{\rm{f}}}}}}}}|{{{{{{{{\mathscr{D}}}}}}}}}_{n})\) is in function space. Adopting physics notation, it can be written as \(p({{{{{{{\rm{f}}}}}}}}|{{{{{{{{\mathscr{D}}}}}}}}}_{n})={e}^{-{{{{{{{\mathscr{S}}}}}}}}}/{{{{{{{\mathscr{Z}}}}}}}}({{{{{{{{\mathscr{D}}}}}}}}}_{n})\), where \({{{{{{{\mathscr{Z}}}}}}}}({{{{{{{{\mathscr{D}}}}}}}}}_{n})=\int{e}^{-{{{{{{{\mathscr{S}}}}}}}}}\) is the partition function, and \({{{{{{{\mathscr{S}}}}}}}}\) is the action or negative log-posterior (see also Methods). As shown in Supplementary Material (1) this action is given by

$${{{{{{{\mathscr{S}}}}}}}}= \frac{1}{2}\mathop{\sum }\limits_{i=1}^{{N}_{1}}{({{{{{{{\mathbf{{h}}}}}}}}}_{i}^{(1)})}^{\top }{\left[{Q}^{(1)}\right]}^{-1}{{{{{{{\mathbf{{h}}}}}}}}}_{i}^{(1)}+\frac{1}{2}\mathop{\sum }\limits_{l=2}^{L-1}\mathop{\sum }\limits_{i=1}^{{N}_{l}}{({{{{{{{\mathbf{{h}}}}}}}}}_{i}^{(l)})}^{\top }{\left[{\tilde{Q}}^{(l)}({{{{{{{\rm{{h}}}}}}}^{(l-1)}}})\right]}^{-1}{{{{{{{\mathbf{{h}}}}}}}}}_{i}^{(l)} \\ +\frac{1}{2}{{{{{{{\mathbf{{f}}}}}}}}}^{\top }{\left[{\tilde{Q}}_{f}({{{{{{{{\rm{h}}}}}}}}}^{(L-1)})\right]}^{-1}{{{{{{\mathbf{{f}}}}}}}}+\frac{1}{2{\sigma }^{2}}\mathop{\sum}\limits_{\mu }{({f}_{\mu }-{y}_{\mu })}^{2}$$
(3)

where fμ = f(\({{{{{{\boldsymbol{{x }}}}}}}}\)μ), \({h}_{i\mu }^{(l)}={h}_{i}^{(l)}({{{{{{{\mathbf{{x}}}}}}}}}_{\mu })\) and

$${Q}_{\mu \nu }^{(1)}= \frac{{\sigma }_{1}^{2}}{d}{{{{{{{\mathbf{{x}}}}}}}}}_{\mu }^{T}{{{{{{{\mathbf{{x}}}}}}}}}_{\nu },\,\,\,\,\,\,\,{\tilde{Q}}^{(l+1)}{({{{{{{{{\rm{h}}}}}}}}}^{(l)})}_{\mu \nu }=\frac{{\sigma }_{l+1}^{2}}{{N}_{l}}\mathop{\sum }\limits_{i=1}^{{N}_{l}}\phi ({h}_{i\mu }^{(l)})\phi ({h}_{i\nu }^{(l)}),\\ {\tilde{Q}}_{f}{({{{{{{{{\rm{h}}}}}}}}}^{(L-1)})}_{\mu \nu }= \frac{{\sigma }_{L}^{2}}{{N}_{L-1}}\mathop{\sum }\limits_{j=1}^{{N}_{L-1}}\phi ({h}_{j\mu }^{(L-1)})\phi ({h}_{j\nu }^{(L-1)})$$
(4)

where \({{{{{{{\mathbf{{h}}}}}}}}}_{i}^{(l)}\in {{\mathbb{R}}}^{n}\). We comment that for rank-deficient matrices, the inverses are regularized by including a small positive regularizer to be taken to zero at the end of the computation.

To familiarize ourselves with the action in Eq. (3) let us see how it reproduces the standard Gaussian Processes picture at infinite channel/width3,14. Strictly speaking, this action is highly non-linear, since the \(\tilde{Q}\)’s matrix elements contain high powers of pre-activations and since their inverse enters the action. Crucially, however, the \(\tilde{Q}\)’s are width-averaged quantities. Thus, at Nl →  one may replace them by their averages. Furthermore, upstream dependencies, wherein h(l) affects h(l−1), vanish (see Supplementary Material (1)). Roughly speaking, this is because h(l) only feels the collective effect of all the neurons in h(l−1) rendering its feedback on any specific neuron negligible.

Having these two simplifications in mind, we begin a layer-by-layer downstream analysis of the DNN: As there is no upstream feedback on h(1), the average of \({\tilde{Q}}^{(2)}({{{{{{{{\rm{h}}}}}}}}}^{(1)})\) (denoted Q(2)) can be carried under the Gaussian action of the input layer alone (first term in the action). Replacing \({\tilde{Q}}^{(2)}({{{{{{{{\rm{h}}}}}}}}}^{(1)})\) by Q(2) in the second term in the action, would then imply that h(2) fluctuates in a Gaussian manner with Q(2) as its covariance matrix. Next, the average of \({\tilde{Q}}^{(3)}({{{{{{{{\rm{h}}}}}}}}}^{(2)})\) (denoted Q(3)) can be found based on the now known, Gaussian statistics of h(2). Repeating this process, the final kernel (Q(L) = Qf) is found and is exactly the one obtained using the method introduced by Cho and Saul14. Together with the MSE loss term (last term in the action), we find that the outputs (fμ) fluctuate in a Gaussian manner, leading to the standard GP picture of infinite width trained DNNs3.

Here, however, our focus is at large but finite width (Nl 1). In this more complex regime, several corrections may appear: (i) The pre-activations’ average and covariance may deviate from those of a random DNN. (ii) Q(l), the covariance of activations in the l − 1 layer, would not solely determine the covariance of pre-activations in the downstream layer l, as upstream effects between h(l+1) and h(l) come into play. (iii) Inter-channel (or inter-neuron in the fully connected case) and inter-layer correlations may appear. (iv) The fluctuations of pre-activations may deviate from that of a Gaussian. A priori, all these corrections may play similarly dominant roles, thereby making analysis cumbersome.

Effective GP description in the feature learning regime

The basic analytical insight underlying this work is that these four types of corrections scale differently with n, d, and Nl. This allows for a controlled mean-field treatment, which differs substantially from straightforward perturbation theory in one over the width. As shown in Supplementary Material (1.6), corrections of type (iii) are often much smaller than those of type (i) and (ii). This holds generally for hidden layers when Nl’s (or channel’s number for CNNs) are much larger than 1. Considering the output layer of CNNs this requires a large fan-in and FCNs this holds when using an MF scaling21. Turning to correction of type (iv) in the l’th layer, these are suppressed when the average \({\tilde{Q}}^{(l)}\) has a large density of dominant eigenvalues—a situation relevant for when n and the input dimension are both large relative to one.

This leaves us with corrections of types (i) and (ii). Interestingly, following these corrections to all orders leads to a tractable mean-field picture of learning. The latter is an augmentation of the standard correspondence between GPs and DNNs at infinite width (NNGP)14,34: Pre-activations in different layers or channels/neurons remain uncorrelated and Gaussian. Correlations only appear between different data-points (and latent pixels for CNNs) within the same layer and channel/neuron. We henceforth denote the covariance of pre-activations and activations at layer l (up to normalization by the variance of the weights) by K(l) and Q(l) and refer to these as pre-kernel and post-kernel, respectively. However, in the NNGP34 viewpoint, Q(l) is simply proportional to K(l) and fully determined by the upstream kernel (Q(l−1)) whereas here K(l) and Q(l) differ and moreover depend both on the upstream and downstream kernels. Below, we present our theory for antisymmetric activation function, generalizations to other activation function such as ReLU, where one needs to track both kernels and means of pre-activations, could be found in Supplementary Material (5). Specifically, for the above L-layers FCN with \(\phi={{{{{{{\rm{erf}}}}}}}}\) activation function, we obtain (see also Fig. 7 and Supplementary Material (1) for derivation and extension to CNNs and any number of layers)

$$\bar{{{{{{{\mathbf{{f}}}}}}}}}= {Q}_{f}{\left[{Q}_{f}+{\sigma }^{2}{I}_{n}\right]}^{-1}{{{{{{\mathbf{{y}}}}}}}}\\ {\left[{\left({K}^{(L-1)}\right)}^{-1}\right]}_{\mu \nu }= {\left[{\left({Q}^{(L-1)}\right)}^{-1}\right]}_{\mu \nu }-\frac{1}{{N}_{L-1}}{{{{{{{\rm{Tr}}}}}}}}\left\{{A}^{(L)}\frac{\partial {Q}_{f}}{\partial {\left[{K}^{(L-1)}\right]}_{\mu \nu }}\right\}\\ {\left[{\left[{K}^{(l-1)}\right]}^{-1}\right]}_{\mu \nu }= {\left[{\left[{Q}^{(l-1)}\right]}^{-1}\right]}_{\mu \nu }+\frac{2{N}_{l}}{{N}_{l-1}}\frac{\partial {D}_{{{{{{{{\rm{KL}}}}}}}}}({K}^{(l)}||{Q}^{(l)})}{\partial {\left[{K}^{(l-1)}\right]}_{\mu \nu }}\,\,{{\mbox{for all}}}\,\,l\in [2,L-1]\\ {\left[{\Sigma }^{-1}\right]}_{s{s}^{{\prime} }}= \frac{d}{{\sigma }_{1}^{2}}{\delta }_{s{s}^{{\prime} }}+\frac{2{N}_{2}}{{N}_{1}}\frac{\partial {D}_{{{{{{{{\rm{KL}}}}}}}}}({K}^{(2)}||{Q}^{(2)})}{\partial {\Sigma }_{s{s}^{{\prime} }}}\\ {A}^{(L)}= {\sigma }^{-4}({{{{{{\mathbf{{y}}}}}}}}-\bar{{{{{{{\mathbf{{f}}}}}}}}}){({{{{{{\mathbf{{y}}}}}}}}-\bar{{{{{{{\mathbf{{f}}}}}}}}})}^{\top }-{\left[{Q}_{f}+{\sigma }^{2}{I}_{n}\right]}^{-1}$$
(5)

where \({[{Q}_{f}]}_{\mu \nu }={\sigma }_{L}^{2}G{({K}^{(L-1)})}_{\mu \nu }\), \({[{Q}^{(l)}]}_{\mu \nu }={\sigma }_{l}^{2}G{({K}^{(l-1)})}_{\mu \nu }\) and \(G{(K)}_{\mu \nu }= \frac{2}{\pi }{\sin }^{-1}\left(\frac{2{K}_{\mu \nu }}{\sqrt{1+2{K}_{\mu \mu }}\sqrt{1+2{K}_{\nu \nu }}}\right)\)35 for a matrix \(K\in {{\mathbb{R}}}^{n\times n}\) (equivalent expressions are known for several common activation functions, such as ReLU14). Also, \(\bar{{{{{{{{\rm{f}}}}}}}}}\) is the average DNN output and DKL(KQ) is the KL-divergence between two Gaussians with covariance matrices K and Q. The input layer post-kernel is \({K}^{(1)}={X}_{n}\Sigma {X}_{n}^{\top }\), where Σ is the covariance matrix of input layer weights.

As their lack of dependence on width suggests, the first equation together with the definitions of the post-kernels Q(l), Qf are already present in the strict GP limit (Nl → ). They are, respectively, the GP inference formula and standard kernel recursive equations of random DNNs14 with erf activation. The remaining equations are, to the best of our knowledge, novel and follow the changes to the pre-kernels and post-kernels at finite Nl. These could be solved analytically in some simple cases (see subsection “Analytical solution of the EoS - two layer CNN for the case of two-layer CNN”). We note that for non-anti-symmetric activation, one will also need to track the mean of each layer’s pre-activation (see Supplementary Material (5)).

To get a qualitative impression of their role, one can consider the case where the penultimate layer (l = L − 1) is linear, in which case \({Q}_{f}={\sigma }_{L}^{2}{K}^{(L-1)}\). Consequently, \(\frac{\partial {[{Q}_{f}]}_{{\mu }^{{\prime} },{\nu }^{{\prime} }}}{\partial {[{K}^{(L-1)}]}_{\mu,\nu }}={\sigma }_{L}^{2}{\delta }_{\mu {\mu }^{{\prime} }}{\delta }_{\nu {\nu }^{{\prime} }}\), (where δμν, with double index refers here to the Kronecker delta) and thus the second equation simplifies to

$${\sigma }_{L}^{2}{Q}_{f}^{-1}={({Q}^{(L-1)})}^{-1}-\frac{{\sigma }_{L}^{2}}{{N}_{L-1}}\left({{{{{{{\boldsymbol{\varepsilon }}}}}}}}{{{{{{{{\boldsymbol{\varepsilon }}}}}}}}}^{\top }-{[{Q}_{f}+{\sigma }^{2}{I}_{n}]}^{-1}\right),$$
(6)

where \({{{{{{{\boldsymbol{\varepsilon }}}}}}}}=({{{{{{{\rm{y}}}}}}}}-\bar{{{{{{{{\rm{f}}}}}}}}})/\sigma\). We note in passing that even for a non-linear penultimate layer, a similar term will arise from the expansion of Qf in K(L−1) to linear order. From the above form, several insights can be drawn.

First, we argue that the above equation implies that the trained DNN is more susceptible to changes along ε than the DNN at NL−1 → . Noting how \({Q}_{f}^{-1}\) enters the action (Eq. (3)), it controls the stiffness associated with fluctuations in f. Hence, \({Q}_{f}^{-1}\) makes fluctuations in the direction of ε more likely than they are according to \({({Q}^{(L-1)})}^{-1}\). Since ε measures the discrepancy in train predictions, this effect reduces the discrepancy by making the DNN more responsive in these directions than it is at NL−1 → . The second term, proportional to, \({\left[{Q}_{f}+{\sigma }^{2}{I}_{n}\right]}^{-1}\) amounts to a negligible reduction in fluctuations along eigenvectors of Qf corresponding to eigenvalues which are larger than σ2.

Using Eq. (6) one can also identify the aforementioned emergent feature learning scale (or FLS) namely, \(\chi={N}_{L-1}^{-1}{{{{{{{{\boldsymbol{\varepsilon }}}}}}}}}^{\top }{Q}^{(L-1)}{{{{{{{\boldsymbol{\varepsilon }}}}}}}}\). This scale represents the magnitude of the leading term when one Taylor expands Qf in 1/NL−1. When χ = O(1) or larger there is a significant change in the eigenvalues of Qf compared to Q(L−1) which indicates feature learning. On the other hand, when this quantity is small, we are closer to the GP regime (see Supplementary Material (1.6)). To assess the scaling χ, one can consider the common situation where ε has some non-negligible overlap with dominant eigenvectors of Q(L−1) whose eigenvalues are on the scale λ. Here we find χ ≈ λ MSE/σ2n/NL−1, where MSE denotes the mean train MSE which enters here via ε2/n. Due to its explicit n dependency, and for λ = O(n) at large n36χ maybe O(1) even at very large NL−1 and/or when the average MSE is rather small.

Figure 2c shows the value of NL−1 (or CL−1) at which χ = 1 (i.e. NL−1 or CL−1 at which feature learning becomes a dominant effect) as a function of n for several DNNs we study. The scale separation, demonstrated there by the fact that χ can be O(1) in regions where 1/Nl’s is negligible, is central to our analytical approach.

This scale χ is also the reason that naive perturbation theory in 1/Nl fails at large n3,6,23,27, as it treats χ and O(1/NL−1) on the same footing, since they both have a single negative power of NL−1. In contrast, our EoS treat the FLS non-perturbatively.

Last, we stress that the EoS provide us with a concrete, effective GP description for the entire DNN as well as its hidden layers. A priori one would expect that the normality of pre-activations, a large C, N trait, will be lost at finite C, N. Yet, we find that pre-activations remain Gaussian and accommodate strong feature learning effects while maintaining accurate predictions. This unexpectedly simple behavior opens various reverse engineering possibilities, wherein one infers the effective kernels from experiments and uses their spectrum and eigenvectors to rationalize about the DNN (see also Fig. 5). On this note, we comment that extending the EoS to test points is straightforward by formally treating the test point as an additional training point with its own “noise” parameter and then sending it to infinity (see Supplementary Material (1.4.1.)). For concrete EoS expressions which include test-points, see section “Analytical solution of the EoS - two layer CNN”.

Numerical demonstration: 3 layer FCN

Next, we test the agreement between the above results and statements and actual trained DNNs, starting from the 3-layer FCN defined in Eq. (1) with L = 3. We focus here on a student-teacher setting with n = 512 or 1024 training data points drawn from iid Gaussian distributions with unit variance along each input dimension. The target was generated by a randomly drawn teacher FCN of the same type, only with N1 = N2 = 1. The student was trained using an analog scaling to the MF scaling21, wherein the output layer weights are scaled down by a factor of \(1/\sqrt{{N}_{l}}\). Whereas for the CNNs discussed below, this choice of scaling was not required, for FCNs we found it necessary for getting any appreciable feature learning at N1 = N2 = N 1 (Fig. 2c).

As described in the Methods section, we trained 20 FCNs using our Langevin algorithm until they reached equilibrium. We use these trained FCNs to calculate various average quantities under our partition function (Eq. (3)). Specifically, we focused on: (i) The normalized train-loss on the scale of σ2, namely \(\,{{\mbox{MSE}}}\,/({\sigma }^{2}{\sum }_{\mu }{y}_{\mu }^{2})\)(ii) The eigenvalues (λi, i [1. . d]) of the average Σ, where we average over neurons, training seeds, and training time (the latter within the equilibrium region). (iii) The normalized overlap (α) between the discrepancy in prediction on the training set times the target, namely \(\alpha={\sigma }^{-2}\mathop{\sum }\nolimits_{\mu=1}^{n}({\bar{f}}_{\mu }-{y}_{\mu }){y}_{\mu }/({\sum }_{\mu }{y}_{\mu }^{2})\). We then used a JAX-based37 numerical solver for the EoS and compared it with the experiment.

As the results of Fig. 2b show, the predictions of our EoS for all these three quantities converged well as we increased N. Furthermore, they do so in a region where they differ considerably from their associated GP limit. Indeed, as shown in the Supplementary Material (3) the top Σ eigenvalue came out 2–3 times larger than it is in the GP limit. The associated eigenvector corresponded to the first layer weights of the teacher (w*). The rest of the eigenvalues remained at their GP limit values. Put together, this is a clear sign of strong feature learning.

Notably, however, this notion of feature learning does not involve compression. Indeed, since Σ has the same variance as in the GP limit for directions perpendicular to \({{{{{{\mathbf{{w }}}}}}}}\)*, it does not compress the input by projecting it solely on the label relevant direction (\({{{{{{\mathbf{{w }}}}}}}}\)*). Instead, it exaggerates the fluctuation of student weights along, \({{{{{{\mathbf{{w }}}}}}}}\)* thereby making it statistically more likely that \({h}_{i\mu }^{(1)}\) and \({h}_{i\nu }^{(1)}\) with opposite sign of \({{{{{{\mathbf{{w }}}}}}}}\)*\({{{{{{\mathbf{{x }}}}}}}}\)μ and \({{{{{{\mathbf{{w }}}}}}}}\)*\({{{{{{\mathbf{{x }}}}}}}}\)ν will be further apart in the space of pre-activations.

Next, we study how χ behaves as a function of n and C (or N) for different architectures. Figure 2 shows the value of N (or C) at which χ = 1. As χ contains a single inverse power of C at ten times this value, χ would be 0.1 and thus indicate only minor feature learning effects in our EoS. As N, C diminish from this latter value, our EoS yield increasingly stronger feature learning effects. We find that both for CNNs in the standard scaling and for FCNs with MF scaling, the crossover to feature learning happens well within the validity region of our mean-field decoupling (i.e. large N or C). In contrast, FCN with standard scaling shows this crossover when N = O(1), which is outside the scope of our theory. In this aspect, we comment that there is evidence that FCNs with standard scale are inferior to those with mean-field scaling38 and perform similarly to GPs34.

Analytical solution of the EoS - two layer CNN

Having tested our EoS numerically, we turn to show they lend themselves, in simple settings, to a fully analytical calculation. Amongst other things, this will flesh out the non-perturbative nature of our results. To this end, we consider a simple non-linear CNN with 2 layers. Though bounds have been derived39,40, we are not aware of any analytical predictions for the performance of finite non-linear 2-layer DNNs, let alone CNNs. It is therefore a natural first application of our approach. Specifically, we consider

$$f({{{{{{\mathbf{{x}}}}}}}})=\mathop{\sum }\limits_{i=1}^{N}\mathop{\sum }\limits_{c=1}^{C}{a}_{ic}{{{{{{{\rm{erf}}}}}}}}\left({{{{{{{\mathbf{{w}}}}}}}}}_{c}\cdot {{{{{{{\mathbf{{x}}}}}}}}}_{i}\right)$$
(7)

where \({{{{{{\mathbf{{x}}}}}}}}\in {{\mathbb{R}}}^{d}\) with d = NS and \({{{{{{{\mathbf{{w}}}}}}}}}_{c},{{{{{{{\mathbf{{x}}}}}}}}}_{i}\in {{\mathbb{R}}}^{S}\). The vector xi is given by the iS, . . , (i + 1)S − 1 coordinates of x. The dataset consists of \({\{{{{{{{{\mathbf{{x}}}}}}}}}_{\mu }\}}_{\mu=1}^{n}\) i.i.d. samples, each sample xμ is a centered Gaussian vector with covariance Id. We choose a linear target of the form \({y}_{\mu }={\sum }_{i}{a}_{i}^{*}({{{{{{{\mathbf{{w}}}}}}}}}^{*}\cdot {{{{{{{\mathbf{{x}}}}}}}}}_{\mu,i})\) where \({a}_{i}^{*} \sim {{{{{{{\mathscr{N}}}}}}}}(0,1/N)\) and \({w}_{s}^{*} \sim {{{{{{{\mathscr{N}}}}}}}}(0,1/S)\). This choice is not crucial, but does simplify considerably the GP inference part of the computation. We train this DNN using our Langevin algorithm and tune weight-decay and gradient noise such that, without any data, \({a}_{ic} \sim {{{{{{{\mathscr{N}}}}}}}}(0,{\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}{(NC)}^{-1})\) and \({[{{{{{{{{\rm{w}}}}}}}}}_{c}]}_{s} \sim {{{{{{{\mathscr{N}}}}}}}}(0,{\sigma }_{{{{{{{{\rm{w}}}}}}}}}^{2}{S}^{-1})\).

The equations of state are given by (See Supplementary Material (3))

$$\bar{{{{{{{\mathbf{{f}}}}}}}}}= {Q}_{f}{[{\sigma }^{2}{I}_{n}+{Q}_{f}]}^{-1}{{{{{{\mathbf{{y}}}}}}}}\\ {\left[{Q}_{f}\right]}_{\mu \nu }= \frac{{\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}}{N}\mathop{\sum}\limits_{i}G{\left({X}_{n}\Sigma {X}_{n}^{\top }\right)}_{\mu i,\nu i}\\ {\left[{\Sigma }^{-1}\right]}_{s{s}^{{\prime} }}= \frac{S}{{\sigma }_{{{{{{{{\rm{w}}}}}}}}}^{2}}{\delta }_{s{s}^{{\prime} }}-\frac{1}{C}{{{{{{{\rm{Tr}}}}}}}}\left\{({{{{{{{\boldsymbol{\varepsilon }}}}}}}}{{{{{{{{\boldsymbol{\varepsilon }}}}}}}}}^{\top }-{{K}_{f}}^{-1})\left.\frac{\partial {Q}_{f}}{\partial {\Sigma }_{s{s}^{{\prime} }}}\right)\right\}$$
(8)

Here we denote the discrepancy from the target by \({{{{{{{\boldsymbol{\varepsilon }}}}}}}}=({{{{{{\mathbf{{y}}}}}}}}-\bar{{{{{{{\mathbf{{f}}}}}}}}})/{\sigma }^{2}\). The above equations for \({\Sigma }_{s{s}^{{\prime} }}\) and δμ could be solved numerically. Once \({\Sigma }_{s{s}^{{\prime} }}\) is obtained, generalizing the EoS to test points is straightforward (see Supplementary Material (1.4.1)) and amounts to doing GP inference with Qf. In particular, to obtain the required values of Qf between a test point x* and a train point, xν one calculates \(2{\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}{\sum }_{i}{\sin }^{-1}\left(\frac{2{{{{{{{{\rm{x}}}}}}}}}_{*,i}\Sigma {{{{{{{{\rm{x}}}}}}}}}_{ \nu , i}}{\sqrt{1+2{{{{{{{{\rm{x}}}}}}}}}_{*,i}\Sigma {{{{{{{{\rm{x}}}}}}}}}_{*,i}}\sqrt{1+2{{{{{{{{\rm{x}}}}}}}}}_{\nu,i}\Sigma {{{{{{{{\rm{x}}}}}}}}}_{\nu,i}}}\right)/(\pi N)\). The results for both train and test losses are shown in Fig. 3 in solid lines and match empirical values well.

Fig. 3: Theory and experiment comparison for the 2-layer non-linear CNN.
figure 3

Left. Discrepancy measured along target direction normalized by that of the corresponding GP (α/αC). Dots denote empirical values for the test-set. Dashed and solid lines are theoretical predictions via an approximate-analytical and exact-numerical solution of the equations of state, respectively. Train-set comparisons are shown in the inset. Right. Statistics of \({{{{{{\mathbf{{w }}}}}}}}\)\({{{{{{\mathbf{{w }}}}}}}}\)*/\({{{{{{\mathbf{{w }}}}}}}}\)* of the trained CNN (blue) compared with \({{{{{{\mathbf{{w }}}}}}}}\) projected on a normalized random vector (green) for the n = 1600 experiment with C = 640. Dashed lines show fit to a Gaussian.

To obtain fully analytical results, we proceed with several approximations for large n. First, we approximate the spectrum of the matrix \({[{Q}_{f}]}_{\mu \nu }\) based on its continuum kernel version \({Q}_{f}({{{{{{\mathbf{{x}}}}}}}},{{{{{{{\mathbf{{x}}}}}}}}}^{{\prime} })\). This is closely related to the equivalent kernel41 approximation, which we adopt here along with its leading order correction36. Similarly, we use large n to replace the double summation \({\sum }_{\nu \mu }{\varepsilon }_{\mu }{\varepsilon }_{\nu }{[{Q}_{f}]}_{\mu \nu }\) by two integrals over the measure from which  \({{{{{{\boldsymbol{{x }}}}}}}}\)μ are drawn (dμ). See Supplementary Material (3.1) for further details and a discussion of the fully connected case (N = 1).

Following our approximations, the equations acquire the full rotation symmetry of the data-set measure, which amount to an independent orthogonal transformation of each xi. Furthermore, as shown in Supplementary Material (3), at large S, \(\bar{f}({{{{{{\mathbf{{x}}}}}}}})\) (the continuum function representing fμ) is linear given a linear target, y(\({{{{{{\mathbf{{x }}}}}}}}\)), regardless of Σ and hence so is ϵ(\({{{{{{\mathbf{{x }}}}}}}}\)). The above symmetry then implies that ε(\({{{{{{\mathbf{{x }}}}}}}}\)) is only a function of \({{{{{{\mathbf{{w }}}}}}}}\)*\({{{{{{\mathbf{{x }}}}}}}}\)i and furthermore takes the simple form ε(\({{{{{{\mathbf{{x }}}}}}}}\)) = αy(\({{{{{{\mathbf{{x }}}}}}}}\)). The quantity α thus measures the overlap between the discrepancy in predictions (ϵ) and the target. Following this, the EoS are reduced to a non-linear equation in a single variable α

$$\sigma^2 \alpha=1 - \frac{q_{{{\mathrm{train}}}} S \lambda_{\infty} l_*}{S \lambda_{\infty} l_*+\sigma^2/n}$$
(9)

where \(l_*=[1-\chi_2]^{-1}/S,\,\chi_2=C^{-1}\alpha^2n^2\lambda_{\infty}\), and λ is the dominant eigenvalue of \({Q}_{f}({{{{{{\mathbf{{x}}}}}}}},{{{{{{{\mathbf{{x}}}}}}}}}^{{\prime} })\) associated with a linear function in the limit C → , l* is the eigenvalue of Σ associated with w* (the remaining eigenvalues are inert and equal to 1/S), χ2 is the FLS for 2-layer CNN, and we assumed \(||{{{{{{{\mathbf{{a}}}}}}}}}^{*}|{|}^{2}=||{{{{{{{\mathbf{{w}}}}}}}}}^{*}|{|}^{2}={\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}={\sigma }_{{{{{{{{\rm{w}}}}}}}}}^{2}=1\) (see Supplementary Material (3). for more generic expressions). The quantity qtrain is exactly one in the equivalent kernel limit, and its perturbative correction (in 1/n) can be found in Supplementary Material (3).

Solving the above equation for α, one obtains l* and hence Σ and also Qf (via Eq. (8)). Using the obtained, Qf one can calculate the DNN’s predictions on the test-set. The effect of the FLS is evident in the second equation, where it controls the deviations from the GP limit. Here we also recall that α2 is the train MSE over σ4, thus χ2 as defined above, contains the MSE factor mentioned in the introduction.

To test the theoretical predictions, we trained two such CNNs, with n = {800, 1600}; S = 64; N = 20 and varying channel number. Figure 3, left panel, shows the empirical test-set values for α (dots) compared with a numerical solution of the equations of state (solid lines) and their analytical solution (dashed lines). For the latter, we obtained Σ analytically and performed the resulting GP inference with Qf numerically. The inset tracks the train-set results which, in this case, are fully analytical and involve no numerical GP inference. Both predictions match empirical values quite well, even in the regime where test root MSE is roughly half that of a Gaussian Process (C → ). The right panel shows the input layer weights, dotted with w*/w* and with a normalized random vector. These remain Gaussian up to minor statistical noise. Further details can be found in the methods section.

To emphasize the non-perturbative nature of our Eq. (9), let us assume for the sake of negation that they agree with first order perturbation theory in 1/C (as in refs. 3,6,23,24). If so, we may replace α in the above expression for χ2 by its GP value, as it already contains one negative power of C and hence receives no further corrections at that order. Numerics show this value αGP = 0.558 for n = 1600. Plugging this in, one obtains \({l}_{*}=\frac{2}{S}{[1-633.2/C]}^{-1}\). Clearly, this logic leads to a contradiction unless C 633.2. In contrast, our theory provides highly accurate predictions for n = 1600, C = 320 and C = 640 well away from where \(\frac{2}{S}{[1-633.2/C]}^{-1}\) admits a perturbation theory in 1/C. In Supplementary Material (6.2) we report additional results on l* over its GP value.

Extensions to Deeper CNNs and Subsets of Real-World Data-sets

For truly deep CNNs and real-world datasets, obtaining fully analytical predictions for DNN performance is a challenging task, even in the C →  limit. Still, the EoS could be solved numerically and compared with experimental values. Furthermore, the quantities which underlie them could be examined and reasoned upon. We do so here in two richer settings, a 3-layer CNN trained with a teacher CNN and the Myrtle-5 CNN42 trained on a subset of CIFAR-10. Our first setting extends that of the previous subsection by having an extra activated layer and a target function generated by a similar non-linear teacher network, having however a single channel. For more details, see section “Three layer CNN model definition”.

As the first test of our theory, we examine the fluctuations of pre-activations in the input and middle layers of the trained student CNN and check their normality. Specifically, for the input weights \({{{{{{\mathbf{{w }}}}}}}}\)c we obtain the histogram (over channels, equilibrium samples, and seeds) of \({{{{{{\mathbf{{w }}}}}}}}\)c\({{{{{{\mathbf{{w }}}}}}}}\)*/\({{{{{{\mathbf{{w }}}}}}}}\)*, where \({{{{{{\mathbf{{w }}}}}}}}\)* is the teacher input weight and the histogram of \({{{{{{\mathbf{{w }}}}}}}}\)c\({{{{{{\mathbf{{w }}}}}}}}\)r/\({{{{{{\mathbf{{w }}}}}}}}\)r where wr is a random vector. Teacher overlap has a variance of 0.254 here, whereas random overlap variance, averaged over choices of \({{{{{{\mathbf{{w }}}}}}}}\)r’s, was 0.039 with a std of 0.0043. For the hidden layer, we obtain the histogram of \({{{{{{{\mathbf{h}}}}}}}}_{{c}^{{\prime} }}^{(2)} \cdot {{{{{{{\mathbf{h}}}}}}}}^{(2),* }/|{{{{{{{\mathbf{h}}}}}}}}^{(1),* }|\) where h(2),* are the teacher’s pre-activations as well as \({{{{{{{\mathbf{{h}}}}}}}}}_{{c}^{{\prime} }}^{(2)}\cdot {{{{{{{\mathbf{{h}}}}}}}}}_{r}^{(2)}/|{{{{{{{\mathbf{{h}}}}}}}}}_{r}^{(2)}|\) where \({{{{{{{\mathbf{{h}}}}}}}}}_{r}^{(2)}\) is the pre-activation of a different randomly chosen teacher. Teacher overlap variance here was, 64.4 whereas average student variance was, 2.3 with a std of 0.12. Figure 4 shows the associated histograms along with their fit to a Gaussian. The large and consistent differences in the variance of the fluctuations between teacher directions and random directions show that we are deep in the feature learning regime. Remarkably, the fluctuations remain almost perfectly Gaussian. The larger variance along teacher directions implies that by drawing DNNs from the trained DNN ensemble and diagonalizing their empirical covariance matrices, one is more likely to find dominant eigenvalues along these teacher directions.

Fig. 4: Pre-activation statistics in the 3-layer student-teacher setting.
figure 4

Left. Histogram of student input weight vector, dotted with the normalized teacher weight (blue) and normalized random weight (green). Right. Histogram of student hidden layer pre-activations, dotted with the normalized teacher pre-activations (blue) and normalized pre-activations of a random teacher (green). Dots are empirical values and dashed lines are Gaussian fits. Insets: 2d histograms along the same vectors before (left) and after (right) training. Within our framework, these variances are determined by \({{{{{{\mathbf{{v }}}}}}}}\)K(l)\({{{{{{\mathbf{{v }}}}}}}}\) with \({{{{{{\mathbf{{v }}}}}}}}\) being either random unit vector or h(l) of the single channel teacher. Remarkably, despite strong changes to the kernels and various non-linearities in the action, the pre-activation is almost perfectly Gaussian.

We turn to verify the EoS and rationalize the behavior of the pre-kernels. To this end, we average the empirical pre-activations, over channels and training seeds, to obtain an estimator for the pre-kernel and post-kernel of h(2) (i.e. Q(2) and K(2)) and that of the input weights (Σ). We then obtain ∂DKL(K(2)Q(2))/∂Σ analytically using the 3rd equation from Eqs. (5), plugging in the empirical Σ. Finally, we compare the empirical Σ with that obtained from the last equation from Eqs. (5). Figure 5 left panel plots the eigenvalues of Σ, our predictions (Σpred), and the post-kernel of the input layer which is simply \({\sigma }_{{{{{{{{\rm{w}}}}}}}}}^{-2}{S}_{0}{I}_{{S}_{0}}\), showing a good match between the first two. Figure 5 right panel plots the eigenvalues of Q(2) as predicted from K(2), compared with its empirical value.

Fig. 5: Theoretical, empirical, and GP spectrum of kernels for the 3-layer non-linear student-teacher CNN setting.
figure 5

Left: the leading 50 eigenvalues making up the spectrum of Σ for the empirical Σ (blue), the predicted Σ based on the equations of state (red), and \(\Sigma\) at C →  (green). Right: the leading 10 eigenvalues of the post- kernel of the hidden layer (Q(2)) again for the empirical, predicted, and GP post kernels. The inset shows the next 90 trailing eigenvalues and reveals a gap separating the leading S1 eigenvalues from the rest. The dimension of Q(2) here is Nn = 2000, we focus on large eigenvalues as these dominate the predictions. The outlier in the left panel is aligned with the teacher weight vector  \({{{{{{\mathbf{{w }}}}}}}}\)*. The 30 (or N) outliers on the right panel are again a feature learning effect and represent the linear feature proportional to \({{{{{{\mathbf{{w }}}}}}}}\)* in each of the N latent pixels of the hidden layer.

Next, we trained the myrtle-5 CNN42, capable of good performance and containing both pooling layers and ReLU activations, with C = 256 on a subset of CIFAR-10 (n = 2048). Barring some one-dimensional integrals requiring numerical evaluation, our EoS generalize straightforwardly to ReLU activation, in which case we are required to track both K (variances) and means (See Supplementary Material 5). Thus, for ReLU feature, learning can manifest itself through both these quantities. Still, the results below suggest that changes to kernels play a dominant role (see also Supplementary Material 6.3). Further analysis of the impact of activation functions on feature learning is left for future work.

Figure 6 shows histograms of various linear combinations of pre-activations. These show a strong deviation of trained DNNs from non-trained DNNs or DNNs at infinite channel/width, and at the same time show quite a good fit to Gaussian in most cases. This opens the possibility of reverse engineering the pre-kernels governing this trained network and using them to rationalize about the DNN, for instance by identifying their dominant eigenvectors.

Fig. 6: Pre-activation statistics for the myrtle-5 CNN with ReLU activation and C = 256, trained on n = 2048 CIFAR-10 images.
figure 6

a, b Statistics of pre-activations of untrained (gray) and trained (blue) Myrtle-5 nets, in the 2nd layer. The histograms are aggregations over channels and initialization seeds of the pre-activations, projected on various eigenvectors of the corresponding kernel matrix. a A projection on the 1st (leading) eigenvector, while b is the same for the 10th eigenvector. c, d Are the same as a, b, but for the 4th layer. The fits to a Gaussian (dashed lines) provide an accuracy measure for our variational Gaussian approximation. Notice that the equilibrium distribution, even when it is near Gaussian, can differ from the initial one in either or both its mean and variance, a signature of feature learning. Generally, we find that Gaussianity increases with the depth of the layer (i.e. downstream layers are more Gaussian than upstream layers), and with the index of the eigenvector, we project on (e.g. here projection on the 10th eigenvector is more Gaussian than on the 1st eigenvector). More details can be found in Supplementary Material (6.3), where we also show that correlations are small across different layers and across channels within each layer, thus justifying neglecting these in our theory.

The 2nd layer (as well as the input layer, (see Supplementary Material (6.3))) show deviations from Gaussianity in the leading eigenvalue. This is expected since the kernels of these layers show quite a dilute dominant spectrum, whereas VGA requires a contribution from many adjacent modes (see Supplementary Material (1.3)). Interestingly, despite this non-Gaussianity in the leading eigenvalue of layers 1 and 2, Gaussainity is restored in the downstream layers 3 and 4. Correlations across layers and across channels within the same layer are very weak (largely on the order of 10−3) and fully consistent with the mean-field decoupling underlying this work. Further technical details are found in Supplementary Material (6.3).

Discussion

In this work, we presented what is, to the best of our knowledge, a novel mean-field framework for analyzing finite deep non-linear neural networks in the feature learning regime. Central to our analysis was a series of mean-field approximations, revealing that pre-activations are weakly correlated between layers and follow a Gaussian distribution within each layer with a pre-kernel K(l). Using the latter together with the post-kernel Q(l) induced by the upstream layer, explicit equations-of-state (EoS) governing the statistics of the hidden layers were given. These enabled us to derive, for the first time, analytical predictions for the performance of non-linear CNNs and deep non-linear FCNs in the feature learning regime. We further note that our EoS generalizes straightforwardly to combined CNN-FCN architectures, ReLU activation functions (see Supplementary Material 5), pooling layers, and models with multiple outputs. The GP represented by the equation we find is a good approximation to the true posterior distribution generated by a large but finite-width Bayesian neural network. Thus, among the other advantages of BNNs, providing reliable uncertainty estimates and principled model comparison33, they may also admit a concrete interpretation through an effective GP.

Various aspects of this work invite further study. Empirically, it would be interesting to better characterize the scope of models for which a Bayesian sampler (or potentially ensemble-averaged NTK dynamics) leads to Gaussian pre-activations and overall GP-like behavior. Probing the “feature-learning-load” of each layer, by experimentally measuring the differences between the kernels Q(l) and K(l), may also provide insights on generalization, transfer learning, and pruning, thus complementing other diagnostic tools suggested recently43. For instance, transferring a layer with a small feature learning load may provide little benefit, and pruning a channel having a large overlap with a leading eigenvalue of K(l) − Q(l) may be harmful.

From the theory side, it is desirable to develop analytical techniques for solving the EoS as well as guarantees regarding the existence and uniqueness of solutions. In particular, exploring the possibility of spontaneous symmetry breaking of internal symmetries such as weight inversion. Providing a mathematical underpinning for the approximations involved here may lend itself to developing performance bounds on the Langevin algorithm and Bayesian neural networks12,44. Similarly, one can consider using the empirical effective kernel (Qf) as a starting point to develop GP-based bounds45 on performance. Last, it is interesting to explore the approach to equilibrium of the training dynamics and adapt the approximations carried here to the NTK setting1,27.

Methods

Mean field action

Here we present the main ingredients of our theory, leading to the EoS we find. Further details can be found in the Supplementary Material.

Decoupling of layers and neurons

First, we provide decoupling of Eq. (3) into layer-wise neuron-wise terms, wherein each of the terms depends on the upstream and downstream layers only through channel-averaged second cumulants of activations and pre-activations (pre-kernel and post-kernel). Further details are found in Supplementary Material (1).

Consider the non-linear terms in action in Eq. (3) which couple the different layers. This coupling is mediated through the channel/width-averaged quantities: indeed, h(1) depends on h(2) through the channel/width averaged square term in h(2), h(2) depends on h(1) through the average of ϕ(h(1))ϕ(h(1)), and h(3) depends on h(2) through the average of ϕ(h(2))ϕ(h(2)) and so forth. For, Nl 1 we expect these to be weakly fluctuating and well approximated by their mean-field values. This behavior propagates till the output layer, and in particular implies that the outputs f fluctuate in a Gaussian manner, as previously conjectured23. As for the dependency of h(L−1) on the f variables, it is not through a channel/width averaged quantity. However, we find that in various scenarios, such as FCNs with MF scaling or CNNs with large N, the fluctuations of f are suppressed enabling us to replace f by its average (see Supplementary Material (1.7) and (2.2)). Following this, we obtain our mean-field action,

$${{{{{{{{\mathscr{S}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}=\mathop{\sum }\limits_{l=1}^{L-1}{{{{{{{{\mathscr{S}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}^{(l)}+{{{{{{{{\mathscr{S}}}}}}}}}_{f,{{{{{{{\rm{MF}}}}}}}}}.$$
(10)

This allows us to define the 〈. . . 〉MF with respect to the distribution: \({\pi }_{{{{{{{{\rm{MF}}}}}}}}}({\{{{{{{{{{\rm{h}}}}}}}}}^{(l)}\}}_{l=1}^{L-1},{{{{{{{\rm{f}}}}}}}})={e}^{-{{{{{{{{\mathscr{S}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}}/{{{{{{{{\mathscr{Z}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}\), and the partition function \({{{{{{{{\mathscr{Z}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}=\int\mathop{\prod }\nolimits_{l=1}^{L-1}d{{{{{{{{\rm{h}}}}}}}}}^{(l)}d{{{{{{{\rm{f}}}}}}}}{\pi }_{{{{{{{{\rm{MF}}}}}}}}}({\{{{{{{{{{\rm{h}}}}}}}}}^{(l)}\}}_{l=1}^{L-1},{{{{{{{\rm{f}}}}}}}})\). Such that

$${{{{{{{{\mathscr{S}}}}}}}}}_{f,{{{{{{{\rm{MF}}}}}}}}}=\frac{1}{2{\sigma }^{2}}\mathop{\sum}\limits_{\mu }{({f}_{\mu }-{y}_{\mu })}^{2}+\mathop{\sum}\limits_{\mu \nu }\frac{1}{2}{f}_{\mu }{[{Q}_{f}]}_{\mu \nu }^{-1}{f}_{\nu }$$
(11)
$${{{{{{{{\mathscr{S}}}}}}}}}_{{{{{{{{\rm{MF}}}}}}}}}^{(l)}=\frac{1}{2}{N}_{l+1}\mathop{\sum}\limits_{\mu \nu }{A}_{\mu \nu }^{(l+1)}{\tilde{Q}}_{\mu \nu }^{(l+1)}({{{{{{{{\rm{h}}}}}}}}}^{(l)})+\mathop{\sum}\limits_{\mu \nu j}\frac{1}{2}{h}_{\mu j}^{(l)}{[{Q}^{(l)}]}_{\mu \nu }^{-1}{h}_{\nu j}^{(l)}\,\,\,\,\,\,\,{{\mbox{for}}}\,l\in [1,\, L-1]$$
(12)

where \({A}^{(L)}={\sigma }^{-4}({{{{{{\mathbf{{y}}}}}}}}-\bar{{{{{{{\mathbf{{f}}}}}}}}}){({{{{{{\mathbf{{y}}}}}}}}-\bar{{{{{{{\mathbf{{f}}}}}}}}})}^{\top }-{[{Q}_{f}+{\sigma }^{2}{I}_{n}]}^{-1}\), and \({A}^{(l)}={[{Q}^{(l)}]}^{-1}\big({I}_{n}- {\langle {{{{{{{\mathbf{{h}}}}}}}}}_{j}^{(l)}{({{{{{{{\mathbf{{h}}}}}}}}}_{j}^{(l)})}^{T}\rangle }_{{{{{{{{\rm{MF}}}}}}}}}{[{Q}^{(l)}]}^{-1}\big)\). The post-kernels are defined self-consistently as \({Q}_{f}={\langle {\tilde{Q}}_{f}\rangle }_{{{{{{{{\rm{MF}}}}}}}}}\), and \({Q}^{(l)}={\langle {\tilde{Q}}^{(l)}\rangle }_{{{{{{{{\rm{MF}}}}}}}}}\) for l [1, L − 1].

Notably, any coupling between the different layers is only through static mean-field quantities, namely the pre-kernels and-post kernels. In addition, all neuron-neuron couplings (and similarly, channel-channel couplings for CNNs) have been removed.

Intra-layer decoupling

Despite the simplified inter-layer coupling and intra-layer neuron coupling, the mean-field actions are still non-quadratic for all layers but the output layer. This non-linearity couples all the h(l) variables for the same neuron (channel in the CNN case) in a way that is roughly all-to-all in the data-point index. In atomic and nuclear physics, similar circumstances are well described by self-consistent Hartree-Fock approximations46,47,48,49. In our setting, this approximation is directly analogous to a variational Gaussian approximation (VGA). In Supplementary Material (4) we argue that in the typical case where the diagonal of K(l) is much larger than the off-diagonal elements, the VGA is well controlled. Technically, we do so by showing, order by order in perturbation theory, that the diagrams accounted for by the VGA approximation dominate all other perturbation theory diagrams. In Supplementary Material (3) we also establish this using different means for S0 1 for the specific case of two-layer CNN with a single activated layer. We further comment that the VGA is exact for deep linear DNNs.

Accordingly, we now look for the Gaussian distribution, governed by a kernel K(l) which is the closest to the above non-quadratic action. In models with many hidden layers, this leads to the following “inverse kernel shift” behavior for, 1 < l < L − 2

$${[{[{K}^{(l-1)}]}^{-1}]}_{\mu \nu }={[{[{Q}^{(l-1)}]}^{-1}]}_{\mu \nu }+\frac{2{N}_{l}}{{N}_{l-1}}\frac{\partial {D}_{{{{{{{{\rm{KL}}}}}}}}}({K}^{(l)}||{Q}^{(l)})\,\,}{\partial {[{K}^{(l-1)}]}_{\mu \nu }}$$
(13)

where l denotes a layer index and DKL(AB) is the Kullback-Leibler (KL) divergence between two multivariate Gaussians with covariance matrices A and B. As shown in Supplementary Material (1), for antisymmetric activation functions, the derivative of the KL-divergence is explicitly given by Tr\([{[{Q}^{(l)}]}^{-1}({K}^{(l)}-{Q}^{(l)}){[{Q}^{(l)}]}^{-1}(\partial {Q}^{(l)}/\partial {K}_{i\mu,j\nu }^{(l-1)})]\). For non-anti-symmetric ones, see Supplementary Material (5).

Three layer CNN model definition

For the analysis in section “Extensions to deeper CNNs and subsets of real-world data-sets”, we consider a student 3-layer CNN defined by

$$f({{{{{{\mathbf{{x}}}}}}}})= \mathop{\sum }\limits_{j=0}^{N-1}\mathop{\sum }\limits_{{c}^{{\prime} }=1}^{{C}_{2}}{a}_{{c}^{{\prime} }j}\phi \left({h}_{{c}^{{\prime} }j}^{(2)}({{{{{{\mathbf{{x}}}}}}}})\right)\\ {h}_{{c}^{{\prime} }j}^{(2)}({{{{{{\mathbf{{x}}}}}}}})= \mathop{\sum }\limits_{i=0}^{{S}_{1}-1}\mathop{\sum }\limits_{c=1}^{{C}_{1}}{v}_{{c}^{{\prime} }ci}\phi \left({h}_{cji}^{(1)}({{{{{{\mathbf{{x}}}}}}}})\right)\\ {h}_{cji}^{(1)}({{{{{{\mathbf{{x}}}}}}}})= {{{{{{{\mathbf{{w}}}}}}}}}_{c} \cdot {{{{{{{\mathbf{{x}}}}}}}}}_{i+j{S}_{1}}$$
(14)

Where \({{{{{{{\mathbf{{w}}}}}}}}}_{c} ,\, {{{{{{{\mathbf{{x}}}}}}}}}_{i+j{S}_{1}}\in {{\mathbb{R}}}^{{S}_{0}}\), \(a\in {{\mathbb{R}}}^{{C}_{2}\times N}\), \(v\in {{\mathbb{R}}}^{{C}_{2}\times {C}_{1}\times {S}_{1}}\), and the input vector \({{{{{{\mathbf{{x}}}}}}}}\in {{\mathbb{R}}}^{d}\) with d = NS1S0, and the activation function, \(\phi :{\mathbb{R}}\to {\mathbb{R}}\), is applied element-wise. See Fig. 7 for illustration. Similarly to before, the regression target (yμ) is generated by a random teacher CNN (yμ = f*(\({{{{{{\mathbf{{x }}}}}}}}\)μ)) having the same architecture as the student, only with C1 = C2 = 1. In addition, we chose S0 = 50, S1 = 30, N = 2, C1 = C2 = 100 for the student. We denote teacher weights and pre-activations by * subscripts. Further details are found in the supplementary Information section 2.

Fig. 7: A 3 layers CNN architecture.
figure 7

An illustration with non-overlaping strides. The black squares represent the strides window in each layer.

Bayesian posterior sampling with Langevin-type dynamics

In this section, we give more details regarding the algorithm used to generate samples from the Bayesian posterior. We train the DNNs using full-batch gradient descent with weight decay and external white Gaussian noise. The discrete-time dynamics of the parameters are thus

$${{{{{{{\boldsymbol{\theta }}}}}}}}_{t+1}-{{{{{{{\boldsymbol{\theta}}}}}}}}_{t}=-\eta \left(\gamma {{{{{{{\boldsymbol{\theta}}}}}}}}_{t}+{\nabla }_{{{{{{{{\boldsymbol{\theta }}}}}}}}_{t}}{{{{{{{\mathscr{L}}}}}}}}\left({{{{{{{\boldsymbol{\theta }}}}}}}}_{t},\, {{{{{{{{\mathscr{D}}}}}}}}}_{n}\right)\right)+2\sigma \sqrt{\eta }{{{{{{{\boldsymbol{\xi}}}}}}}}_{t+1}$$
(15)

where \({{{{{{\boldsymbol{\theta }}}}}}}\)t is the vector of all network parameters at time step t, γ is the strength of the weight decay (which, from the Bayesian perspective, is inversely proportional to the variance of the parameters for the prior), \({{{{{{{\mathscr{L}}}}}}}}({{{{{{{\mathbf{{\theta }}}}}}}}}_{t},{{{{{{{{\mathscr{D}}}}}}}}}_{n})\) is the loss as a function of the DNN parameters \({{{{{{\boldsymbol{\theta }}}}}}}\)t, and data, σ is the magnitude of noise, η is the learning rate and \({\xi }_{t,i} \sim {{{{{{{\mathscr{N}}}}}}}}(0,1)\). As η → 0 these discrete-time dynamics converge to the continuous-time Langevin equation given by \(\dot{{{{{{{\boldsymbol{\theta }}}}}}}}(t)=-{\nabla }_{{{{{{{\boldsymbol{\theta }}}}}}}}\left(\frac{\gamma }{2}\parallel {{{{{{\boldsymbol{\theta }}}}}}}(t){\parallel }^{2}+{{{{{{{\mathscr{L}}}}}}}}\left({{{{{{\boldsymbol{\theta }}}}}}}(t),{{{{{{{{\mathscr{D}}}}}}}}}_{n}\right)\right)+2\sigma {{{{{{\boldsymbol{\xi }}}}}}}\left(t\right)\) with \(\left\langle {\xi }_{i}(t){\xi }_{j}({t}^{{\prime} })\right\rangle={\delta }_{ij}\delta \left(t-{t}^{{\prime} }\right)\), such that as t →  the DNN parameters \({{{{{{\boldsymbol{\theta }}}}}}}\) will be sampled from the equilibrium Gibbs distribution in parameter space \(p({{{{{{\boldsymbol{\theta }}}}}}}|{{{{{{{{\mathscr{D}}}}}}}}}_{n})\). In principle, this is true regardless of the initial condition given to the parameters. However, in practice, to achieve reasonable convergence times we set a random initial condition with zero mean and a variance that matches that of the Bayesian prior.

This algorithm, used to generate samples from the Bayes posterior, corresponds to the Unadjusted Langevin Algorithm (ULA)50,51,52, together with a weight decay term. Hyperparameters such as learning rate, weight decay, and noise level (or mini-batch size for SGD) can be experimented with and compared across different training protocols, thus making ours more tightly connected to gradient-based algorithms used by practitioners for training DNNs, while also conforming to the Bayesian perspective. While this method of posterior sampling may be slower to converge compared to other more sophisticated samplers (such as HMC33), it is simpler (e.g. has no Metropolis acceptance step, hence the word “unadjusted”) and admits an intuitive correspondence with vanilla SGD, where the mini-batch noise is replaced with white additive noise31. Under some rather mild conditions, ULA has been shown to have good convergence properties50.

Experimental details

Hyperparameters

For the 2-layer CNN experiments, we used S = 64, N = 20, and varying channel number. The training parameters (noise and weight-decay) were tuned such that σ2 = 0.1 and weight variance of 2.0 over fan-in, for both layers at n = 0. The target was drawn once for all experiments using i.i.d. Gaussian centered random \({a}_{i}^{*}\) and \({w}_{s}^{*}\) with variances 1/N and 1/S respectively.

For the 3-layer CNN experiments, we took S1 = 50, S0 = 30, N = 2. The training parameters (noise and weight-decay) were scaled such that σ2 = 0.005 and weight variance of 2.0 over fan-in for the inputs and hidden layer with no training data (at initialization). The weight variance of the read-out layer was 15 over the fan-in. The target was drawn again once for all experiments from a teacher CNN with C = 1.

For all the myrtle-5 experiments, we used n = 2048, C = 256 and ReLU activation. The training parameters (noise and weight-decay) were scaled such that σ2 = 0.005 and weight variance of 2.0 over fan-in for all layers with no training data (at initialization).

For all the FCN experiments, we used equal width (N1 = N2) and weight decay corresponding to variance, \({\sigma }_{{{{{{{{\rm{w}}}}}}}}}^{2}={\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}=2\) (with no training data) in the regular scaling. For the MF scaling, we took \({\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}=2/{N}_{2}\). The target was drawn again once for all experiments from a teacher CNN with N1 = N2 = 1. Specifically, when calculating the emergent scale, we used \({\sigma }_{{{{{{{{\rm{a}}}}}}}}}^{2}=2/256\) independent of N2.

Equilibrium sampling

To obtain weakly correlated samples from the equilibrium distribution of the trained CNNs we used the following procedure. For the 2 and 3-layer CNNs, we used an adaptive learning rate scheduler: For the first 100 epochs we used a learning rate lr0/10, then we crank up the learning rate to lr0. As of epoch 5e3, every 1e3 epoch we estimate the fluctuations of the train-loss and check for spikes—events in which the train-loss was five times larger than the standard deviation in the past 500 epochs. If a spike is observed, the learning rate is reduced by a factor of 0.7. This continues until 5e4 epochs pass without any events. Then the learning rate is reduced again by a factor of two and remains fixed. Samples from these final stages were treated as equilibrium samples. We further checked that (i) different initialization seeds trained with this protocol reached the same train-loss statistics. (ii) No further reduction in train-loss occurred after the final learning rate reduction. For several runs, we also verified that increasing the last reduction of learning rate by an additional factor of 2 did not have any appreciable effect on the loss. The initial lr0 was ~ 1e − 4 (w.r.t. a standard mean reduction MSE loss) and the final learning rate was typically ~ 1e − 5. The runs terminated at epoch 3e5.

For the myrtle-5 CNN trained on CIFAR-10, we first ran several runs for 3e5 epochs using the above procedure and examined those that reached the lowest train-loss. We then generated a fixed scheduler based on those more successful instances, running up to 4e5 epochs. We again verified that further lowering the final learning rate has no appreciable effect on the training loss, and that different seeds reach similar final train-loss. This ensures that we are indeed sampling from a valid equilibrium distribution.

For the 3-layer CNN and Myrtle-5, we found that auto-correlation times of pre-activations change considerably between the layers. While the read-out layer typically had an auto-correlation time of the order of 1e3 epochs (at the lowest learning rates) the auto-correlation times for the input layers could reach ~1e6 or larger values. To overcome this issue, when analyzing pre-activations of these deeper DNNs we took an ensemble containing 98 and 234 different initialization seeds for the 3-layer CNN and Myrtle-5 respectively.

For the 3-layer FCN we used a fixed scheduler which starts at 1/2 the maximal stable learning rate and reduces the learning rate by factors of 2 at 100, 1e5, 1e6, 3e6 epochs and by a factor of 4 at 4e6, 5e6 epochs (factor of 128 in total). Equilibrium sampling was done between 6e6 − 7e6 epochs.

Numerical solution of the equations of state

For the 2-layer CNN, the equations of state were solved using Newton-Krylov method53, which does not require explicit gradients. To facilitate convergence, we adopted an annealing procedure: For C ~ 1e3, we obtain the solution using a GP initial value (x0) for Σ. The optimization outcome was then used as x0 for the next lower value of C. Using 12 CPU cores, this optimization took several hours. After obtaining \({\Sigma }_{s{s}^{{\prime} }}\) as a function of C, the resulting kernel \({[{Q}_{f}]}_{\mu \nu }\) was used in standard GP inference to obtain f on the test-set. For the 3-layer FCN we used a more efficient JAX-based code to generate the kernels and kernel derivatives involved in the EoS, but otherwise followed the same procedure. Optimization took between several minutes to a few hours on one Titan-X GPU, depending on parameters.