RAG Systems - Complete Code Reference
This file contains the complete code implementations for Chapter 7: RAG Systems Deep Dive. The chapter contains condensed examples with links to the full implementations here.
Table of Contents
- Document Extraction
- Chunking Strategies
- Vector Search with FAISS
- Hybrid Search
- Reranking
- Context Assembly and Generation
- Vector Database Operations
- Advanced RAG Patterns
- Evaluation and Debugging
- Production Components
- Complete RAG System
Document Extraction
Complete implementation for extracting text from various document formats.
import fitz # PyMuPDF for PDFs
from docx import Document
from bs4 import BeautifulSoup
import markdown
def extract_text(file_path: str) -> tuple[str, dict]:
"""Extract text and metadata from a document."""
extension = file_path.split('.')[-1].lower()
if extension == 'pdf':
return extract_pdf(file_path)
elif extension == 'docx':
return extract_docx(file_path)
elif extension in ('html', 'htm'):
return extract_html(file_path)
elif extension == 'md':
return extract_markdown(file_path)
else:
with open(file_path, 'r') as f:
return f.read(), {'format': 'plain'}
def extract_pdf(file_path: str) -> tuple[str, dict]:
"""Extract text from PDF, handling layout."""
doc = fitz.open(file_path)
text_parts = []
metadata = {
'format': 'pdf',
'pages': len(doc),
'title': doc.metadata.get('title', ''),
'author': doc.metadata.get('author', '')
}
for page_num, page in enumerate(doc):
# Extract text preserving reading order
text = page.get_text("text")
text_parts.append(f"[Page {page_num + 1}]\n{text}")
return '\n\n'.join(text_parts), metadata
def extract_html(file_path: str) -> tuple[str, dict]:
"""Extract text from HTML, removing boilerplate."""
with open(file_path, 'r') as f:
soup = BeautifulSoup(f.read(), 'html.parser')
# Remove navigation, scripts, styles
for element in soup(['nav', 'script', 'style', 'header', 'footer', 'aside']):
element.decompose()
# Get main content
main = soup.find('main') or soup.find('article') or soup.find('body')
text = main.get_text(separator='\n', strip=True)
metadata = {
'format': 'html',
'title': soup.title.string if soup.title else ''
}
return text, metadataChunking Strategies
Fixed-Size Chunking
def fixed_size_chunk(text: str, chunk_size: int = 512, overlap: int = 50) -> list[str]:
"""Split text into fixed-size chunks with overlap."""
tokens = tokenizer.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = start + chunk_size
chunk_tokens = tokens[start:end]
chunks.append(tokenizer.decode(chunk_tokens))
start = end - overlap
return chunksSemantic Chunking
def semantic_chunk(text: str, max_chunk_size: int = 512) -> list[str]:
"""Split text at paragraph boundaries, respecting max size."""
paragraphs = text.split('\n\n')
chunks = []
current_chunk = []
current_size = 0
for para in paragraphs:
para_size = len(tokenizer.encode(para))
if current_size + para_size > max_chunk_size and current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_size = 0
# Handle paragraphs larger than max_chunk_size
if para_size > max_chunk_size:
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_size = 0
# Split large paragraph with fixed-size
chunks.extend(fixed_size_chunk(para, max_chunk_size))
else:
current_chunk.append(para)
current_size += para_size
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return chunksRecursive Chunking
def recursive_chunk(
text: str,
chunk_size: int = 512,
separators: list[str] = ['\n\n', '\n', '. ', ' ']
) -> list[str]:
"""Recursively split text using progressively finer separators."""
def split_text(text: str, sep_idx: int) -> list[str]:
if len(tokenizer.encode(text)) <= chunk_size:
return [text]
if sep_idx >= len(separators):
# Fall back to character split
return fixed_size_chunk(text, chunk_size)
sep = separators[sep_idx]
parts = text.split(sep)
if len(parts) == 1:
# This separator didn't help, try next
return split_text(text, sep_idx + 1)
chunks = []
current = []
for part in parts:
test = sep.join(current + [part])
if len(tokenizer.encode(test)) <= chunk_size:
current.append(part)
else:
if current:
chunks.append(sep.join(current))
current = [part]
if current:
chunks.append(sep.join(current))
# Recursively split any chunks that are still too large
result = []
for chunk in chunks:
result.extend(split_text(chunk, sep_idx + 1))
return result
return split_text(text, 0)Sentence-Window Chunking
class SentenceWindowChunker:
"""Embed sentences, retrieve with window of context."""
def __init__(self, window_size: int = 3):
self.window_size = window_size
def create_chunks(self, text: str) -> list[dict]:
sentences = sent_tokenize(text)
chunks = []
for i, sentence in enumerate(sentences):
# The embedding is based on this sentence
# But the retrieved context includes surrounding sentences
start = max(0, i - self.window_size)
end = min(len(sentences), i + self.window_size + 1)
chunks.append({
'embed_text': sentence, # Used for embedding
'context': ' '.join(sentences[start:end]), # Retrieved text
'sentence_idx': i
})
return chunksDocument-Structure-Aware Chunking
def structure_aware_chunk(text: str, max_chunk_size: int = 512) -> list[dict]:
"""Chunk respecting document structure (headers, sections)."""
# Parse markdown-style headers
import re
header_pattern = r'^(#{1,6})\s+(.+)$'
lines = text.split('\n')
sections = []
current_section = {'headers': [], 'content': []}
for line in lines:
header_match = re.match(header_pattern, line)
if header_match:
# Save current section
if current_section['content']:
sections.append(current_section)
level = len(header_match.group(1))
title = header_match.group(2)
# Track header hierarchy
current_section = {
'headers': current_section['headers'][:level-1] + [title],
'content': []
}
else:
current_section['content'].append(line)
if current_section['content']:
sections.append(current_section)
# Convert sections to chunks
chunks = []
for section in sections:
header_context = ' > '.join(section['headers'])
content = '\n'.join(section['content'])
if len(tokenizer.encode(content)) <= max_chunk_size:
chunks.append({
'text': content,
'headers': section['headers'],
'header_context': header_context
})
else:
# Split large section
sub_chunks = recursive_chunk(content, max_chunk_size)
for sub in sub_chunks:
chunks.append({
'text': sub,
'headers': section['headers'],
'header_context': header_context
})
return chunksChunk Metadata
def create_indexed_chunk(text: str, doc_metadata: dict, chunk_idx: int) -> dict:
return {
'text': text,
'embedding': embed(text), # Generated later
# Metadata for filtering and citations
'source_file': doc_metadata['file_path'],
'source_title': doc_metadata.get('title', ''),
'page_number': doc_metadata.get('page'),
'section': doc_metadata.get('section'),
'chunk_index': chunk_idx,
'created_at': datetime.now().isoformat(),
# Metadata for filtering
'document_type': doc_metadata.get('type'),
'department': doc_metadata.get('department'),
'access_level': doc_metadata.get('access_level', 'public')
}Vector Search with FAISS
import faiss
# Exact search (brute force) {.unnumbered}
def create_exact_index(embeddings: np.ndarray) -> faiss.Index:
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim) # Inner product (dot product)
index.add(embeddings)
return index
# HNSW (fast, accurate, memory-heavy) {.unnumbered}
def create_hnsw_index(embeddings: np.ndarray, M: int = 32) -> faiss.Index:
dim = embeddings.shape[1]
index = faiss.IndexHNSWFlat(dim, M)
index.add(embeddings)
return index
# IVF + PQ (scalable to billions) {.unnumbered}
def create_ivf_pq_index(
embeddings: np.ndarray,
nlist: int = 1024, # Number of clusters
m: int = 32, # Number of subquantizers
bits: int = 8 # Bits per subquantizer
) -> faiss.Index:
dim = embeddings.shape[1]
quantizer = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, bits)
index.train(embeddings) # Required for IVF+PQ
index.add(embeddings)
return index
# Search {.unnumbered}
def search(index: faiss.Index, query_embedding: np.ndarray, k: int = 10):
distances, indices = index.search(query_embedding.reshape(1, -1), k)
return distances[0], indices[0]
# Batch embedding for efficiency {.unnumbered}
def embed_documents_batched(texts: list[str], batch_size: int = 64) -> np.ndarray:
"""Embed documents in batches for GPU efficiency."""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = model.encode(
batch,
normalize_embeddings=True,
show_progress_bar=False
)
embeddings.append(batch_embeddings)
return np.vstack(embeddings)Hybrid Search
BM25 Index
from rank_bm25 import BM25Okapi
class BM25Index:
def __init__(self, documents: list[str]):
# Tokenize documents
self.tokenized = [doc.lower().split() for doc in documents]
self.bm25 = BM25Okapi(self.tokenized)
self.documents = documents
def search(self, query: str, k: int = 10) -> list[tuple[int, float]]:
tokenized_query = query.lower().split()
scores = self.bm25.get_scores(tokenized_query)
# Get top k
top_k = np.argsort(scores)[::-1][:k]
return [(idx, scores[idx]) for idx in top_k]Score Fusion
def score_fusion(
vector_results: list[tuple[int, float]], # (doc_id, similarity)
bm25_results: list[tuple[int, float]], # (doc_id, score)
vector_weight: float = 0.7
) -> list[tuple[int, float]]:
"""Combine results using weighted score fusion."""
# Normalize scores to [0, 1]
def normalize(results):
if not results:
return {}
scores = [s for _, s in results]
min_s, max_s = min(scores), max(scores)
range_s = max_s - min_s if max_s > min_s else 1
return {idx: (s - min_s) / range_s for idx, s in results}
vector_scores = normalize(vector_results)
bm25_scores = normalize(bm25_results)
# Combine
all_docs = set(vector_scores.keys()) | set(bm25_scores.keys())
combined = {}
for doc_id in all_docs:
v_score = vector_scores.get(doc_id, 0)
b_score = bm25_scores.get(doc_id, 0)
combined[doc_id] = vector_weight * v_score + (1 - vector_weight) * b_score
# Sort by combined score
sorted_results = sorted(combined.items(), key=lambda x: x[1], reverse=True)
return sorted_resultsReciprocal Rank Fusion (RRF)
def reciprocal_rank_fusion(
result_lists: list[list[int]], # Multiple ranked lists of doc_ids
k: int = 60
) -> list[tuple[int, float]]:
"""
Combine multiple ranked lists using RRF.
RRF score for document d = sum(1 / (k + rank_i(d))) for each list i
k is a smoothing parameter (commonly 60).
"""
scores = defaultdict(float)
for results in result_lists:
for rank, doc_id in enumerate(results, 1):
scores[doc_id] += 1 / (k + rank)
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return sorted_resultsComplete Hybrid Retriever
class HybridRetriever:
def __init__(
self,
documents: list[str],
embedding_model: SentenceTransformer,
vector_weight: float = 0.7
):
self.documents = documents
self.embedding_model = embedding_model
self.vector_weight = vector_weight
# Build vector index
print("Building vector index...")
self.embeddings = embedding_model.encode(
[f"passage: {doc}" for doc in documents],
normalize_embeddings=True
)
self.vector_index = faiss.IndexFlatIP(self.embeddings.shape[1])
self.vector_index.add(self.embeddings)
# Build BM25 index
print("Building BM25 index...")
self.bm25_index = BM25Index(documents)
def search(self, query: str, k: int = 10, method: str = "rrf") -> list[dict]:
# Vector search
query_embedding = self.embedding_model.encode(
f"query: {query}",
normalize_embeddings=True
).reshape(1, -1)
distances, vector_indices = self.vector_index.search(query_embedding, k * 2)
vector_results = list(zip(vector_indices[0].tolist(), distances[0].tolist()))
# BM25 search
bm25_results = self.bm25_index.search(query, k * 2)
# Fuse results
if method == "rrf":
vector_ranked = [idx for idx, _ in vector_results]
bm25_ranked = [idx for idx, _ in bm25_results]
fused = reciprocal_rank_fusion([vector_ranked, bm25_ranked])
else: # score fusion
fused = score_fusion(vector_results, bm25_results, self.vector_weight)
# Return top k with documents
results = []
for doc_id, score in fused[:k]:
results.append({
'doc_id': doc_id,
'score': score,
'text': self.documents[doc_id]
})
return resultsHybrid Weight Tuning
def tune_hybrid_weights(
retriever: HybridRetriever,
eval_queries: list[str],
ground_truth: list[list[int]], # Relevant doc IDs for each query
weight_range: list[float] = [0.3, 0.5, 0.7, 0.9]
) -> float:
"""Find optimal hybrid weight via grid search."""
best_weight = 0.5
best_score = 0
for weight in weight_range:
retriever.vector_weight = weight
total_recall = 0
for query, relevant_ids in zip(eval_queries, ground_truth):
results = retriever.search(query, k=10)
retrieved_ids = {r['doc_id'] for r in results}
recall = len(retrieved_ids & set(relevant_ids)) / len(relevant_ids)
total_recall += recall
avg_recall = total_recall / len(eval_queries)
if avg_recall > best_score:
best_score = avg_recall
best_weight = weight
return best_weightReranking
Cross-Encoder Reranker
from sentence_transformers import CrossEncoder
class Reranker:
def __init__(self, model_name: str = "BAAI/bge-reranker-large"):
self.model = CrossEncoder(model_name)
def rerank(
self,
query: str,
documents: list[str],
top_k: int = 5
) -> list[tuple[int, float, str]]:
"""Rerank documents by relevance to query."""
# Create query-document pairs
pairs = [[query, doc] for doc in documents]
# Score all pairs
scores = self.model.predict(pairs)
# Sort by score
scored_docs = list(zip(range(len(documents)), scores, documents))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return scored_docs[:top_k]
def batch_rerank(self, query: str, documents: list[str], batch_size: int = 32):
pairs = [[query, doc] for doc in documents]
scores = []
for i in range(0, len(pairs), batch_size):
batch = pairs[i:i + batch_size]
batch_scores = self.model.predict(batch)
scores.extend(batch_scores)
return scoresRAG Pipeline with Reranking
class RAGPipeline:
def __init__(
self,
retriever: HybridRetriever,
reranker: Reranker,
llm_client
):
self.retriever = retriever
self.reranker = reranker
self.llm = llm_client
def query(self, question: str, retrieve_k: int = 50, rerank_k: int = 5) -> str:
# Stage 1: Retrieve candidates
candidates = self.retriever.search(question, k=retrieve_k)
# Stage 2: Rerank
documents = [c['text'] for c in candidates]
reranked = self.reranker.rerank(question, documents, top_k=rerank_k)
# Stage 3: Generate with context
context = "\n\n---\n\n".join([doc for _, _, doc in reranked])
prompt = f"""Based on the following context, answer the question.
Context:
{context}
Question: {question}
Answer:"""
response = self.llm.generate(prompt)
return responseLLM-based Reranking
def llm_rerank(
query: str,
documents: list[str],
llm_client,
top_k: int = 5
) -> list[tuple[int, str]]:
"""Rerank using an LLM to score relevance."""
scores = []
for i, doc in enumerate(documents):
prompt = f"""Rate the relevance of this document to the query on a scale of 0-10.
Only output the number, nothing else.
Query: {query}
Document: {doc[:1000]} # Truncate for efficiency
Relevance (0-10):"""
response = llm_client.generate(prompt, max_tokens=5)
try:
score = float(response.strip())
except ValueError:
score = 0
scores.append((i, score, doc))
scores.sort(key=lambda x: x[1], reverse=True)
return [(idx, doc) for idx, _, doc in scores[:top_k]]Context Assembly and Generation
Context Window Management
def assemble_context(
documents: list[dict],
max_context_tokens: int = 8000,
tokenizer
) -> str:
"""Assemble documents into context, respecting token limits."""
assembled = []
total_tokens = 0
for i, doc in enumerate(documents):
doc_text = doc['text']
doc_tokens = len(tokenizer.encode(doc_text))
if total_tokens + doc_tokens > max_context_tokens:
# Truncate this document to fit
remaining = max_context_tokens - total_tokens
if remaining > 100: # Worth including partial
truncated = tokenizer.decode(tokenizer.encode(doc_text)[:remaining])
assembled.append(f"[Document {i+1}]\n{truncated}\n[truncated]")
break
assembled.append(f"[Document {i+1}]\n{doc_text}")
total_tokens += doc_tokens
return "\n\n---\n\n".join(assembled)Citation and Attribution
def generate_with_citations(
question: str,
documents: list[dict],
llm_client
) -> dict:
"""Generate answer with inline citations."""
# Number documents for reference
context_parts = []
for i, doc in enumerate(documents):
context_parts.append(f"[{i+1}] {doc['text']}")
# Store metadata for final citation list
doc['citation_id'] = i + 1
context = "\n\n".join(context_parts)
prompt = f"""Answer the question based on the provided documents.
Include citations in brackets [1], [2], etc. for any facts from the documents.
If the documents don't contain relevant information, say so.
Documents:
{context}
Question: {question}
Answer with citations:"""
response = llm_client.generate(prompt)
# Extract which documents were cited
import re
cited_ids = set(int(m) for m in re.findall(r'\[(\d+)\]', response))
# Build citation list
citations = []
for doc in documents:
if doc['citation_id'] in cited_ids:
citations.append({
'id': doc['citation_id'],
'source': doc.get('source', 'Unknown'),
'title': doc.get('title', ''),
'url': doc.get('url', '')
})
return {
'answer': response,
'citations': citations
}Handling No Results with Fallback
def query_with_fallback(
question: str,
retriever: HybridRetriever,
reranker: Reranker,
llm_client,
relevance_threshold: float = 0.3
) -> str:
candidates = retriever.search(question, k=20)
if not candidates:
return "I couldn't find any relevant information in the knowledge base."
reranked = reranker.rerank(question, [c['text'] for c in candidates], top_k=5)
# Check if top results are actually relevant
top_score = reranked[0][1]
if top_score < relevance_threshold:
return f"""I found some information that might be related, but I'm not confident it answers your question.
The most relevant document mentions: {reranked[0][2][:200]}...
Could you rephrase your question or provide more context?"""
# Proceed with generation
context = "\n\n".join([doc for _, _, doc in reranked])
return generate_answer(question, context, llm_client)Vector Database Operations
Qdrant Operations
from qdrant_client import QdrantClient
from qdrant_client.http import models
client = QdrantClient("localhost", port=6333)
# Create collection {.unnumbered}
client.create_collection(
collection_name="documents",
vectors_config=models.VectorParams(
size=1024, # Embedding dimension
distance=models.Distance.COSINE
)
)
# Insert vectors with metadata {.unnumbered}
client.upsert(
collection_name="documents",
points=[
models.PointStruct(
id=i,
vector=embedding.tolist(),
payload={
"text": chunk_text,
"source": source_file,
"page": page_num,
"department": department
}
)
for i, (embedding, chunk_text, source_file, page_num, department)
in enumerate(chunks_with_metadata)
]
)
# Search with filtering {.unnumbered}
results = client.search(
collection_name="documents",
query_vector=query_embedding.tolist(),
query_filter=models.Filter(
must=[
models.FieldCondition(
key="department",
match=models.MatchValue(value="engineering")
)
]
),
limit=10
)Metadata Filtering and Indexing
# Create payload indexes for filterable fields {.unnumbered}
client.create_payload_index(
collection_name="documents",
field_name="department",
field_schema=models.PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name="documents",
field_name="date",
field_schema=models.PayloadSchemaType.DATETIME
)
# Filtered queries use indexes instead of scanning {.unnumbered}
results = client.search(
collection_name="documents",
query_vector=query_embedding,
query_filter=models.Filter(
must=[
models.FieldCondition(
key="department",
match=models.MatchValue(value="legal")
),
models.FieldCondition(
key="date",
range=models.Range(
gte="2024-01-01T00:00:00Z"
)
)
]
),
limit=10
)Index Tuning
# HNSW parameters {.unnumbered}
client.create_collection(
collection_name="documents",
vectors_config=models.VectorParams(
size=1024,
distance=models.Distance.COSINE,
),
hnsw_config=models.HnswConfigDiff(
m=16, # Number of bi-directional links (higher = more accurate, more memory)
ef_construct=100, # Size of candidate list during construction
)
)
# Search-time parameter {.unnumbered}
results = client.search(
collection_name="documents",
query_vector=query_embedding,
limit=10,
search_params=models.SearchParams(
hnsw_ef=128 # Size of candidate list during search (higher = more accurate, slower)
)
)Vector Store Manager with Updates
class VectorStoreManager:
"""Manage vector store updates with minimal disruption."""
def update_document(self, doc_id: str, new_chunks: list[dict]):
"""Update a document's chunks."""
# Delete old chunks for this document
self.client.delete(
collection_name="documents",
points_selector=models.FilterSelector(
filter=models.Filter(
must=[
models.FieldCondition(
key="doc_id",
match=models.MatchValue(value=doc_id)
)
]
)
)
)
# Insert new chunks
self.client.upsert(
collection_name="documents",
points=new_chunks
)
def reindex_with_new_embeddings(self, new_embedding_model):
"""Re-embed all documents with a new model."""
# Create new collection
new_collection = f"documents_v{int(time.time())}"
self.client.create_collection(collection_name=new_collection, ...)
# Batch re-embed and insert
for batch in self.iterate_all_documents():
new_embeddings = new_embedding_model.encode([d['text'] for d in batch])
self.client.upsert(
collection_name=new_collection,
points=[...]
)
# Atomic switch (application reads from new collection)
self.current_collection = new_collection
# Delete old collection after confirming switch
self.client.delete_collection("documents")Advanced RAG Patterns
Query Expansion
def expand_query(query: str, llm_client) -> str:
"""Expand query with related terms."""
prompt = f"""Given this search query, generate additional search terms that would help find relevant documents.
Query: {query}
Additional search terms (comma-separated):"""
expansion = llm_client.generate(prompt, max_tokens=50)
return f"{query} {expansion}"Hypothetical Document Embedding (HyDE)
def hyde_search(query: str, retriever, llm_client) -> list[dict]:
"""Search using HyDE: generate hypothetical document, then retrieve."""
# Generate hypothetical answer
prompt = f"""Answer this question in a detailed paragraph, as if you were writing
a section of a document that would answer it perfectly.
Question: {query}
Answer:"""
hypothetical_doc = llm_client.generate(prompt, max_tokens=300)
# Embed and search with the hypothetical document
return retriever.search(hypothetical_doc)Query Decomposition
def decompose_query(query: str, llm_client) -> list[str]:
"""Decompose complex query into simpler sub-queries."""
prompt = f"""Break this complex question into simpler sub-questions that can be
answered independently.
Question: {query}
Sub-questions (one per line):"""
response = llm_client.generate(prompt)
sub_queries = [q.strip() for q in response.strip().split('\n') if q.strip()]
return sub_queries
def multi_query_search(query: str, retriever, llm_client) -> list[dict]:
"""Search with decomposed queries and merge results."""
sub_queries = decompose_query(query, llm_client)
all_results = []
seen_ids = set()
for sub_query in sub_queries:
results = retriever.search(sub_query)
for r in results:
if r['doc_id'] not in seen_ids:
all_results.append(r)
seen_ids.add(r['doc_id'])
return all_resultsSelf-RAG
class SelfRAG:
"""Self-reflective RAG that decides when to retrieve and evaluates outputs."""
def query(self, question: str) -> str:
# Step 1: Decide if retrieval is needed
needs_retrieval = self.should_retrieve(question)
if needs_retrieval:
# Step 2: Retrieve
documents = self.retriever.search(question)
# Step 3: Generate with retrieval
response = self.generate_with_context(question, documents)
# Step 4: Evaluate if response is supported by documents
is_supported = self.evaluate_support(response, documents)
if not is_supported:
# Try again with more documents or different approach
response = self.regenerate(question, documents, response)
else:
# Generate without retrieval
response = self.generate_direct(question)
# Step 5: Evaluate if response is useful
is_useful = self.evaluate_usefulness(question, response)
if not is_useful:
response = self.improve_response(question, response)
return response
def should_retrieve(self, question: str) -> bool:
"""Decide if external knowledge is needed."""
prompt = f"""Does answering this question require external knowledge beyond
general knowledge? Answer YES or NO.
Question: {question}
Needs external knowledge:"""
response = self.llm.generate(prompt, max_tokens=5)
return "YES" in response.upper()
def evaluate_support(self, response: str, documents: list[dict]) -> bool:
"""Check if response is supported by retrieved documents."""
context = "\n\n".join([d['text'] for d in documents])
prompt = f"""Is this response fully supported by the provided documents?
Documents:
{context}
Response: {response}
Answer SUPPORTED or UNSUPPORTED:"""
result = self.llm.generate(prompt, max_tokens=10)
return "SUPPORTED" in result.upper()Corrective RAG (CRAG)
class CorrectiveRAG:
"""RAG with retrieval quality evaluation and correction."""
def query(self, question: str) -> str:
# Initial retrieval
documents = self.retriever.search(question)
# Evaluate retrieval quality
quality = self.evaluate_retrieval(question, documents)
if quality == "CORRECT":
# Retrieval is good, proceed with generation
context = self.extract_relevant_content(question, documents)
elif quality == "INCORRECT":
# Retrieval failed, fall back to web search or other sources
documents = self.web_search(question)
context = self.extract_relevant_content(question, documents)
else: # AMBIGUOUS
# Partially relevant, combine with additional search
additional_docs = self.web_search(question)
combined = documents + additional_docs
context = self.extract_relevant_content(question, combined)
return self.generate(question, context)
def evaluate_retrieval(self, question: str, documents: list[dict]) -> str:
"""Evaluate if retrieval results are relevant."""
evaluations = []
for doc in documents[:5]: # Evaluate top 5
prompt = f"""Is this document relevant to the question?
Question: {question}
Document: {doc['text'][:500]}
Answer CORRECT (highly relevant), INCORRECT (not relevant), or AMBIGUOUS (partially relevant):"""
result = self.llm.generate(prompt, max_tokens=10)
evaluations.append(result.strip().upper())
# Aggregate evaluations
if evaluations.count("CORRECT") >= 3:
return "CORRECT"
elif evaluations.count("INCORRECT") >= 3:
return "INCORRECT"
else:
return "AMBIGUOUS"Multi-hop RAG
class MultiHopRAG:
"""RAG for questions requiring multiple retrieval steps."""
def query(self, question: str, max_hops: int = 3) -> str:
collected_context = []
current_query = question
for hop in range(max_hops):
# Retrieve for current query
documents = self.retriever.search(current_query)
collected_context.extend(documents)
# Check if we have enough to answer
can_answer, missing_info = self.check_answerability(
question, collected_context
)
if can_answer:
break
if missing_info:
# Generate follow-up query to find missing information
current_query = self.generate_followup_query(
question, collected_context, missing_info
)
else:
break
return self.generate_answer(question, collected_context)
def check_answerability(
self,
question: str,
context: list[dict]
) -> tuple[bool, str]:
"""Check if context is sufficient to answer, identify what's missing."""
context_text = "\n\n".join([d['text'] for d in context])
prompt = f"""Given this question and context, can the question be fully answered?
Question: {question}
Context:
{context_text}
If YES, respond: ANSWERABLE
If NO, respond: MISSING: [what information is still needed]"""
result = self.llm.generate(prompt)
if "ANSWERABLE" in result.upper():
return True, None
elif "MISSING:" in result.upper():
missing = result.split("MISSING:")[-1].strip()
return False, missing
else:
return False, NoneGraphRAG
from dataclasses import dataclass
from typing import Optional
import networkx as nx
@dataclass
class Entity:
id: str
name: str
type: str
properties: dict
@dataclass
class Relationship:
source_id: str
target_id: str
type: str
properties: dict
class GraphRAG:
"""RAG system combining vector retrieval with knowledge graph traversal."""
def __init__(
self,
vector_store,
graph_store,
llm_client,
entity_extractor
):
self.vectors = vector_store
self.graph = graph_store
self.llm = llm_client
self.entity_extractor = entity_extractor
def build_index(self, documents: list[dict]) -> None:
"""Index documents and extract knowledge graph."""
for doc in documents:
# Standard vector indexing
embedding = self.embed(doc['text'])
self.vectors.add(doc['id'], embedding, doc)
# Extract entities and relationships
entities, relationships = self.entity_extractor.extract(doc['text'])
for entity in entities:
self.graph.add_node(
entity.id,
name=entity.name,
type=entity.type,
doc_id=doc['id'],
**entity.properties
)
for rel in relationships:
self.graph.add_edge(
rel.source_id,
rel.target_id,
type=rel.type,
doc_id=doc['id'],
**rel.properties
)
def query(self, question: str, top_k: int = 10) -> dict:
"""Query combining vector retrieval and graph traversal."""
# Step 1: Extract entities from query
query_entities = self.entity_extractor.extract_entities(question)
# Step 2: Vector retrieval for relevant chunks
vector_results = self.vectors.search(question, k=top_k)
# Step 3: Graph-based retrieval
graph_context = self._get_graph_context(query_entities)
# Step 4: Combine and generate
combined_context = self._merge_contexts(vector_results, graph_context)
response = self.llm.generate(
self._build_prompt(question, combined_context)
)
return {
'answer': response,
'vector_sources': vector_results,
'graph_entities': graph_context['entities'],
'relationships': graph_context['relationships']
}
def _get_graph_context(
self,
query_entities: list[Entity],
max_hops: int = 2
) -> dict:
"""Traverse graph to find relevant connected entities."""
relevant_entities = []
relevant_relationships = []
for entity in query_entities:
# Find matching nodes in graph
matches = self.graph.find_by_name(entity.name, entity.type)
for match in matches:
# Get neighborhood within max_hops
neighbors = nx.single_source_shortest_path_length(
self.graph.G, match, cutoff=max_hops
)
for neighbor_id, distance in neighbors.items():
node_data = self.graph.get_node(neighbor_id)
relevant_entities.append({
'entity': node_data,
'distance': distance,
'source_entity': entity.name
})
# Get edges on the path
if distance > 0:
path = nx.shortest_path(self.graph.G, match, neighbor_id)
for i in range(len(path) - 1):
edge_data = self.graph.get_edge(path[i], path[i+1])
relevant_relationships.append(edge_data)
return {
'entities': relevant_entities,
'relationships': relevant_relationships
}
def _merge_contexts(
self,
vector_results: list[dict],
graph_context: dict
) -> str:
"""Merge vector and graph contexts into unified context."""
sections = []
# Add vector-retrieved content
sections.append("## Retrieved Documents")
for i, result in enumerate(vector_results[:5]):
sections.append(f"\n### Source {i+1}")
sections.append(result['text'][:1000])
# Add graph-derived facts
if graph_context['entities']:
sections.append("\n## Related Entities and Relationships")
# Group by entity type
by_type = {}
for item in graph_context['entities']:
etype = item['entity'].get('type', 'Unknown')
if etype not in by_type:
by_type[etype] = []
by_type[etype].append(item['entity']['name'])
for etype, names in by_type.items():
sections.append(f"\n**{etype}**: {', '.join(names[:10])}")
# Add key relationships
sections.append("\n**Key Relationships:**")
for rel in graph_context['relationships'][:10]:
sections.append(
f"- {rel['source']} --[{rel['type']}]--> {rel['target']}"
)
return '\n'.join(sections)
class EntityExtractor:
"""Extract entities and relationships from text using LLM."""
def __init__(self, llm_client):
self.llm = llm_client
def extract(self, text: str) -> tuple[list[Entity], list[Relationship]]:
"""Extract entities and relationships from text."""
prompt = f"""Extract entities and relationships from this text.
Text: {text[:3000]}
Output as JSON:
{{
"entities": [
{{"id": "unique_id", "name": "Entity Name", "type": "PERSON|ORG|CONCEPT|PRODUCT|...", "properties": {{}}}}
],
"relationships": [
{{"source_id": "id1", "target_id": "id2", "type": "WORKS_FOR|CREATED|RELATED_TO|...", "properties": {{}}}}
]
}}"""
response = self.llm.generate(prompt, response_format="json")
data = json.loads(response)
entities = [Entity(**e) for e in data['entities']]
relationships = [Relationship(**r) for r in data['relationships']]
return entities, relationshipsEvaluation and Debugging
Retrieval Evaluation
def evaluate_retrieval(
retriever,
test_queries: list[str],
ground_truth: list[set[int]], # Relevant doc IDs per query
k_values: list[int] = [1, 5, 10, 20]
) -> dict:
"""Evaluate retrieval quality."""
metrics = {f'recall@{k}': [] for k in k_values}
metrics.update({f'precision@{k}': [] for k in k_values})
metrics['mrr'] = [] # Mean Reciprocal Rank
for query, relevant_ids in zip(test_queries, ground_truth):
results = retriever.search(query, k=max(k_values))
retrieved_ids = [r['doc_id'] for r in results]
# Recall@k: What fraction of relevant docs were retrieved?
for k in k_values:
retrieved_at_k = set(retrieved_ids[:k])
recall = len(retrieved_at_k & relevant_ids) / len(relevant_ids)
metrics[f'recall@{k}'].append(recall)
precision = len(retrieved_at_k & relevant_ids) / k
metrics[f'precision@{k}'].append(precision)
# MRR: Rank of first relevant document
for rank, doc_id in enumerate(retrieved_ids, 1):
if doc_id in relevant_ids:
metrics['mrr'].append(1 / rank)
break
else:
metrics['mrr'].append(0)
# Average metrics
return {k: np.mean(v) for k, v in metrics.items()}End-to-End Evaluation with LLM Judge
def evaluate_rag_e2e(
rag_pipeline,
test_cases: list[dict], # {'question': str, 'expected_answer': str, 'expected_facts': list[str]}
evaluator_llm
) -> dict:
"""Evaluate end-to-end RAG quality using LLM-as-judge."""
results = {
'factual_accuracy': [],
'completeness': [],
'hallucination_rate': []
}
for case in test_cases:
response = rag_pipeline.query(case['question'])
# Use LLM to evaluate
eval_prompt = f"""Evaluate this RAG system response.
Question: {case['question']}
Expected Answer: {case['expected_answer']}
System Response: {response}
Rate on a scale of 1-5:
1. Factual Accuracy: Does the response contain correct information?
2. Completeness: Does it cover the key points?
3. Hallucination: Does it include information not in the expected answer? (1=lots, 5=none)
Output as JSON: {{"accuracy": N, "completeness": N, "hallucination": N}}"""
eval_response = evaluator_llm.generate(eval_prompt)
scores = json.loads(eval_response)
results['factual_accuracy'].append(scores['accuracy'])
results['completeness'].append(scores['completeness'])
results['hallucination_rate'].append(6 - scores['hallucination']) # Invert
return {k: np.mean(v) for k, v in results.items()}RAG Debugger
class RAGDebugger:
"""Debug RAG failures by analyzing each component."""
def diagnose(self, query: str, expected_answer: str) -> dict:
"""Diagnose why RAG failed for a query."""
diagnosis = {
'query': query,
'expected': expected_answer,
'issues': []
}
# Step 1: Check retrieval
retrieved_docs = self.retriever.search(query, k=20)
relevant_docs = self.find_relevant_docs(expected_answer, retrieved_docs)
if not relevant_docs:
diagnosis['issues'].append({
'component': 'retrieval',
'issue': 'No relevant documents retrieved',
'details': {
'retrieved_docs': [d['text'][:200] for d in retrieved_docs[:5]],
'suggestion': 'Check embedding model, chunking, or document coverage'
}
})
return diagnosis
# Step 2: Check if relevant docs are ranked high enough
relevant_ranks = [
i for i, d in enumerate(retrieved_docs)
if d['doc_id'] in [r['doc_id'] for r in relevant_docs]
]
if min(relevant_ranks) > 5:
diagnosis['issues'].append({
'component': 'ranking',
'issue': f'Relevant docs ranked too low (best rank: {min(relevant_ranks)})',
'details': {
'relevant_doc_ranks': relevant_ranks,
'suggestion': 'Consider reranking or hybrid search tuning'
}
})
# Step 3: Check reranking
reranked = self.reranker.rerank(query, [d['text'] for d in retrieved_docs], top_k=5)
reranked_ids = [retrieved_docs[idx]['doc_id'] for idx, _, _ in reranked]
relevant_in_reranked = [
r for r in relevant_docs if r['doc_id'] in reranked_ids
]
if not relevant_in_reranked:
diagnosis['issues'].append({
'component': 'reranking',
'issue': 'Reranker filtered out relevant documents',
'details': {
'reranked_docs': [retrieved_docs[idx]['text'][:200] for idx, _, _ in reranked],
'suggestion': 'Reranker may not be suited for this domain'
}
})
return diagnosis
# Step 4: Check generation
context = "\n\n".join([d['text'] for d in relevant_docs[:5]])
generated = self.generate(query, context)
if not self.answer_contains_expected(generated, expected_answer):
diagnosis['issues'].append({
'component': 'generation',
'issue': 'Generation failed despite good retrieval',
'details': {
'context_provided': context[:500],
'generated_answer': generated,
'suggestion': 'Check prompt engineering or context assembly'
}
})
if not diagnosis['issues']:
diagnosis['status'] = 'SUCCESS'
else:
diagnosis['status'] = 'FAILURE'
return diagnosis
def find_relevant_docs(self, expected_answer: str, docs: list[dict]) -> list[dict]:
"""Find docs that contain information needed for the answer."""
relevant = []
for doc in docs:
if self.has_answer_overlap(doc['text'], expected_answer):
relevant.append(doc)
return relevantBuilding Evaluation Datasets
def create_evaluation_dataset(
documents: list[dict],
llm_client,
num_samples: int = 100
) -> list[dict]:
"""Generate evaluation questions from documents."""
eval_data = []
sampled_docs = random.sample(documents, min(num_samples * 2, len(documents)))
for doc in sampled_docs:
if len(eval_data) >= num_samples:
break
# Generate a question that this document answers
prompt = f"""Based on this document, generate a question that someone might ask
that this document would answer. Also provide the expected answer.
Document:
{doc['text'][:1500]}
Generate in this format:
Question: [your question]
Answer: [expected answer based on the document]"""
response = llm_client.generate(prompt)
# Parse response
try:
lines = response.strip().split('\n')
question = None
answer = None
for line in lines:
if line.startswith('Question:'):
question = line.replace('Question:', '').strip()
elif line.startswith('Answer:'):
answer = line.replace('Answer:', '').strip()
if question and answer:
eval_data.append({
'question': question,
'expected_answer': answer,
'source_doc_id': doc['id'],
'source_doc_text': doc['text']
})
except Exception:
continue
return eval_dataRAG Monitor for Production
class RAGMonitor:
"""Monitor RAG system health in production."""
def __init__(self):
self.metrics = defaultdict(list)
def log_query(
self,
query: str,
retrieved_docs: list[dict],
reranked_docs: list[dict],
response: str,
latencies: dict,
user_feedback: Optional[str] = None
):
"""Log a query for monitoring."""
self.metrics['retrieval_latency'].append(latencies['retrieval'])
self.metrics['rerank_latency'].append(latencies['reranking'])
self.metrics['generation_latency'].append(latencies['generation'])
self.metrics['total_latency'].append(sum(latencies.values()))
self.metrics['docs_retrieved'].append(len(retrieved_docs))
self.metrics['docs_after_rerank'].append(len(reranked_docs))
# Retrieval score distribution
if retrieved_docs:
self.metrics['top_retrieval_score'].append(retrieved_docs[0]['score'])
self.metrics['avg_retrieval_score'].append(
np.mean([d['score'] for d in retrieved_docs])
)
if user_feedback:
self.metrics['user_feedback'].append(user_feedback)
def get_health_report(self) -> dict:
"""Generate health report."""
return {
'latency': {
'p50_total': np.percentile(self.metrics['total_latency'], 50),
'p95_total': np.percentile(self.metrics['total_latency'], 95),
'p99_total': np.percentile(self.metrics['total_latency'], 99),
'breakdown': {
'retrieval_avg': np.mean(self.metrics['retrieval_latency']),
'rerank_avg': np.mean(self.metrics['rerank_latency']),
'generation_avg': np.mean(self.metrics['generation_latency'])
}
},
'retrieval': {
'avg_docs_retrieved': np.mean(self.metrics['docs_retrieved']),
'avg_top_score': np.mean(self.metrics['top_retrieval_score']),
'low_score_rate': np.mean([
1 if s < 0.3 else 0
for s in self.metrics['top_retrieval_score']
])
},
'user_satisfaction': self.calculate_satisfaction()
}
def detect_anomalies(self) -> list[str]:
"""Detect potential issues."""
anomalies = []
# Latency spike
recent_latency = self.metrics['total_latency'][-100:]
baseline_latency = self.metrics['total_latency'][:-100]
if baseline_latency and np.mean(recent_latency) > 1.5 * np.mean(baseline_latency):
anomalies.append(f"Latency spike: {np.mean(recent_latency):.2f}s vs baseline {np.mean(baseline_latency):.2f}s")
# Low retrieval scores
recent_scores = self.metrics['top_retrieval_score'][-100:]
if np.mean(recent_scores) < 0.3:
anomalies.append(f"Low retrieval scores: avg {np.mean(recent_scores):.2f}")
# User feedback degradation
recent_feedback = self.metrics['user_feedback'][-50:]
if recent_feedback:
negative_rate = sum(1 for f in recent_feedback if f == 'negative') / len(recent_feedback)
if negative_rate > 0.3:
anomalies.append(f"High negative feedback rate: {negative_rate:.0%}")
return anomaliesProduction Components
RAG Cache
class RAGCache:
"""Multi-level caching for RAG systems."""
def __init__(self, redis_client, embedding_cache_size: int = 10000):
self.redis = redis_client
self.embedding_cache = LRUCache(embedding_cache_size)
def get_or_compute_embedding(self, text: str) -> np.ndarray:
"""Cache embeddings to avoid recomputation."""
cache_key = hashlib.md5(text.encode()).hexdigest()
if cache_key in self.embedding_cache:
return self.embedding_cache[cache_key]
embedding = self.embedding_model.encode(text)
self.embedding_cache[cache_key] = embedding
return embedding
def get_or_search(self, query: str, search_params: dict) -> list[dict]:
"""Cache search results for repeated queries."""
cache_key = f"search:{hashlib.md5((query + str(search_params)).encode()).hexdigest()}"
cached = self.redis.get(cache_key)
if cached:
return json.loads(cached)
results = self.retriever.search(query, **search_params)
# Cache with TTL (results may become stale as docs update)
self.redis.setex(cache_key, 3600, json.dumps(results)) # 1 hour TTL
return results
def get_or_generate(self, prompt_hash: str, generate_fn: callable) -> str:
"""Cache LLM responses for identical prompts."""
cached = self.redis.get(f"gen:{prompt_hash}")
if cached:
return cached.decode()
response = generate_fn()
self.redis.setex(f"gen:{prompt_hash}", 86400, response) # 24 hour TTL
return responseResilient RAG with Fallbacks
class ResilientRAG:
"""RAG with error handling and fallbacks."""
def query(self, question: str) -> RAGResponse:
try:
# Primary flow
return self._primary_query(question)
except RetrievalTimeout:
# Fallback to cached results or simplified search
return self._fallback_search(question)
except EmbeddingServiceError:
# Use BM25 only
return self._keyword_only_search(question)
except GenerationError:
# Return retrieved docs without generation
docs = self._retrieve(question)
return RAGResponse(
answer="I found relevant documents but couldn't generate a summary. Here are the sources:",
documents=docs,
status="partial"
)
except Exception as e:
logger.error(f"RAG failed: {e}")
return RAGResponse(
answer="I'm having trouble answering right now. Please try again.",
status="error"
)
def _primary_query(self, question: str) -> RAGResponse:
with timeout(seconds=10):
docs = self._retrieve(question)
with timeout(seconds=30):
answer = self._generate(question, docs)
return RAGResponse(answer=answer, documents=docs, status="success")Cost-Aware RAG
class CostAwareRAG:
"""RAG with cost tracking and limits."""
def __init__(self, daily_budget: float = 100.0):
self.daily_budget = daily_budget
self.daily_spend = 0.0
self.last_reset = datetime.now().date()
def query(self, question: str, priority: str = "normal") -> RAGResponse:
self._maybe_reset_daily()
# Check budget
estimated_cost = self._estimate_cost(question)
if self.daily_spend + estimated_cost > self.daily_budget:
if priority != "high":
return RAGResponse(
answer="Daily limit reached. Please try again tomorrow.",
status="rate_limited"
)
# Track cost
start_cost = self.daily_spend
result = self._execute_query(question)
actual_cost = self._calculate_actual_cost(result)
self.daily_spend += actual_cost
logger.info(f"Query cost: ${actual_cost:.4f}, Daily total: ${self.daily_spend:.2f}")
return result
def _estimate_cost(self, question: str) -> float:
# Rough estimate based on typical usage
# Embedding: ~$0.0001 per query
# Retrieval: ~free (self-hosted)
# Generation: ~$0.01-0.10 depending on model and length
return 0.05
def _calculate_actual_cost(self, result: RAGResponse) -> float:
# Calculate from token counts
input_tokens = result.input_token_count
output_tokens = result.output_token_count
# Prices vary by model
return (input_tokens * 0.00001) + (output_tokens * 0.00003)Document Update Pipeline
class DocumentUpdatePipeline:
"""Pipeline for handling document updates."""
def __init__(self, vector_store, embedding_model):
self.vector_store = vector_store
self.embedding_model = embedding_model
def process_update(self, doc_id: str, new_content: str, update_type: str):
"""Process a document update."""
if update_type == "create":
self._add_document(doc_id, new_content)
elif update_type == "update":
self._update_document(doc_id, new_content)
elif update_type == "delete":
self._delete_document(doc_id)
def _update_document(self, doc_id: str, new_content: str):
"""Update document: re-chunk, re-embed, replace in vector store."""
# Get old chunks
old_chunks = self.vector_store.get_by_metadata({"doc_id": doc_id})
# Create new chunks
new_chunks = self.chunker.chunk(new_content)
new_embeddings = self.embedding_model.encode([c['text'] for c in new_chunks])
# Atomic update: delete old, insert new
with self.vector_store.transaction():
self.vector_store.delete(ids=[c['id'] for c in old_chunks])
self.vector_store.insert(
ids=[f"{doc_id}_{i}" for i in range(len(new_chunks))],
embeddings=new_embeddings,
metadatas=[{**c, 'doc_id': doc_id} for c in new_chunks]
)
# Invalidate caches
self.cache.invalidate_by_doc(doc_id)
def sync_from_source(self, source_connector):
"""Sync documents from source system."""
changes = source_connector.get_changes_since(self.last_sync)
for change in changes:
self.process_update(
doc_id=change['id'],
new_content=change.get('content'),
update_type=change['type']
)
self.last_sync = datetime.now()Complete RAG System
Document Processor
from dataclasses import dataclass
from typing import Iterator
import hashlib
@dataclass
class Chunk:
id: str
text: str
metadata: dict
embedding: Optional[np.ndarray] = None
class DocumentProcessor:
"""Process documents into chunks ready for indexing."""
def __init__(
self,
chunk_size: int = 512,
chunk_overlap: int = 50,
embedding_model: str = "intfloat/e5-large-v2"
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.embedding_model = SentenceTransformer(embedding_model)
self.tokenizer = AutoTokenizer.from_pretrained(embedding_model)
def process_file(self, file_path: str) -> list[Chunk]:
"""Process a single file into chunks."""
# Extract text
text, metadata = self.extract_text(file_path)
metadata['source_file'] = file_path
# Chunk
chunks = self.chunk_text(text, metadata)
# Embed
texts = [c.text for c in chunks]
embeddings = self.embedding_model.encode(
[f"passage: {t}" for t in texts],
normalize_embeddings=True,
batch_size=32
)
for chunk, embedding in zip(chunks, embeddings):
chunk.embedding = embedding
return chunks
def extract_text(self, file_path: str) -> tuple[str, dict]:
"""Extract text from various file formats."""
# Implement based on file type
pass
def chunk_text(self, text: str, base_metadata: dict) -> list[Chunk]:
"""Chunk text with semantic awareness."""
# Use recursive chunking
raw_chunks = self.recursive_chunk(text)
chunks = []
for i, chunk_text in enumerate(raw_chunks):
chunk_id = hashlib.md5(
f"{base_metadata['source_file']}:{i}:{chunk_text[:100]}".encode()
).hexdigest()[:16]
chunks.append(Chunk(
id=chunk_id,
text=chunk_text,
metadata={
**base_metadata,
'chunk_index': i,
'char_start': text.find(chunk_text), # Approximate
'token_count': len(self.tokenizer.encode(chunk_text))
}
))
return chunksHybrid Retriever (Full Implementation)
class HybridRetriever:
"""Hybrid vector + BM25 retrieval with configurable fusion."""
def __init__(
self,
vector_store,
bm25_index,
vector_weight: float = 0.7,
fusion_method: str = "rrf"
):
self.vector_store = vector_store
self.bm25_index = bm25_index
self.vector_weight = vector_weight
self.fusion_method = fusion_method
self.embedding_model = SentenceTransformer("intfloat/e5-large-v2")
def search(
self,
query: str,
k: int = 10,
filter: dict = None
) -> list[RetrievalResult]:
"""Execute hybrid search."""
# Embed query
query_embedding = self.embedding_model.encode(
f"query: {query}",
normalize_embeddings=True
)
# Vector search
vector_results = self.vector_store.search(
query_embedding,
k=k * 3, # Fetch more for fusion
filter=filter
)
# BM25 search
bm25_results = self.bm25_index.search(query, k=k * 3)
# Fuse results
if self.fusion_method == "rrf":
fused = self.reciprocal_rank_fusion(vector_results, bm25_results)
else:
fused = self.score_fusion(vector_results, bm25_results)
return fused[:k]
def reciprocal_rank_fusion(
self,
vector_results: list,
bm25_results: list,
k: int = 60
) -> list[RetrievalResult]:
"""Combine results using RRF."""
scores = defaultdict(float)
for rank, result in enumerate(vector_results):
scores[result.id] += 1.0 / (k + rank + 1)
for rank, result in enumerate(bm25_results):
scores[result.id] += 1.0 / (k + rank + 1)
# Sort by combined score
sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
# Build result list
all_results = {r.id: r for r in vector_results + bm25_results}
return [
RetrievalResult(
id=id,
score=scores[id],
text=all_results[id].text,
metadata=all_results[id].metadata
)
for id in sorted_ids
if id in all_results
]RAG Generator
class RAGGenerator:
"""Generate responses with retrieved context."""
def __init__(
self,
llm_client,
reranker_model: str = "BAAI/bge-reranker-large",
use_reranking: bool = True
):
self.llm = llm_client
self.use_reranking = use_reranking
if use_reranking:
self.reranker = CrossEncoder(reranker_model)
def generate(
self,
query: str,
retrieved_docs: list[RetrievalResult],
max_context_tokens: int = 4000
) -> GenerationResult:
"""Generate response with citations."""
# Rerank if enabled
if self.use_reranking and len(retrieved_docs) > 5:
docs = self.rerank(query, retrieved_docs, top_k=10)
else:
docs = retrieved_docs
# Assemble context
context = self.assemble_context(docs, max_context_tokens)
# Generate with structured prompt
prompt = self.build_prompt(query, context, docs)
response = self.llm.generate(prompt)
# Extract and verify citations
citations = self.extract_citations(response, docs)
return GenerationResult(
answer=response,
citations=citations,
retrieved_docs=docs
)
def build_prompt(self, query: str, context: str, docs: list) -> str:
return f"""Answer the question based on the provided context.
Include citations [1], [2], etc. when referencing specific information.
If the context doesn't contain relevant information, say so.
Context:
{context}
Question: {query}
Answer:"""
def rerank(
self,
query: str,
docs: list[RetrievalResult],
top_k: int = 5
) -> list[RetrievalResult]:
"""Rerank documents using cross-encoder."""
pairs = [[query, doc.text] for doc in docs]
scores = self.reranker.predict(pairs)
# Sort by score
scored_docs = list(zip(docs, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in scored_docs[:top_k]]Production RAG System
class ProductionRAG:
"""Complete RAG system with all components."""
def __init__(self, config: RAGConfig):
# Document processing
self.processor = DocumentProcessor(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
embedding_model=config.embedding_model
)
# Storage
self.vector_store = VectorStore(config.vector_db_config)
self.bm25_index = BM25Index()
# Retrieval
self.retriever = HybridRetriever(
self.vector_store,
self.bm25_index,
vector_weight=config.vector_weight
)
# Generation
self.generator = RAGGenerator(
llm_client=config.llm_client,
use_reranking=config.use_reranking
)
# Monitoring
self.monitor = RAGMonitor()
def ingest(self, documents: list[str]) -> IngestionReport:
"""Ingest documents into the system."""
report = IngestionReport()
for doc_path in documents:
try:
chunks = self.processor.process_file(doc_path)
# Add to vector store
self.vector_store.upsert(chunks)
# Add to BM25 index
for chunk in chunks:
self.bm25_index.add(chunk.id, chunk.text)
report.success(doc_path, len(chunks))
except Exception as e:
report.failure(doc_path, str(e))
return report
def query(
self,
question: str,
config: QueryConfig = None
) -> RAGResponse:
"""Answer a question using RAG."""
config = config or QueryConfig()
timings = {}
# Retrieve
start = time.time()
docs = self.retriever.search(
question,
k=config.retrieve_k,
filter=config.filter
)
timings['retrieval'] = time.time() - start
# Generate
start = time.time()
result = self.generator.generate(
question,
docs,
max_context_tokens=config.max_context_tokens
)
timings['generation'] = time.time() - start
# Log for monitoring
self.monitor.log_query(question, docs, result, timings)
return RAGResponse(
answer=result.answer,
citations=result.citations,
latency=sum(timings.values())
)Comprehensive RAG Evaluation
def evaluate_rag_system(
rag: ProductionRAG,
test_set: list[dict], # {'question': str, 'answer': str, 'source_ids': list[str]}
llm_judge
) -> EvaluationReport:
"""Comprehensive RAG evaluation."""
retrieval_metrics = []
generation_metrics = []
for case in test_set:
# Test retrieval
docs = rag.retriever.search(case['question'], k=10)
retrieved_ids = [d.id for d in docs]
recall_at_5 = len(set(retrieved_ids[:5]) & set(case['source_ids'])) / len(case['source_ids'])
recall_at_10 = len(set(retrieved_ids) & set(case['source_ids'])) / len(case['source_ids'])
mrr = 0
for rank, doc_id in enumerate(retrieved_ids, 1):
if doc_id in case['source_ids']:
mrr = 1.0 / rank
break
retrieval_metrics.append({
'recall@5': recall_at_5,
'recall@10': recall_at_10,
'mrr': mrr
})
# Test generation
response = rag.query(case['question'])
# LLM judge for answer quality
judge_prompt = f"""Rate the answer quality on a scale of 1-5.
Question: {case['question']}
Expected: {case['answer']}
Actual: {response.answer}
Rating (1-5):"""
rating = int(llm_judge.generate(judge_prompt).strip())
generation_metrics.append({'quality': rating})
return EvaluationReport(
retrieval={
'recall@5': np.mean([m['recall@5'] for m in retrieval_metrics]),
'recall@10': np.mean([m['recall@10'] for m in retrieval_metrics]),
'mrr': np.mean([m['mrr'] for m in retrieval_metrics])
},
generation={
'quality': np.mean([m['quality'] for m in generation_metrics])
}
)