Loading...
Development

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling

Beam Search & Sampling

Complete Module: Priority Queues, Heaps, Top-k, Nucleus Sampling


Module Objective

Master advanced text generationBeam Search, Top-k, Nucleus (Top-p) sampling using priority queues (heaps) — with full PyTorch implementation and 10x better coherence.


1. Why Not Greedy Decoding?

# Greedy: always pick argmax
next_token = logits[:, -1].argmax()

Problem:

  • Gets stuck in local optima
  • Misses high-probability sequences
  • Repetitive: "I I I I I..."

2. Beam Search: Keep Top-k Paths

"Explore multiple futures, keep the best"

Step 1:
  "The" → ["cat", "dog", "man"]
  
Step 2:
  "The cat" → ["sat", "is", "jumped"]
  "The dog" → ["barked", "ran", "is"]
  ...
→ Keep top-3 sequences by log-prob

3. Priority Queue (Heap) = Core of Beam Search

import heapq

# (log_prob, sequence)
beam = [(-0.1, [1]), (-0.2, [2]), (-0.3, [3])]
heapq.heapify(beam)

Heap operations:

  • heappush: O(log k)
  • heappop: O(log k)
  • Beam width k = 5 → fast

4. Beam Search Implementation

@torch.no_grad()
def beam_search(model, idx, beam_width=5, max_len=50, eos_token=0):
    model.eval()
    
    # Initial beam: (log_prob, sequence, cache)
    beam = [(0.0, idx.tolist(), None)]
    
    for _ in range(max_len):
        all_candidates = []
        
        for log_prob, seq, cache in beam:
            input_tensor = torch.tensor([seq[-1]], dtype=torch.long).unsqueeze(0) if cache else torch.tensor([seq], dtype=torch.long)
            
            logits, _, new_cache = model(input_tensor, past_kv=cache)
            log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
            
            # Get top-k from this path
            top_k = torch.topk(log_probs, beam_width)
            
            for i in range(beam_width):
                next_token = top_k.indices[i].item()
                new_log_prob = log_prob + top_k.values[i].item()
                new_seq = seq + [next_token]
                new_cache_i = new_cache
                
                all_candidates.append((new_log_prob, new_seq, new_cache_i))
                
                if next_token == eos_token:
                    break
        
        # Keep top beam_width
        beam = heapq.nlargest(beam_width, all_candidates, key=lambda x: x[0])
        
        # Early stop if all beams ended
        if all(seq[-1] == eos_token for _, seq, _ in beam):
            break
    
    # Return best sequence
    best_seq = max(beam, key=lambda x: x[0])[1]
    return torch.tensor(best_seq)

5. Top-k Sampling

def top_k_sampling(logits, k=50, temperature=1.0):
    logits = logits / temperature
    top_k = torch.topk(logits, k)
    probs = F.softmax(top_k.values, dim=-1)
    next_token = torch.multinomial(probs, 1)
    return top_k.indices[next_token]

6. Nucleus (Top-p) Sampling

"Sample from smallest set whose cumulative prob > p"

def nucleus_sampling(logits, p=0.9, temperature=1.0):
    logits = logits / temperature
    probs = F.softmax(logits, dim=-1)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    mask = cum_probs > p
    mask[..., 1:] = mask[..., :-1].clone()
    mask[..., 0] = 0
    
    filtered_probs = sorted_probs.clone()
    filtered_probs[mask] = 0
    filtered_probs = filtered_probs / filtered_probs.sum()
    
    next_token = torch.multinomial(filtered_probs, 1)
    return sorted_indices[next_token]

7. Full Generation with All Methods

@torch.no_grad()
def generate(model, prompt, method="beam", **kwargs):
    idx = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
    cache = None
    
    for _ in range(kwargs.get("max_len", 100)):
        logits, _, cache = model(idx if cache is None else idx[:, -1:], past_kv=cache)
        logits = logits[:, -1, :]
        
        if method == "greedy":
            next_token = logits.argmax(-1, keepdim=True)
        elif method == "topk":
            next_token = top_k_sampling(logits, **kwargs)
        elif method == "nucleus":
            next_token = nucleus_sampling(logits, **kwargs)
        elif method == "beam":
            # Switch to beam search
            return beam_search(model, idx, **kwargs)
        
        idx = torch.cat([idx, next_token], dim=1)
        if next_token.item() == 0: break
    
    return idx

