717 lines
44 KiB
Plaintext
717 lines
44 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "40978455-23da-4159-bf08-15d9e8f79984",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 🔍 Predicting Item Prices from Descriptions (Part 1)\n",
|
||
"A complete pipeline from raw text to fine-tuned frontier and open source models\n",
|
||
"\n",
|
||
"---\n",
|
||
"In this project, we aim to **predict item prices based solely on their textual descriptions**. \n",
|
||
"\n",
|
||
"We approach the problem with a structured 8-part pipeline:\n",
|
||
"\n",
|
||
"- 🧩 **Part 1: Data Curation & Preprocessing** : We aggregate, clean, analyze, and balance the dataset — then export it in .pkl format and save it in the HuggingFace Hub for the next step: model training and evaluation.\n",
|
||
"\n",
|
||
"- ⚔️ **Part 2: Traditional ML vs Frontier LLMs** : We compare traditional machine learning models (LR, SVR, XGBoost) using vectorized text inputs (BoW, Word2Vec) against LLMs like GPT-4o, LLaMA, Deepseek ... ❗ Who will predict better: handcrafted features or massive pretraining?\n",
|
||
"\n",
|
||
"- 🧠 **Part 3: E5 Embeddings & RAG** : We compare XGBoost on **contextual dense embeddings** vs. Word2Vec, and test if **RAG** boosts GPT-4o Mini’s price predictions. 📦 Do contextual embeddings and retrieval improve price prediction?\n",
|
||
"\n",
|
||
"- 🔧 **Part 4: Fine-Tuning GPT-4o Mini** : We fine-tune GPT-4o Mini on our curated dataset and compare performance before and after.\n",
|
||
"🤖 Can a fine-tuned GPT-4o Mini beat its own zero-shot performance?\n",
|
||
"\n",
|
||
"- 🦙 **Part 5: Evaluating LLaMA 3.1 8B Quantized** : We run LLaMA 3.1 (8B, quantized) using the same evaluation setup to see how well an open-source base model performs with no fine-tuning.\n",
|
||
"\n",
|
||
"- ⚙️ **Part 6: Fine-Tuning LLaMA 3.1 with QLoRA** : We fine-tune LLaMA 3.1 using QLoRA and explore key hyperparameters, tracking **training and validation loss** to monitor overfitting and select the best configuration.\n",
|
||
"\n",
|
||
"- 🧪 **Part 7: Evaluating Fine-Tuned LLaMA 3.1 8B (Quantized)** : After fine-tuning LLaMA 3.1, it's time to evaluate its performance and see how it stacks up against other models. Let's dive into the results.\n",
|
||
"\n",
|
||
"- 🏆**Part 8: Summary & Leaderboard** : Who comes out on top? Let’s find out. We wrap up with final model rankings and key insights across ML, embeddings, RAG, and fine-tuned frontier and open-source models.\n",
|
||
"\n",
|
||
"---\n",
|
||
"- ➡️ Data Curation & Preprocessing\n",
|
||
"- Model Benchmarking – Traditional ML vs LLMs\n",
|
||
"- E5 Embeddings & RAG\n",
|
||
"- Fine-Tuning GPT-4o Mini\n",
|
||
"- Evaluating LLaMA 3.1 8B Quantized\n",
|
||
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
|
||
"- Evaluating Fine-Tuned LLaMA \n",
|
||
"- Summary & Leaderboard\n",
|
||
"\n",
|
||
"---\n",
|
||
"\n",
|
||
"Let’s begin with Part 1.\n",
|
||
"\n",
|
||
"# 🧩 Part 1: Data Curation & Preprocessing\n",
|
||
"\n",
|
||
"- Tasks:\n",
|
||
" - Load and filter dataset, then prepare each datapoint\n",
|
||
" - Explore, visualize, balance price distribution\n",
|
||
" - Export .pkl, upload to HF Hub\n",
|
||
"- 🧑💻 Skill Level: Advanced\n",
|
||
"- ⚙️ Hardware: ✅ CPU is sufficient — no GPU required\n",
|
||
"- 🛠️ Requirements: 🔑 Hugging Face Token\n",
|
||
"\n",
|
||
"---\n",
|
||
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dcf2f470",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"!uv pip install transformers"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ddbb5eb0-9ab7-4675-b195-0bf4055b9320",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# imports\n",
|
||
"\n",
|
||
"import os\n",
|
||
"import sys\n",
|
||
"import random\n",
|
||
"import pickle\n",
|
||
"import importlib\n",
|
||
"from dotenv import load_dotenv\n",
|
||
"from huggingface_hub import login\n",
|
||
"from datasets import Dataset, DatasetDict\n",
|
||
"from collections import Counter, defaultdict\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"%matplotlib inline\n",
|
||
"import numpy as np"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "fa916b7a-9044-4461-b29a-815d47973e75",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# import datasets\n",
|
||
"# print(datasets.__version__)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e6cf6e19-1276-4b37-8f9b-6acf1473a7c6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# environment\n",
|
||
"\n",
|
||
"load_dotenv(override=True)\n",
|
||
"hf_token = os.getenv('HF_TOKEN')\n",
|
||
"if not hf_token:\n",
|
||
" print(\"❌ HF_TOKEN is missing\")\n",
|
||
"\n",
|
||
"login(hf_token, add_to_git_credential=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a1637a14-b2df-4286-a8d6-ddae413f4a8a",
|
||
"metadata": {},
|
||
"source": [
|
||
"## ⚙️ Data Loading & Curation (Simultaneously)\n",
|
||
"We load and curate the data at the same time using loaders.py and items.py.\n",
|
||
"- Datasets come from: https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories\n",
|
||
"- `loaders.py` handles parallel loading and filtering of products\n",
|
||
"- `items.py` defines the Item class to clean, validate, and prepare each datapoint (title, description, price...) for modeling.\n",
|
||
"\n",
|
||
"\n",
|
||
"🛠️ Note: Data is filtered to include items priced between 1 and 999 USD.\n",
|
||
"\n",
|
||
"💡 Comments have been added in both files to clarify the processing logic.\n",
|
||
"\n",
|
||
"⚠️ Loading 2.8M+ items can take 40+ mins on a regular laptop.\n",
|
||
"\n",
|
||
"⚠️ Set WORKER wisely in `loaders.py` to match your system capacity. Too many may crash your machine."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "8b89273c-e02f-4c15-8394-5d948a266bfc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sys.path.append('./helpers')\n",
|
||
"import helpers.items\n",
|
||
"import helpers.loaders\n",
|
||
"\n",
|
||
"importlib.reload(helpers.items)\n",
|
||
"importlib.reload(helpers.loaders)\n",
|
||
"\n",
|
||
"from helpers.items import Item # noqa: E402\n",
|
||
"from helpers.loaders import ItemLoader # noqa: E402"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "260a123b-8f34-4c66-bcac-1c3b25e95d7f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset_names = [\n",
|
||
" \"Automotive\",\n",
|
||
" \"Electronics\",\n",
|
||
" \"Office_Products\",\n",
|
||
" \"Tools_and_Home_Improvement\",\n",
|
||
" \"Cell_Phones_and_Accessories\",\n",
|
||
" \"Toys_and_Games\",\n",
|
||
" \"Appliances\",\n",
|
||
" \"Musical_Instruments\",\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9b482032-cba9-4ee9-9451-9b7dc9f41be6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"items = []\n",
|
||
"for dataset_name in dataset_names:\n",
|
||
" loader = ItemLoader(dataset_name)\n",
|
||
" items.extend(loader.load())\n",
|
||
"\n",
|
||
"# Now, time for a coffee break!!\n",
|
||
"# By the way, the larger datasets first... it speeds up the process."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "145d0648-e01d-46b9-ad42-f10b69fccbc3",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🔍 Inspecting a Sample Datapoint"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0185985d-5f67-4e4b-ac66-95b5b293231f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(f\"A grand total of {len(items):,} items\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "2b0c0ae8-c0ec-4f6f-b847-800da379c01b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Investigate the first item from the list\n",
|
||
"\n",
|
||
"datapoint = items[0]\n",
|
||
"\n",
|
||
"# Access various attributes\n",
|
||
"title = datapoint.title\n",
|
||
"details = datapoint.details\n",
|
||
"price = datapoint.price\n",
|
||
"category = datapoint.category\n",
|
||
"\n",
|
||
"print(f\"Datapoint: {datapoint}\")\n",
|
||
"print('*' * 40)\n",
|
||
"print(f\"Title: {title}\")\n",
|
||
"print('*' * 40)\n",
|
||
"print(f\"Detail: {details}\")\n",
|
||
"print('*' * 40)\n",
|
||
"print(f\"Price: ${price}\")\n",
|
||
"print('*' * 40)\n",
|
||
"print(f\"Category: {category}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e05ed6e4-1cbc-46a4-be2f-4832b99e5ec3",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# The prompt that will be used during training\n",
|
||
"print(items[0].prompt)\n",
|
||
"print('*' * 40)\n",
|
||
"# The prompt that will be used during testing\n",
|
||
"print(items[0].test_prompt())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f66e714d-2bae-458e-a0f6-1ce78d0696b3",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📊 Data Visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dd50ae2c-b34e-4be7-bd74-62055e4d5b2d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.figure(figsize=(15, 6))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "c736b038-2dcd-40b9-8ae9-d17271f1ff81",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of token counts\n",
|
||
"\n",
|
||
"tokens = [item.token_count for item in items]\n",
|
||
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
|
||
"plt.xlabel('Length (tokens)')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"plt.hist(tokens, rwidth=0.7, color=\"blue\", bins=range(0, 300, 10))\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {
|
||
"image.png": {
|
||
"image/png": ""
|
||
}
|
||
},
|
||
"cell_type": "markdown",
|
||
"id": "940ba698",
|
||
"metadata": {},
|
||
"source": [
|
||
""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "da33633a-7ad5-479c-8dff-f7a7a149d49c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of prices\n",
|
||
"\n",
|
||
"prices = [item.price for item in items]\n",
|
||
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
|
||
"plt.xlabel('Price ($)')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d0f494d7-349e-4878-929c-075ac97c6b6d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of categories\n",
|
||
"\n",
|
||
"category_counts = Counter()\n",
|
||
"for item in items:\n",
|
||
" category_counts[item.category]+=1\n",
|
||
"\n",
|
||
"categories = category_counts.keys()\n",
|
||
"counts = [category_counts[category] for category in categories]\n",
|
||
"\n",
|
||
"# Bar chart by category\n",
|
||
"plt.bar(categories, counts, color=\"goldenrod\")\n",
|
||
"plt.title('How many items in each category')\n",
|
||
"plt.xlabel('Categories')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"\n",
|
||
"plt.xticks(rotation=30, ha='right')\n",
|
||
"\n",
|
||
"# Add value labels on top of each bar\n",
|
||
"for i, v in enumerate(counts):\n",
|
||
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
|
||
"\n",
|
||
"# Display the chart\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d4fe384d-049b-4742-98e5-20d162db5151",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🎯 Data Sampling\n",
|
||
"\n",
|
||
"We sample to keep the dataset balanced but rich:\n",
|
||
"- 🎯 Keep all items if price ≥ $240 or group size ≤ 1200\n",
|
||
"- 🎯 For large groups, randomly sample 1200 items, favoring rare categories\n",
|
||
"\n",
|
||
"✅ This keeps valuable high-price items and avoids overrepresented classes"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "20330037-744d-4834-8ece-413a8dbe2030",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"HEAVY_DATASET = \"Automative\"\n",
|
||
"\n",
|
||
"# Group items by rounded price\n",
|
||
"# Slots is a dictionary where the keys are rounded prices and the values are lists of items that have that rounded price\n",
|
||
"slots = defaultdict(list)\n",
|
||
"for item in items:\n",
|
||
" slots[round(item.price)].append(item)\n",
|
||
"\n",
|
||
"np.random.seed(42) # Set random seed for reproducibility\n",
|
||
"sample = [] # Final collection of items after our sampling process completes\n",
|
||
"\n",
|
||
"# Sampling loop\n",
|
||
"for price, items_at_price in slots.items():\n",
|
||
"\n",
|
||
" # Take all items if price ≥ 240 or small group\n",
|
||
" if price >= 240 or len(items_at_price) <= 1200:\n",
|
||
" sample.extend(items_at_price)\n",
|
||
"\n",
|
||
" # Otherwise sample 1200 items with weights\n",
|
||
" else:\n",
|
||
"\n",
|
||
" # Weight: 1 for toys, 5 for others\n",
|
||
" weights = [1 if item.category == HEAVY_DATASET else 5 for item in items_at_price]\n",
|
||
" weights = np.array(weights) / sum(weights)\n",
|
||
"\n",
|
||
" indices = np.random.choice(len(items_at_price), 1200, False, weights) # False = don't pick the same index twice\n",
|
||
" sample.extend([items_at_price[i] for i in indices])\n",
|
||
"\n",
|
||
"print(f\"There are {len(sample):,} items in the sample\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "21aed337-6f15-48e4-8155-70551ed1d5e0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of prices in the sample\n",
|
||
"\n",
|
||
"prices = [float(item.price) for item in sample]\n",
|
||
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
|
||
"plt.xlabel('Price ($)')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "08a7353e-2752-4493-bb0b-6057d1eab16d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of categories in the sample\n",
|
||
"\n",
|
||
"category_counts = Counter()\n",
|
||
"for item in sample:\n",
|
||
" category_counts[item.category]+=1\n",
|
||
"\n",
|
||
"categories = category_counts.keys()\n",
|
||
"counts = [category_counts[category] for category in categories]\n",
|
||
"\n",
|
||
"# Create bar chart\n",
|
||
"plt.bar(categories, counts, color=\"pink\")\n",
|
||
"\n",
|
||
"# Customize the chart\n",
|
||
"plt.title('How many in each category')\n",
|
||
"plt.xlabel('Categories')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"\n",
|
||
"plt.xticks(rotation=30, ha='right')\n",
|
||
"\n",
|
||
"# Add value labels on top of each bar\n",
|
||
"for i, v in enumerate(counts):\n",
|
||
" plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n",
|
||
"\n",
|
||
"# Display the chart\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9bdb0c58-24e0-4ab5-8a28-2136b53ab915",
|
||
"metadata": {},
|
||
"source": [
|
||
"The HEAVY_DATASET still in the lead, but improved somewhat"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4ce8ff80-cd19-4c3b-965f-ce6af8ee347d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Create pie chart\n",
|
||
"\n",
|
||
"fig, ax = plt.subplots(figsize=(8, 8))\n",
|
||
"wedges, texts, autotexts = ax.pie(\n",
|
||
" counts,\n",
|
||
" # labels=categories,\n",
|
||
" autopct='%1.0f%%',\n",
|
||
" startangle=90,\n",
|
||
" pctdistance=0.85,\n",
|
||
" labeldistance=1.1\n",
|
||
")\n",
|
||
"ax.legend(wedges, categories, title=\"Categories\", loc=\"lower center\", bbox_to_anchor=(0.5, 1.15), ncol=3)\n",
|
||
"\n",
|
||
"# Draw donut center\n",
|
||
"centre_circle = plt.Circle((0, 0), 0.70, fc='white')\n",
|
||
"fig.gca().add_artist(centre_circle)\n",
|
||
"\n",
|
||
"# Add center label\n",
|
||
"ax.text(0, 0, \"Categories\", ha='center', va='center', fontsize=14, fontweight='bold')\n",
|
||
"\n",
|
||
"# Equal aspect ratio\n",
|
||
"plt.axis('equal')\n",
|
||
"plt.title(\"Category Distribution\")\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "acbc6beb-fab4-49ab-bc7e-243638c1fa99",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# How does the price vary with the character count of the prompt?\n",
|
||
"\n",
|
||
"sizes = [len(item.prompt) for item in sample]\n",
|
||
"prices = [item.price for item in sample]\n",
|
||
"\n",
|
||
"# Create the scatter plot\n",
|
||
"plt.scatter(sizes, prices, s=0.2, color=\"red\")\n",
|
||
"\n",
|
||
"# Add labels and title\n",
|
||
"plt.xlabel('Size')\n",
|
||
"plt.ylabel('Price')\n",
|
||
"plt.title('Is there a simple correlation between prompt length and item price?')\n",
|
||
"\n",
|
||
"# Display the plot\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "76b060a4-0b8d-495c-bb96-28cb7b7ec623",
|
||
"metadata": {},
|
||
"source": [
|
||
"There is no strong or simple correlation between prompt length and item price.\n",
|
||
"\n",
|
||
"In other words, longer prompts don’t clearly mean higher prices, and vice versa."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0f33211c-3548-4a21-990b-21aa55089186",
|
||
"metadata": {},
|
||
"source": [
|
||
"## ✅ Final Check Before Training"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "be8d0c68-ac6e-4a4d-a6c7-64e9c6763ec4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Ensure the price label is correctly placed by the end of the prompt\n",
|
||
"\n",
|
||
"def report(item):\n",
|
||
" prompt = item.prompt\n",
|
||
" tokens = Item.tokenizer.encode(item.prompt)\n",
|
||
" print(prompt)\n",
|
||
" print(tokens[-6:])\n",
|
||
" print(Item.tokenizer.batch_decode(tokens[-6:]))\n",
|
||
"\n",
|
||
"report(sample[50])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "656d523d-8297-4d75-a973-a7e5517d21bc",
|
||
"metadata": {},
|
||
"source": [
|
||
"LLaMA and GPT-4o both tokenize numbers from 1 to 999 as a single token, while models like Qwen2, Gemma, and Phi-3 split them into multiple tokens. This helps keep prices compact in our prompts — useful for our project, though not strictly required."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e36254ba-d20f-44ad-b991-1f1f3cdc4aaa",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📦 Creating Train/Test Datasets"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "5cfb5092-c38d-4c14-8dd0-e1d97c06d7f6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"random.seed(42)\n",
|
||
"random.shuffle(sample)\n",
|
||
"train = sample[:400_000]\n",
|
||
"test = sample[400_000:402_000]\n",
|
||
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "2f084822-e489-4946-8cf5-f5b0ebd7a23c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(train[0].prompt)\n",
|
||
"print('*' * 40)\n",
|
||
"print(test[0].test_prompt())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d49a08ce-dd41-4af8-82f6-4701628e8152",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Plot the distribution of prices in the first 250 test points\n",
|
||
"\n",
|
||
"prices = [float(item.price) for item in test[:250]]\n",
|
||
"plt.figure(figsize=(15, 6))\n",
|
||
"plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
|
||
"plt.xlabel('Price ($)')\n",
|
||
"plt.ylabel('Count')\n",
|
||
"plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0c581439-93f2-422a-924f-fd6c58ef8693",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Extract prompts and prices\n",
|
||
"train_prompts = [item.prompt for item in train]\n",
|
||
"train_prices = [item.price for item in train]\n",
|
||
"test_prompts = [item.test_prompt() for item in test]\n",
|
||
"test_prices = [item.price for item in test]\n",
|
||
"\n",
|
||
"# Create Hugging Face datasets\n",
|
||
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
|
||
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
|
||
"dataset = DatasetDict({\n",
|
||
" \"train\": train_dataset,\n",
|
||
" \"test\": test_dataset\n",
|
||
"})\n",
|
||
"\n",
|
||
"# Save full Item objects\n",
|
||
"os.makedirs(\"data\", exist_ok=True) # Make sure the folder exists\n",
|
||
"\n",
|
||
"# Save full Item objects to the folder\n",
|
||
"with open('data/train.pkl', 'wb') as file:\n",
|
||
" pickle.dump(train, file)\n",
|
||
"\n",
|
||
"with open('data/test.pkl', 'wb') as file:\n",
|
||
" pickle.dump(test, file)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "3914d029-350e-4140-a31f-e931fa289a41",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Push to the Hugging Face Hub\n",
|
||
"USERNAME = \"lisekarimi\" # 🔧 Replace with your Hugging Face username\n",
|
||
"DATASET_NAME = f\"{USERNAME}/pricer-data\"\n",
|
||
"\n",
|
||
"dataset.push_to_hub(DATASET_NAME, private=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3d8f3b33-41f8-4ee6-96ed-27677ffc8ec4",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Note:** \n",
|
||
"- The dataset `pricer-data` on Hugging Face only contains `text` and `price`:\n",
|
||
"\n",
|
||
"\n",
|
||
"{\n",
|
||
" \"text\": \"How much does this cost...Price is $175.00\",\n",
|
||
" \"price\": 175.0\n",
|
||
"}\n",
|
||
"\n",
|
||
"- Full `Item` objects (with metadata) are available in `train.pkl` and `test.pkl`:\n",
|
||
"\n",
|
||
"Item(data={\n",
|
||
" \"title\": str,\n",
|
||
" \"description\": list[str],\n",
|
||
" \"features\": list[str],\n",
|
||
" \"details\": str\n",
|
||
"}, price=float)\n",
|
||
"\n",
|
||
"\n",
|
||
"Now, it’s time to move on to **Part 2: Model Benchmarking – Traditional ML vs Frontier LLMs.**\n",
|
||
"\n",
|
||
"🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part2_tradml_vs_frontier.ipynb)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|