Multimodal Systems - Complete Code Reference

This reference contains the full implementations for Chapter 53: Multimodal Systems. Each section provides production-ready code that can be adapted for your specific use cases.


Table of Contents

  1. Vision-Language Model Integration
  2. Image Optimization
  3. Vision Result Validation
  4. Video Processing
  5. Audio Transcription
  6. Speaker Diarization
  7. Text-to-Speech Processing
  8. Document Processing
  9. Hybrid Document Processing
  10. Multi-Page Document Processing
  11. Multimodal Embeddings and Indexing
  12. Late Fusion Search
  13. Multimodal Document RAG
  14. Multimodal Orchestration
  15. Cost-Aware Processing

Vision-Language Model Integration

Complete implementation for processing images with vision-language models, including encoding, single-image analysis, and multi-image comparison.

import anthropic
import base64
from pathlib import Path

class VisionProcessor:
    """Process images with vision-language models."""

    def __init__(self, client: anthropic.Anthropic):
        self.client = client

    def encode_image(self, image_path: str) -> dict:
        """Encode image for API submission."""
        path = Path(image_path)
        suffix = path.suffix.lower()

        media_types = {
            '.jpg': 'image/jpeg',
            '.jpeg': 'image/jpeg',
            '.png': 'image/png',
            '.gif': 'image/gif',
            '.webp': 'image/webp'
        }

        if suffix not in media_types:
            raise ValueError(f"Unsupported image format: {suffix}")

        with open(path, 'rb') as f:
            image_data = base64.standard_b64encode(f.read()).decode('utf-8')

        return {
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": media_types[suffix],
                "data": image_data
            }
        }

    def analyze_image(
        self,
        image_path: str,
        prompt: str,
        system_prompt: str = None
    ) -> str:
        """Analyze an image with a text prompt."""
        image_content = self.encode_image(image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    image_content,
                    {"type": "text", "text": prompt}
                ]
            }
        ]

        response = self.client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=4096,
            system=system_prompt or "You are a helpful assistant that analyzes images.",
            messages=messages
        )

        return response.content[0].text

    def compare_images(
        self,
        image_paths: list[str],
        prompt: str
    ) -> str:
        """Compare multiple images."""
        content = []
        for i, path in enumerate(image_paths):
            content.append({"type": "text", "text": f"Image {i+1}:"})
            content.append(self.encode_image(path))

        content.append({"type": "text", "text": prompt})

        response = self.client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=4096,
            messages=[{"role": "user", "content": content}]
        )

        return response.content[0].text

Image Optimization

Optimize images for VLM processing by resizing based on task requirements and estimating token usage.

from PIL import Image
import io

