Files
LLM_Engineering_OLD/week6/community-contributions/lisekarimi/09_part1_data_curation.ipynb
2025-06-07 05:47:51 +02:00

717 lines
44 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "40978455-23da-4159-bf08-15d9e8f79984",
"metadata": {},
"source": [
"# 🔍 Predicting Item Prices from Descriptions (Part 1)\n",
"A complete pipeline from raw text to fine-tuned frontier and open source models\n",
"\n",
"---\n",
"In this project, we aim to **predict item prices based solely on their textual descriptions**. \n",
"\n",
"We approach the problem with a structured 8-part pipeline:\n",
"\n",
"- 🧩 **Part 1: Data Curation & Preprocessing** : We aggregate, clean, analyze, and balance the dataset — then export it in .pkl format and save it in the HuggingFace Hub for the next step: model training and evaluation.\n",
"\n",
"- ⚔️ **Part 2: Traditional ML vs Frontier LLMs** : We compare traditional machine learning models (LR, SVR, XGBoost) using vectorized text inputs (BoW, Word2Vec) against LLMs like GPT-4o, LLaMA, Deepseek ... ❗ Who will predict better: handcrafted features or massive pretraining?\n",
"\n",
"- 🧠 **Part 3: E5 Embeddings & RAG** : We compare XGBoost on **contextual dense embeddings** vs. Word2Vec, and test if **RAG** boosts GPT-4o Minis price predictions. 📦 Do contextual embeddings and retrieval improve price prediction?\n",
"\n",
"- 🔧 **Part 4: Fine-Tuning GPT-4o Mini** : We fine-tune GPT-4o Mini on our curated dataset and compare performance before and after.\n",
"🤖 Can a fine-tuned GPT-4o Mini beat its own zero-shot performance?\n",
"\n",
"- 🦙 **Part 5: Evaluating LLaMA 3.1 8B Quantized** : We run LLaMA 3.1 (8B, quantized) using the same evaluation setup to see how well an open-source base model performs with no fine-tuning.\n",
"\n",
"- ⚙️ **Part 6: Fine-Tuning LLaMA 3.1 with QLoRA** : We fine-tune LLaMA 3.1 using QLoRA and explore key hyperparameters, tracking **training and validation loss** to monitor overfitting and select the best configuration.\n",
"\n",
"- 🧪 **Part 7: Evaluating Fine-Tuned LLaMA 3.1 8B (Quantized)** : After fine-tuning LLaMA 3.1, it's time to evaluate its performance and see how it stacks up against other models. Let's dive into the results.\n",
"\n",
"- 🏆**Part 8: Summary & Leaderboard** : Who comes out on top? Lets find out. We wrap up with final model rankings and key insights across ML, embeddings, RAG, and fine-tuned frontier and open-source models.\n",
"\n",
"---\n",
"- ➡️ Data Curation & Preprocessing\n",
"- Model Benchmarking Traditional ML vs LLMs\n",
"- E5 Embeddings & RAG\n",
"- Fine-Tuning GPT-4o Mini\n",
"- Evaluating LLaMA 3.1 8B Quantized\n",
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
"- Evaluating Fine-Tuned LLaMA \n",
"- Summary & Leaderboard\n",
"\n",
"---\n",
"\n",
"Lets begin with Part 1.\n",
"\n",
"# 🧩 Part 1: Data Curation & Preprocessing\n",
"\n",
"- Tasks:\n",
" - Load and filter dataset, then prepare each datapoint\n",
" - Explore, visualize, balance price distribution\n",
" - Export .pkl, upload to HF Hub\n",
"- 🧑‍💻 Skill Level: Advanced\n",
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
"- 🛠️ Requirements: 🔑 Hugging Face Token\n",
"\n",
"---\n",
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcf2f470",
"metadata": {},
"outputs": [],
"source": [
"!uv pip install transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddbb5eb0-9ab7-4675-b195-0bf4055b9320",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import sys\n",
"import random\n",
"import pickle\n",
"import importlib\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"from datasets import Dataset, DatasetDict\n",
"from collections import Counter, defaultdict\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa916b7a-9044-4461-b29a-815d47973e75",
"metadata": {},
"outputs": [],
"source": [
"# import datasets\n",
"# print(datasets.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6cf6e19-1276-4b37-8f9b-6acf1473a7c6",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv(override=True)\n",
"hf_token = os.getenv('HF_TOKEN')\n",
"if not hf_token:\n",
" print(\"❌ HF_TOKEN is missing\")\n",
"\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "markdown",
"id": "a1637a14-b2df-4286-a8d6-ddae413f4a8a",
"metadata": {},
"source": [
"## ⚙️ Data Loading & Curation (Simultaneously)\n",
"We load and curate the data at the same time using loaders.py and items.py.\n",
"- Datasets come from: https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories\n",
"- `loaders.py` handles parallel loading and filtering of products\n",
"- `items.py` defines the Item class to clean, validate, and prepare each datapoint (title, description, price...) for modeling.\n",
"\n",
"\n",
"🛠️ Note: Data is filtered to include items priced between 1 and 999 USD.\n",
"\n",
"💡 Comments have been added in both files to clarify the processing logic.\n",
"\n",
"⚠️ Loading 2.8M+ items can take 40+ mins on a regular laptop.\n",
"\n",
"⚠️ Set WORKER wisely in `loaders.py` to match your system capacity. Too many may crash your machine."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b89273c-e02f-4c15-8394-5d948a266bfc",
"metadata": {},
"outputs": [],
"source": [
"sys.path.append('./helpers')\n",
"import helpers.items\n",
"import helpers.loaders\n",
"\n",
"importlib.reload(helpers.items)\n",
"importlib.reload(helpers.loaders)\n",
"\n",
"from helpers.items import Item # noqa: E402\n",
"from helpers.loaders import ItemLoader # noqa: E402"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "260a123b-8f34-4c66-bcac-1c3b25e95d7f",
"metadata": {},
"outputs": [],
"source": [
"dataset_names = [\n",
" \"Automotive\",\n",
" \"Electronics\",\n",
" \"Office_Products\",\n",
" \"Tools_and_Home_Improvement\",\n",
" \"Cell_Phones_and_Accessories\",\n",
" \"Toys_and_Games\",\n",
" \"Appliances\",\n",
" \"Musical_Instruments\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b482032-cba9-4ee9-9451-9b7dc9f41be6",
"metadata": {},
"outputs": [],
"source": [
"items = []\n",
"for dataset_name in dataset_names:\n",
" loader = ItemLoader(dataset_name)\n",
" items.extend(loader.load())\n",
"\n",
"# Now, time for a coffee break!!\n",
"# By the way, the larger datasets first... it speeds up the process."
]
},
{
"cell_type": "markdown",
"id": "145d0648-e01d-46b9-ad42-f10b69fccbc3",
"metadata": {},
"source": [
"## 🔍 Inspecting a Sample Datapoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0185985d-5f67-4e4b-ac66-95b5b293231f",
"metadata": {},
"outputs": [],
"source": [
"print(f\"A grand total of {len(items):,} items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b0c0ae8-c0ec-4f6f-b847-800da379c01b",
"metadata": {},
"outputs": [],
"source": [
"# Investigate the first item from the list\n",
"\n",
"datapoint = items[0]\n",
"\n",
"# Access various attributes\n",
"title = datapoint.title\n",
"details = datapoint.details\n",
"price = datapoint.price\n",
"category = datapoint.category\n",
"\n",
"print(f\"Datapoint: {datapoint}\")\n",
"print('*' * 40)\n",
"print(f\"Title: {title}\")\n",
"print('*' * 40)\n",
"print(f\"Detail: {details}\")\n",
"print('*' * 40)\n",
"print(f\"Price: ${price}\")\n",
"print('*' * 40)\n",
"print(f\"Category: {category}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e05ed6e4-1cbc-46a4-be2f-4832b99e5ec3",
"metadata": {},
"outputs": [],
"source": [
"# The prompt that will be used during training\n",
"print(items[0].prompt)\n",
"print('*' * 40)\n",
"# The prompt that will be used during testing\n",
"print(items[0].test_prompt())"
]
},
{
"cell_type": "markdown",
"id": "f66e714d-2bae-458e-a0f6-1ce78d0696b3",
"metadata": {},
"source": [
"## 📊 Data Visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd50ae2c-b34e-4be7-bd74-62055e4d5b2d",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(15, 6))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c736b038-2dcd-40b9-8ae9-d17271f1ff81",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of token counts\n",
"\n",
"tokens = [item.token_count for item in items]\n",
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
"plt.xlabel('Length (tokens)')\n",
"plt.ylabel('Count')\n",
"plt.hist(tokens, rwidth=0.7, color=\"blue\", bins=range(0, 300, 10))\n",
"plt.show()"
]
},
{
"attachments": {
"image.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "940ba698",
"metadata": {},
"source": [
"![image.png](attachment:image.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da33633a-7ad5-479c-8dff-f7a7a149d49c",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices\n",
"\n",
"prices = [item.price for item in items]\n",
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count')\n",
"plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0f494d7-349e-4878-929c-075ac97c6b6d",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of categories\n",
"\n",
"category_counts = Counter()\n",
"for item in items:\n",
" category_counts[item.category]+=1\n",
"\n",
"categories = category_counts.keys()\n",
"counts = [category_counts[category] for category in categories]\n",
"\n",
"# Bar chart by category\n",
"plt.bar(categories, counts, color=\"goldenrod\")\n",
"plt.title('How many items in each category')\n",
"plt.xlabel('Categories')\n",
"plt.ylabel('Count')\n",
"\n",
"plt.xticks(rotation=30, ha='right')\n",
"\n",
"# Add value labels on top of each bar\n",
"for i, v in enumerate(counts):\n",
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
"\n",
"# Display the chart\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "d4fe384d-049b-4742-98e5-20d162db5151",
"metadata": {},
"source": [
"## 🎯 Data Sampling\n",
"\n",
"We sample to keep the dataset balanced but rich:\n",
"- 🎯 Keep all items if price ≥ $240 or group size ≤ 1200\n",
"- 🎯 For large groups, randomly sample 1200 items, favoring rare categories\n",
"\n",
"✅ This keeps valuable high-price items and avoids overrepresented classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20330037-744d-4834-8ece-413a8dbe2030",
"metadata": {},
"outputs": [],
"source": [
"HEAVY_DATASET = \"Automative\"\n",
"\n",
"# Group items by rounded price\n",
"# Slots is a dictionary where the keys are rounded prices and the values are lists of items that have that rounded price\n",
"slots = defaultdict(list)\n",
"for item in items:\n",
" slots[round(item.price)].append(item)\n",
"\n",
"np.random.seed(42) # Set random seed for reproducibility\n",
"sample = [] # Final collection of items after our sampling process completes\n",
"\n",
"# Sampling loop\n",
"for price, items_at_price in slots.items():\n",
"\n",
" # Take all items if price ≥ 240 or small group\n",
" if price >= 240 or len(items_at_price) <= 1200:\n",
" sample.extend(items_at_price)\n",
"\n",
" # Otherwise sample 1200 items with weights\n",
" else:\n",
"\n",
" # Weight: 1 for toys, 5 for others\n",
" weights = [1 if item.category == HEAVY_DATASET else 5 for item in items_at_price]\n",
" weights = np.array(weights) / sum(weights)\n",
"\n",
" indices = np.random.choice(len(items_at_price), 1200, False, weights) # False = don't pick the same index twice\n",
" sample.extend([items_at_price[i] for i in indices])\n",
"\n",
"print(f\"There are {len(sample):,} items in the sample\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21aed337-6f15-48e4-8155-70551ed1d5e0",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices in the sample\n",
"\n",
"prices = [float(item.price) for item in sample]\n",
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count')\n",
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08a7353e-2752-4493-bb0b-6057d1eab16d",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of categories in the sample\n",
"\n",
"category_counts = Counter()\n",
"for item in sample:\n",
" category_counts[item.category]+=1\n",
"\n",
"categories = category_counts.keys()\n",
"counts = [category_counts[category] for category in categories]\n",
"\n",
"# Create bar chart\n",
"plt.bar(categories, counts, color=\"pink\")\n",
"\n",
"# Customize the chart\n",
"plt.title('How many in each category')\n",
"plt.xlabel('Categories')\n",
"plt.ylabel('Count')\n",
"\n",
"plt.xticks(rotation=30, ha='right')\n",
"\n",
"# Add value labels on top of each bar\n",
"for i, v in enumerate(counts):\n",
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
"\n",
"# Display the chart\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "9bdb0c58-24e0-4ab5-8a28-2136b53ab915",
"metadata": {},
"source": [
"The HEAVY_DATASET still in the lead, but improved somewhat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ce8ff80-cd19-4c3b-965f-ce6af8ee347d",
"metadata": {},
"outputs": [],
"source": [
"# Create pie chart\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 8))\n",
"wedges, texts, autotexts = ax.pie(\n",
" counts,\n",
" # labels=categories,\n",
" autopct='%1.0f%%',\n",
" startangle=90,\n",
" pctdistance=0.85,\n",
" labeldistance=1.1\n",
")\n",
"ax.legend(wedges, categories, title=\"Categories\", loc=\"lower center\", bbox_to_anchor=(0.5, 1.15), ncol=3)\n",
"\n",
"# Draw donut center\n",
"centre_circle = plt.Circle((0, 0), 0.70, fc='white')\n",
"fig.gca().add_artist(centre_circle)\n",
"\n",
"# Add center label\n",
"ax.text(0, 0, \"Categories\", ha='center', va='center', fontsize=14, fontweight='bold')\n",
"\n",
"# Equal aspect ratio\n",
"plt.axis('equal')\n",
"plt.title(\"Category Distribution\")\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acbc6beb-fab4-49ab-bc7e-243638c1fa99",
"metadata": {},
"outputs": [],
"source": [
"# How does the price vary with the character count of the prompt?\n",
"\n",
"sizes = [len(item.prompt) for item in sample]\n",
"prices = [item.price for item in sample]\n",
"\n",
"# Create the scatter plot\n",
"plt.scatter(sizes, prices, s=0.2, color=\"red\")\n",
"\n",
"# Add labels and title\n",
"plt.xlabel('Size')\n",
"plt.ylabel('Price')\n",
"plt.title('Is there a simple correlation between prompt length and item price?')\n",
"\n",
"# Display the plot\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "76b060a4-0b8d-495c-bb96-28cb7b7ec623",
"metadata": {},
"source": [
"There is no strong or simple correlation between prompt length and item price.\n",
"\n",
"In other words, longer prompts dont clearly mean higher prices, and vice versa."
]
},
{
"cell_type": "markdown",
"id": "0f33211c-3548-4a21-990b-21aa55089186",
"metadata": {},
"source": [
"## ✅ Final Check Before Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be8d0c68-ac6e-4a4d-a6c7-64e9c6763ec4",
"metadata": {},
"outputs": [],
"source": [
"# Ensure the price label is correctly placed by the end of the prompt\n",
"\n",
"def report(item):\n",
" prompt = item.prompt\n",
" tokens = Item.tokenizer.encode(item.prompt)\n",
" print(prompt)\n",
" print(tokens[-6:])\n",
" print(Item.tokenizer.batch_decode(tokens[-6:]))\n",
"\n",
"report(sample[50])"
]
},
{
"cell_type": "markdown",
"id": "656d523d-8297-4d75-a973-a7e5517d21bc",
"metadata": {},
"source": [
"LLaMA and GPT-4o both tokenize numbers from 1 to 999 as a single token, while models like Qwen2, Gemma, and Phi-3 split them into multiple tokens. This helps keep prices compact in our prompts — useful for our project, though not strictly required."
]
},
{
"cell_type": "markdown",
"id": "e36254ba-d20f-44ad-b991-1f1f3cdc4aaa",
"metadata": {},
"source": [
"## 📦 Creating Train/Test Datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5cfb5092-c38d-4c14-8dd0-e1d97c06d7f6",
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"random.shuffle(sample)\n",
"train = sample[:400_000]\n",
"test = sample[400_000:402_000]\n",
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f084822-e489-4946-8cf5-f5b0ebd7a23c",
"metadata": {},
"outputs": [],
"source": [
"print(train[0].prompt)\n",
"print('*' * 40)\n",
"print(test[0].test_prompt())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d49a08ce-dd41-4af8-82f6-4701628e8152",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices in the first 250 test points\n",
"\n",
"prices = [float(item.price) for item in test[:250]]\n",
"plt.figure(figsize=(15, 6))\n",
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count')\n",
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c581439-93f2-422a-924f-fd6c58ef8693",
"metadata": {},
"outputs": [],
"source": [
"# Extract prompts and prices\n",
"train_prompts = [item.prompt for item in train]\n",
"train_prices = [item.price for item in train]\n",
"test_prompts = [item.test_prompt() for item in test]\n",
"test_prices = [item.price for item in test]\n",
"\n",
"# Create Hugging Face datasets\n",
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
"dataset = DatasetDict({\n",
" \"train\": train_dataset,\n",
" \"test\": test_dataset\n",
"})\n",
"\n",
"# Save full Item objects\n",
"os.makedirs(\"data\", exist_ok=True) # Make sure the folder exists\n",
"\n",
"# Save full Item objects to the folder\n",
"with open('data/train.pkl', 'wb') as file:\n",
" pickle.dump(train, file)\n",
"\n",
"with open('data/test.pkl', 'wb') as file:\n",
" pickle.dump(test, file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3914d029-350e-4140-a31f-e931fa289a41",
"metadata": {},
"outputs": [],
"source": [
"# Push to the Hugging Face Hub\n",
"USERNAME = \"lisekarimi\" # 🔧 Replace with your Hugging Face username\n",
"DATASET_NAME = f\"{USERNAME}/pricer-data\"\n",
"\n",
"dataset.push_to_hub(DATASET_NAME, private=True)"
]
},
{
"cell_type": "markdown",
"id": "3d8f3b33-41f8-4ee6-96ed-27677ffc8ec4",
"metadata": {},
"source": [
"**Note:** \n",
"- The dataset `pricer-data` on Hugging Face only contains `text` and `price`:\n",
"\n",
"\n",
"{\n",
" \"text\": \"How much does this cost...Price is $175.00\",\n",
" \"price\": 175.0\n",
"}\n",
"\n",
"- Full `Item` objects (with metadata) are available in `train.pkl` and `test.pkl`:\n",
"\n",
"Item(data={\n",
" \"title\": str,\n",
" \"description\": list[str],\n",
" \"features\": list[str],\n",
" \"details\": str\n",
"}, price=float)\n",
"\n",
"\n",
"Now, its time to move on to **Part 2: Model Benchmarking Traditional ML vs Frontier LLMs.**\n",
"\n",
"🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part2_tradml_vs_frontier.ipynb)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}