.. _sec_bptt:
Backpropagation Through Time
============================
So far we repeatedly alluded to things like *exploding gradients*,
*vanishing gradients*, *truncating backprop*, and the need to *detach
the computational graph*. For instance, in the previous section we
invoked ``s.detach()`` on the sequence. None of this was really fully
explained, in the interest of being able to build a model quickly and to
see how it works. In this section we will delve a bit more deeply into
the details of backpropagation for sequence models and why (and how) the
math works. For a more detailed discussion about randomization and
backpropagation also see the paper by :cite:`Tallec.Ollivier.2017`.
We encountered some of the effects of gradient explosion when we first
implemented recurrent neural networks (:numref:`sec_rnn_scratch`). In
particular, if you solved the problems in the problem set, you would
have seen that gradient clipping is vital to ensure proper convergence.
To provide a better understanding of this issue, this section will
review how gradients are computed for sequence models. Note that there
is nothing conceptually new in how it works. After all, we are still
merely applying the chain rule to compute gradients. Nonetheless, it is
worth while reviewing backpropagation (:numref:`sec_backprop`) again.
Forward propagation in a recurrent neural network is relatively
straightforward. *Backpropagation through time* is actually a specific
application of back propagation in recurrent neural networks. It
requires us to expand the recurrent neural network one timestep at a
time to obtain the dependencies between model variables and parameters.
Then, based on the chain rule, we apply backpropagation to compute and
store gradients. Since sequences can be rather long, the dependency can
be rather lengthy. For instance, for a sequence of 1000 characters, the
first symbol could potentially have significant influence on the symbol
at position 1000. This is not really computationally feasible (it takes
too long and requires too much memory) and it requires over 1000
matrix-vector products before we would arrive at that very elusive
gradient. This is a process fraught with computational and statistical
uncertainty. In the following we will elucidate what happens and how to
address this in practice.
A Simplified Recurrent Network
------------------------------
We start with a simplified model of how an RNN works. This model ignores
details about the specifics of the hidden state and how it is updated.
These details are immaterial to the analysis and would only serve to
clutter the notation, but make it look more intimidating. In this
simplified model, we denote :math:`h_t` as the hidden state, :math:`x_t`
as the input, and :math:`o_t` as the output at timestep :math:`t`. In
addition, :math:`w_h` and :math:`w_o` indicate the weights of hidden
states and the output layer, respectively. As a result, the hidden
states and outputs at each timesteps can be explained as
.. math:: h_t = f(x_t, h_{t-1}, w_h) \text{ and } o_t = g(h_t, w_o).
Hence, we have a chain of values
:math:`\{\ldots, (h_{t-1}, x_{t-1}, o_{t-1}), (h_{t}, x_{t}, o_t), \ldots\}`
that depend on each other via recursive computation. The forward pass is
fairly straightforward. All we need is to loop through the
:math:`(x_t, h_t, o_t)` triples one step at a time. The discrepancy
between outputs :math:`o_t` and the desired targets :math:`y_t` is then
evaluated by an objective function as
.. math:: L(x, y, w_h, w_o) = \sum_{t=1}^T l(y_t, o_t).
For backpropagation, matters are a bit more tricky, especially when we
compute the gradients with regard to the parameters :math:`w_h` of the
objective function :math:`L`. To be specific, by the chain rule,
.. math::
\begin{aligned}
\partial_{w_h} L & = \sum_{t=1}^T \partial_{w_h} l(y_t, o_t) \\
& = \sum_{t=1}^T \partial_{o_t} l(y_t, o_t) \partial_{h_t} g(h_t, w_h) \left[ \partial_{w_h} h_t\right].
\end{aligned}
The first and the second part of the derivative is easy to compute. The
third part :math:`\partial_{w_h} h_t` is where things get tricky, since
we need to compute the effect of the parameters on :math:`h_t`.
To derive the above gradient, assume that we have three sequences
:math:`\{a_{t}\},\{b_{t}\},\{c_{t}\}` satisfying
:math:`a_{0}=0, a_{1}=b_{1}`, and :math:`a_{t}=b_{t}+c_{t}a_{t-1}` for
:math:`t=1, 2,\ldots`. Then for :math:`t\geq 1`, it is easy to show
.. math:: a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.
:label: eq_bptt_at
Now let’s apply :eq:`eq_bptt_at` with
.. math:: a_t = \partial_{w_h}h_{t},
.. math:: b_t = \partial_{w_h}f(x_{t},h_{t-1},w_h),
.. math:: c_t = \partial_{h_{t-1}}f(x_{t},h_{t-1},w_h).
Therefore, :math:`a_{t}=b_{t}+c_{t}a_{t-1}` becomes the following
recursion
.. math::
\partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w)+\partial_{h}f(x_{t},h_{t-1},w_h)\partial_{w_h}h_{t-1}.
By :eq:`eq_bptt_at`, the third part will be
.. math::
\partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w_h)+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}\partial_{h_{j-1}}f(x_{j},h_{j-1},w_h)\right)\partial_{w_h}f(x_{i},h_{i-1},w_h).
While we can use the chain rule to compute :math:`\partial_w h_t`
recursively, this chain can get very long whenever :math:`t` is large.
Let’s discuss a number of strategies for dealing with this problem.
- **Compute the full sum.** This is very slow and gradients can blow
up, since subtle changes in the initial conditions can potentially
affect the outcome a lot. That is, we could see things similar to the
butterfly effect where minimal changes in the initial conditions lead
to disproportionate changes in the outcome. This is actually quite
undesirable in terms of the model that we want to estimate. After
all, we are looking for robust estimators that generalize well. Hence
this strategy is almost never used in practice.
- **Truncate the sum after** :math:`\tau` **steps.** This is what we
have been discussing so far. This leads to an *approximation* of the
true gradient, simply by terminating the sum above at
:math:`\partial_w h_{t-\tau}`. The approximation error is thus given
by :math:`\partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}`
(multiplied by a product of gradients involving
:math:`\partial_h f`). In practice this works quite well. It is what
is commonly referred to as truncated BPTT (backpropgation through
time). One of the consequences of this is that the model focuses
primarily on short-term influence rather than long-term consequences.
This is actually *desirable*, since it biases the estimate towards
simpler and more stable models.
- **Randomized Truncation.** Last we can replace
:math:`\partial_{w_h} h_t` by a random variable which is correct in
expectation but which truncates the sequence. This is achieved by
using a sequence of :math:`\xi_t` where :math:`E[\xi_t] = 1` and
:math:`P(\xi_t = 0) = 1-\pi` and furthermore
:math:`P(\xi_t = \pi^{-1}) = \pi`. We use this to replace the
gradient:
.. math:: z_t = \partial_w f(x_t, h_{t-1}, w) + \xi_t \partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}.
It follows from the definition of :math:`\xi_t` that
:math:`E[z_t] = \partial_w h_t`. Whenever :math:`\xi_t = 0` the
expansion terminates at that point. This leads to a weighted sum of
sequences of varying lengths where long sequences are rare but
appropriately overweighted. :cite:`Tallec.Ollivier.2017` proposed this
in their paper. Unfortunately, while appealing in theory, the model does
not work much better than simple truncation, most likely due to a number
of factors. First, the effect of an observation after a number of
backpropagation steps into the past is quite sufficient to capture
dependencies in practice. Second, the increased variance counteracts the
fact that the gradient is more accurate. Third, we actually *want*
models that have only a short range of interaction. Hence, BPTT has a
slight regularizing effect which can be desirable.
.. _fig_truncated_bptt:
.. figure:: ../img/truncated-bptt.svg
From top to bottom: randomized BPTT, regularly truncated BPTT and
full BPTT
:numref:`fig_truncated_bptt` illustrates the three cases when
analyzing the first few words of *The Time Machine*: \* The first row is
the randomized truncation which partitions the text into segments of
varying length. \* The second row is the regular truncated BPTT which
breaks it into sequences of the same length. \* The third row is the
full BPTT that leads to a computationally infeasible expression.
The Computational Graph
-----------------------
In order to visualize the dependencies between model variables and
parameters during computation in a recurrent neural network, we can draw
a computational graph for the model, as shown in
:numref:`fig_rnn_bptt`. For example, the computation of the hidden
states of timestep 3, :math:`\mathbf{h}_3`, depends on the model
parameters :math:`\mathbf{W}_{hx}` and :math:`\mathbf{W}_{hh}`, the
hidden state of the last timestep :math:`\mathbf{h}_2`, and the input of
the current timestep :math:`\mathbf{x}_3`.
.. _fig_rnn_bptt:
.. figure:: ../img/rnn-bptt.svg
Computational dependencies for a recurrent neural network model with
three timesteps. Boxes represent variables (not shaded) or parameters
(shaded) and circles represent operators.
BPTT in Detail
--------------
After discussing the general principle, let’s discuss BPTT in detail. By
decomposing :math:`\mathbf{W}` into different sets of weight matrices
(:math:`\mathbf{W}_{hx}, \mathbf{W}_{hh}` and :math:`\mathbf{W}_{oh}`),
we will get a simple linear latent variable model:
.. math::
\mathbf{h}_t = \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1} \text{ and }
\mathbf{o}_t = \mathbf{W}_{oh} \mathbf{h}_t.
Following the discussion in :numref:`sec_backprop`, we compute the
gradients :math:`\frac{\partial L}{\partial \mathbf{W}_{hx}}`,
:math:`\frac{\partial L}{\partial \mathbf{W}_{hh}}`,
:math:`\frac{\partial L}{\partial \mathbf{W}_{oh}}` for
.. math:: L(\mathbf{x}, \mathbf{y}, \mathbf{W}) = \sum_{t=1}^T l(\mathbf{o}_t, y_t),
where :math:`l(\cdot)` denotes the chosen loss function. Taking the
derivatives with respect to :math:`W_{oh}` is fairly straightforward and
we obtain
.. math::
\partial_{\mathbf{W}_{oh}} L = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{h}_t\right),
where :math:`\mathrm{prod} (\cdot)` indicates the product of two or more
matrices.
The dependency on :math:`\mathbf{W}_{hx}` and :math:`\mathbf{W}_{hh}` is
a bit more tricky since it involves a chain of derivatives. We begin
with
.. math::
\begin{aligned}
\partial_{\mathbf{W}_{hh}} L & = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{W}_{oh}, \partial_{\mathbf{W}_{hh}} \mathbf{h}_t\right), \\
\partial_{\mathbf{W}_{hx}} L & = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{W}_{oh}, \partial_{\mathbf{W}_{hx}} \mathbf{h}_t\right).
\end{aligned}
After all, hidden states depend on each other and on past inputs. The
key quantity is how past hidden states affect future hidden states.
.. math::
\partial_{\mathbf{h}_t} \mathbf{h}_{t+1} = \mathbf{W}_{hh}^\top
\text{ and thus }
\partial_{\mathbf{h}_t} \mathbf{h}_T = \left(\mathbf{W}_{hh}^\top\right)^{T-t}.
Chaining terms together yields
.. math::
\begin{aligned}
\partial_{\mathbf{W}_{hh}} \mathbf{h}_t & = \sum_{j=1}^t \left(\mathbf{W}_{hh}^\top\right)^{t-j} \mathbf{h}_j \\
\partial_{\mathbf{W}_{hx}} \mathbf{h}_t & = \sum_{j=1}^t \left(\mathbf{W}_{hh}^\top\right)^{t-j} \mathbf{x}_j.
\end{aligned}
A number of things follow from this potentially very intimidating
expression. First, it pays to store intermediate results, i.e., powers
of :math:`\mathbf{W}_{hh}` as we work our way through the terms of the
loss function :math:`L`. Second, this simple linear example already
exhibits some key problems of long sequence models: it involves
potentially very large powers :math:`\mathbf{W}_{hh}^j`. In it,
eigenvalues smaller than :math:`1` vanish for large :math:`j` and
eigenvalues larger than :math:`1` diverge. This is numerically unstable
and gives undue importance to potentially irrelevant past detail. One
way to address this is to truncate the sum at a computationally
convenient size. Later on in :numref:`chap_modern_rnn` we will see how
more sophisticated sequence models such as LSTMs can alleviate this
further. In practice, this truncation is effected by *detaching* the
gradient after a given number of steps.
Summary
-------
- Backpropagation through time is merely an application of backprop to
sequence models with a hidden state.
- Truncation is needed for computational convenience and numerical
stability.
- High powers of matrices can lead to divergent and vanishing
eigenvalues. This manifests itself in the form of exploding or
vanishing gradients.
- For efficient computation, intermediate values are cached.
Exercises
---------
1. Assume that we have a symmetric matrix
:math:`\mathbf{M} \in \mathbb{R}^{n \times n}` with eigenvalues
:math:`\lambda_i`. Without loss of generality, assume that they are
ordered in ascending order :math:`\lambda_i \leq \lambda_{i+1}`. Show
that :math:`\mathbf{M}^k` has eigenvalues :math:`\lambda_i^k`.
2. Prove that for a random vector :math:`\mathbf{x} \in \mathbb{R}^n`,
with high probability :math:`\mathbf{M}^k \mathbf{x}` will be very
much aligned with the largest eigenvector :math:`\mathbf{v}_n` of
:math:`\mathbf{M}`. Formalize this statement.
3. What does the above result mean for gradients in a recurrent neural
network?
4. Besides gradient clipping, can you think of any other methods to cope
with gradient explosion in recurrent neural networks?
`Discussions `__
-------------------------------------------------
|image0|
.. |image0| image:: ../img/qr_bptt.svg