LLM/NLP Foundations Code Reference

Complete code implementations for Chapter 5: LLM/NLP Foundations.


Table of Contents


Tokenization

Token Budget Manager

Complete implementation for managing token budgets in prompts:

class TokenBudget:
    """Helper for managing token budgets in prompts."""

    def __init__(self, model_max_tokens: int, tokenizer):
        self.max_tokens = model_max_tokens
        self.tokenizer = tokenizer

    def count(self, text: str) -> int:
        """Count tokens in text."""
        return len(self.tokenizer.encode(text))

    def fits(self, components: dict[str, str], reserve_output: int = 500) -> bool:
        """Check if all components fit within budget.

        Args:
            components: Dict mapping component names to text content
            reserve_output: Tokens to reserve for model output

        Returns:
            True if total tokens + reserved output fits in max_tokens
        """
        total = sum(self.count(text) for text in components.values())
        return total + reserve_output <= self.max_tokens

    def truncate_to_fit(self, text: str, max_tokens: int) -> str:
        """Truncate text to fit within token limit.

        Args:
            text: Text to potentially truncate
            max_tokens: Maximum tokens allowed

        Returns:
            Original text if within limit, otherwise truncated version
        """
        tokens = self.tokenizer.encode(text)
        if len(tokens) <= max_tokens:
            return text
        return self.tokenizer.decode(tokens[:max_tokens])

    def allocate_budget(
        self,
        fixed_components: dict[str, str],
        variable_component: str,
        reserve_output: int = 500
    ) -> str:
        """Allocate remaining budget to a variable-length component.

        Args:
            fixed_components: Components with fixed content (system prompt, query)
            variable_component: Content to truncate if needed (e.g., retrieved docs)
            reserve_output: Tokens to reserve for model output

        Returns:
            Truncated variable component that fits remaining budget
        """
        fixed_tokens = sum(self.count(text) for text in fixed_components.values())
        available = self.max_tokens - fixed_tokens - reserve_output
        return self.truncate_to_fit(variable_component, max(0, available))


# Usage example {.unnumbered}
def build_rag_prompt(
    tokenizer,
    system_prompt: str,
    retrieved_docs: list[str],
    user_query: str,
    model_max_tokens: int = 4096,
    reserve_output: int = 1000
) -> dict[str, str]:
    """Build a RAG prompt that fits within token limits."""
    budget = TokenBudget(model_max_tokens, tokenizer)

    # Fixed components
    fixed = {
        "system": system_prompt,
        "user_query": user_query
    }

    # Combine retrieved documents
    context = "\n\n---\n\n".join(retrieved_docs)

    # Truncate context to fit
    truncated_context = budget.allocate_budget(fixed, context, reserve_output)

    return {
        "system": system_prompt,
        "context": truncated_context,
        "user_query": user_query
    }

Embeddings

Embedding Pooling Strategies

Complete implementation for pooling token embeddings into sequence embeddings:

import torch

def pool_embeddings(
    token_embeddings: torch.Tensor,
    attention_mask: torch.Tensor,
    strategy: str = "mean"
) -> torch.Tensor:
    """
    Pool token embeddings into a single sequence embedding.

    Args:
        token_embeddings: (batch_size, seq_length, hidden_dim)
        attention_mask: (batch_size, seq_length) - 1 for real tokens, 0 for padding
        strategy: "mean", "max", or "cls"

    Returns:
        Pooled embeddings of shape (batch_size, hidden_dim)
    """
    if strategy == "cls":
        # Use the [CLS] token embedding (position 0)
        # Common for BERT-style models
        return token_embeddings[:, 0, :]

    elif strategy == "mean":
        # Average non-padding tokens
        # Most common for sentence embeddings
        mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
        sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)  # Avoid division by zero
        return sum_embeddings / sum_mask

    elif strategy == "max":
        # Max-pool over non-padding tokens
        # Can capture salient features
        mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        token_embeddings = token_embeddings.clone()
        token_embeddings[mask_expanded == 0] = -1e9  # Mask padding with large negative
        return torch.max(token_embeddings, dim=1).values

    elif strategy == "mean_max":
        # Concatenate mean and max pooling
        # Sometimes gives better results
        mean_pooled = pool_embeddings(token_embeddings, attention_mask, "mean")
        max_pooled = pool_embeddings(token_embeddings, attention_mask, "max")
        return torch.cat([mean_pooled, max_pooled], dim=-1)

    else:
        raise ValueError(f"Unknown pooling strategy: {strategy}")


