268 lines
8.1 KiB
Python
268 lines
8.1 KiB
Python
from chromadb import PersistentClient
|
|
from dotenv import load_dotenv
|
|
from enum import Enum
|
|
|
|
import plotly.graph_objects as go
|
|
from langchain.document_loaders import DirectoryLoader, TextLoader
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.schema import Document
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
from langchain_chroma import Chroma
|
|
from langchain.memory import ConversationBufferMemory
|
|
from langchain.chains import ConversationalRetrievalChain
|
|
import numpy as np
|
|
import os
|
|
from pathlib import Path
|
|
from sklearn.manifold import TSNE
|
|
from typing import Any, List, Tuple, Generator
|
|
|
|
cur_path = Path(__file__)
|
|
env_path = cur_path.parent.parent.parent.parent / '.env'
|
|
assert env_path.exists(), f"Please add an .env to the root project path"
|
|
|
|
load_dotenv(dotenv_path=env_path)
|
|
|
|
|
|
class Rag(Enum):
|
|
|
|
GPT_MODEL = "gpt-4o-mini"
|
|
HUG_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|
EMBED_MODEL = OpenAIEmbeddings()
|
|
DB_NAME = "vector_db"
|
|
|
|
|
|
def add_metadata(doc: Document, doc_type: str) -> Document:
|
|
"""
|
|
Add metadata to a Document object.
|
|
|
|
:param doc: The Document object to add metadata to.
|
|
:type doc: Document
|
|
:param doc_type: The type of document to be added as metadata.
|
|
:type doc_type: str
|
|
:return: The Document object with added metadata.
|
|
:rtype: Document
|
|
"""
|
|
doc.metadata["doc_type"] = doc_type
|
|
return doc
|
|
|
|
|
|
def get_chunks(folders: Generator[Path, None, None], file_ext='.txt') -> List[Document]:
|
|
"""
|
|
Load documents from specified folders, add metadata, and split them into chunks.
|
|
|
|
:param folders: List of folder paths containing documents.
|
|
:type folders: List[str]
|
|
:param file_ext:
|
|
The file extension to get from a local knowledge base (e.g. '.txt')
|
|
:type file_ext: str
|
|
:return: List of document chunks.
|
|
:rtype: List[Document]
|
|
"""
|
|
text_loader_kwargs = {'encoding': 'utf-8'}
|
|
documents = []
|
|
for folder in folders:
|
|
doc_type = os.path.basename(folder)
|
|
loader = DirectoryLoader(
|
|
folder, glob=f"**/*{file_ext}", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs
|
|
)
|
|
folder_docs = loader.load()
|
|
documents.extend([add_metadata(doc, doc_type) for doc in folder_docs])
|
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
chunks = text_splitter.split_documents(documents)
|
|
|
|
return chunks
|
|
|
|
|
|
def create_vector_db(db_name: str, chunks: List[Document], embeddings: Any) -> Any:
|
|
"""
|
|
Create a vector database from document chunks.
|
|
|
|
:param db_name: Name of the database to create.
|
|
:type db_name: str
|
|
:param chunks: List of document chunks.
|
|
:type chunks: List[Document]
|
|
:param embeddings: Embedding function to use.
|
|
:type embeddings: Any
|
|
:return: Created vector store.
|
|
:rtype: Any
|
|
"""
|
|
# Delete if already exists
|
|
if os.path.exists(db_name):
|
|
Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()
|
|
|
|
# Create vectorstore
|
|
vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)
|
|
|
|
return vectorstore
|
|
|
|
|
|
def get_local_vector_db(path: str) -> Any:
|
|
"""
|
|
Get a local vector database.
|
|
|
|
:param path: Path to the local vector database.
|
|
:type path: str
|
|
:return: Persistent client for the vector database.
|
|
:rtype: Any
|
|
"""
|
|
return PersistentClient(path=path)
|
|
|
|
|
|
def get_vector_db_info(vector_store: Any) -> None:
|
|
"""
|
|
Print information about the vector database.
|
|
|
|
:param vector_store: Vector store to get information from.
|
|
:type vector_store: Any
|
|
"""
|
|
collection = vector_store._collection
|
|
count = collection.count()
|
|
|
|
sample_embedding = collection.get(limit=1, include=["embeddings"])["embeddings"][0]
|
|
dimensions = len(sample_embedding)
|
|
|
|
print(f"There are {count:,} vectors with {dimensions:,} dimensions in the vector store")
|
|
|
|
|
|
def get_plot_data(collection: Any) -> Tuple[np.ndarray, List[str], List[str], List[str]]:
|
|
"""
|
|
Get plot data from a collection.
|
|
|
|
:param collection: Collection to get data from.
|
|
:type collection: Any
|
|
:return: Tuple containing vectors, colors, document types, and documents.
|
|
:rtype: Tuple[np.ndarray, List[str], List[str], List[str]]
|
|
"""
|
|
result = collection.get(include=['embeddings', 'documents', 'metadatas'])
|
|
vectors = np.array(result['embeddings'])
|
|
documents = result['documents']
|
|
metadatas = result['metadatas']
|
|
doc_types = [metadata['doc_type'] for metadata in metadatas]
|
|
colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in
|
|
doc_types]
|
|
|
|
return vectors, colors, doc_types, documents
|
|
|
|
|
|
def get_2d_plot(collection: Any) -> go.Figure:
|
|
"""
|
|
Generate a 2D plot of the vector store.
|
|
|
|
:param collection: Collection to generate plot from.
|
|
:type collection: Any
|
|
:return: 2D scatter plot figure.
|
|
:rtype: go.Figure
|
|
"""
|
|
vectors, colors, doc_types, documents = get_plot_data(collection)
|
|
tsne = TSNE(n_components=2, random_state=42)
|
|
reduced_vectors = tsne.fit_transform(vectors)
|
|
|
|
fig = go.Figure(data=[go.Scatter(
|
|
x=reduced_vectors[:, 0],
|
|
y=reduced_vectors[:, 1],
|
|
mode='markers',
|
|
marker=dict(size=5, color=colors, opacity=0.8),
|
|
text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)],
|
|
hoverinfo='text'
|
|
)])
|
|
|
|
fig.update_layout(
|
|
title='2D Chroma Vector Store Visualization',
|
|
scene=dict(xaxis_title='x', yaxis_title='y'),
|
|
width=800,
|
|
height=600,
|
|
margin=dict(r=20, b=10, l=10, t=40)
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
def get_3d_plot(collection: Any) -> go.Figure:
|
|
"""
|
|
Generate a 3D plot of the vector store.
|
|
|
|
:param collection: Collection to generate plot from.
|
|
:type collection: Any
|
|
:return: 3D scatter plot figure.
|
|
:rtype: go.Figure
|
|
"""
|
|
vectors, colors, doc_types, documents = get_plot_data(collection)
|
|
tsne = TSNE(n_components=3, random_state=42)
|
|
reduced_vectors = tsne.fit_transform(vectors)
|
|
|
|
fig = go.Figure(data=[go.Scatter3d(
|
|
x=reduced_vectors[:, 0],
|
|
y=reduced_vectors[:, 1],
|
|
z=reduced_vectors[:, 2],
|
|
mode='markers',
|
|
marker=dict(size=5, color=colors, opacity=0.8),
|
|
text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)],
|
|
hoverinfo='text'
|
|
)])
|
|
|
|
fig.update_layout(
|
|
title='3D Chroma Vector Store Visualization',
|
|
scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),
|
|
width=900,
|
|
height=700,
|
|
margin=dict(r=20, b=10, l=10, t=40)
|
|
)
|
|
|
|
return fig
|
|
|
|
|
|
def get_conversation_chain(vectorstore: Any) -> ConversationalRetrievalChain:
|
|
"""
|
|
Create a conversation chain using the vector store.
|
|
|
|
:param vectorstore: Vector store to use in the conversation chain.
|
|
:type vectorstore: Any
|
|
:return: Conversational retrieval chain.
|
|
:rtype: ConversationalRetrievalChain
|
|
"""
|
|
llm = ChatOpenAI(temperature=0.7, model_name=Rag.GPT_MODEL.value)
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
|
|
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 25})
|
|
|
|
conversation_chain = ConversationalRetrievalChain.from_llm(
|
|
llm=llm,
|
|
retriever=retriever,
|
|
memory=memory,
|
|
return_source_documents=True,
|
|
)
|
|
|
|
return conversation_chain
|
|
|
|
|
|
def get_lang_doc(document_text, doc_id, metadata=None, encoding='utf-8'):
|
|
|
|
"""
|
|
Build a langchain Document that can be used to create a chroma database
|
|
|
|
:type document_text: str
|
|
:param document_text:
|
|
The text to add to a document object
|
|
:type doc_id: str
|
|
:param doc_id:
|
|
The document id to include.
|
|
:type metadata: dict
|
|
:param metadata:
|
|
A dictionary of metadata to associate to the document object. This will help filter an item from a
|
|
vector database.
|
|
:type encoding: string
|
|
:param encoding:
|
|
The type of encoding to use for loading the text.
|
|
|
|
"""
|
|
return Document(
|
|
page_content=document_text,
|
|
id=doc_id,
|
|
metadata=metadata,
|
|
encoding=encoding,
|
|
)
|
|
|
|
|