sach91 bootcamp week8 exercise
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
KnowledgeHub Agents
|
||||
"""
|
||||
from .base_agent import BaseAgent
|
||||
from .ingestion_agent import IngestionAgent
|
||||
from .question_agent import QuestionAgent
|
||||
from .summary_agent import SummaryAgent
|
||||
from .connection_agent import ConnectionAgent
|
||||
from .export_agent import ExportAgent
|
||||
|
||||
__all__ = [
|
||||
'BaseAgent',
|
||||
'IngestionAgent',
|
||||
'QuestionAgent',
|
||||
'SummaryAgent',
|
||||
'ConnectionAgent',
|
||||
'ExportAgent'
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Base Agent class - Foundation for all specialized agents
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from utils.ollama_client import OllamaClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Abstract base class for all agents"""
|
||||
|
||||
def __init__(self, name: str, llm_client: Optional[OllamaClient] = None,
|
||||
model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize base agent
|
||||
|
||||
Args:
|
||||
name: Agent name for logging
|
||||
llm_client: Shared Ollama client (creates new one if None)
|
||||
model: Ollama model to use
|
||||
"""
|
||||
self.name = name
|
||||
self.model = model
|
||||
|
||||
# Use shared client or create new one
|
||||
if llm_client is None:
|
||||
self.llm = OllamaClient(model=model)
|
||||
logger.info(f"{self.name} initialized with new LLM client (model: {model})")
|
||||
else:
|
||||
self.llm = llm_client
|
||||
logger.info(f"{self.name} initialized with shared LLM client (model: {model})")
|
||||
|
||||
def generate(self, prompt: str, system: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: int = 2048) -> str:
|
||||
"""
|
||||
Generate text using the LLM
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
system: System message (optional)
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
logger.info(f"{self.name} generating response")
|
||||
response = self.llm.generate(
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
logger.debug(f"{self.name} generated {len(response)} characters")
|
||||
return response
|
||||
|
||||
def chat(self, messages: list, temperature: float = 0.7,
|
||||
max_tokens: int = 2048) -> str:
|
||||
"""
|
||||
Chat completion with message history
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
logger.info(f"{self.name} processing chat with {len(messages)} messages")
|
||||
response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
logger.debug(f"{self.name} generated {len(response)} characters")
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def process(self, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Main processing method - must be implemented by subclasses
|
||||
|
||||
Each agent implements its specialized logic here
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name} (model: {self.model})"
|
||||
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Connection Agent - Finds relationships and connections between documents
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Tuple
|
||||
from agents.base_agent import BaseAgent
|
||||
from models.knowledge_graph import KnowledgeNode, KnowledgeEdge, KnowledgeGraph
|
||||
from utils.embeddings import EmbeddingModel
|
||||
import chromadb
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConnectionAgent(BaseAgent):
|
||||
"""Agent that discovers connections between documents and concepts"""
|
||||
|
||||
def __init__(self, collection: chromadb.Collection,
|
||||
embedding_model: EmbeddingModel,
|
||||
llm_client=None, model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize connection agent
|
||||
|
||||
Args:
|
||||
collection: ChromaDB collection with documents
|
||||
embedding_model: Model for computing similarities
|
||||
llm_client: Optional shared LLM client
|
||||
model: Ollama model name
|
||||
"""
|
||||
super().__init__(name="ConnectionAgent", llm_client=llm_client, model=model)
|
||||
|
||||
self.collection = collection
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
logger.info(f"{self.name} initialized")
|
||||
|
||||
def process(self, document_id: str = None, query: str = None,
|
||||
top_k: int = 5) -> Dict:
|
||||
"""
|
||||
Find documents related to a document or query
|
||||
|
||||
Args:
|
||||
document_id: ID of reference document
|
||||
query: Search query (used if document_id not provided)
|
||||
top_k: Number of related documents to find
|
||||
|
||||
Returns:
|
||||
Dictionary with related documents and connections
|
||||
"""
|
||||
if document_id:
|
||||
logger.info(f"{self.name} finding connections for document: {document_id}")
|
||||
return self._find_related_to_document(document_id, top_k)
|
||||
elif query:
|
||||
logger.info(f"{self.name} finding connections for query: {query[:100]}")
|
||||
return self._find_related_to_query(query, top_k)
|
||||
else:
|
||||
return {'related': [], 'error': 'No document_id or query provided'}
|
||||
|
||||
def _find_related_to_document(self, document_id: str, top_k: int) -> Dict:
|
||||
"""Find documents related to a specific document"""
|
||||
try:
|
||||
# Get chunks from the document
|
||||
results = self.collection.get(
|
||||
where={"document_id": document_id},
|
||||
include=['embeddings', 'documents', 'metadatas']
|
||||
)
|
||||
|
||||
if not results['ids']:
|
||||
return {'related': [], 'error': 'Document not found'}
|
||||
|
||||
# Use the first chunk's embedding as representative
|
||||
query_embedding = results['embeddings'][0]
|
||||
document_name = results['metadatas'][0].get('filename', 'Unknown')
|
||||
|
||||
# Search for similar chunks from OTHER documents
|
||||
search_results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k * 3, # Get more to filter out same document
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
|
||||
# Filter out chunks from the same document
|
||||
related = []
|
||||
seen_docs = set([document_id])
|
||||
|
||||
if search_results['ids']:
|
||||
for i in range(len(search_results['ids'][0])):
|
||||
related_doc_id = search_results['metadatas'][0][i].get('document_id')
|
||||
|
||||
if related_doc_id not in seen_docs:
|
||||
seen_docs.add(related_doc_id)
|
||||
|
||||
similarity = 1.0 - search_results['distances'][0][i]
|
||||
|
||||
related.append({
|
||||
'document_id': related_doc_id,
|
||||
'document_name': search_results['metadatas'][0][i].get('filename', 'Unknown'),
|
||||
'similarity': float(similarity),
|
||||
'preview': search_results['documents'][0][i][:150] + "..."
|
||||
})
|
||||
|
||||
if len(related) >= top_k:
|
||||
break
|
||||
|
||||
return {
|
||||
'source_document': document_name,
|
||||
'source_id': document_id,
|
||||
'related': related,
|
||||
'num_related': len(related)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding related documents: {e}")
|
||||
return {'related': [], 'error': str(e)}
|
||||
|
||||
def _find_related_to_query(self, query: str, top_k: int) -> Dict:
|
||||
"""Find documents related to a query"""
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_model.embed_query(query)
|
||||
|
||||
# Search
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k * 2, # Get more to deduplicate by document
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
|
||||
# Deduplicate by document
|
||||
related = []
|
||||
seen_docs = set()
|
||||
|
||||
if results['ids']:
|
||||
for i in range(len(results['ids'][0])):
|
||||
doc_id = results['metadatas'][0][i].get('document_id')
|
||||
|
||||
if doc_id not in seen_docs:
|
||||
seen_docs.add(doc_id)
|
||||
|
||||
similarity = 1.0 - results['distances'][0][i]
|
||||
|
||||
related.append({
|
||||
'document_id': doc_id,
|
||||
'document_name': results['metadatas'][0][i].get('filename', 'Unknown'),
|
||||
'similarity': float(similarity),
|
||||
'preview': results['documents'][0][i][:150] + "..."
|
||||
})
|
||||
|
||||
if len(related) >= top_k:
|
||||
break
|
||||
|
||||
return {
|
||||
'query': query,
|
||||
'related': related,
|
||||
'num_related': len(related)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding related documents: {e}")
|
||||
return {'related': [], 'error': str(e)}
|
||||
|
||||
def build_knowledge_graph(self, similarity_threshold: float = 0.7) -> KnowledgeGraph:
|
||||
"""
|
||||
Build a knowledge graph showing document relationships
|
||||
|
||||
Args:
|
||||
similarity_threshold: Minimum similarity to create an edge
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object
|
||||
"""
|
||||
logger.info(f"{self.name} building knowledge graph")
|
||||
|
||||
graph = KnowledgeGraph()
|
||||
|
||||
try:
|
||||
# Get all documents
|
||||
all_results = self.collection.get(
|
||||
include=['embeddings', 'metadatas']
|
||||
)
|
||||
|
||||
if not all_results['ids']:
|
||||
return graph
|
||||
|
||||
# Group by document
|
||||
documents = {}
|
||||
for i, metadata in enumerate(all_results['metadatas']):
|
||||
doc_id = metadata.get('document_id')
|
||||
if doc_id not in documents:
|
||||
documents[doc_id] = {
|
||||
'name': metadata.get('filename', 'Unknown'),
|
||||
'embedding': all_results['embeddings'][i]
|
||||
}
|
||||
|
||||
# Create nodes
|
||||
for doc_id, doc_data in documents.items():
|
||||
node = KnowledgeNode(
|
||||
id=doc_id,
|
||||
name=doc_data['name'],
|
||||
node_type='document',
|
||||
description=f"Document: {doc_data['name']}"
|
||||
)
|
||||
graph.add_node(node)
|
||||
|
||||
# Create edges based on similarity
|
||||
doc_ids = list(documents.keys())
|
||||
for i, doc_id1 in enumerate(doc_ids):
|
||||
emb1 = np.array(documents[doc_id1]['embedding'])
|
||||
|
||||
for doc_id2 in doc_ids[i+1:]:
|
||||
emb2 = np.array(documents[doc_id2]['embedding'])
|
||||
|
||||
# Calculate similarity
|
||||
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
||||
|
||||
if similarity >= similarity_threshold:
|
||||
edge = KnowledgeEdge(
|
||||
source_id=doc_id1,
|
||||
target_id=doc_id2,
|
||||
relationship='similar_to',
|
||||
weight=float(similarity)
|
||||
)
|
||||
graph.add_edge(edge)
|
||||
|
||||
logger.info(f"{self.name} built graph with {len(graph.nodes)} nodes and {len(graph.edges)} edges")
|
||||
return graph
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building knowledge graph: {e}")
|
||||
return graph
|
||||
|
||||
def explain_connection(self, doc_id1: str, doc_id2: str) -> str:
|
||||
"""
|
||||
Use LLM to explain why two documents are related
|
||||
|
||||
Args:
|
||||
doc_id1: First document ID
|
||||
doc_id2: Second document ID
|
||||
|
||||
Returns:
|
||||
Explanation text
|
||||
"""
|
||||
try:
|
||||
# Get sample chunks from each document
|
||||
results1 = self.collection.get(
|
||||
where={"document_id": doc_id1},
|
||||
limit=2,
|
||||
include=['documents', 'metadatas']
|
||||
)
|
||||
|
||||
results2 = self.collection.get(
|
||||
where={"document_id": doc_id2},
|
||||
limit=2,
|
||||
include=['documents', 'metadatas']
|
||||
)
|
||||
|
||||
if not results1['ids'] or not results2['ids']:
|
||||
return "Could not retrieve documents"
|
||||
|
||||
doc1_name = results1['metadatas'][0].get('filename', 'Document 1')
|
||||
doc2_name = results2['metadatas'][0].get('filename', 'Document 2')
|
||||
|
||||
doc1_text = " ".join(results1['documents'][:2])[:1000]
|
||||
doc2_text = " ".join(results2['documents'][:2])[:1000]
|
||||
|
||||
system_prompt = """You analyze documents and explain their relationships.
|
||||
Provide a brief, clear explanation of how two documents are related."""
|
||||
|
||||
user_prompt = f"""Analyze these two documents and explain how they are related:
|
||||
|
||||
Document 1 ({doc1_name}):
|
||||
{doc1_text}
|
||||
|
||||
Document 2 ({doc2_name}):
|
||||
{doc2_text}
|
||||
|
||||
How are these documents related? Provide a concise explanation:"""
|
||||
|
||||
explanation = self.generate(
|
||||
prompt=user_prompt,
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
return explanation
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error explaining connection: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Export Agent - Generates formatted reports and exports
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
from agents.base_agent import BaseAgent
|
||||
from models.document import Summary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExportAgent(BaseAgent):
|
||||
"""Agent that exports summaries and reports in various formats"""
|
||||
|
||||
def __init__(self, llm_client=None, model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize export agent
|
||||
|
||||
Args:
|
||||
llm_client: Optional shared LLM client
|
||||
model: Ollama model name
|
||||
"""
|
||||
super().__init__(name="ExportAgent", llm_client=llm_client, model=model)
|
||||
|
||||
logger.info(f"{self.name} initialized")
|
||||
|
||||
def process(self, content: Dict, format: str = "markdown") -> str:
|
||||
"""
|
||||
Export content in specified format
|
||||
|
||||
Args:
|
||||
content: Content dictionary to export
|
||||
format: Export format ('markdown', 'text', 'html')
|
||||
|
||||
Returns:
|
||||
Formatted content string
|
||||
"""
|
||||
logger.info(f"{self.name} exporting as {format}")
|
||||
|
||||
if format == "markdown":
|
||||
return self._export_markdown(content)
|
||||
elif format == "text":
|
||||
return self._export_text(content)
|
||||
elif format == "html":
|
||||
return self._export_html(content)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def _export_markdown(self, content: Dict) -> str:
|
||||
"""Export as Markdown"""
|
||||
md = []
|
||||
md.append(f"# Knowledge Report")
|
||||
md.append(f"\n*Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}*\n")
|
||||
|
||||
if 'title' in content:
|
||||
md.append(f"## {content['title']}\n")
|
||||
|
||||
if 'summary' in content:
|
||||
md.append(f"### Summary\n")
|
||||
md.append(f"{content['summary']}\n")
|
||||
|
||||
if 'key_points' in content and content['key_points']:
|
||||
md.append(f"### Key Points\n")
|
||||
for point in content['key_points']:
|
||||
md.append(f"- {point}")
|
||||
md.append("")
|
||||
|
||||
if 'sections' in content:
|
||||
for section in content['sections']:
|
||||
md.append(f"### {section['title']}\n")
|
||||
md.append(f"{section['content']}\n")
|
||||
|
||||
if 'sources' in content and content['sources']:
|
||||
md.append(f"### Sources\n")
|
||||
for i, source in enumerate(content['sources'], 1):
|
||||
md.append(f"{i}. {source}")
|
||||
md.append("")
|
||||
|
||||
return "\n".join(md)
|
||||
|
||||
def _export_text(self, content: Dict) -> str:
|
||||
"""Export as plain text"""
|
||||
lines = []
|
||||
lines.append("=" * 60)
|
||||
lines.append("KNOWLEDGE REPORT")
|
||||
lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
||||
lines.append("=" * 60)
|
||||
lines.append("")
|
||||
|
||||
if 'title' in content:
|
||||
lines.append(content['title'])
|
||||
lines.append("-" * len(content['title']))
|
||||
lines.append("")
|
||||
|
||||
if 'summary' in content:
|
||||
lines.append("SUMMARY:")
|
||||
lines.append(content['summary'])
|
||||
lines.append("")
|
||||
|
||||
if 'key_points' in content and content['key_points']:
|
||||
lines.append("KEY POINTS:")
|
||||
for i, point in enumerate(content['key_points'], 1):
|
||||
lines.append(f" {i}. {point}")
|
||||
lines.append("")
|
||||
|
||||
if 'sections' in content:
|
||||
for section in content['sections']:
|
||||
lines.append(section['title'].upper())
|
||||
lines.append("-" * 40)
|
||||
lines.append(section['content'])
|
||||
lines.append("")
|
||||
|
||||
if 'sources' in content and content['sources']:
|
||||
lines.append("SOURCES:")
|
||||
for i, source in enumerate(content['sources'], 1):
|
||||
lines.append(f" {i}. {source}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("=" * 60)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _export_html(self, content: Dict) -> str:
|
||||
"""Export as HTML"""
|
||||
html = []
|
||||
html.append("<!DOCTYPE html>")
|
||||
html.append("<html>")
|
||||
html.append("<head>")
|
||||
html.append(" <meta charset='utf-8'>")
|
||||
html.append(" <title>Knowledge Report</title>")
|
||||
html.append(" <style>")
|
||||
html.append(" body { font-family: Arial, sans-serif; max-width: 800px; margin: 40px auto; padding: 20px; }")
|
||||
html.append(" h1 { color: #333; border-bottom: 3px solid #007bff; padding-bottom: 10px; }")
|
||||
html.append(" h2 { color: #555; margin-top: 30px; }")
|
||||
html.append(" .meta { color: #888; font-style: italic; }")
|
||||
html.append(" .key-points { background: #f8f9fa; padding: 15px; border-left: 4px solid #007bff; }")
|
||||
html.append(" .source { color: #666; font-size: 0.9em; }")
|
||||
html.append(" </style>")
|
||||
html.append("</head>")
|
||||
html.append("<body>")
|
||||
|
||||
html.append(" <h1>Knowledge Report</h1>")
|
||||
html.append(f" <p class='meta'>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}</p>")
|
||||
|
||||
if 'title' in content:
|
||||
html.append(f" <h2>{content['title']}</h2>")
|
||||
|
||||
if 'summary' in content:
|
||||
html.append(f" <h3>Summary</h3>")
|
||||
html.append(f" <p>{content['summary']}</p>")
|
||||
|
||||
if 'key_points' in content and content['key_points']:
|
||||
html.append(" <h3>Key Points</h3>")
|
||||
html.append(" <div class='key-points'>")
|
||||
html.append(" <ul>")
|
||||
for point in content['key_points']:
|
||||
html.append(f" <li>{point}</li>")
|
||||
html.append(" </ul>")
|
||||
html.append(" </div>")
|
||||
|
||||
if 'sections' in content:
|
||||
for section in content['sections']:
|
||||
html.append(f" <h3>{section['title']}</h3>")
|
||||
html.append(f" <p>{section['content']}</p>")
|
||||
|
||||
if 'sources' in content and content['sources']:
|
||||
html.append(" <h3>Sources</h3>")
|
||||
html.append(" <ol class='source'>")
|
||||
for source in content['sources']:
|
||||
html.append(f" <li>{source}</li>")
|
||||
html.append(" </ol>")
|
||||
|
||||
html.append("</body>")
|
||||
html.append("</html>")
|
||||
|
||||
return "\n".join(html)
|
||||
|
||||
def create_study_guide(self, summaries: List[Summary]) -> str:
|
||||
"""
|
||||
Create a study guide from multiple summaries
|
||||
|
||||
Args:
|
||||
summaries: List of Summary objects
|
||||
|
||||
Returns:
|
||||
Formatted study guide
|
||||
"""
|
||||
logger.info(f"{self.name} creating study guide from {len(summaries)} summaries")
|
||||
|
||||
# Compile all content
|
||||
all_summaries = "\n\n".join([
|
||||
f"{s.document_name}:\n{s.summary_text}"
|
||||
for s in summaries
|
||||
])
|
||||
|
||||
all_key_points = []
|
||||
for s in summaries:
|
||||
all_key_points.extend(s.key_points)
|
||||
|
||||
# Use LLM to create cohesive study guide
|
||||
system_prompt = """You create excellent study guides that synthesize information from multiple sources.
|
||||
Create a well-organized study guide with clear sections, key concepts, and important points."""
|
||||
|
||||
user_prompt = f"""Create a comprehensive study guide based on these document summaries:
|
||||
|
||||
{all_summaries}
|
||||
|
||||
Create a well-structured study guide with:
|
||||
1. An overview
|
||||
2. Key concepts
|
||||
3. Important details
|
||||
4. Study tips
|
||||
|
||||
Study Guide:"""
|
||||
|
||||
study_guide = self.generate(
|
||||
prompt=user_prompt,
|
||||
system=system_prompt,
|
||||
temperature=0.5,
|
||||
max_tokens=2048
|
||||
)
|
||||
|
||||
# Format as markdown
|
||||
content = {
|
||||
'title': 'Study Guide',
|
||||
'sections': [
|
||||
{'title': 'Overview', 'content': study_guide},
|
||||
{'title': 'Key Points from All Documents', 'content': '\n'.join([f"• {p}" for p in all_key_points[:15]])}
|
||||
],
|
||||
'sources': [s.document_name for s in summaries]
|
||||
}
|
||||
|
||||
return self._export_markdown(content)
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Ingestion Agent - Processes and stores documents in the vector database
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from agents.base_agent import BaseAgent
|
||||
from models.document import Document, DocumentChunk
|
||||
from utils.document_parser import DocumentParser
|
||||
from utils.embeddings import EmbeddingModel
|
||||
import chromadb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IngestionAgent(BaseAgent):
|
||||
"""Agent responsible for ingesting and storing documents"""
|
||||
|
||||
def __init__(self, collection: chromadb.Collection,
|
||||
embedding_model: EmbeddingModel,
|
||||
llm_client=None, model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize ingestion agent
|
||||
|
||||
Args:
|
||||
collection: ChromaDB collection for storage
|
||||
embedding_model: Model for generating embeddings
|
||||
llm_client: Optional shared LLM client
|
||||
model: Ollama model name
|
||||
"""
|
||||
super().__init__(name="IngestionAgent", llm_client=llm_client, model=model)
|
||||
|
||||
self.collection = collection
|
||||
self.embedding_model = embedding_model
|
||||
self.parser = DocumentParser(chunk_size=1000, chunk_overlap=200)
|
||||
|
||||
logger.info(f"{self.name} ready with ChromaDB collection")
|
||||
|
||||
def process(self, file_path: str) -> Document:
|
||||
"""
|
||||
Process and ingest a document
|
||||
|
||||
Args:
|
||||
file_path: Path to the document file
|
||||
|
||||
Returns:
|
||||
Document object with metadata
|
||||
"""
|
||||
logger.info(f"{self.name} processing: {file_path}")
|
||||
|
||||
# Parse the document
|
||||
parsed = self.parser.parse_file(file_path)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Create document chunks
|
||||
chunks = []
|
||||
chunk_texts = []
|
||||
chunk_ids = []
|
||||
chunk_metadatas = []
|
||||
|
||||
for i, chunk_text in enumerate(parsed['chunks']):
|
||||
chunk_id = f"{doc_id}_chunk_{i}"
|
||||
|
||||
chunk = DocumentChunk(
|
||||
id=chunk_id,
|
||||
document_id=doc_id,
|
||||
content=chunk_text,
|
||||
chunk_index=i,
|
||||
metadata={
|
||||
'filename': parsed['filename'],
|
||||
'extension': parsed['extension'],
|
||||
'total_chunks': len(parsed['chunks'])
|
||||
}
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
chunk_texts.append(chunk_text)
|
||||
chunk_ids.append(chunk_id)
|
||||
chunk_metadatas.append({
|
||||
'document_id': doc_id,
|
||||
'filename': parsed['filename'],
|
||||
'chunk_index': i,
|
||||
'extension': parsed['extension']
|
||||
})
|
||||
|
||||
# Generate embeddings
|
||||
logger.info(f"{self.name} generating embeddings for {len(chunks)} chunks")
|
||||
embeddings = self.embedding_model.embed_documents(chunk_texts)
|
||||
|
||||
# Store in ChromaDB
|
||||
logger.info(f"{self.name} storing in ChromaDB")
|
||||
self.collection.add(
|
||||
ids=chunk_ids,
|
||||
documents=chunk_texts,
|
||||
embeddings=embeddings,
|
||||
metadatas=chunk_metadatas
|
||||
)
|
||||
|
||||
# Create document object
|
||||
document = Document(
|
||||
id=doc_id,
|
||||
filename=parsed['filename'],
|
||||
filepath=parsed['filepath'],
|
||||
content=parsed['text'],
|
||||
chunks=chunks,
|
||||
metadata={
|
||||
'extension': parsed['extension'],
|
||||
'num_chunks': len(chunks),
|
||||
'total_chars': parsed['total_chars']
|
||||
},
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"{self.name} successfully ingested: {document}")
|
||||
return document
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""Get statistics about stored documents"""
|
||||
try:
|
||||
count = self.collection.count()
|
||||
return {
|
||||
'total_chunks': count,
|
||||
'collection_name': self.collection.name
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting statistics: {e}")
|
||||
return {'total_chunks': 0, 'error': str(e)}
|
||||
|
||||
def delete_document(self, document_id: str) -> bool:
|
||||
"""
|
||||
Delete all chunks of a document
|
||||
|
||||
Args:
|
||||
document_id: ID of document to delete
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
# Get all chunk IDs for this document
|
||||
results = self.collection.get(
|
||||
where={"document_id": document_id}
|
||||
)
|
||||
|
||||
if results['ids']:
|
||||
self.collection.delete(ids=results['ids'])
|
||||
logger.info(f"{self.name} deleted document {document_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Question Agent - Answers questions using RAG (Retrieval Augmented Generation)
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from agents.base_agent import BaseAgent
|
||||
from models.document import SearchResult, DocumentChunk
|
||||
from utils.embeddings import EmbeddingModel
|
||||
import chromadb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QuestionAgent(BaseAgent):
|
||||
"""Agent that answers questions using retrieved context"""
|
||||
|
||||
def __init__(self, collection: chromadb.Collection,
|
||||
embedding_model: EmbeddingModel,
|
||||
llm_client=None, model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize question agent
|
||||
|
||||
Args:
|
||||
collection: ChromaDB collection with documents
|
||||
embedding_model: Model for query embeddings
|
||||
llm_client: Optional shared LLM client
|
||||
model: Ollama model name
|
||||
"""
|
||||
super().__init__(name="QuestionAgent", llm_client=llm_client, model=model)
|
||||
|
||||
self.collection = collection
|
||||
self.embedding_model = embedding_model
|
||||
self.top_k = 5 # Number of chunks to retrieve
|
||||
|
||||
logger.info(f"{self.name} initialized")
|
||||
|
||||
def retrieve(self, query: str, top_k: int = None) -> List[SearchResult]:
|
||||
"""
|
||||
Retrieve relevant document chunks for a query
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Number of results to return (uses self.top_k if None)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
logger.info(f"{self.name} retrieving top {top_k} chunks for query")
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_model.embed_query(query)
|
||||
|
||||
# Search ChromaDB
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k
|
||||
)
|
||||
|
||||
# Convert to SearchResult objects
|
||||
search_results = []
|
||||
|
||||
if results['ids'] and len(results['ids']) > 0:
|
||||
for i in range(len(results['ids'][0])):
|
||||
chunk = DocumentChunk(
|
||||
id=results['ids'][0][i],
|
||||
document_id=results['metadatas'][0][i].get('document_id', ''),
|
||||
content=results['documents'][0][i],
|
||||
chunk_index=results['metadatas'][0][i].get('chunk_index', 0),
|
||||
metadata=results['metadatas'][0][i]
|
||||
)
|
||||
|
||||
result = SearchResult(
|
||||
chunk=chunk,
|
||||
score=1.0 - results['distances'][0][i], # Convert distance to similarity
|
||||
document_id=results['metadatas'][0][i].get('document_id', ''),
|
||||
document_name=results['metadatas'][0][i].get('filename', 'Unknown')
|
||||
)
|
||||
|
||||
search_results.append(result)
|
||||
|
||||
logger.info(f"{self.name} retrieved {len(search_results)} results")
|
||||
return search_results
|
||||
|
||||
def process(self, question: str, top_k: int = None) -> Dict[str, any]:
|
||||
"""
|
||||
Answer a question using RAG
|
||||
|
||||
Args:
|
||||
question: User's question
|
||||
top_k: Number of chunks to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary with answer and sources
|
||||
"""
|
||||
logger.info(f"{self.name} processing question: {question[:100]}...")
|
||||
|
||||
# Retrieve relevant chunks
|
||||
search_results = self.retrieve(question, top_k)
|
||||
|
||||
if not search_results:
|
||||
return {
|
||||
'answer': "I don't have any relevant information in my knowledge base to answer this question.",
|
||||
'sources': [],
|
||||
'context_used': ""
|
||||
}
|
||||
|
||||
# Build context from retrieved chunks
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for i, result in enumerate(search_results, 1):
|
||||
context_parts.append(f"[Source {i}] {result.chunk.content}")
|
||||
sources.append({
|
||||
'document': result.document_name,
|
||||
'score': result.score,
|
||||
'preview': result.chunk.content[:150] + "..."
|
||||
})
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
# Create prompt for LLM
|
||||
system_prompt = """You are a helpful research assistant. Answer questions based on the provided context.
|
||||
Be accurate and cite sources when possible. If the context doesn't contain enough information to answer fully, say so.
|
||||
Keep your answer concise and relevant."""
|
||||
|
||||
user_prompt = f"""Context from my knowledge base:
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer based on the context above. If you reference specific information, mention which source(s) you're using."""
|
||||
|
||||
# Generate answer
|
||||
answer = self.generate(
|
||||
prompt=user_prompt,
|
||||
system=system_prompt,
|
||||
temperature=0.3, # Lower temperature for more factual responses
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
logger.info(f"{self.name} generated answer ({len(answer)} chars)")
|
||||
|
||||
return {
|
||||
'answer': answer,
|
||||
'sources': sources,
|
||||
'context_used': context,
|
||||
'num_sources': len(sources)
|
||||
}
|
||||
|
||||
def set_top_k(self, k: int):
|
||||
"""Set the number of chunks to retrieve"""
|
||||
self.top_k = k
|
||||
logger.info(f"{self.name} top_k set to {k}")
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Summary Agent - Creates summaries and extracts key points from documents
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from agents.base_agent import BaseAgent
|
||||
from models.document import Summary
|
||||
import chromadb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SummaryAgent(BaseAgent):
|
||||
"""Agent that creates summaries of documents"""
|
||||
|
||||
def __init__(self, collection: chromadb.Collection,
|
||||
llm_client=None, model: str = "llama3.2"):
|
||||
"""
|
||||
Initialize summary agent
|
||||
|
||||
Args:
|
||||
collection: ChromaDB collection with documents
|
||||
llm_client: Optional shared LLM client
|
||||
model: Ollama model name
|
||||
"""
|
||||
super().__init__(name="SummaryAgent", llm_client=llm_client, model=model)
|
||||
self.collection = collection
|
||||
|
||||
logger.info(f"{self.name} initialized")
|
||||
|
||||
def process(self, document_id: str = None, document_text: str = None,
|
||||
document_name: str = "Unknown") -> Summary:
|
||||
"""
|
||||
Create a summary of a document
|
||||
|
||||
Args:
|
||||
document_id: ID of document in ChromaDB (retrieves chunks if provided)
|
||||
document_text: Full document text (used if document_id not provided)
|
||||
document_name: Name of the document
|
||||
|
||||
Returns:
|
||||
Summary object
|
||||
"""
|
||||
logger.info(f"{self.name} creating summary for: {document_name}")
|
||||
|
||||
# Get document text
|
||||
if document_id:
|
||||
text = self._get_document_text(document_id)
|
||||
if not text:
|
||||
return Summary(
|
||||
document_id=document_id,
|
||||
document_name=document_name,
|
||||
summary_text="Error: Could not retrieve document",
|
||||
key_points=[]
|
||||
)
|
||||
elif document_text:
|
||||
text = document_text
|
||||
else:
|
||||
return Summary(
|
||||
document_id="",
|
||||
document_name=document_name,
|
||||
summary_text="Error: No document provided",
|
||||
key_points=[]
|
||||
)
|
||||
|
||||
# Truncate if too long (to fit in context)
|
||||
max_chars = 8000
|
||||
if len(text) > max_chars:
|
||||
logger.warning(f"{self.name} truncating document from {len(text)} to {max_chars} chars")
|
||||
text = text[:max_chars] + "\n\n[Document truncated...]"
|
||||
|
||||
# Generate summary
|
||||
summary_text = self._generate_summary(text)
|
||||
|
||||
# Extract key points
|
||||
key_points = self._extract_key_points(text)
|
||||
|
||||
summary = Summary(
|
||||
document_id=document_id or "",
|
||||
document_name=document_name,
|
||||
summary_text=summary_text,
|
||||
key_points=key_points
|
||||
)
|
||||
|
||||
logger.info(f"{self.name} completed summary with {len(key_points)} key points")
|
||||
return summary
|
||||
|
||||
def _get_document_text(self, document_id: str) -> str:
|
||||
"""Retrieve and reconstruct document text from chunks"""
|
||||
try:
|
||||
results = self.collection.get(
|
||||
where={"document_id": document_id}
|
||||
)
|
||||
|
||||
if not results['ids']:
|
||||
return ""
|
||||
|
||||
# Sort by chunk index
|
||||
chunks_data = list(zip(
|
||||
results['documents'],
|
||||
results['metadatas']
|
||||
))
|
||||
|
||||
chunks_data.sort(key=lambda x: x[1].get('chunk_index', 0))
|
||||
|
||||
# Combine chunks
|
||||
text = "\n\n".join([chunk[0] for chunk in chunks_data])
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document: {e}")
|
||||
return ""
|
||||
|
||||
def _generate_summary(self, text: str) -> str:
|
||||
"""Generate a concise summary of the text"""
|
||||
system_prompt = """You are an expert at creating concise, informative summaries.
|
||||
Your summaries capture the main ideas and key information in clear, accessible language.
|
||||
Keep summaries to 3-5 sentences unless the document is very long."""
|
||||
|
||||
user_prompt = f"""Please create a concise summary of the following document:
|
||||
|
||||
{text}
|
||||
|
||||
Summary:"""
|
||||
|
||||
summary = self.generate(
|
||||
prompt=user_prompt,
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
return summary.strip()
|
||||
|
||||
def _extract_key_points(self, text: str) -> List[str]:
|
||||
"""Extract key points from the text"""
|
||||
system_prompt = """You extract the most important key points from documents.
|
||||
List 3-7 key points as concise bullet points. Each point should be a complete, standalone statement."""
|
||||
|
||||
user_prompt = f"""Please extract the key points from the following document:
|
||||
|
||||
{text}
|
||||
|
||||
List the key points (one per line, without bullets or numbers):"""
|
||||
|
||||
response = self.generate(
|
||||
prompt=user_prompt,
|
||||
system=system_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
# Parse the response into a list
|
||||
key_points = []
|
||||
for line in response.split('\n'):
|
||||
line = line.strip()
|
||||
# Remove common list markers
|
||||
line = line.lstrip('•-*0123456789.)')
|
||||
line = line.strip()
|
||||
|
||||
if line and len(line) > 10: # Filter out very short lines
|
||||
key_points.append(line)
|
||||
|
||||
return key_points[:7] # Limit to 7 points
|
||||
|
||||
def summarize_multiple(self, document_ids: List[str]) -> List[Summary]:
|
||||
"""
|
||||
Create summaries for multiple documents
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs
|
||||
|
||||
Returns:
|
||||
List of Summary objects
|
||||
"""
|
||||
summaries = []
|
||||
|
||||
for doc_id in document_ids:
|
||||
summary = self.process(document_id=doc_id)
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
Reference in New Issue
Block a user