class ImageOptimizer:
    """Optimize images for VLM processing."""

    # Approximate token costs per resolution tier
    RESOLUTION_TIERS = {
        'low': (512, 512),      # ~85 tokens
        'medium': (1024, 1024), # ~340 tokens
        'high': (2048, 2048),   # ~1360 tokens
    }

    def optimize_for_task(
        self,
        image_path: str,
        task: str
    ) -> bytes:
        """Optimize image resolution based on task requirements."""
        task_requirements = {
            'ocr': 'high',           # Need detail for text
            'classification': 'low', # Overall content is enough
            'object_detection': 'medium',
            'chart_reading': 'high',
            'general_qa': 'medium',
        }

        tier = task_requirements.get(task, 'medium')
        max_size = self.RESOLUTION_TIERS[tier]

        img = Image.open(image_path)

        # Resize if larger than tier allows
        if img.width > max_size[0] or img.height > max_size[1]:
            img.thumbnail(max_size, Image.Resampling.LANCZOS)

        # Convert to RGB if necessary (remove alpha channel)
        if img.mode in ('RGBA', 'P'):
            img = img.convert('RGB')

        # Compress to JPEG for smaller payload
        buffer = io.BytesIO()
        img.save(buffer, format='JPEG', quality=85)
        return buffer.getvalue()

    def estimate_tokens(self, image_path: str) -> int:
        """Estimate token count for an image."""
        img = Image.open(image_path)
        # Rough estimation: ~0.5 tokens per 32x32 patch
        patches = (img.width // 32) * (img.height // 32)
        return int(patches * 0.5)

Vision Result Validation

Validate vision model outputs, detect hallucinations, and implement redundancy-based verification.

import re

class VisionResultValidator:
    """Validate vision model outputs."""

    def validate_ocr_result(
        self,
        extracted_text: str,
        image_path: str,
        expected_patterns: list[str] = None
    ) -> dict:
        """Validate OCR extraction results."""
        issues = []

        # Check for common hallucination indicators
        hallucination_phrases = [
            "I cannot read",
            "text is not clear",
            "appears to say",
            "might be",
        ]

        confidence = 1.0
        for phrase in hallucination_phrases:
            if phrase.lower() in extracted_text.lower():
                confidence -= 0.2
                issues.append(f"Uncertainty indicator: '{phrase}'")

        # Check expected patterns if provided
        if expected_patterns:
            for pattern in expected_patterns:
                if not re.search(pattern, extracted_text):
                    issues.append(f"Expected pattern not found: {pattern}")
                    confidence -= 0.1

        # Check for empty or minimal extraction
        if len(extracted_text.strip()) < 10:
            issues.append("Extraction returned minimal text")
            confidence = 0.3

        return {
            'valid': confidence > 0.5 and len(issues) < 3,
            'confidence': max(0, confidence),
            'issues': issues,
            'text': extracted_text
        }

    def validate_with_redundancy(
        self,
        image_path: str,
        prompt: str,
        processors: list,
        agreement_threshold: float = 0.8
    ) -> dict:
        """Use multiple models/calls to validate results."""
        results = []
        for processor in processors:
            result = processor.analyze_image(image_path, prompt)
            results.append(result)

        # Check agreement (simplified - production would use semantic similarity)
        if len(results) >= 2:
            # Compare key facts extracted
            agreement = self._calculate_agreement(results)
            return {
                'results': results,
                'agreement': agreement,
                'confident': agreement >= agreement_threshold,
                'consensus': results[0] if agreement >= agreement_threshold else None
            }

        return {'results': results, 'confident': False}

Video Processing

Extract frames from video using various sampling strategies and analyze video content with VLMs.

import cv2
from typing import Generator
import numpy as np

class VideoProcessor:
    """Process video for VLM analysis."""

    def __init__(self, vlm_client):
        self.vlm = vlm_client

    def extract_frames(
        self,
        video_path: str,
        strategy: str = "uniform",
        max_frames: int = 20,
        fps: float = 1.0
    ) -> Generator[tuple[float, np.ndarray], None, None]:
        """Extract frames from video using specified strategy."""
        cap = cv2.VideoCapture(video_path)
        video_fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / video_fps

        if strategy == "uniform":
            # Sample at fixed intervals
            interval = max(1, int(video_fps / fps))
            frame_indices = range(0, total_frames, interval)
        elif strategy == "keyframes":
            # Extract scene changes
            frame_indices = self._detect_scene_changes(video_path, max_frames)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        frames_yielded = 0
        for idx in frame_indices:
            if frames_yielded >= max_frames:
                break

            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                timestamp = idx / video_fps
                yield timestamp, frame
                frames_yielded += 1

        cap.release()

    def _detect_scene_changes(
        self,
        video_path: str,
        max_scenes: int
    ) -> list[int]:
        """Detect scene changes using frame differencing."""
        cap = cv2.VideoCapture(video_path)
        scene_frames = [0]  # Always include first frame

        prev_frame = None
        frame_idx = 0
        threshold = 30.0  # Adjust based on content

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

            if prev_frame is not None:
                diff = cv2.absdiff(prev_frame, gray)
                mean_diff = np.mean(diff)

                if mean_diff > threshold:
                    scene_frames.append(frame_idx)

            prev_frame = gray
            frame_idx += 1

        cap.release()

        # If too many scenes, sample uniformly from them
        if len(scene_frames) > max_scenes:
            indices = np.linspace(0, len(scene_frames) - 1, max_scenes, dtype=int)
            scene_frames = [scene_frames[i] for i in indices]

        return scene_frames

    def analyze_video(
        self,
        video_path: str,
        prompt: str,
        max_frames: int = 10
    ) -> dict:
        """Analyze video content with VLM."""
        frames = list(self.extract_frames(
            video_path,
            strategy="uniform",
            max_frames=max_frames
        ))

        # Build content with frames and timestamps
        content = [{"type": "text", "text": f"Video analysis ({len(frames)} frames):"}]

        for timestamp, frame in frames:
            # Convert frame to JPEG bytes
            _, buffer = cv2.imencode('.jpg', frame)
            frame_bytes = buffer.tobytes()

            content.append({
                "type": "text",
                "text": f"\n[Frame at {timestamp:.1f}s]:"
            })
            content.append(self.vlm.encode_image_bytes(frame_bytes))

        content.append({"type": "text", "text": f"\n\n{prompt}"})

        response = self.vlm.client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=4096,
            messages=[{"role": "user", "content": content}]
        )

        return {
            'analysis': response.content[0].text,
            'frames_analyzed': len(frames),
            'video_path': video_path
        }

    def summarize_video(self, video_path: str) -> dict:
        """Generate a summary of video content."""
        return self.analyze_video(
            video_path,
            prompt="""Summarize this video:
1. What is the main content or activity?
2. What key events or changes occur?
3. Who or what are the main subjects?
4. What is the setting or context?

Provide a concise summary suitable for video search indexing."""
        )

