511 lines
15 KiB
Plaintext
511 lines
15 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "12934dbc-ff4f-4dfc-8cc1-d92cc8826cf2",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 🔍 Predicting Item Prices from Descriptions (Part 4)\n",
|
||
"---\n",
|
||
"- Data Curation & Preprocessing\n",
|
||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||
"- E5 Embeddings & RAG\n",
|
||
"- ➡️ Fine-Tuning GPT-4o Mini\n",
|
||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||
"- Evaluating Fine-Tuned LLaMA\n",
|
||
"- Summary & Leaderboard\n",
|
||
"\n",
|
||
"---\n",
|
||
"\n",
|
||
"# 🔧 Part 4: Fine-Tuning GPT-4o Mini\n",
|
||
"\n",
|
||
"- 🧑💻 Skill Level: Advanced\n",
|
||
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
|
||
"- 🛠️ Requirements: 🔑 HF Token, Open API Key, wandb API Key\n",
|
||
"- Tasks:\n",
|
||
" - Convert chat data to .jsonl format for OpenAI\n",
|
||
" - Fine-tune the model and monitor with Weights & Biases\n",
|
||
" - Test the fine-tuned GPT-4o Mini \n",
|
||
"\n",
|
||
"Can fine-tuning GPT-4o Mini outperform both its zero-shot baseline and RAG-enhanced version? \n",
|
||
"Time to find out.\n",
|
||
"\n",
|
||
"---\n",
|
||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5809630f-d3ea-41df-86ec-9cbf59a46f5c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# imports\n",
|
||
"\n",
|
||
"import os\n",
|
||
"import importlib\n",
|
||
"import json\n",
|
||
"import re\n",
|
||
"from dotenv import load_dotenv\n",
|
||
"from huggingface_hub import login\n",
|
||
"from datasets import load_dataset\n",
|
||
"from openai import OpenAI"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4120c84d-c310-4d31-9e1f-1549ea4a4186",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"load_dotenv(override=True)\n",
|
||
"\n",
|
||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
||
"if not openai_api_key:\n",
|
||
" print(\"❌ OPENAI_API_KEY is missing\")\n",
|
||
"\n",
|
||
"openai = OpenAI(api_key=openai_api_key)\n",
|
||
"\n",
|
||
"hf_token = os.getenv('HF_TOKEN')\n",
|
||
"if not hf_token:\n",
|
||
" print(\"❌ HF_TOKEN is missing\")\n",
|
||
"\n",
|
||
"login(hf_token, add_to_git_credential=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "31d3aa97-68a8-4f71-a43f-107f7c8553c5",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📥 Load Dataset"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f2bae96a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
|
||
"# %pip install -U datasets"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c45e23d6-1304-4859-81f0-35a9ddf1c755",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"HF_USER = \"lisekarimi\"\n",
|
||
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||
"\n",
|
||
"dataset = load_dataset(DATASET_NAME)\n",
|
||
"train = dataset['train']\n",
|
||
"test = dataset['test']"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "667adda8-add8-41b6-9e60-7870bad20c02",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"test[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b85d86d0-b6b1-49cd-9ef0-9214c1267199",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🛠️ Step 1 : Data Preparation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d3ba760d-467a-4cd9-8d3f-e6ce84273610",
|
||
"metadata": {},
|
||
"source": [
|
||
"To fine-tune GPT-4o-mini, OpenAI requires training data in **.jsonl format**. \n",
|
||
"\n",
|
||
"`make_jsonl` converts our chat data :\n",
|
||
"\n",
|
||
"from \n",
|
||
"\n",
|
||
"[\n",
|
||
" {\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"},\n",
|
||
" {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"},\n",
|
||
" {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}\n",
|
||
"]\n",
|
||
"\n",
|
||
"into the .jsonl format \n",
|
||
"\n",
|
||
"{\"messages\": [{\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"}, {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"}, {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}]}\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ec254755-67f6-4676-b67f-c1376ea00124",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Mask the price in the test item\n",
|
||
"def mask_price_value(text):\n",
|
||
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e5e51957-b0ec-49f9-ae70-74771a101756",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def messages_for(datapoint):\n",
|
||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||
" assistant_response = f\"Price is ${datapoint['price']:.2f}\"\n",
|
||
" return [\n",
|
||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||
" {\"role\": \"assistant\", \"content\": assistant_response}\n",
|
||
" ]\n",
|
||
"\n",
|
||
"messages_for(train[0])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "03583d32-b0f2-44c0-820e-62c8e7e48247",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def make_jsonl(datapoints):\n",
|
||
" result = \"\"\n",
|
||
" for datapoint in datapoints:\n",
|
||
" messages = messages_for(datapoint)\n",
|
||
" messages_str = json.dumps(messages, ensure_ascii=False)\n",
|
||
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
|
||
" return result.strip()\n",
|
||
"\n",
|
||
"make_jsonl(train.select([0]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "36c9cf60-0bcb-44cb-8df6-ff2ed4110cd2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"ft_train = train.select(range(100))\n",
|
||
"ft_validation = train.select(range(100, 150))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "494eaecd-ae5d-4396-b694-6faf88fb7fd6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Convert the items into jsonl and write them to a file\n",
|
||
"\n",
|
||
"def write_jsonl(datapoints, filename):\n",
|
||
" with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
|
||
" jsonl = make_jsonl(datapoints)\n",
|
||
" f.write(jsonl)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ae42986d-ab02-4a11-aa0c-ede9c63ec7a2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"write_jsonl(ft_train, \"data/ft_train.jsonl\")\n",
|
||
"write_jsonl(ft_validation, \"data/ft_val.jsonl\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b9bed22d-73ad-4820-a983-cbdccd8dbbc8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"with open(\"data/ft_train.jsonl\", \"rb\") as f:\n",
|
||
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||
"with open(\"data/ft_val.jsonl\", \"rb\") as f:\n",
|
||
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1e6c6ce8-6600-4068-9ec5-32c6428ce9ea",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_file"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "26943fad-4301-4bb4-97e8-be52a9743322",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"validation_file"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "edb0a3ec-1607-4c5b-ab06-852f951cae8b",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🚀 Step 2: Run Fine-Tuning & Monitor with wandb\n",
|
||
"We will use https://wandb.ai to monitor the training runs\n",
|
||
"\n",
|
||
"1- Create an API key in wandb\n",
|
||
"\n",
|
||
"2- Add this key in OpenAI dashboard https://platform.openai.com/account/organization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "59f552fe-5e80-4742-94a8-5492556a6543",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "144088d7-7c30-439a-9282-1e6096c181ea",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Run the fine tuning\n",
|
||
"\n",
|
||
"openai.fine_tuning.jobs.create(\n",
|
||
" training_file=train_file.id,\n",
|
||
" validation_file=validation_file.id,\n",
|
||
" model=\"gpt-4o-mini-2024-07-18\",\n",
|
||
" seed=42,\n",
|
||
" hyperparameters={\"n_epochs\": 1},\n",
|
||
" integrations = [wandb_integration],\n",
|
||
" suffix=\"pricer\"\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "330e75f5-0208-4c74-8dd3-07bc06047b2e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
|
||
"job_id\n",
|
||
"\n",
|
||
"# Then check your wandb dashboard to view the run of this job ID"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4a92dac5-e6d8-439c-b55e-507becb37a6c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Use this command to track the fine-tuning progress here\n",
|
||
"\n",
|
||
"openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=2).data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b6b65677-06b2-47d3-b0e6-51210a3d832b",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 📧 You’ll get an email once fine-tuning is complete. ☕ You can take a break until then. ▶️ Once you receive it, run the cells below to continue."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0a7af4be-0b55-4654-af7a-f47485babc52",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Step 3 : Test the fine tuned model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c8497eb8-49ee-4a05-9e51-fc1b4b2b41d4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"ft_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
|
||
"ft_model_name"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "12bed33f-be31-4d7c-8651-3f267c529304",
|
||
"metadata": {},
|
||
"source": [
|
||
"You can find the entire fine-tuning process in the **Fine-tuning** dashboard on OpenAI.\n",
|
||
"\n",
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ac6a89ef-f982-457a-bad7-bd84b6132a07",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Build LLM messages\n",
|
||
"def build_messages(datapoint):\n",
|
||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
|
||
" return [\n",
|
||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
||
" ]\n",
|
||
"\n",
|
||
"def get_price(s):\n",
|
||
" s = s.replace('$','').replace(',','')\n",
|
||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||
" return float(match.group()) if match else 0\n",
|
||
"\n",
|
||
"def gpt_ft(datapoint):\n",
|
||
" response = openai.chat.completions.create(\n",
|
||
" model=ft_model_name,\n",
|
||
" messages=build_messages(datapoint),\n",
|
||
" seed=42,\n",
|
||
" max_tokens=7\n",
|
||
" )\n",
|
||
" reply = response.choices[0].message.content\n",
|
||
" return get_price(reply)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "93a93017-458c-4769-b81c-b2dad2af7552",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(test[0][\"price\"])\n",
|
||
"print(gpt_ft(test[0]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "87a5ad10-ed60-4533-ad61-225ceb847e6c",
|
||
"metadata": {},
|
||
"source": [
|
||
"🔔 **Reminder:** \n",
|
||
"- In **Part 2**, GPT-4o Mini (zero-shot) scored: \n",
|
||
" Avg. Error: ~$99 | RMSLE: 0.75 | Accuracy: 44.8% \n",
|
||
"\n",
|
||
"- In **Part 3**, with **RAG**, performance improved to: \n",
|
||
" Avg. Error: ~$59.54 | RMSLE: 0.42 | Accuracy: 69.2%\n",
|
||
"\n",
|
||
"🧪 **Now it’s time to see** if fine-tuning can push GPT-4o Mini even further and outperform both baselines."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0adf1500-9cc7-491a-9ea6-88932af85dca",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import helpers.testing\n",
|
||
"importlib.reload(helpers.testing)\n",
|
||
"\n",
|
||
"from helpers.testing import Tester # noqa: E402\n",
|
||
"\n",
|
||
"tester = Tester(gpt_ft, test)\n",
|
||
"tester.run()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "37439666",
|
||
"metadata": {},
|
||
"source": [
|
||
"Gpt Ft Error=$129.16 RMSLE=0.94 Hits=35.2%"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5487da30-e1a8-4db5-bf17-80bc4f109524",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Fine-tuning GPT-4o Mini led to worse performance than both its zero-shot and RAG-enhanced versions.**\n",
|
||
"\n",
|
||
"⚠️ When Fine-Tuning Isn’t Needed:\n",
|
||
"- For tasks like price prediction, GPT-4o performs well with prompting alone — thanks to strong pretraining and generalization.\n",
|
||
"- 💡 Fine-tuning isn’t always better. Use it when prompting fails — not by default.\n",
|
||
"\n",
|
||
"✅ **When Fine-Tuning Is Worth It (based on OpenAI’s own guidelines)**\n",
|
||
"- Custom tone/style – e.g., mimicking a brand voice or writing like a specific author\n",
|
||
"- More consistent output – e.g., always following a strict format\n",
|
||
"- Fix prompt failures – e.g., when multi-step instructions get ignored\n",
|
||
"- Handle edge cases – e.g., rare product types or weird inputs\n",
|
||
"- Teach new tasks – e.g., estimating prices in a custom format no model has seen before\n",
|
||
"\n",
|
||
"---\n",
|
||
"\n",
|
||
"Now that we’ve explored both frontier closed-source models and traditional ML, it’s time to turn to open-source.\n",
|
||
"\n",
|
||
"🚀 **Next up: Fine-tuned LLaMA 3.1 8B (quantized)** — can it beat its base version, outperform GPT-4o Mini, or even challenge the big players?\n",
|
||
"\n",
|
||
"🔍 Let’s find out in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part5_llama31_8b_quant.ipynb)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|