Merge pull request #431 from lisekarimi/feature/week6

Add week6 contributions
This commit is contained in:
Ed Donner
2025-06-14 19:35:28 -04:00
committed by GitHub
9 changed files with 4895 additions and 0 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,510 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "12934dbc-ff4f-4dfc-8cc1-d92cc8826cf2",
"metadata": {},
"source": [
"# 🔍 Predicting Item Prices from Descriptions (Part 4)\n",
"---\n",
"- Data Curation & Preprocessing\n",
"- Model Benchmarking Traditional ML vs LLMs\n",
"- E5 Embeddings & RAG\n",
"- ➡️ Fine-Tuning GPT-4o Mini\n",
"- Evaluating LLaMA 3.1 8B Quantized\n",
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
"- Evaluating Fine-Tuned LLaMA\n",
"- Summary & Leaderboard\n",
"\n",
"---\n",
"\n",
"# 🔧 Part 4: Fine-Tuning GPT-4o Mini\n",
"\n",
"- 🧑‍💻 Skill Level: Advanced\n",
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
"- 🛠️ Requirements: 🔑 HF Token, Open API Key, wandb API Key\n",
"- Tasks:\n",
" - Convert chat data to .jsonl format for OpenAI\n",
" - Fine-tune the model and monitor with Weights & Biases\n",
" - Test the fine-tuned GPT-4o Mini \n",
"\n",
"Can fine-tuning GPT-4o Mini outperform both its zero-shot baseline and RAG-enhanced version? \n",
"Time to find out.\n",
"\n",
"---\n",
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5809630f-d3ea-41df-86ec-9cbf59a46f5c",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import importlib\n",
"import json\n",
"import re\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"from datasets import load_dataset\n",
"from openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4120c84d-c310-4d31-9e1f-1549ea4a4186",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"if not openai_api_key:\n",
" print(\"❌ OPENAI_API_KEY is missing\")\n",
"\n",
"openai = OpenAI(api_key=openai_api_key)\n",
"\n",
"hf_token = os.getenv('HF_TOKEN')\n",
"if not hf_token:\n",
" print(\"❌ HF_TOKEN is missing\")\n",
"\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "markdown",
"id": "31d3aa97-68a8-4f71-a43f-107f7c8553c5",
"metadata": {},
"source": [
"## 📥 Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2bae96a",
"metadata": {},
"outputs": [],
"source": [
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
"# %pip install -U datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c45e23d6-1304-4859-81f0-35a9ddf1c755",
"metadata": {},
"outputs": [],
"source": [
"HF_USER = \"lisekarimi\"\n",
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
"\n",
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "667adda8-add8-41b6-9e60-7870bad20c02",
"metadata": {},
"outputs": [],
"source": [
"test[0]"
]
},
{
"cell_type": "markdown",
"id": "b85d86d0-b6b1-49cd-9ef0-9214c1267199",
"metadata": {},
"source": [
"## 🛠️ Step 1 : Data Preparation"
]
},
{
"cell_type": "markdown",
"id": "d3ba760d-467a-4cd9-8d3f-e6ce84273610",
"metadata": {},
"source": [
"To fine-tune GPT-4o-mini, OpenAI requires training data in **.jsonl format**. \n",
"\n",
"`make_jsonl` converts our chat data :\n",
"\n",
"from \n",
"\n",
"[\n",
" {\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"},\n",
" {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}\n",
"]\n",
"\n",
"into the .jsonl format \n",
"\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"}, {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"}, {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}]}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec254755-67f6-4676-b67f-c1376ea00124",
"metadata": {},
"outputs": [],
"source": [
"# Mask the price in the test item\n",
"def mask_price_value(text):\n",
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e51957-b0ec-49f9-ae70-74771a101756",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(datapoint):\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
" assistant_response = f\"Price is ${datapoint['price']:.2f}\"\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": assistant_response}\n",
" ]\n",
"\n",
"messages_for(train[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03583d32-b0f2-44c0-820e-62c8e7e48247",
"metadata": {},
"outputs": [],
"source": [
"def make_jsonl(datapoints):\n",
" result = \"\"\n",
" for datapoint in datapoints:\n",
" messages = messages_for(datapoint)\n",
" messages_str = json.dumps(messages, ensure_ascii=False)\n",
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
" return result.strip()\n",
"\n",
"make_jsonl(train.select([0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36c9cf60-0bcb-44cb-8df6-ff2ed4110cd2",
"metadata": {},
"outputs": [],
"source": [
"ft_train = train.select(range(100))\n",
"ft_validation = train.select(range(100, 150))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "494eaecd-ae5d-4396-b694-6faf88fb7fd6",
"metadata": {},
"outputs": [],
"source": [
"# Convert the items into jsonl and write them to a file\n",
"\n",
"def write_jsonl(datapoints, filename):\n",
" with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
" jsonl = make_jsonl(datapoints)\n",
" f.write(jsonl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae42986d-ab02-4a11-aa0c-ede9c63ec7a2",
"metadata": {},
"outputs": [],
"source": [
"write_jsonl(ft_train, \"data/ft_train.jsonl\")\n",
"write_jsonl(ft_validation, \"data/ft_val.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9bed22d-73ad-4820-a983-cbdccd8dbbc8",
"metadata": {},
"outputs": [],
"source": [
"with open(\"data/ft_train.jsonl\", \"rb\") as f:\n",
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
"with open(\"data/ft_val.jsonl\", \"rb\") as f:\n",
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e6c6ce8-6600-4068-9ec5-32c6428ce9ea",
"metadata": {},
"outputs": [],
"source": [
"train_file"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26943fad-4301-4bb4-97e8-be52a9743322",
"metadata": {},
"outputs": [],
"source": [
"validation_file"
]
},
{
"cell_type": "markdown",
"id": "edb0a3ec-1607-4c5b-ab06-852f951cae8b",
"metadata": {},
"source": [
"## 🚀 Step 2: Run Fine-Tuning & Monitor with wandb\n",
"We will use https://wandb.ai to monitor the training runs\n",
"\n",
"1- Create an API key in wandb\n",
"\n",
"2- Add this key in OpenAI dashboard https://platform.openai.com/account/organization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59f552fe-5e80-4742-94a8-5492556a6543",
"metadata": {},
"outputs": [],
"source": [
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "144088d7-7c30-439a-9282-1e6096c181ea",
"metadata": {},
"outputs": [],
"source": [
"# Run the fine tuning\n",
"\n",
"openai.fine_tuning.jobs.create(\n",
" training_file=train_file.id,\n",
" validation_file=validation_file.id,\n",
" model=\"gpt-4o-mini-2024-07-18\",\n",
" seed=42,\n",
" hyperparameters={\"n_epochs\": 1},\n",
" integrations = [wandb_integration],\n",
" suffix=\"pricer\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "330e75f5-0208-4c74-8dd3-07bc06047b2e",
"metadata": {},
"outputs": [],
"source": [
"job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
"job_id\n",
"\n",
"# Then check your wandb dashboard to view the run of this job ID"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a92dac5-e6d8-439c-b55e-507becb37a6c",
"metadata": {},
"outputs": [],
"source": [
"# Use this command to track the fine-tuning progress here\n",
"\n",
"openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=2).data"
]
},
{
"cell_type": "markdown",
"id": "b6b65677-06b2-47d3-b0e6-51210a3d832b",
"metadata": {},
"source": [
"# 📧 Youll get an email once fine-tuning is complete. ☕ You can take a break until then. ▶️ Once you receive it, run the cells below to continue."
]
},
{
"cell_type": "markdown",
"id": "0a7af4be-0b55-4654-af7a-f47485babc52",
"metadata": {},
"source": [
"## Step 3 : Test the fine tuned model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8497eb8-49ee-4a05-9e51-fc1b4b2b41d4",
"metadata": {},
"outputs": [],
"source": [
"ft_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
"ft_model_name"
]
},
{
"cell_type": "markdown",
"id": "12bed33f-be31-4d7c-8651-3f267c529304",
"metadata": {},
"source": [
"You can find the entire fine-tuning process in the **Fine-tuning** dashboard on OpenAI.\n",
"\n",
"![Fine-tuning Process](https://github.com/lisekarimi/lexo/blob/main/assets/09_ft_gpt4omini.png?raw=true)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac6a89ef-f982-457a-bad7-bd84b6132a07",
"metadata": {},
"outputs": [],
"source": [
"# Build LLM messages\n",
"def build_messages(datapoint):\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
" ]\n",
"\n",
"def get_price(s):\n",
" s = s.replace('$','').replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
" return float(match.group()) if match else 0\n",
"\n",
"def gpt_ft(datapoint):\n",
" response = openai.chat.completions.create(\n",
" model=ft_model_name,\n",
" messages=build_messages(datapoint),\n",
" seed=42,\n",
" max_tokens=7\n",
" )\n",
" reply = response.choices[0].message.content\n",
" return get_price(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93a93017-458c-4769-b81c-b2dad2af7552",
"metadata": {},
"outputs": [],
"source": [
"print(test[0][\"price\"])\n",
"print(gpt_ft(test[0]))"
]
},
{
"cell_type": "markdown",
"id": "87a5ad10-ed60-4533-ad61-225ceb847e6c",
"metadata": {},
"source": [
"🔔 **Reminder:** \n",
"- In **Part 2**, GPT-4o Mini (zero-shot) scored: \n",
" Avg. Error: ~$99 | RMSLE: 0.75 | Accuracy: 44.8% \n",
"\n",
"- In **Part 3**, with **RAG**, performance improved to: \n",
" Avg. Error: ~$59.54 | RMSLE: 0.42 | Accuracy: 69.2%\n",
"\n",
"🧪 **Now its time to see** if fine-tuning can push GPT-4o Mini even further and outperform both baselines."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0adf1500-9cc7-491a-9ea6-88932af85dca",
"metadata": {},
"outputs": [],
"source": [
"import helpers.testing\n",
"importlib.reload(helpers.testing)\n",
"\n",
"from helpers.testing import Tester # noqa: E402\n",
"\n",
"tester = Tester(gpt_ft, test)\n",
"tester.run()"
]
},
{
"cell_type": "markdown",
"id": "37439666",
"metadata": {},
"source": [
"Gpt Ft Error=$129.16 RMSLE=0.94 Hits=35.2%"
]
},
{
"cell_type": "markdown",
"id": "5487da30-e1a8-4db5-bf17-80bc4f109524",
"metadata": {},
"source": [
"**Fine-tuning GPT-4o Mini led to worse performance than both its zero-shot and RAG-enhanced versions.**\n",
"\n",
"⚠️ When Fine-Tuning Isnt Needed:\n",
"- For tasks like price prediction, GPT-4o performs well with prompting alone — thanks to strong pretraining and generalization.\n",
"- 💡 Fine-tuning isnt always better. Use it when prompting fails — not by default.\n",
"\n",
"✅ **When Fine-Tuning Is Worth It (based on OpenAIs own guidelines)**\n",
"- Custom tone/style e.g., mimicking a brand voice or writing like a specific author\n",
"- More consistent output e.g., always following a strict format\n",
"- Fix prompt failures e.g., when multi-step instructions get ignored\n",
"- Handle edge cases e.g., rare product types or weird inputs\n",
"- Teach new tasks e.g., estimating prices in a custom format no model has seen before\n",
"\n",
"---\n",
"\n",
"Now that weve explored both frontier closed-source models and traditional ML, its time to turn to open-source.\n",
"\n",
"🚀 **Next up: Fine-tuned LLaMA 3.1 8B (quantized)** — can it beat its base version, outperform GPT-4o Mini, or even challenge the big players?\n",
"\n",
"🔍 Lets find out in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part5_llama31_8b_quant.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,120 @@
from typing import Optional # A variable might be a certain type or None
from transformers import AutoTokenizer
import re
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
MIN_TOKENS = 150 # Minimum tokens required to accept an item
MAX_TOKENS = 160 # We limit to 160 tokens so that after adding prompt text, the total stays around 180 tokens.
MIN_CHARS = 300 # Reject items with less than 300 characters
CEILING_CHARS = MAX_TOKENS * 7 # Truncate long text to about 1120 characters (approx 160 tokens)
class Item:
"""
An Item is a cleaned, curated datapoint of a Product with a Price
"""
# Load tokenizer for the model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
# Define PRICE_LABEL and question for the training prompt
PRICE_LABEL = "Price is $"
QUESTION = "How much does this cost to the nearest dollar?"
# A list of useless phrases to remove to reduce noise for price prediction
REMOVALS = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "]
# Attributes for each item
title: str
price: float
category: str
token_count: int = 0 # How many tokens in the final prompt
# Optional fields
details: Optional[str] # The value can be a string or can be None
prompt: Optional[str] = None
include = False # Whether to keep the item or not
def __init__(self, data, price):
self.title = data['title']
self.price = price
self.parse(data)
def scrub_details(self):
"""
Removes useless phrases from details, which often has repeated specs or boilerplate text.
"""
details = self.details
for remove in self.REMOVALS:
details = details.replace(remove, "")
return details
def scrub(self, stuff):
"""
Clean up the provided text by removing unnecessary characters and whitespace
Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers
"""
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip()
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",")
words = stuff.split(' ')
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]
return " ".join(select)
def parse(self, data):
"""
Prepares the text, checks length, tokenizes it, and sets include = True if its valid.
"""
# Builds a full contents string by combining description, features, and cleaned details.
contents = '\n'.join(data['description'])
if contents:
contents += '\n'
features = '\n'.join(data['features'])
if features:
contents += features + '\n'
self.details = data['details']
if self.details:
contents += self.scrub_details() + '\n'
# If content is long enough, trim it to max char limit before processing.
if len(contents) > MIN_CHARS:
contents = contents[:CEILING_CHARS]
# Clean and tokenize text, then check token count.
text = f"{self.scrub(self.title)}\n{self.scrub(contents)}"
tokens = self.tokenizer.encode(text, add_special_tokens=False)
if len(tokens) > MIN_TOKENS:
# Truncate tokens, decode them back and create the training prompt
tokens = tokens[:MAX_TOKENS]
text = self.tokenizer.decode(tokens)
self.make_prompt(text)
# Mark the item as valid and ready to be used in training
self.include = True # Only items with MIN_TOKENS <= tokens <= MAX_TOKENS are kept
def make_prompt(self, text):
"""
Builds the training prompt using the question, text, and price. Then counts the tokens.
"""
self.prompt = f"{self.QUESTION}\n\n{text}\n\n"
self.prompt += f"{self.PRICE_LABEL }{str(round(self.price))}.00"
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))
def test_prompt(self):
"""
Returns the prompt without the actual price, useful for testing/inference.
"""
return self.prompt.split(self.PRICE_LABEL )[0] + self.PRICE_LABEL
def __repr__(self):
"""
Defines how the Item object looks when printed — it shows the title and price.
"""
return f"<{self.title} = ${self.price}>"

View File

@@ -0,0 +1,106 @@
from datetime import datetime # Measure how long loading takes
from tqdm import tqdm # Shows a progress bar while processing data
from datasets import load_dataset # Load a dataset from Hugging Face Hub
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor # For parallel processing (speed)
from items import Item
CHUNK_SIZE = 1000 # Process the dataset in chunks of 1000 datapoints at a time (for efficiency)
MIN_PRICE = 0.5
MAX_PRICE = 999.49
WORKER = 4 # Set the number of workers here
class ItemLoader:
def __init__(self, name):
"""
Initialize the loader with a dataset name.
"""
self.name = name # Store the category name
self.dataset = None #Placeholder for the dataset (we load it later in load())
def process_chunk(self, chunk):
"""
Convert a chunk of datapoints into valid Item objects.
"""
batch = [] # Initialize the list to hold valid items
# Loop through each datapoint in the chunk
for datapoint in chunk:
try:
# Extract price from datapoint
price_str = datapoint['price']
if price_str:
price = float(price_str)
# Check if price is within valid range
if MIN_PRICE <= price <= MAX_PRICE:
item = Item(datapoint, price)
# Keep only valid items
if item.include:
batch.append(item)
except ValueError:
continue # Skip datapoints with invalid price format
return batch # Return the list of valid items
def load_in_parallel(self, workers):
"""
Split the dataset into chunks and process them in parallel.
"""
results = []
size = len(self.dataset)
chunk_count = (size // CHUNK_SIZE) + 1
# Build chunks directly here (no separate function)
chunks = [
self.dataset.select(range(i, min(i + CHUNK_SIZE, size)))
for i in range(0, size, CHUNK_SIZE)
]
# Process chunks in parallel using multiple CPU cores
with ProcessPoolExecutor(max_workers=workers) as pool:
for batch in tqdm(pool.map(self.process_chunk, chunks), total=chunk_count):
results.extend(batch)
# Add the category name to each result
for result in results:
result.category = self.name
return results
def load(self, workers=WORKER):
"""
Load and process the dataset, returning valid items.
"""
# Record start time
start = datetime.now()
# Print loading message
print(f"Loading dataset {self.name}", flush=True)
# Load dataset from Hugging Face (based on category name)
self.dataset = load_dataset(
"McAuley-Lab/Amazon-Reviews-2023",
f"raw_meta_{self.name}",
split="full",
trust_remote_code=True
)
# Process the dataset in parallel and collect valid items
results = self.load_in_parallel(workers)
# Record end time and print summary
finish = datetime.now()
print(
f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins",
flush=True
)
# Return the list of valid items
return results

View File

@@ -0,0 +1,84 @@
import math
import matplotlib.pyplot as plt
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
RESET = "\033[0m"
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}
class Tester:
def __init__(self, predictor, data, title=None, size=250):
self.predictor = predictor
self.data = data
self.title = title or predictor.__name__.replace("_", " ").title()
self.size = size
self.guesses = []
self.truths = []
self.errors = []
self.sles = []
self.colors = []
def color_for(self, error, truth):
if error<40 or error/truth < 0.2:
return "green"
elif error<80 or error/truth < 0.4:
return "orange"
else:
return "red"
def run_datapoint(self, i):
datapoint = self.data[i]
guess = self.predictor(datapoint)
truth = datapoint["price"]
error = abs(guess - truth)
log_error = math.log(truth+1) - math.log(guess+1)
sle = log_error ** 2
color = self.color_for(error, truth)
title = datapoint["text"][:40] + "..." if len(datapoint["text"]) > 40 else datapoint["text"]
self.guesses.append(guess)
self.truths.append(truth)
self.errors.append(error)
self.sles.append(sle)
self.colors.append(color)
# print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}")
def chart(self, title):
max_error = max(self.errors)
plt.figure(figsize=(15, 6))
max_val = max(max(self.truths), max(self.guesses))
plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)
plt.scatter(self.truths, self.guesses, s=3, c=self.colors)
plt.xlabel('Ground Truth')
plt.ylabel('Model Estimate')
plt.xlim(0, max_val)
plt.ylim(0, max_val)
plt.title(title)
# Add color legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),
Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),
Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)
]
plt.legend(handles=legend_elements, loc='upper left')
plt.show()
def report(self):
average_error = sum(self.errors) / self.size
rmsle = math.sqrt(sum(self.sles) / self.size)
hits = sum(1 for color in self.colors if color=="green")
title = f"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%"
self.chart(title)
def run(self):
self.error = 0
for i in range(self.size):
self.run_datapoint(i)
self.report()
@classmethod
def test(cls, function, data):
cls(function, data).run()