Umar - Bootcamp
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import random
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import login
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import chromadb
|
||||
from tqdm import tqdm
|
||||
|
||||
load_dotenv(override=True)
|
||||
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')
|
||||
|
||||
hf_token = os.environ['HF_TOKEN']
|
||||
login(hf_token, add_to_git_credential=True)
|
||||
|
||||
DB = "travel_vectorstore"
|
||||
CATEGORIES = ['Flights', 'Hotels', 'Car_Rentals', 'Vacation_Packages', 'Cruises', 'Activities']
|
||||
|
||||
AIRLINES = ['American Airlines', 'Delta', 'United', 'Southwest', 'JetBlue', 'Spirit', 'Frontier', 'Alaska Airlines', 'Emirates', 'British Airways', 'Air France', 'Lufthansa', 'Qatar Airways']
|
||||
CITIES = ['New York', 'Los Angeles', 'Chicago', 'Houston', 'Miami', 'San Francisco', 'Boston', 'Seattle', 'Denver', 'Atlanta', 'Las Vegas', 'Orlando', 'Phoenix', 'London', 'Paris', 'Tokyo', 'Dubai', 'Singapore', 'Sydney', 'Rome']
|
||||
HOTELS = ['Hilton', 'Marriott', 'Hyatt', 'Holiday Inn', 'Best Western', 'Sheraton', 'Ritz-Carlton', 'Four Seasons', 'Westin', 'Radisson']
|
||||
CLASSES = ['Economy', 'Premium Economy', 'Business', 'First Class']
|
||||
CAR_COMPANIES = ['Hertz', 'Enterprise', 'Avis', 'Budget', 'National', 'Alamo']
|
||||
CAR_TYPES = ['Compact', 'Sedan', 'SUV', 'Luxury', 'Van']
|
||||
|
||||
def generate_flight_description():
|
||||
airline = random.choice(AIRLINES)
|
||||
source = random.choice(CITIES)
|
||||
dest = random.choice([c for c in CITIES if c != source])
|
||||
flight_class = random.choice(CLASSES)
|
||||
stops = random.choice(['non-stop', 'one-stop', 'two-stops'])
|
||||
duration = f"{random.randint(1, 15)} hours {random.randint(0, 59)} minutes"
|
||||
|
||||
description = f"{airline} {flight_class} {stops} flight from {source} to {dest}. "
|
||||
description += f"Flight duration approximately {duration}. "
|
||||
|
||||
if random.random() > 0.5:
|
||||
description += f"Includes {random.randint(1, 2)} checked bag"
|
||||
if random.random() > 0.5:
|
||||
description += "s"
|
||||
description += ". "
|
||||
|
||||
if flight_class in ['Business', 'First Class']:
|
||||
description += random.choice(['Priority boarding included. ', 'Lounge access available. ', 'Lie-flat seats. '])
|
||||
|
||||
price = random.randint(150, 2500) if flight_class == 'Economy' else random.randint(800, 8000)
|
||||
return description, price
|
||||
|
||||
def generate_hotel_description():
|
||||
hotel = random.choice(HOTELS)
|
||||
city = random.choice(CITIES)
|
||||
stars = random.randint(2, 5)
|
||||
room_type = random.choice(['Standard Room', 'Deluxe Room', 'Suite', 'Executive Suite'])
|
||||
nights = random.randint(1, 7)
|
||||
|
||||
description = f"{hotel} {stars}-star hotel in {city}. {room_type} for {nights} night"
|
||||
if nights > 1:
|
||||
description += "s"
|
||||
description += ". "
|
||||
|
||||
amenities = []
|
||||
if random.random() > 0.3:
|
||||
amenities.append('Free WiFi')
|
||||
if random.random() > 0.5:
|
||||
amenities.append('Breakfast included')
|
||||
if random.random() > 0.6:
|
||||
amenities.append('Pool access')
|
||||
if random.random() > 0.7:
|
||||
amenities.append('Gym')
|
||||
if random.random() > 0.8:
|
||||
amenities.append('Spa services')
|
||||
|
||||
if amenities:
|
||||
description += f"Amenities: {', '.join(amenities)}. "
|
||||
|
||||
price_per_night = random.randint(80, 500) if stars <= 3 else random.randint(200, 1200)
|
||||
total_price = price_per_night * nights
|
||||
|
||||
return description, total_price
|
||||
|
||||
def generate_car_rental_description():
|
||||
company = random.choice(CAR_COMPANIES)
|
||||
car_type = random.choice(CAR_TYPES)
|
||||
city = random.choice(CITIES)
|
||||
days = random.randint(1, 14)
|
||||
|
||||
description = f"{company} car rental in {city}. {car_type} class vehicle for {days} day"
|
||||
if days > 1:
|
||||
description += "s"
|
||||
description += ". "
|
||||
|
||||
if random.random() > 0.6:
|
||||
description += "Unlimited mileage included. "
|
||||
if random.random() > 0.5:
|
||||
description += "Airport pickup available. "
|
||||
if random.random() > 0.7:
|
||||
description += "GPS navigation included. "
|
||||
|
||||
daily_rate = {'Compact': random.randint(25, 45), 'Sedan': random.randint(35, 65), 'SUV': random.randint(50, 90), 'Luxury': random.randint(80, 200), 'Van': random.randint(60, 100)}
|
||||
total_price = daily_rate[car_type] * days
|
||||
|
||||
return description, total_price
|
||||
|
||||
def generate_vacation_package_description():
|
||||
city = random.choice(CITIES)
|
||||
nights = random.randint(3, 10)
|
||||
|
||||
description = f"All-inclusive vacation package to {city} for {nights} nights. "
|
||||
description += f"Includes round-trip {random.choice(CLASSES)} flights, {random.choice(HOTELS)} hotel accommodation, "
|
||||
|
||||
extras = []
|
||||
if random.random() > 0.3:
|
||||
extras.append('daily breakfast')
|
||||
if random.random() > 0.5:
|
||||
extras.append('airport transfers')
|
||||
if random.random() > 0.6:
|
||||
extras.append('city tour')
|
||||
if random.random() > 0.7:
|
||||
extras.append('travel insurance')
|
||||
|
||||
if extras:
|
||||
description += f"and {', '.join(extras)}. "
|
||||
|
||||
base_price = random.randint(800, 4000)
|
||||
return description, base_price
|
||||
|
||||
def generate_cruise_description():
|
||||
destinations = [', '.join(random.sample(['Caribbean', 'Mediterranean', 'Alaska', 'Hawaii', 'Baltic Sea', 'South Pacific'], k=random.randint(2, 4)))]
|
||||
nights = random.choice([3, 5, 7, 10, 14])
|
||||
|
||||
description = f"{nights}-night cruise visiting {destinations[0]}. "
|
||||
description += f"All meals and entertainment included. "
|
||||
|
||||
cabin_type = random.choice(['Interior cabin', 'Ocean view cabin', 'Balcony cabin', 'Suite'])
|
||||
description += f"{cabin_type}. "
|
||||
|
||||
if random.random() > 0.5:
|
||||
description += "Unlimited beverage package available. "
|
||||
if random.random() > 0.6:
|
||||
description += "Shore excursions at each port. "
|
||||
|
||||
base_price = random.randint(500, 5000)
|
||||
return description, base_price
|
||||
|
||||
def generate_activity_description():
|
||||
city = random.choice(CITIES)
|
||||
activities = ['City sightseeing tour', 'Museum pass', 'Adventure sports package', 'Wine tasting tour', 'Cooking class', 'Hot air balloon ride', 'Snorkeling excursion', 'Helicopter tour', 'Spa day package', 'Theme park tickets']
|
||||
activity = random.choice(activities)
|
||||
|
||||
description = f"{activity} in {city}. "
|
||||
|
||||
if 'tour' in activity.lower():
|
||||
description += f"Duration: {random.randint(2, 8)} hours. "
|
||||
if random.random() > 0.5:
|
||||
description += "Hotel pickup included. "
|
||||
if random.random() > 0.6:
|
||||
description += "Small group experience. "
|
||||
|
||||
price = random.randint(30, 500)
|
||||
return description, price
|
||||
|
||||
GENERATORS = {
|
||||
'Flights': generate_flight_description,
|
||||
'Hotels': generate_hotel_description,
|
||||
'Car_Rentals': generate_car_rental_description,
|
||||
'Vacation_Packages': generate_vacation_package_description,
|
||||
'Cruises': generate_cruise_description,
|
||||
'Activities': generate_activity_description
|
||||
}
|
||||
|
||||
print("Generating synthetic travel dataset...")
|
||||
travel_data = []
|
||||
|
||||
items_per_category = 3334
|
||||
for category in CATEGORIES:
|
||||
print(f"Generating {category}...")
|
||||
generator = GENERATORS[category]
|
||||
for _ in range(items_per_category):
|
||||
description, price = generator()
|
||||
travel_data.append((description, float(price), category))
|
||||
|
||||
random.shuffle(travel_data)
|
||||
print(f"Generated {len(travel_data)} travel deals")
|
||||
|
||||
print("\nInitializing SentenceTransformer model...")
|
||||
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
||||
|
||||
print(f"Connecting to ChromaDB at {DB}...")
|
||||
client = chromadb.PersistentClient(path=DB)
|
||||
|
||||
collection_name = "travel_deals"
|
||||
existing_collections = [col.name for col in client.list_collections()]
|
||||
if collection_name in existing_collections:
|
||||
client.delete_collection(collection_name)
|
||||
print(f"Deleted existing collection: {collection_name}")
|
||||
|
||||
collection = client.create_collection(collection_name)
|
||||
print(f"Created new collection: {collection_name}")
|
||||
|
||||
print("\nCreating embeddings and adding to ChromaDB...")
|
||||
for i in tqdm(range(0, len(travel_data), 1000)):
|
||||
batch = travel_data[i:i+1000]
|
||||
documents = [desc for desc, _, _ in batch]
|
||||
vectors = model.encode(documents).astype(float).tolist()
|
||||
metadatas = [{"category": cat, "price": price} for _, price, cat in batch]
|
||||
ids = [f"travel_{j}" for j in range(i, i+len(batch))]
|
||||
|
||||
collection.add(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
embeddings=vectors,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
total_items = collection.count()
|
||||
print(f"\nVectorstore created successfully with {total_items} travel deals")
|
||||
|
||||
result = collection.get(include=['metadatas'], limit=total_items)
|
||||
categories = [m['category'] for m in result['metadatas']]
|
||||
prices = [m['price'] for m in result['metadatas']]
|
||||
category_counts = {}
|
||||
for cat in categories:
|
||||
category_counts[cat] = category_counts.get(cat, 0) + 1
|
||||
|
||||
print("\nCategory distribution:")
|
||||
for category, count in sorted(category_counts.items()):
|
||||
print(f" {category}: {count}")
|
||||
|
||||
avg_price = sum(prices) / len(prices) if prices else 0
|
||||
print(f"\nAverage price: ${avg_price:.2f}")
|
||||
print(f"Price range: ${min(prices):.2f} - ${max(prices):.2f}")
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from dotenv import load_dotenv
|
||||
import chromadb
|
||||
import numpy as np
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
w8d5_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if w8d5_path not in sys.path:
|
||||
sys.path.insert(0, w8d5_path)
|
||||
|
||||
from agents.travel_planning_agent import TravelPlanningAgent
|
||||
from helpers.travel_deals import TravelOpportunity
|
||||
|
||||
BG_BLUE = '\033[44m'
|
||||
WHITE = '\033[37m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
CATEGORIES = ['Flights', 'Hotels', 'Car_Rentals', 'Vacation_Packages', 'Cruises', 'Activities']
|
||||
COLORS = ['red', 'blue', 'green', 'orange', 'purple', 'cyan']
|
||||
|
||||
def init_logging():
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] [Travel Agents] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S %z",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
root.addHandler(handler)
|
||||
|
||||
class TravelDealFramework:
|
||||
|
||||
DB = "travel_vectorstore"
|
||||
MEMORY_FILENAME = "travel_memory.json"
|
||||
|
||||
def __init__(self):
|
||||
init_logging()
|
||||
load_dotenv()
|
||||
client = chromadb.PersistentClient(path=self.DB)
|
||||
self.memory = self.read_memory()
|
||||
self.collection = client.get_or_create_collection('travel_deals')
|
||||
self.planner = None
|
||||
|
||||
def init_agents_as_needed(self):
|
||||
if not self.planner:
|
||||
self.log("Initializing Travel Agent Framework")
|
||||
self.planner = TravelPlanningAgent(self.collection)
|
||||
self.log("Travel Agent Framework ready")
|
||||
|
||||
def read_memory(self) -> List[TravelOpportunity]:
|
||||
if os.path.exists(self.MEMORY_FILENAME):
|
||||
with open(self.MEMORY_FILENAME, "r") as file:
|
||||
data = json.load(file)
|
||||
opportunities = [TravelOpportunity(**item) for item in data]
|
||||
return opportunities
|
||||
return []
|
||||
|
||||
def write_memory(self) -> None:
|
||||
data = [opportunity.dict() for opportunity in self.memory]
|
||||
with open(self.MEMORY_FILENAME, "w") as file:
|
||||
json.dump(data, file, indent=2)
|
||||
|
||||
def log(self, message: str):
|
||||
text = BG_BLUE + WHITE + "[Travel Framework] " + message + RESET
|
||||
logging.info(text)
|
||||
|
||||
def run(self) -> List[TravelOpportunity]:
|
||||
self.init_agents_as_needed()
|
||||
logging.info("Starting Travel Planning Agent")
|
||||
results = self.planner.plan(memory=self.memory)
|
||||
logging.info(f"Travel Planning Agent completed with {len(results) if results else 0} results")
|
||||
if results:
|
||||
self.memory.extend(results)
|
||||
self.write_memory()
|
||||
return self.memory
|
||||
|
||||
@classmethod
|
||||
def get_plot_data(cls, max_datapoints=10000):
|
||||
client = chromadb.PersistentClient(path=cls.DB)
|
||||
collection = client.get_or_create_collection('travel_deals')
|
||||
result = collection.get(include=['embeddings', 'documents', 'metadatas'], limit=max_datapoints)
|
||||
vectors = np.array(result['embeddings'])
|
||||
documents = result['documents']
|
||||
categories = [metadata['category'] for metadata in result['metadatas']]
|
||||
colors = [COLORS[CATEGORIES.index(c)] for c in categories]
|
||||
tsne = TSNE(n_components=3, random_state=42, n_jobs=-1)
|
||||
reduced_vectors = tsne.fit_transform(vectors)
|
||||
return documents, reduced_vectors, colors
|
||||
|
||||
if __name__=="__main__":
|
||||
TravelDealFramework().run()
|
||||
|
||||
67
week8/community_contributions/w8d5/helpers/travel_deals.py
Normal file
67
week8/community_contributions/w8d5/helpers/travel_deals.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Self
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import feedparser
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import time
|
||||
|
||||
feeds = [
|
||||
"https://thepointsguy.com/feed/",
|
||||
]
|
||||
|
||||
def extract(html_snippet: str) -> str:
|
||||
soup = BeautifulSoup(html_snippet, 'html.parser')
|
||||
text = soup.get_text(strip=True)
|
||||
text = re.sub('<[^<]+?>', '', text)
|
||||
return text.replace('\n', ' ').strip()
|
||||
|
||||
class ScrapedTravelDeal:
|
||||
title: str
|
||||
summary: str
|
||||
url: str
|
||||
details: str
|
||||
|
||||
def __init__(self, entry: Dict[str, str]):
|
||||
self.title = entry.get('title', '')
|
||||
summary_text = entry.get('summary', entry.get('description', ''))
|
||||
self.summary = extract(summary_text)
|
||||
self.url = entry.get('link', '')
|
||||
self.details = self.summary
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.title}>"
|
||||
|
||||
def describe(self):
|
||||
return f"Title: {self.title}\nDetails: {self.details.strip()}\nURL: {self.url}"
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, show_progress: bool = False) -> List[Self]:
|
||||
deals = []
|
||||
feed_iter = tqdm(feeds) if show_progress else feeds
|
||||
for feed_url in feed_iter:
|
||||
try:
|
||||
feed = feedparser.parse(feed_url)
|
||||
for entry in feed.entries[:10]:
|
||||
deals.append(cls(entry))
|
||||
time.sleep(0.3)
|
||||
except Exception as e:
|
||||
print(f"Error fetching {feed_url}: {e}")
|
||||
return deals
|
||||
|
||||
class TravelDeal(BaseModel):
|
||||
destination: str
|
||||
deal_type: str
|
||||
description: str
|
||||
price: float
|
||||
url: str
|
||||
|
||||
class TravelDealSelection(BaseModel):
|
||||
deals: List[TravelDeal]
|
||||
|
||||
class TravelOpportunity(BaseModel):
|
||||
deal: TravelDeal
|
||||
estimate: float
|
||||
discount: float
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from dotenv import load_dotenv
|
||||
import chromadb
|
||||
import numpy as np
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
w8d5_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if w8d5_path not in sys.path:
|
||||
sys.path.insert(0, w8d5_path)
|
||||
|
||||
from agents.travel_scanner_agent import TravelScannerAgent
|
||||
from agents.travel_estimator_agent import TravelEstimatorAgent
|
||||
from agents.travel_xgboost_agent import TravelXGBoostAgent
|
||||
from agents.travel_messaging_agent import TravelMessagingAgent
|
||||
from helpers.travel_deals import TravelOpportunity, TravelDeal
|
||||
|
||||
BG_BLUE = '\033[44m'
|
||||
WHITE = '\033[37m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
CATEGORIES = ['Flights', 'Hotels', 'Car_Rentals', 'Vacation_Packages', 'Cruises', 'Activities']
|
||||
COLORS = ['red', 'blue', 'green', 'orange', 'purple', 'cyan']
|
||||
|
||||
def init_logging():
|
||||
root = logging.getLogger()
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] [Travel Agents] [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S %z",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
root.addHandler(handler)
|
||||
|
||||
|
||||
class TravelDualFramework:
|
||||
|
||||
DB = "travel_vectorstore"
|
||||
LLM_MEMORY_FILE = "travel_memory_llm.json"
|
||||
XGB_MEMORY_FILE = "travel_memory_xgb.json"
|
||||
DEAL_THRESHOLD = 200.0
|
||||
|
||||
def __init__(self):
|
||||
init_logging()
|
||||
load_dotenv()
|
||||
client = chromadb.PersistentClient(path=self.DB)
|
||||
self.collection = client.get_or_create_collection('travel_deals')
|
||||
|
||||
self.llm_memory = self.read_memory(self.LLM_MEMORY_FILE)
|
||||
self.xgb_memory = self.read_memory(self.XGB_MEMORY_FILE)
|
||||
|
||||
self.scanner = None
|
||||
self.llm_estimator = None
|
||||
self.xgb_estimator = None
|
||||
self.messenger = None
|
||||
|
||||
def init_agents_as_needed(self):
|
||||
if not self.scanner:
|
||||
self.log("Initializing Travel Dual Estimation Framework")
|
||||
self.scanner = TravelScannerAgent()
|
||||
self.llm_estimator = TravelEstimatorAgent(self.collection)
|
||||
self.xgb_estimator = TravelXGBoostAgent(self.collection)
|
||||
self.messenger = TravelMessagingAgent()
|
||||
self.log("Travel Dual Framework ready")
|
||||
|
||||
def read_memory(self, filename: str) -> List[TravelOpportunity]:
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as file:
|
||||
data = json.load(file)
|
||||
opportunities = [TravelOpportunity(**item) for item in data]
|
||||
return opportunities
|
||||
return []
|
||||
|
||||
def write_memory(self, opportunities: List[TravelOpportunity], filename: str) -> None:
|
||||
data = [opportunity.dict() for opportunity in opportunities]
|
||||
with open(filename, "w") as file:
|
||||
json.dump(data, file, indent=2)
|
||||
|
||||
def log(self, message: str):
|
||||
text = BG_BLUE + WHITE + "[Dual Framework] " + message + RESET
|
||||
logging.info(text)
|
||||
|
||||
def run(self) -> Tuple[List[TravelOpportunity], List[TravelOpportunity]]:
|
||||
self.init_agents_as_needed()
|
||||
|
||||
self.log("Starting dual estimation scan")
|
||||
deal_selection = self.scanner.scan()
|
||||
|
||||
if not deal_selection or not deal_selection.deals:
|
||||
self.log("No deals found")
|
||||
return self.llm_memory, self.xgb_memory
|
||||
|
||||
deals = deal_selection.deals
|
||||
self.log(f"Processing {len(deals)} deals with both estimators")
|
||||
|
||||
llm_opportunities = []
|
||||
xgb_opportunities = []
|
||||
|
||||
for deal in deals:
|
||||
llm_estimate = self.llm_estimator.estimate(deal.description)
|
||||
llm_discount = llm_estimate - deal.price
|
||||
|
||||
if llm_discount >= self.DEAL_THRESHOLD:
|
||||
llm_opp = TravelOpportunity(
|
||||
deal=deal,
|
||||
estimate=llm_estimate,
|
||||
discount=llm_discount
|
||||
)
|
||||
llm_opportunities.append(llm_opp)
|
||||
self.log(f"LLM found opportunity: {deal.destination} - ${llm_discount:.0f} savings")
|
||||
self.messenger.alert(llm_opp)
|
||||
|
||||
xgb_estimate = self.xgb_estimator.estimate(deal.description)
|
||||
xgb_discount = xgb_estimate - deal.price
|
||||
|
||||
if xgb_discount >= self.DEAL_THRESHOLD:
|
||||
xgb_opp = TravelOpportunity(
|
||||
deal=deal,
|
||||
estimate=xgb_estimate,
|
||||
discount=xgb_discount
|
||||
)
|
||||
xgb_opportunities.append(xgb_opp)
|
||||
self.log(f"XGBoost found opportunity: {deal.destination} - ${xgb_discount:.0f} savings")
|
||||
self.messenger.alert(xgb_opp)
|
||||
|
||||
if llm_opportunities:
|
||||
self.llm_memory.extend(llm_opportunities)
|
||||
self.write_memory(self.llm_memory, self.LLM_MEMORY_FILE)
|
||||
|
||||
if xgb_opportunities:
|
||||
self.xgb_memory.extend(xgb_opportunities)
|
||||
self.write_memory(self.xgb_memory, self.XGB_MEMORY_FILE)
|
||||
|
||||
self.log(f"Scan complete: {len(llm_opportunities)} LLM, {len(xgb_opportunities)} XGBoost opportunities")
|
||||
|
||||
return self.llm_memory, self.xgb_memory
|
||||
|
||||
@classmethod
|
||||
def get_plot_data(cls, max_datapoints=10000):
|
||||
client = chromadb.PersistentClient(path=cls.DB)
|
||||
collection = client.get_or_create_collection('travel_deals')
|
||||
result = collection.get(include=['embeddings', 'documents', 'metadatas'], limit=max_datapoints)
|
||||
vectors = np.array(result['embeddings'])
|
||||
documents = result['documents']
|
||||
categories = [metadata['category'] for metadata in result['metadatas']]
|
||||
colors = [COLORS[CATEGORIES.index(c)] for c in categories]
|
||||
tsne = TSNE(n_components=3, random_state=42, n_jobs=-1)
|
||||
reduced_vectors = tsne.fit_transform(vectors)
|
||||
return documents, reduced_vectors, colors, categories
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
framework = TravelDualFramework()
|
||||
framework.run()
|
||||
|
||||
Reference in New Issue
Block a user