Audio Transcription

Transcribe audio using Whisper API with support for long audio, chunking, and compression.

import openai
from pathlib import Path
import subprocess
import tempfile

class AudioTranscriber:
    """Transcribe audio using Whisper API."""

    SUPPORTED_FORMATS = {'.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm'}
    MAX_FILE_SIZE = 25 * 1024 * 1024  # 25 MB

    def __init__(self, client: openai.OpenAI):
        self.client = client

    def transcribe(
        self,
        audio_path: str,
        language: str = None,
        prompt: str = None,
        response_format: str = "verbose_json"
    ) -> dict:
        """Transcribe audio file."""
        path = Path(audio_path)

        # Validate format
        if path.suffix.lower() not in self.SUPPORTED_FORMATS:
            raise ValueError(f"Unsupported format: {path.suffix}")

        # Check file size, convert if needed
        if path.stat().st_size > self.MAX_FILE_SIZE:
            audio_path = self._compress_audio(audio_path)

        with open(audio_path, 'rb') as audio_file:
            response = self.client.audio.transcriptions.create(
                model="whisper-1",
                file=audio_file,
                language=language,
                prompt=prompt,  # Vocabulary hints
                response_format=response_format,
                timestamp_granularities=["word", "segment"]
            )

        if response_format == "verbose_json":
            return {
                'text': response.text,
                'language': response.language,
                'duration': response.duration,
                'segments': response.segments,
                'words': response.words
            }
        return {'text': response}

    def _compress_audio(self, audio_path: str) -> str:
        """Compress audio to meet size limits."""
        output_path = tempfile.mktemp(suffix='.mp3')
        subprocess.run([
            'ffmpeg', '-i', audio_path,
            '-ac', '1',  # Mono
            '-ar', '16000',  # 16kHz sample rate
            '-b:a', '32k',  # Low bitrate
            output_path
        ], check=True, capture_output=True)
        return output_path

    def transcribe_long_audio(
        self,
        audio_path: str,
        chunk_duration: int = 600  # 10 minutes
    ) -> dict:
        """Transcribe long audio by chunking."""
        # Get audio duration
        duration = self._get_duration(audio_path)
        chunks = []

        for start in range(0, int(duration), chunk_duration):
            # Extract chunk
            chunk_path = self._extract_chunk(
                audio_path, start, min(chunk_duration, duration - start)
            )

            # Transcribe chunk
            result = self.transcribe(chunk_path)

            # Adjust timestamps
            for segment in result.get('segments', []):
                segment['start'] += start
                segment['end'] += start

            chunks.append(result)

            # Clean up
            Path(chunk_path).unlink()

        # Merge results
        return self._merge_transcriptions(chunks)

    def _get_duration(self, audio_path: str) -> float:
        """Get audio duration in seconds."""
        result = subprocess.run([
            'ffprobe', '-v', 'error',
            '-show_entries', 'format=duration',
            '-of', 'default=noprint_wrappers=1:nokey=1',
            audio_path
        ], capture_output=True, text=True)
        return float(result.stdout.strip())

    def _extract_chunk(
        self,
        audio_path: str,
        start: float,
        duration: float
    ) -> str:
        """Extract a chunk from audio file."""
        output_path = tempfile.mktemp(suffix='.mp3')
        subprocess.run([
            'ffmpeg', '-i', audio_path,
            '-ss', str(start),
            '-t', str(duration),
            '-ac', '1', '-ar', '16000',
            output_path
        ], check=True, capture_output=True)
        return output_path

    def _merge_transcriptions(self, chunks: list[dict]) -> dict:
        """Merge transcription chunks."""
        all_text = ' '.join(c['text'] for c in chunks)
        all_segments = []
        for chunk in chunks:
            all_segments.extend(chunk.get('segments', []))

        return {
            'text': all_text,
            'segments': all_segments,
            'duration': sum(c.get('duration', 0) for c in chunks)
        }

Speaker Diarization

Identify speakers in audio and align transcriptions with speaker labels.

