Attention Mechanism
Table of Contents
—
Python Code
1 class MultiHeadAttention(nn.Module):
2 def __init__(self, d_in, d_out,
3 block_size, dropout, num_heads, qkv_bias=False):
4 super().__init__()
5 assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
6
7 self.d_out = d_out
8 self.num_heads = num_heads
9 self.head_dim = d_out // num_heads
10 self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
11 self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
12 self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
13 self.out_proj = nn.Linear(d_out, d_out)
14 self.dropout = nn.Dropout(dropout)
15 self.register_buffer(
16 'mask',
17 torch.triu(torch.ones(block_size, block_size), diagonal=1)
18 )
19
20 def forward(self, x):
21 b, num_tokens, d_in = x.shape
22 keys = self.W_key(x)
23 queries = self.W_query(x)
24 values = self.W_value(x)
25
26 keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
27 values = values.view(b, num_tokens, self.num_heads, self.head_dim)
28 queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
29
30 keys = keys.transpose(1, 2)
31 queries = queries.transpose(1, 2)
32 values = values.transpose(1, 2)
33
34 attn_scores = queries @ keys.transpose(2, 3)
35 mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
36 mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
37 attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
38
39 attn_weights = torch.softmax(
40 attn_scores / keys.shape[-1]**0.5, dim=-1)
41 attn_weights = self.dropout(attn_weights)
42
43 context_vec = (attn_weights @ values).transpose(1, 2)
44
45 context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
46 context_vec = self.out_proj(context_vec)
47 return context_vec
MultiHeadAttention Class Definition
- Line 1, Class Definition:
- Defines a class
MultiHeadAttention
as a subclass ofnn.Module
. - This class implements the multi-head attention mechanism.
- Defines a class
- Line 2, Constructor:
- Initializes the multi-head attention layer with specified parameters.
- Parameters include input dimension (
d_in
), output dimension (d_out
), block size for masking, dropout rate, number of attention heads (num_heads
), and an optional bias for query, key, and value projections (qkv_bias
).
- Line 3, Superclass Initialization:
- Calls the constructor of the superclass
nn.Module
to properly initialize the class.
- Calls the constructor of the superclass
- Line 4, Dimension Assertion:
- Ensures that the output dimension (
d_out
) is divisible by the number of heads (num_heads
). - This is necessary to evenly distribute the dimensions across the heads.
- Ensures that the output dimension (
- Lines 6-14, Parameter Definitions:
- Sets up the internal parameters for the multi-head attention mechanism.
self.d_out
: Output dimension.self.num_heads
: Number of attention heads.self.head_dim
: Dimension of each head, calculated asd_out / num_heads
.self.W_query
,self.W_key
,self.W_value
: Linear layers for projecting inputs to query, key, and value spaces, respectively.self.out_proj
: Linear layer for projecting concatenated outputs from all heads.self.dropout
: Dropout layer to prevent overfitting.
- Line 15-19, Mask Buffer Registration:
- Registers a buffer
mask
for the subsequent attention mask, which will be used to ignore future tokens by setting attention scores to-inf
. - The mask is an upper triangular matrix with zeros on the diagonal and ones elsewhere, created using
torch.triu
.
- Registers a buffer
Forward Method
- Line 20, Forward Method Definition:
- Defines the forward pass of the multi-head attention layer.
- Takes an input tensor
x
with shape assumptions of(30, 50, 512)
for batch size, sequence size, and embedding dimension, respectively.
- Line 21, Input Shape Extraction:
- Extracts the batch size (
b
), number of tokens (num_tokens
), and input dimension (d_in
) from the input tensorx
.
- Extracts the batch size (
- Lines 22-24, Query/Key/Value Projections:
- Applies linear transformations to the input tensor to obtain queries, keys, and values.
- The resulting tensors have the same shape
(30, 50, 512)
.
- Lines 26-28, Reshaping for Multi-Head Attention:
- Reshapes the query, key, and value tensors to prepare for multi-head attention.
- The new shape for each tensor is
(30, 50, num_heads, head_dim)
.
- Lines 30-32, Transposition for Attention Calculation:
- Transposes the reshaped tensors to bring the
num_heads
dimension before thenum_tokens
dimension. - The new shape for each tensor is
(30, num_heads, 50, head_dim)
.
- Transposes the reshaped tensors to bring the
- Line 34, Attention Score Calculation:
- Computes the attention scores by performing a batched matrix multiplication of queries and keys.
- The resulting tensor shape is
(30, num_heads, 50, 50)
.
- Lines 35-37, Attention Mask Application:
- Applies the attention mask to the attention scores, setting future tokens’ attention scores to
-inf
. - This ensures that the model cannot attend to future tokens, maintaining the auto-regressive property.
- Applies the attention mask to the attention scores, setting future tokens’ attention scores to
- Lines 39-41, Softmax and Dropout on Attention Scores:
- Applies the softmax function to the attention scores, normalizing them to probabilities.
- Applies dropout to the normalized attention scores for regularization.
- Lines 43, Context Vector Calculation:
- Computes the context vectors by performing a weighted sum of the values based on the attention weights.
- The context vectors are then transposed back to the original token dimension ordering.
- The resulting tensor shape is
(30, num_heads, 50, head_dim)
.
- Line 45, Reshaping Context Vectors:
- Reshapes the context vectors to concatenate the heads’ outputs.
- The resulting tensor shape is
(30, 50, 512)
.
- Line 46, Output Projection:
- Applies the final linear transformation to the concatenated context vectors.
- The output tensor has the same shape
(30, 50, 512)
and is returned as the final result of the multi-head attention mechanism.