1111 lines
37 KiB
Plaintext
1111 lines
37 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3",
|
|
"metadata": {
|
|
"id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3"
|
|
},
|
|
"source": [
|
|
"# Fine Tuning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "n9sehdR5Cv6A",
|
|
"metadata": {
|
|
"id": "n9sehdR5Cv6A"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install gensim\n",
|
|
"!pip install --upgrade datasets==3.6.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b9bf9a6a",
|
|
"metadata": {
|
|
"id": "b9bf9a6a"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import math\n",
|
|
"import random\n",
|
|
"import json\n",
|
|
"import pickle\n",
|
|
"import re\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from tqdm import tqdm\n",
|
|
"from pathlib import Path\n",
|
|
"from openai import OpenAI\n",
|
|
"from datetime import datetime\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from huggingface_hub import login\n",
|
|
"from sklearn.svm import LinearSVR\n",
|
|
"from gensim.models import Word2Vec\n",
|
|
"from IPython.display import display\n",
|
|
"from transformers import AutoTokenizer\n",
|
|
"from gensim.utils import simple_preprocess\n",
|
|
"from collections import Counter, defaultdict\n",
|
|
"from sklearn.linear_model import LinearRegression\n",
|
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|
"from concurrent.futures import ProcessPoolExecutor\n",
|
|
"from datasets import Dataset, DatasetDict, load_dataset\n",
|
|
"from sklearn.metrics import mean_squared_error, r2_score\n",
|
|
"from sklearn.feature_extraction.text import CountVectorizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "261d16fa",
|
|
"metadata": {
|
|
"id": "261d16fa"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"load_dotenv(override=True)\n",
|
|
"openai_key = os.environ.get(\"OPENAI_API_KEY\")\n",
|
|
"\n",
|
|
"#anthropic_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n",
|
|
"\n",
|
|
"hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
|
"print(hf_token)\n",
|
|
"\n",
|
|
"if hf_token:\n",
|
|
" print(\"Loggin in...\")\n",
|
|
" login(hf_token, add_to_git_credential=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2cdfe762-3200-4459-981e-0ded7c14b4de",
|
|
"metadata": {
|
|
"id": "2cdfe762-3200-4459-981e-0ded7c14b4de"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"GREEN = \"\\033[92m\"\n",
|
|
"YELLOW = \"\\033[93m\"\n",
|
|
"RED = \"\\033[91m\"\n",
|
|
"RESET = \"\\033[0m\"\n",
|
|
"COLOR_MAP = {\"red\": RED, \"orange\": YELLOW, \"green\": GREEN}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0832e74b-2779-4822-8e6c-4361ec165c7f",
|
|
"metadata": {
|
|
"id": "0832e74b-2779-4822-8e6c-4361ec165c7f"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
|
"\n",
|
|
"MIN_CHARS = 300\n",
|
|
"MIN_TOKENS = 150\n",
|
|
"MAX_TOKENS = 160\n",
|
|
"CEILING_CHARS = MAX_TOKENS * 7\n",
|
|
"\n",
|
|
"class Item:\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
|
" PREFIX = \"Price is $\"\n",
|
|
" QUESTION = \"How much does this cost to the nearest dollar?\"\n",
|
|
" REMOVALS = ['\"Batteries Included?\": \"No\"', '\"Batteries Included?\": \"Yes\"', '\"Batteries Required?\": \"No\"', '\"Batteries Required?\": \"Yes\"', \"By Manufacturer\", \"Item\", \"Date First\", \"Package\", \":\", \"Number of\", \"Best Sellers\", \"Number\", \"Product \"]\n",
|
|
"\n",
|
|
" def __init__(self, data, price):\n",
|
|
" self.title = data[\"title\"]\n",
|
|
" self.price = price\n",
|
|
" self.category = data.get(\"category\", \"Unknown\")\n",
|
|
" self.token_count = 0\n",
|
|
" self.details = None\n",
|
|
" self.prompt = None\n",
|
|
" self.include = False\n",
|
|
" self.parse(data)\n",
|
|
"\n",
|
|
" def scrub_details(self):\n",
|
|
" details = self.details\n",
|
|
"\n",
|
|
" for remove in self.REMOVALS:\n",
|
|
" details = details.replace(remove, \"\")\n",
|
|
"\n",
|
|
" return details\n",
|
|
"\n",
|
|
" def scrub(self, text):\n",
|
|
" text = re.sub(r'[:\\[\\]\"{}【】\\s]+', ' ', text).strip()\n",
|
|
" text = text.replace(\" ,\", \",\").replace(\",,,\",\",\").replace(\",,\",\",\")\n",
|
|
" words = text.split(\" \")\n",
|
|
" select = [word for word in words if len(word) < 7 or not any(char.isdigit() for char in word)]\n",
|
|
" return \" \".join(select)\n",
|
|
"\n",
|
|
" def parse(self, data):\n",
|
|
" contents = '\\n'.join(data.get(\"description\", []))\n",
|
|
"\n",
|
|
" if contents:\n",
|
|
" contents += '\\n'\n",
|
|
"\n",
|
|
" features = '\\n'.join(data.get(\"features\", []))\n",
|
|
" if features:\n",
|
|
" contents += features + '\\n'\n",
|
|
"\n",
|
|
" self.details = data.get(\"details\")\n",
|
|
" if self.details:\n",
|
|
" contents += self.scrub_details() + '\\n'\n",
|
|
"\n",
|
|
" if len(contents) > MIN_CHARS:\n",
|
|
" contents = contents[:CEILING_CHARS]\n",
|
|
" text = f\"{self.scrub(self.title)}\\n{self.scrub(contents)}\"\n",
|
|
" tokens = self.tokenizer.encode(text, add_special_tokens=False)\n",
|
|
"\n",
|
|
" if len(tokens) > MIN_TOKENS:\n",
|
|
" tokens = tokens[:MAX_TOKENS]\n",
|
|
" text = self.tokenizer.decode(tokens)\n",
|
|
" self.make_prompt(text)\n",
|
|
" self.include = True\n",
|
|
"\n",
|
|
" def make_prompt(self, text):\n",
|
|
" self.prompt = f\"{self.QUESTION}\\n\\n{text}\\n\\n\"\n",
|
|
" self.prompt += f\"{self.PREFIX}{str(round(self.price))}.00\"\n",
|
|
" self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))\n",
|
|
"\n",
|
|
" def test_prompt(self):\n",
|
|
" return self.prompt.split(self.PREFIX)[0] + self.PREFIX\n",
|
|
"\n",
|
|
" def __repr__(self):\n",
|
|
" return f\"<{self.title} = ${self.price}>\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aa478d70",
|
|
"metadata": {
|
|
"id": "aa478d70"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"MIN_PRICE = 0.5\n",
|
|
"CHUNK_SIZE = 1000\n",
|
|
"MAX_PRICE = 999.49\n",
|
|
"\n",
|
|
"class ItemLoader:\n",
|
|
" def __init__(self, name):\n",
|
|
" self.name = name\n",
|
|
" self.dataset = None\n",
|
|
"\n",
|
|
" def from_datapoint(self, datapoint):\n",
|
|
" try:\n",
|
|
" price_str = datapoint.get(\"price\")\n",
|
|
" if price_str:\n",
|
|
" price = float(price_str)\n",
|
|
" if MIN_PRICE <= price <= MAX_PRICE:\n",
|
|
" item = Item(datapoint, price)\n",
|
|
" if item.include:\n",
|
|
" return item\n",
|
|
" except ValueError:\n",
|
|
" return None\n",
|
|
"\n",
|
|
" def from_chunk(self, chunk):\n",
|
|
" batch = []\n",
|
|
" for datapoint in chunk:\n",
|
|
" item = self.from_datapoint(datapoint)\n",
|
|
"\n",
|
|
" if item:\n",
|
|
" batch.append(item)\n",
|
|
"\n",
|
|
" return batch\n",
|
|
"\n",
|
|
" def chunk_generator(self):\n",
|
|
" size = len(self.dataset)\n",
|
|
" for start in range(0, size, CHUNK_SIZE):\n",
|
|
" yield self.dataset.select(range(start, min(start + CHUNK_SIZE, size)))\n",
|
|
"\n",
|
|
" def load_in_parallel(self, workers):\n",
|
|
" results = []\n",
|
|
" chunk_count = (len(self.dataset) // CHUNK_SIZE) + 1\n",
|
|
"\n",
|
|
" with ProcessPoolExecutor(max_workers=workers) as pool:\n",
|
|
" for batch in tqdm(pool.map(self.from_chunk, self.chunk_generator()), total=chunk_count):\n",
|
|
" results.extend(batch)\n",
|
|
"\n",
|
|
" for result in results:\n",
|
|
" result.category = self.name\n",
|
|
"\n",
|
|
" return results\n",
|
|
"\n",
|
|
" def load(self, workers=8):\n",
|
|
" self.dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_{self.name}\", split=\"full\", trust_remote_code=True)\n",
|
|
" start = datetime.now()\n",
|
|
" print(f\"Loading {self.dataset}\")\n",
|
|
" results = self.load_in_parallel(workers)\n",
|
|
" duration = (datetime.now() - start).total_seconds() / 60\n",
|
|
" print(f\"Completed {self.name} with {len(results):,} items in {duration:.1f} mins\")\n",
|
|
" return results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6XywRUiUro69",
|
|
"metadata": {
|
|
"id": "6XywRUiUro69"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Tester:\n",
|
|
"\n",
|
|
" def __init__(self, predictor, data, title=None, size=250):\n",
|
|
" self.predictor = predictor\n",
|
|
" self.data = data\n",
|
|
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
|
|
" self.size = size\n",
|
|
" self.guesses = []\n",
|
|
" self.truths = []\n",
|
|
" self.errors = []\n",
|
|
" self.sles = []\n",
|
|
" self.colors = []\n",
|
|
"\n",
|
|
" def color_for(self, error, truth):\n",
|
|
" if error < 40 or error / truth < 0.2:\n",
|
|
" return \"green\"\n",
|
|
"\n",
|
|
" if error < 80 or error / truth < 0.4:\n",
|
|
" return \"orange\"\n",
|
|
"\n",
|
|
" return \"red\"\n",
|
|
"\n",
|
|
" def run_datapoint(self, index):\n",
|
|
" datapoint = self.data[index]\n",
|
|
" guess = self.predictor(datapoint)\n",
|
|
" truth = datapoint.price\n",
|
|
" error = abs(guess - truth)\n",
|
|
" log_error = math.log(truth + 1) - math.log(guess + 1)\n",
|
|
" sle = log_error ** 2\n",
|
|
" color = self.color_for(error, truth)\n",
|
|
" name = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40] + \"...\"\n",
|
|
" self.guesses.append(guess)\n",
|
|
" self.truths.append(truth)\n",
|
|
" self.errors.append(error)\n",
|
|
" self.sles.append(sle)\n",
|
|
" self.colors.append(color)\n",
|
|
" print(f\"{COLOR_MAP[color]}{index + 1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {name}{RESET}\")\n",
|
|
"\n",
|
|
" def chart(self, title):\n",
|
|
" plt.figure(figsize=(12, 8))\n",
|
|
" max_val = max(max(self.truths), max(self.guesses))\n",
|
|
" plt.plot([0, max_val], [0, max_val], color=\"deepskyblue\", lw=2, alpha=0.6)\n",
|
|
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
|
|
" plt.xlabel(\"Ground Truth\")\n",
|
|
" plt.ylabel(\"Model Estimate\")\n",
|
|
" plt.xlim(0, max_val)\n",
|
|
" plt.ylim(0, max_val)\n",
|
|
" plt.title(title)\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
" def report(self):\n",
|
|
" average_error = sum(self.errors) / self.size\n",
|
|
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
|
|
" hits = sum(1 for color in self.colors if color == \"green\")\n",
|
|
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits / self.size * 100:.1f}%\"\n",
|
|
" self.chart(title)\n",
|
|
"\n",
|
|
" def run(self):\n",
|
|
" for index in range(self.size):\n",
|
|
" self.run_datapoint(index)\n",
|
|
"\n",
|
|
" self.report()\n",
|
|
"\n",
|
|
" @classmethod\n",
|
|
" def test(cls, function, data):\n",
|
|
" cls(function, data).run()\n",
|
|
"\n",
|
|
"def get_price(s):\n",
|
|
" s = s.replace(\"$\", \"\").replace(\",\", \"\")\n",
|
|
" match = re.search(r\"[-+]?\\d*\\.?\\d+\", s)\n",
|
|
" return float(match.group()) if match else 0.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9856f570",
|
|
"metadata": {
|
|
"id": "9856f570"
|
|
},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3XTxVhq0xC8Z",
|
|
"metadata": {
|
|
"id": "3XTxVhq0xC8Z"
|
|
},
|
|
"source": [
|
|
"### Load Catalogs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22",
|
|
"metadata": {
|
|
"id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"catalog_labels = [\n",
|
|
" \"All_Beauty\",\n",
|
|
" # \"Automotive\",\n",
|
|
" # \"Electronics\",\n",
|
|
" # \"Office_Products\",\n",
|
|
" # \"Tools_and_Home_Improvement\",\n",
|
|
" # \"Cell_Phones_and_Accessories\",\n",
|
|
" # \"Toys_and_Games\",\n",
|
|
" \"Appliances\",\n",
|
|
" \"Musical_Instruments\",\n",
|
|
" \"Software\",\n",
|
|
" \"Handmade_Products\"\n",
|
|
"]\n",
|
|
"curated_pool = []\n",
|
|
"\n",
|
|
"for label in catalog_labels:\n",
|
|
" print(\"Loading \" + label)\n",
|
|
" loader = ItemLoader(label)\n",
|
|
" curated_pool.extend(loader.load())\n",
|
|
"\n",
|
|
"print(f\"Total curated items: {len(curated_pool):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d",
|
|
"metadata": {
|
|
"id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"price_series = [item.price for item in curated_pool]\n",
|
|
"token_series = [item.token_count for item in curated_pool]\n",
|
|
"category_tally = Counter(item.category for item in curated_pool)\n",
|
|
"summary_frame = pd.DataFrame({\"price\": price_series, \"tokens\": token_series})\n",
|
|
"\n",
|
|
"display(summary_frame.describe())\n",
|
|
"display(pd.DataFrame.from_dict(category_tally, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b",
|
|
"metadata": {
|
|
"id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"price_slots = defaultdict(list)\n",
|
|
"for item in curated_pool:\n",
|
|
" key = round(item.price)\n",
|
|
" if 1 <= key <= 999:\n",
|
|
" price_slots[key].append(item)\n",
|
|
"\n",
|
|
"slot_counts = {k: len(v) for k, v in price_slots.items()}\n",
|
|
"print(f\"Slots populated: {len(slot_counts)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585",
|
|
"metadata": {
|
|
"id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"random.seed(123)\n",
|
|
"np.random.seed(123)\n",
|
|
"balanced_bundle = []\n",
|
|
"\n",
|
|
"for price in range(1, 1000):\n",
|
|
" bucket = price_slots.get(price, [])\n",
|
|
"\n",
|
|
" if price >= 240:\n",
|
|
" balanced_bundle.extend(bucket)\n",
|
|
"\n",
|
|
" elif len(bucket) <= 1200:\n",
|
|
" balanced_bundle.extend(bucket)\n",
|
|
"\n",
|
|
" else:\n",
|
|
" weights = np.array([1 if item.category == \"Automotive\" else 5 for item in bucket], dtype=float)\n",
|
|
" weights /= weights.sum()\n",
|
|
" indices = np.random.choice(len(bucket), size=1200, replace=False, p=weights)\n",
|
|
" for idx in indices:\n",
|
|
" balanced_bundle.append(bucket[idx])\n",
|
|
"\n",
|
|
"print(f\"Balanced bundle size: {len(balanced_bundle):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a506f42c-81c0-4198-bc0b-1e0653620be8",
|
|
"metadata": {
|
|
"id": "a506f42c-81c0-4198-bc0b-1e0653620be8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"bundle_prices = [item.price for item in balanced_bundle]\n",
|
|
"bundle_tokens = [item.token_count for item in balanced_bundle]\n",
|
|
"bundle_categories = Counter(item.category for item in balanced_bundle)\n",
|
|
"display(pd.Series(bundle_prices).describe())\n",
|
|
"display(pd.DataFrame.from_dict(bundle_categories, orient=\"index\", columns=[\"count\"]).sort_values(\"count\", ascending=False))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5842ace6-332d-46da-a853-5ea5a2a1cf88",
|
|
"metadata": {
|
|
"id": "5842ace6-332d-46da-a853-5ea5a2a1cf88"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.figure(figsize=(12, 5))\n",
|
|
"plt.hist(bundle_prices, bins=range(0, 1000, 10), color=\"midnightblue\", rwidth=0.8)\n",
|
|
"plt.xlabel(\"Price\")\n",
|
|
"plt.ylabel(\"Count\")\n",
|
|
"plt.figure(figsize=(12, 5))\n",
|
|
"plt.hist(bundle_tokens, bins=range(0, 300, 10), color=\"forestgreen\", rwidth=0.8)\n",
|
|
"plt.xlabel(\"Tokens\")\n",
|
|
"plt.ylabel(\"Count\")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "42ee0099-0d2a-4331-a01c-3462363a6987",
|
|
"metadata": {
|
|
"id": "42ee0099-0d2a-4331-a01c-3462363a6987"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"random.seed(123)\n",
|
|
"random.shuffle(balanced_bundle)\n",
|
|
"test_target = min(2000, max(1, len(balanced_bundle) // 20))\n",
|
|
"train_target = min(400_000, len(balanced_bundle) - test_target)\n",
|
|
"train_items = balanced_bundle[:train_target]\n",
|
|
"test_items = balanced_bundle[train_target:train_target + test_target]\n",
|
|
"print(f\"Training set: {len(train_items):,}\")\n",
|
|
"print(f\"Test set: {len(test_items):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b",
|
|
"metadata": {
|
|
"id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_prompts = [item.prompt for item in train_items]\n",
|
|
"train_prices = [item.price for item in train_items]\n",
|
|
"test_prompts = [item.test_prompt() for item in test_items]\n",
|
|
"test_prices = [item.price for item in test_items]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "31ca360d-5fc6-487a-91c6-d61758b2ff16",
|
|
"metadata": {
|
|
"id": "31ca360d-5fc6-487a-91c6-d61758b2ff16"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
|
|
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
|
|
"pricing_dataset = DatasetDict({\"train\": train_dataset, \"test\": test_dataset})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8",
|
|
"metadata": {
|
|
"id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8"
|
|
},
|
|
"source": [
|
|
"### Persist"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c",
|
|
"metadata": {
|
|
"id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"storage_dir = Path(\"data\")\n",
|
|
"storage_dir.mkdir(exist_ok=True)\n",
|
|
"\n",
|
|
"with open(storage_dir / \"balanced_train.pkl\", \"wb\") as f:\n",
|
|
" pickle.dump(train_items, f)\n",
|
|
"\n",
|
|
"with open(storage_dir / \"balanced_test.pkl\", \"wb\") as f:\n",
|
|
" pickle.dump(test_items, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b2164662-9bc9-4a66-9e4e-a8a955a45753",
|
|
"metadata": {
|
|
"id": "b2164662-9bc9-4a66-9e4e-a8a955a45753"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"pricing_dataset[\"train\"].to_parquet(storage_dir / \"balanced_train.parquet\")\n",
|
|
"pricing_dataset[\"test\"].to_parquet(storage_dir / \"balanced_test.parquet\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013",
|
|
"metadata": {
|
|
"id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013"
|
|
},
|
|
"source": [
|
|
"## Baselines"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "qX0c_prppnyZ",
|
|
"metadata": {
|
|
"id": "qX0c_prppnyZ"
|
|
},
|
|
"source": [
|
|
"### Stochastic Anchor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7323252b-db50-4b8a-a7fc-8504bb3d218b",
|
|
"metadata": {
|
|
"id": "7323252b-db50-4b8a-a7fc-8504bb3d218b"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def stochastic_anchor(item):\n",
|
|
" return random.randrange(1, 1000)\n",
|
|
"\n",
|
|
"random.seed(123)\n",
|
|
"Tester.test(stochastic_anchor, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "O0xVXRXkp9sQ",
|
|
"metadata": {
|
|
"id": "O0xVXRXkp9sQ"
|
|
},
|
|
"source": [
|
|
"### Global Mean"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6a932b0e-ba6e-45d2-8436-b740c3681272",
|
|
"metadata": {
|
|
"id": "6a932b0e-ba6e-45d2-8436-b740c3681272"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_price_values = [item.price for item in train_items]\n",
|
|
"global_mean_price = sum(train_price_values) / len(train_price_values)\n",
|
|
"\n",
|
|
"def global_mean_estimator(item):\n",
|
|
" return global_mean_price\n",
|
|
"\n",
|
|
"Tester.test(global_mean_estimator, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d3410bd4-98e4-42a6-a702-4423cfd034b4",
|
|
"metadata": {
|
|
"id": "d3410bd4-98e4-42a6-a702-4423cfd034b4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def parse_features(raw):\n",
|
|
" if not raw:\n",
|
|
" return {}\n",
|
|
" try:\n",
|
|
" return json.loads(raw)\n",
|
|
" except json.JSONDecodeError:\n",
|
|
" return {}\n",
|
|
"for item in train_items:\n",
|
|
" item.features = parse_features(item.details)\n",
|
|
"for item in test_items:\n",
|
|
" item.features = parse_features(item.details)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "44537051-7b4e-4b8c-95a7-a989ea51e517",
|
|
"metadata": {
|
|
"id": "44537051-7b4e-4b8c-95a7-a989ea51e517"
|
|
},
|
|
"source": [
|
|
"### Feature Engineering"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec",
|
|
"metadata": {
|
|
"id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def infer_weight(item):\n",
|
|
" payload = item.features.get(\"Item Weight\")\n",
|
|
" if not payload:\n",
|
|
" return None\n",
|
|
"\n",
|
|
" parts = payload.split(\" \")\n",
|
|
" amount = float(parts[0])\n",
|
|
" unit = parts[1].lower()\n",
|
|
"\n",
|
|
" if unit == \"pounds\":\n",
|
|
" return amount\n",
|
|
"\n",
|
|
" if unit == \"ounces\":\n",
|
|
" return amount / 16\n",
|
|
"\n",
|
|
" if unit == \"grams\":\n",
|
|
" return amount / 453.592\n",
|
|
"\n",
|
|
" if unit == \"milligrams\":\n",
|
|
" return amount / 453592\n",
|
|
"\n",
|
|
" if unit == \"kilograms\":\n",
|
|
" return amount / 0.453592\n",
|
|
"\n",
|
|
" if unit == \"hundredths\" and len(parts) > 2 and parts[2].lower() == \"pounds\":\n",
|
|
" return amount / 100\n",
|
|
"\n",
|
|
" return None"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4d7b6f35-890c-4227-8990-6b62694a332d",
|
|
"metadata": {
|
|
"id": "4d7b6f35-890c-4227-8990-6b62694a332d"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def infer_rank(item):\n",
|
|
" payload = item.features.get(\"Best Sellers Rank\")\n",
|
|
" if not payload:\n",
|
|
" return None\n",
|
|
"\n",
|
|
" values = list(payload.values()) if isinstance(payload, dict) else []\n",
|
|
" if not values:\n",
|
|
" return None\n",
|
|
"\n",
|
|
" return sum(values) / len(values)\n",
|
|
"\n",
|
|
"top_brands = {\"nvidea\",\"hp\",\"dell\",\"lenovo\",\"samsung\",\"asus\",\"sony\",\"canon\",\"apple\",\"intel\"}\n",
|
|
"\n",
|
|
"def is_top_brand(item):\n",
|
|
" brand = item.features.get(\"Brand\")\n",
|
|
" return 1 if brand and brand.lower() in top_brands else 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1a6e06f3-614f-4687-bd43-9ac03aaface8",
|
|
"metadata": {
|
|
"id": "1a6e06f3-614f-4687-bd43-9ac03aaface8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_weights = [infer_weight(item) for item in train_items]\n",
|
|
"train_weights = [value for value in train_weights if value is not None]\n",
|
|
"average_weight = sum(train_weights) / len(train_weights) if train_weights else 1.0\n",
|
|
"train_ranks = [infer_rank(item) for item in train_items]\n",
|
|
"train_ranks = [value for value in train_ranks if value is not None]\n",
|
|
"average_rank = sum(train_ranks) / len(train_ranks) if train_ranks else 1_000_000.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42",
|
|
"metadata": {
|
|
"id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_features(item):\n",
|
|
" weight = infer_weight(item)\n",
|
|
" rank = infer_rank(item)\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"weight\": weight if weight is not None else average_weight,\n",
|
|
" \"rank\": rank if rank is not None else average_rank,\n",
|
|
" \"text_length\": len(item.test_prompt()),\n",
|
|
" \"top_brand\": is_top_brand(item)\n",
|
|
" }"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "189e959c-d70c-4509-bff6-1cbd8e8db637",
|
|
"metadata": {
|
|
"id": "189e959c-d70c-4509-bff6-1cbd8e8db637"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_frame = pd.DataFrame([build_features(item) for item in train_items])\n",
|
|
"train_frame[\"price\"] = [item.price for item in train_items]\n",
|
|
"test_frame = pd.DataFrame([build_features(item) for item in test_items[:250]])\n",
|
|
"test_frame[\"price\"] = [item.price for item in test_items[:250]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2",
|
|
"metadata": {
|
|
"id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"feature_columns = [\"weight\", \"rank\", \"text_length\", \"top_brand\"]\n",
|
|
"X_train = train_frame[feature_columns]\n",
|
|
"y_train = train_frame[\"price\"]\n",
|
|
"X_test = test_frame[feature_columns]\n",
|
|
"y_test = test_frame[\"price\"]\n",
|
|
"linear_model = LinearRegression()\n",
|
|
"linear_model.fit(X_train, y_train)\n",
|
|
"\n",
|
|
"def linear_baseline(item):\n",
|
|
" return float(linear_model.predict(pd.DataFrame([build_features(item)]))[0])\n",
|
|
"\n",
|
|
"Tester.test(linear_baseline, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ga-f4JK7sPU2",
|
|
"metadata": {
|
|
"id": "ga-f4JK7sPU2"
|
|
},
|
|
"source": [
|
|
"### NLP Baselines"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "de958a51-69ba-420c-84b7-d32765898fd2",
|
|
"metadata": {
|
|
"id": "de958a51-69ba-420c-84b7-d32765898fd2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"document_texts = [item.test_prompt() for item in train_items]\n",
|
|
"price_targets = np.array([item.price for item in train_items])\n",
|
|
"vectorizer = CountVectorizer(max_features=1000, stop_words=\"english\")\n",
|
|
"X_matrix = vectorizer.fit_transform(document_texts)\n",
|
|
"bow_model = LinearRegression()\n",
|
|
"bow_model.fit(X_matrix, price_targets)\n",
|
|
"\n",
|
|
"def bow_predictor(item):\n",
|
|
" pred = float(bow_model.predict(vectorizer.transform([item.test_prompt()]))[0])\n",
|
|
" return max(pred, 0)\n",
|
|
"\n",
|
|
"Tester.test(bow_predictor, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "QFDAoNnoRCk1",
|
|
"metadata": {
|
|
"id": "QFDAoNnoRCk1"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"processed_docs = [simple_preprocess(text) for text in document_texts]\n",
|
|
"word2vec_model = Word2Vec(sentences=processed_docs, vector_size=400, window=5, min_count=1, workers=4)\n",
|
|
"\n",
|
|
"def document_vector(text):\n",
|
|
" words = simple_preprocess(text)\n",
|
|
" vectors = [word2vec_model.wv[word] for word in words if word in word2vec_model.wv]\n",
|
|
"\n",
|
|
" if not vectors:\n",
|
|
" return np.zeros(word2vec_model.vector_size)\n",
|
|
" return np.mean(vectors, axis=0)\n",
|
|
"\n",
|
|
"w2v_features = np.array([document_vector(text) for text in document_texts])\n",
|
|
"svr_model = LinearSVR()\n",
|
|
"svr_model.fit(w2v_features, price_targets)\n",
|
|
"\n",
|
|
"def w2v_predictor(item):\n",
|
|
" return float(svr_model.predict([document_vector(item.test_prompt())])[0])\n",
|
|
"\n",
|
|
"Tester.test(w2v_predictor, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "kBVWisusQwDq",
|
|
"metadata": {
|
|
"id": "kBVWisusQwDq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"forest_model = RandomForestRegressor(n_estimators=200, random_state=123)\n",
|
|
"forest_model.fit(X_train, y_train)\n",
|
|
"\n",
|
|
"def forest_predictor(item):\n",
|
|
" return float(forest_model.predict(pd.DataFrame([build_features(item)])[feature_columns])[0])\n",
|
|
"\n",
|
|
"Tester.test(forest_predictor, test_items[:250])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "wgth1KvMSEOb",
|
|
"metadata": {
|
|
"id": "wgth1KvMSEOb"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"fine_tune_train = train_items[:200]\n",
|
|
"fine_tune_validation = train_items[200:250]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "g7uz8SC5S3_s",
|
|
"metadata": {
|
|
"id": "g7uz8SC5S3_s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def compose_messages(item, include_price=True):\n",
|
|
" system_message = \"You estimate prices of items. Reply only with the price\"\n",
|
|
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
|
|
" assistant_content = f\"Price is ${item.price:.2f}\" if include_price else \"Price is $\"\n",
|
|
" return [\n",
|
|
" {\"role\": \"system\", \"content\": system_message},\n",
|
|
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
|
" {\"role\": \"assistant\", \"content\": assistant_content}\n",
|
|
" ]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "_zHswJwzWCHZ",
|
|
"metadata": {
|
|
"id": "_zHswJwzWCHZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_jsonl(items):\n",
|
|
" lines = []\n",
|
|
" for item in items:\n",
|
|
" payload = {\"messages\": compose_messages(item)}\n",
|
|
" lines.append(json.dumps(payload))\n",
|
|
"\n",
|
|
" return \"\\n\".join(lines)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "rSHYkQojWH8Q",
|
|
"metadata": {
|
|
"id": "rSHYkQojWH8Q"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_jsonl = storage_dir / \"balanced_pricer_train.jsonl\"\n",
|
|
"validation_jsonl = storage_dir / \"balanced_pricer_validation.jsonl\"\n",
|
|
"train_jsonl.write_text(build_jsonl(fine_tune_train))\n",
|
|
"validation_jsonl.write_text(build_jsonl(fine_tune_validation))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "37BH0u-QWOiY",
|
|
"metadata": {
|
|
"id": "37BH0u-QWOiY"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"openai_client = OpenAI()\n",
|
|
"\n",
|
|
"with open(train_jsonl, \"rb\") as f:\n",
|
|
" train_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n",
|
|
"\n",
|
|
"with open(validation_jsonl, \"rb\") as f:\n",
|
|
" validation_file = openai_client.files.create(file=f, purpose=\"fine-tune\")\n",
|
|
"\n",
|
|
"train_file, validation_file"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2nNSE_AzWYMq",
|
|
"metadata": {
|
|
"id": "2nNSE_AzWYMq"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"balanced-pricer\"}}\n",
|
|
"fine_tune_job = openai_client.fine_tuning.jobs.create(\n",
|
|
" training_file=train_file.id,\n",
|
|
" validation_file=validation_file.id,\n",
|
|
" model=\"gpt-4o-mini-2024-07-18\",\n",
|
|
" seed=123,\n",
|
|
" hyperparameters={\"n_epochs\": 1},\n",
|
|
" integrations=[wandb_integration],\n",
|
|
" suffix=\"balanced-pricer\"\n",
|
|
")\n",
|
|
"fine_tune_job"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ASiJUw-Fh8Ul",
|
|
"metadata": {
|
|
"id": "ASiJUw-Fh8Ul"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"job_status = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id)\n",
|
|
"job_events = openai_client.fine_tuning.jobs.list_events(fine_tuning_job_id=fine_tune_job.id, limit=10)\n",
|
|
"job_status, job_events"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7jB_7gqBiH_r",
|
|
"metadata": {
|
|
"id": "7jB_7gqBiH_r"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"fine_tuned_model_name = openai_client.fine_tuning.jobs.retrieve(fine_tune_job.id).fine_tuned_model\n",
|
|
"print(fine_tuned_model_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "BHfLSadhiVQE",
|
|
"metadata": {
|
|
"id": "BHfLSadhiVQE"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tuned_predictor(item):\n",
|
|
" messages = compose_messages(item, include_price=False)\n",
|
|
" response = openai_client.chat.completions.create(\n",
|
|
" model=fine_tuned_model_name,\n",
|
|
" messages=messages,\n",
|
|
" seed=123,\n",
|
|
" max_tokens=7\n",
|
|
")\n",
|
|
" answer = response.choices[0].message.content\n",
|
|
" return get_price(answer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "C0CiTZ4jkjrI",
|
|
"metadata": {
|
|
"id": "C0CiTZ4jkjrI"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"if test_items:\n",
|
|
" sample_item = test_items[0]\n",
|
|
" print(sample_item.price)\n",
|
|
" print(tuned_predictor(sample_item))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "WInQE0ObkuBl",
|
|
"metadata": {
|
|
"id": "WInQE0ObkuBl"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"Tester.test(tuned_predictor, test_items[:250])"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|