class SpeakerDiarizer:
    """Identify speakers in audio."""

    def __init__(self, diarization_model):
        self.diarizer = diarization_model

    def diarize_and_transcribe(
        self,
        audio_path: str,
        transcriber: AudioTranscriber
    ) -> list[dict]:
        """Combine diarization with transcription."""
        # Get speaker segments
        diarization = self.diarizer(audio_path)

        # Get transcription with word timestamps
        transcription = transcriber.transcribe(
            audio_path,
            response_format="verbose_json"
        )

        # Align words with speakers
        aligned = []
        for word in transcription.get('words', []):
            word_mid = (word['start'] + word['end']) / 2

            # Find speaker at this timestamp
            speaker = self._find_speaker(diarization, word_mid)

            aligned.append({
                'word': word['word'],
                'start': word['start'],
                'end': word['end'],
                'speaker': speaker
            })

        # Group into speaker turns
        return self._group_by_speaker(aligned)

    def _find_speaker(self, diarization, timestamp: float) -> str:
        """Find which speaker was talking at a timestamp."""
        for segment in diarization:
            if segment['start'] <= timestamp <= segment['end']:
                return segment['speaker']
        return 'unknown'

    def _group_by_speaker(self, aligned_words: list[dict]) -> list[dict]:
        """Group consecutive words by speaker into turns."""
        if not aligned_words:
            return []

        turns = []
        current_turn = {
            'speaker': aligned_words[0]['speaker'],
            'start': aligned_words[0]['start'],
            'words': [aligned_words[0]['word']]
        }

        for word in aligned_words[1:]:
            if word['speaker'] == current_turn['speaker']:
                current_turn['words'].append(word['word'])
                current_turn['end'] = word['end']
            else:
                current_turn['text'] = ' '.join(current_turn['words'])
                turns.append(current_turn)
                current_turn = {
                    'speaker': word['speaker'],
                    'start': word['start'],
                    'words': [word['word']]
                }

        # Don't forget last turn
        current_turn['text'] = ' '.join(current_turn['words'])
        current_turn['end'] = aligned_words[-1]['end']
        turns.append(current_turn)

        return turns

Text-to-Speech Processing

Generate speech from text with support for long text, chunking, and SSML.

class TextToSpeechProcessor:
    """Generate speech from text."""

    def __init__(self, tts_client):
        self.client = tts_client

    def synthesize(
        self,
        text: str,
        voice: str = "alloy",
        model: str = "tts-1-hd",
        speed: float = 1.0
    ) -> bytes:
        """Synthesize speech from text."""
        response = self.client.audio.speech.create(
            model=model,
            voice=voice,
            input=text,
            speed=speed
        )
        return response.content

    def synthesize_long_text(
        self,
        text: str,
        voice: str = "alloy",
        max_chunk_chars: int = 4000
    ) -> bytes:
        """Synthesize long text by chunking at sentence boundaries."""
        chunks = self._chunk_text(text, max_chunk_chars)
        audio_segments = []

        for chunk in chunks:
            segment = self.synthesize(chunk, voice)
            audio_segments.append(segment)

        return self._concatenate_audio(audio_segments)

    def _chunk_text(self, text: str, max_chars: int) -> list[str]:
        """Split text at sentence boundaries."""
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text)
        chunks = []
        current_chunk = ""

        for sentence in sentences:
            if len(current_chunk) + len(sentence) < max_chars:
                current_chunk += sentence + " "
            else:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                current_chunk = sentence + " "

        if current_chunk:
            chunks.append(current_chunk.strip())

        return chunks

    def _concatenate_audio(self, segments: list[bytes]) -> bytes:
        """Concatenate audio segments."""
        from pydub import AudioSegment
        import io

        combined = AudioSegment.empty()
        for segment in segments:
            audio = AudioSegment.from_mp3(io.BytesIO(segment))
            combined += audio

        output = io.BytesIO()
        combined.export(output, format="mp3")
        return output.getvalue()

    def synthesize_with_ssml(
        self,
        ssml_text: str,
        voice: str = "alloy"
    ) -> bytes:
        """Synthesize speech with SSML markup for prosody control."""
        # Note: Not all APIs support SSML; this is illustrative
        # SSML allows control over pauses, emphasis, pronunciation
        return self.synthesize(ssml_text, voice)

Document Processing

Process PDF documents using vision-language models for full content extraction, table extraction, and form field extraction.

import json

