Why Deep Learning Works Even Though It Shouldn’t

This is a big question, and I’m not a particularly big person. As such, these are all likely to be obvious observations to someone deep in the literature and theory. What I find however is that there are a base of unspoken intuitions that underlie expert understanding of a field, that are never directly stated in the literature, because they can’t be easily proved with the rigor that the literature demands. And as a result, the insights exist only in conversation and subtext, which make them inaccessible to the casual reader.

Because I have no need of rigor to post on the internet, (or even a need to be correct) I’m going to post some of those intuitions here as I (not an expert) understand them. Since the best way to get the right answer on the internet is to post the wrong one, I’ve gotten a lot of good feedback from people about this, and updated it accordingly. If this is all obvious to you, skip to the section on “Suggestions for Research” because there are a lot of ways that I think typical papers ignore things that most researchers believe to be true.

In particular I find that people from a statistics background tend to throw up their hands at deep learning, because from a traditional statistics perspective, none of it can possibly work. This makes it very frustrating that it does. As a result they tend to have a much more dim view of its results and methods than their continued success warrants, so I hope here that I can bridge some of that gap. 

The key thing I’m going to try to intuitively explain is why models always get better when they are bigger and deeper, even when the amount of data they consume stays the same or gets smaller. Some of this might turn out to be wrong, but I think it’s much more likely to be incomplete than to be wrong. The effects I describe here likely matter, even though it’s possible they aren’t the dominant causes. There is going to be nothing terribly formal here, which will madden some people, and relieve others. If you find this all irritatingly hand wavy, go read papers about the lottery ticket theory of deep learning instead, because I think that’s the closest thing to a formal theory that encapsulates most of this and is currently making progress.

So here goes.

If you start your parameters in a reasonable place, they’re already close to good ones, even though they’re totally random.

In high dimensional spaces, distance is a statistical concept. Squared euclidean distance is just a big sum, and statistics tells us what happens to all big sums. They become normal distributions, and they become relatively tighter and tighter around their mean as the number of terms in the sum increases. This means that when there is any amount of well behaved randomness involved, all distances in high dimensions are about the same. In a model, with parameters that begin as random variables due to initialization, and end as random variables due to the nature of the data, the central limit theorem applies to these sums. So all sets of parameters in a high dimensional model are about equally close to/far from each other.

In the dimensions we live in, we’re used to the idea that some things are closer together than other things, so we mentally think of concepts like “regions” and think about things like bad regions and good regions for parameters. But high dimensional spaces are extremely well connected. You can get to anywhere with a short jump from anywhere else. There are no bad places to start. If the magnitudes of the random initialization are about right, all places are reasonably good. No matter where you start, you’re close to good parameters, and you’re as likely to be close to good parameters as to any others. The only assumption we need for this to be formally true is that there are lots of good sets of parameters, and that they come from roughly the same distribution as the parameters at initialization, which is pretty mild.

High dimensional spaces are unlikely to have local optima, and probably don’t have any optima at all.

Just recall what is necessary for a set of parameters to be at a optimum. All the gradients need to be zero, and the hessian needs to be positive semidefinite. In other words, you need to be surrounded by walls. In 4 dimensions, you can walk through walls. GPT3 has 175 billion parameters. In 175 billion dimensions, walls are so far beneath your notice that if you observe them at all it is like God looking down upon individual protons.

If there’s any randomness at all in the loss landscape, which of course there is, it’s vanishingly unlikely that all of the millions or billions of directions the model has to choose from will be simultaneously uphill. With so many directions to choose from you will always have at least one direction to escape. It’s just completely implausible that any big model comes close to any optima at all. In fact it’s implausible that an optimum exists. Unless you have a loss function that has a finite minimum value like squared loss (not cross entropy or softmax), or without explicit regularization that bounds the magnitude of the values, forces positive curvature, and hurts performance of the model, all real models diverge.

Look at this scrub walking through walls in 4 dimensions. Hah. You are like little baby.

When Gradient Descent can’t tell two things apart, it averages them together.

