{ "cells": [ { "cell_type": "markdown", "id": "a246687d", "metadata": {}, "source": [ "# The Product Pricer\n", "\n", "A model that can estimate how much something costs, from its description\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3792ce5b", "metadata": {}, "outputs": [], "source": [ "! uv -q pip install langchain-ollama" ] }, { "cell_type": "code", "execution_count": null, "id": "390c3ce3", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "from dotenv import load_dotenv\n", "from huggingface_hub import login\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "import matplotlib.pyplot as plt\n", "import pickle\n", "import re\n", "from langchain_ollama import OllamaLLM\n", "from openai import OpenAI\n", "from testing import Tester\n", "import json\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8a8ff331", "metadata": {}, "outputs": [], "source": [ "load_dotenv(override=True)\n", "hf_token = os.getenv(\"HF_TOKEN\")\n", "login(hf_token, add_to_git_credential=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "1051e21e", "metadata": {}, "outputs": [], "source": [ "from items import Item\n", "from loaders import ItemLoader\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "id": "290fa868", "metadata": {}, "outputs": [], "source": [ "dataset_names = [\n", " \"Appliances\",\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "12ffad66", "metadata": {}, "outputs": [], "source": [ "items = []\n", "for dataset_name in dataset_names:\n", " loader = ItemLoader(dataset_name)\n", " items.extend(loader.load())" ] }, { "cell_type": "code", "execution_count": null, "id": "0b3890d7", "metadata": {}, "outputs": [], "source": [ "print(f\"A grand total of {len(items):,} items\")" ] }, { "cell_type": "code", "execution_count": null, "id": "246ab22a", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of token counts again\n", "\n", "tokens = [item.token_count for item in items]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n", "plt.xlabel('Length (tokens)')\n", "plt.ylabel('Count')\n", "plt.hist(tokens, rwidth=0.7, color=\"skyblue\", bins=range(0, 300, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "3a49a4d4", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "prices = [item.price for item in items]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "57e4ea1b", "metadata": {}, "outputs": [], "source": [ "# How does the price vary with the character count of the prompt?\n", "\n", "sample = items\n", "\n", "sizes = [len(item.prompt) for item in sample]\n", "prices = [item.price for item in sample]\n", "\n", "# Create the scatter plot\n", "plt.figure(figsize=(15, 8))\n", "plt.scatter(sizes, prices, s=0.2, color=\"red\")\n", "\n", "# Add labels and title\n", "plt.xlabel('Size')\n", "plt.ylabel('Price')\n", "plt.title('Is there a simple correlation?')\n", "\n", "# Display the plot\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "e6620daa", "metadata": {}, "outputs": [], "source": [ "def report(item):\n", " prompt = item.prompt\n", " tokens = Item.tokenizer.encode(item.prompt)\n", " print(prompt)\n", " print(tokens[-10:])\n", " print(Item.tokenizer.batch_decode(tokens[-10:]))" ] }, { "cell_type": "code", "execution_count": null, "id": "af71d177", "metadata": {}, "outputs": [], "source": [ "report(sample[50])" ] }, { "cell_type": "code", "execution_count": null, "id": "75ab3c21", "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "\n", "random.seed(42)\n", "random.shuffle(sample)\n", "train = sample[:25_000]\n", "test = sample[25_000:27_000]\n", "print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6d5cbd3a", "metadata": {}, "outputs": [], "source": [ "print(train[0].prompt)" ] }, { "cell_type": "code", "execution_count": null, "id": "39de86d6", "metadata": {}, "outputs": [], "source": [ "print(test[0].test_prompt())" ] }, { "cell_type": "code", "execution_count": null, "id": "65480df9", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices in the first 250 test points\n", "\n", "prices = [float(item.price) for item in test[:250]]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "7a315b10", "metadata": {}, "outputs": [], "source": [ "filtered_prices = [float(item.price) for item in test if item.price > 99.999]" ] }, { "cell_type": "markdown", "id": "5693c9c6", "metadata": {}, "source": [ "### Confirm that the tokenizer tokenizes all 3 digit prices into 1 token" ] }, { "cell_type": "code", "execution_count": null, "id": "99e8cfc3", "metadata": {}, "outputs": [], "source": [ "for price in filtered_prices:\n", " tokens = Item.tokenizer.encode(f\"{price}\", add_special_tokens=False)\n", " assert len(tokens) == 3\n" ] }, { "cell_type": "markdown", "id": "f3159195", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "code", "execution_count": null, "id": "7bdc5dd5", "metadata": {}, "outputs": [], "source": [ "def messages_for(item):\n", " system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n", " user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": user_prompt},\n", " {\"role\": \"assistant\", \"content\": \"Price is $\"}\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "211b0658", "metadata": {}, "outputs": [], "source": [ "# A utility function to extract the price from a string\n", "\n", "def get_price(s):\n", " s = s.replace('$','').replace(',','')\n", " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", " return float(match.group()) if match else 0" ] }, { "cell_type": "code", "execution_count": null, "id": "ee01da84", "metadata": {}, "outputs": [], "source": [ "# Convert the items into a list of json objects - a \"jsonl\" string\n", "# Each row represents a message in the form:\n", "# {\"messages\" : [{\"role\": \"system\", \"content\": \"You estimate prices...\n", "\n", "\n", "def make_jsonl(items):\n", " result = \"\"\n", " for item in items:\n", " messages = messages_for(item)\n", " messages_str = json.dumps(messages)\n", " result += '{\"messages\": ' + messages_str +'}\\n'\n", " return result.strip()" ] }, { "cell_type": "code", "execution_count": null, "id": "f23e8959", "metadata": {}, "outputs": [], "source": [ "# Convert the items into jsonl and write them to a file\n", "\n", "def write_jsonl(items, filename):\n", " with open(filename, \"w\") as f:\n", " jsonl = make_jsonl(items)\n", " f.write(jsonl)" ] }, { "cell_type": "markdown", "id": "b6a83580", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": null, "id": "451b974f", "metadata": {}, "outputs": [], "source": [ "with open('train_lite.pkl', 'rb') as f:\n", " train_lite = pickle.load(f)\n", "\n", "with open('test_lite.pkl', 'rb') as f:\n", " test_lite = pickle.load(f)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f365d65c", "metadata": {}, "outputs": [], "source": [ "messages_for(test_lite[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "57b0b160", "metadata": {}, "outputs": [], "source": [ "get_price(\"The price is roughly $99.99 because blah blah\")" ] }, { "cell_type": "markdown", "id": "ff3e4670", "metadata": {}, "source": [ "## Models" ] }, { "cell_type": "code", "execution_count": null, "id": "9f62c94b", "metadata": {}, "outputs": [], "source": [ "MODEL_LLAMA3_2 = \"llama3.2\"\n", "MODEL_MISTRAL = \"mistral\"\n", "MODEL_TINY_LLAMA = \"tinyllama\"\n", "\n", "llm3_2 = OllamaLLM(model=MODEL_LLAMA3_2)\n", "llmMistral = OllamaLLM(model=MODEL_MISTRAL)\n", "llmTinyLlama = OllamaLLM(model=MODEL_TINY_LLAMA)\n" ] }, { "cell_type": "markdown", "id": "d18394fb", "metadata": {}, "source": [ "## Model Tests" ] }, { "cell_type": "code", "execution_count": null, "id": "7dac335f", "metadata": {}, "outputs": [], "source": [ "def llama3_2_model(item):\n", " response = llm3_2.invoke(messages_for(item))\n", " return get_price(response)\n", "\n", "def mistral_model(item):\n", " response = llmMistral.invoke(messages_for(item))\n", " return get_price(response)\n", "\n", "def tinyllama_model(item):\n", " response = llmTinyLlama.invoke(messages_for(item))\n", " return get_price(response)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "062e78c2", "metadata": {}, "outputs": [], "source": [ "test_lite[0].price" ] }, { "cell_type": "code", "execution_count": null, "id": "c58756f2", "metadata": {}, "outputs": [], "source": [ "Tester.test(llama3_2_model, test_lite)" ] }, { "cell_type": "code", "execution_count": null, "id": "899e2401", "metadata": {}, "outputs": [], "source": [ "Tester.test(mistral_model, test_lite)" ] }, { "cell_type": "code", "execution_count": null, "id": "2f5bc9ad", "metadata": {}, "outputs": [], "source": [ "Tester.test(tinyllama_model, test_lite)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }