Backend Engineering Code Reference

Complete code implementations for Chapter 50: Backend Engineering for AI.


Table of Contents


Fine-Tuning Patterns

Training Data Validation

Complete implementation for validating fine-tuning datasets:

import json
from collections import Counter

def validate_training_data(file_path: str) -> dict:
    """Validate fine-tuning dataset quality."""
    issues = []
    examples = []

    with open(file_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            try:
                example = json.loads(line)
                examples.append(example)

                # Check structure
                if 'messages' not in example:
                    issues.append(f"Line {line_num}: Missing 'messages' key")
                    continue

                messages = example['messages']

                # Check roles
                roles = [m.get('role') for m in messages]
                if roles[-1] != 'assistant':
                    issues.append(f"Line {line_num}: Last message should be assistant")

                # Check for empty content
                for i, msg in enumerate(messages):
                    if not msg.get('content', '').strip():
                        issues.append(f"Line {line_num}: Empty content in message {i}")

            except json.JSONDecodeError as e:
                issues.append(f"Line {line_num}: Invalid JSON - {e}")

    # Check for duplicates
    contents = [json.dumps(ex['messages']) for ex in examples]
    duplicates = [item for item, count in Counter(contents).items() if count > 1]

    # Analyze distribution
    assistant_lengths = []
    for ex in examples:
        for msg in ex.get('messages', []):
            if msg.get('role') == 'assistant':
                assistant_lengths.append(len(msg.get('content', '')))

    return {
        'total_examples': len(examples),
        'issues': issues,
        'duplicate_count': len(duplicates),
        'avg_assistant_length': sum(assistant_lengths) / len(assistant_lengths) if assistant_lengths else 0,
        'min_assistant_length': min(assistant_lengths) if assistant_lengths else 0,
        'max_assistant_length': max(assistant_lengths) if assistant_lengths else 0,
    }

Testing LLM Applications

Statistical Test Suite

Complete implementation for statistical testing of LLM outputs:

import numpy as np
from scipy import stats

class LLMTestSuite:
    def __init__(self, llm_client, num_samples: int = 30):
        self.client = llm_client
        self.num_samples = num_samples

    def test_classification_accuracy(
        self,
        prompt_template: str,
        test_cases: list[tuple[str, str]],  # (input, expected_label)
        min_accuracy: float = 0.95
    ) -> dict:
        """Test classification prompt achieves minimum accuracy."""
        correct = 0
        results = []

        for input_text, expected in test_cases:
            prompt = prompt_template.format(input=input_text)
            response = self.client.generate(prompt)
            predicted = self.parse_classification(response)

            is_correct = predicted == expected
            correct += is_correct
            results.append({
                'input': input_text,
                'expected': expected,
                'predicted': predicted,
                'correct': is_correct
            })

        accuracy = correct / len(test_cases)

        # Calculate confidence interval
        ci_low, ci_high = stats.proportion_confint(
            correct, len(test_cases), alpha=0.05, method='wilson'
        )

        return {
            'accuracy': accuracy,
            'ci_95': (ci_low, ci_high),
            'passed': ci_low >= min_accuracy,  # Conservative: use lower bound
            'results': results
        }

    def test_output_consistency(
        self,
        prompt: str,
        consistency_threshold: float = 0.8
    ) -> dict:
        """Test that repeated calls produce consistent outputs."""
        responses = []
        for _ in range(self.num_samples):
            response = self.client.generate(prompt, temperature=0.7)
            responses.append(response)

        # Calculate pairwise similarity
        similarities = []
        for i in range(len(responses)):
            for j in range(i + 1, len(responses)):
                sim = self.calculate_similarity(responses[i], responses[j])
                similarities.append(sim)

        avg_similarity = np.mean(similarities)

        return {
            'avg_similarity': avg_similarity,
            'min_similarity': min(similarities),
            'max_similarity': max(similarities),
            'passed': avg_similarity >= consistency_threshold,
            'sample_responses': responses[:5]
        }

    def test_format_compliance(
        self,
        prompt: str,
        validator: callable,
        min_compliance: float = 0.99
    ) -> dict:
        """Test that outputs comply with expected format."""
        compliant = 0
        failures = []

        for _ in range(self.num_samples):
            response = self.client.generate(prompt)
            try:
                validator(response)
                compliant += 1
            except Exception as e:
                failures.append({'response': response, 'error': str(e)})

        compliance_rate = compliant / self.num_samples

        return {
            'compliance_rate': compliance_rate,
            'passed': compliance_rate >= min_compliance,
            'failures': failures[:5]  # Sample of failures for debugging
        }

Regression Test Suite

Complete implementation for detecting quality regressions:

import json

class RegressionTestSuite:
    def __init__(self, baseline_results_path: str):
        self.baseline = self.load_baseline(baseline_results_path)

    def load_baseline(self, path: str) -> dict:
        """Load baseline metrics from previous successful run."""
        with open(path, 'r') as f:
            return json.load(f)

    def run_regression_tests(
        self,
        current_model,
        test_cases: list[dict],
        tolerance: float = 0.02  # Allow 2% degradation
    ) -> dict:
        """Compare current model performance against baseline."""
        current_results = self.evaluate_model(current_model, test_cases)

        regressions = []
        improvements = []

        for metric, baseline_value in self.baseline.items():
            current_value = current_results.get(metric, 0)
            delta = current_value - baseline_value

            if delta < -tolerance:
                regressions.append({
                    'metric': metric,
                    'baseline': baseline_value,
                    'current': current_value,
                    'delta': delta
                })
            elif delta > tolerance:
                improvements.append({
                    'metric': metric,
                    'baseline': baseline_value,
                    'current': current_value,
                    'delta': delta
                })

        return {
            'passed': len(regressions) == 0,
            'regressions': regressions,
            'improvements': improvements,
            'current_results': current_results
        }

    def update_baseline(self, results: dict, path: str):
        """Update baseline after verified improvements."""
        with open(path, 'w') as f:
            json.dump(results, f, indent=2)

Contract Test Suite

Complete implementation for API contract testing:

from pydantic import BaseModel, ValidationError
from typing import Optional

class ClassificationResponse(BaseModel):
    category: str
    confidence: float
    reasoning: Optional[str] = None

class ContractTestSuite:
    def __init__(self, api_client):
        self.client = api_client

    def test_response_contract(
        self,
        endpoint: str,
        test_inputs: list[dict],
        response_schema: type[BaseModel]
    ) -> dict:
        """Test that API responses conform to schema."""
        results = []

        for test_input in test_inputs:
            response = self.client.post(endpoint, json=test_input)

            try:
                # Validate response structure
                validated = response_schema.model_validate(response.json())
                results.append({
                    'input': test_input,
                    'valid': True,
                    'response': response.json()
                })
            except ValidationError as e:
                results.append({
                    'input': test_input,
                    'valid': False,
                    'error': str(e),
                    'response': response.json()
                })

        valid_count = sum(1 for r in results if r['valid'])

        return {
            'total': len(results),
            'valid': valid_count,
            'invalid': len(results) - valid_count,
            'compliance_rate': valid_count / len(results),
            'failures': [r for r in results if not r['valid']]
        }

    def test_error_handling(self, endpoint: str) -> dict:
        """Test API error handling for edge cases."""
        edge_cases = [
            {'input': ''},  # Empty input
            {'input': None},  # Null input
            {'input': 'x' * 100000},  # Very long input
            {},  # Missing required field
            {'input': 'valid', 'extra_field': 'ignored'},  # Extra fields
        ]

        results = []
        for case in edge_cases:
            response = self.client.post(endpoint, json=case)

            results.append({
                'case': case,
                'status_code': response.status_code,
                'has_error_message': 'error' in response.json(),
                'response': response.json()
            })

        return results

Snapshot Tester

Complete implementation for snapshot testing:

import difflib
from pathlib import Path

class SnapshotTester:
    def __init__(self, snapshot_dir: str):
        self.snapshot_dir = Path(snapshot_dir)
        self.snapshot_dir.mkdir(parents=True, exist_ok=True)

    def test_or_update(
        self,
        test_name: str,
        prompt: str,
        llm_client,
        update: bool = False
    ) -> dict:
        """Test output against snapshot, or update snapshot."""
        snapshot_path = self.snapshot_dir / f"{test_name}.txt"

        # Generate current output (deterministic)
        current_output = llm_client.generate(prompt, temperature=0)

        if update or not snapshot_path.exists():
            # Create or update snapshot
            snapshot_path.write_text(current_output)
            return {
                'status': 'updated',
                'output': current_output
            }

        # Compare with snapshot
        expected_output = snapshot_path.read_text()

        if current_output == expected_output:
            return {'status': 'passed'}

        # Generate diff for debugging
        diff = list(difflib.unified_diff(
            expected_output.splitlines(keepends=True),
            current_output.splitlines(keepends=True),
            fromfile='expected',
            tofile='current'
        ))

        return {
            'status': 'failed',
            'expected': expected_output,
            'current': current_output,
            'diff': ''.join(diff)
        }

Debugging LLM Applications

Debug Logger

Complete implementation for LLM interaction logging:

import logging
import json
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Optional

@dataclass
class LLMInteraction:
    timestamp: str
    request_id: str
    model: str
    system_prompt: str
    user_message: str
    full_prompt: str  # The complete prompt sent to API
    prompt_tokens: int
    response: str
    completion_tokens: int
    latency_ms: float
    temperature: float
    retrieved_documents: Optional[list[dict]] = None
    tool_calls: Optional[list[dict]] = None
    error: Optional[str] = None

class LLMDebugLogger:
    def __init__(self, log_file: str = "llm_interactions.jsonl"):
        self.log_file = log_file
        self.logger = logging.getLogger("llm_debug")

    def log_interaction(self, interaction: LLMInteraction):
        """Log complete interaction for debugging."""
        with open(self.log_file, 'a') as f:
            f.write(json.dumps(asdict(interaction)) + '\n')

    def create_interaction(
        self,
        request_id: str,
        model: str,
        system_prompt: str,
        user_message: str,
        **kwargs
    ) -> LLMInteraction:
        """Create interaction record with timestamp."""
        return LLMInteraction(
            timestamp=datetime.utcnow().isoformat(),
            request_id=request_id,
            model=model,
            system_prompt=system_prompt,
            user_message=user_message,
            **kwargs
        )

    def find_similar_failures(self, current_input: str, n: int = 5) -> list[dict]:
        """Find past interactions with similar inputs that failed."""
        # Load past interactions and find similar failures
        # Useful for identifying patterns in failures
        pass

Data Pipeline Architecture

Training Data Pipeline

Complete implementation for preparing fine-tuning data:

from dataclasses import dataclass
from typing import Iterator
import hashlib
import json
import random

@dataclass
class TrainingExample:
    id: str
    messages: list[dict]
    metadata: dict

class TrainingDataPipeline:
    def __init__(self, output_dir: str):
        self.output_dir = output_dir
        self.seen_hashes = set()

    def process(self, raw_data: Iterator[dict]) -> None:
        """Process raw data through the pipeline."""
        examples = []

        for item in raw_data:
            # Extract
            extracted = self.extract(item)
            if not extracted:
                continue

            # Clean
            cleaned = self.clean(extracted)
            if not cleaned:
                continue

            # Deduplicate
            content_hash = self.hash_content(cleaned)
            if content_hash in self.seen_hashes:
                continue
            self.seen_hashes.add(content_hash)

            # Format
            formatted = self.format_for_training(cleaned)

            # Validate
            if self.validate(formatted):
                examples.append(formatted)

        # Split and save
        self.split_and_save(examples)

    def extract(self, item: dict) -> dict | None:
        """Extract relevant fields from raw data."""
        try:
            return {
                'input': item['user_query'],
                'output': item['assistant_response'],
                'context': item.get('context', ''),
                'metadata': {
                    'source': item.get('source'),
                    'timestamp': item.get('timestamp')
                }
            }
        except KeyError:
            return None

    def clean(self, item: dict) -> dict | None:
        """Clean and filter data."""
        # Remove low-quality examples
        if len(item['output']) < 10:
            return None

        # Clean text
        item['input'] = item['input'].strip()
        item['output'] = item['output'].strip()

        # Remove PII (simplified)
        item['input'] = self.remove_pii(item['input'])
        item['output'] = self.remove_pii(item['output'])

        return item

    def hash_content(self, item: dict) -> str:
        """Create hash for deduplication."""
        content = f"{item['input']}|||{item['output']}"
        return hashlib.sha256(content.encode()).hexdigest()

    def format_for_training(self, item: dict) -> TrainingExample:
        """Convert to training format."""
        messages = [
            {"role": "user", "content": item['input']},
            {"role": "assistant", "content": item['output']}
        ]

        if item.get('context'):
            messages.insert(0, {
                "role": "system",
                "content": f"Context: {item['context']}"
            })

        return TrainingExample(
            id=self.hash_content(item)[:12],
            messages=messages,
            metadata=item['metadata']
        )

    def validate(self, example: TrainingExample) -> bool:
        """Validate example meets requirements."""
        # Check structure
        if not example.messages:
            return False

        # Check token count
        total_tokens = sum(
            count_tokens(m['content']) for m in example.messages
        )
        if total_tokens > 4096:  # Max sequence length
            return False

        return True

    def split_and_save(
        self,
        examples: list[TrainingExample],
        train_ratio: float = 0.8,
        val_ratio: float = 0.1
    ):
        """Split data and save to files."""
        random.shuffle(examples)

        n = len(examples)
        train_end = int(n * train_ratio)
        val_end = train_end + int(n * val_ratio)

        splits = {
            'train': examples[:train_end],
            'val': examples[train_end:val_end],
            'test': examples[val_end:]
        }

        for split_name, split_examples in splits.items():
            path = f"{self.output_dir}/{split_name}.jsonl"
            with open(path, 'w') as f:
                for ex in split_examples:
                    f.write(json.dumps({
                        'id': ex.id,
                        'messages': ex.messages
                    }) + '\n')

Document Ingestion Pipeline

Complete implementation for RAG document ingestion:

class DocumentIngestionPipeline:
    def __init__(self, vector_store, embedding_model):
        self.vector_store = vector_store
        self.embedder = embedding_model
        self.chunker = SemanticChunker(chunk_size=512, overlap=50)

    async def ingest_document(self, doc_path: str) -> dict:
        """Ingest a document into the vector store."""
        # Extract
        text, metadata = await self.extract(doc_path)

        # Chunk
        chunks = self.chunker.chunk(text)

        # Embed
        embeddings = await self.embed_batch(chunks)

        # Store
        ids = await self.store(chunks, embeddings, metadata)

        return {
            'doc_path': doc_path,
            'num_chunks': len(chunks),
            'chunk_ids': ids
        }

    async def extract(self, doc_path: str) -> tuple[str, dict]:
        """Extract text and metadata from document."""
        # Use appropriate extractor based on file type
        pass

    async def embed_batch(self, texts: list[str]) -> list[list[float]]:
        """Embed texts in batches for efficiency."""
        batch_size = 32
        embeddings = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_embeddings = await self.embedder.embed(batch)
            embeddings.extend(batch_embeddings)

        return embeddings

    async def store(
        self,
        chunks: list[str],
        embeddings: list[list[float]],
        metadata: dict
    ) -> list[str]:
        """Store chunks with embeddings in vector store."""
        ids = []
        for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
            chunk_id = f"{metadata['doc_id']}_chunk_{i}"
            await self.vector_store.upsert(
                id=chunk_id,
                embedding=embedding,
                text=chunk,
                metadata={**metadata, 'chunk_index': i}
            )
            ids.append(chunk_id)

        return ids

Integration Patterns

Circuit Breaker Pattern

Complete implementation for protecting against LLM API failures:

from datetime import datetime, timedelta
from enum import Enum

class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Blocking calls
    HALF_OPEN = "half_open"  # Testing recovery

class LLMCircuitBreaker:
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: timedelta = timedelta(seconds=30),
        half_open_max_calls: int = 3
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.half_open_max_calls = half_open_max_calls

        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.last_failure_time = None
        self.half_open_calls = 0

    def can_execute(self) -> bool:
        """Check if call should proceed."""
        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            # Check if recovery timeout has passed
            if datetime.now() - self.last_failure_time > self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
                self.half_open_calls = 0
                return True
            return False

        if self.state == CircuitState.HALF_OPEN:
            return self.half_open_calls < self.half_open_max_calls

        return False

    def record_success(self):
        """Record successful call."""
        if self.state == CircuitState.HALF_OPEN:
            self.half_open_calls += 1
            if self.half_open_calls >= self.half_open_max_calls:
                # Recovery successful
                self.state = CircuitState.CLOSED
                self.failure_count = 0

    def record_failure(self):
        """Record failed call."""
        self.failure_count += 1
        self.last_failure_time = datetime.now()

        if self.state == CircuitState.HALF_OPEN:
            # Recovery failed
            self.state = CircuitState.OPEN
        elif self.failure_count >= self.failure_threshold:
            self.state = CircuitState.OPEN

    async def execute(self, func, *args, **kwargs):
        """Execute function with circuit breaker protection."""
        if not self.can_execute():
            raise CircuitOpenError("Circuit breaker is open")

        try:
            result = await func(*args, **kwargs)
            self.record_success()
            return result
        except Exception as e:
            self.record_failure()
            raise

Queue-Based Processing

Complete implementation for asynchronous LLM job processing:

import redis
from rq import Queue
import json

class LLMJobProcessor:
    def __init__(self, redis_url: str, llm_service):
        self.redis_conn = redis.from_url(redis_url)
        self.queue = Queue('llm_jobs', connection=self.redis_conn)
        self.llm = llm_service

    def enqueue_request(self, request_id: str, prompt: str, callback_url: str):
        """Enqueue LLM request for async processing."""
        job = self.queue.enqueue(
            self.process_request,
            request_id,
            prompt,
            callback_url,
            job_timeout='5m'
        )
        return job.id

    @staticmethod
    def process_request(request_id: str, prompt: str, callback_url: str):
        """Process request (runs in worker)."""
        try:
            response = llm_service.generate(prompt)

            # Callback with result
            requests.post(callback_url, json={
                'request_id': request_id,
                'status': 'success',
                'response': response
            })

        except Exception as e:
            requests.post(callback_url, json={
                'request_id': request_id,
                'status': 'error',
                'error': str(e)
            })

    def get_job_status(self, job_id: str) -> dict:
        """Check job status."""
        job = self.queue.fetch_job(job_id)
        if not job:
            return {'status': 'not_found'}

        return {
            'status': job.get_status(),
            'result': job.result if job.is_finished else None,
            'error': str(job.exc_info) if job.is_failed else None
        }

Caching Layer

Complete implementation for LLM response caching:

import hashlib
from functools import wraps

class LLMCache:
    def __init__(self, redis_client, ttl_seconds: int = 3600):
        self.redis = redis_client
        self.ttl = ttl_seconds

    def cache_key(self, prompt: str, model: str, temperature: float) -> str:
        """Generate cache key from request parameters."""
        # Only cache deterministic requests (temperature=0)
        if temperature > 0:
            return None

        content = f"{model}:{prompt}"
        return f"llm_cache:{hashlib.sha256(content.encode()).hexdigest()}"

    def get(self, key: str) -> str | None:
        """Get cached response."""
        if not key:
            return None
        return self.redis.get(key)

    def set(self, key: str, value: str):
        """Cache response."""
        if key:
            self.redis.setex(key, self.ttl, value)

def cached_llm_call(cache: LLMCache):
    """Decorator for caching LLM calls."""
    def decorator(func):
        @wraps(func)
        async def wrapper(prompt: str, model: str = "default", temperature: float = 0, **kwargs):
            cache_key = cache.cache_key(prompt, model, temperature)

            # Check cache
            cached = cache.get(cache_key)
            if cached:
                return cached

            # Make actual call
            result = await func(prompt, model=model, temperature=temperature, **kwargs)

            # Cache result
            cache.set(cache_key, result)

            return result
        return wrapper
    return decorator