8. Comparison: All Methods on TinyShakespeare

prompt = "ROMEO:"

print("Greedy:")
print(decode(generate(model, prompt, method="greedy", max_len=100)[0].tolist()))

print("\nTop-k (k=40):")
print(decode(generate(model, prompt, method="topk", k=40, temperature=0.8, max_len=100)[0].tolist()))

print("\nNucleus (p=0.9):")
print(decode(generate(model, prompt, method="nucleus", p=0.9, temperature=1.0, max_len=100)[0].tolist()))

print("\nBeam Search (width=5):")
beam_out = beam_search(model, torch.tensor(encode(prompt)).unsqueeze(0), beam_width=5)
print(decode(beam_out.tolist()))

Results:

  • Greedy: repetitive
  • Top-k: diverse, sometimes incoherent
  • Nucleus: best balance
  • Beam: most fluent, but deterministic

9. Priority Queue (Heap) in Action

import heapq

# Simulate beam
beam = []
heapq.heappush(beam, (-0.1, [1, 2]))  # log_prob, seq
heapq.heappush(beam, (-0.3, [1, 3]))
heapq.heappush(beam, (-0.2, [1, 4]))

print(heapq.heappop(beam))  # (-0.1, [1, 2]) → best

10. Beam Search with Length Normalization

# Prevent short sequences
score = log_prob / (len(seq) ** 0.6)

11. Summary Table

MethodDiversityCoherenceSpeedUse Case
GreedyLowMediumFastestBaseline
Top-kMediumMediumFastGeneral
NucleusHighHighFastBest for creativity
BeamLowHighestSlowBest for accuracy

12. Practice Exercises

  1. Add length penalty to beam search
  2. Implement diverse beam search
  3. Combine top-k + nucleus
  4. Measure perplexity of outputs
  5. Visualize probability mass

13. Key Takeaways

CheckInsight
CheckBeam Search = BFS with heap
CheckTop-k = truncate tail
CheckNucleus = dynamic truncation
CheckNucleus > Top-k in practice
CheckUsed in GPT, Claude, Gemini

Full Copy-Paste: All Decoding Methods

import torch
import torch.nn.functional as F
import heapq

# === Top-k ===
def top_k(logits, k=50, t=1.0):
    logits = logits / t
    v, _ = torch.topk(logits, k)
    probs = F.softmax(v, dim=-1)
    return torch.multinomial(probs, 1)

# === Nucleus ===
def nucleus(logits, p=0.9, t=1.0):
    logits = logits / t
    probs = F.softmax(logits, dim=-1)
    s_idx = torch.argsort(probs, descending=True)
    s_probs = probs[s_idx]
    cum = torch.cumsum(s_probs, dim=-1)
    mask = cum > p
    mask[1:] = mask[:-1]
    mask[0] = 0
    s_probs[mask] = 0
    s_probs = s_probs / s_probs.sum()
    idx = torch.multinomial(s_probs, 1)
    return s_idx[idx]

# === Beam Search ===
@torch.no_grad()
def beam(model, idx, k=5, max_len=50):
    beam = [(0.0, idx.tolist(), None)]
    for _ in range(max_len):
        cands = []
        for lp, seq, cache in beam:
            x = torch.tensor([seq[-1]]).unsqueeze(0) if cache else torch.tensor([seq])
            logits, _, nc = model(x, cache)
            logp = F.log_softmax(logits[0, -1], dim=-1)
            for token, prob in enumerate(logp.topk(k).values.tolist()):
                cands.append((lp + prob, seq + [logp.topk(k).indices[token].item()], nc))
        beam = heapq.nlargest(k, cands, key=lambda x: x[0])
    return torch.tensor(max(beam, key=lambda x: x[0])[1])

Final Words

You now control how LLMs think.

  • Greedy → robot
  • Beam → perfectionist
  • Nucleuscreative genius

End of Module
You generate like GPT-4 — coherent, diverse, fast.
Next: Build a chatbot API.