.. _sec_self-attention-and-positional-encoding:
Self-Attention and Positional Encoding
======================================
In deep learning, we often use CNNs or RNNs to encode sequences. Now
with attention mechanisms in mind, imagine feeding a sequence of tokens
into an attention mechanism such that at every step, each token has its
own query, keys, and values. Here, when computing the value of a token’s
representation at the next layer, the token can attend (via its query
vector) to any other’s token (matching based on their key vectors).
Using the full set of query-key compatibility scores, we can compute,
for each token, a representation by building the appropriate weighted
sum over the other tokens. Because every token is attending to each
other token (unlike the case where decoder steps attend to encoder
steps), such architectures are typically described as *self-attention*
models :cite:`Lin.Feng.Santos.ea.2017,Vaswani.Shazeer.Parmar.ea.2017`,
and elsewhere described as *intra-attention* model
:cite:`Cheng.Dong.Lapata.2016,Parikh.Tackstrom.Das.ea.2016,Paulus.Xiong.Socher.2017`.
In this section, we will discuss sequence encoding using self-attention,
including using additional information for the sequence order.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Self-Attention
--------------
Given a sequence of input tokens
:math:`\mathbf{x}_1, \ldots, \mathbf{x}_n` where any
:math:`\mathbf{x}_i \in \mathbb{R}^d` (:math:`1 \leq i \leq n`), its
self-attention outputs a sequence of the same length
:math:`\mathbf{y}_1, \ldots, \mathbf{y}_n`, where
.. math:: \mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d
according to the definition of attention pooling in
:eq:`eq_attention_pooling`. Using multi-head attention, the
following code snippet computes the self-attention of a tensor with
shape (batch size, number of time steps or sequence length in tokens,
:math:`d`). The output tensor has the same shape.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:10:14] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
.. _subsec_cnn-rnn-self-attention:
Comparing CNNs, RNNs, and Self-Attention
----------------------------------------
Let’s compare architectures for mapping a sequence of :math:`n` tokens
to another one of equal length, where each input or output token is
represented by a :math:`d`-dimensional vector. Specifically, we will
consider CNNs, RNNs, and self-attention. We will compare their
computational complexity, sequential operations, and maximum path
lengths. Note that sequential operations prevent parallel computation,
while a shorter path between any combination of sequence positions makes
it easier to learn long-range dependencies within the sequence
:cite:`Hochreiter.Bengio.Frasconi.ea.2001`.
.. _fig_cnn-rnn-self-attention:
.. figure:: ../img/cnn-rnn-self-attention.svg
Comparing CNN (padding tokens are omitted), RNN, and self-attention
architectures.
Let’s regard any text sequence as a “one-dimensional image”. Similarly,
one-dimensional CNNs can process local features such as :math:`n`-grams
in text. Given a sequence of length :math:`n`, consider a convolutional
layer whose kernel size is :math:`k`, and whose numbers of input and
output channels are both :math:`d`. The computational complexity of the
convolutional layer is :math:`\mathcal{O}(knd^2)`. As
:numref:`fig_cnn-rnn-self-attention` shows, CNNs are hierarchical, so
there are :math:`\mathcal{O}(1)` sequential operations and the maximum
path length is :math:`\mathcal{O}(n/k)`. For example,
:math:`\mathbf{x}_1` and :math:`\mathbf{x}_5` are within the receptive
field of a two-layer CNN with kernel size 3 in
:numref:`fig_cnn-rnn-self-attention`.
When updating the hidden state of RNNs, multiplication of the
:math:`d \times d` weight matrix and the :math:`d`-dimensional hidden
state has a computational complexity of :math:`\mathcal{O}(d^2)`. Since
the sequence length is :math:`n`, the computational complexity of the
recurrent layer is :math:`\mathcal{O}(nd^2)`. According to
:numref:`fig_cnn-rnn-self-attention`, there are :math:`\mathcal{O}(n)`
sequential operations that cannot be parallelized and the maximum path
length is also :math:`\mathcal{O}(n)`.
In self-attention, the queries, keys, and values are all
:math:`n \times d` matrices. Consider the scaled dot product attention
in :eq:`eq_softmax_QK_V`, where an :math:`n \times d` matrix is
multiplied by a :math:`d \times n` matrix, then the output
:math:`n \times n` matrix is multiplied by an :math:`n \times d` matrix.
As a result, the self-attention has a :math:`\mathcal{O}(n^2d)`
computational complexity. As we can see from
:numref:`fig_cnn-rnn-self-attention`, each token is directly connected
to any other token via self-attention. Therefore, computation can be
parallel with :math:`\mathcal{O}(1)` sequential operations and the
maximum path length is also :math:`\mathcal{O}(1)`.
All in all, both CNNs and self-attention enjoy parallel computation and
self-attention has the shortest maximum path length. However, the
quadratic computational complexity with respect to the sequence length
makes self-attention prohibitively slow for very long sequences.
.. _subsec_positional-encoding:
Positional Encoding
-------------------
Unlike RNNs, which recurrently process tokens of a sequence one-by-one,
self-attention ditches sequential operations in favor of parallel
computation. Note that self-attention by itself does not preserve the
order of the sequence. What do we do if it really matters that the model
knows in which order the input sequence arrived?
The dominant approach for preserving information about the order of
tokens is to represent this to the model as an additional input
associated with each token. These inputs are called *positional
encodings*, and they can either be learned or fixed *a priori*. We now
describe a simple scheme for fixed positional encodings based on sine
and cosine functions :cite:`Vaswani.Shazeer.Parmar.ea.2017`.
Suppose that the input representation
:math:`\mathbf{X} \in \mathbb{R}^{n \times d}` contains the
:math:`d`-dimensional embeddings for :math:`n` tokens of a sequence. The
positional encoding outputs :math:`\mathbf{X} + \mathbf{P}` using a
positional embedding matrix
:math:`\mathbf{P} \in \mathbb{R}^{n \times d}` of the same shape, whose
element on the :math:`i^\textrm{th}` row and the
:math:`(2j)^\textrm{th}` or the :math:`(2j + 1)^\textrm{th}` column is
.. math:: \begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}
:label: eq_positional-encoding-def
At first glance, this trigonometric function design looks weird. Before
we give explanations of this design, let’s first implement it in the
following ``PositionalEncoding`` class.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Block): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len).reshape(-1, 1) / np.power(
10000, np.arange(0, num_hiddens, 2) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
return self.dropout(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
num_hiddens: int
dropout: float
max_len: int = 1000
def setup(self):
# Create a long enough P
self.P = jnp.zeros((1, self.max_len, self.num_hiddens))
X = jnp.arange(self.max_len, dtype=jnp.float32).reshape(
-1, 1) / jnp.power(10000, jnp.arange(
0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens)
self.P = self.P.at[:, :, 0::2].set(jnp.sin(X))
self.P = self.P.at[:, :, 1::2].set(jnp.cos(X))
@nn.compact
def __call__(self, X, training=False):
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'P', self.P)
X = X + self.P[:, :X.shape[1], :]
return nn.Dropout(self.dropout)(X, deterministic=not training)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(tf.keras.layers.Layer): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len, dtype=np.float32).reshape(
-1,1)/np.power(10000, np.arange(
0, num_hiddens, 2, dtype=np.float32) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def call(self, X, **kwargs):
X = X + self.P[:, :X.shape[1], :]
return self.dropout(X, **kwargs)
.. raw:: html
.. raw:: html
In the positional embedding matrix :math:`\mathbf{P}`, rows correspond
to positions within a sequence and columns represent different
positional encoding dimensions. In the example below, we can see that
the :math:`6^{\textrm{th}}` and the :math:`7^{\textrm{th}}` columns of
the positional embedding matrix have a higher frequency than the
:math:`8^{\textrm{th}}` and the :math:`9^{\textrm{th}}` columns. The
offset between the :math:`6^{\textrm{th}}` and the
:math:`7^{\textrm{th}}` (same for the :math:`8^{\textrm{th}}` and the
:math:`9^{\textrm{th}}`) columns is due to the alternation of sine and
cosine functions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_48_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_51_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
params = pos_encoding.init(d2l.get_key(), jnp.zeros((1, num_steps, encoding_dim)))
X, inter_vars = pos_encoding.apply(params, jnp.zeros((1, num_steps, encoding_dim)),
mutable='intermediates')
P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P
P = P[:, :X.shape[1], :]
d2l.plot(jnp.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in jnp.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_54_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(tf.zeros((1, num_steps, encoding_dim)), training=False)
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_57_0.svg
.. raw:: html
.. raw:: html
Absolute Positional Information
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To see how the monotonically decreased frequency along the encoding
dimension relates to absolute positional information, let’s print out
the binary representations of :math:`0, 1, \ldots, 7`. As we can see,
the lowest bit, the second-lowest bit, and the third-lowest bit
alternate on every number, every two numbers, and every four numbers,
respectively.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
In binary representations, a higher bit has a lower frequency than a
lower bit. Similarly, as demonstrated in the heat map below, the
positional encoding decreases frequencies along the encoding dimension
by using trigonometric functions. Since the outputs are float numbers,
such continuous representations are more space-efficient than binary
representations.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_78_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
P = np.expand_dims(np.expand_dims(P[0, :, :], 0), 0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_81_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
P = jnp.expand_dims(jnp.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_84_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
P = tf.expand_dims(tf.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
.. figure:: output_self-attention-and-positional-encoding_ce9eb6_87_0.svg
.. raw:: html
.. raw:: html
Relative Positional Information
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Besides capturing absolute positional information, the above positional
encoding also allows a model to easily learn to attend by relative
positions. This is because for any fixed position offset :math:`\delta`,
the positional encoding at position :math:`i + \delta` can be
represented by a linear projection of that at position :math:`i`.
This projection can be explained mathematically. Denoting
:math:`\omega_j = 1/10000^{2j/d}`, any pair of
:math:`(p_{i, 2j}, p_{i, 2j+1})` in
:eq:`eq_positional-encoding-def` can be linearly projected to
:math:`(p_{i+\delta, 2j}, p_{i+\delta, 2j+1})` for any fixed offset
:math:`\delta`:
.. math::
\begin{aligned}
\begin{bmatrix} \cos(\delta \omega_j) & \sin(\delta \omega_j) \\ -\sin(\delta \omega_j) & \cos(\delta \omega_j) \\ \end{bmatrix}
\begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \\ \end{bmatrix}
=&\begin{bmatrix} \cos(\delta \omega_j) \sin(i \omega_j) + \sin(\delta \omega_j) \cos(i \omega_j) \\ -\sin(\delta \omega_j) \sin(i \omega_j) + \cos(\delta \omega_j) \cos(i \omega_j) \\ \end{bmatrix}\\
=&\begin{bmatrix} \sin\left((i+\delta) \omega_j\right) \\ \cos\left((i+\delta) \omega_j\right) \\ \end{bmatrix}\\
=&
\begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \\ \end{bmatrix},
\end{aligned}
where the :math:`2\times 2` projection matrix does not depend on any
position index :math:`i`.
Summary
-------
In self-attention, the queries, keys, and values all come from the same
place. Both CNNs and self-attention enjoy parallel computation and
self-attention has the shortest maximum path length. However, the
quadratic computational complexity with respect to the sequence length
makes self-attention prohibitively slow for very long sequences. To use
the sequence order information, we can inject absolute or relative
positional information by adding positional encoding to the input
representations.
Exercises
---------
1. Suppose that we design a deep architecture to represent a sequence by
stacking self-attention layers with positional encoding. What could
the possible issues be?
2. Can you design a learnable positional encoding method?
3. Can we assign different learned embeddings according to different
offsets between queries and keys that are compared in self-attention?
Hint: you may refer to relative position embeddings
:cite:`shaw2018self,huang2018music`.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html