Add initial implementation of Deal Intel project
This commit introduces the foundational structure for the Deal Intel project, including: - Environment configuration file (.env.example) for managing secrets and API keys. - Scripts for building a ChromaDB vector store (build_vector_store.py) and training machine learning models (train_rf.py, train_ensemble.py). - Health check functionality (health_check.py) to ensure system readiness. - A launcher script (launcher.py) for executing various commands, including UI launch and health checks. - Logging utilities (logging_utils.py) for consistent logging across the application. - A README file providing an overview and setup instructions for the project. These additions establish a comprehensive framework for an agentic deal-hunting AI system, integrating various components for data processing, model training, and user interaction.
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
# Modal & Hugging Face
|
||||
MODAL_TOKEN_ID=your_modal_token_id
|
||||
MODAL_TOKEN_SECRET=your_modal_token_secret
|
||||
HF_TOKEN=your_hf_token
|
||||
|
||||
# LLM Providers (use one)
|
||||
OPENAI_API_KEY=your_openai_api_key
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key
|
||||
|
||||
# Pushover (push notifications)
|
||||
PUSHOVER_USER=your_pushover_user
|
||||
PUSHOVER_TOKEN=your_pushover_token
|
||||
|
||||
# Twilio (SMS)
|
||||
TWILIO_ACCOUNT_SID=your_twilio_sid
|
||||
TWILIO_AUTH_TOKEN=your_twilio_auth
|
||||
@@ -0,0 +1,74 @@
|
||||
# Deal Intel — Agentic Deal-Hunting AI
|
||||
|
||||
## Overview
|
||||
An end-to-end agentic system that scans product sources, estimates fair value using a hybrid LLM+RAG+ML stack, ranks best opportunities, and alerts you via push/SMS. Includes a Gradio UI and vector-space visualization.
|
||||
|
||||
## Prerequisites
|
||||
- Environment and secrets:
|
||||
- `HF_TOKEN`, `MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET`
|
||||
- Either `OPENAI_API_KEY` or `DEEPSEEK_API_KEY`
|
||||
- For push notifications: `PUSHOVER_USER`, `PUSHOVER_TOKEN`
|
||||
- Optional Twilio SMS: `TWILIO_ACCOUNT_SID`, `TWILIO_AUTH_TOKEN`
|
||||
- Dependencies installed: `pip install -r requirements.txt`
|
||||
- Modal set up: `modal setup` (or env vars) and credits available
|
||||
|
||||
## Steps & Acceptance Criteria
|
||||
|
||||
1) Environment Setup
|
||||
- Install Python deps and export required secrets.
|
||||
- Acceptance: `openai`, `chromadb`, and `modal` import successfully; `modal setup` completes.
|
||||
|
||||
2) Deploy Specialist Pricer on Modal
|
||||
- Use `pricer_service2.py` and deploy the `Pricer` class with GPU and Hugging Face cache.
|
||||
- Acceptance: `Pricer.price.remote("...")` returns a numeric price; `keep_warm.py` prints `"ok"` every cycle if used.
|
||||
|
||||
3) Build Product Vector Store (RAG)
|
||||
- Populate `chromadb` persistent DB `products_vectorstore` with embeddings, documents, metadatas (including `price` and `category`).
|
||||
- Acceptance: Query for 5 similars returns valid `documents` and `metadatas` with prices.
|
||||
|
||||
4) Train Classical ML Models and Save Artifacts
|
||||
- Train Random Forest on embeddings → save `random_forest_model.pkl` at repo root.
|
||||
- Train Ensemble `LinearRegression` over Specialist/Frontier/RF predictions → save `ensemble_model.pkl`.
|
||||
- Acceptance: Files exist and load in `agents/random_forest_agent.py` and `agents/ensemble_agent.py`.
|
||||
|
||||
5) Verify Individual Agents
|
||||
- SpecialistAgent → calls Modal pricer and returns float.
|
||||
- FrontierAgent → performs RAG on `chromadb`, calls `OpenAI`/`DeepSeek`.
|
||||
- RandomForestAgent → loads `random_forest_model.pkl`, encodes descriptions with `SentenceTransformer`.
|
||||
- ScannerAgent → pulls RSS feeds and returns consistent structured outputs with clear-price deals.
|
||||
- Acceptance: Each agent returns sensible outputs without exceptions.
|
||||
|
||||
6) Orchestration (Planning + Messaging)
|
||||
- PlanningAgent coordinates scanning → ensemble pricing → selection against `DEAL_THRESHOLD`.
|
||||
- MessagingAgent pushes alerts via Pushover; optionally Twilio SMS if enabled.
|
||||
- Acceptance: Planner produces at least one `Opportunity` and alert sends with price/estimate/discount/URL.
|
||||
|
||||
7) Framework & Persistence
|
||||
- DealAgentFramework initializes logging, loads `chromadb`, reads/writes `memory.json`.
|
||||
- Acceptance: After a run, `memory.json` includes the new opportunity.
|
||||
|
||||
8) UI (Gradio)
|
||||
- Use `price_is_right_final.py` for logs, embeddings 3D plot, and interactive table; `price_is_right.py` is a simpler alternative.
|
||||
- Acceptance: UI loads; “Run” updates opportunities; selecting a row triggers alert.
|
||||
|
||||
9) Operational Readiness
|
||||
- Keep-warm optional: ping `Pricer.wake_up.remote()` to avoid cold starts.
|
||||
- Acceptance: End-to-end run latency is acceptable; reduced cold start when keep-warm is active.
|
||||
|
||||
10) Testing
|
||||
- Run CI tests in `community_contributions/pricer_test/`.
|
||||
- Add a smoke test for `DealAgentFramework.run()` and memory persistence.
|
||||
- Acceptance: Tests pass; smoke run returns plausible prices and discounts.
|
||||
|
||||
## Usage
|
||||
|
||||
- Launch UI:
|
||||
- `python "Deal Intel/launcher.py" ui`
|
||||
- Run planner one cycle:
|
||||
- `python "Deal Intel/launcher.py" run`
|
||||
- Keep Modal warm (optional):
|
||||
- `python "Deal Intel/launcher.py" keepwarm`
|
||||
|
||||
## Required Artifacts
|
||||
- `random_forest_model.pkl` — required by `agents/random_forest_agent.py`
|
||||
- `ensemble_model.pkl` — required by `agents/ensemble_agent.py`
|
||||
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Build a ChromaDB vector store ('products_vectorstore') with product documents and embeddings.
|
||||
Streaming from McAuley-Lab/Amazon-Reviews-2023 raw_meta_* datasets.
|
||||
"""
|
||||
|
||||
from itertools import islice
|
||||
from typing import List, Dict, Iterable
|
||||
|
||||
import argparse
|
||||
import chromadb
|
||||
from datasets import load_dataset
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from logging_utils import init_logger
|
||||
import config as cfg
|
||||
|
||||
logger = init_logger("DealIntel.BuildVectorStore")
|
||||
|
||||
def text_for(dp: Dict) -> str:
|
||||
"""
|
||||
Construct product text from typical raw_meta_* fields: title + description + features + details.
|
||||
"""
|
||||
title = dp.get("title") or ""
|
||||
description = "\n".join(dp.get("description") or [])
|
||||
features = "\n".join(dp.get("features") or [])
|
||||
details = (dp.get("details") or "").strip()
|
||||
parts = [title, description, features, details]
|
||||
return "\n".join([p for p in parts if p])
|
||||
|
||||
def stream_category(category: str) -> Iterable[Dict]:
|
||||
"""
|
||||
Stream datapoints from raw_meta_{category}.
|
||||
"""
|
||||
ds = load_dataset(
|
||||
"McAuley-Lab/Amazon-Reviews-2023",
|
||||
f"raw_meta_{category}",
|
||||
split="full",
|
||||
trust_remote_code=True,
|
||||
streaming=True,
|
||||
)
|
||||
return ds
|
||||
|
||||
def build(categories: List[str], max_items_per_category: int, batch_size: int):
|
||||
logger.info(f"Initializing DB at '{cfg.DB_PATH}' collection '{cfg.COLLECTION_NAME}'")
|
||||
client = chromadb.PersistentClient(path=cfg.DB_PATH)
|
||||
collection = client.get_or_create_collection(cfg.COLLECTION_NAME)
|
||||
|
||||
logger.info(f"Loading embedding model '{cfg.MODEL_NAME}'")
|
||||
model = SentenceTransformer(cfg.MODEL_NAME)
|
||||
|
||||
total_added = 0
|
||||
for category in categories:
|
||||
logger.info(f"Category {category}: targeting up to {max_items_per_category} items")
|
||||
stream = stream_category(category)
|
||||
limited = islice(stream, max_items_per_category)
|
||||
|
||||
buffer_docs: List[str] = []
|
||||
buffer_embeddings: List[List[float]] = []
|
||||
buffer_metadatas: List[Dict] = []
|
||||
buffer_ids: List[str] = []
|
||||
count = 0
|
||||
|
||||
for dp in tqdm(limited, total=max_items_per_category, desc=f"{category}"):
|
||||
price = dp.get("price")
|
||||
if not price:
|
||||
continue
|
||||
try:
|
||||
price_val = float(price)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
doc = text_for(dp)
|
||||
if not doc or len(doc) < 50:
|
||||
continue
|
||||
|
||||
buffer_docs.append(doc)
|
||||
buffer_metadatas.append({"price": price_val, "category": category})
|
||||
buffer_ids.append(f"{category}-{dp.get('asin', str(count))}")
|
||||
count += 1
|
||||
|
||||
if len(buffer_docs) >= batch_size:
|
||||
embeddings = model.encode(buffer_docs, show_progress_bar=False)
|
||||
buffer_embeddings = [emb.tolist() for emb in embeddings]
|
||||
collection.add(
|
||||
ids=buffer_ids,
|
||||
documents=buffer_docs,
|
||||
metadatas=buffer_metadatas,
|
||||
embeddings=buffer_embeddings,
|
||||
)
|
||||
total_added += len(buffer_docs)
|
||||
buffer_docs.clear()
|
||||
buffer_embeddings.clear()
|
||||
buffer_metadatas.clear()
|
||||
buffer_ids.clear()
|
||||
|
||||
if buffer_docs:
|
||||
embeddings = model.encode(buffer_docs, show_progress_bar=False)
|
||||
buffer_embeddings = [emb.tolist() for emb in embeddings]
|
||||
collection.add(
|
||||
ids=buffer_ids,
|
||||
documents=buffer_docs,
|
||||
metadatas=buffer_metadatas,
|
||||
embeddings=buffer_embeddings,
|
||||
)
|
||||
total_added += len(buffer_docs)
|
||||
|
||||
logger.info(f"Category {category}: added {count} items")
|
||||
|
||||
logger.info(f"Completed build. Total items added: {total_added}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build ChromaDB vector store")
|
||||
parser.add_argument("--categories", nargs="*", default=cfg.CATEGORIES, help="Categories to ingest")
|
||||
parser.add_argument("--max-per-category", type=int, default=cfg.MAX_ITEMS_PER_CATEGORY)
|
||||
parser.add_argument("--batch-size", type=int, default=cfg.BATCH_SIZE)
|
||||
args = parser.parse_args()
|
||||
|
||||
build(args.categories, args.max_per_category, args.batch_size)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Centralized configuration for Deal Intel.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
# Vector store
|
||||
DB_PATH = os.getenv("DEAL_INTEL_DB_PATH", "products_vectorstore")
|
||||
COLLECTION_NAME = os.getenv("DEAL_INTEL_COLLECTION", "products")
|
||||
|
||||
# Embedding model
|
||||
MODEL_NAME = os.getenv("DEAL_INTEL_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
# Categories (kept consistent with framework plot colors)
|
||||
CATEGORIES: List[str] = [
|
||||
"Appliances",
|
||||
"Automotive",
|
||||
"Cell_Phones_and_Accessories",
|
||||
"Electronics",
|
||||
"Musical_Instruments",
|
||||
"Office_Products",
|
||||
"Tools_and_Home_Improvement",
|
||||
"Toys_and_Games",
|
||||
]
|
||||
|
||||
# Data limits
|
||||
MAX_ITEMS_PER_CATEGORY = int(os.getenv("DEAL_INTEL_MAX_ITEMS", "2500"))
|
||||
BATCH_SIZE = int(os.getenv("DEAL_INTEL_BATCH_SIZE", "500"))
|
||||
|
||||
# Training limits
|
||||
RF_MAX_DATAPOINTS = int(os.getenv("DEAL_INTEL_RF_MAX_DATAPOINTS", "10000"))
|
||||
ENSEMBLE_SAMPLE_SIZE = int(os.getenv("DEAL_INTEL_ENSEMBLE_SAMPLE_SIZE", "200"))
|
||||
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Health checks for Deal Intel readiness:
|
||||
- Environment variables presence
|
||||
- Modal pricer availability
|
||||
- ChromaDB collection populated
|
||||
- Model artifacts load
|
||||
- Agent instantiation
|
||||
"""
|
||||
|
||||
import os
|
||||
import joblib
|
||||
import chromadb
|
||||
|
||||
from logging_utils import init_logger
|
||||
import config as cfg
|
||||
|
||||
logger = init_logger("DealIntel.Health")
|
||||
|
||||
def check_env() -> bool:
|
||||
ok = True
|
||||
required_any = ["OPENAI_API_KEY", "DEEPSEEK_API_KEY"]
|
||||
required = ["HF_TOKEN", "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET"]
|
||||
push_vars = ["PUSHOVER_USER", "PUSHOVER_TOKEN"]
|
||||
|
||||
logger.info("Checking environment variables")
|
||||
if not any(os.getenv(k) for k in required_any):
|
||||
logger.warning("Missing OPENAI_API_KEY or DEEPSEEK_API_KEY")
|
||||
ok = False
|
||||
for k in required:
|
||||
if not os.getenv(k):
|
||||
logger.warning(f"Missing {k}")
|
||||
ok = False
|
||||
if not all(os.getenv(k) for k in push_vars):
|
||||
logger.info("Pushover tokens not found — push alerts will be disabled")
|
||||
return ok
|
||||
|
||||
def check_modal() -> bool:
|
||||
import modal
|
||||
logger.info("Checking Modal pricer wake_up()")
|
||||
try:
|
||||
try:
|
||||
Pricer = modal.Cls.from_name("pricer-service", "Pricer")
|
||||
except Exception:
|
||||
Pricer = modal.Cls.lookup("pricer-service", "Pricer")
|
||||
pricer = Pricer()
|
||||
reply = pricer.wake_up.remote()
|
||||
logger.info(f"Modal wake_up reply: {reply}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Modal pricer check failed: {e}")
|
||||
return False
|
||||
|
||||
def check_chroma() -> bool:
|
||||
logger.info(f"Checking ChromaDB at '{cfg.DB_PATH}' collection '{cfg.COLLECTION_NAME}'")
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=cfg.DB_PATH)
|
||||
collection = client.get_or_create_collection(cfg.COLLECTION_NAME)
|
||||
result = collection.get(include=['embeddings'], limit=10)
|
||||
count = len(result.get("embeddings") or [])
|
||||
logger.info(f"ChromaDB sample embeddings count: {count}")
|
||||
return count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB check failed: {e}")
|
||||
return False
|
||||
|
||||
def check_models() -> bool:
|
||||
logger.info("Checking model artifacts load")
|
||||
ok = True
|
||||
try:
|
||||
joblib.load("random_forest_model.pkl")
|
||||
logger.info("Random Forest model loaded")
|
||||
except Exception as e:
|
||||
logger.error(f"Random Forest model load failed: {e}")
|
||||
ok = False
|
||||
try:
|
||||
joblib.load("ensemble_model.pkl")
|
||||
logger.info("Ensemble model loaded")
|
||||
except Exception as e:
|
||||
logger.error(f"Ensemble model load failed: {e}")
|
||||
ok = False
|
||||
return ok
|
||||
|
||||
def check_agents() -> bool:
|
||||
logger.info("Checking agent instantiation")
|
||||
try:
|
||||
from agents.random_forest_agent import RandomForestAgent
|
||||
from agents.frontier_agent import FrontierAgent
|
||||
from agents.specialist_agent import SpecialistAgent
|
||||
|
||||
client = chromadb.PersistentClient(path=cfg.DB_PATH)
|
||||
collection = client.get_or_create_collection(cfg.COLLECTION_NAME)
|
||||
|
||||
rf = RandomForestAgent()
|
||||
fr = FrontierAgent(collection)
|
||||
sp = SpecialistAgent()
|
||||
_ = (rf, fr, sp)
|
||||
logger.info("Agents instantiated")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Agent instantiation failed: {e}")
|
||||
return False
|
||||
|
||||
def run_all() -> bool:
|
||||
env_ok = check_env()
|
||||
modal_ok = check_modal()
|
||||
chroma_ok = check_chroma()
|
||||
models_ok = check_models()
|
||||
agents_ok = check_agents()
|
||||
|
||||
overall = all([env_ok, modal_ok, chroma_ok, models_ok, agents_ok])
|
||||
if overall:
|
||||
logger.info("Health check passed — system ready")
|
||||
else:
|
||||
logger.warning("Health check failed — see logs for details")
|
||||
return overall
|
||||
|
||||
if __name__ == "__main__":
|
||||
ready = run_all()
|
||||
if not ready:
|
||||
raise SystemExit(1)
|
||||
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Deal Intel launcher script
|
||||
- ui: launch Gradio UI from price_is_right_final.App
|
||||
- run: execute one planner cycle and print resulting opportunities
|
||||
- keepwarm: ping Modal Pricer.wake_up to keep container warm
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from logging_utils import init_logger
|
||||
logger = init_logger("DealIntel.Launcher")
|
||||
|
||||
def launch_ui():
|
||||
from price_is_right_final import App
|
||||
logger.info("Launching UI")
|
||||
App().run()
|
||||
|
||||
def run_once():
|
||||
from deal_agent_framework import DealAgentFramework
|
||||
fw = DealAgentFramework()
|
||||
fw.init_agents_as_needed()
|
||||
logger.info("Running planner once")
|
||||
opportunities = fw.run()
|
||||
logger.info(f"Opportunities in memory: {len(opportunities)}")
|
||||
if opportunities:
|
||||
last = opportunities[-1]
|
||||
logger.info(f"Last opportunity: price=${last.deal.price:.2f}, estimate=${last.estimate:.2f}, discount=${last.discount:.2f}")
|
||||
logger.info(f"URL: {last.deal.url}")
|
||||
logger.info(f"Description: {last.deal.product_description[:120]}...")
|
||||
|
||||
def keep_warm(interval_sec: int = 30):
|
||||
import modal
|
||||
logger.info("Starting keep-warm loop for Modal Pricer")
|
||||
try:
|
||||
Pricer = modal.Cls.from_name("pricer-service", "Pricer")
|
||||
except Exception:
|
||||
Pricer = modal.Cls.lookup("pricer-service", "Pricer")
|
||||
pricer = Pricer()
|
||||
try:
|
||||
while True:
|
||||
reply = pricer.wake_up.remote()
|
||||
logger.info(f"Wake-up reply: {reply}")
|
||||
time.sleep(interval_sec)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keep-warm loop stopped")
|
||||
|
||||
def health():
|
||||
logger.info("Running health checks")
|
||||
from health_check import run_all
|
||||
ok = run_all()
|
||||
if not ok:
|
||||
logger.warning("Health checks failed")
|
||||
raise SystemExit(1)
|
||||
logger.info("Health checks passed")
|
||||
|
||||
def main(argv=None):
|
||||
parser = argparse.ArgumentParser(description="Deal Intel Launcher")
|
||||
parser.add_argument("command", choices=["ui", "run", "keepwarm", "health"], help="Command to execute")
|
||||
parser.add_argument("--interval", type=int, default=30, help="Keep-warm ping interval (seconds)")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if args.command == "ui":
|
||||
launch_ui()
|
||||
elif args.command == "run":
|
||||
run_once()
|
||||
elif args.command == "keepwarm":
|
||||
keep_warm(args.interval)
|
||||
elif args.command == "health":
|
||||
health()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shared logging utilities for Deal Intel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
DEFAULT_FORMAT = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s"
|
||||
DEFAULT_DATEFMT = "%Y-%m-%d %H:%M:%S %z"
|
||||
|
||||
def init_logger(name: str, level: Optional[str] = None) -> logging.Logger:
|
||||
"""
|
||||
Initialize and return a logger with consistent formatting.
|
||||
Level can be overridden via env DEAL_INTEL_LOG_LEVEL.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if logger.handlers:
|
||||
return logger # avoid duplicate handlers
|
||||
|
||||
env_level = os.getenv("DEAL_INTEL_LOG_LEVEL", "INFO")
|
||||
level = level or env_level
|
||||
level_map = {
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
"ERROR": logging.ERROR,
|
||||
"WARNING": logging.WARNING,
|
||||
"INFO": logging.INFO,
|
||||
"DEBUG": logging.DEBUG,
|
||||
}
|
||||
logger.setLevel(level_map.get(level.upper(), logging.INFO))
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter(DEFAULT_FORMAT, datefmt=DEFAULT_DATEFMT))
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a LinearRegression ensemble over Specialist, Frontier, and RF predictions.
|
||||
Saves to ensemble_model.pkl and logs coefficients and metrics.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import chromadb
|
||||
from tqdm import tqdm
|
||||
|
||||
from agents.specialist_agent import SpecialistAgent
|
||||
from agents.frontier_agent import FrontierAgent
|
||||
from agents.random_forest_agent import RandomForestAgent
|
||||
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
from logging_utils import init_logger
|
||||
import config as cfg
|
||||
|
||||
logger = init_logger("DealIntel.TrainEnsemble")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train Ensemble LinearRegression")
|
||||
parser.add_argument("--sample-size", type=int, default=cfg.ENSEMBLE_SAMPLE_SIZE)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Initializing Chroma collection")
|
||||
client = chromadb.PersistentClient(path=cfg.DB_PATH)
|
||||
collection = client.get_or_create_collection(cfg.COLLECTION_NAME)
|
||||
|
||||
logger.info("Loading datapoints")
|
||||
result = collection.get(include=['documents', 'metadatas'], limit=args.sample_size * 10)
|
||||
documents = result["documents"]
|
||||
metadatas = result["metadatas"]
|
||||
if not documents:
|
||||
raise RuntimeError("No documents in collection — build the vector store first.")
|
||||
|
||||
pairs = list(zip(documents, metadatas))
|
||||
random.seed(42)
|
||||
random.shuffle(pairs)
|
||||
pairs = pairs[:args.sample_size]
|
||||
|
||||
logger.info("Initializing agents")
|
||||
specialist = SpecialistAgent()
|
||||
frontier = FrontierAgent(collection)
|
||||
rf = RandomForestAgent()
|
||||
|
||||
X_rows = []
|
||||
y_vals = []
|
||||
logger.info(f"Collecting predictions for {len(pairs)} samples")
|
||||
for doc, md in tqdm(pairs, desc="Collect"):
|
||||
description = doc
|
||||
target_price = float(md["price"])
|
||||
|
||||
try:
|
||||
s = specialist.price(description)
|
||||
except Exception as e:
|
||||
logger.warning(f"Specialist failed; skipping sample: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
f = frontier.price(description)
|
||||
except Exception as e:
|
||||
logger.warning(f"Frontier failed; skipping sample: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
r = rf.price(description)
|
||||
except Exception as e:
|
||||
logger.warning(f"RF failed; skipping sample: {e}")
|
||||
continue
|
||||
|
||||
X_rows.append({
|
||||
"Specialist": s,
|
||||
"Frontier": f,
|
||||
"RandomForest": r,
|
||||
"Min": min(s, f, r),
|
||||
"Max": max(s, f, r),
|
||||
})
|
||||
y_vals.append(target_price)
|
||||
|
||||
if len(X_rows) < 20:
|
||||
raise RuntimeError("Too few samples collected. Ensure tokens/services are configured and retry.")
|
||||
|
||||
X = pd.DataFrame(X_rows)
|
||||
y = pd.Series(y_vals)
|
||||
|
||||
logger.info("Fitting LinearRegression")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
|
||||
lr = LinearRegression()
|
||||
lr.fit(X_train, y_train)
|
||||
|
||||
preds = lr.predict(X_test)
|
||||
rmse = mean_squared_error(y_test, preds, squared=False)
|
||||
r2 = r2_score(y_test, preds)
|
||||
logger.info(f"Holdout RMSE={rmse:.2f}, R2={r2:.3f}")
|
||||
|
||||
coef_log = ", ".join([f"{col}={coef:.3f}" for col, coef in zip(X.columns.tolist(), lr.coef_)])
|
||||
logger.info(f"Coefficients: {coef_log}; Intercept={lr.intercept_:.3f}")
|
||||
|
||||
joblib.dump(lr, "ensemble_model.pkl")
|
||||
logger.info("Saved model to ensemble_model.pkl")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a RandomForestRegressor on embeddings from ChromaDB, save to random_forest_model.pkl.
|
||||
Logs simple holdout metrics.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import joblib
|
||||
import numpy as np
|
||||
import chromadb
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
|
||||
from logging_utils import init_logger
|
||||
import config as cfg
|
||||
|
||||
logger = init_logger("DealIntel.TrainRF")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train Random Forest pricer")
|
||||
parser.add_argument("--max-datapoints", type=int, default=cfg.RF_MAX_DATAPOINTS)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Loading embeddings from {cfg.DB_PATH}/{cfg.COLLECTION_NAME} (limit={args.max_datapoints})")
|
||||
client = chromadb.PersistentClient(path=cfg.DB_PATH)
|
||||
collection = client.get_or_create_collection(cfg.COLLECTION_NAME)
|
||||
result = collection.get(include=['embeddings', 'metadatas'], limit=args.max_datapoints)
|
||||
|
||||
if not result.get("embeddings"):
|
||||
raise RuntimeError("No embeddings found — build the vector store first.")
|
||||
|
||||
X = np.array(result["embeddings"])
|
||||
y = np.array([md["price"] for md in result["metadatas"]])
|
||||
|
||||
logger.info(f"Training RF on {X.shape[0]} samples, {X.shape[1]} features")
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
rf = RandomForestRegressor(n_estimators=300, random_state=42, n_jobs=-1)
|
||||
rf.fit(X_train, y_train)
|
||||
|
||||
preds = rf.predict(X_test)
|
||||
rmse = mean_squared_error(y_test, preds, squared=False)
|
||||
r2 = r2_score(y_test, preds)
|
||||
logger.info(f"Holdout RMSE={rmse:.2f}, R2={r2:.3f}")
|
||||
|
||||
joblib.dump(rf, "random_forest_model.pkl")
|
||||
logger.info("Saved model to random_forest_model.pkl")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user