Files
LLM_Engineering_OLD/week6/community-contributions/lisekarimi/09_part4_ft_gpt4omini.ipynb
2025-06-07 05:47:51 +02:00

511 lines
15 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "12934dbc-ff4f-4dfc-8cc1-d92cc8826cf2",
"metadata": {},
"source": [
"# 🔍 Predicting Item Prices from Descriptions (Part 4)\n",
"---\n",
"- Data Curation & Preprocessing\n",
"- Model Benchmarking Traditional ML vs LLMs\n",
"- E5 Embeddings & RAG\n",
"- ➡️ Fine-Tuning GPT-4o Mini\n",
"- Evaluating LLaMA 3.1 8B Quantized\n",
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
"- Evaluating Fine-Tuned LLaMA\n",
"- Summary & Leaderboard\n",
"\n",
"---\n",
"\n",
"# 🔧 Part 4: Fine-Tuning GPT-4o Mini\n",
"\n",
"- 🧑‍💻 Skill Level: Advanced\n",
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
"- 🛠️ Requirements: 🔑 HF Token, Open API Key, wandb API Key\n",
"- Tasks:\n",
" - Convert chat data to .jsonl format for OpenAI\n",
" - Fine-tune the model and monitor with Weights & Biases\n",
" - Test the fine-tuned GPT-4o Mini \n",
"\n",
"Can fine-tuning GPT-4o Mini outperform both its zero-shot baseline and RAG-enhanced version? \n",
"Time to find out.\n",
"\n",
"---\n",
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5809630f-d3ea-41df-86ec-9cbf59a46f5c",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import importlib\n",
"import json\n",
"import re\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"from datasets import load_dataset\n",
"from openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4120c84d-c310-4d31-9e1f-1549ea4a4186",
"metadata": {},
"outputs": [],
"source": [
"load_dotenv(override=True)\n",
"\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"if not openai_api_key:\n",
" print(\"❌ OPENAI_API_KEY is missing\")\n",
"\n",
"openai = OpenAI(api_key=openai_api_key)\n",
"\n",
"hf_token = os.getenv('HF_TOKEN')\n",
"if not hf_token:\n",
" print(\"❌ HF_TOKEN is missing\")\n",
"\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "markdown",
"id": "31d3aa97-68a8-4f71-a43f-107f7c8553c5",
"metadata": {},
"source": [
"## 📥 Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2bae96a",
"metadata": {},
"outputs": [],
"source": [
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
"# %pip install -U datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c45e23d6-1304-4859-81f0-35a9ddf1c755",
"metadata": {},
"outputs": [],
"source": [
"HF_USER = \"lisekarimi\"\n",
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
"\n",
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "667adda8-add8-41b6-9e60-7870bad20c02",
"metadata": {},
"outputs": [],
"source": [
"test[0]"
]
},
{
"cell_type": "markdown",
"id": "b85d86d0-b6b1-49cd-9ef0-9214c1267199",
"metadata": {},
"source": [
"## 🛠️ Step 1 : Data Preparation"
]
},
{
"cell_type": "markdown",
"id": "d3ba760d-467a-4cd9-8d3f-e6ce84273610",
"metadata": {},
"source": [
"To fine-tune GPT-4o-mini, OpenAI requires training data in **.jsonl format**. \n",
"\n",
"`make_jsonl` converts our chat data :\n",
"\n",
"from \n",
"\n",
"[\n",
" {\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"},\n",
" {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}\n",
"]\n",
"\n",
"into the .jsonl format \n",
"\n",
"{\"messages\": [{\"role\": \"system\", \"content\": \"You estimate prices of items. Reply only with the price, no explanation\"}, {\"role\": \"user\", \"content\": \"How much is this laptop worth?\"}, {\"role\": \"assistant\", \"content\": \"Price is $999.00\"}]}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec254755-67f6-4676-b67f-c1376ea00124",
"metadata": {},
"outputs": [],
"source": [
"# Mask the price in the test item\n",
"def mask_price_value(text):\n",
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5e51957-b0ec-49f9-ae70-74771a101756",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(datapoint):\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
" assistant_response = f\"Price is ${datapoint['price']:.2f}\"\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": assistant_response}\n",
" ]\n",
"\n",
"messages_for(train[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03583d32-b0f2-44c0-820e-62c8e7e48247",
"metadata": {},
"outputs": [],
"source": [
"def make_jsonl(datapoints):\n",
" result = \"\"\n",
" for datapoint in datapoints:\n",
" messages = messages_for(datapoint)\n",
" messages_str = json.dumps(messages, ensure_ascii=False)\n",
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
" return result.strip()\n",
"\n",
"make_jsonl(train.select([0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36c9cf60-0bcb-44cb-8df6-ff2ed4110cd2",
"metadata": {},
"outputs": [],
"source": [
"ft_train = train.select(range(100))\n",
"ft_validation = train.select(range(100, 150))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "494eaecd-ae5d-4396-b694-6faf88fb7fd6",
"metadata": {},
"outputs": [],
"source": [
"# Convert the items into jsonl and write them to a file\n",
"\n",
"def write_jsonl(datapoints, filename):\n",
" with open(filename, \"w\", encoding=\"utf-8\") as f:\n",
" jsonl = make_jsonl(datapoints)\n",
" f.write(jsonl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae42986d-ab02-4a11-aa0c-ede9c63ec7a2",
"metadata": {},
"outputs": [],
"source": [
"write_jsonl(ft_train, \"data/ft_train.jsonl\")\n",
"write_jsonl(ft_validation, \"data/ft_val.jsonl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9bed22d-73ad-4820-a983-cbdccd8dbbc8",
"metadata": {},
"outputs": [],
"source": [
"with open(\"data/ft_train.jsonl\", \"rb\") as f:\n",
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
"with open(\"data/ft_val.jsonl\", \"rb\") as f:\n",
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e6c6ce8-6600-4068-9ec5-32c6428ce9ea",
"metadata": {},
"outputs": [],
"source": [
"train_file"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26943fad-4301-4bb4-97e8-be52a9743322",
"metadata": {},
"outputs": [],
"source": [
"validation_file"
]
},
{
"cell_type": "markdown",
"id": "edb0a3ec-1607-4c5b-ab06-852f951cae8b",
"metadata": {},
"source": [
"## 🚀 Step 2: Run Fine-Tuning & Monitor with wandb\n",
"We will use https://wandb.ai to monitor the training runs\n",
"\n",
"1- Create an API key in wandb\n",
"\n",
"2- Add this key in OpenAI dashboard https://platform.openai.com/account/organization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59f552fe-5e80-4742-94a8-5492556a6543",
"metadata": {},
"outputs": [],
"source": [
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "144088d7-7c30-439a-9282-1e6096c181ea",
"metadata": {},
"outputs": [],
"source": [
"# Run the fine tuning\n",
"\n",
"openai.fine_tuning.jobs.create(\n",
" training_file=train_file.id,\n",
" validation_file=validation_file.id,\n",
" model=\"gpt-4o-mini-2024-07-18\",\n",
" seed=42,\n",
" hyperparameters={\"n_epochs\": 1},\n",
" integrations = [wandb_integration],\n",
" suffix=\"pricer\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "330e75f5-0208-4c74-8dd3-07bc06047b2e",
"metadata": {},
"outputs": [],
"source": [
"job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
"job_id\n",
"\n",
"# Then check your wandb dashboard to view the run of this job ID"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a92dac5-e6d8-439c-b55e-507becb37a6c",
"metadata": {},
"outputs": [],
"source": [
"# Use this command to track the fine-tuning progress here\n",
"\n",
"openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=2).data"
]
},
{
"cell_type": "markdown",
"id": "b6b65677-06b2-47d3-b0e6-51210a3d832b",
"metadata": {},
"source": [
"# 📧 Youll get an email once fine-tuning is complete. ☕ You can take a break until then. ▶️ Once you receive it, run the cells below to continue."
]
},
{
"cell_type": "markdown",
"id": "0a7af4be-0b55-4654-af7a-f47485babc52",
"metadata": {},
"source": [
"## Step 3 : Test the fine tuned model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8497eb8-49ee-4a05-9e51-fc1b4b2b41d4",
"metadata": {},
"outputs": [],
"source": [
"ft_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
"ft_model_name"
]
},
{
"cell_type": "markdown",
"id": "12bed33f-be31-4d7c-8651-3f267c529304",
"metadata": {},
"source": [
"You can find the entire fine-tuning process in the **Fine-tuning** dashboard on OpenAI.\n",
"\n",
"![Fine-tuning Process](https://github.com/lisekarimi/lexo/blob/main/assets/09_ft_gpt4omini.png?raw=true)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac6a89ef-f982-457a-bad7-bd84b6132a07",
"metadata": {},
"outputs": [],
"source": [
"# Build LLM messages\n",
"def build_messages(datapoint):\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
" user_prompt = mask_price_value(datapoint[\"text\"]).replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\",\"\")\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
" ]\n",
"\n",
"def get_price(s):\n",
" s = s.replace('$','').replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
" return float(match.group()) if match else 0\n",
"\n",
"def gpt_ft(datapoint):\n",
" response = openai.chat.completions.create(\n",
" model=ft_model_name,\n",
" messages=build_messages(datapoint),\n",
" seed=42,\n",
" max_tokens=7\n",
" )\n",
" reply = response.choices[0].message.content\n",
" return get_price(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93a93017-458c-4769-b81c-b2dad2af7552",
"metadata": {},
"outputs": [],
"source": [
"print(test[0][\"price\"])\n",
"print(gpt_ft(test[0]))"
]
},
{
"cell_type": "markdown",
"id": "87a5ad10-ed60-4533-ad61-225ceb847e6c",
"metadata": {},
"source": [
"🔔 **Reminder:** \n",
"- In **Part 2**, GPT-4o Mini (zero-shot) scored: \n",
" Avg. Error: ~$99 | RMSLE: 0.75 | Accuracy: 44.8% \n",
"\n",
"- In **Part 3**, with **RAG**, performance improved to: \n",
" Avg. Error: ~$59.54 | RMSLE: 0.42 | Accuracy: 69.2%\n",
"\n",
"🧪 **Now its time to see** if fine-tuning can push GPT-4o Mini even further and outperform both baselines."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0adf1500-9cc7-491a-9ea6-88932af85dca",
"metadata": {},
"outputs": [],
"source": [
"import helpers.testing\n",
"importlib.reload(helpers.testing)\n",
"\n",
"from helpers.testing import Tester # noqa: E402\n",
"\n",
"tester = Tester(gpt_ft, test)\n",
"tester.run()"
]
},
{
"cell_type": "markdown",
"id": "37439666",
"metadata": {},
"source": [
"Gpt Ft Error=$129.16 RMSLE=0.94 Hits=35.2%"
]
},
{
"cell_type": "markdown",
"id": "5487da30-e1a8-4db5-bf17-80bc4f109524",
"metadata": {},
"source": [
"**Fine-tuning GPT-4o Mini led to worse performance than both its zero-shot and RAG-enhanced versions.**\n",
"\n",
"⚠️ When Fine-Tuning Isnt Needed:\n",
"- For tasks like price prediction, GPT-4o performs well with prompting alone — thanks to strong pretraining and generalization.\n",
"- 💡 Fine-tuning isnt always better. Use it when prompting fails — not by default.\n",
"\n",
"✅ **When Fine-Tuning Is Worth It (based on OpenAIs own guidelines)**\n",
"- Custom tone/style e.g., mimicking a brand voice or writing like a specific author\n",
"- More consistent output e.g., always following a strict format\n",
"- Fix prompt failures e.g., when multi-step instructions get ignored\n",
"- Handle edge cases e.g., rare product types or weird inputs\n",
"- Teach new tasks e.g., estimating prices in a custom format no model has seen before\n",
"\n",
"---\n",
"\n",
"Now that weve explored both frontier closed-source models and traditional ML, its time to turn to open-source.\n",
"\n",
"🚀 **Next up: Fine-tuned LLaMA 3.1 8B (quantized)** — can it beat its base version, outperform GPT-4o Mini, or even challenge the big players?\n",
"\n",
"🔍 Lets find out in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part5_llama31_8b_quant.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}