Bootcamp: Solisoma(week7 assesment)
This commit is contained in:
@@ -0,0 +1,750 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "275415f0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# pip installs\n",
|
||||||
|
"\n",
|
||||||
|
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
|
||||||
|
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "535bd9de",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# imports\n",
|
||||||
|
"import re\n",
|
||||||
|
"import math\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"from google.colab import userdata\n",
|
||||||
|
"from huggingface_hub import login\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
|
||||||
|
"from datasets import load_dataset\n",
|
||||||
|
"from peft import PeftModel\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"# Auto-detect device\n",
|
||||||
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
|
"print(f\"Using device: {device}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fc58234a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Constants\n",
|
||||||
|
"\n",
|
||||||
|
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
|
||||||
|
"PROJECT_NAME = \"pricer\"\n",
|
||||||
|
"HF_USER = \"ed-donner\"\n",
|
||||||
|
"RUN_NAME = \"2024-09-13_13.04.39\"\n",
|
||||||
|
"PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
|
||||||
|
"REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
|
||||||
|
"FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
|
||||||
|
"# Or just use the one I've uploaded\n",
|
||||||
|
"# DATASET_NAME = \"ed-donner/pricer-data\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Hyperparameters for QLoRA\n",
|
||||||
|
"\n",
|
||||||
|
"QUANT_4_BIT = True\n",
|
||||||
|
"\n",
|
||||||
|
"%matplotlib inline\n",
|
||||||
|
"\n",
|
||||||
|
"# Used for writing to output in color\n",
|
||||||
|
"\n",
|
||||||
|
"GREEN = \"\\033[92m\"\n",
|
||||||
|
"YELLOW = \"\\033[93m\"\n",
|
||||||
|
"RED = \"\\033[91m\"\n",
|
||||||
|
"RESET = \"\\033[0m\"\n",
|
||||||
|
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0145ad8a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Log in to HuggingFace\n",
|
||||||
|
"\n",
|
||||||
|
"hf_token = userdata.get('HF_TOKEN')\n",
|
||||||
|
"login(hf_token, add_to_git_credential=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6919506e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = load_dataset(DATASET_NAME)\n",
|
||||||
|
"train_full = dataset['train']\n",
|
||||||
|
"test_full = dataset['test']\n",
|
||||||
|
"\n",
|
||||||
|
"# TRAIN_SIZE = len(train_full)\n",
|
||||||
|
"# TEST_SIZE = len(test_full)\n",
|
||||||
|
"\n",
|
||||||
|
"TRAIN_SIZE = 8000 # Very small for testing\n",
|
||||||
|
"TEST_SIZE = 2000 # Very small for testing\n",
|
||||||
|
"\n",
|
||||||
|
"train = train_full.select(range(min(TRAIN_SIZE, len(train_full))))\n",
|
||||||
|
"test = test_full.select(range(min(TEST_SIZE, len(test_full))))\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Using small test dataset:\")\n",
|
||||||
|
"print(f\" Train samples: {len(train)} (full dataset has {len(train_full)})\")\n",
|
||||||
|
"print(f\" Test samples: {len(test)} (full dataset has {len(test_full)})\")\n",
|
||||||
|
"print(f\"\\nTo use full dataset, set TRAIN_SIZE and TEST_SIZE to None or large numbers\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ea79cde1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"if QUANT_4_BIT:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_4bit=True,\n",
|
||||||
|
" bnb_4bit_use_double_quant=True,\n",
|
||||||
|
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
||||||
|
" bnb_4bit_quant_type=\"nf4\"\n",
|
||||||
|
" )\n",
|
||||||
|
"else:\n",
|
||||||
|
" quant_config = BitsAndBytesConfig(\n",
|
||||||
|
" load_in_8bit=True,\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ef108f8d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the Tokenizer and the Model\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
|
||||||
|
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||||
|
"tokenizer.padding_side = \"right\"\n",
|
||||||
|
"\n",
|
||||||
|
"base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||||||
|
" BASE_MODEL,\n",
|
||||||
|
" quantization_config=quant_config,\n",
|
||||||
|
" device_map=\"auto\",\n",
|
||||||
|
")\n",
|
||||||
|
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||||
|
"\n",
|
||||||
|
"# Load the fine-tuned model with PEFT\n",
|
||||||
|
"if REVISION:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
|
||||||
|
"else:\n",
|
||||||
|
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
|
||||||
|
"\n",
|
||||||
|
"fine_tuned_model.eval()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7f3c4176",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def extract_price(s):\n",
|
||||||
|
" \"\"\"Extract price from model output - expects format 'Price is $X.XX'\"\"\"\n",
|
||||||
|
" if not s or not isinstance(s, str):\n",
|
||||||
|
" return None\n",
|
||||||
|
" \n",
|
||||||
|
" if \"Price is $\" in s:\n",
|
||||||
|
" contents = s.split(\"Price is $\")[1]\n",
|
||||||
|
" contents = contents.replace(',', '') # Remove commas from numbers\n",
|
||||||
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
|
||||||
|
" \n",
|
||||||
|
" if match:\n",
|
||||||
|
" try:\n",
|
||||||
|
" return float(match.group())\n",
|
||||||
|
" except (ValueError, AttributeError):\n",
|
||||||
|
" return None\n",
|
||||||
|
" \n",
|
||||||
|
" return None"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "436fa29a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Original prediction function - greedy decoding (supports batch processing)\n",
|
||||||
|
"\n",
|
||||||
|
"def model_predict(prompt, device=device, batch_mode=False):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Simple greedy prediction with improved generation parameters.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" \n",
|
||||||
|
" # Handle batch mode\n",
|
||||||
|
" if batch_mode and isinstance(prompt, list):\n",
|
||||||
|
" return model_predict_batch(prompt, device)\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" attention_mask = torch.ones(inputs.shape, device=device)\n",
|
||||||
|
" \n",
|
||||||
|
" outputs = fine_tuned_model.generate(\n",
|
||||||
|
" inputs, \n",
|
||||||
|
" attention_mask=attention_mask, \n",
|
||||||
|
" max_new_tokens=15,\n",
|
||||||
|
" num_return_sequences=1,\n",
|
||||||
|
" temperature=0.1, # Lower temperature for more deterministic\n",
|
||||||
|
" do_sample=False, # Greedy decoding\n",
|
||||||
|
" pad_token_id=tokenizer.pad_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
" response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
||||||
|
" price = extract_price(response)\n",
|
||||||
|
" return price if price is not None else 0.0\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error in model_predict: {e}\")\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
"\n",
|
||||||
|
"def model_predict_batch(prompts, device=device):\n",
|
||||||
|
" \"\"\"Batch prediction for multiple prompts at once - much faster!\"\"\"\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" try:\n",
|
||||||
|
" # Tokenize all prompts at once with padding\n",
|
||||||
|
" inputs = tokenizer(\n",
|
||||||
|
" prompts, \n",
|
||||||
|
" return_tensors=\"pt\", \n",
|
||||||
|
" padding=True, \n",
|
||||||
|
" truncation=True,\n",
|
||||||
|
" max_length=512\n",
|
||||||
|
" ).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=15,\n",
|
||||||
|
" num_return_sequences=1,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" pad_token_id=tokenizer.pad_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # Decode all responses\n",
|
||||||
|
" responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Extract prices for all responses\n",
|
||||||
|
" prices = []\n",
|
||||||
|
" for response in responses:\n",
|
||||||
|
" price = extract_price(response)\n",
|
||||||
|
" prices.append(price if price is not None else 0.0)\n",
|
||||||
|
" \n",
|
||||||
|
" return prices\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error in model_predict_batch: {e}\")\n",
|
||||||
|
" return [0.0] * len(prompts)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a666dab6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Improved prediction function with dual strategy: full generation + fallback to weighted top-K\n",
|
||||||
|
"# Supports batch processing for faster inference\n",
|
||||||
|
"\n",
|
||||||
|
"top_K = 6\n",
|
||||||
|
"\n",
|
||||||
|
"def improved_model_predict(prompt, device=device, max_tokens=15, batch_mode=False):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Improved prediction using dual strategy:\n",
|
||||||
|
" 1. Full generation and extract price (handles multi-token prices)\n",
|
||||||
|
" 2. Fallback to weighted average of top-K token probabilities\n",
|
||||||
|
" \n",
|
||||||
|
" Args:\n",
|
||||||
|
" prompt: Single string or list of strings for batch processing\n",
|
||||||
|
" device: Device to use\n",
|
||||||
|
" max_tokens: Maximum tokens to generate\n",
|
||||||
|
" batch_mode: If True and prompt is a list, processes all at once (much faster!)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Handle batch mode\n",
|
||||||
|
" if batch_mode and isinstance(prompt, list):\n",
|
||||||
|
" return improved_model_predict_batch(prompt, device, max_tokens)\n",
|
||||||
|
" \n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" try:\n",
|
||||||
|
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" attention_mask = torch.ones(inputs.shape, device=device)\n",
|
||||||
|
"\n",
|
||||||
|
" # Strategy 1: Full generation and extract price (handles multi-token prices)\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model.generate(\n",
|
||||||
|
" inputs,\n",
|
||||||
|
" attention_mask=attention_mask,\n",
|
||||||
|
" max_new_tokens=max_tokens,\n",
|
||||||
|
" num_return_sequences=1,\n",
|
||||||
|
" temperature=0.1, # Lower temperature for deterministic output\n",
|
||||||
|
" do_sample=False, # Greedy decoding\n",
|
||||||
|
" pad_token_id=tokenizer.pad_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
" full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
||||||
|
" extracted_price = extract_price(full_response)\n",
|
||||||
|
" \n",
|
||||||
|
" if extracted_price is not None and extracted_price > 0:\n",
|
||||||
|
" return float(extracted_price)\n",
|
||||||
|
" \n",
|
||||||
|
" # Strategy 2: Fallback to single-token weighted average\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
|
||||||
|
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
|
||||||
|
"\n",
|
||||||
|
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||||
|
" top_probs, top_token_ids = next_token_probs.topk(top_K)\n",
|
||||||
|
" \n",
|
||||||
|
" prices, weights = [], []\n",
|
||||||
|
" for i in range(top_K):\n",
|
||||||
|
" predicted_token = tokenizer.decode([top_token_ids[0][i].item()], skip_special_tokens=True)\n",
|
||||||
|
" probability = top_probs[0][i].item()\n",
|
||||||
|
" try:\n",
|
||||||
|
" result = float(predicted_token)\n",
|
||||||
|
" except ValueError:\n",
|
||||||
|
" continue\n",
|
||||||
|
" if result > 0:\n",
|
||||||
|
" prices.append(result)\n",
|
||||||
|
" weights.append(probability)\n",
|
||||||
|
" \n",
|
||||||
|
" if not prices:\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" # Weighted average\n",
|
||||||
|
" total = sum(weights)\n",
|
||||||
|
" if total == 0:\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" weighted_prices = [price * weight / total for price, weight in zip(prices, weights)]\n",
|
||||||
|
" return sum(weighted_prices)\n",
|
||||||
|
" \n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error in improved_model_predict: {e}\")\n",
|
||||||
|
" return 0.0\n",
|
||||||
|
"\n",
|
||||||
|
"def improved_model_predict_batch(prompts, device=device, max_tokens=15):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Batch version of improved_model_predict - processes multiple prompts in parallel.\n",
|
||||||
|
" This is MUCH faster than calling improved_model_predict in a loop!\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" set_seed(42)\n",
|
||||||
|
" try:\n",
|
||||||
|
" # Tokenize all prompts at once with padding\n",
|
||||||
|
" inputs = tokenizer(\n",
|
||||||
|
" prompts,\n",
|
||||||
|
" return_tensors=\"pt\",\n",
|
||||||
|
" padding=True,\n",
|
||||||
|
" truncation=True,\n",
|
||||||
|
" max_length=512\n",
|
||||||
|
" ).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" prices = []\n",
|
||||||
|
" \n",
|
||||||
|
" # Strategy 1: Full generation for all prompts at once\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" outputs = fine_tuned_model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=max_tokens,\n",
|
||||||
|
" num_return_sequences=1,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" pad_token_id=tokenizer.pad_token_id\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # Decode all responses\n",
|
||||||
|
" responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Extract prices - try Strategy 1 first\n",
|
||||||
|
" need_fallback = []\n",
|
||||||
|
" fallback_indices = []\n",
|
||||||
|
" \n",
|
||||||
|
" for idx, response in enumerate(responses):\n",
|
||||||
|
" extracted_price = extract_price(response)\n",
|
||||||
|
" if extracted_price is not None and extracted_price > 0:\n",
|
||||||
|
" prices.append(float(extracted_price))\n",
|
||||||
|
" else:\n",
|
||||||
|
" prices.append(None) # Mark for fallback\n",
|
||||||
|
" need_fallback.append(prompts[idx])\n",
|
||||||
|
" fallback_indices.append(idx)\n",
|
||||||
|
" \n",
|
||||||
|
" # Strategy 2: Fallback for items that failed Strategy 1\n",
|
||||||
|
" if need_fallback:\n",
|
||||||
|
" # Re-encode only the ones that need fallback\n",
|
||||||
|
" fallback_inputs = tokenizer(\n",
|
||||||
|
" need_fallback,\n",
|
||||||
|
" return_tensors=\"pt\",\n",
|
||||||
|
" padding=True,\n",
|
||||||
|
" truncation=True,\n",
|
||||||
|
" max_length=512\n",
|
||||||
|
" ).to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" fallback_outputs = fine_tuned_model(**fallback_inputs)\n",
|
||||||
|
" next_token_logits = fallback_outputs.logits[:, -1, :].to('cpu')\n",
|
||||||
|
" \n",
|
||||||
|
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
|
||||||
|
" top_probs, top_token_ids = next_token_probs.topk(top_K)\n",
|
||||||
|
" \n",
|
||||||
|
" # Process each fallback item\n",
|
||||||
|
" for batch_idx, original_idx in enumerate(fallback_indices):\n",
|
||||||
|
" batch_prices, batch_weights = [], []\n",
|
||||||
|
" \n",
|
||||||
|
" for k in range(top_K):\n",
|
||||||
|
" predicted_token = tokenizer.decode(\n",
|
||||||
|
" [top_token_ids[batch_idx][k].item()], \n",
|
||||||
|
" skip_special_tokens=True\n",
|
||||||
|
" )\n",
|
||||||
|
" probability = top_probs[batch_idx][k].item()\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" result = float(predicted_token)\n",
|
||||||
|
" except ValueError:\n",
|
||||||
|
" continue\n",
|
||||||
|
" \n",
|
||||||
|
" if result > 0:\n",
|
||||||
|
" batch_prices.append(result)\n",
|
||||||
|
" batch_weights.append(probability)\n",
|
||||||
|
" \n",
|
||||||
|
" if batch_prices:\n",
|
||||||
|
" total = sum(batch_weights)\n",
|
||||||
|
" if total > 0:\n",
|
||||||
|
" weighted_avg = sum(p * w / total for p, w in zip(batch_prices, batch_weights))\n",
|
||||||
|
" prices[original_idx] = weighted_avg\n",
|
||||||
|
" else:\n",
|
||||||
|
" prices[original_idx] = 0.0\n",
|
||||||
|
" else:\n",
|
||||||
|
" prices[original_idx] = 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" # Replace None with 0.0\n",
|
||||||
|
" return [p if p is not None else 0.0 for p in prices]\n",
|
||||||
|
" \n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error in improved_model_predict_batch: {e}\")\n",
|
||||||
|
" return [0.0] * len(prompts)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9664c4c7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Tester:\n",
|
||||||
|
"\n",
|
||||||
|
" def __init__(self, predictor, data, title=None, size=250):\n",
|
||||||
|
" self.predictor = predictor\n",
|
||||||
|
" self.data = data\n",
|
||||||
|
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
|
||||||
|
" self.size = min(size, len(data)) if data else size\n",
|
||||||
|
" self.guesses = []\n",
|
||||||
|
" self.truths = []\n",
|
||||||
|
" self.errors = []\n",
|
||||||
|
" self.sles = []\n",
|
||||||
|
" self.colors = []\n",
|
||||||
|
" self.relative_errors = []\n",
|
||||||
|
"\n",
|
||||||
|
" def color_for(self, error, truth):\n",
|
||||||
|
" \"\"\"Determine color with safe division handling\"\"\"\n",
|
||||||
|
" if truth == 0:\n",
|
||||||
|
" # If truth is 0, use absolute error only\n",
|
||||||
|
" if error < 40:\n",
|
||||||
|
" return \"green\"\n",
|
||||||
|
" elif error < 80:\n",
|
||||||
|
" return \"orange\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" return \"red\"\n",
|
||||||
|
" \n",
|
||||||
|
" relative_error = error / truth\n",
|
||||||
|
" if error < 40 or relative_error < 0.2:\n",
|
||||||
|
" return \"green\"\n",
|
||||||
|
" elif error < 80 or relative_error < 0.4:\n",
|
||||||
|
" return \"orange\"\n",
|
||||||
|
" else:\n",
|
||||||
|
" return \"red\"\n",
|
||||||
|
"\n",
|
||||||
|
" def run_datapoint(self, i):\n",
|
||||||
|
" \"\"\"Test a single datapoint\"\"\"\n",
|
||||||
|
" datapoint = self.data[i]\n",
|
||||||
|
" guess = self.predictor(datapoint[\"text\"])\n",
|
||||||
|
" truth = float(datapoint[\"price\"])\n",
|
||||||
|
" \n",
|
||||||
|
" # Handle invalid guesses (None, tuple, negative)\n",
|
||||||
|
" if guess is None:\n",
|
||||||
|
" guess = 0.0\n",
|
||||||
|
" if isinstance(guess, tuple):\n",
|
||||||
|
" guess = guess[0] if len(guess) > 0 else 0.0\n",
|
||||||
|
" if guess < 0:\n",
|
||||||
|
" guess = 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" error = abs(guess - truth)\n",
|
||||||
|
" relative_error = error / truth if truth > 0 else error\n",
|
||||||
|
" log_error = math.log(truth + 1) - math.log(guess + 1)\n",
|
||||||
|
" sle = log_error ** 2\n",
|
||||||
|
" color = self.color_for(error, truth)\n",
|
||||||
|
" \n",
|
||||||
|
" # Extract item title safely\n",
|
||||||
|
" try:\n",
|
||||||
|
" title_parts = datapoint[\"text\"].split(\"\\n\\n\")\n",
|
||||||
|
" title = (title_parts[1][:40] + \"...\") if len(title_parts) > 1 else \"Unknown\"\n",
|
||||||
|
" except:\n",
|
||||||
|
" title = \"Unknown\"\n",
|
||||||
|
" \n",
|
||||||
|
" self.guesses.append(guess)\n",
|
||||||
|
" self.truths.append(truth)\n",
|
||||||
|
" self.errors.append(error)\n",
|
||||||
|
" self.relative_errors.append(relative_error)\n",
|
||||||
|
" self.sles.append(sle)\n",
|
||||||
|
" self.colors.append(color)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} ({relative_error*100:.1f}%) SLE: {sle:.4f} Item: {title}{RESET}\")\n",
|
||||||
|
"\n",
|
||||||
|
" def chart(self, title):\n",
|
||||||
|
" \"\"\"Create comprehensive visualization\"\"\"\n",
|
||||||
|
" fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
|
||||||
|
" \n",
|
||||||
|
" # 1. Scatter plot: Predictions vs Truth\n",
|
||||||
|
" ax1 = axes[0, 0]\n",
|
||||||
|
" max_val = max(max(self.truths), max(self.guesses)) * 1.1\n",
|
||||||
|
" ax1.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6, label='Perfect prediction')\n",
|
||||||
|
" ax1.scatter(self.truths, self.guesses, s=20, c=self.colors, alpha=0.6)\n",
|
||||||
|
" ax1.set_xlabel('Ground Truth Price ($)', fontsize=12)\n",
|
||||||
|
" ax1.set_ylabel('Predicted Price ($)', fontsize=12)\n",
|
||||||
|
" ax1.set_xlim(0, max_val)\n",
|
||||||
|
" ax1.set_ylim(0, max_val)\n",
|
||||||
|
" ax1.set_title('Predictions vs Ground Truth', fontsize=14)\n",
|
||||||
|
" ax1.legend()\n",
|
||||||
|
" ax1.grid(True, alpha=0.3)\n",
|
||||||
|
" \n",
|
||||||
|
" # 2. Error distribution histogram\n",
|
||||||
|
" ax2 = axes[0, 1]\n",
|
||||||
|
" ax2.hist(self.errors, bins=30, color='skyblue', alpha=0.7, edgecolor='black')\n",
|
||||||
|
" ax2.axvline(np.mean(self.errors), color='red', linestyle='--', label='Mean Error')\n",
|
||||||
|
" ax2.set_xlabel('Absolute Error ($)', fontsize=12)\n",
|
||||||
|
" ax2.set_ylabel('Frequency', fontsize=12)\n",
|
||||||
|
" ax2.set_title('Error Distribution', fontsize=14)\n",
|
||||||
|
" ax2.legend()\n",
|
||||||
|
" ax2.grid(True, alpha=0.3)\n",
|
||||||
|
" \n",
|
||||||
|
" # 3. Relative error distribution\n",
|
||||||
|
" ax3 = axes[1, 0]\n",
|
||||||
|
" relative_errors_pct = [e * 100 for e in self.relative_errors]\n",
|
||||||
|
" ax3.hist(relative_errors_pct, bins=30, color='lightcoral', alpha=0.7, edgecolor='black')\n",
|
||||||
|
" ax3.set_xlabel('Relative Error (%)', fontsize=12)\n",
|
||||||
|
" ax3.set_ylabel('Frequency', fontsize=12)\n",
|
||||||
|
" ax3.set_title('Relative Error Distribution', fontsize=14)\n",
|
||||||
|
" ax3.grid(True, alpha=0.3)\n",
|
||||||
|
" \n",
|
||||||
|
" # 4. Accuracy by price range\n",
|
||||||
|
" ax4 = axes[1, 1]\n",
|
||||||
|
" price_ranges = [(0, 50), (50, 100), (100, 200), (200, 500), (500, float('inf'))]\n",
|
||||||
|
" range_errors = []\n",
|
||||||
|
" range_labels = []\n",
|
||||||
|
" for low, high in price_ranges:\n",
|
||||||
|
" range_indices = [i for i, t in enumerate(self.truths) if low <= t < high]\n",
|
||||||
|
" if range_indices:\n",
|
||||||
|
" avg_error = np.mean([self.errors[i] for i in range_indices])\n",
|
||||||
|
" range_errors.append(avg_error)\n",
|
||||||
|
" range_labels.append(f\"${low}-${high if high != float('inf') else '+'}\")\n",
|
||||||
|
" \n",
|
||||||
|
" ax4.bar(range_labels, range_errors, color='steelblue', alpha=0.7)\n",
|
||||||
|
" ax4.set_xlabel('Price Range ($)', fontsize=12)\n",
|
||||||
|
" ax4.set_ylabel('Average Error ($)', fontsize=12)\n",
|
||||||
|
" ax4.set_title('Average Error by Price Range', fontsize=14)\n",
|
||||||
|
" ax4.tick_params(axis='x', rotation=45)\n",
|
||||||
|
" ax4.grid(True, alpha=0.3, axis='y')\n",
|
||||||
|
" \n",
|
||||||
|
" plt.tight_layout()\n",
|
||||||
|
" plt.suptitle(title, fontsize=16, y=1.02)\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
" def calculate_metrics(self):\n",
|
||||||
|
" \"\"\"Calculate comprehensive evaluation metrics\"\"\"\n",
|
||||||
|
" guesses_arr = np.array(self.guesses)\n",
|
||||||
|
" truths_arr = np.array(self.truths)\n",
|
||||||
|
" errors_arr = np.array(self.errors)\n",
|
||||||
|
" \n",
|
||||||
|
" metrics = {\n",
|
||||||
|
" 'mae': np.mean(errors_arr), # Mean Absolute Error\n",
|
||||||
|
" 'median_error': np.median(errors_arr),\n",
|
||||||
|
" 'rmse': np.sqrt(np.mean(errors_arr ** 2)), # Root Mean Squared Error\n",
|
||||||
|
" 'rmsle': math.sqrt(sum(self.sles) / self.size),\n",
|
||||||
|
" 'mape': np.mean([abs(e) if t > 0 else 0 for e, t in zip(errors_arr/truths_arr, truths_arr)]) * 100,\n",
|
||||||
|
" }\n",
|
||||||
|
" \n",
|
||||||
|
" # R² (coefficient of determination)\n",
|
||||||
|
" ss_res = np.sum((truths_arr - guesses_arr) ** 2)\n",
|
||||||
|
" ss_tot = np.sum((truths_arr - np.mean(truths_arr)) ** 2)\n",
|
||||||
|
" metrics['r2'] = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0\n",
|
||||||
|
" \n",
|
||||||
|
" # Hit rates\n",
|
||||||
|
" hits_green = sum(1 for c in self.colors if c == \"green\")\n",
|
||||||
|
" hits_orange_green = sum(1 for c in self.colors if c in [\"green\", \"orange\"])\n",
|
||||||
|
" metrics['hit_rate_green'] = hits_green / self.size * 100\n",
|
||||||
|
" metrics['hit_rate_acceptable'] = hits_orange_green / self.size * 100\n",
|
||||||
|
" \n",
|
||||||
|
" return metrics\n",
|
||||||
|
"\n",
|
||||||
|
" def report(self):\n",
|
||||||
|
" \"\"\"Generate comprehensive report\"\"\"\n",
|
||||||
|
" metrics = self.calculate_metrics()\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"\\n{'='*70}\")\n",
|
||||||
|
" print(f\"FINAL REPORT: {self.title}\")\n",
|
||||||
|
" print(f\"{'='*70}\")\n",
|
||||||
|
" print(f\"Total Predictions: {self.size}\")\n",
|
||||||
|
" print(f\"\\n--- Error Metrics ---\")\n",
|
||||||
|
" print(f\"Mean Absolute Error (MAE): ${metrics['mae']:,.2f}\")\n",
|
||||||
|
" print(f\"Median Error: ${metrics['median_error']:,.2f}\")\n",
|
||||||
|
" print(f\"Root Mean Squared Error (RMSE): ${metrics['rmse']:,.2f}\")\n",
|
||||||
|
" print(f\"Root Mean Squared Log Error: {metrics['rmsle']:.4f}\")\n",
|
||||||
|
" print(f\"Mean Absolute Percentage Error: {metrics['mape']:.2f}%\")\n",
|
||||||
|
" print(f\"\\n--- Accuracy Metrics ---\")\n",
|
||||||
|
" print(f\"R² Score (Coefficient of Determination): {metrics['r2']:.4f}\")\n",
|
||||||
|
" print(f\"Hit Rate (Green - Excellent): {metrics['hit_rate_green']:.1f}%\")\n",
|
||||||
|
" print(f\"Hit Rate (Green+Orange - Good): {metrics['hit_rate_acceptable']:.1f}%\")\n",
|
||||||
|
" print(f\"{'='*70}\\n\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Create visualization\n",
|
||||||
|
" chart_title = f\"{self.title} | MAE=${metrics['mae']:,.2f} | RMSLE={metrics['rmsle']:.3f} | R²={metrics['r2']:.3f}\"\n",
|
||||||
|
" self.chart(chart_title)\n",
|
||||||
|
" \n",
|
||||||
|
" return metrics\n",
|
||||||
|
"\n",
|
||||||
|
" def run(self, show_progress=True, batch_size=8):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Run test on all datapoints with progress bar.\n",
|
||||||
|
" \n",
|
||||||
|
" Args:\n",
|
||||||
|
" show_progress: Show progress bar\n",
|
||||||
|
" batch_size: Process this many items at once (0 = no batching, process one by one)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" print(f\"Testing {self.size} predictions with {self.title}...\\n\")\n",
|
||||||
|
" \n",
|
||||||
|
" if batch_size > 1:\n",
|
||||||
|
" # Batch processing mode - much faster!\n",
|
||||||
|
" print(f\"Using batch processing with batch_size={batch_size}\")\n",
|
||||||
|
" texts = [self.data[i][\"text\"] for i in range(self.size)]\n",
|
||||||
|
" \n",
|
||||||
|
" iterator = tqdm(range(0, self.size, batch_size), desc=\"Batch Predicting\") if show_progress else range(0, self.size, batch_size)\n",
|
||||||
|
" \n",
|
||||||
|
" for batch_start in iterator:\n",
|
||||||
|
" batch_end = min(batch_start + batch_size, self.size)\n",
|
||||||
|
" batch_texts = texts[batch_start:batch_end]\n",
|
||||||
|
" \n",
|
||||||
|
" # Get batch predictions\n",
|
||||||
|
" batch_guesses = self.predictor(batch_texts, batch_mode=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Process each result in the batch\n",
|
||||||
|
" for i, guess in enumerate(batch_guesses):\n",
|
||||||
|
" actual_idx = batch_start + i\n",
|
||||||
|
" self.run_datapoint_internal(actual_idx, guess)\n",
|
||||||
|
" else:\n",
|
||||||
|
" # Sequential processing (original method)\n",
|
||||||
|
" iterator = tqdm(range(self.size), desc=\"Predicting\") if show_progress else range(self.size)\n",
|
||||||
|
" for i in iterator:\n",
|
||||||
|
" self.run_datapoint(i)\n",
|
||||||
|
" \n",
|
||||||
|
" return self.report()\n",
|
||||||
|
" \n",
|
||||||
|
" def run_datapoint_internal(self, i, guess):\n",
|
||||||
|
" \"\"\"Internal method to process a single datapoint when we already have the guess\"\"\"\n",
|
||||||
|
" datapoint = self.data[i]\n",
|
||||||
|
" truth = float(datapoint[\"price\"])\n",
|
||||||
|
" \n",
|
||||||
|
" # Handle invalid guesses (None, tuple, negative)\n",
|
||||||
|
" if guess is None:\n",
|
||||||
|
" guess = 0.0\n",
|
||||||
|
" if isinstance(guess, tuple):\n",
|
||||||
|
" guess = guess[0] if len(guess) > 0 else 0.0\n",
|
||||||
|
" if guess < 0:\n",
|
||||||
|
" guess = 0.0\n",
|
||||||
|
" \n",
|
||||||
|
" error = abs(guess - truth)\n",
|
||||||
|
" relative_error = error / truth if truth > 0 else error\n",
|
||||||
|
" log_error = math.log(truth + 1) - math.log(guess + 1)\n",
|
||||||
|
" sle = log_error ** 2\n",
|
||||||
|
" color = self.color_for(error, truth)\n",
|
||||||
|
" \n",
|
||||||
|
" # Extract item title safely\n",
|
||||||
|
" try:\n",
|
||||||
|
" title_parts = datapoint[\"text\"].split(\"\\n\\n\")\n",
|
||||||
|
" title = (title_parts[1][:40] + \"...\") if len(title_parts) > 1 else \"Unknown\"\n",
|
||||||
|
" except:\n",
|
||||||
|
" title = \"Unknown\"\n",
|
||||||
|
" \n",
|
||||||
|
" self.guesses.append(guess)\n",
|
||||||
|
" self.truths.append(truth)\n",
|
||||||
|
" self.errors.append(error)\n",
|
||||||
|
" self.relative_errors.append(relative_error)\n",
|
||||||
|
" self.sles.append(sle)\n",
|
||||||
|
" self.colors.append(color)\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} ({relative_error*100:.1f}%) SLE: {sle:.4f} Item: {title}{RESET}\")\n",
|
||||||
|
"\n",
|
||||||
|
" @classmethod\n",
|
||||||
|
" def test(cls, function, data, title=None, size=250, batch_size=8):\n",
|
||||||
|
" \"\"\"Quick test method with optional batch processing\"\"\"\n",
|
||||||
|
" return cls(function, data, title, size).run(batch_size=batch_size)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2e60a696",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test_size = len(test)\n",
|
||||||
|
"batch_size = 1 # increase to 2 for faster processing\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Running test with {test_size} samples, batch_size={batch_size}\")\n",
|
||||||
|
"\n",
|
||||||
|
"results = Tester.test(\n",
|
||||||
|
" improved_model_predict, \n",
|
||||||
|
" test, \n",
|
||||||
|
" title=\"GPT-4o-mini Fine-tuned (Improved - Test Mode)\",\n",
|
||||||
|
" size=test_size,\n",
|
||||||
|
" batch_size=batch_size\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user