Loading...
Development

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (ha

"Attention is All You Need" — Build Scaled Dot-Product Attention from Scratch

A Complete One-Module Learning Tutorial with Graphs, Hashing, and Code


Module Objective

Implement and understand the Scaled Dot-Product Attention mechanism from the seminal paper "Attention is All You Need" (Vaswani et al., 2017) — with visualization, intuition, and efficiency tricks (hashing for large inputs).


1. Core Idea: Why Attention?

"Let every token talk to every other token — weighted by relevance."

Instead of RNNs or CNNs, Attention computes direct dependencies between input tokens.


2. Scaled Dot-Product Attention — The Formula

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

Where:

  • $ Q \in \mathbb{R}^{n \times d_k} $: Queries
  • $ K \in \mathbb{R}^{m \times d_k} $: Keys
  • $ V \in \mathbb{R}^{m \times d_v} $: Values
  • $ d_k $: dimension of keys/queries
  • $ n $: number of queries (e.g., output sequence length)
  • $ m $: number of keys/values (e.g., input sequence length)

3. Step-by-Step Breakdown

StepOperationShape
1$ QK^T $$ (n, d_k) \times (d_k, m) \to (n, m) $
2Scale: $ \div \sqrt{d_k} $Stabilizes gradients
3Softmax over last dim$ \to $ attention weights
4Multiply by $ V $$ (n, m) \times (m, d_v) \to (n, d_v) $

4. PyTorch Implementation (From Scratch)

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, n, d_k)
    K: (batch, m, d_k)
    V: (batch, m, d_v)
    """
    d_k = Q.size(-1)
    
    # Step 1: QK^T
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, n, m)
    
    # Step 2: Scale
    scores = scores / (d_k ** 0.5)
    
    # Step 3: Optional Mask (for decoder)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Softmax
    attn_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Weighted sum of values
    output = torch.matmul(attn_weights, V)  # (batch, n, d_v)
    
    return output, attn_weights

5. Test with Dummy Data

batch_size = 1
seq_len = 4
d_k = d_v = 8

# Simulate learned projections
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, attn = scaled_dot_product_attention(Q, K, V)

print("Output shape:", output.shape)        # (1, 4, 8)
print("Attention weights shape:", attn.shape)  # (1, 4, 4)

6. Visualize Attention Weights

def plot_attention(attn_weights, title="Attention Weights"):
    plt.figure(figsize=(6, 5))
    sns.heatmap(
        attn_weights[0].detach().cpu().numpy(),
        cmap="Blues",
        annot=True,
        fmt=".2f",
        xticklabels=[f"Key {i}" for i in range(seq_len)],
        yticklabels=[f"Query {i}" for i in range(seq_len)]
    )
    plt.title(title)
    plt.xlabel("Keys")
    plt.ylabel("Queries")
    plt.show()

plot_attention(attn, "Random Attention (Before Training)")

After training, attention becomes sharp and meaningful (e.g., "it" → "cat").


7. Add Causal Mask (Decoder-Only)

def create_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # True where allowed

mask = create_causal_mask(seq_len)
mask = mask.unsqueeze(0)  # (1, seq_len, seq_len)

output_masked, attn_masked = scaled_dot_product_attention(Q, K, V, mask=mask)
plot_attention(attn_masked, "Causal (Autoregressive) Attention")

Prevents future peeking — essential for language generation.


8. Efficiency Problem: $ O(n^2) $ Memory & Time

Sequence LengthMemory (GB)Time
512~0.5 GBFast
4096~16 GBSlow
32768~1 TBImpossible

9. Optimization: Hashing + Sparse Attention

Idea: Locality-Sensitive Hashing (LSH) for Attention

Only attend to nearby or similar keys → reduce $ O(n^2) \to O(n \log n) $

LSH Attention (Reformer-style)

import torch.nn as nn

class LSHAttention(nn.Module):
    def __init__(self, d_model, n_hashes=4, bucket_size=64):
        super().__init__()
        self.n_hashes = n_hashes
        self.bucket_size = bucket_size
        self.d_model = d_model
        
    def hash_vectors(self, vectors):
        # Random rotation + bucket
        rotation_matrix = torch.randn(self.d_model, self.d_model)
        rotated_vecs = vectors @ rotation_matrix
        buckets = torch.argmax(rotated_vecs, dim=-1)
        return buckets

    def forward(self, Q, K, V):
        batch_size, seq_len, d = Q.shape
        
        # Multi-round LSH
        all_outputs = []
        all_weights = []

        for _ in range(self.n_hashes):
            buckets = self.hash_vectors(K)  # (batch, seq_len)
            sorted buckets, indices = torch.sort(buckets)
            
            # Chunk into buckets
            chunks = torch.split(indices, self.bucket_size, dim=1)
            
            # Approximate attention within chunks
            chunk_outs = []
            for chunk in chunks:
                Q_chunk = Q.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                K_chunk = K.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                V_chunk = V.gather(1, chunk.unsqueeze(-1).expand(-1, -1, d))
                
                out, w = scaled_dot_product_attention(Q_chunk, K_chunk, V_chunk)
                chunk_outs.append(out)
            
            output = torch.cat(chunk_outs, dim=1)
            all_outputs.append(output.unsqueeze(1))
        
        final_output = torch.mean(torch.cat(all_outputs, dim=1), dim=1)
        return final_output, None  # weights not meaningful

Memory: $ O(n \cdot b) $ where $ b $ = bucket size
Used in: Reformer, Longformer, BigBird


10. Full Multi-Head Attention (Transformer Block)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def split_heads(self, x):
        batch, seq, _ = x.shape
        return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        batch, _, seq, d_k = x.shape
        return x.transpose(1, 2).contiguous().view(batch, seq, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        output = self.combine_heads(attn_output)
        return self.W_o(output), attn_weights

11. Full Example: Train Tiny Model

# Tiny dataset: learn to copy input
X = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8]], dtype=torch.long)
Y = X.clone()

model = MultiHeadAttention(d_model=16, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    optimizer.zero_grad()
    
    # Embed (simple)
    emb = nn.Embedding(10, 16)
    x = emb(X)
    
    output, _ = model(x, x, x)
    loss = F.mse_loss(output, x)
    
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

12. Summary Cheat Sheet

ComponentFormulaPurpose
Dot Product$ QK^T $Similarity
Scaling$ \div \sqrt{d_k} $Gradient stability
Softmax$ \text{softmax}(\cdot) $
ModelAttention TypeEfficiency
TransformerFull $ O(n^2) $Baseline
ReformerLSH$ O(n \log n) $
LongformerSliding Window$ O(n) $

13. Graph: Attention Scaling

import numpy as np

seq_lens = [128, 512, 2048, 8192, 32768]
full_mem = np.array(seq_lens)**2 * 4 / 1e9  # GB (float32)
lsh_mem = np.array(seq_lens) * 64 * 4 / 1e9  # bucket_size=64

plt.figure(figsize=(8, 5))
plt.plot(seq_lens, full_mem, 'r-o', label="Full Attention (O(n²))")
plt.plot(seq_lens, lsh_mem, 'g--s', label="LSH Attention (O(n log n))")
plt.yscale('log')
plt.xlabel("Sequence Length")
plt.ylabel("Memory (GB)")
plt.title("Attention Memory Scaling")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Practice Exercises

  1. Implement masked multi-head attention for a decoder.
  2. Replace softmax with sparsemax or entmax.
  3. Add relative position encodings.
  4. Use Performer's FAVOR+ (linear attention).
  5. Visualize attention on real sentences using Hugging Face.

Key Takeaways

CheckInsight
CheckAttention = weighted sum of values, guided by query-key similarity
CheckScaling prevents vanishing gradients
CheckCausal mask → autoregressive generation
CheckHashing (LSH) → long sequences
CheckMulti-head → multiple perspectives

Final Words

"Attention is All You Need" — not just a paper, but a paradigm shift.

You now have:

  • Full mathematical understanding
  • Working PyTorch code
  • Visualization tools
  • Efficiency tricks (hashing)

Next Steps:
Build a mini-Transformer from scratch → train on text → generate poetry!


End of Module
You just built the heart of GPT, BERT, and every modern LLM.
Attention is yours.