Merge pull request #431 from lisekarimi/feature/week6
Add week6 contributions
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,510 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12934dbc-ff4f-4dfc-8cc1-d92cc8826cf2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🔍 Predicting Item Prices from Descriptions (Part 4)\n",
|
||||
"---\n",
|
||||
"- Data Curation & Preprocessing\n",
|
||||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||||
"- E5 Embeddings & RAG\n",
|
||||
"- ➡️ Fine-Tuning GPT-4o Mini\n",
|
||||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||||
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||||
"- Evaluating Fine-Tuned LLaMA\n",
|
||||
"- Summary & Leaderboard\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"# 🔧 Part 4: Fine-Tuning GPT-4o Mini\n",
|
||||
"\n",
|
||||
"- 🧑💻 Skill Level: Advanced\n",
|
||||
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
|
||||
"- 🛠️ Requirements: 🔑 HF Token, Open API Key, wandb API Key\n",
|
||||
"- Tasks:\n",
|
||||
" - Convert chat data to .jsonl format for OpenAI\n",
|
||||
" - Fine-tune the model and monitor with Weights & Biases\n",
|
||||
" - Test the fine-tuned GPT-4o Mini \n",
|
||||
"\n",
|
||||
"Can fine-tuning GPT-4o Mini outperform both its zero-shot baseline and RAG-enhanced version? \n",
|
||||
"Time to find out.\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5809630f-d3ea-41df-86ec-9cbf59a46f5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import importlib\n",
|
||||
"import json\n",
|
||||
"import re\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from openai import OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4120c84d-c310-4d31-9e1f-1549ea4a4186",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"load_dotenv(override=True)\n",
|
||||
"\n",
|
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||||
"if not openai_api_key:\n",
|
||||
" print(\"❌ OPENAI_API_KEY is missing\")\n",
|
||||
"\n",
|
||||
"openai = OpenAI(api_key=openai_api_key)\n",
|
||||
"\n",
|
||||
"hf_token = os.getenv('HF_TOKEN')\n",
|
||||
"if not hf_token:\n",
|
||||
" print(\"❌ HF_TOKEN is missing\")\n",
|
||||
"\n",
|
||||
"login(hf_token, add_to_git_credential=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "31d3aa97-68a8-4f71-a43f-107f7c8553c5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 📥 Load Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f2bae96a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
|
||||
"# %pip install -U datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c45e23d6-1304-4859-81f0-35a9ddf1c755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"HF_USER = \"lisekarimi\"\n",
|
||||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||
"\n",
|
||||
"dataset = load_dataset(DATASET_NAME)\n",
|
||||
"train = dataset['train']\n",
|
||||
"test = dataset['test']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "667adda8-add8-41b6-9e60-7870bad20c02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b85d86d0-b6b1-49cd-9ef0-9214c1267199",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 🛠️ Step 1 : Data Preparation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d3ba760d-467a-4cd9-8d3f-e6ce84273610",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To fine-tune GPT-4o-mini, OpenAI requires training data in **.jsonl format**. \n",
|
||||
"\n",
|
||||
"`make_jsonl` converts our chat data :\n",
|
||||
"\n",
|
||||
"from \n",
|
||||
"\n",
|
||||
"[\n",
|
||||
" {\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"into the .jsonl format \n",
|
||||
"\n",
|
||||
"{\"messages\": [{\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"}, {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"}, {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}]}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec254755-67f6-4676-b67f-c1376ea00124",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Mask the price in the test item\n",
|
||||
"def mask_price_value(text):\n",
|
||||
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e5e51957-b0ec-49f9-ae70-74771a101756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def messages_for(datapoint):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||
" assistant_response = f\"Price is ${datapoint['price']:.2f}\"\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": assistant_response}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"messages_for(train[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03583d32-b0f2-44c0-820e-62c8e7e48247",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def make_jsonl(datapoints):\n",
|
||||
" result = \"\"\n",
|
||||
" for datapoint in datapoints:\n",
|
||||
" messages = messages_for(datapoint)\n",
|
||||
" messages_str = json.dumps(messages, ensure_ascii=False)\n",
|
||||
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
|
||||
" return result.strip()\n",
|
||||
"\n",
|
||||
"make_jsonl(train.select([0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36c9cf60-0bcb-44cb-8df6-ff2ed4110cd2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ft_train = train.select(range(100))\n",
|
||||
"ft_validation = train.select(range(100, 150))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "494eaecd-ae5d-4396-b694-6faf88fb7fd6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert the items into jsonl and write them to a file\n",
|
||||
"\n",
|
||||
"def write_jsonl(datapoints, filename):\n",
|
||||
" with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
|
||||
" jsonl = make_jsonl(datapoints)\n",
|
||||
" f.write(jsonl)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae42986d-ab02-4a11-aa0c-ede9c63ec7a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"write_jsonl(ft_train, \"data/ft_train.jsonl\")\n",
|
||||
"write_jsonl(ft_validation, \"data/ft_val.jsonl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9bed22d-73ad-4820-a983-cbdccd8dbbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"data/ft_train.jsonl\", \"rb\") as f:\n",
|
||||
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"with open(\"data/ft_val.jsonl\", \"rb\") as f:\n",
|
||||
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e6c6ce8-6600-4068-9ec5-32c6428ce9ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "26943fad-4301-4bb4-97e8-be52a9743322",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"validation_file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "edb0a3ec-1607-4c5b-ab06-852f951cae8b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 🚀 Step 2: Run Fine-Tuning & Monitor with wandb\n",
|
||||
"We will use https://wandb.ai to monitor the training runs\n",
|
||||
"\n",
|
||||
"1- Create an API key in wandb\n",
|
||||
"\n",
|
||||
"2- Add this key in OpenAI dashboard https://platform.openai.com/account/organization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59f552fe-5e80-4742-94a8-5492556a6543",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "144088d7-7c30-439a-9282-1e6096c181ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run the fine tuning\n",
|
||||
"\n",
|
||||
"openai.fine_tuning.jobs.create(\n",
|
||||
" training_file=train_file.id,\n",
|
||||
" validation_file=validation_file.id,\n",
|
||||
" model=\"gpt-4o-mini-2024-07-18\",\n",
|
||||
" seed=42,\n",
|
||||
" hyperparameters={\"n_epochs\": 1},\n",
|
||||
" integrations = [wandb_integration],\n",
|
||||
" suffix=\"pricer\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "330e75f5-0208-4c74-8dd3-07bc06047b2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
|
||||
"job_id\n",
|
||||
"\n",
|
||||
"# Then check your wandb dashboard to view the run of this job ID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a92dac5-e6d8-439c-b55e-507becb37a6c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Use this command to track the fine-tuning progress here\n",
|
||||
"\n",
|
||||
"openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=2).data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6b65677-06b2-47d3-b0e6-51210a3d832b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 📧 You’ll get an email once fine-tuning is complete. ☕ You can take a break until then. ▶️ Once you receive it, run the cells below to continue."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0a7af4be-0b55-4654-af7a-f47485babc52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 3 : Test the fine tuned model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c8497eb8-49ee-4a05-9e51-fc1b4b2b41d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ft_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
|
||||
"ft_model_name"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12bed33f-be31-4d7c-8651-3f267c529304",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can find the entire fine-tuning process in the **Fine-tuning** dashboard on OpenAI.\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ac6a89ef-f982-457a-bad7-bd84b6132a07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build LLM messages\n",
|
||||
"def build_messages(datapoint):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"def get_price(s):\n",
|
||||
" s = s.replace('$','').replace(',','')\n",
|
||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||||
" return float(match.group()) if match else 0\n",
|
||||
"\n",
|
||||
"def gpt_ft(datapoint):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=ft_model_name,\n",
|
||||
" messages=build_messages(datapoint),\n",
|
||||
" seed=42,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "93a93017-458c-4769-b81c-b2dad2af7552",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(test[0][\"price\"])\n",
|
||||
"print(gpt_ft(test[0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87a5ad10-ed60-4533-ad61-225ceb847e6c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"🔔 **Reminder:** \n",
|
||||
"- In **Part 2**, GPT-4o Mini (zero-shot) scored: \n",
|
||||
" Avg. Error: ~$99 | RMSLE: 0.75 | Accuracy: 44.8% \n",
|
||||
"\n",
|
||||
"- In **Part 3**, with **RAG**, performance improved to: \n",
|
||||
" Avg. Error: ~$59.54 | RMSLE: 0.42 | Accuracy: 69.2%\n",
|
||||
"\n",
|
||||
"🧪 **Now it’s time to see** if fine-tuning can push GPT-4o Mini even further and outperform both baselines."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0adf1500-9cc7-491a-9ea6-88932af85dca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import helpers.testing\n",
|
||||
"importlib.reload(helpers.testing)\n",
|
||||
"\n",
|
||||
"from helpers.testing import Tester # noqa: E402\n",
|
||||
"\n",
|
||||
"tester = Tester(gpt_ft, test)\n",
|
||||
"tester.run()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "37439666",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Gpt Ft Error=$129.16 RMSLE=0.94 Hits=35.2%"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5487da30-e1a8-4db5-bf17-80bc4f109524",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Fine-tuning GPT-4o Mini led to worse performance than both its zero-shot and RAG-enhanced versions.**\n",
|
||||
"\n",
|
||||
"⚠️ When Fine-Tuning Isn’t Needed:\n",
|
||||
"- For tasks like price prediction, GPT-4o performs well with prompting alone — thanks to strong pretraining and generalization.\n",
|
||||
"- 💡 Fine-tuning isn’t always better. Use it when prompting fails — not by default.\n",
|
||||
"\n",
|
||||
"✅ **When Fine-Tuning Is Worth It (based on OpenAI’s own guidelines)**\n",
|
||||
"- Custom tone/style – e.g., mimicking a brand voice or writing like a specific author\n",
|
||||
"- More consistent output – e.g., always following a strict format\n",
|
||||
"- Fix prompt failures – e.g., when multi-step instructions get ignored\n",
|
||||
"- Handle edge cases – e.g., rare product types or weird inputs\n",
|
||||
"- Teach new tasks – e.g., estimating prices in a custom format no model has seen before\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"Now that we’ve explored both frontier closed-source models and traditional ML, it’s time to turn to open-source.\n",
|
||||
"\n",
|
||||
"🚀 **Next up: Fine-tuned LLaMA 3.1 8B (quantized)** — can it beat its base version, outperform GPT-4o Mini, or even challenge the big players?\n",
|
||||
"\n",
|
||||
"🔍 Let’s find out in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part5_llama31_8b_quant.ipynb)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
1500
week6/community-contributions/lisekarimi/data/human_output.csv
Normal file
1500
week6/community-contributions/lisekarimi/data/human_output.csv
Normal file
File diff suppressed because it is too large
Load Diff
120
week6/community-contributions/lisekarimi/helpers/items.py
Normal file
120
week6/community-contributions/lisekarimi/helpers/items.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Optional # A variable might be a certain type or None
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
|
||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B"
|
||||
|
||||
MIN_TOKENS = 150 # Minimum tokens required to accept an item
|
||||
MAX_TOKENS = 160 # We limit to 160 tokens so that after adding prompt text, the total stays around 180 tokens.
|
||||
|
||||
MIN_CHARS = 300 # Reject items with less than 300 characters
|
||||
CEILING_CHARS = MAX_TOKENS * 7 # Truncate long text to about 1120 characters (approx 160 tokens)
|
||||
|
||||
class Item:
|
||||
"""
|
||||
An Item is a cleaned, curated datapoint of a Product with a Price
|
||||
"""
|
||||
|
||||
# Load tokenizer for the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
||||
|
||||
# Define PRICE_LABEL and question for the training prompt
|
||||
PRICE_LABEL = "Price is $"
|
||||
QUESTION = "How much does this cost to the nearest dollar?"
|
||||
|
||||
# A list of useless phrases to remove to reduce noise for price prediction
|
||||
REMOVALS = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "]
|
||||
|
||||
# Attributes for each item
|
||||
title: str
|
||||
price: float
|
||||
category: str
|
||||
token_count: int = 0 # How many tokens in the final prompt
|
||||
|
||||
# Optional fields
|
||||
details: Optional[str] # The value can be a string or can be None
|
||||
prompt: Optional[str] = None
|
||||
include = False # Whether to keep the item or not
|
||||
|
||||
def __init__(self, data, price):
|
||||
self.title = data['title']
|
||||
self.price = price
|
||||
self.parse(data)
|
||||
|
||||
def scrub_details(self):
|
||||
"""
|
||||
Removes useless phrases from details, which often has repeated specs or boilerplate text.
|
||||
"""
|
||||
details = self.details
|
||||
for remove in self.REMOVALS:
|
||||
details = details.replace(remove, "")
|
||||
return details
|
||||
|
||||
def scrub(self, stuff):
|
||||
"""
|
||||
Clean up the provided text by removing unnecessary characters and whitespace
|
||||
Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers
|
||||
"""
|
||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip()
|
||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",")
|
||||
words = stuff.split(' ')
|
||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]
|
||||
return " ".join(select)
|
||||
|
||||
def parse(self, data):
|
||||
"""
|
||||
Prepares the text, checks length, tokenizes it, and sets include = True if it’s valid.
|
||||
"""
|
||||
# Builds a full contents string by combining description, features, and cleaned details.
|
||||
contents = '\n'.join(data['description'])
|
||||
if contents:
|
||||
contents += '\n'
|
||||
features = '\n'.join(data['features'])
|
||||
if features:
|
||||
contents += features + '\n'
|
||||
self.details = data['details']
|
||||
if self.details:
|
||||
contents += self.scrub_details() + '\n'
|
||||
|
||||
# If content is long enough, trim it to max char limit before processing.
|
||||
if len(contents) > MIN_CHARS:
|
||||
contents = contents[:CEILING_CHARS]
|
||||
|
||||
# Clean and tokenize text, then check token count.
|
||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents)}"
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(tokens) > MIN_TOKENS:
|
||||
# Truncate tokens, decode them back and create the training prompt
|
||||
tokens = tokens[:MAX_TOKENS]
|
||||
text = self.tokenizer.decode(tokens)
|
||||
self.make_prompt(text)
|
||||
|
||||
# Mark the item as valid and ready to be used in training
|
||||
self.include = True # Only items with MIN_TOKENS <= tokens <= MAX_TOKENS are kept
|
||||
|
||||
|
||||
def make_prompt(self, text):
|
||||
"""
|
||||
Builds the training prompt using the question, text, and price. Then counts the tokens.
|
||||
"""
|
||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n"
|
||||
self.prompt += f"{self.PRICE_LABEL }{str(round(self.price))}.00"
|
||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))
|
||||
|
||||
def test_prompt(self):
|
||||
"""
|
||||
Returns the prompt without the actual price, useful for testing/inference.
|
||||
"""
|
||||
return self.prompt.split(self.PRICE_LABEL )[0] + self.PRICE_LABEL
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Defines how the Item object looks when printed — it shows the title and price.
|
||||
"""
|
||||
return f"<{self.title} = ${self.price}>"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
106
week6/community-contributions/lisekarimi/helpers/loaders.py
Normal file
106
week6/community-contributions/lisekarimi/helpers/loaders.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime # Measure how long loading takes
|
||||
from tqdm import tqdm # Shows a progress bar while processing data
|
||||
from datasets import load_dataset # Load a dataset from Hugging Face Hub
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor # For parallel processing (speed)
|
||||
from items import Item
|
||||
|
||||
CHUNK_SIZE = 1000 # Process the dataset in chunks of 1000 datapoints at a time (for efficiency)
|
||||
MIN_PRICE = 0.5
|
||||
MAX_PRICE = 999.49
|
||||
WORKER = 4 # Set the number of workers here
|
||||
|
||||
class ItemLoader:
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Initialize the loader with a dataset name.
|
||||
"""
|
||||
self.name = name # Store the category name
|
||||
self.dataset = None #Placeholder for the dataset (we load it later in load())
|
||||
|
||||
def process_chunk(self, chunk):
|
||||
"""
|
||||
Convert a chunk of datapoints into valid Item objects.
|
||||
"""
|
||||
batch = [] # Initialize the list to hold valid items
|
||||
|
||||
# Loop through each datapoint in the chunk
|
||||
for datapoint in chunk:
|
||||
try:
|
||||
# Extract price from datapoint
|
||||
price_str = datapoint['price']
|
||||
if price_str:
|
||||
price = float(price_str)
|
||||
|
||||
# Check if price is within valid range
|
||||
if MIN_PRICE <= price <= MAX_PRICE:
|
||||
item = Item(datapoint, price)
|
||||
|
||||
# Keep only valid items
|
||||
if item.include:
|
||||
batch.append(item)
|
||||
except ValueError:
|
||||
continue # Skip datapoints with invalid price format
|
||||
return batch # Return the list of valid items
|
||||
|
||||
|
||||
def load_in_parallel(self, workers):
|
||||
"""
|
||||
Split the dataset into chunks and process them in parallel.
|
||||
"""
|
||||
results = []
|
||||
size = len(self.dataset)
|
||||
chunk_count = (size // CHUNK_SIZE) + 1
|
||||
|
||||
# Build chunks directly here (no separate function)
|
||||
chunks = [
|
||||
self.dataset.select(range(i, min(i + CHUNK_SIZE, size)))
|
||||
for i in range(0, size, CHUNK_SIZE)
|
||||
]
|
||||
|
||||
# Process chunks in parallel using multiple CPU cores
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
for batch in tqdm(pool.map(self.process_chunk, chunks), total=chunk_count):
|
||||
results.extend(batch)
|
||||
|
||||
# Add the category name to each result
|
||||
for result in results:
|
||||
result.category = self.name
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load(self, workers=WORKER):
|
||||
"""
|
||||
Load and process the dataset, returning valid items.
|
||||
"""
|
||||
# Record start time
|
||||
start = datetime.now()
|
||||
|
||||
# Print loading message
|
||||
print(f"Loading dataset {self.name}", flush=True)
|
||||
|
||||
# Load dataset from Hugging Face (based on category name)
|
||||
self.dataset = load_dataset(
|
||||
"McAuley-Lab/Amazon-Reviews-2023",
|
||||
f"raw_meta_{self.name}",
|
||||
split="full",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Process the dataset in parallel and collect valid items
|
||||
results = self.load_in_parallel(workers)
|
||||
|
||||
# Record end time and print summary
|
||||
finish = datetime.now()
|
||||
print(
|
||||
f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins",
|
||||
flush=True
|
||||
)
|
||||
|
||||
# Return the list of valid items
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
84
week6/community-contributions/lisekarimi/helpers/testing.py
Normal file
84
week6/community-contributions/lisekarimi/helpers/testing.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}
|
||||
|
||||
class Tester:
|
||||
|
||||
def __init__(self, predictor, data, title=None, size=250):
|
||||
self.predictor = predictor
|
||||
self.data = data
|
||||
self.title = title or predictor.__name__.replace("_", " ").title()
|
||||
self.size = size
|
||||
self.guesses = []
|
||||
self.truths = []
|
||||
self.errors = []
|
||||
self.sles = []
|
||||
self.colors = []
|
||||
|
||||
def color_for(self, error, truth):
|
||||
if error<40 or error/truth < 0.2:
|
||||
return "green"
|
||||
elif error<80 or error/truth < 0.4:
|
||||
return "orange"
|
||||
else:
|
||||
return "red"
|
||||
|
||||
def run_datapoint(self, i):
|
||||
datapoint = self.data[i]
|
||||
guess = self.predictor(datapoint)
|
||||
truth = datapoint["price"]
|
||||
error = abs(guess - truth)
|
||||
log_error = math.log(truth+1) - math.log(guess+1)
|
||||
sle = log_error ** 2
|
||||
color = self.color_for(error, truth)
|
||||
title = datapoint["text"][:40] + "..." if len(datapoint["text"]) > 40 else datapoint["text"]
|
||||
self.guesses.append(guess)
|
||||
self.truths.append(truth)
|
||||
self.errors.append(error)
|
||||
self.sles.append(sle)
|
||||
self.colors.append(color)
|
||||
# print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}")
|
||||
|
||||
def chart(self, title):
|
||||
max_error = max(self.errors)
|
||||
plt.figure(figsize=(15, 6))
|
||||
max_val = max(max(self.truths), max(self.guesses))
|
||||
plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)
|
||||
plt.scatter(self.truths, self.guesses, s=3, c=self.colors)
|
||||
plt.xlabel('Ground Truth')
|
||||
plt.ylabel('Model Estimate')
|
||||
plt.xlim(0, max_val)
|
||||
plt.ylim(0, max_val)
|
||||
plt.title(title)
|
||||
|
||||
# Add color legend
|
||||
from matplotlib.lines import Line2D
|
||||
legend_elements = [
|
||||
Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),
|
||||
Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc='upper left')
|
||||
plt.show()
|
||||
|
||||
def report(self):
|
||||
average_error = sum(self.errors) / self.size
|
||||
rmsle = math.sqrt(sum(self.sles) / self.size)
|
||||
hits = sum(1 for color in self.colors if color=="green")
|
||||
title = f"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%"
|
||||
self.chart(title)
|
||||
|
||||
def run(self):
|
||||
self.error = 0
|
||||
for i in range(self.size):
|
||||
self.run_datapoint(i)
|
||||
self.report()
|
||||
|
||||
@classmethod
|
||||
def test(cls, function, data):
|
||||
cls(function, data).run()
|
||||
Reference in New Issue
Block a user