Since we do not want blanks in our attention model we simply need to
limit :math:`\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i`
to :math:`\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i`
for however long, :math:`l \leq n`, the actual sentence is. Since it is
such a common problem, it has a name: the *masked softmax operation*.
Let’s implement it. Actually, the implementation cheats ever so slightly
by setting the values of :math:`\mathbf{v}_i`, for :math:`i > l`, to
zero. Moreover, it sets the attention weights to a large negative
number, such as :math:`-10^{6}`, in order to make their contribution to
gradients and values vanish in practice. This is done since linear
algebra kernels and operators are heavily optimized for GPUs and it is
faster to be slightly wasteful in computation rather than to have code
with conditional (if then else) statements.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
if valid_lens is None:
return npx.softmax(X)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = valid_lens.repeat(shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
value=-1e6, axis=1)
return npx.softmax(X).reshape(shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = jnp.arange((maxlen),
dtype=jnp.float32)[None, :] < valid_len[:, None]
return jnp.where(mask, X, value)
if valid_lens is None:
return nn.softmax(X, axis=-1)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = jnp.repeat(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.softmax(X.reshape(shape), axis=-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = tf.range(start=0, limit=maxlen, dtype=tf.float32)[
None, :] < tf.cast(valid_len[:, None], dtype=tf.float32)
if len(X.shape) == 3:
return tf.where(tf.expand_dims(mask, axis=-1), X, value)
else:
return tf.where(mask, X, value)
if valid_lens is None:
return tf.nn.softmax(X, axis=-1)
else:
shape = X.shape
if len(valid_lens.shape) == 1:
valid_lens = tf.repeat(valid_lens, repeats=shape[1])
else:
valid_lens = tf.reshape(valid_lens, shape=-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens,
value=-1e6)
return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1)
.. raw:: html
.. raw:: html
To illustrate how this function works, consider a minibatch of two
examples of size :math:`2 \times 4`, where their valid lengths are
:math:`2` and :math:`3`, respectively. As a result of the masked softmax
operation, values beyond the valid lengths for each pair of vectors are
all masked as zero.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[0.4448, 0.5552, 0.0000, 0.0000],
[0.4032, 0.5968, 0.0000, 0.0000]],
[[0.2795, 0.2805, 0.4400, 0.0000],
[0.2798, 0.3092, 0.4110, 0.0000]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:05:24] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[0.488994 , 0.511006 , 0. , 0. ],
[0.43654838, 0.56345165, 0. , 0. ]],
[[0.28817102, 0.3519408 , 0.3598882 , 0. ],
[0.29034293, 0.25239873, 0.45725834, 0. ]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[0.2914798 , 0.7085202 , 0. , 0. ],
[0.5130609 , 0.48693904, 0. , 0. ]],
[[0.17453432, 0.4599773 , 0.36548832, 0. ],
[0.3574293 , 0.3150612 , 0.32750952, 0. ]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
If we need more fine-grained control to specify the valid length for
each of the two vectors of every example, we simply use a
two-dimensional tensor of valid lengths. This yields:
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.4109, 0.2794, 0.3097, 0.0000]],
[[0.3960, 0.6040, 0.0000, 0.0000],
[0.2557, 0.1833, 0.2420, 0.3190]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)),
np.array([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[1. , 0. , 0. , 0. ],
[0.35848376, 0.36588794, 0.2756283 , 0. ]],
[[0.54370314, 0.45629686, 0. , 0. ],
[0.19598779, 0.25580424, 0.19916737, 0.34904057]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)),
jnp.array([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[1. , 0. , 0. , 0. ],
[0.31556115, 0.28214547, 0.40229338, 0. ]],
[[0.5613054 , 0.43869466, 0. , 0. ],
[0.29578257, 0.20095006, 0.2151548 , 0.28811258]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
.. _subsec_batch_dot:
Batch Matrix Multiplication
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Another commonly used operation is to multiply batches of matrices by
one another. This comes in handy when we have minibatches of queries,
keys, and values. More specifically, assume that
.. math::
\mathbf{Q} = [\mathbf{Q}_1, \mathbf{Q}_2, \ldots, \mathbf{Q}_n] \in \mathbb{R}^{n \times a \times b}, \\
\mathbf{K} = [\mathbf{K}_1, \mathbf{K}_2, \ldots, \mathbf{K}_n] \in \mathbb{R}^{n \times b \times c}.
Then the batch matrix multiplication (BMM) computes the elementwise
product
.. math:: \textrm{BMM}(\mathbf{Q}, \mathbf{K}) = [\mathbf{Q}_1 \mathbf{K}_1, \mathbf{Q}_2 \mathbf{K}_2, \ldots, \mathbf{Q}_n \mathbf{K}_n] \in \mathbb{R}^{n \times a \times c}.
:label: eq_batch-matrix-mul
Let’s see this in action in a deep learning framework.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 6))
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = np.ones((2, 3, 4))
K = np.ones((2, 4, 6))
d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = jnp.ones((2, 3, 4))
K = jnp.ones((2, 4, 6))
d2l.check_shape(jax.lax.batch_matmul(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = tf.ones((2, 3, 4))
K = tf.ones((2, 4, 6))
d2l.check_shape(tf.matmul(Q, K).numpy(), (2, 3, 6))
.. raw:: html
.. raw:: html
Scaled Dot Product Attention
----------------------------
Let’s return to the dot product attention introduced in
:eq:`eq_dot_product_attention`. In general, it requires that both
the query and the key have the same vector length, say :math:`d`, even
though this can be addressed easily by replacing
:math:`\mathbf{q}^\top \mathbf{k}` with
:math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}` where :math:`\mathbf{M}`
is a matrix suitably chosen for translating between both spaces. For now
assume that the dimensions match.
In practice, we often think of minibatches for efficiency, such as
computing attention for :math:`n` queries and :math:`m` key-value pairs,
where queries and keys are of length :math:`d` and values are of length
:math:`v`. The scaled dot product attention of queries
:math:`\mathbf Q\in\mathbb R^{n\times d}`, keys
:math:`\mathbf K\in\mathbb R^{m\times d}`, and values
:math:`\mathbf V\in\mathbb R^{m\times v}` thus can be written as
.. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.
:label: eq_softmax_QK_V
Note that when applying this to a minibatch, we need the batch matrix
multiplication introduced in :eq:`eq_batch-matrix-mul`. In the
following implementation of the scaled dot product attention, we use
dropout for model regularization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.transpose(1, 2)
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Block): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set transpose_b=True to swap the last two dimensions of keys
scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
dropout: float
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
@nn.compact
def __call__(self, queries, keys, values, valid_lens=None,
training=False):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.swapaxes(1, 2)
scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
return dropout_layer(attention_weights)@values, attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(tf.keras.layers.Layer): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def call(self, queries, keys, values, valid_lens=None, **kwargs):
d = queries.shape[-1]
scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(
tf.cast(d, dtype=tf.float32))
self.attention_weights = masked_softmax(scores, valid_lens)
return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
To illustrate how the ``DotProductAttention`` class works, we use the
same keys, values, and valid lengths from the earlier toy example for
additive attention. For the purpose of our example we assume that we
have a minibatch size of :math:`2`, a total of :math:`10` keys and
values, and that the dimensionality of the values is :math:`4`. Lastly,
we assume that the valid length per observation is :math:`2` and
:math:`6` respectively. Given that, we expect the output to be a
:math:`2 \times 1 \times 4` tensor, i.e., one row per example of the
minibatch.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = np.random.normal(0, 1, (2, 1, 2))
keys = np.random.normal(0, 1, (2, 10, 2))
values = np.random.normal(0, 1, (2, 10, 4))
valid_lens = np.array([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = jax.random.normal(d2l.get_key(), (2, 1, 2))
keys = jax.random.normal(d2l.get_key(), (2, 10, 2))
values = jax.random.normal(d2l.get_key(), (2, 10, 4))
valid_lens = jnp.array([2, 6])
attention = DotProductAttention(dropout=0.5)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[[ 0.75924027 -0.4776329 0.19306126 0.15036084]]
[[-0.07728005 1.1064801 -0.839485 -0.36051023]]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = tf.random.normal(shape=(2, 1, 2))
keys = tf.random.normal(shape=(2, 10, 2))
values = tf.random.normal(shape=(2, 10, 4))
valid_lens = tf.constant([2, 6])
attention = DotProductAttention(dropout=0.5)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
.. raw:: html
.. raw:: html
Let’s check whether the attention weights actually vanish for anything
beyond the second and sixth column respectively (because of setting the
valid length to :math:`2` and :math:`6`).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_108_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_111_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_114_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_117_0.svg
.. raw:: html
.. raw:: html
.. _subsec_additive-attention:
Additive Attention
------------------
When queries :math:`\mathbf{q}` and keys :math:`\mathbf{k}` are vectors
of different dimension, we can either use a matrix to address the
mismatch via :math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}`, or we can
use additive attention as the scoring function. Another benefit is that,
as its name indicates, the attention is additive. This can lead to some
minor computational savings. Given a query
:math:`\mathbf{q} \in \mathbb{R}^q` and a key
:math:`\mathbf{k} \in \mathbb{R}^k`, the *additive attention* scoring
function :cite:`Bahdanau.Cho.Bengio.2014` is given by
.. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \textrm{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},
:label: eq_additive-attn
where :math:`\mathbf W_q\in\mathbb R^{h\times q}`,
:math:`\mathbf W_k\in\mathbb R^{h\times k}`, and
:math:`\mathbf w_v\in\mathbb R^{h}` are the learnable parameters. This
term is then fed into a softmax to ensure both nonnegativity and
normalization. An equivalent interpretation of
:eq:`eq_additive-attn` is that the query and key are concatenated
and fed into an MLP with a single hidden layer. Using :math:`\tanh` as
the activation function and disabling bias terms, we implement additive
attention as follows:
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Module): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.LazyLinear(num_hiddens, bias=False)
self.W_q = nn.LazyLinear(num_hiddens, bias=False)
self.w_v = nn.LazyLinear(1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Block): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
# Use flatten=False to only transform the last axis so that the
# shapes for the other axes are kept the same
self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.w_v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1,
# no. of key-value pairs, num_hiddens). Sum them up with
# broadcasting
features = np.expand_dims(queries, axis=2) + np.expand_dims(
keys, axis=1)
features = np.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores:
# (batch_size, no. of queries, no. of key-value pairs)
scores = np.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Module): #@save
num_hiddens: int
dropout: float
def setup(self):
self.W_k = nn.Dense(self.num_hiddens, use_bias=False)
self.W_q = nn.Dense(self.num_hiddens, use_bias=False)
self.w_v = nn.Dense(1, use_bias=False)
@nn.compact
def __call__(self, queries, keys, values, valid_lens, training=False):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1)
features = nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return dropout_layer(attention_weights)@values, attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(tf.keras.layers.Layer): #@save
"""Additive attention."""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super().__init__(**kwargs)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.w_v = tf.keras.layers.Dense(1, use_bias=False)
self.dropout = tf.keras.layers.Dropout(dropout)
def call(self, queries, keys, values, valid_lens, **kwargs):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = tf.expand_dims(queries, axis=2) + tf.expand_dims(
keys, axis=1)
features = tf.nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = tf.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return tf.matmul(self.dropout(
self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
Let’s see how ``AdditiveAttention`` works. In our toy example we pick
queries, keys and values of size :math:`(2, 1, 20)`, :math:`(2, 10, 2)`
and :math:`(2, 10, 4)`, respectively. This is identical to our choice
for ``DotProductAttention``, except that now the queries are
:math:`20`-dimensional. Likewise, we pick :math:`(2, 6)` as the valid
lengths for the sequences in the minibatch.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = torch.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = np.random.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = jax.random.normal(d2l.get_key(), (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[[ 0.8057054 -0.45312855 0.233752 0.32691044]]
[[-0.23993565 0.23599407 0.04756263 0.13463953]]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = tf.random.normal(shape=(2, 1, 20))
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
.. raw:: html
.. raw:: html
When reviewing the attention function we see a behavior that is
qualitatively quite similar to that of ``DotProductAttention``. That is,
only terms within the chosen valid length :math:`(2, 6)` are nonzero.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_153_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_156_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_159_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. figure:: output_attention-scoring-functions_722781_162_0.svg
.. raw:: html
.. raw:: html
Summary
-------
In this section we introduced the two key attention scoring functions:
dot product and additive attention. They are effective tools for
aggregating across sequences of variable length. In particular, the dot
product attention is the mainstay of modern Transformer architectures.
When queries and keys are vectors of different lengths, we can use the
additive attention scoring function instead. Optimizing these layers is
one of the key areas of advance in recent years. For instance, `NVIDIA’s
Transformer
Library `__
and Megatron :cite:`shoeybi2019megatron` crucially rely on efficient
variants of the attention mechanism. We will dive into this in quite a
bit more detail as we review Transformers in later sections.
Exercises
---------
1. Implement distance-based attention by modifying the
``DotProductAttention`` code. Note that you only need the squared
norms of the keys :math:`\|\mathbf{k}_i\|^2` for an efficient
implementation.
2. Modify the dot product attention to allow for queries and keys of
different dimensionalities by employing a matrix to adjust
dimensions.
3. How does the computational cost scale with the dimensionality of the
keys, queries, values, and their number? What about the memory
bandwidth requirements?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html