class DocumentProcessor:
    """Process documents using vision-language models."""

    def __init__(self, vlm_client, pdf_renderer):
        self.vlm = vlm_client
        self.renderer = pdf_renderer

    def process_pdf(
        self,
        pdf_path: str,
        extraction_type: str = "full"
    ) -> dict:
        """Process a PDF document."""
        # Render PDF pages to images
        pages = self.renderer.render_pages(pdf_path, dpi=150)

        results = []
        for i, page_image in enumerate(pages):
            if extraction_type == "full":
                result = self._extract_full_content(page_image, i + 1)
            elif extraction_type == "tables":
                result = self._extract_tables(page_image, i + 1)
            elif extraction_type == "key_value":
                result = self._extract_key_values(page_image, i + 1)
            else:
                raise ValueError(f"Unknown extraction type: {extraction_type}")

            results.append(result)

        return {
            'pages': results,
            'total_pages': len(pages),
            'extraction_type': extraction_type
        }

    def _extract_full_content(self, page_image: bytes, page_num: int) -> dict:
        """Extract all content from a page."""
        prompt = """Extract all content from this document page.

Output as structured markdown:
- Use headers (##, ###) for section titles
- Use tables for tabular data
- Use lists for enumerated items
- Preserve the logical reading order
- Note any figures or images with [Figure: description]

Be thorough but maintain document structure."""

        content = self.vlm.analyze_image(page_image, prompt)

        return {
            'page': page_num,
            'content': content,
            'type': 'full'
        }

    def _extract_tables(self, page_image: bytes, page_num: int) -> dict:
        """Extract tables from a page."""
        prompt = """Identify and extract all tables from this document page.

For each table:
1. Provide the table title or context
2. Extract as a markdown table with proper headers
3. Note any merged cells or special formatting

If no tables are present, respond with "No tables found."

Output format:
### Table 1: [title/context]
| Header 1 | Header 2 | ... |
|----------|----------|-----|
| data     | data     | ... |
"""

        content = self.vlm.analyze_image(page_image, prompt)

        return {
            'page': page_num,
            'content': content,
            'type': 'tables'
        }

    def _extract_key_values(self, page_image: bytes, page_num: int) -> dict:
        """Extract key-value pairs from forms."""
        prompt = """This appears to be a form or structured document.
Extract all field labels and their values as key-value pairs.

Output as JSON:
{
    "field_name": "value",
    "another_field": "another_value"
}

For checkboxes, use true/false.
For empty fields, use null.
For multiple values, use arrays."""

        content = self.vlm.analyze_image(
            page_image,
            prompt,
            response_format={"type": "json_object"}
        )

        return {
            'page': page_num,
            'content': json.loads(content),
            'type': 'key_value'
        }

Hybrid Document Processing

Combine traditional OCR, layout analysis, and VLM-based processing with cross-verification.

class HybridDocumentProcessor:
    """Combine traditional and VLM-based document processing."""

    def __init__(
        self,
        layout_model,  # e.g., LayoutLMv3, DocTR
        vlm_client,
        ocr_engine
    ):
        self.layout = layout_model
        self.vlm = vlm_client
        self.ocr = ocr_engine

    def process_with_verification(
        self,
        document_image: bytes
    ) -> dict:
        """Process document with cross-verification."""
        # Traditional OCR for baseline text
        ocr_result = self.ocr.extract_text(document_image)

        # Layout analysis for structure
        layout_result = self.layout.analyze(document_image)

        # VLM for semantic understanding
        vlm_result = self.vlm.analyze_image(
            document_image,
            "Extract all text and describe the document structure."
        )

        # Cross-verify results
        verification = self._verify_consistency(
            ocr_result, layout_result, vlm_result
        )

        return {
            'ocr_text': ocr_result['text'],
            'layout': layout_result,
            'vlm_interpretation': vlm_result,
            'verification': verification,
            'confidence': verification['confidence']
        }

    def _verify_consistency(
        self,
        ocr: dict,
        layout: dict,
        vlm: str
    ) -> dict:
        """Verify consistency across extraction methods."""
        issues = []

        # Check if VLM found text that OCR missed
        ocr_words = set(ocr['text'].lower().split())
        vlm_words = set(vlm.lower().split())

        vlm_only = vlm_words - ocr_words
        significant_vlm_only = {w for w in vlm_only if len(w) > 3}

        if len(significant_vlm_only) > 10:
            issues.append({
                'type': 'text_mismatch',
                'description': 'VLM found text not in OCR',
                'examples': list(significant_vlm_only)[:5]
            })

        # Check table consistency
        ocr_tables = layout.get('tables', [])
        vlm_mentions_tables = 'table' in vlm.lower()

        if ocr_tables and not vlm_mentions_tables:
            issues.append({
                'type': 'table_mismatch',
                'description': 'Layout found tables but VLM did not mention them'
            })

        confidence = 1.0 - (len(issues) * 0.2)

        return {
            'consistent': len(issues) == 0,
            'issues': issues,
            'confidence': max(0.3, confidence)
        }

Multi-Page Document Processing

Process multi-page documents with context awareness across pages.

class MultiPageProcessor:
    """Process multi-page documents with context awareness."""

    def __init__(self, processor: DocumentProcessor):
        self.processor = processor

    def process_with_context(
        self,
        pdf_path: str,
        context_window: int = 2
    ) -> dict:
        """Process pages with surrounding context."""
        pages = self.processor.renderer.render_pages(pdf_path)
        results = []

        for i, page in enumerate(pages):
            # Build context from surrounding pages
            context = self._build_context(results, i, context_window)

            # Process with context
            prompt = f"""Process this document page.

Previous context:
{context}

Extract content maintaining consistency with previous pages.
Continue any tables or lists that span pages.
"""
            result = self.processor.vlm.analyze_image(page, prompt)
            results.append({
                'page': i + 1,
                'content': result
            })

        return {'pages': results}

    def _build_context(
        self,
        previous_results: list,
        current_index: int,
        window: int
    ) -> str:
        """Build context string from previous pages."""
        start = max(0, current_index - window)
        context_pages = previous_results[start:current_index]

        if not context_pages:
            return "This is the first page."

        context = []
        for p in context_pages:
            # Summarize previous page content
            content_preview = p['content'][:500] + "..." if len(p['content']) > 500 else p['content']
            context.append(f"Page {p['page']}: {content_preview}")

        return "\n".join(context)

