Files
LLM_Engineering_OLD/week6/community-contributions/lisekarimi/09_part3_e5embeddings_rag.ipynb
2025-06-07 05:47:51 +02:00

1081 lines
33 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "d9b9eaa6-a12f-4cf8-a4c5-e8ac2c15d15b",
"metadata": {
"id": "d9b9eaa6-a12f-4cf8-a4c5-e8ac2c15d15b"
},
"source": [
"# 🔍 Predicting Item Prices from Descriptions (Part 3)\n",
"---\n",
"- Data Curation & Preprocessing\n",
"- Model Benchmarking Traditional ML vs LLMs\n",
"- ➡E5 Embeddings & RAG\n",
"- Fine-Tuning GPT-4o Mini\n",
"- Evaluating LLaMA 3.1 8B Quantized\n",
"- Fine-Tuning LLaMA 3.1 with QLoRA\n",
"- Evaluating Fine-Tuned LLaMA\n",
"- Summary & Leaderboard\n",
"\n",
"---\n",
"\n",
"# 🧠 Part 3: E5 Embeddings & RAG\n",
"\n",
"- 🧑‍💻 Skill Level: Advanced\n",
"- ⚙️ Hardware: ⚠️ GPU required for embeddings (400K items) - use Google Colab\n",
"- 🛠️ Requirements: 🔑 HF Token, Open API Key\n",
"- Tasks:\n",
" - Preprocessed item descriptions\n",
" - Generated and stored embeddings in ChromaDB\n",
" - Trained XGBoost on embeddings, pushed to HF Hub, and ran predictions\n",
" - Predicted prices with GPT-4o Mini using RAG\n",
"\n",
"Is Word2Vec enough for XGBoost, or do contextual E5 embeddings perform better?\n",
"\n",
"Does retrieval improve price prediction for GPT-4o Mini?\n",
"\n",
"Lets find out.\n",
"\n",
"⚠️ This notebook assumes basic familiarity with RAG and contextual embeddings.\n",
"We use the same E5 embedding space for both XGBoost and GPT-4o Mini with RAG, enabling a fair comparison.\n",
"Embeddings are stored and queried via ChromaDB — no LangChain is used for creation or retrieval.\n",
"\n",
"---\n",
"📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8e2af5e-03cc-46dc-8a8b-37cb102d0e92",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d8e2af5e-03cc-46dc-8a8b-37cb102d0e92",
"outputId": "905907cc-81c5-4a3b-e7c8-9e237e594a09"
},
"outputs": [],
"source": [
"# Install required packages in Google Colab\n",
"%pip install -q tqdm huggingface_hub numpy sentence-transformers datasets chromadb xgboost"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ce6a892-b357-4132-b9c0-a3142a0244c8",
"metadata": {
"id": "4ce6a892-b357-4132-b9c0-a3142a0244c8"
},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import math\n",
"import chromadb\n",
"import re\n",
"import joblib\n",
"import os\n",
"from tqdm import tqdm\n",
"import gc\n",
"from huggingface_hub import login, HfApi\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"from datasets import load_dataset\n",
"from google.colab import userdata\n",
"from xgboost import XGBRegressor\n",
"from openai import OpenAI\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "yBH-mvV0QBiw",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yBH-mvV0QBiw",
"outputId": "b4b6df10-dc05-4dbe-dd8b-55bae5a2b7af"
},
"outputs": [],
"source": [
"# Mount Google Drive to access persistent storage\n",
"\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3OUI1jQYyaeX",
"metadata": {
"id": "3OUI1jQYyaeX"
},
"outputs": [],
"source": [
"# Google Colab User Data\n",
"# Ensure you have set the following in your Google Colab environment:\n",
"openai_api_key = userdata.get(\"OPENAI_API_KEY\")\n",
"hf_token = userdata.get('HF_TOKEN')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99f6f632",
"metadata": {},
"outputs": [],
"source": [
"openai = OpenAI(api_key=openai_api_key)\n",
"login(hf_token, add_to_git_credential=True)\n",
"\n",
"# Configuration\n",
"ROOT = \"/content/drive/MyDrive/deal_finder\"\n",
"CHROMA_PATH = f\"{ROOT}/chroma\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "FF-HryRnDXm5",
"metadata": {
"id": "FF-HryRnDXm5"
},
"outputs": [],
"source": [
"# Helper class for evaluating model predictions\n",
"\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"RESET = \"\\033[0m\"\n",
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}\n",
"\n",
"class Tester:\n",
"\n",
" def __init__(self, predictor, data, title=None, size=250):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = size\n",
" self.guesses = []\n",
" self.truths = []\n",
" self.errors = []\n",
" self.sles = []\n",
" self.colors = []\n",
"\n",
" def color_for(self, error, truth):\n",
" if error<40 or error/truth < 0.2:\n",
" return \"green\"\n",
" elif error<80 or error/truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" guess = self.predictor(datapoint)\n",
" truth = datapoint[\"price\"]\n",
" error = abs(guess - truth)\n",
" log_error = math.log(truth+1) - math.log(guess+1)\n",
" sle = log_error ** 2\n",
" color = self.color_for(error, truth)\n",
" # title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color)\n",
" # print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart(self, title):\n",
" # max_error = max(self.errors)\n",
" plt.figure(figsize=(12, 8))\n",
" max_val = max(max(self.truths), max(self.guesses))\n",
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
" plt.xlabel('Ground Truth')\n",
" plt.ylabel('Model Estimate')\n",
" plt.xlim(0, max_val)\n",
" plt.ylim(0, max_val)\n",
" plt.title(title)\n",
"\n",
" # Add color legend\n",
" from matplotlib.lines import Line2D\n",
" legend_elements = [\n",
" Line2D([0], [0], marker='o', color='w', label='Accurate (green)', markerfacecolor='green', markersize=8),\n",
" Line2D([0], [0], marker='o', color='w', label='Medium error (orange)', markerfacecolor='orange', markersize=8),\n",
" Line2D([0], [0], marker='o', color='w', label='High error (red)', markerfacecolor='red', markersize=8)\n",
" ]\n",
" plt.legend(handles=legend_elements, loc='upper right')\n",
"\n",
" plt.show()\n",
"\n",
"\n",
" def report(self):\n",
" average_error = sum(self.errors) / self.size\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" hits = sum(1 for color in self.colors if color==\"green\")\n",
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
" self.chart(title)\n",
"\n",
" def run(self):\n",
" self.error = 0\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" self.report()\n",
"\n",
" @classmethod\n",
" def test(cls, function, data):\n",
" cls(function, data).run()\n"
]
},
{
"cell_type": "markdown",
"id": "6f82b230-2e03-4b1e-9be5-926fcd19acbe",
"metadata": {
"id": "6f82b230-2e03-4b1e-9be5-926fcd19acbe"
},
"source": [
"## 📥 Load Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ae00568",
"metadata": {},
"outputs": [],
"source": [
"# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
"# %pip install -U datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55f1495b-f343-4152-8739-3a99f5ac405d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 177,
"referenced_widgets": [
"6e7c01d666f64fa58d6a059cc8d8f323",
"597b7155767441e6a0283a19edced00f",
"cf1360550eaa49a0867f55db8b8c4c77",
"94f26137cccf47f6a36d9325bc8f5b9c",
"a764b97f3dcd480c8860dde979e5e114",
"f1ec9a46c9ce4e038f3051bbd1b2c661",
"992f46ae91554731987b4baf79ba1bbd",
"b4abe22402fe40fd82b7fe93b4bc06f3",
"57ec058518734e3dbd27324cbba243c0",
"f101230e8a9a431d85ee2f8e51add7ad",
"e196658b093746588113240a60336437",
"cb06a4d26cb84c708857b683d1e84c12",
"e82ad07ba22e465cbe0232c504c3b693",
"c4e0ed1165f54393aaec24cd4624d562",
"295a3c6662034aaaab4d2e0192d1d1ce",
"c38aff0c91a849feb547e78156c2c347",
"69647c5595874c3185cebf6813ee908c",
"1036b1af4b154916a3d4f16f5ed799eb",
"e6347ff832cc4c04aef86594ea5a9e64",
"01c63224aa6a4f0c9c88a4d85527e767",
"1db34b9a4f1f42a897345b5a6630ced6",
"9293f2d745024d7facb68e04cc188850",
"26f6ec91efaf42909cec172fafe55987",
"c1131f0324b0498da9bc59720e867eb6",
"3e58017527a04634a489a33ed53fd312",
"06cd89f57d08466c875d179e79e3ecd2",
"2e0aa0aa87a04419a277f303f577f7ff",
"8fa0fe1992db42a997e7cd3ee08bd09e",
"accb1d5142a9498da0117f746fedd691",
"fcc2fc2f82e2441995b9e61b23b9b91e",
"da93fe316dd24cb48538b52ef2eaf6b5",
"5cea58775faf41829c04d2a84e3e2c31",
"1914ec7959d143d09a55da324bbcd47b",
"a3d3504148df46f59b6770fb377e2bb6",
"b088b9a503e24f179741d40d21a730d9",
"b77dcf4632954d0c9c3b6d441c5f684d",
"4cc8b3c4d9934f24a94b4601ab7816b5",
"c093f1c0806a43b79594ddac856a301c",
"9f4d9ac1aa074ed6b0248a4b18fde7db",
"c00785b8fdda409e9cb435abbb0466da",
"612e211af4cd46eb9d2f3148d1c7cb0b",
"86f93c663cc446adbc6366a528cb01b0",
"dd42911451ec48e086c1c99e76492321",
"5b942241f11c4f2ab086f0f289f99a03",
"d28a5c6172f74c0f8bbd2d949455f22e",
"0e67b2055f214eb691b4b54d9431bdd8",
"f81c4dc72b3b4b40a6a70528db732482",
"043a355b6a85471ba0142eb25e2c9eb0",
"8682bfab79a8409499797a3307e4d64d",
"55a837644bb643ac864fa1a674e665c8",
"33aae5a98bf5433b813ff8216e015089",
"56eedfc5ba6642dc8443ab60f5f09b8c",
"a1b710c227a84ea1a55c310084f13a93",
"0d4bc0d0e88a4c77a202f9c11b2ee2a9",
"20858379c2cd45d59070b18149d6e925"
]
},
"id": "55f1495b-f343-4152-8739-3a99f5ac405d",
"outputId": "37317fe6-b560-4ad0-c7d6-66517fd67c42"
},
"outputs": [],
"source": [
"HF_USER = \"lisekarimi\"\n",
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
"\n",
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85880d79-f1ba-4ee8-a039-b6acea84562c",
"metadata": {
"id": "85880d79-f1ba-4ee8-a039-b6acea84562c"
},
"outputs": [],
"source": [
"print(train[0][\"text\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88842541-d73b-4fae-a550-6dedf8fab633",
"metadata": {
"id": "88842541-d73b-4fae-a550-6dedf8fab633"
},
"outputs": [],
"source": [
"print(train[0][\"price\"])"
]
},
{
"cell_type": "markdown",
"id": "7b8a9a5b-f74d-487d-a400-d157fea8c979",
"metadata": {
"id": "7b8a9a5b-f74d-487d-a400-d157fea8c979"
},
"source": [
"## 📦 Embed + Save Training Data to Chroma\n",
"- No LangChain used.\n",
"- We use `intfloat/e5-small-v2` for embeddings:\n",
" - Fast, high-quality, retrieval-tuned\n",
" - **Requires 'passage:' prefix**\n",
"- We embed item descriptions and store them in ChromaDB, with price saved as metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b95a87a8-2136-4e03-a36c-42e5d53a3e28",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 337,
"referenced_widgets": [
"8216f5d45e9345e493a43b8cbbe6598a",
"ec3854658f8448fc8463e8635889f700",
"7a90822b2aff4d5cb926442f01a77a9b",
"9518c3af589744cfbbb51f87d68f216e",
"327044765c044384a14be4e660bb152f",
"0b773d68d2394d80a2baf73c1808752a",
"21568b9954c8411d863baa7385df624f",
"0a08828a0ba4430ea6e039949f220b5b",
"3d5a51cfb5f44eecbf80d46e2e4608fd",
"313f059a82104a9394182f6dcdb0bfb4",
"6a625748afc84fe89a8af7a4ef638675",
"ebe43cd30e414f31ab52614c6e9f9f2b",
"88c29992adaa44af857e3216f7e53e60",
"0528af78cef844e8a2b489dcb8fce049",
"8cbccd78a79447158f02caadfa7d805f",
"076ce072490c493ba5b3c431f6166eda",
"dd7780038f8a4cd3837972c78b6583bc",
"9e285e2b58934552b98edd998b82a678",
"338efda3245a4989a9b3ee0795949bb8",
"136dfb68394742ea98d9eb845730846c",
"891d821725b6457c9d06737bf75fe3ed",
"14feb4e20339465d966a6a80504eb819",
"c02b637785324b9eb88e6a2c00cb986b",
"3635da14e6f04e8f90548eb6381290a8",
"1314757f404e47f5b0f6fa4de8537863",
"9e5f2478e931476d882e471c7f66aaeb",
"4ad885d69d9f492c960ca53426189707",
"992d5e88d7844a52a283c0e19475ab78",
"43eaec936c774e3380ae4ff1a823f3dc",
"ceeb11b317ac4d37b59641024f77265f",
"5e0371de53164830b4e8c2b6954b5947",
"63a729492e8a4a759d75b769cbb3e1e7",
"14dde2c87b7b4c9ea16d48732108dcd7",
"f50717b099d142be95390ae8f1e99e6a",
"ffa64c304dab4ef18e9ef50ac1625cd6",
"f358351612004f64adffb931c3130603",
"7593358526ae4a87bf4be0eb1bcfc076",
"51536b45f5674d498272dc7b2def635d",
"8fbe2a3fc07943e7bf0fdc927bab795a",
"6b265cc65d5a42638572c1776faafdb1",
"39fa86a7760d43c793eb8ef27475af7d",
"eee5113e2dd1402faf76d00f07d8e0af",
"6792ed7123724b2d8091bc8d36255e68",
"e35094b24c154340bb1b3ebba7ac0a0d",
"dd63bb6ffed34b6687a0c79d8af93fb7",
"32080bc9381c449ab63794655ec6d714",
"eb7aa289fefc465d98edeed9ce2bff51",
"53fae218b4b74863af5fe53a66a5f7ef",
"35bc6d95c60f4c3d8ddc6b3b0845ff7e",
"f4765ca278ad4da4b465bd2920a21320",
"7ac6ead5baef4f30aff170a30a9a7977",
"e7adb5eb38d54b29b734d207982411c8",
"8f4f51b75af74daa9b9ad6696760109c",
"ae4db932b7544c6cb9ff668fa954addd",
"be63f07eedbd4d46ac4913df45216108",
"2e47d9e7b36a4ec69a9071930671ae8e",
"7b1c7f9bf0e8412abb66bcfc24cf9668",
"5c8742d3f663470e9977d006e83314b7",
"74ec67e07ee0477eb41e21093ae82858",
"4b60a8f023bc4d759bc197b11bf4e160",
"7a090f162fa84568a5e486ba935c3ed1",
"8b650428a6834f5d8ebe62ad327493e0",
"5c4d22bce82546d28a8b0c041895c8e3",
"16121b830a2948afb3ca8eb54e27a678",
"0305a4b4408f4562b87b58098148326d",
"68f07b5b7ad447ce9a87023d872c2e73",
"2156a5ced089414c99a1bb8dd3a0b3b7",
"2e6cd134c70e455a85c47b1575135883",
"f4264985b5cc4a0f970a088fb90b8bcf",
"71d790bf25324e6dbb5372f636c53da9",
"dac3ba29ee4d4083a9abca7eab632534",
"5c75c020a1914da680340fe826f3f58d",
"195e6dfb82c84f0191838acbbfe38126",
"b06adcaf8d4c497897ed3625f3afb4eb",
"d4ab3971183a4e8fa10402e3542e6466",
"444ca1f5213241c2bc71fa9ebe9ac3ca",
"34d571f76ef845f4bc272a5e05491c31",
"e8ee76b022d64b2cb24a2cb7b61aeef7",
"8c9ac87788b04ae6899f3b62fdc3ed0d",
"431b638c435444c38e50a09573b8f31b",
"0430f22e24d14171b83261faa090f349",
"0fa5ae935a554461b086a4b81470b9ad",
"f072e665d27e442ab4d0e2eb33c98db9",
"fd3b1885c39c4b70b083d7fddf74d4b6",
"f77051cb151645559223ecf835426688",
"0e17661f878948598703ee7942e5e1a2",
"fca913c6cfff48099d1744d5b091fc46",
"085baf51ecef46318ceafbaba2bb4490",
"52309039c2d8421bbb8e99f63f5ba91f",
"f4233cd960ea4f549734a5b1e1da5e2e",
"42ce1b7765f547cd9ecd8b428ec1c718",
"e72a08514d3b42d2b5fbf87a920bcdf0",
"ad05cf4c0ed44341aa3cd2cbd22b513d",
"db9915d53d784b85accebe1552c4e7e1",
"9519b6d9bf1b45e3b56da4c28d2aeb2e",
"cfeb0597708b49fa9b65342e1ac446ae",
"e29617eff6fd4199a74b670198ba2a69",
"1cea197a15d94654a0e792318435d707",
"89dcb96670a8433593e3452fad3c9210",
"0802085388be453b8fe5edee7e0a01ef",
"1ed257f19b8b44ee85f09e10178ae52f",
"04107981561149cba5baf74ccba87aa6",
"09afb010020e4b2f91d7cdbdca316962",
"b11b51beaa54474cb7682110bd2d24ae",
"47822470ddf842cd9e3368090549a2b5",
"835bce5d87a2417c9b6a5b27627447dc",
"5ca06dd536d44de784984a492d23573f",
"8e75bdb4469e497c8f021ebde7c6c9b3",
"7f4d4f8ece1d4651a2186f10a0cc25a5",
"92036442af5f4b698f2a54ecba4650e2"
]
},
"id": "b95a87a8-2136-4e03-a36c-42e5d53a3e28",
"outputId": "6094328e-8c33-4b40-80e9-08c5cfb3e277"
},
"outputs": [],
"source": [
"# Load embedding model\n",
"model_embedding = SentenceTransformer(\"intfloat/e5-small-v2\", device='cuda')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "733cf41d-e81e-4cfc-b597-67da02dbc3cf",
"metadata": {
"id": "733cf41d-e81e-4cfc-b597-67da02dbc3cf"
},
"outputs": [],
"source": [
"# Init Chroma\n",
"client = chromadb.PersistentClient(path=CHROMA_PATH)\n",
"collection = client.get_or_create_collection(name=\"price_items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f493c7d-1c72-40f9-a5c6-63c7f6b1cf2c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 91
},
"id": "1f493c7d-1c72-40f9-a5c6-63c7f6b1cf2c",
"outputId": "72627732-4eee-4d9a-c8cb-0c42e2541a80"
},
"outputs": [],
"source": [
"# Format description function (no price in text)\n",
"def description(item):\n",
" text = item[\"text\"].replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
" text = text.split(\"\\n\\nPrice is $\")[0]\n",
" return f\"passage: {text}\"\n",
"\n",
"description(train[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f44bf613-adf6-4993-bf7b-6aa9fad21a03",
"metadata": {
"id": "f44bf613-adf6-4993-bf7b-6aa9fad21a03"
},
"outputs": [],
"source": [
"batch_size = 300 # how many items to insert into Chroma at once\n",
"encode_batch_size = 1024 # how many items to encode at once in GPU memory\n",
"\n",
"for i in tqdm(range(0, len(train), batch_size), desc=\"Processing batches\"):\n",
"\n",
" end_idx = min(i + batch_size, len(train))\n",
"\n",
" # Collect documents and metadata\n",
" documents = [description(train[j]) for j in range(i, end_idx)]\n",
" metadatas = [{\"price\": train[j][\"price\"]} for j in range(i, end_idx)]\n",
" ids = [f\"doc_{j}\" for j in range(i, end_idx)]\n",
"\n",
" # GPU batch encoding\n",
" vectors = model_embedding.encode(\n",
" documents,\n",
" batch_size=encode_batch_size,\n",
" show_progress_bar=False,\n",
" normalize_embeddings=True\n",
" ).tolist()\n",
"\n",
" # Insert into Chroma\n",
" collection.add(\n",
" ids=ids,\n",
" documents=documents,\n",
" embeddings=vectors,\n",
" metadatas=metadatas\n",
" )\n",
"\n",
"print(\"✅ Embedding and storage to ChromaDB completed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2e2ccc9-b772-45f7-8258-cbc4f9c3ed59",
"metadata": {},
"outputs": [],
"source": [
"# Now flush and clean\n",
"print(\"🧹 Cleaning up and saving ChromaDB...\")\n",
"client = None\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"id": "c35d2fab-583f-4527-a7cc-9d31214b2f35",
"metadata": {},
"source": [
"Our ChromaDB is currently saved in a persistent Google Drive path; for a production-ready app, we recommend uploading it to AWS S3 for better reliability and scalability.\n",
"\n",
"🧩 Now that we've generated the E5 embeddings, let's use them for both **XGBoost regression** and **GPT-4o Mini with RAG** ."
]
},
{
"cell_type": "markdown",
"id": "40e4c587-211d-4bc0-91cf-6267f45405d6",
"metadata": {
"id": "40e4c587-211d-4bc0-91cf-6267f45405d6"
},
"source": [
"## 📈 Embedding-Based Regression with XGBoost"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f058ccac-3392-457d-b54c-6471960e9af3",
"metadata": {
"id": "f058ccac-3392-457d-b54c-6471960e9af3"
},
"outputs": [],
"source": [
"# Step 1: Load vectors and prices from Chroma\n",
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
"vectors = np.array(result['embeddings'])\n",
"documents = result['documents']\n",
"prices = [meta['price'] for meta in result['metadatas']]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "JYQo0RaMb8Ql",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 254
},
"id": "JYQo0RaMb8Ql",
"outputId": "c1641347-1fd4-41bb-e060-147224fc6bed"
},
"outputs": [],
"source": [
"# Step 2: Train XGBoost model\n",
"xgb_model = XGBRegressor(n_estimators=100, random_state=42, n_jobs=-1, verbosity=0)\n",
"xgb_model.fit(vectors, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "yaqG0z7jb919",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yaqG0z7jb919",
"outputId": "6a2f9120-97e0-4436-aa12-40d94fbc5c64"
},
"outputs": [],
"source": [
"# Step 3: Serialize XGBoost model locally for Hugging Face upload\n",
"MODEL_DIR = os.path.join(ROOT, \"models\")\n",
"MODEL_FILENAME = \"xgboost_model.pkl\"\n",
"LOCAL_MODEL = os.path.join(MODEL_DIR, MODEL_FILENAME)\n",
"\n",
"os.makedirs(MODEL_DIR, exist_ok=True)\n",
"joblib.dump(xgb_model, LOCAL_MODEL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "Z_17sQUdxIr3",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104,
"referenced_widgets": [
"2362f3121e5546b98e4623eb3680e96b",
"ef53ee3b68c840d6a3fe98386d26bbd9",
"a4768d0ecdd640a2a5bccd07a93c54b7",
"e177440016974bc699b666fa721c6490",
"2a9d0e5829174b738b4dfea1c71a3481",
"ee6dffc7b79e405d923940166ef10590",
"57bf3388622241869a5e9dab558dca72",
"aa87f4feddd6409fbfb81f417e5d6662",
"973a83ca118e4ed1b5a51821034ecc31",
"d5a3c955aba14b3ea8e9b5c90a3bf20a",
"daaa4f26bad545a394685e266f85a6ae"
]
},
"id": "Z_17sQUdxIr3",
"outputId": "68ebdbdb-d42e-4bc8-addc-85b42d418d1d"
},
"outputs": [],
"source": [
"# Step 4: Push serialized XGBoost model to Hugging Face Hub\n",
"api = HfApi(token=hf_token)\n",
"REPO_NAME = \"smart-deal-finder-models\"\n",
"REPO_ID = f\"{HF_USER}/{REPO_NAME}\"\n",
"\n",
"# Create the model repo if it doesn't exist\n",
"api.create_repo(repo_id=REPO_ID, repo_type=\"model\", private=True, exist_ok=True)\n",
"\n",
"# Upload the saved model\n",
"api.upload_file(\n",
" path_or_fileobj=LOCAL_MODEL,\n",
" path_in_repo=MODEL_FILENAME,\n",
" repo_id=REPO_ID,\n",
" repo_type=\"model\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f59125d-9fa6-483b-957f-4423a9b2c900",
"metadata": {
"id": "3f59125d-9fa6-483b-957f-4423a9b2c900"
},
"outputs": [],
"source": [
"# Step 5: Define the predictor\n",
"def xgb_predictor(datapoint):\n",
" doc = description(datapoint)\n",
" vector = model_embedding.encode([doc], normalize_embeddings=True)[0]\n",
" return max(0, xgb_model.predict([vector])[0])"
]
},
{
"cell_type": "markdown",
"id": "a890f1f0-d827-472f-a7a9-6c2cbe3d8341",
"metadata": {
"id": "a890f1f0-d827-472f-a7a9-6c2cbe3d8341"
},
"source": [
"🔔 Reminder: In Part 2, XGBoost with Word2Vec (non-contextual embeddings) achieved:\n",
"- Avg. Error: ~$107\n",
"- RMSLE: 0.83\n",
"- Accuracy: 29.20%\n",
"\n",
"🧪 Now, lets see if contextual embeddings improve XGBoost."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "q-tIbVilTPxP",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 718
},
"id": "q-tIbVilTPxP",
"outputId": "7c9043ef-a2c4-4933-b334-18d99690ba0f"
},
"outputs": [],
"source": [
"# Step 4: Run the Tester on a subset of test data\n",
"tester = Tester(xgb_predictor, test)\n",
"tester.run()"
]
},
{
"cell_type": "markdown",
"id": "dcb09db0-7d69-40e1-a6e3-b92263e38f1e",
"metadata": {
"id": "dcb09db0-7d69-40e1-a6e3-b92263e38f1e"
},
"source": [
"Xgb Predictor Error=$110.68 RMSLE=0.93 Hits=30.4%"
]
},
{
"cell_type": "markdown",
"id": "1ccd5d3f-98cd-45a8-951f-d6446062addc",
"metadata": {
"id": "1ccd5d3f-98cd-45a8-951f-d6446062addc"
},
"source": [
"Results are nearly the same. In this setup, switching to contextual embeddings didnt yield performance gains for XGBoost."
]
},
{
"cell_type": "markdown",
"id": "4db1051d-9a7e-4cec-87fc-0d77fd858ced",
"metadata": {
"id": "4db1051d-9a7e-4cec-87fc-0d77fd858ced"
},
"source": [
"## 🚰 Retrieval-Augmented Pipeline GPT-4o Mini\n",
"\n",
"- Preprocess: clean the input text (description(item))\n",
"- Embed: generate embedding vector (get_embedding(item))\n",
"- Retrieve: find similar items from ChromaDB (find_similar_items)\n",
"- Build Prompt: create the LLM prompt using context and masked target (build_messages)\n",
"- Predict: get price estimate from LLM (estimate_price)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "YPLxSn7eHp9N",
"metadata": {
"id": "YPLxSn7eHp9N"
},
"outputs": [],
"source": [
"test[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eFxFKNroNiyD",
"metadata": {
"id": "eFxFKNroNiyD"
},
"outputs": [],
"source": [
"# Step 1: Preprocess test item text\n",
"# (uses the same `description(item)` function as during training)\n",
"description(test[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "lxIEtSWYHqCT",
"metadata": {
"id": "lxIEtSWYHqCT"
},
"outputs": [],
"source": [
"# Step 2: Embed a test item\n",
"def get_embedding(item):\n",
" return model_embedding.encode([description(item)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "y43prQsuHp_w",
"metadata": {
"id": "y43prQsuHp_w"
},
"outputs": [],
"source": [
"# Step 3: Query Chroma for similar items\n",
"def find_similars(item):\n",
" results = collection.query(query_embeddings=get_embedding(item).astype(float).tolist(), n_results=5)\n",
" documents = results['documents'][0][:]\n",
" prices = [m['price'] for m in results['metadatas'][0][:]]\n",
" return documents, prices"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "nxAOUFRkHp6v",
"metadata": {
"id": "nxAOUFRkHp6v"
},
"outputs": [],
"source": [
"documents, prices = find_similars(test[1])\n",
"documents, prices"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "djPoSk6sHo84",
"metadata": {
"id": "djPoSk6sHo84"
},
"outputs": [],
"source": [
"# Step 4: Format similar items as context\n",
"def format_context(similars, prices):\n",
" message = \"To provide some context, here are some other items that might be similar to the item you need to estimate.\\n\\n\"\n",
" for similar, price in zip(similars, prices):\n",
" message += f\"Potentially related product:\\n{similar}\\nPrice is ${price:.2f}\\n\\n\"\n",
" return message"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "F3yxhnqSHp4C",
"metadata": {
"id": "F3yxhnqSHp4C"
},
"outputs": [],
"source": [
"print(format_context(documents, prices))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pEJobsKNHqE8",
"metadata": {
"id": "pEJobsKNHqE8"
},
"outputs": [],
"source": [
"# Step 5: Mask the price in the test item\n",
"def mask_price_value(text):\n",
" return re.sub(r\"(\\n\\nPrice is \\$).*\", r\"\\1\", text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "vLhBNVBNQAHS",
"metadata": {
"id": "vLhBNVBNQAHS"
},
"outputs": [],
"source": [
"# Step 6: Build LLM messages\n",
"def build_messages(datapoint, similars, prices):\n",
"\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation.\"\n",
"\n",
" context = format_context(similars, prices)\n",
"\n",
" prompt = mask_price_value(datapoint[\"text\"])\n",
" prompt = prompt.replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
"\n",
" user_prompt = context + \"And now the question for you:\\n\\n\" + prompt\n",
"\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "I94fNHfBHp1a",
"metadata": {
"id": "I94fNHfBHp1a"
},
"outputs": [],
"source": [
"build_messages(test[1], documents, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5NfY_GAVHpy4",
"metadata": {
"id": "5NfY_GAVHpy4"
},
"outputs": [],
"source": [
"# Step 7: Run prediction\n",
"def get_price(s):\n",
" s = s.replace('$','').replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
" return float(match.group()) if match else 0\n",
"\n",
"def gpt_4o_mini_rag(item):\n",
" documents, prices = find_similars(item)\n",
" response = openai.chat.completions.create(\n",
" model=\"gpt-4o-mini\",\n",
" messages=build_messages(item, documents, prices),\n",
" seed=42,\n",
" max_tokens=5\n",
" )\n",
" reply = response.choices[0].message.content\n",
" return get_price(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "Pg-GJTT0HpwV",
"metadata": {
"id": "Pg-GJTT0HpwV"
},
"outputs": [],
"source": [
"print(test[1][\"price\"])\n",
"print(gpt_4o_mini_rag(test[1]))"
]
},
{
"cell_type": "markdown",
"id": "54103ab4-d6dd-4c0b-add5-5d9741e934b4",
"metadata": {
"id": "54103ab4-d6dd-4c0b-add5-5d9741e934b4"
},
"source": [
"🔔 Reminder: In Part 2, GPT-4o Mini (without RAG) achieved:\n",
"- Avg. Error: ~$99\n",
"- RMSLE: 0.75\n",
"- Accuracy: 44.8%\n",
"\n",
"🧪 Lets find out if RAG can boost GPT-4o Minis price prediction capabilities.\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "r0NGJupwHppF",
"metadata": {
"id": "r0NGJupwHppF"
},
"outputs": [],
"source": [
"Tester.test(gpt_4o_mini_rag, test)"
]
},
{
"cell_type": "markdown",
"id": "00545880-d9e1-4934-8008-b62c105d177b",
"metadata": {
"id": "00545880-d9e1-4934-8008-b62c105d177b"
},
"source": [
"Gpt 4O Mini Rag Error=$59.54 RMSLE=0.42 Hits=69.2%"
]
},
{
"cell_type": "markdown",
"id": "2b9f46ae-92b5-4189-89b0-df88a600bb89",
"metadata": {
"id": "2b9f46ae-92b5-4189-89b0-df88a600bb89"
},
"source": [
"🎉 **GPT-4o Mini + RAG shows clear gains:** \n",
"Average error dropped from **$99 → $59.54**, RMSLE from **0.75 → 0.42**, and accuracy rose from **48.8% → 69.2%**. \n",
"\n",
"Adding retrieval-based context led to a strong performance boost for GPT-4o Mini.\n",
"\n",
"Now the question is — can fine-tuning push it even further, surpass RAG, and challenge larger models?\n",
"\n",
"🔜 See you in the [next notebook](https://github.com/lisekarimi/lexo/blob/main/09_part4_ft_gpt4omini.ipynb)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}