def normalize_embeddings(embeddings: torch.Tensor) -> torch.Tensor:
    """L2 normalize embeddings to unit length.

    Normalized embeddings allow using dot product for cosine similarity,
    which is faster and enables efficient nearest neighbor search.
    """
    return embeddings / embeddings.norm(dim=-1, keepdim=True)


def compute_similarity_matrix(
    query_embeddings: torch.Tensor,
    document_embeddings: torch.Tensor,
    normalize: bool = True
) -> torch.Tensor:
    """Compute pairwise similarity between queries and documents.

    Args:
        query_embeddings: (num_queries, hidden_dim)
        document_embeddings: (num_documents, hidden_dim)
        normalize: Whether to L2 normalize (gives cosine similarity)

    Returns:
        Similarity matrix of shape (num_queries, num_documents)
    """
    if normalize:
        query_embeddings = normalize_embeddings(query_embeddings)
        document_embeddings = normalize_embeddings(document_embeddings)

    return torch.mm(query_embeddings, document_embeddings.T)

Embedding Analogy Solver

Complete implementation for solving word analogies using embeddings:

import torch

def analogy(
    embeddings: torch.Tensor,
    word_to_idx: dict[str, int],
    idx_to_word: dict[int, str],
    a: str,
    b: str,
    c: str,
    top_k: int = 5
) -> list[str]:
    """Find d such that a:b :: c:d using vector arithmetic.

    The classic example: king - man + woman = queen
    This works because embedding spaces encode semantic relationships
    as geometric directions.

    Args:
        embeddings: (vocab_size, embedding_dim) tensor of word embeddings
        word_to_idx: Mapping from words to embedding indices
        idx_to_word: Mapping from indices to words
        a, b, c: Words forming the analogy a:b :: c:?
        top_k: Number of candidates to return

    Returns:
        List of top_k most likely words to complete the analogy
    """
    # Get embeddings for input words
    vec_a = embeddings[word_to_idx[a]]
    vec_b = embeddings[word_to_idx[b]]
    vec_c = embeddings[word_to_idx[c]]

    # Compute target vector: a - b + c
    # If a:b captures a relationship, adding that relationship to c
    # should give us the analogous word d
    target = vec_a - vec_b + vec_c

    # Normalize target for cosine similarity
    target_norm = target / torch.norm(target)

    # Normalize all embeddings
    all_norms = embeddings / embeddings.norm(dim=1, keepdim=True)

    # Compute similarities with all words
    similarities = torch.matmul(all_norms, target_norm)

    # Exclude input words from results
    exclude_idx = {word_to_idx[w] for w in [a, b, c]}
    for idx in exclude_idx:
        similarities[idx] = -float('inf')

    # Return top-k most similar words
    top_indices = similarities.topk(top_k).indices
    return [idx_to_word[i.item()] for i in top_indices]


