LLM/NLP Foundations Code Reference
Complete code implementations for Chapter 5: LLM/NLP Foundations.
Table of Contents
Tokenization
Token Budget Manager
Complete implementation for managing token budgets in prompts:
class TokenBudget:
"""Helper for managing token budgets in prompts."""
def __init__(self, model_max_tokens: int, tokenizer):
self.max_tokens = model_max_tokens
self.tokenizer = tokenizer
def count(self, text: str) -> int:
"""Count tokens in text."""
return len(self.tokenizer.encode(text))
def fits(self, components: dict[str, str], reserve_output: int = 500) -> bool:
"""Check if all components fit within budget.
Args:
components: Dict mapping component names to text content
reserve_output: Tokens to reserve for model output
Returns:
True if total tokens + reserved output fits in max_tokens
"""
total = sum(self.count(text) for text in components.values())
return total + reserve_output <= self.max_tokens
def truncate_to_fit(self, text: str, max_tokens: int) -> str:
"""Truncate text to fit within token limit.
Args:
text: Text to potentially truncate
max_tokens: Maximum tokens allowed
Returns:
Original text if within limit, otherwise truncated version
"""
tokens = self.tokenizer.encode(text)
if len(tokens) <= max_tokens:
return text
return self.tokenizer.decode(tokens[:max_tokens])
def allocate_budget(
self,
fixed_components: dict[str, str],
variable_component: str,
reserve_output: int = 500
) -> str:
"""Allocate remaining budget to a variable-length component.
Args:
fixed_components: Components with fixed content (system prompt, query)
variable_component: Content to truncate if needed (e.g., retrieved docs)
reserve_output: Tokens to reserve for model output
Returns:
Truncated variable component that fits remaining budget
"""
fixed_tokens = sum(self.count(text) for text in fixed_components.values())
available = self.max_tokens - fixed_tokens - reserve_output
return self.truncate_to_fit(variable_component, max(0, available))
# Usage example {.unnumbered}
def build_rag_prompt(
tokenizer,
system_prompt: str,
retrieved_docs: list[str],
user_query: str,
model_max_tokens: int = 4096,
reserve_output: int = 1000
) -> dict[str, str]:
"""Build a RAG prompt that fits within token limits."""
budget = TokenBudget(model_max_tokens, tokenizer)
# Fixed components
fixed = {
"system": system_prompt,
"user_query": user_query
}
# Combine retrieved documents
context = "\n\n---\n\n".join(retrieved_docs)
# Truncate context to fit
truncated_context = budget.allocate_budget(fixed, context, reserve_output)
return {
"system": system_prompt,
"context": truncated_context,
"user_query": user_query
}Embeddings
Embedding Pooling Strategies
Complete implementation for pooling token embeddings into sequence embeddings:
import torch
def pool_embeddings(
token_embeddings: torch.Tensor,
attention_mask: torch.Tensor,
strategy: str = "mean"
) -> torch.Tensor:
"""
Pool token embeddings into a single sequence embedding.
Args:
token_embeddings: (batch_size, seq_length, hidden_dim)
attention_mask: (batch_size, seq_length) - 1 for real tokens, 0 for padding
strategy: "mean", "max", or "cls"
Returns:
Pooled embeddings of shape (batch_size, hidden_dim)
"""
if strategy == "cls":
# Use the [CLS] token embedding (position 0)
# Common for BERT-style models
return token_embeddings[:, 0, :]
elif strategy == "mean":
# Average non-padding tokens
# Most common for sentence embeddings
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) # Avoid division by zero
return sum_embeddings / sum_mask
elif strategy == "max":
# Max-pool over non-padding tokens
# Can capture salient features
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
token_embeddings = token_embeddings.clone()
token_embeddings[mask_expanded == 0] = -1e9 # Mask padding with large negative
return torch.max(token_embeddings, dim=1).values
elif strategy == "mean_max":
# Concatenate mean and max pooling
# Sometimes gives better results
mean_pooled = pool_embeddings(token_embeddings, attention_mask, "mean")
max_pooled = pool_embeddings(token_embeddings, attention_mask, "max")
return torch.cat([mean_pooled, max_pooled], dim=-1)
else:
raise ValueError(f"Unknown pooling strategy: {strategy}")
def normalize_embeddings(embeddings: torch.Tensor) -> torch.Tensor:
"""L2 normalize embeddings to unit length.
Normalized embeddings allow using dot product for cosine similarity,
which is faster and enables efficient nearest neighbor search.
"""
return embeddings / embeddings.norm(dim=-1, keepdim=True)
def compute_similarity_matrix(
query_embeddings: torch.Tensor,
document_embeddings: torch.Tensor,
normalize: bool = True
) -> torch.Tensor:
"""Compute pairwise similarity between queries and documents.
Args:
query_embeddings: (num_queries, hidden_dim)
document_embeddings: (num_documents, hidden_dim)
normalize: Whether to L2 normalize (gives cosine similarity)
Returns:
Similarity matrix of shape (num_queries, num_documents)
"""
if normalize:
query_embeddings = normalize_embeddings(query_embeddings)
document_embeddings = normalize_embeddings(document_embeddings)
return torch.mm(query_embeddings, document_embeddings.T)Embedding Analogy Solver
Complete implementation for solving word analogies using embeddings:
import torch
def analogy(
embeddings: torch.Tensor,
word_to_idx: dict[str, int],
idx_to_word: dict[int, str],
a: str,
b: str,
c: str,
top_k: int = 5
) -> list[str]:
"""Find d such that a:b :: c:d using vector arithmetic.
The classic example: king - man + woman = queen
This works because embedding spaces encode semantic relationships
as geometric directions.
Args:
embeddings: (vocab_size, embedding_dim) tensor of word embeddings
word_to_idx: Mapping from words to embedding indices
idx_to_word: Mapping from indices to words
a, b, c: Words forming the analogy a:b :: c:?
top_k: Number of candidates to return
Returns:
List of top_k most likely words to complete the analogy
"""
# Get embeddings for input words
vec_a = embeddings[word_to_idx[a]]
vec_b = embeddings[word_to_idx[b]]
vec_c = embeddings[word_to_idx[c]]
# Compute target vector: a - b + c
# If a:b captures a relationship, adding that relationship to c
# should give us the analogous word d
target = vec_a - vec_b + vec_c
# Normalize target for cosine similarity
target_norm = target / torch.norm(target)
# Normalize all embeddings
all_norms = embeddings / embeddings.norm(dim=1, keepdim=True)
# Compute similarities with all words
similarities = torch.matmul(all_norms, target_norm)
# Exclude input words from results
exclude_idx = {word_to_idx[w] for w in [a, b, c]}
for idx in exclude_idx:
similarities[idx] = -float('inf')
# Return top-k most similar words
top_indices = similarities.topk(top_k).indices
return [idx_to_word[i.item()] for i in top_indices]
def find_similar_words(
embeddings: torch.Tensor,
word_to_idx: dict[str, int],
idx_to_word: dict[int, str],
query_word: str,
top_k: int = 10
) -> list[tuple[str, float]]:
"""Find words most similar to a query word.
Args:
embeddings: (vocab_size, embedding_dim) tensor
word_to_idx: Mapping from words to indices
idx_to_word: Mapping from indices to words
query_word: Word to find similar words for
top_k: Number of similar words to return
Returns:
List of (word, similarity_score) tuples
"""
query_idx = word_to_idx[query_word]
query_vec = embeddings[query_idx]
# Normalize for cosine similarity
query_norm = query_vec / torch.norm(query_vec)
all_norms = embeddings / embeddings.norm(dim=1, keepdim=True)
similarities = torch.matmul(all_norms, query_norm)
# Exclude the query word itself
similarities[query_idx] = -float('inf')
top_scores, top_indices = similarities.topk(top_k)
return [
(idx_to_word[idx.item()], score.item())
for idx, score in zip(top_indices, top_scores)
]Attention
Scaled Dot-Product Attention
Complete implementation of scaled dot-product attention:
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor = None,
dropout: float = 0.0,
training: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute scaled dot-product attention.
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
Args:
Q: Queries, shape (batch, num_heads, seq_len, d_k)
K: Keys, shape (batch, num_heads, seq_len, d_k)
V: Values, shape (batch, num_heads, seq_len, d_v)
mask: Optional attention mask, shape (batch, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)
dropout: Dropout probability for attention weights
training: Whether in training mode (affects dropout)
Returns:
output: Attention output, shape (batch, num_heads, seq_len, d_v)
attention_weights: Attention weights, shape (batch, num_heads, seq_len, seq_len)
"""
d_k = Q.shape[-1]
# Step 1: Compute attention scores
# QK^T gives us a (seq_len, seq_len) matrix of compatibility scores
# Each entry (i, j) represents how much position i should attend to position j
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, num_heads, seq_len, seq_len)
# Step 2: Scale by sqrt(d_k)
# Without scaling, dot products grow with dimension, pushing softmax
# into regions with vanishing gradients. Scaling maintains variance.
scores = scores / math.sqrt(d_k)
# Step 3: Apply mask (if provided)
# Masked positions get -inf, which becomes 0 after softmax
# Used for: causal attention, padding masks, custom attention patterns
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax to get attention weights
# Each row sums to 1: position i distributes its attention across all positions
attention_weights = F.softmax(scores, dim=-1)
# Handle case where entire row is masked (all -inf -> nan after softmax)
attention_weights = attention_weights.nan_to_num(0.0)
# Optional dropout on attention weights (regularization)
if dropout > 0.0 and training:
attention_weights = F.dropout(attention_weights, p=dropout)
# Step 5: Weighted sum of values
# Output at position i is a weighted combination of all value vectors
output = torch.matmul(attention_weights, V) # (batch, num_heads, seq_len, d_v)
return output, attention_weights
def create_causal_mask(seq_length: int, device: torch.device = None) -> torch.Tensor:
"""Create a causal attention mask.
Returns a matrix where position i can attend to positions 0...i
but not positions i+1...n. Used for autoregressive generation.
Args:
seq_length: Length of the sequence
device: Device to create tensor on
Returns:
Lower triangular matrix of shape (1, 1, seq_len, seq_len)
"""
mask = torch.tril(torch.ones(seq_length, seq_length, device=device))
return mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
def create_padding_mask(
attention_mask: torch.Tensor
) -> torch.Tensor:
"""Create attention mask from padding mask.
Args:
attention_mask: (batch, seq_len) with 1 for real tokens, 0 for padding
Returns:
Attention mask of shape (batch, 1, 1, seq_len) for broadcasting
"""
return attention_mask.unsqueeze(1).unsqueeze(2)Multi-Head Attention
Complete implementation of multi-head attention:
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
"""Multi-head attention mechanism.
Multiple attention heads allow the model to jointly attend to information
from different representation subspaces at different positions.
Different heads can learn different types of relationships:
- Syntactic patterns (subject-verb agreement)
- Semantic relationships (coreference)
- Positional patterns (nearby words)
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True
):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
self.dropout = dropout
# Projections for Q, K, V (combined into single matrices for efficiency)
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
# Output projection
self.W_o = nn.Linear(d_model, d_model, bias=bias)
self._init_weights()
def _init_weights(self):
"""Initialize weights with scaled initialization."""
for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
return_attention: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for multi-head attention.
Args:
query: (batch, seq_len, d_model)
key: (batch, seq_len, d_model) - can be from different sequence for cross-attention
value: (batch, seq_len, d_model)
mask: Optional attention mask
return_attention: Whether to return attention weights
Returns:
output: (batch, seq_len, d_model)
attention_weights: Optional, (batch, num_heads, seq_len, seq_len)
"""
batch_size = query.shape[0]
# Project and reshape to (batch, num_heads, seq_len, d_k)
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Compute attention for all heads in parallel
attn_output, attn_weights = scaled_dot_product_attention(
Q, K, V, mask,
dropout=self.dropout,
training=self.training
)
# Concatenate heads: (batch, seq_len, num_heads * d_k) = (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Final projection
output = self.W_o(attn_output)
if return_attention:
return output, attn_weights
return output
class GroupedQueryAttention(nn.Module):
"""Grouped-Query Attention (GQA) - used in LLaMA 2, Mistral, etc.
Instead of separate K, V projections per head, groups of query heads
share the same K, V. This reduces KV cache memory during inference
with minimal quality loss.
- num_heads query heads
- num_kv_heads key-value heads (num_heads must be divisible by num_kv_heads)
- When num_kv_heads = 1, this is Multi-Query Attention (MQA)
- When num_kv_heads = num_heads, this is standard Multi-Head Attention
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_kv_heads: int,
dropout: float = 0.0
):
super().__init__()
assert num_heads % num_kv_heads == 0
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = num_heads // num_kv_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, self.num_kv_heads * self.d_k, bias=False)
self.W_v = nn.Linear(d_model, self.num_kv_heads * self.d_k, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)
# Repeat K, V for each group of query heads
K = K.repeat_interleave(self.num_queries_per_kv, dim=1)
V = V.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.W_o(attn_output)Transformer Architecture
Transformer Block
Complete implementation of a transformer block:
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerBlock(nn.Module):
"""A single transformer block with pre-norm architecture.
Components:
1. Multi-head self-attention with residual connection
2. Feed-forward network with residual connection
Uses pre-norm (LayerNorm before sublayers) which is more stable
for training deep networks than post-norm.
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "gelu"
):
super().__init__()
# Attention sublayer
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
# Feed-forward sublayer
self.feed_forward = FeedForward(d_model, d_ff, dropout, activation)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Args:
x: Input tensor of shape (batch, seq_len, d_model)
mask: Optional attention mask
Returns:
Output tensor of shape (batch, seq_len, d_model)
"""
# Pre-norm attention with residual
normed = self.norm1(x)
attended = self.attention(normed, normed, normed, mask)
x = x + self.dropout(attended)
# Pre-norm feed-forward with residual
normed = self.norm2(x)
fed_forward = self.feed_forward(normed)
x = x + self.dropout(fed_forward)
return x
class FeedForward(nn.Module):
"""Position-wise feed-forward network.
FFN(x) = activation(xW1 + b1)W2 + b2
The hidden dimension is typically 4x the model dimension.
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "gelu"
):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
if activation == "gelu":
self.activation = nn.GELU()
elif activation == "relu":
self.activation = nn.ReLU()
elif activation == "silu":
self.activation = nn.SiLU()
else:
raise ValueError(f"Unknown activation: {activation}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear2(self.dropout(self.activation(self.linear1(x))))
class SwiGLU(nn.Module):
"""SwiGLU activation - used in LLaMA and other modern architectures.
SwiGLU(x) = (xW1 * SiLU(xW_gate)) W2
The gating mechanism improves performance at similar compute cost.
"""
def __init__(self, d_model: int, d_ff: int, bias: bool = False):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=bias)
self.w2 = nn.Linear(d_ff, d_model, bias=bias)
self.w_gate = nn.Linear(d_model, d_ff, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w_gate(x)) * self.w1(x))Complete Transformer Language Model
Complete implementation of a decoder-only transformer language model:
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerLM(nn.Module):
"""Decoder-only transformer language model (GPT-style).
This is the architecture used by GPT, LLaMA, and most modern LLMs.
It uses causal (autoregressive) attention so each position can only
attend to previous positions.
"""
def __init__(
self,
vocab_size: int,
d_model: int,
num_heads: int,
num_layers: int,
d_ff: int,
max_seq_length: int,
dropout: float = 0.1,
tie_weights: bool = True
):
super().__init__()
self.d_model = d_model
self.max_seq_length = max_seq_length
# Token and position embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_length, d_model)
self.embedding_dropout = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Final layer norm and output projection
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Weight tying: share weights between token embedding and output projection
# This reduces parameters and often improves performance
if tie_weights:
self.lm_head.weight = self.token_embedding.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights with scaled normal distribution."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Forward pass for language modeling.
Args:
token_ids: Input token IDs, shape (batch, seq_len)
attention_mask: Optional mask for padding, shape (batch, seq_len)
Returns:
Logits over vocabulary, shape (batch, seq_len, vocab_size)
"""
batch_size, seq_length = token_ids.shape
device = token_ids.device
# Get embeddings
positions = torch.arange(seq_length, device=device)
x = self.token_embedding(token_ids) + self.position_embedding(positions)
x = self.embedding_dropout(x)
# Create causal mask (each position attends only to previous positions)
causal_mask = torch.tril(
torch.ones(seq_length, seq_length, device=device)
).unsqueeze(0).unsqueeze(0)
# Combine with padding mask if provided
if attention_mask is not None:
# attention_mask: (batch, seq_len) -> (batch, 1, 1, seq_len)
padding_mask = attention_mask.unsqueeze(1).unsqueeze(2)
mask = causal_mask * padding_mask
else:
mask = causal_mask
# Pass through transformer blocks
for block in self.blocks:
x = block(x, mask)
# Final norm and output projection
x = self.norm(x)
logits = self.lm_head(x) # (batch, seq_len, vocab_size)
return logits
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_p: float = 0.9,
top_k: int = 50
) -> torch.Tensor:
"""Generate text autoregressively.
Args:
input_ids: Starting token IDs, shape (batch, seq_len)
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0 = greedy)
top_p: Nucleus sampling threshold
top_k: Top-k sampling limit
Returns:
Generated token IDs including input, shape (batch, seq_len + generated)
"""
self.eval()
for _ in range(max_new_tokens):
# Truncate if we exceed max length
if input_ids.shape[1] >= self.max_seq_length:
input_ids = input_ids[:, -self.max_seq_length:]
# Forward pass
with torch.no_grad():
logits = self(input_ids)
next_token_logits = logits[:, -1, :] # Last position only
# Apply temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=-1)
return input_ids
def count_parameters(model: nn.Module) -> dict:
"""Count parameters in a transformer model by component."""
counts = {}
total = 0
for name, param in model.named_parameters():
if param.requires_grad:
count = param.numel()
# Categorize by component
if 'embedding' in name:
category = 'embeddings'
elif 'attention' in name or 'W_q' in name or 'W_k' in name or 'W_v' in name or 'W_o' in name:
category = 'attention'
elif 'feed_forward' in name or 'linear' in name:
category = 'feed_forward'
elif 'norm' in name:
category = 'layer_norm'
elif 'lm_head' in name:
category = 'lm_head'
else:
category = 'other'
counts[category] = counts.get(category, 0) + count
total += count
counts['total'] = total
return countsInference and Generation
Beam Search
Complete implementation of beam search for text generation:
import torch
import torch.nn.functional as F
from dataclasses import dataclass
@dataclass
class BeamHypothesis:
"""A single hypothesis in beam search."""
tokens: torch.Tensor # Token IDs generated so far
log_prob: float # Cumulative log probability
is_finished: bool # Whether EOS has been generated
def beam_search(
model,
input_ids: torch.Tensor,
beam_width: int = 5,
max_length: int = 50,
eos_token_id: int = None,
length_penalty: float = 1.0,
early_stopping: bool = True
) -> torch.Tensor:
"""Beam search for text generation.
Maintains multiple hypotheses and selects the highest-probability
complete sequence. Good for tasks requiring coherent, high-quality
output like translation or summarization.
Args:
model: Language model with forward method returning logits
input_ids: Starting tokens, shape (1, seq_len)
beam_width: Number of hypotheses to maintain
max_length: Maximum total sequence length
eos_token_id: End of sequence token ID
length_penalty: Penalty for longer sequences (>1 favors longer)
early_stopping: Stop when all beams have finished
Returns:
Best sequence, shape (1, seq_len)
"""
device = input_ids.device
batch_size = input_ids.shape[0]
assert batch_size == 1, "Beam search currently supports batch_size=1"
# Initialize beams
beams = [BeamHypothesis(
tokens=input_ids.clone(),
log_prob=0.0,
is_finished=False
)]
finished_beams = []
for step in range(max_length - input_ids.shape[1]):
if not beams:
break
all_candidates = []
for beam in beams:
if beam.is_finished:
finished_beams.append(beam)
continue
# Get next token probabilities
with torch.no_grad():
logits = model(beam.tokens)[:, -1, :]
log_probs = F.log_softmax(logits, dim=-1)
# Get top candidates for this beam
top_log_probs, top_indices = torch.topk(log_probs, beam_width, dim=-1)
for i in range(beam_width):
new_token = top_indices[0, i].unsqueeze(0).unsqueeze(0)
new_log_prob = beam.log_prob + top_log_probs[0, i].item()
new_tokens = torch.cat([beam.tokens, new_token], dim=-1)
is_finished = (eos_token_id is not None and
new_token.item() == eos_token_id)
all_candidates.append(BeamHypothesis(
tokens=new_tokens,
log_prob=new_log_prob,
is_finished=is_finished
))
# Select top beams by score (with length penalty)
def score(beam):
length = beam.tokens.shape[1]
return beam.log_prob / (length ** length_penalty)
all_candidates.sort(key=score, reverse=True)
beams = [b for b in all_candidates[:beam_width] if not b.is_finished]
# Add finished beams to collection
finished_beams.extend([b for b in all_candidates[:beam_width] if b.is_finished])
# Early stopping: all beams finished
if early_stopping and not beams:
break
# Add any remaining unfinished beams
finished_beams.extend(beams)
if not finished_beams:
return input_ids
# Return the best beam
best_beam = max(finished_beams, key=score)
return best_beam.tokens
def diverse_beam_search(
model,
input_ids: torch.Tensor,
num_groups: int = 3,
beam_width_per_group: int = 3,
diversity_penalty: float = 0.5,
max_length: int = 50,
eos_token_id: int = None
) -> list[torch.Tensor]:
"""Diverse beam search generates multiple distinct outputs.
Divides beams into groups and penalizes selecting tokens already
chosen by other groups, encouraging diversity.
Args:
model: Language model
input_ids: Starting tokens
num_groups: Number of beam groups
beam_width_per_group: Beams per group
diversity_penalty: Penalty for tokens chosen by other groups
max_length: Maximum sequence length
eos_token_id: End of sequence token ID
Returns:
List of diverse generated sequences
"""
device = input_ids.device
total_beams = num_groups * beam_width_per_group
# Initialize all beams with the same input
beams = [[BeamHypothesis(
tokens=input_ids.clone(),
log_prob=0.0,
is_finished=False
) for _ in range(beam_width_per_group)] for _ in range(num_groups)]
for step in range(max_length - input_ids.shape[1]):
# Track tokens selected by previous groups this step
tokens_selected = []
for group_idx in range(num_groups):
group_beams = beams[group_idx]
all_candidates = []
for beam in group_beams:
if beam.is_finished:
all_candidates.append(beam)
continue
with torch.no_grad():
logits = model(beam.tokens)[:, -1, :]
# Apply diversity penalty
for prev_token in tokens_selected:
logits[0, prev_token] -= diversity_penalty
log_probs = F.log_softmax(logits, dim=-1)
top_log_probs, top_indices = torch.topk(
log_probs, beam_width_per_group, dim=-1
)
for i in range(beam_width_per_group):
new_token = top_indices[0, i].unsqueeze(0).unsqueeze(0)
all_candidates.append(BeamHypothesis(
tokens=torch.cat([beam.tokens, new_token], dim=-1),
log_prob=beam.log_prob + top_log_probs[0, i].item(),
is_finished=(eos_token_id and new_token.item() == eos_token_id)
))
# Select top beams for this group
all_candidates.sort(key=lambda b: b.log_prob, reverse=True)
beams[group_idx] = all_candidates[:beam_width_per_group]
# Record tokens selected by this group
for beam in beams[group_idx]:
if not beam.is_finished:
tokens_selected.append(beam.tokens[0, -1].item())
# Return best beam from each group
results = []
for group in beams:
best = max(group, key=lambda b: b.log_prob)
results.append(best.tokens)
return results