Multimodal Embeddings and Indexing

Generate embeddings for multiple modalities and build a searchable multimodal index.

from typing import Union
import numpy as np

class MultimodalEmbedder:
    """Generate embeddings for multiple modalities."""

    def __init__(
        self,
        clip_model,
        text_embedder,
        audio_embedder=None
    ):
        self.clip = clip_model
        self.text = text_embedder
        self.audio = audio_embedder

    def embed(
        self,
        content: Union[str, bytes],
        modality: str
    ) -> np.ndarray:
        """Embed content based on modality."""
        if modality == "text":
            return self.text.embed(content)
        elif modality == "image":
            return self.clip.encode_image(content)
        elif modality == "image_text_query":
            # Use CLIP text encoder for image search queries
            return self.clip.encode_text(content)
        elif modality == "audio" and self.audio:
            return self.audio.embed(content)
        else:
            raise ValueError(f"Unsupported modality: {modality}")

    def embed_document(
        self,
        document: dict
    ) -> dict:
        """Embed a multimodal document."""
        embeddings = {}

        if 'text' in document:
            embeddings['text'] = self.embed(document['text'], 'text')

        if 'images' in document:
            embeddings['images'] = [
                self.embed(img, 'image')
                for img in document['images']
            ]

        if 'audio' in document and self.audio:
            embeddings['audio'] = self.embed(document['audio'], 'audio')

        return embeddings


class MultimodalIndex:
    """Index and search multimodal content."""

    def __init__(
        self,
        embedder: MultimodalEmbedder,
        vector_store
    ):
        self.embedder = embedder
        self.store = vector_store

    def index_document(
        self,
        doc_id: str,
        document: dict,
        metadata: dict = None
    ) -> None:
        """Index a multimodal document."""
        embeddings = self.embedder.embed_document(document)

        # Index text embedding
        if 'text' in embeddings:
            self.store.upsert(
                id=f"{doc_id}_text",
                vector=embeddings['text'],
                metadata={
                    **(metadata or {}),
                    'doc_id': doc_id,
                    'modality': 'text',
                    'content_preview': document['text'][:200]
                }
            )

        # Index image embeddings
        for i, img_emb in enumerate(embeddings.get('images', [])):
            self.store.upsert(
                id=f"{doc_id}_image_{i}",
                vector=img_emb,
                metadata={
                    **(metadata or {}),
                    'doc_id': doc_id,
                    'modality': 'image',
                    'image_index': i
                }
            )

    def search(
        self,
        query: str,
        modalities: list[str] = None,
        top_k: int = 10
    ) -> list[dict]:
        """Search across modalities."""
        modalities = modalities or ['text', 'image']
        all_results = []

        # Search text with text embedding
        if 'text' in modalities:
            text_query_emb = self.embedder.embed(query, 'text')
            text_results = self.store.search(
                vector=text_query_emb,
                filter={'modality': 'text'},
                top_k=top_k
            )
            all_results.extend(text_results)

        # Search images with CLIP text embedding
        if 'image' in modalities:
            image_query_emb = self.embedder.embed(query, 'image_text_query')
            image_results = self.store.search(
                vector=image_query_emb,
                filter={'modality': 'image'},
                top_k=top_k
            )
            all_results.extend(image_results)

        # Sort by score and dedupe by doc_id
        all_results.sort(key=lambda x: x['score'], reverse=True)
        return self._dedupe_by_doc(all_results, top_k)

    def _dedupe_by_doc(
        self,
        results: list[dict],
        limit: int
    ) -> list[dict]:
        """Deduplicate results by document ID."""
        seen_docs = set()
        deduped = []

        for result in results:
            doc_id = result['metadata']['doc_id']
            if doc_id not in seen_docs:
                seen_docs.add(doc_id)
                deduped.append(result)
                if len(deduped) >= limit:
                    break

        return deduped

Multimodal Document RAG

RAG system for documents containing text and images with cross-modal retrieval.

import base64

