From 6a1717a7b397ab62ce56676937310bf226f3f3f8 Mon Sep 17 00:00:00 2001 From: abdoulrasheed Date: Thu, 30 Oct 2025 03:52:03 +0000 Subject: [PATCH] W6 --- .../abdoul/week_six_exercise.ipynb | 1110 +++++++++++++++++ 1 file changed, 1110 insertions(+) create mode 100644 community-contributions/abdoul/week_six_exercise.ipynb diff --git a/community-contributions/abdoul/week_six_exercise.ipynb b/community-contributions/abdoul/week_six_exercise.ipynb new file mode 100644 index 0000000..d7382bc --- /dev/null +++ b/community-contributions/abdoul/week_six_exercise.ipynb @@ -0,0 +1,1110 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3", + "metadata": { + "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3" + }, + "source": [ + "# Fine Tuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "n9sehdR5Cv6A", + "metadata": { + "id": "n9sehdR5Cv6A" + }, + "outputs": [], + "source": [ + "!pip install gensim\n", + "!pip install --upgrade datasets==3.6.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9bf9a6a", + "metadata": { + "id": "b9bf9a6a" + }, + "outputs": [], + "source": [ + "import os\n", + "import math\n", + "import random\n", + "import json\n", + "import pickle\n", + "import re\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from pathlib import Path\n", + "from openai import OpenAI\n", + "from datetime import datetime\n", + "from dotenv import load_dotenv\n", + "import matplotlib.pyplot as plt\n", + "from huggingface_hub import login\n", + "from sklearn.svm import LinearSVR\n", + "from gensim.models import Word2Vec\n", + "from IPython.display import display\n", + "from transformers import AutoTokenizer\n", + "from gensim.utils import simple_preprocess\n", + "from collections import Counter, defaultdict\n", + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "from concurrent.futures import ProcessPoolExecutor\n", + "from datasets import Dataset, DatasetDict, load_dataset\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", + "from sklearn.feature_extraction.text import CountVectorizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "261d16fa", + "metadata": { + "id": "261d16fa" + }, + "outputs": [], + "source": [ + "load_dotenv(override=True)\n", + "openai_key = os.environ.get(\"OPENAI_API_KEY\")\n", + "\n", + "#anthropic_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n", + "\n", + "hf_token = os.environ.get(\"HF_TOKEN\")\n", + "print(hf_token)\n", + "\n", + "if hf_token:\n", + " print(\"Loggin in...\")\n", + " login(hf_token, add_to_git_credential=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cdfe762-3200-4459-981e-0ded7c14b4de", + "metadata": { + "id": "2cdfe762-3200-4459-981e-0ded7c14b4de" + }, + "outputs": [], + "source": [ + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "RESET = \"\\033[0m\"\n", + "COLOR_MAP = {\"red\": RED, \"orange\": YELLOW, \"green\": GREEN}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0832e74b-2779-4822-8e6c-4361ec165c7f", + "metadata": { + "id": "0832e74b-2779-4822-8e6c-4361ec165c7f" + }, + "outputs": [], + "source": [ + "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", + "\n", + "MIN_CHARS = 300\n", + "MIN_TOKENS = 150\n", + "MAX_TOKENS = 160\n", + "CEILING_CHARS = MAX_TOKENS * 7\n", + "\n", + "class Item:\n", + " tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", + " PREFIX = \"Price is $\"\n", + " QUESTION = \"How much does this cost to the nearest dollar?\"\n", + " REMOVALS = ['\"Batteries Included?\": \"No\"', '\"Batteries Included?\": \"Yes\"', '\"Batteries Required?\": \"No\"', '\"Batteries Required?\": \"Yes\"', \"By Manufacturer\", \"Item\", \"Date First\", \"Package\", \":\", \"Number of\", \"Best Sellers\", \"Number\", \"Product \"]\n", + "\n", + " def __init__(self, data, price):\n", + " self.title = data[\"title\"]\n", + " self.price = price\n", + " self.category = data.get(\"category\", \"Unknown\")\n", + " self.token_count = 0\n", + " self.details = None\n", + " self.prompt = None\n", + " self.include = False\n", + " self.parse(data)\n", + "\n", + " def scrub_details(self):\n", + " details = self.details\n", + "\n", + " for remove in self.REMOVALS:\n", + " details = details.replace(remove, \"\")\n", + "\n", + " return details\n", + "\n", + " def scrub(self, text):\n", + " text = re.sub(r'[:\\[\\]\"{}【】\\s]+', ' ', text).strip()\n", + " text = text.replace(\" ,\", \",\").replace(\",,,\",\",\").replace(\",,\",\",\")\n", + " words = text.split(\" \")\n", + " select = [word for word in words if len(word) < 7 or not any(char.isdigit() for char in word)]\n", + " return \" \".join(select)\n", + "\n", + " def parse(self, data):\n", + " contents = '\\n'.join(data.get(\"description\", []))\n", + "\n", + " if contents:\n", + " contents += '\\n'\n", + "\n", + " features = '\\n'.join(data.get(\"features\", []))\n", + " if features:\n", + " contents += features + '\\n'\n", + "\n", + " self.details = data.get(\"details\")\n", + " if self.details:\n", + " contents += self.scrub_details() + '\\n'\n", + "\n", + " if len(contents) > MIN_CHARS:\n", + " contents = contents[:CEILING_CHARS]\n", + " text = f\"{self.scrub(self.title)}\\n{self.scrub(contents)}\"\n", + " tokens = self.tokenizer.encode(text, add_special_tokens=False)\n", + "\n", + " if len(tokens) > MIN_TOKENS:\n", + " tokens = tokens[:MAX_TOKENS]\n", + " text = self.tokenizer.decode(tokens)\n", + " self.make_prompt(text)\n", + " self.include = True\n", + "\n", + " def make_prompt(self, text):\n", + " self.prompt = f\"{self.QUESTION}\\n\\n{text}\\n\\n\"\n", + " self.prompt += f\"{self.PREFIX}{str(round(self.price))}.00\"\n", + " self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))\n", + "\n", + " def test_prompt(self):\n", + " return self.prompt.split(self.PREFIX)[0] + self.PREFIX\n", + "\n", + " def __repr__(self):\n", + " return f\"<{self.title} = ${self.price}>\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa478d70", + "metadata": { + "id": "aa478d70" + }, + "outputs": [], + "source": [ + "MIN_PRICE = 0.5\n", + "CHUNK_SIZE = 1000\n", + "MAX_PRICE = 999.49\n", + "\n", + "class ItemLoader:\n", + " def __init__(self, name):\n", + " self.name = name\n", + " self.dataset = None\n", + "\n", + " def from_datapoint(self, datapoint):\n", + " try:\n", + " price_str = datapoint.get(\"price\")\n", + " if price_str:\n", + " price = float(price_str)\n", + " if MIN_PRICE <= price <= MAX_PRICE:\n", + " item = Item(datapoint, price)\n", + " if item.include:\n", + " return item\n", + " except ValueError:\n", + " return None\n", + "\n", + " def from_chunk(self, chunk):\n", + " batch = []\n", + " for datapoint in chunk:\n", + " item = self.from_datapoint(datapoint)\n", + "\n", + " if item:\n", + " batch.append(item)\n", + "\n", + " return batch\n", + "\n", + " def chunk_generator(self):\n", + " size = len(self.dataset)\n", + " for start in range(0, size, CHUNK_SIZE):\n", + " yield self.dataset.select(range(start, min(start + CHUNK_SIZE, size)))\n", + "\n", + " def load_in_parallel(self, workers):\n", + " results = []\n", + " chunk_count = (len(self.dataset) // CHUNK_SIZE) + 1\n", + "\n", + " with ProcessPoolExecutor(max_workers=workers) as pool:\n", + " for batch in tqdm(pool.map(self.from_chunk, self.chunk_generator()), total=chunk_count):\n", + " results.extend(batch)\n", + "\n", + " for result in results:\n", + " result.category = self.name\n", + "\n", + " return results\n", + "\n", + " def load(self, workers=8):\n", + " self.dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_{self.name}\", split=\"full\", trust_remote_code=True)\n", + " start = datetime.now()\n", + " print(f\"Loading {self.dataset}\")\n", + " results = self.load_in_parallel(workers)\n", + " duration = (datetime.now() - start).total_seconds() / 60\n", + " print(f\"Completed {self.name} with {len(results):,} items in {duration:.1f} mins\")\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6XywRUiUro69", + "metadata": { + "id": "6XywRUiUro69" + }, + "outputs": [], + "source": [ + "class Tester:\n", + "\n", + " def __init__(self, predictor, data, title=None, size=250):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", + " self.size = size\n", + " self.guesses = []\n", + " self.truths = []\n", + " self.errors = []\n", + " self.sles = []\n", + " self.colors = []\n", + "\n", + " def color_for(self, error, truth):\n", + " if error < 40 or error / truth < 0.2:\n", + " return \"green\"\n", + "\n", + " if error < 80 or error / truth < 0.4:\n", + " return \"orange\"\n", + "\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, index):\n", + " datapoint = self.data[index]\n", + " guess = self.predictor(datapoint)\n", + " truth = datapoint.price\n", + " error = abs(guess - truth)\n", + " log_error = math.log(truth + 1) - math.log(guess + 1)\n", + " sle = log_error ** 2\n", + " color = self.color_for(error, truth)\n", + " name = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40] + \"...\"\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.sles.append(sle)\n", + " self.colors.append(color)\n", + " print(f\"{COLOR_MAP[color]}{index + 1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {name}{RESET}\")\n", + "\n", + " def chart(self, title):\n", + " plt.figure(figsize=(12, 8))\n", + " max_val = max(max(self.truths), max(self.guesses))\n", + " plt.plot([0, max_val], [0, max_val], color=\"deepskyblue\", lw=2, alpha=0.6)\n", + " plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", + " plt.xlabel(\"Ground Truth\")\n", + " plt.ylabel(\"Model Estimate\")\n", + " plt.xlim(0, max_val)\n", + " plt.ylim(0, max_val)\n", + " plt.title(title)\n", + " plt.show()\n", + "\n", + " def report(self):\n", + " average_error = sum(self.errors) / self.size\n", + " rmsle = math.sqrt(sum(self.sles) / self.size)\n", + " hits = sum(1 for color in self.colors if color == \"green\")\n", + " title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits / self.size * 100:.1f}%\"\n", + " self.chart(title)\n", + "\n", + " def run(self):\n", + " for index in range(self.size):\n", + " self.run_datapoint(index)\n", + "\n", + " self.report()\n", + "\n", + " @classmethod\n", + " def test(cls, function, data):\n", + " cls(function, data).run()\n", + "\n", + "def get_price(s):\n", + " s = s.replace(\"$\", \"\").replace(\",\", \"\")\n", + " match = re.search(r\"[-+]?\\d*\\.?\\d+\", s)\n", + " return float(match.group()) if match else 0.0" + ] + }, + { + "cell_type": "markdown", + "id": "9856f570", + "metadata": { + "id": "9856f570" + }, + "source": [ + "## Data" + ] + }, + { + "cell_type": "markdown", + "id": "3XTxVhq0xC8Z", + "metadata": { + "id": "3XTxVhq0xC8Z" + }, + "source": [ + "### Load Catalogs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22", + "metadata": { + "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22" + }, + "outputs": [], + "source": [ + "catalog_labels = [\n", + " \"All_Beauty\",\n", + " # \"Automotive\",\n", + " # \"Electronics\",\n", + " # \"Office_Products\",\n", + " # \"Tools_and_Home_Improvement\",\n", + " # \"Cell_Phones_and_Accessories\",\n", + " # \"Toys_and_Games\",\n", + " \"Appliances\",\n", + " \"Musical_Instruments\",\n", + " \"Software\",\n", + " \"Handmade_Products\"\n", + "]\n", + "curated_pool = []\n", + "\n", + "for label in catalog_labels:\n", + " print(\"Loading \" + label)\n", + " loader = ItemLoader(label)\n", + " curated_pool.extend(loader.load())\n", + "\n", + "print(f\"Total curated items: {len(curated_pool):,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d", + "metadata": { + "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d" + }, + "outputs": [], + "source": [ + "price_series = [item.price for item in curated_pool]\n", + "token_series = [item.token_count for item in curated_pool]\n", + "category_tally = Counter(item.category for item in curated_pool)\n", + "summary_frame = pd.DataFrame({\"price\": price_series, \"tokens\": token_series})\n", + "\n", + "display(summary_frame.describe())\n", + "display(pd.DataFrame.from_dict(category_tally, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b", + "metadata": { + "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b" + }, + "outputs": [], + "source": [ + "price_slots = defaultdict(list)\n", + "for item in curated_pool:\n", + " key = round(item.price)\n", + " if 1 <= key <= 999:\n", + " price_slots[key].append(item)\n", + "\n", + "slot_counts = {k: len(v) for k, v in price_slots.items()}\n", + "print(f\"Slots populated: {len(slot_counts)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585", + "metadata": { + "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585" + }, + "outputs": [], + "source": [ + "random.seed(123)\n", + "np.random.seed(123)\n", + "balanced_bundle = []\n", + "\n", + "for price in range(1, 1000):\n", + " bucket = price_slots.get(price, [])\n", + "\n", + " if price >= 240:\n", + " balanced_bundle.extend(bucket)\n", + "\n", + " elif len(bucket) <= 1200:\n", + " balanced_bundle.extend(bucket)\n", + "\n", + " else:\n", + " weights = np.array([1 if item.category == \"Automotive\" else 5 for item in bucket], dtype=float)\n", + " weights /= weights.sum()\n", + " indices = np.random.choice(len(bucket), size=1200, replace=False, p=weights)\n", + " for idx in indices:\n", + " balanced_bundle.append(bucket[idx])\n", + "\n", + "print(f\"Balanced bundle size: {len(balanced_bundle):,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a506f42c-81c0-4198-bc0b-1e0653620be8", + "metadata": { + "id": "a506f42c-81c0-4198-bc0b-1e0653620be8" + }, + "outputs": [], + "source": [ + "bundle_prices = [item.price for item in balanced_bundle]\n", + "bundle_tokens = [item.token_count for item in balanced_bundle]\n", + "bundle_categories = Counter(item.category for item in balanced_bundle)\n", + "display(pd.Series(bundle_prices).describe())\n", + "display(pd.DataFrame.from_dict(bundle_categories, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88", + "metadata": { + "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(12, 5))\n", + "plt.hist(bundle_prices, bins=range(0, 1000, 10), color=\"midnightblue\", rwidth=0.8)\n", + "plt.xlabel(\"Price\")\n", + "plt.ylabel(\"Count\")\n", + "plt.figure(figsize=(12, 5))\n", + "plt.hist(bundle_tokens, bins=range(0, 300, 10), color=\"forestgreen\", rwidth=0.8)\n", + "plt.xlabel(\"Tokens\")\n", + "plt.ylabel(\"Count\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42ee0099-0d2a-4331-a01c-3462363a6987", + "metadata": { + "id": "42ee0099-0d2a-4331-a01c-3462363a6987" + }, + "outputs": [], + "source": [ + "random.seed(123)\n", + "random.shuffle(balanced_bundle)\n", + "test_target = min(2000, max(1, len(balanced_bundle) // 20))\n", + "train_target = min(400_000, len(balanced_bundle) - test_target)\n", + "train_items = balanced_bundle[:train_target]\n", + "test_items = balanced_bundle[train_target:train_target + test_target]\n", + "print(f\"Training set: {len(train_items):,}\")\n", + "print(f\"Test set: {len(test_items):,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b", + "metadata": { + "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b" + }, + "outputs": [], + "source": [ + "train_prompts = [item.prompt for item in train_items]\n", + "train_prices = [item.price for item in train_items]\n", + "test_prompts = [item.test_prompt() for item in test_items]\n", + "test_prices = [item.price for item in test_items]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16", + "metadata": { + "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16" + }, + "outputs": [], + "source": [ + "train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", + "test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", + "pricing_dataset = DatasetDict({\"train\": train_dataset, \"test\": test_dataset})" + ] + }, + { + "cell_type": "markdown", + "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8", + "metadata": { + "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8" + }, + "source": [ + "### Persist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c", + "metadata": { + "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c" + }, + "outputs": [], + "source": [ + "storage_dir = Path(\"data\")\n", + "storage_dir.mkdir(exist_ok=True)\n", + "\n", + "with open(storage_dir / \"balanced_train.pkl\", \"wb\") as f:\n", + " pickle.dump(train_items, f)\n", + "\n", + "with open(storage_dir / \"balanced_test.pkl\", \"wb\") as f:\n", + " pickle.dump(test_items, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753", + "metadata": { + "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753" + }, + "outputs": [], + "source": [ + "pricing_dataset[\"train\"].to_parquet(storage_dir / \"balanced_train.parquet\")\n", + "pricing_dataset[\"test\"].to_parquet(storage_dir / \"balanced_test.parquet\")" + ] + }, + { + "cell_type": "markdown", + "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013", + "metadata": { + "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013" + }, + "source": [ + "## Baselines" + ] + }, + { + "cell_type": "markdown", + "id": "qX0c_prppnyZ", + "metadata": { + "id": "qX0c_prppnyZ" + }, + "source": [ + "### Stochastic Anchor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b", + "metadata": { + "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b" + }, + "outputs": [], + "source": [ + "def stochastic_anchor(item):\n", + " return random.randrange(1, 1000)\n", + "\n", + "random.seed(123)\n", + "Tester.test(stochastic_anchor, test_items[:250])" + ] + }, + { + "cell_type": "markdown", + "id": "O0xVXRXkp9sQ", + "metadata": { + "id": "O0xVXRXkp9sQ" + }, + "source": [ + "### Global Mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a932b0e-ba6e-45d2-8436-b740c3681272", + "metadata": { + "id": "6a932b0e-ba6e-45d2-8436-b740c3681272" + }, + "outputs": [], + "source": [ + "train_price_values = [item.price for item in train_items]\n", + "global_mean_price = sum(train_price_values) / len(train_price_values)\n", + "\n", + "def global_mean_estimator(item):\n", + " return global_mean_price\n", + "\n", + "Tester.test(global_mean_estimator, test_items[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4", + "metadata": { + "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4" + }, + "outputs": [], + "source": [ + "def parse_features(raw):\n", + " if not raw:\n", + " return {}\n", + " try:\n", + " return json.loads(raw)\n", + " except json.JSONDecodeError:\n", + " return {}\n", + "for item in train_items:\n", + " item.features = parse_features(item.details)\n", + "for item in test_items:\n", + " item.features = parse_features(item.details)" + ] + }, + { + "cell_type": "markdown", + "id": "44537051-7b4e-4b8c-95a7-a989ea51e517", + "metadata": { + "id": "44537051-7b4e-4b8c-95a7-a989ea51e517" + }, + "source": [ + "### Feature Engineering" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec", + "metadata": { + "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec" + }, + "outputs": [], + "source": [ + "def infer_weight(item):\n", + " payload = item.features.get(\"Item Weight\")\n", + " if not payload:\n", + " return None\n", + "\n", + " parts = payload.split(\" \")\n", + " amount = float(parts[0])\n", + " unit = parts[1].lower()\n", + "\n", + " if unit == \"pounds\":\n", + " return amount\n", + "\n", + " if unit == \"ounces\":\n", + " return amount / 16\n", + "\n", + " if unit == \"grams\":\n", + " return amount / 453.592\n", + "\n", + " if unit == \"milligrams\":\n", + " return amount / 453592\n", + "\n", + " if unit == \"kilograms\":\n", + " return amount / 0.453592\n", + "\n", + " if unit == \"hundredths\" and len(parts) > 2 and parts[2].lower() == \"pounds\":\n", + " return amount / 100\n", + "\n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d7b6f35-890c-4227-8990-6b62694a332d", + "metadata": { + "id": "4d7b6f35-890c-4227-8990-6b62694a332d" + }, + "outputs": [], + "source": [ + "def infer_rank(item):\n", + " payload = item.features.get(\"Best Sellers Rank\")\n", + " if not payload:\n", + " return None\n", + "\n", + " values = list(payload.values()) if isinstance(payload, dict) else []\n", + " if not values:\n", + " return None\n", + "\n", + " return sum(values) / len(values)\n", + "\n", + "top_brands = {\"nvidea\",\"hp\",\"dell\",\"lenovo\",\"samsung\",\"asus\",\"sony\",\"canon\",\"apple\",\"intel\"}\n", + "\n", + "def is_top_brand(item):\n", + " brand = item.features.get(\"Brand\")\n", + " return 1 if brand and brand.lower() in top_brands else 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8", + "metadata": { + "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8" + }, + "outputs": [], + "source": [ + "train_weights = [infer_weight(item) for item in train_items]\n", + "train_weights = [value for value in train_weights if value is not None]\n", + "average_weight = sum(train_weights) / len(train_weights) if train_weights else 1.0\n", + "train_ranks = [infer_rank(item) for item in train_items]\n", + "train_ranks = [value for value in train_ranks if value is not None]\n", + "average_rank = sum(train_ranks) / len(train_ranks) if train_ranks else 1_000_000.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42", + "metadata": { + "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42" + }, + "outputs": [], + "source": [ + "def build_features(item):\n", + " weight = infer_weight(item)\n", + " rank = infer_rank(item)\n", + "\n", + " return {\n", + " \"weight\": weight if weight is not None else average_weight,\n", + " \"rank\": rank if rank is not None else average_rank,\n", + " \"text_length\": len(item.test_prompt()),\n", + " \"top_brand\": is_top_brand(item)\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "189e959c-d70c-4509-bff6-1cbd8e8db637", + "metadata": { + "id": "189e959c-d70c-4509-bff6-1cbd8e8db637" + }, + "outputs": [], + "source": [ + "train_frame = pd.DataFrame([build_features(item) for item in train_items])\n", + "train_frame[\"price\"] = [item.price for item in train_items]\n", + "test_frame = pd.DataFrame([build_features(item) for item in test_items[:250]])\n", + "test_frame[\"price\"] = [item.price for item in test_items[:250]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2", + "metadata": { + "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2" + }, + "outputs": [], + "source": [ + "feature_columns = [\"weight\", \"rank\", \"text_length\", \"top_brand\"]\n", + "X_train = train_frame[feature_columns]\n", + "y_train = train_frame[\"price\"]\n", + "X_test = test_frame[feature_columns]\n", + "y_test = test_frame[\"price\"]\n", + "linear_model = LinearRegression()\n", + "linear_model.fit(X_train, y_train)\n", + "\n", + "def linear_baseline(item):\n", + " return float(linear_model.predict(pd.DataFrame([build_features(item)]))[0])\n", + "\n", + "Tester.test(linear_baseline, test_items[:250])" + ] + }, + { + "cell_type": "markdown", + "id": "ga-f4JK7sPU2", + "metadata": { + "id": "ga-f4JK7sPU2" + }, + "source": [ + "### NLP Baselines" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de958a51-69ba-420c-84b7-d32765898fd2", + "metadata": { + "id": "de958a51-69ba-420c-84b7-d32765898fd2" + }, + "outputs": [], + "source": [ + "document_texts = [item.test_prompt() for item in train_items]\n", + "price_targets = np.array([item.price for item in train_items])\n", + "vectorizer = CountVectorizer(max_features=1000, stop_words=\"english\")\n", + "X_matrix = vectorizer.fit_transform(document_texts)\n", + "bow_model = LinearRegression()\n", + "bow_model.fit(X_matrix, price_targets)\n", + "\n", + "def bow_predictor(item):\n", + " pred = float(bow_model.predict(vectorizer.transform([item.test_prompt()]))[0])\n", + " return max(pred, 0)\n", + "\n", + "Tester.test(bow_predictor, test_items[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "QFDAoNnoRCk1", + "metadata": { + "id": "QFDAoNnoRCk1" + }, + "outputs": [], + "source": [ + "processed_docs = [simple_preprocess(text) for text in document_texts]\n", + "word2vec_model = Word2Vec(sentences=processed_docs, vector_size=400, window=5, min_count=1, workers=4)\n", + "\n", + "def document_vector(text):\n", + " words = simple_preprocess(text)\n", + " vectors = [word2vec_model.wv[word] for word in words if word in word2vec_model.wv]\n", + "\n", + " if not vectors:\n", + " return np.zeros(word2vec_model.vector_size)\n", + " return np.mean(vectors, axis=0)\n", + "\n", + "w2v_features = np.array([document_vector(text) for text in document_texts])\n", + "svr_model = LinearSVR()\n", + "svr_model.fit(w2v_features, price_targets)\n", + "\n", + "def w2v_predictor(item):\n", + " return float(svr_model.predict([document_vector(item.test_prompt())])[0])\n", + "\n", + "Tester.test(w2v_predictor, test_items[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "kBVWisusQwDq", + "metadata": { + "id": "kBVWisusQwDq" + }, + "outputs": [], + "source": [ + "forest_model = RandomForestRegressor(n_estimators=200, random_state=123)\n", + "forest_model.fit(X_train, y_train)\n", + "\n", + "def forest_predictor(item):\n", + " return float(forest_model.predict(pd.DataFrame([build_features(item)])[feature_columns])[0])\n", + "\n", + "Tester.test(forest_predictor, test_items[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "wgth1KvMSEOb", + "metadata": { + "id": "wgth1KvMSEOb" + }, + "outputs": [], + "source": [ + "fine_tune_train = train_items[:200]\n", + "fine_tune_validation = train_items[200:250]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "g7uz8SC5S3_s", + "metadata": { + "id": "g7uz8SC5S3_s" + }, + "outputs": [], + "source": [ + "def compose_messages(item, include_price=True):\n", + " system_message = \"You estimate prices of items. Reply only with the price\"\n", + " user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n", + " assistant_content = f\"Price is ${item.price:.2f}\" if include_price else \"Price is $\"\n", + " return [\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\"role\": \"user\", \"content\": user_prompt},\n", + " {\"role\": \"assistant\", \"content\": assistant_content}\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "_zHswJwzWCHZ", + "metadata": { + "id": "_zHswJwzWCHZ" + }, + "outputs": [], + "source": [ + "def build_jsonl(items):\n", + " lines = []\n", + " for item in items:\n", + " payload = {\"messages\": compose_messages(item)}\n", + " lines.append(json.dumps(payload))\n", + "\n", + " return \"\\n\".join(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rSHYkQojWH8Q", + "metadata": { + "id": "rSHYkQojWH8Q" + }, + "outputs": [], + "source": [ + "train_jsonl = storage_dir / \"balanced_pricer_train.jsonl\"\n", + "validation_jsonl = storage_dir / \"balanced_pricer_validation.jsonl\"\n", + "train_jsonl.write_text(build_jsonl(fine_tune_train))\n", + "validation_jsonl.write_text(build_jsonl(fine_tune_validation))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37BH0u-QWOiY", + "metadata": { + "id": "37BH0u-QWOiY" + }, + "outputs": [], + "source": [ + "openai_client = OpenAI()\n", + "\n", + "with open(train_jsonl, \"rb\") as f:\n", + " train_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n", + "\n", + "with open(validation_jsonl, \"rb\") as f:\n", + " validation_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n", + "\n", + "train_file, validation_file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2nNSE_AzWYMq", + "metadata": { + "id": "2nNSE_AzWYMq" + }, + "outputs": [], + "source": [ + "wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"balanced-pricer\"}}\n", + "fine_tune_job = openai_client.fine_tuning.jobs.create(\n", + " training_file=train_file.id,\n", + " validation_file=validation_file.id,\n", + " model=\"gpt-4o-mini-2024-07-18\",\n", + " seed=123,\n", + " hyperparameters={\"n_epochs\": 1},\n", + " integrations=[wandb_integration],\n", + " suffix=\"balanced-pricer\"\n", + ")\n", + "fine_tune_job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ASiJUw-Fh8Ul", + "metadata": { + "id": "ASiJUw-Fh8Ul" + }, + "outputs": [], + "source": [ + "job_status = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id)\n", + "job_events = openai_client.fine_tuning.jobs.list_events(fine_tuning_job_id=fine_tune_job.id, limit=10)\n", + "job_status, job_events" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7jB_7gqBiH_r", + "metadata": { + "id": "7jB_7gqBiH_r" + }, + "outputs": [], + "source": [ + "fine_tuned_model_name = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id).fine_tuned_model\n", + "print(fine_tuned_model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "BHfLSadhiVQE", + "metadata": { + "id": "BHfLSadhiVQE" + }, + "outputs": [], + "source": [ + "def tuned_predictor(item):\n", + " messages = compose_messages(item, include_price=False)\n", + " response = openai_client.chat.completions.create(\n", + " model=fine_tuned_model_name,\n", + " messages=messages,\n", + " seed=123,\n", + " max_tokens=7\n", + ")\n", + " answer = response.choices[0].message.content\n", + " return get_price(answer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "C0CiTZ4jkjrI", + "metadata": { + "id": "C0CiTZ4jkjrI" + }, + "outputs": [], + "source": [ + "if test_items:\n", + " sample_item = test_items[0]\n", + " print(sample_item.price)\n", + " print(tuned_predictor(sample_item))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "WInQE0ObkuBl", + "metadata": { + "id": "WInQE0ObkuBl" + }, + "outputs": [], + "source": [ + "Tester.test(tuned_predictor, test_items[:250])" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}