def find_similar_words(
    embeddings: torch.Tensor,
    word_to_idx: dict[str, int],
    idx_to_word: dict[int, str],
    query_word: str,
    top_k: int = 10
) -> list[tuple[str, float]]:
    """Find words most similar to a query word.

    Args:
        embeddings: (vocab_size, embedding_dim) tensor
        word_to_idx: Mapping from words to indices
        idx_to_word: Mapping from indices to words
        query_word: Word to find similar words for
        top_k: Number of similar words to return

    Returns:
        List of (word, similarity_score) tuples
    """
    query_idx = word_to_idx[query_word]
    query_vec = embeddings[query_idx]

    # Normalize for cosine similarity
    query_norm = query_vec / torch.norm(query_vec)
    all_norms = embeddings / embeddings.norm(dim=1, keepdim=True)

    similarities = torch.matmul(all_norms, query_norm)

    # Exclude the query word itself
    similarities[query_idx] = -float('inf')

    top_scores, top_indices = similarities.topk(top_k)

    return [
        (idx_to_word[idx.item()], score.item())
        for idx, score in zip(top_indices, top_scores)
    ]

Attention

Scaled Dot-Product Attention

Complete implementation of scaled dot-product attention:

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor = None,
    dropout: float = 0.0,
    training: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute scaled dot-product attention.

    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

    Args:
        Q: Queries, shape (batch, num_heads, seq_len, d_k)
        K: Keys, shape (batch, num_heads, seq_len, d_k)
        V: Values, shape (batch, num_heads, seq_len, d_v)
        mask: Optional attention mask, shape (batch, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len)
        dropout: Dropout probability for attention weights
        training: Whether in training mode (affects dropout)

    Returns:
        output: Attention output, shape (batch, num_heads, seq_len, d_v)
        attention_weights: Attention weights, shape (batch, num_heads, seq_len, seq_len)
    """
    d_k = Q.shape[-1]

    # Step 1: Compute attention scores
    # QK^T gives us a (seq_len, seq_len) matrix of compatibility scores
    # Each entry (i, j) represents how much position i should attend to position j
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, num_heads, seq_len, seq_len)

    # Step 2: Scale by sqrt(d_k)
    # Without scaling, dot products grow with dimension, pushing softmax
    # into regions with vanishing gradients. Scaling maintains variance.
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (if provided)
    # Masked positions get -inf, which becomes 0 after softmax
    # Used for: causal attention, padding masks, custom attention patterns
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax to get attention weights
    # Each row sums to 1: position i distributes its attention across all positions
    attention_weights = F.softmax(scores, dim=-1)

    # Handle case where entire row is masked (all -inf -> nan after softmax)
    attention_weights = attention_weights.nan_to_num(0.0)

    # Optional dropout on attention weights (regularization)
    if dropout > 0.0 and training:
        attention_weights = F.dropout(attention_weights, p=dropout)

    # Step 5: Weighted sum of values
    # Output at position i is a weighted combination of all value vectors
    output = torch.matmul(attention_weights, V)  # (batch, num_heads, seq_len, d_v)

    return output, attention_weights


def create_causal_mask(seq_length: int, device: torch.device = None) -> torch.Tensor:
    """Create a causal attention mask.

    Returns a matrix where position i can attend to positions 0...i
    but not positions i+1...n. Used for autoregressive generation.

    Args:
        seq_length: Length of the sequence
        device: Device to create tensor on

    Returns:
        Lower triangular matrix of shape (1, 1, seq_len, seq_len)
    """
    mask = torch.tril(torch.ones(seq_length, seq_length, device=device))
    return mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions


def create_padding_mask(
    attention_mask: torch.Tensor
) -> torch.Tensor:
    """Create attention mask from padding mask.

    Args:
        attention_mask: (batch, seq_len) with 1 for real tokens, 0 for padding

    Returns:
        Attention mask of shape (batch, 1, 1, seq_len) for broadcasting
    """
    return attention_mask.unsqueeze(1).unsqueeze(2)

Multi-Head Attention

Complete implementation of multi-head attention:

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism.

    Multiple attention heads allow the model to jointly attend to information
    from different representation subspaces at different positions.
    Different heads can learn different types of relationships:
    - Syntactic patterns (subject-verb agreement)
    - Semantic relationships (coreference)
    - Positional patterns (nearby words)
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.0,
        bias: bool = True
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        self.dropout = dropout

        # Projections for Q, K, V (combined into single matrices for efficiency)
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=bias)

        self._init_weights()

    def _init_weights(self):
        """Initialize weights with scaled initialization."""
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
        return_attention: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for multi-head attention.

        Args:
            query: (batch, seq_len, d_model)
            key: (batch, seq_len, d_model) - can be from different sequence for cross-attention
            value: (batch, seq_len, d_model)
            mask: Optional attention mask
            return_attention: Whether to return attention weights

        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: Optional, (batch, num_heads, seq_len, seq_len)
        """
        batch_size = query.shape[0]

        # Project and reshape to (batch, num_heads, seq_len, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Compute attention for all heads in parallel
        attn_output, attn_weights = scaled_dot_product_attention(
            Q, K, V, mask,
            dropout=self.dropout,
            training=self.training
        )

        # Concatenate heads: (batch, seq_len, num_heads * d_k) = (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # Final projection
        output = self.W_o(attn_output)

        if return_attention:
            return output, attn_weights
        return output