class MultimodalDocumentRAG:
    """RAG system for documents containing text and images."""

    def __init__(
        self,
        text_embedder,
        clip_model,
        vlm_client,
        vector_store
    ):
        self.text_embedder = text_embedder
        self.clip = clip_model
        self.vlm = vlm_client
        self.store = vector_store

    def index_document(self, doc_path: str, doc_id: str) -> dict:
        """Index a document with text and images."""
        # Extract text and images
        pages = self._extract_pages(doc_path)
        indexed_chunks = []

        for page_num, page in enumerate(pages):
            # Index text chunks
            for chunk in self._chunk_text(page['text']):
                chunk_id = f"{doc_id}_p{page_num}_t{len(indexed_chunks)}"
                text_embedding = self.text_embedder.encode(chunk)
                self.store.add(
                    id=chunk_id,
                    vector=text_embedding,
                    metadata={
                        'doc_id': doc_id,
                        'page': page_num,
                        'modality': 'text',
                        'content': chunk
                    }
                )
                indexed_chunks.append(chunk_id)

            # Index images with generated descriptions
            for img_idx, image in enumerate(page['images']):
                # Generate description using VLM
                description = self._describe_image(image, page['text'][:500])

                # Create combined embedding (description + image)
                img_id = f"{doc_id}_p{page_num}_i{img_idx}"

                # Index the description text
                desc_embedding = self.text_embedder.encode(description)
                self.store.add(
                    id=f"{img_id}_desc",
                    vector=desc_embedding,
                    metadata={
                        'doc_id': doc_id,
                        'page': page_num,
                        'modality': 'image_description',
                        'description': description,
                        'image_ref': img_id
                    }
                )

                # Index the image embedding (for visual similarity)
                img_embedding = self.clip.encode_image(image)
                self.store.add(
                    id=f"{img_id}_visual",
                    vector=img_embedding,
                    metadata={
                        'doc_id': doc_id,
                        'page': page_num,
                        'modality': 'image_visual',
                        'image_ref': img_id
                    }
                )

                indexed_chunks.append(img_id)

        return {'doc_id': doc_id, 'chunks_indexed': len(indexed_chunks)}

    def _describe_image(self, image_bytes: bytes, surrounding_text: str) -> str:
        """Generate searchable description of an image."""
        response = self.vlm.generate(
            messages=[{
                "role": "user",
                "content": [
                    {"type": "text", "text": f"""Describe this image for search indexing.
Context from surrounding document text: {surrounding_text}

Provide a detailed description including:
1. What the image shows (diagram, chart, photo, screenshot, etc.)
2. Key elements, labels, or data points
3. The main concept or information conveyed
4. Any text visible in the image"""},
                    {"type": "image", "source": {"type": "base64", "data": base64.b64encode(image_bytes).decode()}}
                ]
            }]
        )
        return response

    def query(self, question: str, top_k: int = 5) -> dict:
        """Query across text and images."""
        # Search text content
        text_embedding = self.text_embedder.encode(question)
        text_results = self.store.search(
            vector=text_embedding,
            filter={'modality': {'$in': ['text', 'image_description']}},
            top_k=top_k
        )

        # Search visual content (for questions about visual elements)
        clip_text_embedding = self.clip.encode_text(question)
        visual_results = self.store.search(
            vector=clip_text_embedding,
            filter={'modality': 'image_visual'},
            top_k=top_k // 2
        )

        # Merge and dedupe results
        all_results = self._merge_results(text_results, visual_results)

        # Build context including both text and image references
        context = self._build_multimodal_context(all_results)

        # Generate answer using VLM
        answer = self._generate_answer(question, context)

        return {
            'answer': answer,
            'sources': all_results
        }

    def _build_multimodal_context(self, results: list[dict]) -> list:
        """Build context with text and images for VLM."""
        context_parts = []

        for result in results:
            if result['metadata']['modality'] == 'text':
                context_parts.append({
                    "type": "text",
                    "text": f"[Page {result['metadata']['page']}]: {result['metadata']['content']}"
                })
            elif result['metadata']['modality'] in ['image_description', 'image_visual']:
                # Include both the image and its description
                image_ref = result['metadata'].get('image_ref')
                if image_ref:
                    image_data = self._load_image(image_ref)
                    context_parts.append({
                        "type": "text",
                        "text": f"[Image from page {result['metadata']['page']}]:"
                    })
                    context_parts.append({
                        "type": "image",
                        "source": {"type": "base64", "data": image_data}
                    })

        return context_parts

Multimodal Orchestration

Orchestrate multiple specialized models for processing different modalities in parallel.

