Beam Search
===========
In :numref:`sec_seq2seq`, we discussed how to train an encoder-decoder
with input and output sequences that are both of variable length. In
this section, we are going to introduce how to use the encoder-decoder
to predict sequences of variable length.
As in :numref:`sec_machine_translation`, when preparing to train the
dataset, we normally attach a special symbol “” after each sentence
to indicate the termination of the sequence. We will continue to use
this mathematical symbol in the discussion below. For ease of
discussion, we assume that the output of the decoder is a sequence of
text. Let the size of output text dictionary :math:`\mathcal{Y}`
(contains special symbol “”) be :math:`\left|\mathcal{Y}\right|`,
and the maximum length of the output sequence be :math:`T'`. There are a
total :math:`\mathcal{O}(\left|\mathcal{Y}\right|^{T'})` types of
possible output sequences. All the subsequences after the special symbol
“” in these output sequences will be discarded. Besides, we still
denote the context vector as :math:`\mathbf{c}`, which encodes
information of all the hidden states from the input.
Greedy Search
-------------
First, we will take a look at a simple solution: greedy search. For any
time step :math:`t'` of the output sequence, we are going to search for
the word with the highest conditional probability from
:math:`|\mathcal{Y}|` numbers of words, with
.. math:: y_{t'} = \operatorname*{argmax}_{y \in \mathcal{Y}} P(y \mid y_1, \ldots, y_{t'-1}, \mathbf{c})
as the output. Once the “” symbol is detected, or the output
sequence has reached its maximum length :math:`T'`, the output is
completed.
As we mentioned in our discussion of the decoder, the conditional
probability of generating an output sequence based on the input sequence
is
:math:`\prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c})`.
We will take the output sequence with the highest conditional
probability as the optimal sequence. The main problem with greedy search
is that there is no guarantee that the optimal sequence will be
obtained.
Take a look at the example below. We assume that there are four words
“A”, “B”, “C”, and “” in the output dictionary. The four numbers
under each time step in :numref:`fig_s2s-prob1` represent the
conditional probabilities of generating “A”, “B”, “C”, and “” at
that time step, respectively. At each time step, greedy search selects
the word with the highest conditional probability. Therefore, the output
sequence “A”, “B”, “C”, and “” will be generated in
:numref:`fig_s2s-prob1`. The conditional probability of this output
sequence is :math:`0.5\times0.4\times0.4\times0.6 = 0.048`.
.. _fig_s2s-prob1:
.. figure:: ../img/s2s-prob1.svg
The four numbers under each time step represent the conditional
probabilities of generating “A”, “B”, “C”, and “” at that time
step, respectively. At each time step, greedy search selects the word
with the highest conditional probability.
Now, we will look at another example shown in :numref:`fig_s2s-prob2`.
Unlike in :numref:`fig_s2s-prob1`, the following figure
:numref:`fig_s2s-prob2` selects the word “C”, which has the second
highest conditional probability at time step 2. Since the output
subsequences of time steps 1 and 2, on which time step 3 is based, are
changed from “A” and “B” in :numref:`fig_s2s-prob1` to “A” and “C” in
:numref:`fig_s2s-prob2`, the conditional probability of each word
generated at time step 3 has also changed in :numref:`fig_s2s-prob2`.
We choose the word “B”, which has the highest conditional probability.
Now, the output subsequences of time step 4 based on the first three
time steps are “A”, “C”, and “B”, which are different from “A”, “B”, and
“C” in :numref:`fig_s2s-prob1`. Therefore, the conditional probability
of generating each word in time step 4 in :numref:`fig_s2s-prob2` is
also different from that in :numref:`fig_s2s-prob1`. We find that the
conditional probability of the output sequence “A”, “C”, “B”, and
“” at the current time step is
:math:`0.5\times0.3 \times0.6\times0.6=0.054`, which is higher than the
conditional probability of the output sequence obtained by greedy
search. Therefore, the output sequence “A”, “B”, “C”, and “”
obtained by the greedy search is not an optimal sequence.
.. _fig_s2s-prob2:
.. figure:: ../img/s2s-prob2.svg
The four numbers under each time step represent the conditional
probabilities of generating “A”, “B”, “C”, and “” at that time
step. At time step 2, the word “C”, which has the second highest
conditional probability, is selected.
.. _beam-search-1:
Exhaustive Search
-----------------
If the goal is to obtain the optimal sequence, we may consider using
exhaustive search: an exhaustive examination of all possible output
sequences, which outputs the sequence with the highest conditional
probability.
Although we can use an exhaustive search to obtain the optimal sequence,
its computational overhead
:math:`\mathcal{O}(\left|\mathcal{Y}\right|^{T'})` is likely to be
excessively high. For example, when :math:`|\mathcal{Y}|=10000` and
:math:`T'=10`, we will need to evaluate :math:`10000^{10} = 10^{40}`
sequences. This is next to impossible to complete. The computational
overhead of greedy search is
:math:`\mathcal{O}(\left|\mathcal{Y}\right|T')`, which is usually
significantly less than the computational overhead of an exhaustive
search. For example, when :math:`|\mathcal{Y}|=10000` and :math:`T'=10`,
we only need to evaluate :math:`10000\times10=1\times10^5` sequences.
Beam Search
-----------
*Beam search* is an improved algorithm based on greedy search. It has a
hyperparameter named *beam size*, :math:`k`. At time step 1, we select
:math:`k` words with the highest conditional probability to be the first
word of the :math:`k` candidate output sequences. For each subsequent
time step, we are going to select the :math:`k` output sequences with
the highest conditional probability from the total of
:math:`k\left|\mathcal{Y}\right|` possible output sequences based on the
:math:`k` candidate output sequences from the previous time step. These
will be the candidate output sequences for that time step. Finally, we
will filter out the sequences containing the special symbol “” from
the candidate output sequences of each time step and discard all the
subsequences after it to obtain a set of final candidate output
sequences.
.. _fig_beam-search:
.. figure:: ../img/beam-search.svg
The beam search process. The beam size is 2 and the maximum length of
the output sequence is 3. The candidate output sequences are
:math:`A`, :math:`C`, :math:`AB`, :math:`CE`, :math:`ABD`, and
:math:`CED`.
:numref:`fig_beam-search` demonstrates the process of beam search with
an example. Suppose that the vocabulary of the output sequence contains
only five elements: :math:`\mathcal{Y} = \{A, B, C, D, E\}` where one of
them is a special symbol “”. Set beam size to 2, the maximum length
of the output sequence to 3. At time step 1 of the output sequence,
suppose the words with the highest conditional probability
:math:`P(y_1 \mid \mathbf{c})` are :math:`A` and :math:`C`. At time step
2, for all :math:`y_2 \in \mathcal{Y},` we compute
.. math:: P(A, y_2 \mid \mathbf{c}) = P(A \mid \mathbf{c})P(y_2 \mid A, \mathbf{c})
and
.. math:: P(C, y_2 \mid \mathbf{c}) = P(C \mid \mathbf{c})P(y_2 \mid C, \mathbf{c}),
and pick the largest two among these 10 values, say
.. math:: P(A, B \mid \mathbf{c}) \text{ and } P(C, E \mid \mathbf{c}).
Then at time step 3, for all :math:`y_3 \in \mathcal{Y}`, we compute
.. math:: P(A, B, y_3 \mid \mathbf{c}) = P(A, B \mid \mathbf{c})P(y_3 \mid A, B, \mathbf{c})
and
.. math:: P(C, E, y_3 \mid \mathbf{c}) = P(C, E \mid \mathbf{c})P(y_3 \mid C, E, \mathbf{c}),
and pick the largest two among these 10 values, say
.. math:: P(A, B, D \mid \mathbf{c}) \text{ and } P(C, E, D \mid \mathbf{c}).
As a result, we obtain 6 candidates output sequences: (1) :math:`A`; (2)
:math:`C`; (3) :math:`A`, :math:`B`; (4) :math:`C`, :math:`E`; (5)
:math:`A`, :math:`B`, :math:`D`; and (6) :math:`C`, :math:`E`,
:math:`D`. In the end, we will get the set of final candidate output
sequences based on these 6 sequences.
In the set of final candidate output sequences, we will take the
sequence with the highest score as the output sequence from those below:
.. math:: \frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c}),
Here, :math:`L` is the length of the final candidate sequence and the
selection for :math:`\alpha` is generally 0.75. The :math:`L^\alpha` on
the denominator is a penalty on the logarithmic addition scores for the
longer sequences above. The computational overhead
:math:`\mathcal{O}(k\left|\mathcal{Y}\right|T')` of the beam search can
be obtained through analysis. The result is between the computational
overhead of greedy search and exhaustive search. In addition, greedy
search can be treated as a beam search with a beam size of 1. Beam
search strikes a balance between computational overhead and search
quality using a flexible beam size of :math:`k`.
Summary
-------
- Methods for predicting variable-length sequences include greedy
search, exhaustive search, and beam search.
- Beam search strikes a balance between computational overhead and
search quality using a flexible beam size.
Exercises
---------
1. Can we treat an exhaustive search as a beam search with a special
beam size? Why?
2. We used language models to generate sentences in
:numref:`sec_rnn_scratch`. Which kind of search does this output
use? Can you improve it?
`Discussions `__