class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention (GQA) - used in LLaMA 2, Mistral, etc.

    Instead of separate K, V projections per head, groups of query heads
    share the same K, V. This reduces KV cache memory during inference
    with minimal quality loss.

    - num_heads query heads
    - num_kv_heads key-value heads (num_heads must be divisible by num_kv_heads)
    - When num_kv_heads = 1, this is Multi-Query Attention (MQA)
    - When num_kv_heads = num_heads, this is standard Multi-Head Attention
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0
    ):
        super().__init__()
        assert num_heads % num_kv_heads == 0

        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = num_heads // num_kv_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, self.num_kv_heads * self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, self.num_kv_heads * self.d_k, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_kv_heads, self.d_k).transpose(1, 2)

        # Repeat K, V for each group of query heads
        K = K.repeat_interleave(self.num_queries_per_kv, dim=1)
        V = V.repeat_interleave(self.num_queries_per_kv, dim=1)

        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(attn_output)

Transformer Architecture

Transformer Block

Complete implementation of a transformer block:

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerBlock(nn.Module):
    """A single transformer block with pre-norm architecture.

    Components:
    1. Multi-head self-attention with residual connection
    2. Feed-forward network with residual connection

    Uses pre-norm (LayerNorm before sublayers) which is more stable
    for training deep networks than post-norm.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = "gelu"
    ):
        super().__init__()

        # Attention sublayer
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # Feed-forward sublayer
        self.feed_forward = FeedForward(d_model, d_ff, dropout, activation)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape (batch, seq_len, d_model)
            mask: Optional attention mask

        Returns:
            Output tensor of shape (batch, seq_len, d_model)
        """
        # Pre-norm attention with residual
        normed = self.norm1(x)
        attended = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attended)

        # Pre-norm feed-forward with residual
        normed = self.norm2(x)
        fed_forward = self.feed_forward(normed)
        x = x + self.dropout(fed_forward)

        return x


