Multimodal Systems - Complete Code Reference
This reference contains the full implementations for Chapter 53: Multimodal Systems. Each section provides production-ready code that can be adapted for your specific use cases.
Table of Contents
- Vision-Language Model Integration
- Image Optimization
- Vision Result Validation
- Video Processing
- Audio Transcription
- Speaker Diarization
- Text-to-Speech Processing
- Document Processing
- Hybrid Document Processing
- Multi-Page Document Processing
- Multimodal Embeddings and Indexing
- Late Fusion Search
- Multimodal Document RAG
- Multimodal Orchestration
- Cost-Aware Processing
Vision-Language Model Integration
Complete implementation for processing images with vision-language models, including encoding, single-image analysis, and multi-image comparison.
import anthropic
import base64
from pathlib import Path
class VisionProcessor:
"""Process images with vision-language models."""
def __init__(self, client: anthropic.Anthropic):
self.client = client
def encode_image(self, image_path: str) -> dict:
"""Encode image for API submission."""
path = Path(image_path)
suffix = path.suffix.lower()
media_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
if suffix not in media_types:
raise ValueError(f"Unsupported image format: {suffix}")
with open(path, 'rb') as f:
image_data = base64.standard_b64encode(f.read()).decode('utf-8')
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_types[suffix],
"data": image_data
}
}
def analyze_image(
self,
image_path: str,
prompt: str,
system_prompt: str = None
) -> str:
"""Analyze an image with a text prompt."""
image_content = self.encode_image(image_path)
messages = [
{
"role": "user",
"content": [
image_content,
{"type": "text", "text": prompt}
]
}
]
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4096,
system=system_prompt or "You are a helpful assistant that analyzes images.",
messages=messages
)
return response.content[0].text
def compare_images(
self,
image_paths: list[str],
prompt: str
) -> str:
"""Compare multiple images."""
content = []
for i, path in enumerate(image_paths):
content.append({"type": "text", "text": f"Image {i+1}:"})
content.append(self.encode_image(path))
content.append({"type": "text", "text": prompt})
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4096,
messages=[{"role": "user", "content": content}]
)
return response.content[0].textImage Optimization
Optimize images for VLM processing by resizing based on task requirements and estimating token usage.
from PIL import Image
import io
class ImageOptimizer:
"""Optimize images for VLM processing."""
# Approximate token costs per resolution tier
RESOLUTION_TIERS = {
'low': (512, 512), # ~85 tokens
'medium': (1024, 1024), # ~340 tokens
'high': (2048, 2048), # ~1360 tokens
}
def optimize_for_task(
self,
image_path: str,
task: str
) -> bytes:
"""Optimize image resolution based on task requirements."""
task_requirements = {
'ocr': 'high', # Need detail for text
'classification': 'low', # Overall content is enough
'object_detection': 'medium',
'chart_reading': 'high',
'general_qa': 'medium',
}
tier = task_requirements.get(task, 'medium')
max_size = self.RESOLUTION_TIERS[tier]
img = Image.open(image_path)
# Resize if larger than tier allows
if img.width > max_size[0] or img.height > max_size[1]:
img.thumbnail(max_size, Image.Resampling.LANCZOS)
# Convert to RGB if necessary (remove alpha channel)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# Compress to JPEG for smaller payload
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=85)
return buffer.getvalue()
def estimate_tokens(self, image_path: str) -> int:
"""Estimate token count for an image."""
img = Image.open(image_path)
# Rough estimation: ~0.5 tokens per 32x32 patch
patches = (img.width // 32) * (img.height // 32)
return int(patches * 0.5)Vision Result Validation
Validate vision model outputs, detect hallucinations, and implement redundancy-based verification.
import re
class VisionResultValidator:
"""Validate vision model outputs."""
def validate_ocr_result(
self,
extracted_text: str,
image_path: str,
expected_patterns: list[str] = None
) -> dict:
"""Validate OCR extraction results."""
issues = []
# Check for common hallucination indicators
hallucination_phrases = [
"I cannot read",
"text is not clear",
"appears to say",
"might be",
]
confidence = 1.0
for phrase in hallucination_phrases:
if phrase.lower() in extracted_text.lower():
confidence -= 0.2
issues.append(f"Uncertainty indicator: '{phrase}'")
# Check expected patterns if provided
if expected_patterns:
for pattern in expected_patterns:
if not re.search(pattern, extracted_text):
issues.append(f"Expected pattern not found: {pattern}")
confidence -= 0.1
# Check for empty or minimal extraction
if len(extracted_text.strip()) < 10:
issues.append("Extraction returned minimal text")
confidence = 0.3
return {
'valid': confidence > 0.5 and len(issues) < 3,
'confidence': max(0, confidence),
'issues': issues,
'text': extracted_text
}
def validate_with_redundancy(
self,
image_path: str,
prompt: str,
processors: list,
agreement_threshold: float = 0.8
) -> dict:
"""Use multiple models/calls to validate results."""
results = []
for processor in processors:
result = processor.analyze_image(image_path, prompt)
results.append(result)
# Check agreement (simplified - production would use semantic similarity)
if len(results) >= 2:
# Compare key facts extracted
agreement = self._calculate_agreement(results)
return {
'results': results,
'agreement': agreement,
'confident': agreement >= agreement_threshold,
'consensus': results[0] if agreement >= agreement_threshold else None
}
return {'results': results, 'confident': False}Video Processing
Extract frames from video using various sampling strategies and analyze video content with VLMs.
import cv2
from typing import Generator
import numpy as np
class VideoProcessor:
"""Process video for VLM analysis."""
def __init__(self, vlm_client):
self.vlm = vlm_client
def extract_frames(
self,
video_path: str,
strategy: str = "uniform",
max_frames: int = 20,
fps: float = 1.0
) -> Generator[tuple[float, np.ndarray], None, None]:
"""Extract frames from video using specified strategy."""
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / video_fps
if strategy == "uniform":
# Sample at fixed intervals
interval = max(1, int(video_fps / fps))
frame_indices = range(0, total_frames, interval)
elif strategy == "keyframes":
# Extract scene changes
frame_indices = self._detect_scene_changes(video_path, max_frames)
else:
raise ValueError(f"Unknown strategy: {strategy}")
frames_yielded = 0
for idx in frame_indices:
if frames_yielded >= max_frames:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
timestamp = idx / video_fps
yield timestamp, frame
frames_yielded += 1
cap.release()
def _detect_scene_changes(
self,
video_path: str,
max_scenes: int
) -> list[int]:
"""Detect scene changes using frame differencing."""
cap = cv2.VideoCapture(video_path)
scene_frames = [0] # Always include first frame
prev_frame = None
frame_idx = 0
threshold = 30.0 # Adjust based on content
while True:
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if prev_frame is not None:
diff = cv2.absdiff(prev_frame, gray)
mean_diff = np.mean(diff)
if mean_diff > threshold:
scene_frames.append(frame_idx)
prev_frame = gray
frame_idx += 1
cap.release()
# If too many scenes, sample uniformly from them
if len(scene_frames) > max_scenes:
indices = np.linspace(0, len(scene_frames) - 1, max_scenes, dtype=int)
scene_frames = [scene_frames[i] for i in indices]
return scene_frames
def analyze_video(
self,
video_path: str,
prompt: str,
max_frames: int = 10
) -> dict:
"""Analyze video content with VLM."""
frames = list(self.extract_frames(
video_path,
strategy="uniform",
max_frames=max_frames
))
# Build content with frames and timestamps
content = [{"type": "text", "text": f"Video analysis ({len(frames)} frames):"}]
for timestamp, frame in frames:
# Convert frame to JPEG bytes
_, buffer = cv2.imencode('.jpg', frame)
frame_bytes = buffer.tobytes()
content.append({
"type": "text",
"text": f"\n[Frame at {timestamp:.1f}s]:"
})
content.append(self.vlm.encode_image_bytes(frame_bytes))
content.append({"type": "text", "text": f"\n\n{prompt}"})
response = self.vlm.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4096,
messages=[{"role": "user", "content": content}]
)
return {
'analysis': response.content[0].text,
'frames_analyzed': len(frames),
'video_path': video_path
}
def summarize_video(self, video_path: str) -> dict:
"""Generate a summary of video content."""
return self.analyze_video(
video_path,
prompt="""Summarize this video:
1. What is the main content or activity?
2. What key events or changes occur?
3. Who or what are the main subjects?
4. What is the setting or context?
Provide a concise summary suitable for video search indexing."""
)Audio Transcription
Transcribe audio using Whisper API with support for long audio, chunking, and compression.
import openai
from pathlib import Path
import subprocess
import tempfile
class AudioTranscriber:
"""Transcribe audio using Whisper API."""
SUPPORTED_FORMATS = {'.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm'}
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
def __init__(self, client: openai.OpenAI):
self.client = client
def transcribe(
self,
audio_path: str,
language: str = None,
prompt: str = None,
response_format: str = "verbose_json"
) -> dict:
"""Transcribe audio file."""
path = Path(audio_path)
# Validate format
if path.suffix.lower() not in self.SUPPORTED_FORMATS:
raise ValueError(f"Unsupported format: {path.suffix}")
# Check file size, convert if needed
if path.stat().st_size > self.MAX_FILE_SIZE:
audio_path = self._compress_audio(audio_path)
with open(audio_path, 'rb') as audio_file:
response = self.client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
language=language,
prompt=prompt, # Vocabulary hints
response_format=response_format,
timestamp_granularities=["word", "segment"]
)
if response_format == "verbose_json":
return {
'text': response.text,
'language': response.language,
'duration': response.duration,
'segments': response.segments,
'words': response.words
}
return {'text': response}
def _compress_audio(self, audio_path: str) -> str:
"""Compress audio to meet size limits."""
output_path = tempfile.mktemp(suffix='.mp3')
subprocess.run([
'ffmpeg', '-i', audio_path,
'-ac', '1', # Mono
'-ar', '16000', # 16kHz sample rate
'-b:a', '32k', # Low bitrate
output_path
], check=True, capture_output=True)
return output_path
def transcribe_long_audio(
self,
audio_path: str,
chunk_duration: int = 600 # 10 minutes
) -> dict:
"""Transcribe long audio by chunking."""
# Get audio duration
duration = self._get_duration(audio_path)
chunks = []
for start in range(0, int(duration), chunk_duration):
# Extract chunk
chunk_path = self._extract_chunk(
audio_path, start, min(chunk_duration, duration - start)
)
# Transcribe chunk
result = self.transcribe(chunk_path)
# Adjust timestamps
for segment in result.get('segments', []):
segment['start'] += start
segment['end'] += start
chunks.append(result)
# Clean up
Path(chunk_path).unlink()
# Merge results
return self._merge_transcriptions(chunks)
def _get_duration(self, audio_path: str) -> float:
"""Get audio duration in seconds."""
result = subprocess.run([
'ffprobe', '-v', 'error',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
audio_path
], capture_output=True, text=True)
return float(result.stdout.strip())
def _extract_chunk(
self,
audio_path: str,
start: float,
duration: float
) -> str:
"""Extract a chunk from audio file."""
output_path = tempfile.mktemp(suffix='.mp3')
subprocess.run([
'ffmpeg', '-i', audio_path,
'-ss', str(start),
'-t', str(duration),
'-ac', '1', '-ar', '16000',
output_path
], check=True, capture_output=True)
return output_path
def _merge_transcriptions(self, chunks: list[dict]) -> dict:
"""Merge transcription chunks."""
all_text = ' '.join(c['text'] for c in chunks)
all_segments = []
for chunk in chunks:
all_segments.extend(chunk.get('segments', []))
return {
'text': all_text,
'segments': all_segments,
'duration': sum(c.get('duration', 0) for c in chunks)
}Speaker Diarization
Identify speakers in audio and align transcriptions with speaker labels.
class SpeakerDiarizer:
"""Identify speakers in audio."""
def __init__(self, diarization_model):
self.diarizer = diarization_model
def diarize_and_transcribe(
self,
audio_path: str,
transcriber: AudioTranscriber
) -> list[dict]:
"""Combine diarization with transcription."""
# Get speaker segments
diarization = self.diarizer(audio_path)
# Get transcription with word timestamps
transcription = transcriber.transcribe(
audio_path,
response_format="verbose_json"
)
# Align words with speakers
aligned = []
for word in transcription.get('words', []):
word_mid = (word['start'] + word['end']) / 2
# Find speaker at this timestamp
speaker = self._find_speaker(diarization, word_mid)
aligned.append({
'word': word['word'],
'start': word['start'],
'end': word['end'],
'speaker': speaker
})
# Group into speaker turns
return self._group_by_speaker(aligned)
def _find_speaker(self, diarization, timestamp: float) -> str:
"""Find which speaker was talking at a timestamp."""
for segment in diarization:
if segment['start'] <= timestamp <= segment['end']:
return segment['speaker']
return 'unknown'
def _group_by_speaker(self, aligned_words: list[dict]) -> list[dict]:
"""Group consecutive words by speaker into turns."""
if not aligned_words:
return []
turns = []
current_turn = {
'speaker': aligned_words[0]['speaker'],
'start': aligned_words[0]['start'],
'words': [aligned_words[0]['word']]
}
for word in aligned_words[1:]:
if word['speaker'] == current_turn['speaker']:
current_turn['words'].append(word['word'])
current_turn['end'] = word['end']
else:
current_turn['text'] = ' '.join(current_turn['words'])
turns.append(current_turn)
current_turn = {
'speaker': word['speaker'],
'start': word['start'],
'words': [word['word']]
}
# Don't forget last turn
current_turn['text'] = ' '.join(current_turn['words'])
current_turn['end'] = aligned_words[-1]['end']
turns.append(current_turn)
return turnsText-to-Speech Processing
Generate speech from text with support for long text, chunking, and SSML.
class TextToSpeechProcessor:
"""Generate speech from text."""
def __init__(self, tts_client):
self.client = tts_client
def synthesize(
self,
text: str,
voice: str = "alloy",
model: str = "tts-1-hd",
speed: float = 1.0
) -> bytes:
"""Synthesize speech from text."""
response = self.client.audio.speech.create(
model=model,
voice=voice,
input=text,
speed=speed
)
return response.content
def synthesize_long_text(
self,
text: str,
voice: str = "alloy",
max_chunk_chars: int = 4000
) -> bytes:
"""Synthesize long text by chunking at sentence boundaries."""
chunks = self._chunk_text(text, max_chunk_chars)
audio_segments = []
for chunk in chunks:
segment = self.synthesize(chunk, voice)
audio_segments.append(segment)
return self._concatenate_audio(audio_segments)
def _chunk_text(self, text: str, max_chars: int) -> list[str]:
"""Split text at sentence boundaries."""
import re
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chars:
current_chunk += sentence + " "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def _concatenate_audio(self, segments: list[bytes]) -> bytes:
"""Concatenate audio segments."""
from pydub import AudioSegment
import io
combined = AudioSegment.empty()
for segment in segments:
audio = AudioSegment.from_mp3(io.BytesIO(segment))
combined += audio
output = io.BytesIO()
combined.export(output, format="mp3")
return output.getvalue()
def synthesize_with_ssml(
self,
ssml_text: str,
voice: str = "alloy"
) -> bytes:
"""Synthesize speech with SSML markup for prosody control."""
# Note: Not all APIs support SSML; this is illustrative
# SSML allows control over pauses, emphasis, pronunciation
return self.synthesize(ssml_text, voice)Document Processing
Process PDF documents using vision-language models for full content extraction, table extraction, and form field extraction.
import json
class DocumentProcessor:
"""Process documents using vision-language models."""
def __init__(self, vlm_client, pdf_renderer):
self.vlm = vlm_client
self.renderer = pdf_renderer
def process_pdf(
self,
pdf_path: str,
extraction_type: str = "full"
) -> dict:
"""Process a PDF document."""
# Render PDF pages to images
pages = self.renderer.render_pages(pdf_path, dpi=150)
results = []
for i, page_image in enumerate(pages):
if extraction_type == "full":
result = self._extract_full_content(page_image, i + 1)
elif extraction_type == "tables":
result = self._extract_tables(page_image, i + 1)
elif extraction_type == "key_value":
result = self._extract_key_values(page_image, i + 1)
else:
raise ValueError(f"Unknown extraction type: {extraction_type}")
results.append(result)
return {
'pages': results,
'total_pages': len(pages),
'extraction_type': extraction_type
}
def _extract_full_content(self, page_image: bytes, page_num: int) -> dict:
"""Extract all content from a page."""
prompt = """Extract all content from this document page.
Output as structured markdown:
- Use headers (##, ###) for section titles
- Use tables for tabular data
- Use lists for enumerated items
- Preserve the logical reading order
- Note any figures or images with [Figure: description]
Be thorough but maintain document structure."""
content = self.vlm.analyze_image(page_image, prompt)
return {
'page': page_num,
'content': content,
'type': 'full'
}
def _extract_tables(self, page_image: bytes, page_num: int) -> dict:
"""Extract tables from a page."""
prompt = """Identify and extract all tables from this document page.
For each table:
1. Provide the table title or context
2. Extract as a markdown table with proper headers
3. Note any merged cells or special formatting
If no tables are present, respond with "No tables found."
Output format:
### Table 1: [title/context]
| Header 1 | Header 2 | ... |
|----------|----------|-----|
| data | data | ... |
"""
content = self.vlm.analyze_image(page_image, prompt)
return {
'page': page_num,
'content': content,
'type': 'tables'
}
def _extract_key_values(self, page_image: bytes, page_num: int) -> dict:
"""Extract key-value pairs from forms."""
prompt = """This appears to be a form or structured document.
Extract all field labels and their values as key-value pairs.
Output as JSON:
{
"field_name": "value",
"another_field": "another_value"
}
For checkboxes, use true/false.
For empty fields, use null.
For multiple values, use arrays."""
content = self.vlm.analyze_image(
page_image,
prompt,
response_format={"type": "json_object"}
)
return {
'page': page_num,
'content': json.loads(content),
'type': 'key_value'
}Hybrid Document Processing
Combine traditional OCR, layout analysis, and VLM-based processing with cross-verification.
class HybridDocumentProcessor:
"""Combine traditional and VLM-based document processing."""
def __init__(
self,
layout_model, # e.g., LayoutLMv3, DocTR
vlm_client,
ocr_engine
):
self.layout = layout_model
self.vlm = vlm_client
self.ocr = ocr_engine
def process_with_verification(
self,
document_image: bytes
) -> dict:
"""Process document with cross-verification."""
# Traditional OCR for baseline text
ocr_result = self.ocr.extract_text(document_image)
# Layout analysis for structure
layout_result = self.layout.analyze(document_image)
# VLM for semantic understanding
vlm_result = self.vlm.analyze_image(
document_image,
"Extract all text and describe the document structure."
)
# Cross-verify results
verification = self._verify_consistency(
ocr_result, layout_result, vlm_result
)
return {
'ocr_text': ocr_result['text'],
'layout': layout_result,
'vlm_interpretation': vlm_result,
'verification': verification,
'confidence': verification['confidence']
}
def _verify_consistency(
self,
ocr: dict,
layout: dict,
vlm: str
) -> dict:
"""Verify consistency across extraction methods."""
issues = []
# Check if VLM found text that OCR missed
ocr_words = set(ocr['text'].lower().split())
vlm_words = set(vlm.lower().split())
vlm_only = vlm_words - ocr_words
significant_vlm_only = {w for w in vlm_only if len(w) > 3}
if len(significant_vlm_only) > 10:
issues.append({
'type': 'text_mismatch',
'description': 'VLM found text not in OCR',
'examples': list(significant_vlm_only)[:5]
})
# Check table consistency
ocr_tables = layout.get('tables', [])
vlm_mentions_tables = 'table' in vlm.lower()
if ocr_tables and not vlm_mentions_tables:
issues.append({
'type': 'table_mismatch',
'description': 'Layout found tables but VLM did not mention them'
})
confidence = 1.0 - (len(issues) * 0.2)
return {
'consistent': len(issues) == 0,
'issues': issues,
'confidence': max(0.3, confidence)
}Multi-Page Document Processing
Process multi-page documents with context awareness across pages.
class MultiPageProcessor:
"""Process multi-page documents with context awareness."""
def __init__(self, processor: DocumentProcessor):
self.processor = processor
def process_with_context(
self,
pdf_path: str,
context_window: int = 2
) -> dict:
"""Process pages with surrounding context."""
pages = self.processor.renderer.render_pages(pdf_path)
results = []
for i, page in enumerate(pages):
# Build context from surrounding pages
context = self._build_context(results, i, context_window)
# Process with context
prompt = f"""Process this document page.
Previous context:
{context}
Extract content maintaining consistency with previous pages.
Continue any tables or lists that span pages.
"""
result = self.processor.vlm.analyze_image(page, prompt)
results.append({
'page': i + 1,
'content': result
})
return {'pages': results}
def _build_context(
self,
previous_results: list,
current_index: int,
window: int
) -> str:
"""Build context string from previous pages."""
start = max(0, current_index - window)
context_pages = previous_results[start:current_index]
if not context_pages:
return "This is the first page."
context = []
for p in context_pages:
# Summarize previous page content
content_preview = p['content'][:500] + "..." if len(p['content']) > 500 else p['content']
context.append(f"Page {p['page']}: {content_preview}")
return "\n".join(context)Multimodal Embeddings and Indexing
Generate embeddings for multiple modalities and build a searchable multimodal index.
from typing import Union
import numpy as np
class MultimodalEmbedder:
"""Generate embeddings for multiple modalities."""
def __init__(
self,
clip_model,
text_embedder,
audio_embedder=None
):
self.clip = clip_model
self.text = text_embedder
self.audio = audio_embedder
def embed(
self,
content: Union[str, bytes],
modality: str
) -> np.ndarray:
"""Embed content based on modality."""
if modality == "text":
return self.text.embed(content)
elif modality == "image":
return self.clip.encode_image(content)
elif modality == "image_text_query":
# Use CLIP text encoder for image search queries
return self.clip.encode_text(content)
elif modality == "audio" and self.audio:
return self.audio.embed(content)
else:
raise ValueError(f"Unsupported modality: {modality}")
def embed_document(
self,
document: dict
) -> dict:
"""Embed a multimodal document."""
embeddings = {}
if 'text' in document:
embeddings['text'] = self.embed(document['text'], 'text')
if 'images' in document:
embeddings['images'] = [
self.embed(img, 'image')
for img in document['images']
]
if 'audio' in document and self.audio:
embeddings['audio'] = self.embed(document['audio'], 'audio')
return embeddings
class MultimodalIndex:
"""Index and search multimodal content."""
def __init__(
self,
embedder: MultimodalEmbedder,
vector_store
):
self.embedder = embedder
self.store = vector_store
def index_document(
self,
doc_id: str,
document: dict,
metadata: dict = None
) -> None:
"""Index a multimodal document."""
embeddings = self.embedder.embed_document(document)
# Index text embedding
if 'text' in embeddings:
self.store.upsert(
id=f"{doc_id}_text",
vector=embeddings['text'],
metadata={
**(metadata or {}),
'doc_id': doc_id,
'modality': 'text',
'content_preview': document['text'][:200]
}
)
# Index image embeddings
for i, img_emb in enumerate(embeddings.get('images', [])):
self.store.upsert(
id=f"{doc_id}_image_{i}",
vector=img_emb,
metadata={
**(metadata or {}),
'doc_id': doc_id,
'modality': 'image',
'image_index': i
}
)
def search(
self,
query: str,
modalities: list[str] = None,
top_k: int = 10
) -> list[dict]:
"""Search across modalities."""
modalities = modalities or ['text', 'image']
all_results = []
# Search text with text embedding
if 'text' in modalities:
text_query_emb = self.embedder.embed(query, 'text')
text_results = self.store.search(
vector=text_query_emb,
filter={'modality': 'text'},
top_k=top_k
)
all_results.extend(text_results)
# Search images with CLIP text embedding
if 'image' in modalities:
image_query_emb = self.embedder.embed(query, 'image_text_query')
image_results = self.store.search(
vector=image_query_emb,
filter={'modality': 'image'},
top_k=top_k
)
all_results.extend(image_results)
# Sort by score and dedupe by doc_id
all_results.sort(key=lambda x: x['score'], reverse=True)
return self._dedupe_by_doc(all_results, top_k)
def _dedupe_by_doc(
self,
results: list[dict],
limit: int
) -> list[dict]:
"""Deduplicate results by document ID."""
seen_docs = set()
deduped = []
for result in results:
doc_id = result['metadata']['doc_id']
if doc_id not in seen_docs:
seen_docs.add(doc_id)
deduped.append(result)
if len(deduped) >= limit:
break
return dedupedLate Fusion Search
Combine results from modality-specific searches with weighted scoring.
class LateFusionSearch:
"""Combine results from modality-specific searches."""
def __init__(self, text_index, image_index):
self.text_index = text_index
self.image_index = image_index
def search(
self,
query: str,
weights: dict = None,
top_k: int = 10
) -> list[dict]:
"""Search with late fusion of modality results."""
weights = weights or {'text': 0.6, 'image': 0.4}
# Search each modality
text_results = self.text_index.search(query, top_k=top_k * 2)
image_results = self.image_index.search(query, top_k=top_k * 2)
# Normalize scores within each modality
text_results = self._normalize_scores(text_results)
image_results = self._normalize_scores(image_results)
# Combine with weights
doc_scores = {}
for r in text_results:
doc_id = r['doc_id']
doc_scores[doc_id] = doc_scores.get(doc_id, 0) + r['score'] * weights['text']
for r in image_results:
doc_id = r['doc_id']
doc_scores[doc_id] = doc_scores.get(doc_id, 0) + r['score'] * weights['image']
# Sort and return
sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
return [{'doc_id': doc_id, 'score': score} for doc_id, score in sorted_docs[:top_k]]
def _normalize_scores(self, results: list[dict]) -> list[dict]:
"""Min-max normalize scores."""
if not results:
return results
scores = [r['score'] for r in results]
min_score, max_score = min(scores), max(scores)
range_score = max_score - min_score or 1
for r in results:
r['score'] = (r['score'] - min_score) / range_score
return resultsMultimodal Document RAG
RAG system for documents containing text and images with cross-modal retrieval.
import base64
class MultimodalDocumentRAG:
"""RAG system for documents containing text and images."""
def __init__(
self,
text_embedder,
clip_model,
vlm_client,
vector_store
):
self.text_embedder = text_embedder
self.clip = clip_model
self.vlm = vlm_client
self.store = vector_store
def index_document(self, doc_path: str, doc_id: str) -> dict:
"""Index a document with text and images."""
# Extract text and images
pages = self._extract_pages(doc_path)
indexed_chunks = []
for page_num, page in enumerate(pages):
# Index text chunks
for chunk in self._chunk_text(page['text']):
chunk_id = f"{doc_id}_p{page_num}_t{len(indexed_chunks)}"
text_embedding = self.text_embedder.encode(chunk)
self.store.add(
id=chunk_id,
vector=text_embedding,
metadata={
'doc_id': doc_id,
'page': page_num,
'modality': 'text',
'content': chunk
}
)
indexed_chunks.append(chunk_id)
# Index images with generated descriptions
for img_idx, image in enumerate(page['images']):
# Generate description using VLM
description = self._describe_image(image, page['text'][:500])
# Create combined embedding (description + image)
img_id = f"{doc_id}_p{page_num}_i{img_idx}"
# Index the description text
desc_embedding = self.text_embedder.encode(description)
self.store.add(
id=f"{img_id}_desc",
vector=desc_embedding,
metadata={
'doc_id': doc_id,
'page': page_num,
'modality': 'image_description',
'description': description,
'image_ref': img_id
}
)
# Index the image embedding (for visual similarity)
img_embedding = self.clip.encode_image(image)
self.store.add(
id=f"{img_id}_visual",
vector=img_embedding,
metadata={
'doc_id': doc_id,
'page': page_num,
'modality': 'image_visual',
'image_ref': img_id
}
)
indexed_chunks.append(img_id)
return {'doc_id': doc_id, 'chunks_indexed': len(indexed_chunks)}
def _describe_image(self, image_bytes: bytes, surrounding_text: str) -> str:
"""Generate searchable description of an image."""
response = self.vlm.generate(
messages=[{
"role": "user",
"content": [
{"type": "text", "text": f"""Describe this image for search indexing.
Context from surrounding document text: {surrounding_text}
Provide a detailed description including:
1. What the image shows (diagram, chart, photo, screenshot, etc.)
2. Key elements, labels, or data points
3. The main concept or information conveyed
4. Any text visible in the image"""},
{"type": "image", "source": {"type": "base64", "data": base64.b64encode(image_bytes).decode()}}
]
}]
)
return response
def query(self, question: str, top_k: int = 5) -> dict:
"""Query across text and images."""
# Search text content
text_embedding = self.text_embedder.encode(question)
text_results = self.store.search(
vector=text_embedding,
filter={'modality': {'$in': ['text', 'image_description']}},
top_k=top_k
)
# Search visual content (for questions about visual elements)
clip_text_embedding = self.clip.encode_text(question)
visual_results = self.store.search(
vector=clip_text_embedding,
filter={'modality': 'image_visual'},
top_k=top_k // 2
)
# Merge and dedupe results
all_results = self._merge_results(text_results, visual_results)
# Build context including both text and image references
context = self._build_multimodal_context(all_results)
# Generate answer using VLM
answer = self._generate_answer(question, context)
return {
'answer': answer,
'sources': all_results
}
def _build_multimodal_context(self, results: list[dict]) -> list:
"""Build context with text and images for VLM."""
context_parts = []
for result in results:
if result['metadata']['modality'] == 'text':
context_parts.append({
"type": "text",
"text": f"[Page {result['metadata']['page']}]: {result['metadata']['content']}"
})
elif result['metadata']['modality'] in ['image_description', 'image_visual']:
# Include both the image and its description
image_ref = result['metadata'].get('image_ref')
if image_ref:
image_data = self._load_image(image_ref)
context_parts.append({
"type": "text",
"text": f"[Image from page {result['metadata']['page']}]:"
})
context_parts.append({
"type": "image",
"source": {"type": "base64", "data": image_data}
})
return context_partsMultimodal Orchestration
Orchestrate multiple specialized models for processing different modalities in parallel.
class MultimodalOrchestrator:
"""Orchestrate multiple specialized models."""
def __init__(
self,
vision_model,
audio_model,
text_model,
fusion_model=None
):
self.vision = vision_model
self.audio = audio_model
self.text = text_model
self.fusion = fusion_model
async def process_multimodal_input(
self,
inputs: dict
) -> dict:
"""Process multimodal input with parallel specialized models."""
import asyncio
tasks = []
if 'image' in inputs:
tasks.append(('vision', self._process_image(inputs['image'])))
if 'audio' in inputs:
tasks.append(('audio', self._process_audio(inputs['audio'])))
if 'text' in inputs:
tasks.append(('text', self._process_text(inputs['text'])))
# Process modalities in parallel
results = {}
modality_results = await asyncio.gather(*[t[1] for t in tasks])
for (modality, _), result in zip(tasks, modality_results):
results[modality] = result
# Fuse results if fusion model available
if self.fusion and len(results) > 1:
results['fused'] = await self._fuse_results(results)
return results
async def _process_image(self, image: bytes) -> dict:
"""Process image with vision model."""
description = await self.vision.describe(image)
objects = await self.vision.detect_objects(image)
text = await self.vision.extract_text(image)
return {
'description': description,
'objects': objects,
'extracted_text': text
}
async def _process_audio(self, audio: bytes) -> dict:
"""Process audio with audio model."""
transcription = await self.audio.transcribe(audio)
return {'transcription': transcription}
async def _process_text(self, text: str) -> dict:
"""Process text with text model."""
embedding = await self.text.embed(text)
entities = await self.text.extract_entities(text)
return {'embedding': embedding, 'entities': entities}
async def _fuse_results(self, results: dict) -> dict:
"""Combine results from multiple modalities."""
fusion_prompt = f"""Combine these analyses into a unified understanding:
Vision analysis: {results.get('vision', {})}
Audio analysis: {results.get('audio', {})}
Text analysis: {results.get('text', {})}
Provide a coherent summary that integrates information across modalities."""
fused = await self.fusion.generate(fusion_prompt)
return {'summary': fused}Cost-Aware Processing
Process multimodal content with budget awareness and optimization.
class CostAwareMultimodalProcessor:
"""Process multimodal content with cost awareness."""
# Approximate costs per 1K tokens (example rates)
COST_PER_1K = {
'text_input': 0.003,
'text_output': 0.015,
'image_input': 0.005, # Per image token
'audio_minute': 0.006,
}
def __init__(self, processor, budget_per_request: float = 0.10):
self.processor = processor
self.budget = budget_per_request
def estimate_cost(self, inputs: dict) -> float:
"""Estimate processing cost for inputs."""
cost = 0.0
if 'text' in inputs:
tokens = len(inputs['text'].split()) * 1.3 # Rough estimate
cost += (tokens / 1000) * self.COST_PER_1K['text_input']
if 'images' in inputs:
for img in inputs['images']:
# Estimate image tokens based on resolution
img_tokens = self._estimate_image_tokens(img)
cost += (img_tokens / 1000) * self.COST_PER_1K['image_input']
if 'audio' in inputs:
duration_minutes = inputs.get('audio_duration', 1)
cost += duration_minutes * self.COST_PER_1K['audio_minute']
return cost
def process_within_budget(self, inputs: dict) -> dict:
"""Process inputs, optimizing to stay within budget."""
estimated_cost = self.estimate_cost(inputs)
if estimated_cost <= self.budget:
return self.processor.process(inputs)
# Optimize to reduce cost
optimized_inputs = self._optimize_inputs(inputs, self.budget)
return self.processor.process(optimized_inputs)
def _optimize_inputs(self, inputs: dict, target_budget: float) -> dict:
"""Reduce input size to meet budget."""
optimized = inputs.copy()
# Reduce image resolution
if 'images' in optimized:
optimized['images'] = [
self._reduce_image_resolution(img, factor=0.5)
for img in optimized['images']
]
# Truncate text
if 'text' in optimized and len(optimized['text']) > 4000:
optimized['text'] = optimized['text'][:4000]
return optimizedUsage Notes
Selecting the Right Components
| Use Case | Recommended Components |
|---|---|
| Simple image Q&A | VisionProcessor |
| Document OCR | DocumentProcessor + ImageOptimizer |
| Video summarization | VideoProcessor |
| Podcast transcription | AudioTranscriber + SpeakerDiarizer |
| Multimodal search | MultimodalEmbedder + MultimodalIndex |
| Document Q&A with figures | MultimodalDocumentRAG |
Performance Considerations
- Batch processing: Process multiple items in parallel when possible
- Resolution optimization: Match image resolution to task requirements
- Chunking: Split long audio/video into manageable segments
- Caching: Cache embeddings and intermediate results
- Cost monitoring: Track token usage across modalities
Error Handling
All implementations should include: - Input validation (format, size limits) - Retry logic for transient API failures - Graceful degradation when optional components fail - Logging for debugging and monitoring