5.5. Generalization in Deep Learning¶ Open the notebook in SageMaker Studio Lab
In Section 3 and Section 4, we tackled regression and classification problems by fitting linear models to training data. In both cases, we provided practical algorithms for finding the parameters that maximized the likelihood of the observed training labels. And then, towards the end of each chapter, we recalled that fitting the training data was only an intermediate goal. Our real quest all along was to discover general patterns on the basis of which we can make accurate predictions even on new examples drawn from the same underlying population. Machine learning researchers are consumers of optimization algorithms. Sometimes, we must even develop new optimization algorithms. But at the end of the day, optimization is merely a means to an end. At its core, machine learning is a statistical discipline and we wish to optimize training loss only insofar as some statistical principle (known or unknown) leads the resulting models to generalize beyond the training set.
On the bright side, it turns out that deep neural networks trained by stochastic gradient descent generalize remarkably well across myriad prediction problems, spanning computer vision; natural language processing; time series data; recommender systems; electronic health records; protein folding; value function approximation in video games and board games; and numerous other domains. On the downside, if you were looking for a straightforward account of either the optimization story (why we can fit them to training data) or the generalization story (why the resulting models generalize to unseen examples), then you might want to pour yourself a drink. While our procedures for optimizing linear models and the statistical properties of the solutions are both described well by a comprehensive body of theory, our understanding of deep learning still resembles the wild west on both fronts.
Both the theory and practice of deep learning are rapidly evolving, with theorists adopting new strategies to explain what’s going on, even as practitioners continue to innovate at a blistering pace, building arsenals of heuristics for training deep networks and a body of intuitions and folk knowledge that provide guidance for deciding which techniques to apply in which situations.
The summary of the present moment is that the theory of deep learning has produced promising lines of attack and scattered fascinating results, but still appears far from a comprehensive account of both (i) why we are able to optimize neural networks and (ii) how models learned by gradient descent manage to generalize so well, even on high-dimensional tasks. However, in practice, (i) is seldom a problem (we can always find parameters that will fit all of our training data) and thus understanding generalization is far the bigger problem. On the other hand, even absent the comfort of a coherent scientific theory, practitioners have developed a large collection of techniques that may help you to produce models that generalize well in practice. While no pithy summary can possibly do justice to the vast topic of generalization in deep learning, and while the overall state of research is far from resolved, we hope, in this section, to present a broad overview of the state of research and practice.
5.5.1. Revisiting Overfitting and Regularization¶
According to the “no free lunch” theorem of Wolpert and Macready (1995), any learning algorithm generalizes better on data with certain distributions, and worse with other distributions. Thus, given a finite training set, a model relies on certain assumptions: to achieve human-level performance it may be useful to identify inductive biases that reflect how humans think about the world. Such inductive biases show preferences for solutions with certain properties. For example, a deep MLP has an inductive bias towards building up a complicated function by the composition of simpler functions.
With machine learning models encoding inductive biases, our approach to training them typically consists of two phases: (i) fit the training data; and (ii) estimate the generalization error (the true error on the underlying population) by evaluating the model on holdout data. The difference between our fit on the training data and our fit on the test data is called the generalization gap and when this is large, we say that our models overfit to the training data. In extreme cases of overfitting, we might exactly fit the training data, even when the test error remains significant. And in the classical view, the interpretation is that our models are too complex, requiring that we either shrink the number of features, the number of nonzero parameters learned, or the size of the parameters as quantified. Recall the plot of model complexity compared with loss (Fig. 3.6.1) from Section 3.6.
However deep learning complicates this picture in counterintuitive ways. First, for classification problems, our models are typically expressive enough to perfectly fit every training example, even in datasets consisting of millions (Zhang et al., 2021). In the classical picture, we might think that this setting lies on the far right extreme of the model complexity axis, and that any improvements in generalization error must come by way of regularization, either by reducing the complexity of the model class, or by applying a penalty, severely constraining the set of values that our parameters might take. But that is where things start to get weird.
Strangely, for many deep learning tasks (e.g., image recognition and text classification) we are typically choosing among model architectures, all of which can achieve arbitrarily low training loss (and zero training error). Because all models under consideration achieve zero training error, the only avenue for further gains is to reduce overfitting. Even stranger, it is often the case that despite fitting the training data perfectly, we can actually reduce the generalization error further by making the model even more expressive, e.g., adding layers, nodes, or training for a larger number of epochs. Stranger yet, the pattern relating the generalization gap to the complexity of the model (as captured, for example, in the depth or width of the networks) can be non-monotonic, with greater complexity hurting at first but subsequently helping in a so-called “double-descent” pattern (Nakkiran et al., 2021). Thus the deep learning practitioner possesses a bag of tricks, some of which seemingly restrict the model in some fashion and others that seemingly make it even more expressive, and all of which, in some sense, are applied to mitigate overfitting.
Complicating things even further, while the guarantees provided by
classical learning theory can be conservative even for classical models,
they appear powerless to explain why it is that deep neural networks
generalize in the first place. Because deep neural networks are capable
of fitting arbitrary labels even for large datasets, and despite the use
of familiar methods such as
5.5.2. Inspiration from Nonparametrics¶
Approaching deep learning for the first time, it is tempting to think of them as parametric models. After all, the models do have millions of parameters. When we update the models, we update their parameters. When we save the models, we write their parameters to disk. However, mathematics and computer science are riddled with counterintuitive changes of perspective, and surprising isomorphisms between seemingly different problems. While neural networks clearly have parameters, in some ways it can be more fruitful to think of them as behaving like nonparametric models. So what precisely makes a model nonparametric? While the name covers a diverse set of approaches, one common theme is that nonparametric methods tend to have a level of complexity that grows as the amount of available data grows.
Perhaps the simplest example of a nonparametric model is the
Note that
In a sense, because neural networks are over-parametrized, possessing many more parameters than are needed to fit the training data, they tend to interpolate the training data (fitting it perfectly) and thus behave, in some ways, more like nonparametric models. More recent theoretical research has established deep connection between large neural networks and nonparametric methods, notably kernel methods. In particular, Jacot et al. (2018) demonstrated that in the limit, as multilayer perceptrons with randomly initialized weights grow infinitely wide, they become equivalent to (nonparametric) kernel methods for a specific choice of the kernel function (essentially, a distance function), which they call the neural tangent kernel. While current neural tangent kernel models may not fully explain the behavior of modern deep networks, their success as an analytical tool underscores the usefulness of nonparametric modeling for understanding the behavior of over-parametrized deep networks.
5.5.3. Early Stopping¶
While deep neural networks are capable of fitting arbitrary labels, even when labels are assigned incorrectly or randomly (Zhang et al., 2021), this capability only emerges over many iterations of training. A new line of work (Rolnick et al., 2017) has revealed that in the setting of label noise, neural networks tend to fit cleanly labeled data first and only subsequently to interpolate the mislabeled data. Moreover, it has been established that this phenomenon translates directly into a guarantee on generalization: whenever a model has fitted the cleanly labeled data but not randomly labeled examples included in the training set, it has in fact generalized (Garg et al., 2021).
Together these findings help to motivate early stopping, a classic
technique for regularizing deep neural networks. Here, rather than
directly constraining the values of the weights, one constrains the
number of epochs of training. The most common way to determine the
stopping criterion is to monitor validation error throughout training
(typically by checking once after each epoch) and to cut off training
when the validation error has not decreased by more than some small
amount
Notably, when there is no label noise and datasets are realizable (the classes are truly separable, e.g., distinguishing cats from dogs), early stopping tends not to lead to significant improvements in generalization. On the other hand, when there is label noise, or intrinsic variability in the label (e.g., predicting mortality among patients), early stopping is crucial. Training models until they interpolate noisy data is typically a bad idea.
5.5.4. Classical Regularization Methods for Deep Networks¶
In Section 3, we described several classical
regularization techniques for constraining the complexity of our models.
In particular, Section 3.7 introduced a method called
weight decay, which consists of adding a regularization term to the loss
function in order to penalize large values of the weights. Depending on
which weight norm is penalized this technique is known either as ridge
regularization (for
In deep learning implementations, weight decay remains a popular tool.
However, researchers have noted that typical strengths of
Notably, deep learning researchers have also built on techniques first popularized in classical regularization contexts, such as adding noise to model inputs. In the next section we will introduce the famous dropout technique (invented by Srivastava et al. (2014)), which has become a mainstay of deep learning, even as the theoretical basis for its efficacy remains similarly mysterious.
5.5.5. Summary¶
Unlike classical linear models, which tend to have fewer parameters than examples, deep networks tend to be over-parametrized, and for most tasks are capable of perfectly fitting the training set. This interpolation regime challenges many hard fast-held intuitions. Functionally, neural networks look like parametric models. But thinking of them as nonparametric models can sometimes be a more reliable source of intuition. Because it is often the case that all deep networks under consideration are capable of fitting all of the training labels, nearly all gains must come by mitigating overfitting (closing the generalization gap). Paradoxically, the interventions that reduce the generalization gap sometimes appear to increase model complexity and at other times appear to decrease complexity. However, these methods seldom decrease complexity sufficiently for classical theory to explain the generalization of deep networks, and why certain choices lead to improved generalization remains for the most part a massive open question despite the concerted efforts of many brilliant researchers.
5.5.6. Exercises¶
In what sense do traditional complexity-based measures fail to account for generalization of deep neural networks?
Why might early stopping be considered a regularization technique?
How do researchers typically determine the stopping criterion?
What important factor seems to differentiate cases when early stopping leads to big improvements in generalization?
Beyond generalization, describe another benefit of early stopping.