class FeedForward(nn.Module):
    """Position-wise feed-forward network.

    FFN(x) = activation(xW1 + b1)W2 + b2

    The hidden dimension is typically 4x the model dimension.
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1,
        activation: str = "gelu"
    ):
        super().__init__()

        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(self.activation(self.linear1(x))))


class SwiGLU(nn.Module):
    """SwiGLU activation - used in LLaMA and other modern architectures.

    SwiGLU(x) = (xW1 * SiLU(xW_gate)) W2

    The gating mechanism improves performance at similar compute cost.
    """

    def __init__(self, d_model: int, d_ff: int, bias: bool = False):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=bias)
        self.w2 = nn.Linear(d_ff, d_model, bias=bias)
        self.w_gate = nn.Linear(d_model, d_ff, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w_gate(x)) * self.w1(x))

Complete Transformer Language Model

Complete implementation of a decoder-only transformer language model:

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerLM(nn.Module):
    """Decoder-only transformer language model (GPT-style).

    This is the architecture used by GPT, LLaMA, and most modern LLMs.
    It uses causal (autoregressive) attention so each position can only
    attend to previous positions.
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        num_heads: int,
        num_layers: int,
        d_ff: int,
        max_seq_length: int,
        dropout: float = 0.1,
        tie_weights: bool = True
    ):
        super().__init__()

        self.d_model = d_model
        self.max_seq_length = max_seq_length

        # Token and position embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_length, d_model)
        self.embedding_dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # Final layer norm and output projection
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Weight tying: share weights between token embedding and output projection
        # This reduces parameters and often improves performance
        if tie_weights:
            self.lm_head.weight = self.token_embedding.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights with scaled normal distribution."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(
        self,
        token_ids: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Forward pass for language modeling.

        Args:
            token_ids: Input token IDs, shape (batch, seq_len)
            attention_mask: Optional mask for padding, shape (batch, seq_len)

        Returns:
            Logits over vocabulary, shape (batch, seq_len, vocab_size)
        """
        batch_size, seq_length = token_ids.shape
        device = token_ids.device

        # Get embeddings
        positions = torch.arange(seq_length, device=device)
        x = self.token_embedding(token_ids) + self.position_embedding(positions)
        x = self.embedding_dropout(x)

        # Create causal mask (each position attends only to previous positions)
        causal_mask = torch.tril(
            torch.ones(seq_length, seq_length, device=device)
        ).unsqueeze(0).unsqueeze(0)

        # Combine with padding mask if provided
        if attention_mask is not None:
            # attention_mask: (batch, seq_len) -> (batch, 1, 1, seq_len)
            padding_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            mask = causal_mask * padding_mask
        else:
            mask = causal_mask

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        # Final norm and output projection
        x = self.norm(x)
        logits = self.lm_head(x)  # (batch, seq_len, vocab_size)

        return logits

    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        top_p: float = 0.9,
        top_k: int = 50
    ) -> torch.Tensor:
        """Generate text autoregressively.

        Args:
            input_ids: Starting token IDs, shape (batch, seq_len)
            max_new_tokens: Maximum tokens to generate
            temperature: Sampling temperature (0 = greedy)
            top_p: Nucleus sampling threshold
            top_k: Top-k sampling limit

        Returns:
            Generated token IDs including input, shape (batch, seq_len + generated)
        """
        self.eval()

        for _ in range(max_new_tokens):
            # Truncate if we exceed max length
            if input_ids.shape[1] >= self.max_seq_length:
                input_ids = input_ids[:, -self.max_seq_length:]

            # Forward pass
            with torch.no_grad():
                logits = self(input_ids)
                next_token_logits = logits[:, -1, :]  # Last position only

            # Apply temperature
            if temperature > 0:
                next_token_logits = next_token_logits / temperature

                # Top-k filtering
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                    next_token_logits[indices_to_remove] = float('-inf')

                # Top-p (nucleus) filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = False
                    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                    next_token_logits[indices_to_remove] = float('-inf')

                # Sample
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Greedy decoding
                next_token = next_token_logits.argmax(dim=-1, keepdim=True)

            input_ids = torch.cat([input_ids, next_token], dim=-1)

        return input_ids


def count_parameters(model: nn.Module) -> dict:
    """Count parameters in a transformer model by component."""
    counts = {}
    total = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            count = param.numel()

            # Categorize by component
            if 'embedding' in name:
                category = 'embeddings'
            elif 'attention' in name or 'W_q' in name or 'W_k' in name or 'W_v' in name or 'W_o' in name:
                category = 'attention'
            elif 'feed_forward' in name or 'linear' in name:
                category = 'feed_forward'
            elif 'norm' in name:
                category = 'layer_norm'
            elif 'lm_head' in name:
                category = 'lm_head'
            else:
                category = 'other'

            counts[category] = counts.get(category, 0) + count
            total += count

    counts['total'] = total
    return counts

Inference and Generation