There are two regimes of deep learning training depending on the amount of data available and the size of the model, the “classical” regime, where the model eventually overfits on the training set and starts getting worse, and the “second descent” regime where it doesn’t. The classical regime is where most industrial models operate, where data is plentiful, compute is $$$, and the model has to be finally deployed serving QPS. The second descent regime is where most research models operate, as datasets are standard and small, and the models never have to be deployed.

But despite the big differences between them I believe the same underlying principle describes both. When gradient descent can’t tell two things apart, it tends to divide credit equally between them.

Classical Regime

Half of an introductory statistics textbook is about how to deal with collinearity. But collinearity only matters if you care about attribution, which you don’t. Suppose you have two nearly collinear inputs. At the beginning of gradient descent, every correlated input’s coefficient is getting nearly the same gradients. In the “classical” regime, eventually they won’t, and eventually some of the difference between their gradients will be noise and thus overfitting. So when that happens you just stop!

This means that the difference between having n parameters and n+1 parameters is something that is effectively controlled by early stopping. You might have a lot of effectively unused parameters, but again, you don’t care. Having a low dimensional model just biases the model, because if a low dimensional model worked, early stopping would choose that automatically due to the way gradients work. The validation set tells you when most of the gradients you’re getting are just noise, so you just ignore them and stop. Early stopping is better regularization than any hand picked a priori regularization, including implicit regularization like model size.

Second Descent Regime

In the “second descent” regime, where the model trains until it exhausts the training set entirely, something else is going on, but I believe the underlying cause is the same. Early in training, when the model fits poorly, it settles into a set of good features in the lower layers. Later in training, when those low level features are learned, it has a very wide variety of models to choose from, that all fit the training set entirely, but gradient descent does what it does, and gives them all about equal weight. From its perspective these equally performing models are collinear.

If you want to pick a function that will generalize well to new data points, a good strategy to do so is 1. Define a class of good functions that fit the data well. 2. Average the predictions of every function in that class. Gradient descent does both of these things automatically. And as the depth and dimensionality get larger and deeper, the class of models becomes less biased by the model structure, and that average becomes more and more accurate.

Putting this all together.

  1. There’s a good set of params somewhere nearby.
  2. When we start walking to it, we can’t ever get stuck along the way, because there are no local optima.
  3. a. In the classical regime, once we’ve stumbled upon a good set of parameters, we’ll know it and we can just stop.
    b. In the double descent regime, once we’ve stumbled upon a good set of parameters for the lower layers, gradient descent will tend to find an average of all the good models in the higher layers.

That’s it. Once you believe all of these things, then it becomes intuitive that big models are always* better even without more data, because all of these trends become more and more true the more dimensions you have.

*Except in the transition region between classical and double descent, where neither 3a or 3b apply as well.

Suggestions for Research

Stop talking about minima. Stop talking about how your optimization algorithm behaves around a minimum. Nobody ever trains their model remotely close to convergence. ADAM doesn’t even provably converge. All real models diverge! You are nowhere close to a minimum! Stop talking about minima already goddamnit! Why even think about minima?! Minima are a myth! Even if you’ve hit zero loss on your training set, you can likely walk through zero loss models in an enormous chaotic basin all around you. Everybody proves their results for minima of a convex function. What really needs further research is how algorithms behave very far from minima. If you assume it’s easy to go downhill forever, which the huge number of degrees of freedom guarantees, then what you really need to establish is that the direction it picks generalizes well.

An optimization algorithm is best thought of as a priority queue of things to learn, and the thing that’s important to prove is that your algorithm learns the good things first. When your model peaks on the validation set, it is starting to learn more bad things than good things. When your model is reaching zero training loss in the second descent mode, it is building a random/average function out of the good things it learned early in training.

If you have an optimization algorithm that is better at learning the good things before the bad things, then it will achieve a lower loss. For me this way of thinking explains a lot of otherwise confusing phenomena, like why distillation works so well. Distillation causes the student model to learn more of the good things before it learns the bad things, because the teacher model knows more good things than bad things already. This is why it’s possible for a more powerful model to learn usefully from a weaker one.

3 Comments

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s