528 lines
14 KiB
Plaintext
528 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2a0f44a9-37cd-4aa5-9b20-cfc0dc8dfc0a",
|
|
"metadata": {},
|
|
"source": [
|
|
"# The Price is Right\n",
|
|
"\n",
|
|
"Today we build a more complex solution for estimating prices of goods.\n",
|
|
"\n",
|
|
"1. Day 2.0 notebook: create a RAG database with our 400,000 training data\n",
|
|
"2. Day 2.1 notebook: visualize in 2D\n",
|
|
"3. Day 2.2 notebook: visualize in 3D\n",
|
|
"4. Day 2.3 notebook: build and test a RAG pipeline with GPT-4o-mini\n",
|
|
"5. Day 2.4 notebook: (a) bring back our Random Forest pricer (b) Create a Ensemble pricer that allows contributions from all the pricers\n",
|
|
"\n",
|
|
"Phew! That's a lot to get through in one day!\n",
|
|
"\n",
|
|
"## PLEASE NOTE:\n",
|
|
"\n",
|
|
"We already have a very powerful product estimator with our proprietary, fine-tuned LLM. Most people would be very satisfied with that! The main reason we're adding these extra steps is to deepen your expertise with RAG and with Agentic workflows.\n",
|
|
"\n",
|
|
"## We will go fast today! Hold on to your hat.."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fbcdfea8-7241-46d7-a771-c0381a3e7063",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# imports\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import re\n",
|
|
"import math\n",
|
|
"import json\n",
|
|
"from tqdm import tqdm\n",
|
|
"import random\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"from huggingface_hub import login\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import pickle\n",
|
|
"from openai import OpenAI\n",
|
|
"from sentence_transformers import SentenceTransformer\n",
|
|
"from datasets import load_dataset\n",
|
|
"import chromadb\n",
|
|
"from testing import Tester"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "98666e73-938e-469d-8987-e6e55ba5e034",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# environment\n",
|
|
"\n",
|
|
"load_dotenv(override=True)\n",
|
|
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
|
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ce73b034-9ec1-4533-ba41-3e57c7878b61",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Log in to HuggingFace\n",
|
|
"\n",
|
|
"hf_token = os.environ['HF_TOKEN']\n",
|
|
"login(hf_token, add_to_git_credential=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4c01daad-86b0-4bc0-91ba-20a64df043ed",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Another import after Logging in to Hugging Face - thank you Trung N.!\n",
|
|
"\n",
|
|
"from items import Item"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9a25a5cf-8f6c-4b5d-ad98-fdd096f5adf8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"openai = OpenAI()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "dc696493-0b6f-48aa-9fa8-b1ae0ecaf3cd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load in the test pickle file\n",
|
|
"# See the section \"Back to the PKL files\" in the day2.0 notebook\n",
|
|
"# for instructions on obtaining this test.pkl file\n",
|
|
"\n",
|
|
"with open('test.pkl', 'rb') as file:\n",
|
|
" test = pickle.load(file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "33d38a06-0c0d-4e96-94d1-35ee183416ce",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def make_context(similars, prices):\n",
|
|
" message = \"To provide some context, here are some other items that might be similar to the item you need to estimate.\\n\\n\"\n",
|
|
" for similar, price in zip(similars, prices):\n",
|
|
" message += f\"Potentially related product:\\n{similar}\\nPrice is ${price:.2f}\\n\\n\"\n",
|
|
" return message"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "61f203b7-63b6-48ed-869b-e393b5bfcad3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def messages_for(item, similars, prices):\n",
|
|
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
|
" user_prompt = make_context(similars, prices)\n",
|
|
" user_prompt += \"And now the question for you:\\n\\n\"\n",
|
|
" user_prompt += item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
|
" return [\n",
|
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
|
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
|
" ]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b26f405d-6e1f-4caa-b97f-1f62cd9d1ebc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"DB = \"products_vectorstore\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d26a1104-cd11-4361-ab25-85fb576e0582",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"client = chromadb.PersistentClient(path=DB)\n",
|
|
"collection = client.get_or_create_collection('products')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1e339760-96d8-4485-bec7-43fadcd30c4d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def description(item):\n",
|
|
" text = item.prompt.replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
|
|
" return text.split(\"\\n\\nPrice is $\")[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a1bd0c87-8bad-43d9-9461-bb69a9e0e22c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"description(test[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9f759bd2-7a7e-4c1a-80a0-e12470feca89",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e44dbd25-fb95-4b6b-bbbb-8da5fc817105",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def vector(item):\n",
|
|
" return model.encode([description(item)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ffd5ee47-db5d-4263-b0d9-80d568c91341",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def find_similars(item):\n",
|
|
" results = collection.query(query_embeddings=vector(item).astype(float).tolist(), n_results=5)\n",
|
|
" documents = results['documents'][0][:]\n",
|
|
" prices = [m['price'] for m in results['metadatas'][0][:]]\n",
|
|
" return documents, prices"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6f7b9ff9-fd90-4627-bb17-7c2f7bbd21f3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(test[1].prompt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ff1b2659-cc6b-47aa-a797-dd1cd3d1d6c3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"documents, prices = find_similars(test[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "24756d4d-edac-41ce-bb80-c3b6f1cea7ee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(make_context(documents, prices))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0b81eca2-0b58-4fe8-9dd6-47f13ba5f8ee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(messages_for(test[1], documents, prices))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d11f1c8d-7480-4d64-a274-b030d701f1b8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_price(s):\n",
|
|
" s = s.replace('$','').replace(',','')\n",
|
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
|
" return float(match.group()) if match else 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "06743833-c362-47f8-b02a-139be2cd52ab",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"get_price(\"The price for this is $99.99\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a919cf7d-b3d3-4968-8c96-54a0da0b0219",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# The function for gpt-4o-mini\n",
|
|
"\n",
|
|
"def gpt_4o_mini_rag(item):\n",
|
|
" documents, prices = find_similars(item)\n",
|
|
" response = openai.chat.completions.create(\n",
|
|
" model=\"gpt-4o-mini\", \n",
|
|
" messages=messages_for(item, documents, prices),\n",
|
|
" seed=42,\n",
|
|
" max_tokens=5\n",
|
|
" )\n",
|
|
" reply = response.choices[0].message.content\n",
|
|
" return get_price(reply)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3e519e26-ff15-4425-90bb-bfbf55deb39b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gpt_4o_mini_rag(test[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ce78741b-2966-41d2-9831-cbf8f8d176be",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test[1].price"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "16d90455-ff7d-4f5f-8b8c-8e061263d1c7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"Tester.test(gpt_4o_mini_rag, test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d793c6d0-ce3f-4680-b37d-4643f0cd1d8e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Optional Extra: Trying a DeepSeek API call instead of OpenAI\n",
|
|
"\n",
|
|
"If you have a DeepSeek API key, we will use it here as an alternative implementation; otherwise skip to the next section.."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "21b6a22f-0195-47b6-8f6d-cab6ebe05742",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Connect to DeepSeek using the OpenAI client python library\n",
|
|
"\n",
|
|
"deepseek_api_key = os.getenv(\"DEEPSEEK_API_KEY\")\n",
|
|
"deepseek_via_openai_client = OpenAI(api_key=deepseek_api_key,base_url=\"https://api.deepseek.com\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ea7267d6-9489-4dac-a6e0-aec108e788c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Added some retry logic here because DeepSeek is very oversubscribed and sometimes fails..\n",
|
|
"\n",
|
|
"def deepseek_api_rag(item):\n",
|
|
" documents, prices = find_similars(item)\n",
|
|
" retries = 8\n",
|
|
" done = False\n",
|
|
" while not done and retries > 0:\n",
|
|
" try:\n",
|
|
" response = deepseek_via_openai_client.chat.completions.create(\n",
|
|
" model=\"deepseek-chat\", \n",
|
|
" messages=messages_for(item, documents, prices),\n",
|
|
" seed=42,\n",
|
|
" max_tokens=8\n",
|
|
" )\n",
|
|
" reply = response.choices[0].message.content\n",
|
|
" done = True\n",
|
|
" except Exception as e:\n",
|
|
" print(f\"Error: {e}\")\n",
|
|
" retries -= 1\n",
|
|
" return get_price(reply)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6560faf2-4dec-41e5-95e2-b2c46cdb3ba8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"deepseek_api_rag(test[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0578b116-869f-429d-8382-701f1c0882f3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"Tester.test(deepseek_api_rag, test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6739870f-1eec-4547-965d-4b594e685697",
|
|
"metadata": {},
|
|
"source": [
|
|
"## And now to wrap this in an \"Agent\" class"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e6d5deb3-6a2a-4484-872c-37176c5e1f07",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from agents.frontier_agent import FrontierAgent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2efa7ba9-c2d7-4f95-8bb5-c4295bbeb01f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's print the logs so we can see what's going on\n",
|
|
"\n",
|
|
"import logging\n",
|
|
"root = logging.getLogger()\n",
|
|
"root.setLevel(logging.INFO)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "56e8dd5d-ed36-49d8-95f7-dc82e548255b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent = FrontierAgent(collection)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "980dd126-f675-4499-8817-0cc0bb73e247",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent.price(\"Quadcast HyperX condenser mic for high quality podcasting\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "66c18a06-d0f1-4ec9-8aff-ec3ca294dd09",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from agents.specialist_agent import SpecialistAgent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ba672fb4-2c3e-42ee-9ea0-21bfcfc5260c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent2 = SpecialistAgent()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a5a97004-95b4-46ea-b12d-a4ead22fcb2a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent2.price(\"Quadcast HyperX condenser mic for high quality podcasting\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "26d5ddc6-baa6-4760-a430-05671847ac47",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|