System Design at Scale Code Reference

Complete code implementations for Chapter 59: System Design at Scale.


Table of Contents


Scaling AI Inference

Dynamic Batcher

Complete implementation for batching requests dynamically for efficient GPU inference:

from dataclasses import dataclass, field
from typing import Optional
import asyncio
import time

@dataclass
class InferenceRequest:
    """A single inference request."""
    id: str
    input_tokens: list[int]
    max_output_tokens: int
    priority: int = 0
    deadline: Optional[float] = None  # Unix timestamp
    arrival_time: float = field(default_factory=time.time)

class DynamicBatcher:
    """Batch requests dynamically for efficient inference."""

    def __init__(
        self,
        max_batch_size: int = 32,
        max_wait_ms: float = 50,
        max_batch_tokens: int = 4096
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.max_batch_tokens = max_batch_tokens
        self.pending: list[InferenceRequest] = []
        self.lock = asyncio.Lock()

    async def add_request(self, request: InferenceRequest) -> asyncio.Future:
        """Add request and return future for result."""
        future = asyncio.Future()

        async with self.lock:
            self.pending.append((request, future))

            # Check if we should dispatch immediately
            if self._should_dispatch():
                batch = self._form_batch()
                asyncio.create_task(self._process_batch(batch))

        return future

    def _should_dispatch(self) -> bool:
        """Decide if we should dispatch current batch."""
        if len(self.pending) >= self.max_batch_size:
            return True

        if self.pending:
            oldest = self.pending[0][0].arrival_time
            wait_time = (time.time() - oldest) * 1000
            if wait_time >= self.max_wait_ms:
                return True

        return False

    def _form_batch(self) -> list[tuple]:
        """Form optimal batch from pending requests."""
        # Sort by priority, then by deadline
        self.pending.sort(
            key=lambda x: (-x[0].priority, x[0].deadline or float('inf'))
        )

        batch = []
        total_tokens = 0

        for item in self.pending[:]:
            request = item[0]
            request_tokens = len(request.input_tokens) + request.max_output_tokens

            if total_tokens + request_tokens <= self.max_batch_tokens:
                batch.append(item)
                total_tokens += request_tokens
                self.pending.remove(item)

            if len(batch) >= self.max_batch_size:
                break

        return batch

    async def _process_batch(self, batch: list[tuple]):
        """Process a batch of requests."""
        # This would call the actual model inference
        # Results would be set on the futures
        pass

Inference Cache

Complete implementation for multi-level caching of inference results:

from typing import Optional
import hashlib

class InferenceCache:
    """Multi-level cache for inference results."""

    def __init__(
        self,
        l1_cache,  # In-memory, per-node
        l2_cache,  # Distributed cache (Redis)
        embedding_cache  # Vector similarity cache
    ):
        self.l1 = l1_cache
        self.l2 = l2_cache
        self.embedding = embedding_cache

    def cache_key(self, prompt: str, params: dict) -> str:
        """Generate cache key from prompt and parameters."""
        content = f"{prompt}|{sorted(params.items())}"
        return hashlib.sha256(content.encode()).hexdigest()

    async def get(
        self,
        prompt: str,
        params: dict,
        similarity_threshold: float = 0.95
    ) -> Optional[dict]:
        """Get cached result, checking multiple cache levels."""
        key = self.cache_key(prompt, params)

        # L1: Exact match in local memory
        result = self.l1.get(key)
        if result:
            return {'result': result, 'cache_hit': 'l1_exact'}

        # L2: Exact match in distributed cache
        result = await self.l2.get(key)
        if result:
            self.l1.set(key, result)  # Populate L1
            return {'result': result, 'cache_hit': 'l2_exact'}

        # L3: Semantic similarity match
        similar = await self.embedding.find_similar(
            prompt, threshold=similarity_threshold
        )
        if similar:
            # Found semantically similar cached result
            return {
                'result': similar['result'],
                'cache_hit': 'semantic',
                'similarity': similar['score'],
                'original_prompt': similar['prompt']
            }

        return None

    async def set(
        self,
        prompt: str,
        params: dict,
        result: str,
        ttl: int = 3600
    ):
        """Cache result at all levels."""
        key = self.cache_key(prompt, params)

        # Set in all cache levels
        self.l1.set(key, result, ttl=ttl)
        await self.l2.set(key, result, ttl=ttl)
        await self.embedding.index(prompt, result)

GPU-Aware Load Balancer

Complete implementation for GPU-aware request routing:

from enum import Enum
from dataclasses import dataclass

class LoadBalancingStrategy(Enum):
    ROUND_ROBIN = "round_robin"
    LEAST_CONNECTIONS = "least_connections"
    LEAST_LATENCY = "least_latency"
    WEIGHTED = "weighted"
    CONSISTENT_HASH = "consistent_hash"
    GPU_AWARE = "gpu_aware"

@dataclass
class LoadBalancerConfig:
    """Configuration for AI-aware load balancer."""
    strategy: LoadBalancingStrategy
    health_check_interval_ms: int
    unhealthy_threshold: int
    considerations: list[str]

AI_LOAD_BALANCING = {
    'inference_general': LoadBalancerConfig(
        strategy=LoadBalancingStrategy.LEAST_LATENCY,
        health_check_interval_ms=1000,
        unhealthy_threshold=3,
        considerations=[
            "Route to server with lowest current latency",
            "Account for GPU memory utilization",
            "Consider request complexity (input length)"
        ]
    ),
    'inference_stateful': LoadBalancerConfig(
        strategy=LoadBalancingStrategy.CONSISTENT_HASH,
        health_check_interval_ms=1000,
        unhealthy_threshold=3,
        considerations=[
            "Route same user/session to same server",
            "Enables KV cache reuse",
            "Good for conversational AI"
        ]
    ),
    'batch_processing': LoadBalancerConfig(
        strategy=LoadBalancingStrategy.GPU_AWARE,
        health_check_interval_ms=5000,
        unhealthy_threshold=2,
        considerations=[
            "Balance GPU utilization across cluster",
            "Consider memory requirements per request",
            "Group similar-sized requests"
        ]
    ),
}

class GPUAwareLoadBalancer:
    """Load balancer that considers GPU state."""

    def __init__(self, servers: list[str]):
        self.servers = servers
        self.server_stats = {}

    async def select_server(self, request: dict) -> str:
        """Select best server for request."""
        # Estimate request resource needs
        estimated_memory = self._estimate_memory(request)
        estimated_compute = self._estimate_compute(request)

        # Get current server stats
        await self._refresh_stats()

        # Score each server
        scores = {}
        for server in self.servers:
            stats = self.server_stats.get(server, {})
            if not stats.get('healthy', False):
                continue

            available_memory = stats.get('gpu_memory_free', 0)
            current_load = stats.get('current_requests', 0)
            avg_latency = stats.get('avg_latency_ms', 1000)

            # Skip if insufficient memory
            if available_memory < estimated_memory:
                continue

            # Score based on available capacity and performance
            scores[server] = (
                available_memory / estimated_memory * 0.4 +
                (1 / (current_load + 1)) * 0.3 +
                (1000 / avg_latency) * 0.3
            )

        if not scores:
            raise NoHealthyServerError("No servers available")

        return max(scores, key=scores.get)

    def _estimate_memory(self, request: dict) -> int:
        """Estimate GPU memory needed for request."""
        input_tokens = request.get('input_tokens', 0)
        max_output = request.get('max_tokens', 100)
        # Rough estimate: ~2KB per token for KV cache
        return (input_tokens + max_output) * 2048

    def _estimate_compute(self, request: dict) -> float:
        """Estimate compute needed for request."""
        return request.get('input_tokens', 0) + request.get('max_tokens', 100)

    async def _refresh_stats(self):
        """Refresh server statistics."""
        # Would poll servers for current stats
        pass

class NoHealthyServerError(Exception):
    pass

Speculative Sampler

Complete implementation for speculative decoding:

import random

class SpeculativeSampler:
    """Speculative decoding implementation."""

    def __init__(self, draft_model, target_model, gamma: int = 4):
        self.draft = draft_model
        self.target = target_model
        self.gamma = gamma  # Number of speculative tokens

    async def generate(self, prompt: str, max_tokens: int) -> str:
        """Generate with speculative decoding."""
        tokens = self._tokenize(prompt)
        generated = []

        while len(generated) < max_tokens:
            # Draft model generates gamma candidates quickly
            draft_tokens = await self.draft.generate(
                tokens + generated,
                num_tokens=self.gamma
            )

            # Target model verifies in parallel
            target_logits = await self.target.forward(
                tokens + generated + draft_tokens
            )

            # Accept matching tokens, reject from first mismatch
            accepted = self._verify_and_accept(
                draft_tokens, target_logits
            )
            generated.extend(accepted)

            if len(accepted) < len(draft_tokens):
                # Sample new token from target at rejection point
                new_token = self._sample_from_target(target_logits, len(accepted))
                generated.append(new_token)

        return self._detokenize(generated)

    def _verify_and_accept(self, draft_tokens, target_logits) -> list:
        """Accept draft tokens that match target distribution."""
        accepted = []
        for i, token in enumerate(draft_tokens):
            target_prob = self._get_prob(target_logits[i], token)
            draft_prob = self._get_draft_prob(token)

            # Accept with probability min(1, target/draft)
            if random.random() < min(1.0, target_prob / draft_prob):
                accepted.append(token)
            else:
                break
        return accepted

    def _tokenize(self, text: str) -> list[int]:
        """Tokenize text."""
        pass

    def _detokenize(self, tokens: list[int]) -> str:
        """Detokenize tokens."""
        pass

    def _get_prob(self, logits, token) -> float:
        """Get probability of token from logits."""
        pass

    def _get_draft_prob(self, token) -> float:
        """Get draft model probability for token."""
        pass

    def _sample_from_target(self, logits, position) -> int:
        """Sample token from target distribution."""
        pass

AI Auto-Scaler

Complete implementation for auto-scaling AI inference workloads:

from dataclasses import dataclass

@dataclass
class AutoScalingPolicy:
    """Policy for auto-scaling inference servers."""
    name: str
    metric: str
    target_value: float
    scale_up_threshold: float
    scale_down_threshold: float
    cooldown_seconds: int
    min_instances: int
    max_instances: int

class AIAutoScaler:
    """Auto-scaler for AI inference workloads."""

    def __init__(self, policies: list[AutoScalingPolicy]):
        self.policies = policies
        self.last_scale_time = 0

    async def evaluate(self, current_metrics: dict) -> dict:
        """Evaluate scaling decision."""
        decisions = []

        for policy in self.policies:
            current_value = current_metrics.get(policy.metric)
            if current_value is None:
                continue

            if current_value > policy.scale_up_threshold:
                decisions.append({
                    'policy': policy.name,
                    'action': 'scale_up',
                    'reason': f"{policy.metric}={current_value} > {policy.scale_up_threshold}"
                })
            elif current_value < policy.scale_down_threshold:
                decisions.append({
                    'policy': policy.name,
                    'action': 'scale_down',
                    'reason': f"{policy.metric}={current_value} < {policy.scale_down_threshold}"
                })

        return self._resolve_decisions(decisions)

    def _resolve_decisions(self, decisions: list[dict]) -> dict:
        """Resolve potentially conflicting decisions."""
        # Scale up takes priority over scale down
        scale_up = any(d['action'] == 'scale_up' for d in decisions)
        scale_down = all(d['action'] == 'scale_down' for d in decisions) and decisions

        if scale_up:
            return {'action': 'scale_up', 'decisions': decisions}
        elif scale_down:
            return {'action': 'scale_down', 'decisions': decisions}
        else:
            return {'action': 'none', 'decisions': decisions}

# Common AI scaling policies {.unnumbered}
AI_SCALING_POLICIES = [
    AutoScalingPolicy(
        name="gpu_utilization",
        metric="avg_gpu_utilization",
        target_value=0.7,
        scale_up_threshold=0.8,
        scale_down_threshold=0.4,
        cooldown_seconds=300,
        min_instances=2,
        max_instances=20
    ),
    AutoScalingPolicy(
        name="request_latency",
        metric="p99_latency_ms",
        target_value=500,
        scale_up_threshold=750,
        scale_down_threshold=200,
        cooldown_seconds=180,
        min_instances=2,
        max_instances=20
    ),
    AutoScalingPolicy(
        name="queue_depth",
        metric="pending_requests",
        target_value=10,
        scale_up_threshold=50,
        scale_down_threshold=5,
        cooldown_seconds=60,
        min_instances=2,
        max_instances=20
    ),
]

Multi-Region Deployment

Global Load Balancer

Complete implementation for multi-region routing and architecture patterns:

from dataclasses import dataclass
from enum import Enum

class MultiRegionPattern(Enum):
    ACTIVE_PASSIVE = "active_passive"
    ACTIVE_ACTIVE = "active_active"
    FOLLOW_THE_SUN = "follow_the_sun"
    DATA_SOVEREIGN = "data_sovereign"

@dataclass
class RegionConfig:
    """Configuration for a region."""
    region: str
    role: str  # primary, secondary, read_replica
    model_versions: list[str]
    capacity: dict
    data_residency: list[str]  # What data can be stored/processed here

@dataclass
class MultiRegionArchitecture:
    """Multi-region deployment architecture."""
    pattern: MultiRegionPattern
    regions: list[RegionConfig]
    routing_policy: str
    failover_policy: str
    data_replication: str
    considerations: list[str]

MULTI_REGION_PATTERNS = {
    'active_passive': MultiRegionArchitecture(
        pattern=MultiRegionPattern.ACTIVE_PASSIVE,
        regions=[
            RegionConfig("us-east-1", "primary", ["v1.0"], {"gpus": 100}, ["us"]),
            RegionConfig("us-west-2", "secondary", ["v1.0"], {"gpus": 50}, ["us"]),
        ],
        routing_policy="Route all traffic to primary; failover to secondary",
        failover_policy="Automatic failover on health check failure",
        data_replication="Async replication from primary to secondary",
        considerations=[
            "Secondary resources often underutilized",
            "Failover adds latency for affected users",
            "Simpler to reason about consistency"
        ]
    ),
    'active_active': MultiRegionArchitecture(
        pattern=MultiRegionPattern.ACTIVE_ACTIVE,
        regions=[
            RegionConfig("us-east-1", "primary", ["v1.0"], {"gpus": 75}, ["us"]),
            RegionConfig("eu-west-1", "primary", ["v1.0"], {"gpus": 75}, ["eu"]),
            RegionConfig("ap-northeast-1", "primary", ["v1.0"], {"gpus": 50}, ["apac"]),
        ],
        routing_policy="Route to nearest healthy region (GeoDNS)",
        failover_policy="Automatic rerouting on regional failure",
        data_replication="Multi-master with conflict resolution",
        considerations=[
            "Better resource utilization",
            "Lower latency for global users",
            "Conflict resolution complexity",
            "Higher operational complexity"
        ]
    ),
}

class GlobalLoadBalancer:
    """Route requests across regions."""

    def __init__(self, regions: list[RegionConfig]):
        self.regions = {r.region: r for r in regions}
        self.region_health = {r.region: True for r in regions}

    async def route(self, request: dict) -> str:
        """Determine best region for request."""
        user_location = request.get('geo_location')
        data_requirements = request.get('data_residency_required')

        # Filter to regions that can handle data requirements
        eligible = [
            r for r in self.regions.values()
            if self._meets_residency(r, data_requirements)
            and self.region_health[r.region]
        ]

        if not eligible:
            raise NoEligibleRegionError("No region meets requirements")

        # Select by latency (closest region)
        return min(
            eligible,
            key=lambda r: self._estimated_latency(r.region, user_location)
        ).region

    def _meets_residency(self, region: RegionConfig, required: list[str]) -> bool:
        if not required:
            return True
        return all(r in region.data_residency for r in required)

    def _estimated_latency(self, region: str, user_location: str) -> float:
        # Would use actual latency measurements or geographic distance
        return 0.0

class NoEligibleRegionError(Exception):
    pass

Cross-Region Replicator

Complete implementation for cross-region data replication:

from dataclasses import dataclass
import asyncio

@dataclass
class ReplicationConfig:
    """Cross-region replication configuration."""
    mode: str  # sync, async, semi-sync
    lag_tolerance_ms: int
    conflict_resolution: str  # last_write_wins, merge, custom
    what_to_replicate: list[str]

class CrossRegionReplicator:
    """Manage cross-region data replication."""

    def __init__(self, primary: str, secondaries: list[str], config: ReplicationConfig):
        self.primary = primary
        self.secondaries = secondaries
        self.config = config

    async def write(self, key: str, value: any, options: dict = None) -> dict:
        """Write with cross-region replication."""
        options = options or {}
        consistency = options.get('consistency', 'eventual')

        # Write to primary
        primary_result = await self._write_primary(key, value)

        if consistency == 'strong':
            # Wait for all secondaries (slow but consistent)
            secondary_results = await asyncio.gather(*[
                self._write_secondary(region, key, value)
                for region in self.secondaries
            ])
            return {
                'primary': primary_result,
                'secondaries': secondary_results,
                'consistency': 'strong'
            }

        elif consistency == 'eventual':
            # Async replication (fast but eventual)
            for region in self.secondaries:
                asyncio.create_task(
                    self._write_secondary(region, key, value)
                )
            return {
                'primary': primary_result,
                'consistency': 'eventual'
            }

        elif consistency == 'quorum':
            # Wait for majority (balanced)
            quorum_size = (len(self.secondaries) + 1) // 2 + 1
            results = [primary_result]

            for coro in asyncio.as_completed([
                self._write_secondary(r, key, value)
                for r in self.secondaries
            ]):
                result = await coro
                results.append(result)
                if len(results) >= quorum_size:
                    break

            return {
                'acked_regions': len(results),
                'consistency': 'quorum'
            }

    async def handle_conflict(self, key: str, versions: list[dict]) -> any:
        """Resolve conflict between versions."""
        if self.config.conflict_resolution == 'last_write_wins':
            return max(versions, key=lambda v: v['timestamp'])['value']
        elif self.config.conflict_resolution == 'merge':
            return self._merge_values(versions)
        else:
            raise ConflictError(f"Cannot resolve conflict for {key}")

    async def _write_primary(self, key: str, value: any) -> dict:
        """Write to primary region."""
        pass

    async def _write_secondary(self, region: str, key: str, value: any) -> dict:
        """Write to secondary region."""
        pass

    def _merge_values(self, versions: list[dict]) -> any:
        """Merge conflicting values."""
        pass

class ConflictError(Exception):
    pass

Traffic Management

Token Bucket Rate Limiter

Complete implementation for rate limiting with token bucket algorithm:

from dataclasses import dataclass
from enum import Enum
import time
import asyncio

class RateLimitStrategy(Enum):
    FIXED_WINDOW = "fixed_window"
    SLIDING_WINDOW = "sliding_window"
    TOKEN_BUCKET = "token_bucket"
    LEAKY_BUCKET = "leaky_bucket"

@dataclass
class RateLimitConfig:
    """Rate limiting configuration."""
    requests_per_second: float
    burst_size: int
    strategy: RateLimitStrategy
    scope: str  # global, per_user, per_ip, per_api_key

class TokenBucketRateLimiter:
    """Token bucket rate limiter."""

    def __init__(self, config: RateLimitConfig):
        self.config = config
        self.buckets = {}  # scope_key -> (tokens, last_update)

    async def allow(self, scope_key: str) -> tuple[bool, dict]:
        """Check if request should be allowed."""
        now = time.time()

        if scope_key not in self.buckets:
            self.buckets[scope_key] = (self.config.burst_size, now)

        tokens, last_update = self.buckets[scope_key]

        # Add tokens based on time passed
        elapsed = now - last_update
        tokens = min(
            self.config.burst_size,
            tokens + elapsed * self.config.requests_per_second
        )

        if tokens >= 1:
            # Allow request
            self.buckets[scope_key] = (tokens - 1, now)
            return True, {
                'remaining': int(tokens - 1),
                'reset_in': (self.config.burst_size - tokens + 1) / self.config.requests_per_second
            }
        else:
            # Deny request
            wait_time = (1 - tokens) / self.config.requests_per_second
            return False, {
                'remaining': 0,
                'retry_after': wait_time
            }

class AdaptiveRateLimiter:
    """Adjust rate limits based on system health."""

    def __init__(self, base_config: RateLimitConfig, health_checker):
        self.base_config = base_config
        self.health = health_checker
        self.current_multiplier = 1.0

    async def get_limit(self, scope_key: str) -> RateLimitConfig:
        """Get current rate limit based on system health."""
        health_status = await self.health.get_status()

        if health_status['load'] > 0.9:
            self.current_multiplier = 0.5  # Cut limits in half
        elif health_status['load'] > 0.8:
            self.current_multiplier = 0.75
        elif health_status['load'] < 0.5:
            self.current_multiplier = 1.0  # Full limits

        return RateLimitConfig(
            requests_per_second=self.base_config.requests_per_second * self.current_multiplier,
            burst_size=int(self.base_config.burst_size * self.current_multiplier),
            strategy=self.base_config.strategy,
            scope=self.base_config.scope
        )

AI Traffic Shaper

Complete implementation for AI workload traffic shaping:

import asyncio
import time

class AITrafficShaper:
    """Shape traffic for AI workloads."""

    def __init__(self, config: dict):
        self.priority_queues = {
            'realtime': asyncio.PriorityQueue(),  # Interactive users
            'batch': asyncio.PriorityQueue(),     # Background processing
            'internal': asyncio.PriorityQueue(),  # System operations
        }
        self.config = config

    async def enqueue(self, request: dict) -> asyncio.Future:
        """Enqueue request with appropriate priority."""
        future = asyncio.Future()
        priority = self._calculate_priority(request)
        queue_name = self._select_queue(request)

        await self.priority_queues[queue_name].put(
            (priority, time.time(), request, future)
        )
        return future

    def _calculate_priority(self, request: dict) -> int:
        """Calculate request priority (lower = higher priority)."""
        base_priority = 50

        # Adjust based on user tier
        tier_adjustment = {
            'enterprise': -20,
            'pro': -10,
            'free': 0
        }.get(request.get('user_tier', 'free'), 0)

        # Adjust based on request type
        type_adjustment = {
            'completion': 0,
            'embedding': -5,  # Fast operations get slight priority
            'fine_tuning': 20  # Background work is lower priority
        }.get(request.get('request_type', 'completion'), 0)

        # Adjust based on wait time (prevent starvation)
        wait_penalty = 0
        if 'enqueue_time' in request:
            wait_seconds = time.time() - request['enqueue_time']
            wait_penalty = -int(wait_seconds / 10)  # Boost priority over time

        return base_priority + tier_adjustment + type_adjustment + wait_penalty

    def _select_queue(self, request: dict) -> str:
        """Select queue based on request characteristics."""
        if request.get('stream', False):
            return 'realtime'
        elif request.get('batch', False):
            return 'batch'
        else:
            return 'realtime'

Deployment Patterns

Canary Deployment

Complete implementation for canary deployments:

from dataclasses import dataclass
import asyncio

@dataclass
class CanaryConfig:
    """Configuration for canary deployment."""
    initial_percentage: float  # Start with this % of traffic
    increment: float           # Increase by this % each step
    step_duration_minutes: int # How long before each increment
    success_criteria: dict     # Metrics that must be met
    rollback_criteria: dict    # Metrics that trigger rollback

class CanaryDeployment:
    """Manage canary deployments for AI models."""

    def __init__(self, config: CanaryConfig, metrics_client, router):
        self.config = config
        self.metrics = metrics_client
        self.router = router
        self.current_percentage = 0
        self.state = "not_started"

    async def start(self, new_version: str):
        """Start canary deployment."""
        self.new_version = new_version
        self.current_percentage = self.config.initial_percentage
        self.state = "in_progress"

        await self.router.set_traffic_split({
            'current': 100 - self.current_percentage,
            'canary': self.current_percentage
        })

        # Monitor and advance
        asyncio.create_task(self._monitor_loop())

    async def _monitor_loop(self):
        """Monitor metrics and advance canary."""
        while self.state == "in_progress":
            await asyncio.sleep(self.config.step_duration_minutes * 60)

            metrics = await self._collect_metrics()

            if self._should_rollback(metrics):
                await self._rollback("Metrics degraded")
                return

            if self._meets_success_criteria(metrics):
                self.current_percentage += self.config.increment

                if self.current_percentage >= 100:
                    await self._complete()
                    return

                await self.router.set_traffic_split({
                    'current': 100 - self.current_percentage,
                    'canary': self.current_percentage
                })

    async def _collect_metrics(self) -> dict:
        """Collect comparison metrics."""
        return {
            'canary_latency_p99': await self.metrics.query(
                "latency_p99", filters={"version": self.new_version}
            ),
            'canary_error_rate': await self.metrics.query(
                "error_rate", filters={"version": self.new_version}
            ),
            'current_latency_p99': await self.metrics.query(
                "latency_p99", filters={"version": "current"}
            ),
            'current_error_rate': await self.metrics.query(
                "error_rate", filters={"version": "current"}
            ),
        }

    def _should_rollback(self, metrics: dict) -> bool:
        """Check if canary should be rolled back."""
        criteria = self.config.rollback_criteria

        # Error rate too high
        if metrics['canary_error_rate'] > criteria.get('max_error_rate', 0.01):
            return True

        # Latency regression
        latency_increase = (
            metrics['canary_latency_p99'] / metrics['current_latency_p99']
        )
        if latency_increase > criteria.get('max_latency_increase', 1.5):
            return True

        return False

    def _meets_success_criteria(self, metrics: dict) -> bool:
        """Check if canary meets success criteria."""
        criteria = self.config.success_criteria

        # Error rate acceptable
        if metrics['canary_error_rate'] > criteria.get('max_error_rate', 0.005):
            return False

        # Latency acceptable
        latency_increase = (
            metrics['canary_latency_p99'] / metrics['current_latency_p99']
        )
        if latency_increase > criteria.get('max_latency_increase', 1.1):
            return False

        return True

    async def _rollback(self, reason: str):
        """Rollback canary deployment."""
        self.state = "rolled_back"
        await self.router.set_traffic_split({'current': 100, 'canary': 0})
        # Alert on-call

    async def _complete(self):
        """Complete canary deployment."""
        self.state = "completed"
        await self.router.set_traffic_split({'current': 0, 'canary': 100})
        await self.router.promote_canary_to_current()

Blue-Green Deployment

Complete implementation for blue-green deployments:

class BlueGreenDeployment:
    """Manage blue-green deployments."""

    def __init__(self, blue_environment, green_environment, router):
        self.blue = blue_environment
        self.green = green_environment
        self.router = router
        self.active = "blue"

    async def deploy(self, new_version: str):
        """Deploy new version to inactive environment."""
        inactive = "green" if self.active == "blue" else "blue"
        inactive_env = self.green if inactive == "green" else self.blue

        # Deploy to inactive environment
        await inactive_env.deploy(new_version)

        # Run smoke tests
        test_results = await self._run_smoke_tests(inactive_env)
        if not test_results['passed']:
            raise DeploymentError(f"Smoke tests failed: {test_results['failures']}")

        # Run integration tests
        integration_results = await self._run_integration_tests(inactive_env)
        if not integration_results['passed']:
            raise DeploymentError(f"Integration tests failed")

        return inactive

    async def switch(self):
        """Switch traffic to the other environment."""
        new_active = "green" if self.active == "blue" else "blue"

        await self.router.switch_to(new_active)
        self.active = new_active

        # Keep old environment warm for quick rollback
        return {
            'active': self.active,
            'rollback_available': True
        }

    async def rollback(self):
        """Rollback to previous environment."""
        old_active = "green" if self.active == "blue" else "blue"
        await self.router.switch_to(old_active)
        self.active = old_active

    async def _run_smoke_tests(self, env) -> dict:
        """Run smoke tests on environment."""
        pass

    async def _run_integration_tests(self, env) -> dict:
        """Run integration tests on environment."""
        pass

class DeploymentError(Exception):
    pass

Chaos Engineering

Chaos Monkey

Complete implementation for chaos engineering experiments:

from dataclasses import dataclass
from enum import Enum
import random
import asyncio
import time

class ChaosType(Enum):
    LATENCY = "latency"
    ERROR = "error"
    RESOURCE_EXHAUSTION = "resource_exhaustion"
    DATA_CORRUPTION = "data_corruption"
    NETWORK_PARTITION = "network_partition"

@dataclass
class ChaosExperiment:
    """Definition of a chaos experiment."""
    name: str
    chaos_type: ChaosType
    target: str
    parameters: dict
    hypothesis: str
    expected_behavior: str
    abort_conditions: list[str]

class ChaosMonkey:
    """Inject chaos into AI systems."""

    def __init__(self, config: dict):
        self.config = config
        self.active_experiments = {}

    async def start_experiment(self, experiment: ChaosExperiment) -> str:
        """Start a chaos experiment."""
        experiment_id = self._generate_id()

        # Validate experiment is safe
        if not self._is_safe(experiment):
            raise UnsafeExperimentError("Experiment exceeds safety limits")

        # Start injection
        if experiment.chaos_type == ChaosType.LATENCY:
            injector = self._create_latency_injector(experiment)
        elif experiment.chaos_type == ChaosType.ERROR:
            injector = self._create_error_injector(experiment)
        elif experiment.chaos_type == ChaosType.RESOURCE_EXHAUSTION:
            injector = self._create_resource_injector(experiment)

        self.active_experiments[experiment_id] = {
            'experiment': experiment,
            'injector': injector,
            'start_time': time.time()
        }

        await injector.start()

        # Monitor for abort conditions
        asyncio.create_task(self._monitor_experiment(experiment_id))

        return experiment_id

    def _create_latency_injector(self, experiment: ChaosExperiment):
        """Create latency injection."""
        return LatencyInjector(
            target=experiment.target,
            latency_ms=experiment.parameters.get('latency_ms', 1000),
            percentage=experiment.parameters.get('percentage', 0.1)
        )

    def _create_error_injector(self, experiment: ChaosExperiment):
        """Create error injection."""
        return ErrorInjector(
            target=experiment.target,
            error_type=experiment.parameters.get('error_type', 'timeout'),
            percentage=experiment.parameters.get('percentage', 0.1)
        )

    def _create_resource_injector(self, experiment: ChaosExperiment):
        """Create resource exhaustion injector."""
        pass

    def _is_safe(self, experiment: ChaosExperiment) -> bool:
        """Check if experiment is safe to run."""
        # Don't run in production without explicit approval
        if self.config.get('environment') == 'production':
            if not experiment.parameters.get('production_approved'):
                return False

        # Limit blast radius
        max_percentage = self.config.get('max_chaos_percentage', 0.1)
        if experiment.parameters.get('percentage', 0) > max_percentage:
            return False

        return True

    async def _monitor_experiment(self, experiment_id: str):
        """Monitor experiment and abort if needed."""
        exp_data = self.active_experiments[experiment_id]
        experiment = exp_data['experiment']

        while experiment_id in self.active_experiments:
            await asyncio.sleep(10)

            metrics = await self._get_metrics(experiment.target)

            for condition in experiment.abort_conditions:
                if self._should_abort(condition, metrics):
                    await self.stop_experiment(experiment_id, reason=condition)
                    return

    async def stop_experiment(self, experiment_id: str, reason: str = None):
        """Stop a chaos experiment."""
        if experiment_id not in self.active_experiments:
            return

        exp_data = self.active_experiments.pop(experiment_id)
        await exp_data['injector'].stop()

        # Log results
        return {
            'experiment': exp_data['experiment'].name,
            'duration': time.time() - exp_data['start_time'],
            'stop_reason': reason or 'manual'
        }

    def _generate_id(self) -> str:
        """Generate unique experiment ID."""
        import uuid
        return str(uuid.uuid4())

    async def _get_metrics(self, target: str) -> dict:
        """Get metrics for target."""
        pass

    def _should_abort(self, condition: str, metrics: dict) -> bool:
        """Check if abort condition is met."""
        pass

class LatencyInjector:
    """Inject latency into service calls."""

    def __init__(self, target: str, latency_ms: int, percentage: float):
        self.target = target
        self.latency_ms = latency_ms
        self.percentage = percentage
        self.active = False

    async def start(self):
        self.active = True

    async def stop(self):
        self.active = False

    async def maybe_inject(self):
        """Potentially inject latency."""
        if self.active and random.random() < self.percentage:
            await asyncio.sleep(self.latency_ms / 1000)

class ErrorInjector:
    """Inject errors into service calls."""

    def __init__(self, target: str, error_type: str, percentage: float):
        self.target = target
        self.error_type = error_type
        self.percentage = percentage
        self.active = False

    async def start(self):
        self.active = True

    async def stop(self):
        self.active = False

class UnsafeExperimentError(Exception):
    pass

Data Architecture

Feature Store

Complete implementation for feature store operations:

from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional

@dataclass
class FeatureDefinition:
    """Definition of a feature."""
    name: str
    dtype: str
    description: str
    entity: str  # What entity this feature belongs to (user, item, etc.)
    freshness_requirement: str  # real-time, hourly, daily
    computation: str  # SQL, Python, streaming

@dataclass
class FeatureView:
    """A collection of related features."""
    name: str
    entities: list[str]
    features: list[FeatureDefinition]
    online: bool  # Available for real-time serving
    offline: bool  # Available for batch/training
    ttl: Optional[int] = None  # Time-to-live in seconds

class FeatureStore:
    """Interface for feature store operations."""

    def __init__(self, online_store, offline_store, registry):
        self.online = online_store
        self.offline = offline_store
        self.registry = registry

    async def get_online_features(
        self,
        feature_view: str,
        entity_keys: dict[str, Any]
    ) -> dict[str, Any]:
        """Get features for real-time inference."""
        view = self.registry.get_view(feature_view)

        if not view.online:
            raise ValueError(f"Feature view {feature_view} not available online")

        features = await self.online.get(
            view=feature_view,
            keys=entity_keys
        )

        # Check freshness
        for feature_name, value in features.items():
            feature_def = next(
                f for f in view.features if f.name == feature_name
            )
            if not self._check_freshness(value, feature_def):
                # Log stale feature, potentially fall back
                pass

        return features

    def get_training_features(
        self,
        feature_view: str,
        entity_df,  # DataFrame with entity keys and timestamps
        features: list[str]
    ):
        """Get point-in-time correct features for training."""
        view = self.registry.get_view(feature_view)

        if not view.offline:
            raise ValueError(f"Feature view {feature_view} not available offline")

        # Point-in-time join to prevent data leakage
        return self.offline.get_historical_features(
            view=feature_view,
            entity_df=entity_df,
            features=features
        )

    def _check_freshness(self, value: dict, feature_def: FeatureDefinition) -> bool:
        """Check if feature value meets freshness requirement."""
        if 'timestamp' not in value:
            return True

        age = datetime.utcnow() - value['timestamp']
        max_age = {
            'real-time': 60,  # 1 minute
            'hourly': 3600,
            'daily': 86400
        }.get(feature_def.freshness_requirement, 86400)

        return age.total_seconds() <= max_age

Skew Detector

Complete implementation for detecting training-serving skew:

from dataclasses import dataclass
import math

@dataclass
class SkewSource:
    """Source of training-serving skew."""
    category: str
    description: str
    detection: str
    prevention: str

SKEW_SOURCES = [
    SkewSource(
        category="Feature computation",
        description="Features computed differently in training vs serving",
        detection="Compare feature distributions between batch and online",
        prevention="Use feature store with shared computation logic"
    ),
    SkewSource(
        category="Data freshness",
        description="Training uses stale data that doesn't reflect production",
        detection="Monitor feature staleness at serving time",
        prevention="Regular retraining; real-time feature updates"
    ),
    SkewSource(
        category="Preprocessing",
        description="Different tokenization, normalization, etc.",
        detection="Hash preprocessing outputs; compare training vs serving",
        prevention="Package preprocessing with model; use same code path"
    ),
    SkewSource(
        category="Data leakage",
        description="Training features include future information",
        detection="Point-in-time validation; hold-out by time",
        prevention="Strict point-in-time feature joins"
    ),
]

class SkewDetector:
    """Detect training-serving skew."""

    def __init__(self, training_stats: dict, serving_monitor):
        self.training_stats = training_stats
        self.serving = serving_monitor

    async def check_skew(self, feature_name: str) -> dict:
        """Check for skew in a feature."""
        training = self.training_stats.get(feature_name, {})
        serving = await self.serving.get_stats(feature_name)

        skew_metrics = {}

        # Distribution shift (PSI - Population Stability Index)
        psi = self._calculate_psi(
            training.get('histogram', []),
            serving.get('histogram', [])
        )
        skew_metrics['psi'] = psi
        skew_metrics['psi_warning'] = psi > 0.1
        skew_metrics['psi_critical'] = psi > 0.25

        # Mean shift
        mean_shift = abs(training.get('mean', 0) - serving.get('mean', 0))
        std = training.get('std', 1)
        skew_metrics['mean_shift_stds'] = mean_shift / std if std > 0 else 0

        # Missing rate change
        missing_rate_change = abs(
            training.get('missing_rate', 0) - serving.get('missing_rate', 0)
        )
        skew_metrics['missing_rate_change'] = missing_rate_change

        return skew_metrics

    def _calculate_psi(self, expected: list, actual: list) -> float:
        """Calculate Population Stability Index."""
        if len(expected) != len(actual):
            return float('inf')

        psi = 0
        for e, a in zip(expected, actual):
            e = max(e, 0.0001)  # Avoid division by zero
            a = max(a, 0.0001)
            psi += (a - e) * math.log(a / e)

        return psi

Data Version Manager

Complete implementation for dataset versioning:

from dataclasses import dataclass
from datetime import datetime
from typing import Optional

@dataclass
class DataVersion:
    """A version of a dataset."""
    dataset_id: str
    version: str
    created_at: datetime
    schema_hash: str
    row_count: int
    checksum: str
    parent_version: Optional[str]
    transformations: list[str]

class DataVersionManager:
    """Manage dataset versions at scale."""

    def __init__(self, catalog, storage):
        self.catalog = catalog
        self.storage = storage

    async def create_version(
        self,
        dataset_id: str,
        data,
        parent_version: str = None,
        transformations: list[str] = None
    ) -> DataVersion:
        """Create a new version of a dataset."""
        # Compute metadata
        schema_hash = self._hash_schema(data)
        row_count = len(data)
        checksum = self._compute_checksum(data)

        # Generate version ID
        version = self._generate_version_id()

        # Store data
        storage_path = f"{dataset_id}/{version}"
        await self.storage.write(storage_path, data)

        # Create version record
        version_record = DataVersion(
            dataset_id=dataset_id,
            version=version,
            created_at=datetime.utcnow(),
            schema_hash=schema_hash,
            row_count=row_count,
            checksum=checksum,
            parent_version=parent_version,
            transformations=transformations or []
        )

        await self.catalog.register_version(version_record)

        return version_record

    async def get_lineage(self, dataset_id: str, version: str) -> list[DataVersion]:
        """Get full lineage of a dataset version."""
        lineage = []
        current = await self.catalog.get_version(dataset_id, version)

        while current:
            lineage.append(current)
            if current.parent_version:
                current = await self.catalog.get_version(
                    dataset_id, current.parent_version
                )
            else:
                current = None

        return lineage

    async def compare_versions(
        self,
        dataset_id: str,
        version_a: str,
        version_b: str
    ) -> dict:
        """Compare two versions of a dataset."""
        v_a = await self.catalog.get_version(dataset_id, version_a)
        v_b = await self.catalog.get_version(dataset_id, version_b)

        return {
            'schema_changed': v_a.schema_hash != v_b.schema_hash,
            'row_count_diff': v_b.row_count - v_a.row_count,
            'same_data': v_a.checksum == v_b.checksum,
            'transformations_between': self._get_transformations_between(v_a, v_b)
        }

    def _hash_schema(self, data) -> str:
        """Hash the schema of data."""
        pass

    def _compute_checksum(self, data) -> str:
        """Compute checksum of data."""
        pass

    def _generate_version_id(self) -> str:
        """Generate unique version ID."""
        import uuid
        return str(uuid.uuid4())[:8]

    def _get_transformations_between(self, v_a: DataVersion, v_b: DataVersion) -> list[str]:
        """Get transformations between two versions."""
        pass

Failure Handling

Circuit Breaker

Complete implementation for circuit breaker pattern:

import time
from enum import Enum
from dataclasses import dataclass
from typing import Callable, Any

class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Failing fast
    HALF_OPEN = "half_open"  # Testing recovery

@dataclass
class CircuitBreakerConfig:
    failure_threshold: int = 5
    success_threshold: int = 3
    timeout: float = 30.0  # Seconds before trying again
    half_open_max_calls: int = 3

class CircuitBreaker:
    """Circuit breaker for external dependencies."""

    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.name = name
        self.config = config or CircuitBreakerConfig()
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time = 0
        self.half_open_calls = 0

    async def call(
        self,
        func: Callable,
        *args,
        fallback: Callable = None,
        **kwargs
    ) -> Any:
        """Execute function with circuit breaker protection."""
        if self.state == CircuitState.OPEN:
            if self._should_try_reset():
                self.state = CircuitState.HALF_OPEN
                self.half_open_calls = 0
            else:
                if fallback:
                    return await fallback(*args, **kwargs)
                raise CircuitOpenError(f"Circuit {self.name} is open")

        try:
            if self.state == CircuitState.HALF_OPEN:
                self.half_open_calls += 1
                if self.half_open_calls > self.config.half_open_max_calls:
                    self.state = CircuitState.OPEN
                    raise CircuitOpenError(f"Circuit {self.name} half-open limit reached")

            result = await func(*args, **kwargs)
            self._record_success()
            return result

        except Exception as e:
            self._record_failure()
            if fallback:
                return await fallback(*args, **kwargs)
            raise

    def _record_success(self):
        self.failure_count = 0
        if self.state == CircuitState.HALF_OPEN:
            self.success_count += 1
            if self.success_count >= self.config.success_threshold:
                self.state = CircuitState.CLOSED
                self.success_count = 0

    def _record_failure(self):
        self.failure_count += 1
        self.last_failure_time = time.time()
        self.success_count = 0

        if self.failure_count >= self.config.failure_threshold:
            self.state = CircuitState.OPEN

    def _should_try_reset(self) -> bool:
        return time.time() - self.last_failure_time >= self.config.timeout

class CircuitOpenError(Exception):
    pass

Graceful Degradation

Complete implementation for graceful degradation:

from dataclasses import dataclass

@dataclass
class DegradationLevel:
    """A level of service degradation."""
    name: str
    description: str
    trigger: str
    actions: list[str]
    user_impact: str

DEGRADATION_LEVELS = [
    DegradationLevel(
        name="NORMAL",
        description="Full service",
        trigger="All systems healthy",
        actions=["Full feature set"],
        user_impact="None"
    ),
    DegradationLevel(
        name="DEGRADED_L1",
        description="Non-critical features disabled",
        trigger="Secondary services unhealthy",
        actions=[
            "Disable personalization",
            "Use cached recommendations",
            "Reduce logging verbosity"
        ],
        user_impact="Slightly less relevant results"
    ),
    DegradationLevel(
        name="DEGRADED_L2",
        description="Reduced functionality",
        trigger="Primary services stressed",
        actions=[
            "Use smaller/faster model",
            "Reduce max output length",
            "Aggressive caching",
            "Rate limit non-essential endpoints"
        ],
        user_impact="Slower responses, reduced quality"
    ),
    DegradationLevel(
        name="DEGRADED_L3",
        description="Minimal functionality",
        trigger="Critical capacity issues",
        actions=[
            "Static fallback responses",
            "Queue requests",
            "Reject low-priority traffic"
        ],
        user_impact="Limited functionality, potential errors"
    ),
    DegradationLevel(
        name="MAINTENANCE",
        description="Service unavailable",
        trigger="Complete outage or maintenance",
        actions=[
            "Return maintenance page",
            "Queue requests for later"
        ],
        user_impact="Service unavailable"
    ),
]

class GracefulDegradation:
    """Manage graceful degradation."""

    def __init__(self, health_checker):
        self.health = health_checker
        self.current_level = "NORMAL"

    async def check_and_adjust(self) -> str:
        """Check health and adjust degradation level."""
        health_status = await self.health.get_status()

        if health_status['critical_services_healthy']:
            if health_status['all_services_healthy']:
                new_level = "NORMAL"
            else:
                new_level = "DEGRADED_L1"
        elif health_status['primary_services_stressed']:
            new_level = "DEGRADED_L2"
        elif health_status['capacity_critical']:
            new_level = "DEGRADED_L3"
        else:
            new_level = "MAINTENANCE"

        if new_level != self.current_level:
            await self._transition(self.current_level, new_level)
            self.current_level = new_level

        return self.current_level

    async def _transition(self, from_level: str, to_level: str):
        """Handle level transition."""
        # Log transition
        # Alert if degrading
        # Apply actions for new level
        pass

Operational Excellence

AI Service Mesh Config

Complete implementation for configuring service mesh for AI workloads:

from dataclasses import dataclass

@dataclass
class ServiceMeshFeature:
    """Features relevant for AI workloads."""
    feature: str
    benefit_for_ai: str
    implementation: str

SERVICE_MESH_AI_FEATURES = [
    ServiceMeshFeature(
        feature="Automatic retries",
        benefit_for_ai="Handle transient GPU failures gracefully",
        implementation="Configure retry budget to avoid thundering herd"
    ),
    ServiceMeshFeature(
        feature="Circuit breaking",
        benefit_for_ai="Prevent cascade failures when model serving degrades",
        implementation="Set thresholds based on inference latency patterns"
    ),
    ServiceMeshFeature(
        feature="Load balancing",
        benefit_for_ai="GPU-aware routing based on memory/compute availability",
        implementation="Custom load balancing policies for AI workloads"
    ),
    ServiceMeshFeature(
        feature="mTLS",
        benefit_for_ai="Secure model serving and feature transmission",
        implementation="Automatic encryption without code changes"
    ),
    ServiceMeshFeature(
        feature="Observability",
        benefit_for_ai="Distributed tracing across inference pipeline",
        implementation="Automatic span creation for all service calls"
    ),
    ServiceMeshFeature(
        feature="Traffic splitting",
        benefit_for_ai="Canary deployments for model updates",
        implementation="Percentage-based routing to model versions"
    ),
]

class AIServiceMeshConfig:
    """Configure service mesh for AI workloads."""

    def __init__(self, mesh_client):
        self.mesh = mesh_client

    async def configure_inference_service(self, service_name: str):
        """Configure mesh for inference service."""
        # Retry policy - be careful with non-idempotent operations
        await self.mesh.set_retry_policy(service_name, {
            'attempts': 3,
            'per_try_timeout': '10s',
            'retry_on': ['5xx', 'reset', 'connect-failure'],
            # Don't retry if response started streaming
            'retriable_headers': ['x-retry-safe']
        })

        # Circuit breaker
        await self.mesh.set_circuit_breaker(service_name, {
            'max_connections': 1000,
            'max_pending_requests': 100,
            'max_requests': 1000,
            'max_retries': 3,
            'consecutive_5xx': 5,  # Open after 5 consecutive errors
            'interval': '10s',
            'base_ejection_time': '30s'
        })

        # Timeout - longer for AI workloads
        await self.mesh.set_timeout(service_name, {
            'timeout': '60s',  # LLM generation can be slow
            'idle_timeout': '300s'  # Keep connections for batching
        })

        # Health check
        await self.mesh.set_health_check(service_name, {
            'path': '/health',
            'interval': '5s',
            'timeout': '3s',
            'unhealthy_threshold': 3,
            'healthy_threshold': 2
        })

AI Instrumentor

Complete implementation for comprehensive AI inference instrumentation:

from dataclasses import dataclass
from typing import Optional
import time

@dataclass
class InferenceSpan:
    """Detailed tracing for inference requests."""
    trace_id: str
    request_id: str
    model_name: str
    model_version: str

    # Timing breakdown
    queue_time_ms: float
    preprocessing_time_ms: float
    inference_time_ms: float
    postprocessing_time_ms: float
    total_time_ms: float

    # Resource usage
    gpu_memory_used_mb: int
    tokens_processed: int
    tokens_generated: int

    # Quality signals
    cache_hit: bool
    batch_size: int
    early_stopped: bool

class AIInstrumentor:
    """Instrument AI inference for observability."""

    def __init__(self, tracer, metrics, logger):
        self.tracer = tracer
        self.metrics = metrics
        self.logger = logger

    async def instrument_inference(self, request: dict, inference_func):
        """Wrap inference with comprehensive instrumentation."""
        span = self.tracer.start_span("inference")
        start_time = time.time()

        timing = {}
        result = None
        error = None

        try:
            # Track queue time
            timing['queue_start'] = time.time()
            await self._wait_for_slot()
            timing['queue_end'] = time.time()

            # Track preprocessing
            timing['preprocess_start'] = time.time()
            processed_input = await self._preprocess(request)
            timing['preprocess_end'] = time.time()

            # Track inference
            timing['inference_start'] = time.time()
            result = await inference_func(processed_input)
            timing['inference_end'] = time.time()

            # Track postprocessing
            timing['postprocess_start'] = time.time()
            final_result = await self._postprocess(result)
            timing['postprocess_end'] = time.time()

            return final_result

        except Exception as e:
            error = e
            raise

        finally:
            # Record detailed metrics
            self._record_metrics(request, result, timing, error)

            # End span with attributes
            span.set_attributes({
                'model': request.get('model'),
                'tokens_in': request.get('input_tokens', 0),
                'tokens_out': len(result.get('tokens', [])) if result else 0,
                'cache_hit': result.get('cache_hit', False) if result else False,
                'error': str(error) if error else None
            })
            span.end()

    def _record_metrics(self, request, result, timing, error):
        """Record comprehensive metrics."""
        labels = {
            'model': request.get('model', 'unknown'),
            'status': 'error' if error else 'success'
        }

        # Latency histogram
        total_ms = (timing.get('postprocess_end', 0) -
                   timing.get('queue_start', 0)) * 1000
        self.metrics.histogram('inference_latency_ms', total_ms, labels)

        # Breakdown histograms
        self.metrics.histogram(
            'queue_time_ms',
            (timing.get('queue_end', 0) - timing.get('queue_start', 0)) * 1000,
            labels
        )
        self.metrics.histogram(
            'inference_time_ms',
            (timing.get('inference_end', 0) - timing.get('inference_start', 0)) * 1000,
            labels
        )

        # Token throughput
        if result:
            tokens = len(result.get('tokens', []))
            inference_seconds = timing.get('inference_end', 0) - timing.get('inference_start', 1)
            tokens_per_second = tokens / max(inference_seconds, 0.001)
            self.metrics.gauge('tokens_per_second', tokens_per_second, labels)

        # Error counter
        if error:
            self.metrics.counter('inference_errors', 1, {
                **labels,
                'error_type': type(error).__name__
            })

    async def _wait_for_slot(self):
        """Wait for available inference slot."""
        pass

    async def _preprocess(self, request: dict) -> dict:
        """Preprocess request."""
        pass

    async def _postprocess(self, result: dict) -> dict:
        """Postprocess result."""
        pass

AI Workload Queue

Complete implementation for queue-based AI workload processing:

from dataclasses import dataclass
from typing import Optional
import time

@dataclass
class QueueConfig:
    """Configuration for AI workload queue."""
    queue_name: str
    priority_levels: int
    max_queue_depth: int
    visibility_timeout_seconds: int
    dead_letter_queue: str
    ordering: str  # fifo, priority, fair

class AIWorkloadQueue:
    """Queue system for AI inference workloads."""

    def __init__(self, config: QueueConfig, queue_backend):
        self.config = config
        self.backend = queue_backend
        self.metrics = {}

    async def enqueue(
        self,
        request: dict,
        priority: int = 5,
        deadline: Optional[float] = None
    ) -> str:
        """Enqueue a request for processing."""
        message_id = self._generate_id()

        # Check queue depth
        current_depth = await self.backend.get_depth(self.config.queue_name)
        if current_depth >= self.config.max_queue_depth:
            raise QueueFullError("Queue at capacity")

        # Estimate wait time
        estimated_wait = self._estimate_wait_time(priority, current_depth)
        if deadline and time.time() + estimated_wait > deadline:
            raise DeadlineExceededError("Cannot meet deadline")

        await self.backend.send(
            queue=self.config.queue_name,
            message={
                'id': message_id,
                'request': request,
                'priority': priority,
                'deadline': deadline,
                'enqueued_at': time.time()
            },
            priority=priority
        )

        return message_id

    async def dequeue_batch(self, batch_size: int) -> list[dict]:
        """Dequeue a batch of requests for processing."""
        messages = await self.backend.receive_batch(
            queue=self.config.queue_name,
            max_messages=batch_size,
            visibility_timeout=self.config.visibility_timeout_seconds
        )

        # Filter expired messages
        valid_messages = []
        for msg in messages:
            if msg.get('deadline') and time.time() > msg['deadline']:
                await self._handle_expired(msg)
            else:
                valid_messages.append(msg)

        return valid_messages

    async def complete(self, message_id: str, result: dict):
        """Mark message as completed."""
        await self.backend.delete(self.config.queue_name, message_id)

        # Store result for retrieval
        await self._store_result(message_id, result)

    def _estimate_wait_time(self, priority: int, depth: int) -> float:
        """Estimate wait time based on queue state."""
        # Higher priority = lower wait
        priority_factor = priority / self.config.priority_levels
        base_time = depth * 0.1  # 100ms per queued request
        return base_time * priority_factor

    def _generate_id(self) -> str:
        """Generate unique message ID."""
        import uuid
        return str(uuid.uuid4())

    async def _handle_expired(self, msg: dict):
        """Handle expired message."""
        pass

    async def _store_result(self, message_id: str, result: dict):
        """Store result for retrieval."""
        pass

class QueueFullError(Exception):
    pass

class DeadlineExceededError(Exception):
    pass

Capacity Plan

Complete implementation for capacity planning:

from dataclasses import dataclass
from datetime import datetime
import math

@dataclass
class CapacityPlan:
    """Capacity planning model."""
    service: str
    current_capacity: dict  # {resource: amount}
    current_utilization: dict  # {resource: percentage}
    growth_rate: float  # Monthly growth rate
    lead_time_weeks: int  # Time to provision new capacity

    def weeks_until_critical(self, critical_utilization: float = 0.8) -> dict:
        """Calculate weeks until each resource hits critical utilization."""
        results = {}
        for resource, current_util in self.current_utilization.items():
            if current_util >= critical_utilization:
                results[resource] = 0
            elif self.growth_rate <= 0:
                results[resource] = float('inf')
            else:
                # Weeks until utilization hits critical
                months = math.log(critical_utilization / current_util) / math.log(1 + self.growth_rate)
                results[resource] = months * 4.33  # Convert to weeks

        return results

    def scaling_recommendation(self) -> dict:
        """Generate scaling recommendations."""
        weeks_to_critical = self.weeks_until_critical()

        recommendations = {}
        for resource, weeks in weeks_to_critical.items():
            if weeks <= self.lead_time_weeks:
                recommendations[resource] = {
                    'action': 'SCALE_NOW',
                    'urgency': 'critical',
                    'weeks_until_critical': weeks
                }
            elif weeks <= self.lead_time_weeks * 2:
                recommendations[resource] = {
                    'action': 'PLAN_SCALING',
                    'urgency': 'high',
                    'weeks_until_critical': weeks
                }
            elif weeks <= self.lead_time_weeks * 4:
                recommendations[resource] = {
                    'action': 'MONITOR',
                    'urgency': 'medium',
                    'weeks_until_critical': weeks
                }
            else:
                recommendations[resource] = {
                    'action': 'OK',
                    'urgency': 'low',
                    'weeks_until_critical': weeks
                }

        return recommendations

Disaster Recovery Orchestrator

Complete implementation for disaster recovery:

from dataclasses import dataclass
import time

@dataclass
class DRPlan:
    """Disaster recovery plan for AI system."""
    component: str
    rto_hours: float  # Recovery Time Objective
    rpo_hours: float  # Recovery Point Objective
    recovery_steps: list[str]
    dependencies: list[str]
    test_frequency: str

DR_PLANS = [
    DRPlan(
        component="Model serving",
        rto_hours=0.25,  # 15 minutes
        rpo_hours=0,     # No data loss - stateless
        recovery_steps=[
            "Switch traffic to standby region",
            "Verify model health in standby",
            "Update DNS/routing",
            "Monitor for errors"
        ],
        dependencies=["Model registry", "Feature store"],
        test_frequency="Monthly"
    ),
    DRPlan(
        component="Feature store",
        rto_hours=1,
        rpo_hours=0.25,  # 15 minutes of feature data
        recovery_steps=[
            "Failover to replica",
            "Verify feature freshness",
            "Backfill any missing features",
            "Validate feature quality"
        ],
        dependencies=["Data warehouse", "Streaming pipeline"],
        test_frequency="Quarterly"
    ),
    DRPlan(
        component="Training infrastructure",
        rto_hours=24,  # Can wait for training
        rpo_hours=168,  # Weekly checkpoint saves
        recovery_steps=[
            "Restore from checkpoint",
            "Verify GPU cluster in DR region",
            "Resume training from last checkpoint",
            "Validate training metrics"
        ],
        dependencies=["Training data", "Checkpoint storage"],
        test_frequency="Bi-annually"
    ),
    DRPlan(
        component="Model registry",
        rto_hours=0.5,
        rpo_hours=0,  # All versions must be preserved
        recovery_steps=[
            "Restore from backup",
            "Verify model version integrity",
            "Validate model signatures",
            "Update serving to use restored registry"
        ],
        dependencies=["Backup storage"],
        test_frequency="Monthly"
    ),
]

class DisasterRecoveryOrchestrator:
    """Orchestrate disaster recovery."""

    def __init__(self, dr_plans: list[DRPlan]):
        self.plans = {p.component: p for p in dr_plans}
        self.state = {}

    async def initiate_dr(self, affected_components: list[str]) -> dict:
        """Initiate disaster recovery for affected components."""
        # Sort by RTO - recover most critical first
        sorted_components = sorted(
            affected_components,
            key=lambda c: self.plans[c].rto_hours
        )

        results = {}
        for component in sorted_components:
            plan = self.plans[component]

            # Check dependencies are recovered first
            for dep in plan.dependencies:
                if dep in affected_components and dep not in results:
                    # Dependency not yet recovered
                    continue

            result = await self._recover_component(component, plan)
            results[component] = result

        return results

    async def _recover_component(self, component: str, plan: DRPlan) -> dict:
        """Execute recovery for a component."""
        start_time = time.time()
        steps_completed = []

        for step in plan.recovery_steps:
            try:
                await self._execute_step(component, step)
                steps_completed.append({'step': step, 'status': 'success'})
            except Exception as e:
                steps_completed.append({'step': step, 'status': 'failed', 'error': str(e)})
                break

        return {
            'component': component,
            'duration_seconds': time.time() - start_time,
            'within_rto': (time.time() - start_time) / 3600 <= plan.rto_hours,
            'steps': steps_completed
        }

    async def _execute_step(self, component: str, step: str):
        """Execute a recovery step."""
        pass