class MultimodalOrchestrator:
    """Orchestrate multiple specialized models."""

    def __init__(
        self,
        vision_model,
        audio_model,
        text_model,
        fusion_model=None
    ):
        self.vision = vision_model
        self.audio = audio_model
        self.text = text_model
        self.fusion = fusion_model

    async def process_multimodal_input(
        self,
        inputs: dict
    ) -> dict:
        """Process multimodal input with parallel specialized models."""
        import asyncio

        tasks = []

        if 'image' in inputs:
            tasks.append(('vision', self._process_image(inputs['image'])))

        if 'audio' in inputs:
            tasks.append(('audio', self._process_audio(inputs['audio'])))

        if 'text' in inputs:
            tasks.append(('text', self._process_text(inputs['text'])))

        # Process modalities in parallel
        results = {}
        modality_results = await asyncio.gather(*[t[1] for t in tasks])
        for (modality, _), result in zip(tasks, modality_results):
            results[modality] = result

        # Fuse results if fusion model available
        if self.fusion and len(results) > 1:
            results['fused'] = await self._fuse_results(results)

        return results

    async def _process_image(self, image: bytes) -> dict:
        """Process image with vision model."""
        description = await self.vision.describe(image)
        objects = await self.vision.detect_objects(image)
        text = await self.vision.extract_text(image)

        return {
            'description': description,
            'objects': objects,
            'extracted_text': text
        }

    async def _process_audio(self, audio: bytes) -> dict:
        """Process audio with audio model."""
        transcription = await self.audio.transcribe(audio)
        return {'transcription': transcription}

    async def _process_text(self, text: str) -> dict:
        """Process text with text model."""
        embedding = await self.text.embed(text)
        entities = await self.text.extract_entities(text)
        return {'embedding': embedding, 'entities': entities}

    async def _fuse_results(self, results: dict) -> dict:
        """Combine results from multiple modalities."""
        fusion_prompt = f"""Combine these analyses into a unified understanding:

Vision analysis: {results.get('vision', {})}
Audio analysis: {results.get('audio', {})}
Text analysis: {results.get('text', {})}

Provide a coherent summary that integrates information across modalities."""

        fused = await self.fusion.generate(fusion_prompt)
        return {'summary': fused}

Cost-Aware Processing

Process multimodal content with budget awareness and optimization.

class CostAwareMultimodalProcessor:
    """Process multimodal content with cost awareness."""

    # Approximate costs per 1K tokens (example rates)
    COST_PER_1K = {
        'text_input': 0.003,
        'text_output': 0.015,
        'image_input': 0.005,  # Per image token
        'audio_minute': 0.006,
    }

    def __init__(self, processor, budget_per_request: float = 0.10):
        self.processor = processor
        self.budget = budget_per_request

    def estimate_cost(self, inputs: dict) -> float:
        """Estimate processing cost for inputs."""
        cost = 0.0

        if 'text' in inputs:
            tokens = len(inputs['text'].split()) * 1.3  # Rough estimate
            cost += (tokens / 1000) * self.COST_PER_1K['text_input']

        if 'images' in inputs:
            for img in inputs['images']:
                # Estimate image tokens based on resolution
                img_tokens = self._estimate_image_tokens(img)
                cost += (img_tokens / 1000) * self.COST_PER_1K['image_input']

        if 'audio' in inputs:
            duration_minutes = inputs.get('audio_duration', 1)
            cost += duration_minutes * self.COST_PER_1K['audio_minute']

        return cost

    def process_within_budget(self, inputs: dict) -> dict:
        """Process inputs, optimizing to stay within budget."""
        estimated_cost = self.estimate_cost(inputs)

        if estimated_cost <= self.budget:
            return self.processor.process(inputs)

        # Optimize to reduce cost
        optimized_inputs = self._optimize_inputs(inputs, self.budget)
        return self.processor.process(optimized_inputs)

    def _optimize_inputs(self, inputs: dict, target_budget: float) -> dict:
        """Reduce input size to meet budget."""
        optimized = inputs.copy()

        # Reduce image resolution
        if 'images' in optimized:
            optimized['images'] = [
                self._reduce_image_resolution(img, factor=0.5)
                for img in optimized['images']
            ]

        # Truncate text
        if 'text' in optimized and len(optimized['text']) > 4000:
            optimized['text'] = optimized['text'][:4000]

        return optimized

Usage Notes

Selecting the Right Components

Use Case Recommended Components
Simple image Q&A VisionProcessor
Document OCR DocumentProcessor + ImageOptimizer
Video summarization VideoProcessor
Podcast transcription AudioTranscriber + SpeakerDiarizer
Multimodal search MultimodalEmbedder + MultimodalIndex
Document Q&A with figures MultimodalDocumentRAG

Performance Considerations

  1. Batch processing: Process multiple items in parallel when possible
  2. Resolution optimization: Match image resolution to task requirements
  3. Chunking: Split long audio/video into manageable segments
  4. Caching: Cache embeddings and intermediate results
  5. Cost monitoring: Track token usage across modalities

Error Handling

All implementations should include: - Input validation (format, size limits) - Retry logic for transient API failures - Graceful degradation when optional components fail - Logging for debugging and monitoring