System Design at Scale Code Reference
Complete code implementations for Chapter 59: System Design at Scale.
Table of Contents
- Scaling AI Inference
- Multi-Region Deployment
- Traffic Management
- Deployment Patterns
- Chaos Engineering
- Data Architecture
- Failure Handling
- Operational Excellence
Scaling AI Inference
Dynamic Batcher
Complete implementation for batching requests dynamically for efficient GPU inference:
from dataclasses import dataclass, field
from typing import Optional
import asyncio
import time
@dataclass
class InferenceRequest:
"""A single inference request."""
id: str
input_tokens: list[int]
max_output_tokens: int
priority: int = 0
deadline: Optional[float] = None # Unix timestamp
arrival_time: float = field(default_factory=time.time)
class DynamicBatcher:
"""Batch requests dynamically for efficient inference."""
def __init__(
self,
max_batch_size: int = 32,
max_wait_ms: float = 50,
max_batch_tokens: int = 4096
):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.max_batch_tokens = max_batch_tokens
self.pending: list[InferenceRequest] = []
self.lock = asyncio.Lock()
async def add_request(self, request: InferenceRequest) -> asyncio.Future:
"""Add request and return future for result."""
future = asyncio.Future()
async with self.lock:
self.pending.append((request, future))
# Check if we should dispatch immediately
if self._should_dispatch():
batch = self._form_batch()
asyncio.create_task(self._process_batch(batch))
return future
def _should_dispatch(self) -> bool:
"""Decide if we should dispatch current batch."""
if len(self.pending) >= self.max_batch_size:
return True
if self.pending:
oldest = self.pending[0][0].arrival_time
wait_time = (time.time() - oldest) * 1000
if wait_time >= self.max_wait_ms:
return True
return False
def _form_batch(self) -> list[tuple]:
"""Form optimal batch from pending requests."""
# Sort by priority, then by deadline
self.pending.sort(
key=lambda x: (-x[0].priority, x[0].deadline or float('inf'))
)
batch = []
total_tokens = 0
for item in self.pending[:]:
request = item[0]
request_tokens = len(request.input_tokens) + request.max_output_tokens
if total_tokens + request_tokens <= self.max_batch_tokens:
batch.append(item)
total_tokens += request_tokens
self.pending.remove(item)
if len(batch) >= self.max_batch_size:
break
return batch
async def _process_batch(self, batch: list[tuple]):
"""Process a batch of requests."""
# This would call the actual model inference
# Results would be set on the futures
passInference Cache
Complete implementation for multi-level caching of inference results:
from typing import Optional
import hashlib
class InferenceCache:
"""Multi-level cache for inference results."""
def __init__(
self,
l1_cache, # In-memory, per-node
l2_cache, # Distributed cache (Redis)
embedding_cache # Vector similarity cache
):
self.l1 = l1_cache
self.l2 = l2_cache
self.embedding = embedding_cache
def cache_key(self, prompt: str, params: dict) -> str:
"""Generate cache key from prompt and parameters."""
content = f"{prompt}|{sorted(params.items())}"
return hashlib.sha256(content.encode()).hexdigest()
async def get(
self,
prompt: str,
params: dict,
similarity_threshold: float = 0.95
) -> Optional[dict]:
"""Get cached result, checking multiple cache levels."""
key = self.cache_key(prompt, params)
# L1: Exact match in local memory
result = self.l1.get(key)
if result:
return {'result': result, 'cache_hit': 'l1_exact'}
# L2: Exact match in distributed cache
result = await self.l2.get(key)
if result:
self.l1.set(key, result) # Populate L1
return {'result': result, 'cache_hit': 'l2_exact'}
# L3: Semantic similarity match
similar = await self.embedding.find_similar(
prompt, threshold=similarity_threshold
)
if similar:
# Found semantically similar cached result
return {
'result': similar['result'],
'cache_hit': 'semantic',
'similarity': similar['score'],
'original_prompt': similar['prompt']
}
return None
async def set(
self,
prompt: str,
params: dict,
result: str,
ttl: int = 3600
):
"""Cache result at all levels."""
key = self.cache_key(prompt, params)
# Set in all cache levels
self.l1.set(key, result, ttl=ttl)
await self.l2.set(key, result, ttl=ttl)
await self.embedding.index(prompt, result)GPU-Aware Load Balancer
Complete implementation for GPU-aware request routing:
from enum import Enum
from dataclasses import dataclass
class LoadBalancingStrategy(Enum):
ROUND_ROBIN = "round_robin"
LEAST_CONNECTIONS = "least_connections"
LEAST_LATENCY = "least_latency"
WEIGHTED = "weighted"
CONSISTENT_HASH = "consistent_hash"
GPU_AWARE = "gpu_aware"
@dataclass
class LoadBalancerConfig:
"""Configuration for AI-aware load balancer."""
strategy: LoadBalancingStrategy
health_check_interval_ms: int
unhealthy_threshold: int
considerations: list[str]
AI_LOAD_BALANCING = {
'inference_general': LoadBalancerConfig(
strategy=LoadBalancingStrategy.LEAST_LATENCY,
health_check_interval_ms=1000,
unhealthy_threshold=3,
considerations=[
"Route to server with lowest current latency",
"Account for GPU memory utilization",
"Consider request complexity (input length)"
]
),
'inference_stateful': LoadBalancerConfig(
strategy=LoadBalancingStrategy.CONSISTENT_HASH,
health_check_interval_ms=1000,
unhealthy_threshold=3,
considerations=[
"Route same user/session to same server",
"Enables KV cache reuse",
"Good for conversational AI"
]
),
'batch_processing': LoadBalancerConfig(
strategy=LoadBalancingStrategy.GPU_AWARE,
health_check_interval_ms=5000,
unhealthy_threshold=2,
considerations=[
"Balance GPU utilization across cluster",
"Consider memory requirements per request",
"Group similar-sized requests"
]
),
}
class GPUAwareLoadBalancer:
"""Load balancer that considers GPU state."""
def __init__(self, servers: list[str]):
self.servers = servers
self.server_stats = {}
async def select_server(self, request: dict) -> str:
"""Select best server for request."""
# Estimate request resource needs
estimated_memory = self._estimate_memory(request)
estimated_compute = self._estimate_compute(request)
# Get current server stats
await self._refresh_stats()
# Score each server
scores = {}
for server in self.servers:
stats = self.server_stats.get(server, {})
if not stats.get('healthy', False):
continue
available_memory = stats.get('gpu_memory_free', 0)
current_load = stats.get('current_requests', 0)
avg_latency = stats.get('avg_latency_ms', 1000)
# Skip if insufficient memory
if available_memory < estimated_memory:
continue
# Score based on available capacity and performance
scores[server] = (
available_memory / estimated_memory * 0.4 +
(1 / (current_load + 1)) * 0.3 +
(1000 / avg_latency) * 0.3
)
if not scores:
raise NoHealthyServerError("No servers available")
return max(scores, key=scores.get)
def _estimate_memory(self, request: dict) -> int:
"""Estimate GPU memory needed for request."""
input_tokens = request.get('input_tokens', 0)
max_output = request.get('max_tokens', 100)
# Rough estimate: ~2KB per token for KV cache
return (input_tokens + max_output) * 2048
def _estimate_compute(self, request: dict) -> float:
"""Estimate compute needed for request."""
return request.get('input_tokens', 0) + request.get('max_tokens', 100)
async def _refresh_stats(self):
"""Refresh server statistics."""
# Would poll servers for current stats
pass
class NoHealthyServerError(Exception):
passSpeculative Sampler
Complete implementation for speculative decoding:
import random
class SpeculativeSampler:
"""Speculative decoding implementation."""
def __init__(self, draft_model, target_model, gamma: int = 4):
self.draft = draft_model
self.target = target_model
self.gamma = gamma # Number of speculative tokens
async def generate(self, prompt: str, max_tokens: int) -> str:
"""Generate with speculative decoding."""
tokens = self._tokenize(prompt)
generated = []
while len(generated) < max_tokens:
# Draft model generates gamma candidates quickly
draft_tokens = await self.draft.generate(
tokens + generated,
num_tokens=self.gamma
)
# Target model verifies in parallel
target_logits = await self.target.forward(
tokens + generated + draft_tokens
)
# Accept matching tokens, reject from first mismatch
accepted = self._verify_and_accept(
draft_tokens, target_logits
)
generated.extend(accepted)
if len(accepted) < len(draft_tokens):
# Sample new token from target at rejection point
new_token = self._sample_from_target(target_logits, len(accepted))
generated.append(new_token)
return self._detokenize(generated)
def _verify_and_accept(self, draft_tokens, target_logits) -> list:
"""Accept draft tokens that match target distribution."""
accepted = []
for i, token in enumerate(draft_tokens):
target_prob = self._get_prob(target_logits[i], token)
draft_prob = self._get_draft_prob(token)
# Accept with probability min(1, target/draft)
if random.random() < min(1.0, target_prob / draft_prob):
accepted.append(token)
else:
break
return accepted
def _tokenize(self, text: str) -> list[int]:
"""Tokenize text."""
pass
def _detokenize(self, tokens: list[int]) -> str:
"""Detokenize tokens."""
pass
def _get_prob(self, logits, token) -> float:
"""Get probability of token from logits."""
pass
def _get_draft_prob(self, token) -> float:
"""Get draft model probability for token."""
pass
def _sample_from_target(self, logits, position) -> int:
"""Sample token from target distribution."""
passAI Auto-Scaler
Complete implementation for auto-scaling AI inference workloads:
from dataclasses import dataclass
@dataclass
class AutoScalingPolicy:
"""Policy for auto-scaling inference servers."""
name: str
metric: str
target_value: float
scale_up_threshold: float
scale_down_threshold: float
cooldown_seconds: int
min_instances: int
max_instances: int
class AIAutoScaler:
"""Auto-scaler for AI inference workloads."""
def __init__(self, policies: list[AutoScalingPolicy]):
self.policies = policies
self.last_scale_time = 0
async def evaluate(self, current_metrics: dict) -> dict:
"""Evaluate scaling decision."""
decisions = []
for policy in self.policies:
current_value = current_metrics.get(policy.metric)
if current_value is None:
continue
if current_value > policy.scale_up_threshold:
decisions.append({
'policy': policy.name,
'action': 'scale_up',
'reason': f"{policy.metric}={current_value} > {policy.scale_up_threshold}"
})
elif current_value < policy.scale_down_threshold:
decisions.append({
'policy': policy.name,
'action': 'scale_down',
'reason': f"{policy.metric}={current_value} < {policy.scale_down_threshold}"
})
return self._resolve_decisions(decisions)
def _resolve_decisions(self, decisions: list[dict]) -> dict:
"""Resolve potentially conflicting decisions."""
# Scale up takes priority over scale down
scale_up = any(d['action'] == 'scale_up' for d in decisions)
scale_down = all(d['action'] == 'scale_down' for d in decisions) and decisions
if scale_up:
return {'action': 'scale_up', 'decisions': decisions}
elif scale_down:
return {'action': 'scale_down', 'decisions': decisions}
else:
return {'action': 'none', 'decisions': decisions}
# Common AI scaling policies {.unnumbered}
AI_SCALING_POLICIES = [
AutoScalingPolicy(
name="gpu_utilization",
metric="avg_gpu_utilization",
target_value=0.7,
scale_up_threshold=0.8,
scale_down_threshold=0.4,
cooldown_seconds=300,
min_instances=2,
max_instances=20
),
AutoScalingPolicy(
name="request_latency",
metric="p99_latency_ms",
target_value=500,
scale_up_threshold=750,
scale_down_threshold=200,
cooldown_seconds=180,
min_instances=2,
max_instances=20
),
AutoScalingPolicy(
name="queue_depth",
metric="pending_requests",
target_value=10,
scale_up_threshold=50,
scale_down_threshold=5,
cooldown_seconds=60,
min_instances=2,
max_instances=20
),
]Multi-Region Deployment
Global Load Balancer
Complete implementation for multi-region routing and architecture patterns:
from dataclasses import dataclass
from enum import Enum
class MultiRegionPattern(Enum):
ACTIVE_PASSIVE = "active_passive"
ACTIVE_ACTIVE = "active_active"
FOLLOW_THE_SUN = "follow_the_sun"
DATA_SOVEREIGN = "data_sovereign"
@dataclass
class RegionConfig:
"""Configuration for a region."""
region: str
role: str # primary, secondary, read_replica
model_versions: list[str]
capacity: dict
data_residency: list[str] # What data can be stored/processed here
@dataclass
class MultiRegionArchitecture:
"""Multi-region deployment architecture."""
pattern: MultiRegionPattern
regions: list[RegionConfig]
routing_policy: str
failover_policy: str
data_replication: str
considerations: list[str]
MULTI_REGION_PATTERNS = {
'active_passive': MultiRegionArchitecture(
pattern=MultiRegionPattern.ACTIVE_PASSIVE,
regions=[
RegionConfig("us-east-1", "primary", ["v1.0"], {"gpus": 100}, ["us"]),
RegionConfig("us-west-2", "secondary", ["v1.0"], {"gpus": 50}, ["us"]),
],
routing_policy="Route all traffic to primary; failover to secondary",
failover_policy="Automatic failover on health check failure",
data_replication="Async replication from primary to secondary",
considerations=[
"Secondary resources often underutilized",
"Failover adds latency for affected users",
"Simpler to reason about consistency"
]
),
'active_active': MultiRegionArchitecture(
pattern=MultiRegionPattern.ACTIVE_ACTIVE,
regions=[
RegionConfig("us-east-1", "primary", ["v1.0"], {"gpus": 75}, ["us"]),
RegionConfig("eu-west-1", "primary", ["v1.0"], {"gpus": 75}, ["eu"]),
RegionConfig("ap-northeast-1", "primary", ["v1.0"], {"gpus": 50}, ["apac"]),
],
routing_policy="Route to nearest healthy region (GeoDNS)",
failover_policy="Automatic rerouting on regional failure",
data_replication="Multi-master with conflict resolution",
considerations=[
"Better resource utilization",
"Lower latency for global users",
"Conflict resolution complexity",
"Higher operational complexity"
]
),
}
class GlobalLoadBalancer:
"""Route requests across regions."""
def __init__(self, regions: list[RegionConfig]):
self.regions = {r.region: r for r in regions}
self.region_health = {r.region: True for r in regions}
async def route(self, request: dict) -> str:
"""Determine best region for request."""
user_location = request.get('geo_location')
data_requirements = request.get('data_residency_required')
# Filter to regions that can handle data requirements
eligible = [
r for r in self.regions.values()
if self._meets_residency(r, data_requirements)
and self.region_health[r.region]
]
if not eligible:
raise NoEligibleRegionError("No region meets requirements")
# Select by latency (closest region)
return min(
eligible,
key=lambda r: self._estimated_latency(r.region, user_location)
).region
def _meets_residency(self, region: RegionConfig, required: list[str]) -> bool:
if not required:
return True
return all(r in region.data_residency for r in required)
def _estimated_latency(self, region: str, user_location: str) -> float:
# Would use actual latency measurements or geographic distance
return 0.0
class NoEligibleRegionError(Exception):
passCross-Region Replicator
Complete implementation for cross-region data replication:
from dataclasses import dataclass
import asyncio
@dataclass
class ReplicationConfig:
"""Cross-region replication configuration."""
mode: str # sync, async, semi-sync
lag_tolerance_ms: int
conflict_resolution: str # last_write_wins, merge, custom
what_to_replicate: list[str]
class CrossRegionReplicator:
"""Manage cross-region data replication."""
def __init__(self, primary: str, secondaries: list[str], config: ReplicationConfig):
self.primary = primary
self.secondaries = secondaries
self.config = config
async def write(self, key: str, value: any, options: dict = None) -> dict:
"""Write with cross-region replication."""
options = options or {}
consistency = options.get('consistency', 'eventual')
# Write to primary
primary_result = await self._write_primary(key, value)
if consistency == 'strong':
# Wait for all secondaries (slow but consistent)
secondary_results = await asyncio.gather(*[
self._write_secondary(region, key, value)
for region in self.secondaries
])
return {
'primary': primary_result,
'secondaries': secondary_results,
'consistency': 'strong'
}
elif consistency == 'eventual':
# Async replication (fast but eventual)
for region in self.secondaries:
asyncio.create_task(
self._write_secondary(region, key, value)
)
return {
'primary': primary_result,
'consistency': 'eventual'
}
elif consistency == 'quorum':
# Wait for majority (balanced)
quorum_size = (len(self.secondaries) + 1) // 2 + 1
results = [primary_result]
for coro in asyncio.as_completed([
self._write_secondary(r, key, value)
for r in self.secondaries
]):
result = await coro
results.append(result)
if len(results) >= quorum_size:
break
return {
'acked_regions': len(results),
'consistency': 'quorum'
}
async def handle_conflict(self, key: str, versions: list[dict]) -> any:
"""Resolve conflict between versions."""
if self.config.conflict_resolution == 'last_write_wins':
return max(versions, key=lambda v: v['timestamp'])['value']
elif self.config.conflict_resolution == 'merge':
return self._merge_values(versions)
else:
raise ConflictError(f"Cannot resolve conflict for {key}")
async def _write_primary(self, key: str, value: any) -> dict:
"""Write to primary region."""
pass
async def _write_secondary(self, region: str, key: str, value: any) -> dict:
"""Write to secondary region."""
pass
def _merge_values(self, versions: list[dict]) -> any:
"""Merge conflicting values."""
pass
class ConflictError(Exception):
passTraffic Management
Token Bucket Rate Limiter
Complete implementation for rate limiting with token bucket algorithm:
from dataclasses import dataclass
from enum import Enum
import time
import asyncio
class RateLimitStrategy(Enum):
FIXED_WINDOW = "fixed_window"
SLIDING_WINDOW = "sliding_window"
TOKEN_BUCKET = "token_bucket"
LEAKY_BUCKET = "leaky_bucket"
@dataclass
class RateLimitConfig:
"""Rate limiting configuration."""
requests_per_second: float
burst_size: int
strategy: RateLimitStrategy
scope: str # global, per_user, per_ip, per_api_key
class TokenBucketRateLimiter:
"""Token bucket rate limiter."""
def __init__(self, config: RateLimitConfig):
self.config = config
self.buckets = {} # scope_key -> (tokens, last_update)
async def allow(self, scope_key: str) -> tuple[bool, dict]:
"""Check if request should be allowed."""
now = time.time()
if scope_key not in self.buckets:
self.buckets[scope_key] = (self.config.burst_size, now)
tokens, last_update = self.buckets[scope_key]
# Add tokens based on time passed
elapsed = now - last_update
tokens = min(
self.config.burst_size,
tokens + elapsed * self.config.requests_per_second
)
if tokens >= 1:
# Allow request
self.buckets[scope_key] = (tokens - 1, now)
return True, {
'remaining': int(tokens - 1),
'reset_in': (self.config.burst_size - tokens + 1) / self.config.requests_per_second
}
else:
# Deny request
wait_time = (1 - tokens) / self.config.requests_per_second
return False, {
'remaining': 0,
'retry_after': wait_time
}
class AdaptiveRateLimiter:
"""Adjust rate limits based on system health."""
def __init__(self, base_config: RateLimitConfig, health_checker):
self.base_config = base_config
self.health = health_checker
self.current_multiplier = 1.0
async def get_limit(self, scope_key: str) -> RateLimitConfig:
"""Get current rate limit based on system health."""
health_status = await self.health.get_status()
if health_status['load'] > 0.9:
self.current_multiplier = 0.5 # Cut limits in half
elif health_status['load'] > 0.8:
self.current_multiplier = 0.75
elif health_status['load'] < 0.5:
self.current_multiplier = 1.0 # Full limits
return RateLimitConfig(
requests_per_second=self.base_config.requests_per_second * self.current_multiplier,
burst_size=int(self.base_config.burst_size * self.current_multiplier),
strategy=self.base_config.strategy,
scope=self.base_config.scope
)AI Traffic Shaper
Complete implementation for AI workload traffic shaping:
import asyncio
import time
class AITrafficShaper:
"""Shape traffic for AI workloads."""
def __init__(self, config: dict):
self.priority_queues = {
'realtime': asyncio.PriorityQueue(), # Interactive users
'batch': asyncio.PriorityQueue(), # Background processing
'internal': asyncio.PriorityQueue(), # System operations
}
self.config = config
async def enqueue(self, request: dict) -> asyncio.Future:
"""Enqueue request with appropriate priority."""
future = asyncio.Future()
priority = self._calculate_priority(request)
queue_name = self._select_queue(request)
await self.priority_queues[queue_name].put(
(priority, time.time(), request, future)
)
return future
def _calculate_priority(self, request: dict) -> int:
"""Calculate request priority (lower = higher priority)."""
base_priority = 50
# Adjust based on user tier
tier_adjustment = {
'enterprise': -20,
'pro': -10,
'free': 0
}.get(request.get('user_tier', 'free'), 0)
# Adjust based on request type
type_adjustment = {
'completion': 0,
'embedding': -5, # Fast operations get slight priority
'fine_tuning': 20 # Background work is lower priority
}.get(request.get('request_type', 'completion'), 0)
# Adjust based on wait time (prevent starvation)
wait_penalty = 0
if 'enqueue_time' in request:
wait_seconds = time.time() - request['enqueue_time']
wait_penalty = -int(wait_seconds / 10) # Boost priority over time
return base_priority + tier_adjustment + type_adjustment + wait_penalty
def _select_queue(self, request: dict) -> str:
"""Select queue based on request characteristics."""
if request.get('stream', False):
return 'realtime'
elif request.get('batch', False):
return 'batch'
else:
return 'realtime'Deployment Patterns
Canary Deployment
Complete implementation for canary deployments:
from dataclasses import dataclass
import asyncio
@dataclass
class CanaryConfig:
"""Configuration for canary deployment."""
initial_percentage: float # Start with this % of traffic
increment: float # Increase by this % each step
step_duration_minutes: int # How long before each increment
success_criteria: dict # Metrics that must be met
rollback_criteria: dict # Metrics that trigger rollback
class CanaryDeployment:
"""Manage canary deployments for AI models."""
def __init__(self, config: CanaryConfig, metrics_client, router):
self.config = config
self.metrics = metrics_client
self.router = router
self.current_percentage = 0
self.state = "not_started"
async def start(self, new_version: str):
"""Start canary deployment."""
self.new_version = new_version
self.current_percentage = self.config.initial_percentage
self.state = "in_progress"
await self.router.set_traffic_split({
'current': 100 - self.current_percentage,
'canary': self.current_percentage
})
# Monitor and advance
asyncio.create_task(self._monitor_loop())
async def _monitor_loop(self):
"""Monitor metrics and advance canary."""
while self.state == "in_progress":
await asyncio.sleep(self.config.step_duration_minutes * 60)
metrics = await self._collect_metrics()
if self._should_rollback(metrics):
await self._rollback("Metrics degraded")
return
if self._meets_success_criteria(metrics):
self.current_percentage += self.config.increment
if self.current_percentage >= 100:
await self._complete()
return
await self.router.set_traffic_split({
'current': 100 - self.current_percentage,
'canary': self.current_percentage
})
async def _collect_metrics(self) -> dict:
"""Collect comparison metrics."""
return {
'canary_latency_p99': await self.metrics.query(
"latency_p99", filters={"version": self.new_version}
),
'canary_error_rate': await self.metrics.query(
"error_rate", filters={"version": self.new_version}
),
'current_latency_p99': await self.metrics.query(
"latency_p99", filters={"version": "current"}
),
'current_error_rate': await self.metrics.query(
"error_rate", filters={"version": "current"}
),
}
def _should_rollback(self, metrics: dict) -> bool:
"""Check if canary should be rolled back."""
criteria = self.config.rollback_criteria
# Error rate too high
if metrics['canary_error_rate'] > criteria.get('max_error_rate', 0.01):
return True
# Latency regression
latency_increase = (
metrics['canary_latency_p99'] / metrics['current_latency_p99']
)
if latency_increase > criteria.get('max_latency_increase', 1.5):
return True
return False
def _meets_success_criteria(self, metrics: dict) -> bool:
"""Check if canary meets success criteria."""
criteria = self.config.success_criteria
# Error rate acceptable
if metrics['canary_error_rate'] > criteria.get('max_error_rate', 0.005):
return False
# Latency acceptable
latency_increase = (
metrics['canary_latency_p99'] / metrics['current_latency_p99']
)
if latency_increase > criteria.get('max_latency_increase', 1.1):
return False
return True
async def _rollback(self, reason: str):
"""Rollback canary deployment."""
self.state = "rolled_back"
await self.router.set_traffic_split({'current': 100, 'canary': 0})
# Alert on-call
async def _complete(self):
"""Complete canary deployment."""
self.state = "completed"
await self.router.set_traffic_split({'current': 0, 'canary': 100})
await self.router.promote_canary_to_current()Blue-Green Deployment
Complete implementation for blue-green deployments:
class BlueGreenDeployment:
"""Manage blue-green deployments."""
def __init__(self, blue_environment, green_environment, router):
self.blue = blue_environment
self.green = green_environment
self.router = router
self.active = "blue"
async def deploy(self, new_version: str):
"""Deploy new version to inactive environment."""
inactive = "green" if self.active == "blue" else "blue"
inactive_env = self.green if inactive == "green" else self.blue
# Deploy to inactive environment
await inactive_env.deploy(new_version)
# Run smoke tests
test_results = await self._run_smoke_tests(inactive_env)
if not test_results['passed']:
raise DeploymentError(f"Smoke tests failed: {test_results['failures']}")
# Run integration tests
integration_results = await self._run_integration_tests(inactive_env)
if not integration_results['passed']:
raise DeploymentError(f"Integration tests failed")
return inactive
async def switch(self):
"""Switch traffic to the other environment."""
new_active = "green" if self.active == "blue" else "blue"
await self.router.switch_to(new_active)
self.active = new_active
# Keep old environment warm for quick rollback
return {
'active': self.active,
'rollback_available': True
}
async def rollback(self):
"""Rollback to previous environment."""
old_active = "green" if self.active == "blue" else "blue"
await self.router.switch_to(old_active)
self.active = old_active
async def _run_smoke_tests(self, env) -> dict:
"""Run smoke tests on environment."""
pass
async def _run_integration_tests(self, env) -> dict:
"""Run integration tests on environment."""
pass
class DeploymentError(Exception):
passChaos Engineering
Chaos Monkey
Complete implementation for chaos engineering experiments:
from dataclasses import dataclass
from enum import Enum
import random
import asyncio
import time
class ChaosType(Enum):
LATENCY = "latency"
ERROR = "error"
RESOURCE_EXHAUSTION = "resource_exhaustion"
DATA_CORRUPTION = "data_corruption"
NETWORK_PARTITION = "network_partition"
@dataclass
class ChaosExperiment:
"""Definition of a chaos experiment."""
name: str
chaos_type: ChaosType
target: str
parameters: dict
hypothesis: str
expected_behavior: str
abort_conditions: list[str]
class ChaosMonkey:
"""Inject chaos into AI systems."""
def __init__(self, config: dict):
self.config = config
self.active_experiments = {}
async def start_experiment(self, experiment: ChaosExperiment) -> str:
"""Start a chaos experiment."""
experiment_id = self._generate_id()
# Validate experiment is safe
if not self._is_safe(experiment):
raise UnsafeExperimentError("Experiment exceeds safety limits")
# Start injection
if experiment.chaos_type == ChaosType.LATENCY:
injector = self._create_latency_injector(experiment)
elif experiment.chaos_type == ChaosType.ERROR:
injector = self._create_error_injector(experiment)
elif experiment.chaos_type == ChaosType.RESOURCE_EXHAUSTION:
injector = self._create_resource_injector(experiment)
self.active_experiments[experiment_id] = {
'experiment': experiment,
'injector': injector,
'start_time': time.time()
}
await injector.start()
# Monitor for abort conditions
asyncio.create_task(self._monitor_experiment(experiment_id))
return experiment_id
def _create_latency_injector(self, experiment: ChaosExperiment):
"""Create latency injection."""
return LatencyInjector(
target=experiment.target,
latency_ms=experiment.parameters.get('latency_ms', 1000),
percentage=experiment.parameters.get('percentage', 0.1)
)
def _create_error_injector(self, experiment: ChaosExperiment):
"""Create error injection."""
return ErrorInjector(
target=experiment.target,
error_type=experiment.parameters.get('error_type', 'timeout'),
percentage=experiment.parameters.get('percentage', 0.1)
)
def _create_resource_injector(self, experiment: ChaosExperiment):
"""Create resource exhaustion injector."""
pass
def _is_safe(self, experiment: ChaosExperiment) -> bool:
"""Check if experiment is safe to run."""
# Don't run in production without explicit approval
if self.config.get('environment') == 'production':
if not experiment.parameters.get('production_approved'):
return False
# Limit blast radius
max_percentage = self.config.get('max_chaos_percentage', 0.1)
if experiment.parameters.get('percentage', 0) > max_percentage:
return False
return True
async def _monitor_experiment(self, experiment_id: str):
"""Monitor experiment and abort if needed."""
exp_data = self.active_experiments[experiment_id]
experiment = exp_data['experiment']
while experiment_id in self.active_experiments:
await asyncio.sleep(10)
metrics = await self._get_metrics(experiment.target)
for condition in experiment.abort_conditions:
if self._should_abort(condition, metrics):
await self.stop_experiment(experiment_id, reason=condition)
return
async def stop_experiment(self, experiment_id: str, reason: str = None):
"""Stop a chaos experiment."""
if experiment_id not in self.active_experiments:
return
exp_data = self.active_experiments.pop(experiment_id)
await exp_data['injector'].stop()
# Log results
return {
'experiment': exp_data['experiment'].name,
'duration': time.time() - exp_data['start_time'],
'stop_reason': reason or 'manual'
}
def _generate_id(self) -> str:
"""Generate unique experiment ID."""
import uuid
return str(uuid.uuid4())
async def _get_metrics(self, target: str) -> dict:
"""Get metrics for target."""
pass
def _should_abort(self, condition: str, metrics: dict) -> bool:
"""Check if abort condition is met."""
pass
class LatencyInjector:
"""Inject latency into service calls."""
def __init__(self, target: str, latency_ms: int, percentage: float):
self.target = target
self.latency_ms = latency_ms
self.percentage = percentage
self.active = False
async def start(self):
self.active = True
async def stop(self):
self.active = False
async def maybe_inject(self):
"""Potentially inject latency."""
if self.active and random.random() < self.percentage:
await asyncio.sleep(self.latency_ms / 1000)
class ErrorInjector:
"""Inject errors into service calls."""
def __init__(self, target: str, error_type: str, percentage: float):
self.target = target
self.error_type = error_type
self.percentage = percentage
self.active = False
async def start(self):
self.active = True
async def stop(self):
self.active = False
class UnsafeExperimentError(Exception):
passData Architecture
Feature Store
Complete implementation for feature store operations:
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
@dataclass
class FeatureDefinition:
"""Definition of a feature."""
name: str
dtype: str
description: str
entity: str # What entity this feature belongs to (user, item, etc.)
freshness_requirement: str # real-time, hourly, daily
computation: str # SQL, Python, streaming
@dataclass
class FeatureView:
"""A collection of related features."""
name: str
entities: list[str]
features: list[FeatureDefinition]
online: bool # Available for real-time serving
offline: bool # Available for batch/training
ttl: Optional[int] = None # Time-to-live in seconds
class FeatureStore:
"""Interface for feature store operations."""
def __init__(self, online_store, offline_store, registry):
self.online = online_store
self.offline = offline_store
self.registry = registry
async def get_online_features(
self,
feature_view: str,
entity_keys: dict[str, Any]
) -> dict[str, Any]:
"""Get features for real-time inference."""
view = self.registry.get_view(feature_view)
if not view.online:
raise ValueError(f"Feature view {feature_view} not available online")
features = await self.online.get(
view=feature_view,
keys=entity_keys
)
# Check freshness
for feature_name, value in features.items():
feature_def = next(
f for f in view.features if f.name == feature_name
)
if not self._check_freshness(value, feature_def):
# Log stale feature, potentially fall back
pass
return features
def get_training_features(
self,
feature_view: str,
entity_df, # DataFrame with entity keys and timestamps
features: list[str]
):
"""Get point-in-time correct features for training."""
view = self.registry.get_view(feature_view)
if not view.offline:
raise ValueError(f"Feature view {feature_view} not available offline")
# Point-in-time join to prevent data leakage
return self.offline.get_historical_features(
view=feature_view,
entity_df=entity_df,
features=features
)
def _check_freshness(self, value: dict, feature_def: FeatureDefinition) -> bool:
"""Check if feature value meets freshness requirement."""
if 'timestamp' not in value:
return True
age = datetime.utcnow() - value['timestamp']
max_age = {
'real-time': 60, # 1 minute
'hourly': 3600,
'daily': 86400
}.get(feature_def.freshness_requirement, 86400)
return age.total_seconds() <= max_ageSkew Detector
Complete implementation for detecting training-serving skew:
from dataclasses import dataclass
import math
@dataclass
class SkewSource:
"""Source of training-serving skew."""
category: str
description: str
detection: str
prevention: str
SKEW_SOURCES = [
SkewSource(
category="Feature computation",
description="Features computed differently in training vs serving",
detection="Compare feature distributions between batch and online",
prevention="Use feature store with shared computation logic"
),
SkewSource(
category="Data freshness",
description="Training uses stale data that doesn't reflect production",
detection="Monitor feature staleness at serving time",
prevention="Regular retraining; real-time feature updates"
),
SkewSource(
category="Preprocessing",
description="Different tokenization, normalization, etc.",
detection="Hash preprocessing outputs; compare training vs serving",
prevention="Package preprocessing with model; use same code path"
),
SkewSource(
category="Data leakage",
description="Training features include future information",
detection="Point-in-time validation; hold-out by time",
prevention="Strict point-in-time feature joins"
),
]
class SkewDetector:
"""Detect training-serving skew."""
def __init__(self, training_stats: dict, serving_monitor):
self.training_stats = training_stats
self.serving = serving_monitor
async def check_skew(self, feature_name: str) -> dict:
"""Check for skew in a feature."""
training = self.training_stats.get(feature_name, {})
serving = await self.serving.get_stats(feature_name)
skew_metrics = {}
# Distribution shift (PSI - Population Stability Index)
psi = self._calculate_psi(
training.get('histogram', []),
serving.get('histogram', [])
)
skew_metrics['psi'] = psi
skew_metrics['psi_warning'] = psi > 0.1
skew_metrics['psi_critical'] = psi > 0.25
# Mean shift
mean_shift = abs(training.get('mean', 0) - serving.get('mean', 0))
std = training.get('std', 1)
skew_metrics['mean_shift_stds'] = mean_shift / std if std > 0 else 0
# Missing rate change
missing_rate_change = abs(
training.get('missing_rate', 0) - serving.get('missing_rate', 0)
)
skew_metrics['missing_rate_change'] = missing_rate_change
return skew_metrics
def _calculate_psi(self, expected: list, actual: list) -> float:
"""Calculate Population Stability Index."""
if len(expected) != len(actual):
return float('inf')
psi = 0
for e, a in zip(expected, actual):
e = max(e, 0.0001) # Avoid division by zero
a = max(a, 0.0001)
psi += (a - e) * math.log(a / e)
return psiData Version Manager
Complete implementation for dataset versioning:
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
@dataclass
class DataVersion:
"""A version of a dataset."""
dataset_id: str
version: str
created_at: datetime
schema_hash: str
row_count: int
checksum: str
parent_version: Optional[str]
transformations: list[str]
class DataVersionManager:
"""Manage dataset versions at scale."""
def __init__(self, catalog, storage):
self.catalog = catalog
self.storage = storage
async def create_version(
self,
dataset_id: str,
data,
parent_version: str = None,
transformations: list[str] = None
) -> DataVersion:
"""Create a new version of a dataset."""
# Compute metadata
schema_hash = self._hash_schema(data)
row_count = len(data)
checksum = self._compute_checksum(data)
# Generate version ID
version = self._generate_version_id()
# Store data
storage_path = f"{dataset_id}/{version}"
await self.storage.write(storage_path, data)
# Create version record
version_record = DataVersion(
dataset_id=dataset_id,
version=version,
created_at=datetime.utcnow(),
schema_hash=schema_hash,
row_count=row_count,
checksum=checksum,
parent_version=parent_version,
transformations=transformations or []
)
await self.catalog.register_version(version_record)
return version_record
async def get_lineage(self, dataset_id: str, version: str) -> list[DataVersion]:
"""Get full lineage of a dataset version."""
lineage = []
current = await self.catalog.get_version(dataset_id, version)
while current:
lineage.append(current)
if current.parent_version:
current = await self.catalog.get_version(
dataset_id, current.parent_version
)
else:
current = None
return lineage
async def compare_versions(
self,
dataset_id: str,
version_a: str,
version_b: str
) -> dict:
"""Compare two versions of a dataset."""
v_a = await self.catalog.get_version(dataset_id, version_a)
v_b = await self.catalog.get_version(dataset_id, version_b)
return {
'schema_changed': v_a.schema_hash != v_b.schema_hash,
'row_count_diff': v_b.row_count - v_a.row_count,
'same_data': v_a.checksum == v_b.checksum,
'transformations_between': self._get_transformations_between(v_a, v_b)
}
def _hash_schema(self, data) -> str:
"""Hash the schema of data."""
pass
def _compute_checksum(self, data) -> str:
"""Compute checksum of data."""
pass
def _generate_version_id(self) -> str:
"""Generate unique version ID."""
import uuid
return str(uuid.uuid4())[:8]
def _get_transformations_between(self, v_a: DataVersion, v_b: DataVersion) -> list[str]:
"""Get transformations between two versions."""
passFailure Handling
Circuit Breaker
Complete implementation for circuit breaker pattern:
import time
from enum import Enum
from dataclasses import dataclass
from typing import Callable, Any
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing fast
HALF_OPEN = "half_open" # Testing recovery
@dataclass
class CircuitBreakerConfig:
failure_threshold: int = 5
success_threshold: int = 3
timeout: float = 30.0 # Seconds before trying again
half_open_max_calls: int = 3
class CircuitBreaker:
"""Circuit breaker for external dependencies."""
def __init__(self, name: str, config: CircuitBreakerConfig = None):
self.name = name
self.config = config or CircuitBreakerConfig()
self.state = CircuitState.CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time = 0
self.half_open_calls = 0
async def call(
self,
func: Callable,
*args,
fallback: Callable = None,
**kwargs
) -> Any:
"""Execute function with circuit breaker protection."""
if self.state == CircuitState.OPEN:
if self._should_try_reset():
self.state = CircuitState.HALF_OPEN
self.half_open_calls = 0
else:
if fallback:
return await fallback(*args, **kwargs)
raise CircuitOpenError(f"Circuit {self.name} is open")
try:
if self.state == CircuitState.HALF_OPEN:
self.half_open_calls += 1
if self.half_open_calls > self.config.half_open_max_calls:
self.state = CircuitState.OPEN
raise CircuitOpenError(f"Circuit {self.name} half-open limit reached")
result = await func(*args, **kwargs)
self._record_success()
return result
except Exception as e:
self._record_failure()
if fallback:
return await fallback(*args, **kwargs)
raise
def _record_success(self):
self.failure_count = 0
if self.state == CircuitState.HALF_OPEN:
self.success_count += 1
if self.success_count >= self.config.success_threshold:
self.state = CircuitState.CLOSED
self.success_count = 0
def _record_failure(self):
self.failure_count += 1
self.last_failure_time = time.time()
self.success_count = 0
if self.failure_count >= self.config.failure_threshold:
self.state = CircuitState.OPEN
def _should_try_reset(self) -> bool:
return time.time() - self.last_failure_time >= self.config.timeout
class CircuitOpenError(Exception):
passGraceful Degradation
Complete implementation for graceful degradation:
from dataclasses import dataclass
@dataclass
class DegradationLevel:
"""A level of service degradation."""
name: str
description: str
trigger: str
actions: list[str]
user_impact: str
DEGRADATION_LEVELS = [
DegradationLevel(
name="NORMAL",
description="Full service",
trigger="All systems healthy",
actions=["Full feature set"],
user_impact="None"
),
DegradationLevel(
name="DEGRADED_L1",
description="Non-critical features disabled",
trigger="Secondary services unhealthy",
actions=[
"Disable personalization",
"Use cached recommendations",
"Reduce logging verbosity"
],
user_impact="Slightly less relevant results"
),
DegradationLevel(
name="DEGRADED_L2",
description="Reduced functionality",
trigger="Primary services stressed",
actions=[
"Use smaller/faster model",
"Reduce max output length",
"Aggressive caching",
"Rate limit non-essential endpoints"
],
user_impact="Slower responses, reduced quality"
),
DegradationLevel(
name="DEGRADED_L3",
description="Minimal functionality",
trigger="Critical capacity issues",
actions=[
"Static fallback responses",
"Queue requests",
"Reject low-priority traffic"
],
user_impact="Limited functionality, potential errors"
),
DegradationLevel(
name="MAINTENANCE",
description="Service unavailable",
trigger="Complete outage or maintenance",
actions=[
"Return maintenance page",
"Queue requests for later"
],
user_impact="Service unavailable"
),
]
class GracefulDegradation:
"""Manage graceful degradation."""
def __init__(self, health_checker):
self.health = health_checker
self.current_level = "NORMAL"
async def check_and_adjust(self) -> str:
"""Check health and adjust degradation level."""
health_status = await self.health.get_status()
if health_status['critical_services_healthy']:
if health_status['all_services_healthy']:
new_level = "NORMAL"
else:
new_level = "DEGRADED_L1"
elif health_status['primary_services_stressed']:
new_level = "DEGRADED_L2"
elif health_status['capacity_critical']:
new_level = "DEGRADED_L3"
else:
new_level = "MAINTENANCE"
if new_level != self.current_level:
await self._transition(self.current_level, new_level)
self.current_level = new_level
return self.current_level
async def _transition(self, from_level: str, to_level: str):
"""Handle level transition."""
# Log transition
# Alert if degrading
# Apply actions for new level
passOperational Excellence
AI Service Mesh Config
Complete implementation for configuring service mesh for AI workloads:
from dataclasses import dataclass
@dataclass
class ServiceMeshFeature:
"""Features relevant for AI workloads."""
feature: str
benefit_for_ai: str
implementation: str
SERVICE_MESH_AI_FEATURES = [
ServiceMeshFeature(
feature="Automatic retries",
benefit_for_ai="Handle transient GPU failures gracefully",
implementation="Configure retry budget to avoid thundering herd"
),
ServiceMeshFeature(
feature="Circuit breaking",
benefit_for_ai="Prevent cascade failures when model serving degrades",
implementation="Set thresholds based on inference latency patterns"
),
ServiceMeshFeature(
feature="Load balancing",
benefit_for_ai="GPU-aware routing based on memory/compute availability",
implementation="Custom load balancing policies for AI workloads"
),
ServiceMeshFeature(
feature="mTLS",
benefit_for_ai="Secure model serving and feature transmission",
implementation="Automatic encryption without code changes"
),
ServiceMeshFeature(
feature="Observability",
benefit_for_ai="Distributed tracing across inference pipeline",
implementation="Automatic span creation for all service calls"
),
ServiceMeshFeature(
feature="Traffic splitting",
benefit_for_ai="Canary deployments for model updates",
implementation="Percentage-based routing to model versions"
),
]
class AIServiceMeshConfig:
"""Configure service mesh for AI workloads."""
def __init__(self, mesh_client):
self.mesh = mesh_client
async def configure_inference_service(self, service_name: str):
"""Configure mesh for inference service."""
# Retry policy - be careful with non-idempotent operations
await self.mesh.set_retry_policy(service_name, {
'attempts': 3,
'per_try_timeout': '10s',
'retry_on': ['5xx', 'reset', 'connect-failure'],
# Don't retry if response started streaming
'retriable_headers': ['x-retry-safe']
})
# Circuit breaker
await self.mesh.set_circuit_breaker(service_name, {
'max_connections': 1000,
'max_pending_requests': 100,
'max_requests': 1000,
'max_retries': 3,
'consecutive_5xx': 5, # Open after 5 consecutive errors
'interval': '10s',
'base_ejection_time': '30s'
})
# Timeout - longer for AI workloads
await self.mesh.set_timeout(service_name, {
'timeout': '60s', # LLM generation can be slow
'idle_timeout': '300s' # Keep connections for batching
})
# Health check
await self.mesh.set_health_check(service_name, {
'path': '/health',
'interval': '5s',
'timeout': '3s',
'unhealthy_threshold': 3,
'healthy_threshold': 2
})AI Instrumentor
Complete implementation for comprehensive AI inference instrumentation:
from dataclasses import dataclass
from typing import Optional
import time
@dataclass
class InferenceSpan:
"""Detailed tracing for inference requests."""
trace_id: str
request_id: str
model_name: str
model_version: str
# Timing breakdown
queue_time_ms: float
preprocessing_time_ms: float
inference_time_ms: float
postprocessing_time_ms: float
total_time_ms: float
# Resource usage
gpu_memory_used_mb: int
tokens_processed: int
tokens_generated: int
# Quality signals
cache_hit: bool
batch_size: int
early_stopped: bool
class AIInstrumentor:
"""Instrument AI inference for observability."""
def __init__(self, tracer, metrics, logger):
self.tracer = tracer
self.metrics = metrics
self.logger = logger
async def instrument_inference(self, request: dict, inference_func):
"""Wrap inference with comprehensive instrumentation."""
span = self.tracer.start_span("inference")
start_time = time.time()
timing = {}
result = None
error = None
try:
# Track queue time
timing['queue_start'] = time.time()
await self._wait_for_slot()
timing['queue_end'] = time.time()
# Track preprocessing
timing['preprocess_start'] = time.time()
processed_input = await self._preprocess(request)
timing['preprocess_end'] = time.time()
# Track inference
timing['inference_start'] = time.time()
result = await inference_func(processed_input)
timing['inference_end'] = time.time()
# Track postprocessing
timing['postprocess_start'] = time.time()
final_result = await self._postprocess(result)
timing['postprocess_end'] = time.time()
return final_result
except Exception as e:
error = e
raise
finally:
# Record detailed metrics
self._record_metrics(request, result, timing, error)
# End span with attributes
span.set_attributes({
'model': request.get('model'),
'tokens_in': request.get('input_tokens', 0),
'tokens_out': len(result.get('tokens', [])) if result else 0,
'cache_hit': result.get('cache_hit', False) if result else False,
'error': str(error) if error else None
})
span.end()
def _record_metrics(self, request, result, timing, error):
"""Record comprehensive metrics."""
labels = {
'model': request.get('model', 'unknown'),
'status': 'error' if error else 'success'
}
# Latency histogram
total_ms = (timing.get('postprocess_end', 0) -
timing.get('queue_start', 0)) * 1000
self.metrics.histogram('inference_latency_ms', total_ms, labels)
# Breakdown histograms
self.metrics.histogram(
'queue_time_ms',
(timing.get('queue_end', 0) - timing.get('queue_start', 0)) * 1000,
labels
)
self.metrics.histogram(
'inference_time_ms',
(timing.get('inference_end', 0) - timing.get('inference_start', 0)) * 1000,
labels
)
# Token throughput
if result:
tokens = len(result.get('tokens', []))
inference_seconds = timing.get('inference_end', 0) - timing.get('inference_start', 1)
tokens_per_second = tokens / max(inference_seconds, 0.001)
self.metrics.gauge('tokens_per_second', tokens_per_second, labels)
# Error counter
if error:
self.metrics.counter('inference_errors', 1, {
**labels,
'error_type': type(error).__name__
})
async def _wait_for_slot(self):
"""Wait for available inference slot."""
pass
async def _preprocess(self, request: dict) -> dict:
"""Preprocess request."""
pass
async def _postprocess(self, result: dict) -> dict:
"""Postprocess result."""
passAI Workload Queue
Complete implementation for queue-based AI workload processing:
from dataclasses import dataclass
from typing import Optional
import time
@dataclass
class QueueConfig:
"""Configuration for AI workload queue."""
queue_name: str
priority_levels: int
max_queue_depth: int
visibility_timeout_seconds: int
dead_letter_queue: str
ordering: str # fifo, priority, fair
class AIWorkloadQueue:
"""Queue system for AI inference workloads."""
def __init__(self, config: QueueConfig, queue_backend):
self.config = config
self.backend = queue_backend
self.metrics = {}
async def enqueue(
self,
request: dict,
priority: int = 5,
deadline: Optional[float] = None
) -> str:
"""Enqueue a request for processing."""
message_id = self._generate_id()
# Check queue depth
current_depth = await self.backend.get_depth(self.config.queue_name)
if current_depth >= self.config.max_queue_depth:
raise QueueFullError("Queue at capacity")
# Estimate wait time
estimated_wait = self._estimate_wait_time(priority, current_depth)
if deadline and time.time() + estimated_wait > deadline:
raise DeadlineExceededError("Cannot meet deadline")
await self.backend.send(
queue=self.config.queue_name,
message={
'id': message_id,
'request': request,
'priority': priority,
'deadline': deadline,
'enqueued_at': time.time()
},
priority=priority
)
return message_id
async def dequeue_batch(self, batch_size: int) -> list[dict]:
"""Dequeue a batch of requests for processing."""
messages = await self.backend.receive_batch(
queue=self.config.queue_name,
max_messages=batch_size,
visibility_timeout=self.config.visibility_timeout_seconds
)
# Filter expired messages
valid_messages = []
for msg in messages:
if msg.get('deadline') and time.time() > msg['deadline']:
await self._handle_expired(msg)
else:
valid_messages.append(msg)
return valid_messages
async def complete(self, message_id: str, result: dict):
"""Mark message as completed."""
await self.backend.delete(self.config.queue_name, message_id)
# Store result for retrieval
await self._store_result(message_id, result)
def _estimate_wait_time(self, priority: int, depth: int) -> float:
"""Estimate wait time based on queue state."""
# Higher priority = lower wait
priority_factor = priority / self.config.priority_levels
base_time = depth * 0.1 # 100ms per queued request
return base_time * priority_factor
def _generate_id(self) -> str:
"""Generate unique message ID."""
import uuid
return str(uuid.uuid4())
async def _handle_expired(self, msg: dict):
"""Handle expired message."""
pass
async def _store_result(self, message_id: str, result: dict):
"""Store result for retrieval."""
pass
class QueueFullError(Exception):
pass
class DeadlineExceededError(Exception):
passCapacity Plan
Complete implementation for capacity planning:
from dataclasses import dataclass
from datetime import datetime
import math
@dataclass
class CapacityPlan:
"""Capacity planning model."""
service: str
current_capacity: dict # {resource: amount}
current_utilization: dict # {resource: percentage}
growth_rate: float # Monthly growth rate
lead_time_weeks: int # Time to provision new capacity
def weeks_until_critical(self, critical_utilization: float = 0.8) -> dict:
"""Calculate weeks until each resource hits critical utilization."""
results = {}
for resource, current_util in self.current_utilization.items():
if current_util >= critical_utilization:
results[resource] = 0
elif self.growth_rate <= 0:
results[resource] = float('inf')
else:
# Weeks until utilization hits critical
months = math.log(critical_utilization / current_util) / math.log(1 + self.growth_rate)
results[resource] = months * 4.33 # Convert to weeks
return results
def scaling_recommendation(self) -> dict:
"""Generate scaling recommendations."""
weeks_to_critical = self.weeks_until_critical()
recommendations = {}
for resource, weeks in weeks_to_critical.items():
if weeks <= self.lead_time_weeks:
recommendations[resource] = {
'action': 'SCALE_NOW',
'urgency': 'critical',
'weeks_until_critical': weeks
}
elif weeks <= self.lead_time_weeks * 2:
recommendations[resource] = {
'action': 'PLAN_SCALING',
'urgency': 'high',
'weeks_until_critical': weeks
}
elif weeks <= self.lead_time_weeks * 4:
recommendations[resource] = {
'action': 'MONITOR',
'urgency': 'medium',
'weeks_until_critical': weeks
}
else:
recommendations[resource] = {
'action': 'OK',
'urgency': 'low',
'weeks_until_critical': weeks
}
return recommendationsDisaster Recovery Orchestrator
Complete implementation for disaster recovery:
from dataclasses import dataclass
import time
@dataclass
class DRPlan:
"""Disaster recovery plan for AI system."""
component: str
rto_hours: float # Recovery Time Objective
rpo_hours: float # Recovery Point Objective
recovery_steps: list[str]
dependencies: list[str]
test_frequency: str
DR_PLANS = [
DRPlan(
component="Model serving",
rto_hours=0.25, # 15 minutes
rpo_hours=0, # No data loss - stateless
recovery_steps=[
"Switch traffic to standby region",
"Verify model health in standby",
"Update DNS/routing",
"Monitor for errors"
],
dependencies=["Model registry", "Feature store"],
test_frequency="Monthly"
),
DRPlan(
component="Feature store",
rto_hours=1,
rpo_hours=0.25, # 15 minutes of feature data
recovery_steps=[
"Failover to replica",
"Verify feature freshness",
"Backfill any missing features",
"Validate feature quality"
],
dependencies=["Data warehouse", "Streaming pipeline"],
test_frequency="Quarterly"
),
DRPlan(
component="Training infrastructure",
rto_hours=24, # Can wait for training
rpo_hours=168, # Weekly checkpoint saves
recovery_steps=[
"Restore from checkpoint",
"Verify GPU cluster in DR region",
"Resume training from last checkpoint",
"Validate training metrics"
],
dependencies=["Training data", "Checkpoint storage"],
test_frequency="Bi-annually"
),
DRPlan(
component="Model registry",
rto_hours=0.5,
rpo_hours=0, # All versions must be preserved
recovery_steps=[
"Restore from backup",
"Verify model version integrity",
"Validate model signatures",
"Update serving to use restored registry"
],
dependencies=["Backup storage"],
test_frequency="Monthly"
),
]
class DisasterRecoveryOrchestrator:
"""Orchestrate disaster recovery."""
def __init__(self, dr_plans: list[DRPlan]):
self.plans = {p.component: p for p in dr_plans}
self.state = {}
async def initiate_dr(self, affected_components: list[str]) -> dict:
"""Initiate disaster recovery for affected components."""
# Sort by RTO - recover most critical first
sorted_components = sorted(
affected_components,
key=lambda c: self.plans[c].rto_hours
)
results = {}
for component in sorted_components:
plan = self.plans[component]
# Check dependencies are recovered first
for dep in plan.dependencies:
if dep in affected_components and dep not in results:
# Dependency not yet recovered
continue
result = await self._recover_component(component, plan)
results[component] = result
return results
async def _recover_component(self, component: str, plan: DRPlan) -> dict:
"""Execute recovery for a component."""
start_time = time.time()
steps_completed = []
for step in plan.recovery_steps:
try:
await self._execute_step(component, step)
steps_completed.append({'step': step, 'status': 'success'})
except Exception as e:
steps_completed.append({'step': step, 'status': 'failed', 'error': str(e)})
break
return {
'component': component,
'duration_seconds': time.time() - start_time,
'within_rto': (time.time() - start_time) / 3600 <= plan.rto_hours,
'steps': steps_completed
}
async def _execute_step(self, component: str, step: str):
"""Execute a recovery step."""
pass