12.3. Gradient Descent¶ Open the notebook in SageMaker Studio Lab
In this section we are going to introduce the basic concepts underlying gradient descent. Although it is rarely used directly in deep learning, an understanding of gradient descent is key to understanding stochastic gradient descent algorithms. For instance, the optimization problem might diverge due to an overly large learning rate. This phenomenon can already be seen in gradient descent. Likewise, preconditioning is a common technique in gradient descent and carries over to more advanced algorithms. Let’s start with a simple special case.
12.3.1. One-Dimensional Gradient Descent¶
Gradient descent in one dimension is an excellent example to explain why
the gradient descent algorithm may reduce the value of the objective
function. Consider some continuously differentiable real-valued function
That is, in first-order approximation
If the derivative
This means that, if we use
to iterate
For simplicity we choose the objective function
Next, we use
epoch 10, x: 0.060466
epoch 10, x: 0.060466
The progress of optimizing over
[22:06:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
12.3.1.1. Learning Rate¶
The learning rate
Conversely, if we use an excessively high learning rate,
12.3.1.2. Local Minima¶
To illustrate what happens for nonconvex functions consider the case of
epoch 10, x: -1.528166
epoch 10, x: -1.528165
12.3.2. Multivariate Gradient Descent¶
Now that we have a better intuition of the univariate case, let’s
consider the situation where
Each partial derivative element
In other words, up to second-order terms in
To see how the algorithm behaves in practice let’s construct an
objective function
To begin with, we need two more helper functions. The first uses an
update function and applies it 20 times to the initial value. The second
helper visualizes the trajectory of
def train_2d(trainer, steps=20, f_grad=None): #@save
"""Optimize a 2D objective function with a customized trainer."""
# `s1` and `s2` are internal state variables that will be used in Momentum, adagrad, RMSProp
x1, x2, s1, s2 = -5, -2, 0, 0
results = [(x1, x2)]
for i in range(steps):
if f_grad:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)
else:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2)
results.append((x1, x2))
print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')
return results
def show_trace_2d(f, results): #@save
"""Show the trace of 2D variables during optimization."""
d2l.set_figsize()
d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
x1, x2 = torch.meshgrid(torch.arange(-5.5, 1.0, 0.1),
torch.arange(-3.0, 1.0, 0.1), indexing='ij')
d2l.plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')
d2l.plt.xlabel('x1')
d2l.plt.ylabel('x2')
def train_2d(trainer, steps=20, f_grad=None): #@save
"""Optimize a 2D objective function with a customized trainer."""
# `s1` and `s2` are internal state variables that will be used in Momentum, adagrad, RMSProp
x1, x2, s1, s2 = -5, -2, 0, 0
results = [(x1, x2)]
for i in range(steps):
if f_grad:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)
else:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2)
results.append((x1, x2))
print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')
return results
def show_trace_2d(f, results): #@save
"""Show the trace of 2D variables during optimization."""
d2l.set_figsize()
d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
x1, x2 = np.meshgrid(np.arange(-55, 1, 1),
np.arange(-30, 1, 1))
x1, x2 = x1.asnumpy()*0.1, x2.asnumpy()*0.1
d2l.plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')
d2l.plt.xlabel('x1')
d2l.plt.ylabel('x2')
def train_2d(trainer, steps=20, f_grad=None): #@save
"""Optimize a 2D objective function with a customized trainer."""
# `s1` and `s2` are internal state variables that will be used in Momentum, adagrad, RMSProp
x1, x2, s1, s2 = -5, -2, 0, 0
results = [(x1, x2)]
for i in range(steps):
if f_grad:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)
else:
x1, x2, s1, s2 = trainer(x1, x2, s1, s2)
results.append((x1, x2))
print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')
return results
def show_trace_2d(f, results): #@save
"""Show the trace of 2D variables during optimization."""
d2l.set_figsize()
d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
x1, x2 = tf.meshgrid(tf.range(-5.5, 1.0, 0.1),
tf.range(-3.0, 1.0, 0.1))
d2l.plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')
d2l.plt.xlabel('x1')
d2l.plt.ylabel('x2')
Next, we observe the trajectory of the optimization variable
def f_2d(x1, x2): # Objective function
return x1 ** 2 + 2 * x2 ** 2
def f_2d_grad(x1, x2): # Gradient of the objective function
return (2 * x1, 4 * x2)
def gd_2d(x1, x2, s1, s2, f_grad):
g1, g2 = f_grad(x1, x2)
return (x1 - eta * g1, x2 - eta * g2, 0, 0)
eta = 0.1
show_trace_2d(f_2d, train_2d(gd_2d, f_grad=f_2d_grad))
epoch 20, x1: -0.057646, x2: -0.000073
def f_2d(x1, x2): # Objective function
return x1 ** 2 + 2 * x2 ** 2
def f_2d_grad(x1, x2): # Gradient of the objective function
return (2 * x1, 4 * x2)
def gd_2d(x1, x2, s1, s2, f_grad):
g1, g2 = f_grad(x1, x2)
return (x1 - eta * g1, x2 - eta * g2, 0, 0)
eta = 0.1
show_trace_2d(f_2d, train_2d(gd_2d, f_grad=f_2d_grad))
epoch 20, x1: -0.057646, x2: -0.000073
def f_2d(x1, x2): # Objective function
return x1 ** 2 + 2 * x2 ** 2
def f_2d_grad(x1, x2): # Gradient of the objective function
return (2 * x1, 4 * x2)
def gd_2d(x1, x2, s1, s2, f_grad):
g1, g2 = f_grad(x1, x2)
return (x1 - eta * g1, x2 - eta * g2, 0, 0)
eta = 0.1
show_trace_2d(f_2d, train_2d(gd_2d, f_grad=f_2d_grad))
epoch 20, x1: -0.057646, x2: -0.000073
12.3.3. Adaptive Methods¶
As we could see in Section 12.3.1.1, getting the
learning rate
12.3.3.1. Newton’s Method¶
Reviewing the Taylor expansion of some function
To avoid cumbersome notation we define
After all, the minimum of
That is, we need to invert the Hessian
As a simple example, for
Let’s see what happens in other problems. Given a convex hyperbolic
cosine function
c = torch.tensor(0.5)
def f(x): # Objective function
return torch.cosh(c * x)
def f_grad(x): # Gradient of the objective function
return c * torch.sinh(c * x)
def f_hess(x): # Hessian of the objective function
return c**2 * torch.cosh(c * x)
def newton(eta=1):
x = 10.0
results = [x]
for i in range(10):
x -= eta * f_grad(x) / f_hess(x)
results.append(float(x))
print('epoch 10, x:', x)
return results
show_trace(newton(), f)
epoch 10, x: tensor(0.)
c = np.array(0.5)
def f(x): # Objective function
return np.cosh(c * x)
def f_grad(x): # Gradient of the objective function
return c * np.sinh(c * x)
def f_hess(x): # Hessian of the objective function
return c**2 * np.cosh(c * x)
def newton(eta=1):
x = 10.0
results = [x]
for i in range(10):
x -= eta * f_grad(x) / f_hess(x)
results.append(float(x))
print('epoch 10, x:', x)
return results
show_trace(newton(), f)
epoch 10, x: 0.0
c = tf.constant(0.5)
def f(x): # Objective function
return tf.cosh(c * x)
def f_grad(x): # Gradient of the objective function
return c * tf.sinh(c * x)
def f_hess(x): # Hessian of the objective function
return c**2 * tf.cosh(c * x)
def newton(eta=1):
x = 10.0
results = [x]
for i in range(10):
x -= eta * f_grad(x) / f_hess(x)
results.append(float(x))
print('epoch 10, x:', x)
return results
show_trace(newton(), f)
epoch 10, x: tf.Tensor(0.0, shape=(), dtype=float32)
Now let’s consider a nonconvex function, such as
c = torch.tensor(0.15 * np.pi)
def f(x): # Objective function
return x * torch.cos(c * x)
def f_grad(x): # Gradient of the objective function
return torch.cos(c * x) - c * x * torch.sin(c * x)
def f_hess(x): # Hessian of the objective function
return - 2 * c * torch.sin(c * x) - x * c**2 * torch.cos(c * x)
show_trace(newton(), f)
epoch 10, x: tensor(26.8341)
c = np.array(0.15 * np.pi)
def f(x): # Objective function
return x * np.cos(c * x)
def f_grad(x): # Gradient of the objective function
return np.cos(c * x) - c * x * np.sin(c * x)
def f_hess(x): # Hessian of the objective function
return - 2 * c * np.sin(c * x) - x * c**2 * np.cos(c * x)
show_trace(newton(), f)
epoch 10, x: 26.834133
c = tf.constant(0.15 * np.pi)
def f(x): # Objective function
return x * tf.cos(c * x)
def f_grad(x): # Gradient of the objective function
return tf.cos(c * x) - c * x * tf.sin(c * x)
def f_hess(x): # Hessian of the objective function
return - 2 * c * tf.sin(c * x) - x * c**2 * tf.cos(c * x)
show_trace(newton(), f)
epoch 10, x: tf.Tensor(26.834133, shape=(), dtype=float32)
This went spectacularly wrong. How can we fix it? One way would be to
“fix” the Hessian by taking its absolute value instead. Another strategy
is to bring back the learning rate. This seems to defeat the purpose,
but not quite. Having second-order information allows us to be cautious
whenever the curvature is large and to take longer steps whenever the
objective function is flatter. Let’s see how this works with a slightly
smaller learning rate, say
12.3.3.2. Convergence Analysis¶
We only analyze the convergence rate of Newton’s method for some convex
and three times differentiable objective function
Denote by
which holds for some
Recall that we have the update
Consequently, whenever we are in a region of bounded
As an aside, optimization researchers call this linear convergence,
whereas a condition such as
12.3.3.3. Preconditioning¶
Quite unsurprisingly computing and storing the full Hessian is very expensive. It is thus desirable to find alternatives. One way to improve matters is preconditioning. It avoids computing the Hessian in its entirety but only computes the diagonal entries. This leads to update algorithms of the form
While this is not quite as good as the full Newton’s method, it is still
much better than not using it. To see why this might be a good idea
consider a situation where one variable denotes height in millimeters
and the other one denotes height in kilometers. Assuming that for both
the natural scale is in meters, we have a terrible mismatch in
parametrizations. Fortunately, using preconditioning removes this.
Effectively preconditioning with gradient descent amounts to selecting a
different learning rate for each variable (coordinate of vector
12.3.3.4. Gradient Descent with Line Search¶
One of the key problems in gradient descent is that we might overshoot
the goal or make insufficient progress. A simple fix for the problem is
to use line search in conjunction with gradient descent. That is, we use
the direction given by
This algorithm converges rapidly (for an analysis and proof see e.g., Boyd and Vandenberghe (2004)). However, for the purpose of deep learning this is not quite so feasible, since each step of the line search would require us to evaluate the objective function on the entire dataset. This is way too costly to accomplish.
12.3.4. Summary¶
Learning rates matter. Too large and we diverge, too small and we do not make progress.
Gradient descent can get stuck in local minima.
In high dimensions adjusting the learning rate is complicated.
Preconditioning can help with scale adjustment.
Newton’s method is a lot faster once it has started working properly in convex problems.
Beware of using Newton’s method without any adjustments for nonconvex problems.
12.3.5. Exercises¶
Experiment with different learning rates and objective functions for gradient descent.
Implement line search to minimize a convex function in the interval
.Do you need derivatives for binary search, i.e., to decide whether to pick
or .How rapid is the rate of convergence for the algorithm?
Implement the algorithm and apply it to minimizing
.
Design an objective function defined on
where gradient descent is exceedingly slow. Hint: scale different coordinates differently.Implement the lightweight version of Newton’s method using preconditioning:
Use diagonal Hessian as preconditioner.
Use the absolute values of that rather than the actual (possibly signed) values.
Apply this to the problem above.
Apply the algorithm above to a number of objective functions (convex or not). What happens if you rotate coordinates by
degrees?