[Bug Fix] - Fixed the overfitting bug - issue was with the messages_for_function()

This commit is contained in:
Tochi-Nwachukwu
2025-10-27 16:56:34 +01:00
parent 3081100dc9
commit 54d7ffce6d

View File

@@ -51,6 +51,7 @@
"# imports\n",
"\n",
"import os\n",
"import re\n",
"from google.colab import userdata\n",
"import json\n",
"from dotenv import load_dotenv\n",
@@ -63,7 +64,8 @@
"from openai import OpenAI\n",
"from typing import Optional\n",
"import re\n",
"from datasets import load_dataset\n"
"from datasets import load_dataset\n",
"import random\n"
]
},
{
@@ -146,8 +148,17 @@
"source": [
"# Access the training data, and dividing it into train and test data\n",
"total_length = len(dataset[\"train\"])\n",
"train_data = dataset[\"train\"].select(range(total_length - 2000))\n",
"test_data = dataset[\"train\"].select(range(total_length - 2000, total_length))"
"\n",
"# Shuffle indices\n",
"all_indices = list(range(total_length))\n",
"random.seed(42)\n",
"random.shuffle(all_indices)\n",
"\n",
"train_indices = all_indices[:-2000]\n",
"test_indices = all_indices[-2000:]\n",
"\n",
"train_data = dataset[\"train\"].select(train_indices)\n",
"test_data = dataset[\"train\"].select(test_indices)"
]
},
{
@@ -199,18 +210,40 @@
},
"outputs": [],
"source": [
"# This function thoroughly formats the price data to make sure that there is no data leak into the training model\n",
"def messages_for(item):\n",
" system_message = \"You are a price estimation assistant. Respond only with the estimated price in the format: Price is $X.XX\"\n",
"\n",
" # Clean the user prompt more robustly\n",
" \n",
" user_prompt = item[\"text\"]\n",
" price = item[\"price\"]\n",
"\n",
" user_prompt = user_prompt.replace(\" to the nearest dollar\", \"\")\n",
" user_prompt = user_prompt.replace(\"\\n\\nPrice is $\", \"\")\n",
" \n",
" price_formats = [\n",
" f\"{price:.2f}\", \n",
" f\"{price:.0f}\", \n",
" f\"{price}\", \n",
" f\"{int(price)}\", \n",
" f\"{price:.2f}\".replace('.', ''), \n",
" ]\n",
" \n",
" for price_str in price_formats:\n",
" if user_prompt.endswith(price_str):\n",
" user_prompt = user_prompt[:-len(price_str)].strip()\n",
" break\n",
" if f\"${price_str}\" in user_prompt:\n",
" user_prompt = user_prompt.replace(f\"${price_str}\", \"\").strip()\n",
" if user_prompt.rstrip().endswith(price_str):\n",
" user_prompt = user_prompt.rstrip()[:-len(price_str)].strip()\n",
"\n",
" # Remove any trailing price information if it exists\n",
" if user_prompt.endswith(str(item[\"price\"])):\n",
" user_prompt = user_prompt.rsplit(str(item[\"price\"]), 1)[0].strip()\n",
" user_prompt = re.sub(r'(\\d+\\.?\\d{0,2})$', '', user_prompt).strip()\n",
" \n",
" user_prompt = re.sub(r'\\$\\s*[\\d,]+\\.?\\d{0,2}', '', user_prompt)\n",
"\n",
" if re.search(rf'\\b{int(price)}\\b\\s*$', user_prompt):\n",
" user_prompt = re.sub(rf'\\b{int(price)}\\b\\s*$', '', user_prompt).strip()\n",
" \n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt.strip()},\n",
@@ -451,13 +484,55 @@
"outputs": [],
"source": [
"# Try this out\n",
"\n",
"\n",
"def messages_for(item):\n",
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
" user_prompt = item[\"text\"].replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
" system_message = \"You are a price estimation assistant. Respond only with the estimated price in the format: Price is $X.XX\"\n",
" \n",
" user_prompt = item[\"text\"]\n",
" price = item[\"price\"]\n",
" \n",
" # Remove common price-related phrases\n",
" user_prompt = user_prompt.replace(\" to the nearest dollar\", \"\")\n",
" user_prompt = user_prompt.replace(\"\\n\\nPrice is $\", \"\")\n",
" \n",
" # Create multiple price format variations to remove\n",
" price_formats = [\n",
" f\"{price:.2f}\", # 329.00\n",
" f\"{price:.0f}\", # 329\n",
" f\"{price}\", # 329.0\n",
" f\"{int(price)}\", # 329\n",
" f\"{price:.2f}\".replace('.', ''), # 32900\n",
" ]\n",
" \n",
" # Try to remove each format from the end of the string\n",
" for price_str in price_formats:\n",
" # Remove from end (most common)\n",
" if user_prompt.endswith(price_str):\n",
" user_prompt = user_prompt[:-len(price_str)].strip()\n",
" break\n",
" # Remove with $ prefix\n",
" if f\"${price_str}\" in user_prompt:\n",
" user_prompt = user_prompt.replace(f\"${price_str}\", \"\").strip()\n",
" # Remove standalone number at the end\n",
" if user_prompt.rstrip().endswith(price_str):\n",
" user_prompt = user_prompt.rstrip()[:-len(price_str)].strip()\n",
" \n",
" # Additional regex cleanup - remove any trailing number that might be a price\n",
" # This catches cases where the price is stuck to the end of a word\n",
" user_prompt = re.sub(r'(\\d+\\.?\\d{0,2})$', '', user_prompt).strip()\n",
" \n",
" # Remove $ signs followed by numbers anywhere in the text\n",
" user_prompt = re.sub(r'\\$\\s*[\\d,]+\\.?\\d{0,2}', '', user_prompt)\n",
" \n",
" # Final safety check - if the price (as int) appears at the very end, remove it\n",
" if re.search(rf'\\b{int(price)}\\b\\s*$', user_prompt):\n",
" user_prompt = re.sub(rf'\\b{int(price)}\\b\\s*$', '', user_prompt).strip()\n",
" \n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt},\n",
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
" {\"role\": \"user\", \"content\": user_prompt.strip()},\n",
" {\"role\": \"assistant\", \"content\": f\"Price is ${item['price']:.2f}\"}\n",
" ]"
]
},