From 54d7ffce6d96345cc1821c2b0486aaded6559c74 Mon Sep 17 00:00:00 2001 From: Tochi-Nwachukwu Date: Mon, 27 Oct 2025 16:56:34 +0100 Subject: [PATCH] [Bug Fix] - Fixed the overfitting bug - issue was with the messages_for_function() --- .../tochi/product_pricer_finetuning.ipynb | 99 ++++++++++++++++--- 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/week6/community-contributions/tochi/product_pricer_finetuning.ipynb b/week6/community-contributions/tochi/product_pricer_finetuning.ipynb index 18625e4..444f827 100644 --- a/week6/community-contributions/tochi/product_pricer_finetuning.ipynb +++ b/week6/community-contributions/tochi/product_pricer_finetuning.ipynb @@ -51,6 +51,7 @@ "# imports\n", "\n", "import os\n", + "import re\n", "from google.colab import userdata\n", "import json\n", "from dotenv import load_dotenv\n", @@ -63,7 +64,8 @@ "from openai import OpenAI\n", "from typing import Optional\n", "import re\n", - "from datasets import load_dataset\n" + "from datasets import load_dataset\n", + "import random\n" ] }, { @@ -146,8 +148,17 @@ "source": [ "# Access the training data, and dividing it into train and test data\n", "total_length = len(dataset[\"train\"])\n", - "train_data = dataset[\"train\"].select(range(total_length - 2000))\n", - "test_data = dataset[\"train\"].select(range(total_length - 2000, total_length))" + "\n", + "# Shuffle indices\n", + "all_indices = list(range(total_length))\n", + "random.seed(42)\n", + "random.shuffle(all_indices)\n", + "\n", + "train_indices = all_indices[:-2000]\n", + "test_indices = all_indices[-2000:]\n", + "\n", + "train_data = dataset[\"train\"].select(train_indices)\n", + "test_data = dataset[\"train\"].select(test_indices)" ] }, { @@ -199,18 +210,40 @@ }, "outputs": [], "source": [ + "# This function thoroughly formats the price data to make sure that there is no data leak into the training model\n", "def messages_for(item):\n", " system_message = \"You are a price estimation assistant. Respond only with the estimated price in the format: Price is $X.XX\"\n", - "\n", - " # Clean the user prompt more robustly\n", + " \n", " user_prompt = item[\"text\"]\n", + " price = item[\"price\"]\n", + "\n", " user_prompt = user_prompt.replace(\" to the nearest dollar\", \"\")\n", " user_prompt = user_prompt.replace(\"\\n\\nPrice is $\", \"\")\n", + " \n", + " price_formats = [\n", + " f\"{price:.2f}\", \n", + " f\"{price:.0f}\", \n", + " f\"{price}\", \n", + " f\"{int(price)}\", \n", + " f\"{price:.2f}\".replace('.', ''), \n", + " ]\n", + " \n", + " for price_str in price_formats:\n", + " if user_prompt.endswith(price_str):\n", + " user_prompt = user_prompt[:-len(price_str)].strip()\n", + " break\n", + " if f\"${price_str}\" in user_prompt:\n", + " user_prompt = user_prompt.replace(f\"${price_str}\", \"\").strip()\n", + " if user_prompt.rstrip().endswith(price_str):\n", + " user_prompt = user_prompt.rstrip()[:-len(price_str)].strip()\n", "\n", - " # Remove any trailing price information if it exists\n", - " if user_prompt.endswith(str(item[\"price\"])):\n", - " user_prompt = user_prompt.rsplit(str(item[\"price\"]), 1)[0].strip()\n", + " user_prompt = re.sub(r'(\\d+\\.?\\d{0,2})$', '', user_prompt).strip()\n", + " \n", + " user_prompt = re.sub(r'\\$\\s*[\\d,]+\\.?\\d{0,2}', '', user_prompt)\n", "\n", + " if re.search(rf'\\b{int(price)}\\b\\s*$', user_prompt):\n", + " user_prompt = re.sub(rf'\\b{int(price)}\\b\\s*$', '', user_prompt).strip()\n", + " \n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": user_prompt.strip()},\n", @@ -451,13 +484,55 @@ "outputs": [], "source": [ "# Try this out\n", + "\n", + "\n", "def messages_for(item):\n", - " system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n", - " user_prompt = item[\"text\"].replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n", + " system_message = \"You are a price estimation assistant. Respond only with the estimated price in the format: Price is $X.XX\"\n", + " \n", + " user_prompt = item[\"text\"]\n", + " price = item[\"price\"]\n", + " \n", + " # Remove common price-related phrases\n", + " user_prompt = user_prompt.replace(\" to the nearest dollar\", \"\")\n", + " user_prompt = user_prompt.replace(\"\\n\\nPrice is $\", \"\")\n", + " \n", + " # Create multiple price format variations to remove\n", + " price_formats = [\n", + " f\"{price:.2f}\", # 329.00\n", + " f\"{price:.0f}\", # 329\n", + " f\"{price}\", # 329.0\n", + " f\"{int(price)}\", # 329\n", + " f\"{price:.2f}\".replace('.', ''), # 32900\n", + " ]\n", + " \n", + " # Try to remove each format from the end of the string\n", + " for price_str in price_formats:\n", + " # Remove from end (most common)\n", + " if user_prompt.endswith(price_str):\n", + " user_prompt = user_prompt[:-len(price_str)].strip()\n", + " break\n", + " # Remove with $ prefix\n", + " if f\"${price_str}\" in user_prompt:\n", + " user_prompt = user_prompt.replace(f\"${price_str}\", \"\").strip()\n", + " # Remove standalone number at the end\n", + " if user_prompt.rstrip().endswith(price_str):\n", + " user_prompt = user_prompt.rstrip()[:-len(price_str)].strip()\n", + " \n", + " # Additional regex cleanup - remove any trailing number that might be a price\n", + " # This catches cases where the price is stuck to the end of a word\n", + " user_prompt = re.sub(r'(\\d+\\.?\\d{0,2})$', '', user_prompt).strip()\n", + " \n", + " # Remove $ signs followed by numbers anywhere in the text\n", + " user_prompt = re.sub(r'\\$\\s*[\\d,]+\\.?\\d{0,2}', '', user_prompt)\n", + " \n", + " # Final safety check - if the price (as int) appears at the very end, remove it\n", + " if re.search(rf'\\b{int(price)}\\b\\s*$', user_prompt):\n", + " user_prompt = re.sub(rf'\\b{int(price)}\\b\\s*$', '', user_prompt).strip()\n", + " \n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", - " {\"role\": \"user\", \"content\": user_prompt},\n", - " {\"role\": \"assistant\", \"content\": \"Price is $\"}\n", + " {\"role\": \"user\", \"content\": user_prompt.strip()},\n", + " {\"role\": \"assistant\", \"content\": f\"Price is ${item['price']:.2f}\"}\n", " ]" ] },