Merge pull request #873 from TheTopDeveloper/community-contributions-branch
Add Week 6 finetuning solution with pickle data and enhanced modules- Joshua Oluoch (Gen AI Bootcamp)
This commit is contained in:
@@ -0,0 +1,828 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Week 6 - Product Pricer Challenge\n",
|
||||
"\n",
|
||||
"**A baseline established by GPT-4o and attempt to beat it with fine-tuning**\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize and Load Configuration\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Imports\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import math\n",
|
||||
"import json\n",
|
||||
"import random\n",
|
||||
"import pickle\n",
|
||||
"from collections import Counter\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"from openai import OpenAI\n",
|
||||
"\n",
|
||||
"# SimpleItem class definition for pickle compatibility\n",
|
||||
"class SimpleItem:\n",
|
||||
" \"\"\"\n",
|
||||
" Simple item class for pickle compatibility\n",
|
||||
" This matches the structure used in the CSV conversion script\n",
|
||||
" \"\"\"\n",
|
||||
" def __init__(self, title, description, price, category=\"Human_Generated\", token_count=0):\n",
|
||||
" self.title = title\n",
|
||||
" self.description = description\n",
|
||||
" self.price = price\n",
|
||||
" self.category = category\n",
|
||||
" self.token_count = token_count\n",
|
||||
"\n",
|
||||
" def test_prompt(self):\n",
|
||||
" \"\"\"\n",
|
||||
" Return a prompt suitable for testing, with the actual price removed\n",
|
||||
" This method is needed for compatibility with the testing framework\n",
|
||||
" \"\"\"\n",
|
||||
" return f\"How much does this cost to the nearest dollar?\\n\\n{self.title}\\n\\n{self.description}\\n\\nPrice is $\"\n",
|
||||
"\n",
|
||||
" def __repr__(self):\n",
|
||||
" return f\"SimpleItem(title='{self.title[:50]}...', price=${self.price})\"\n",
|
||||
"\n",
|
||||
"# Import our custom classes\n",
|
||||
"# Use original testing class to avoid matplotlib color issues\n",
|
||||
"try:\n",
|
||||
" from enhanced_items import Item\n",
|
||||
" # Use original Tester to avoid matplotlib color issues\n",
|
||||
" import sys\n",
|
||||
" import os\n",
|
||||
" sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))\n",
|
||||
" from testing import Tester\n",
|
||||
" print(\"✅ Using enhanced items and original testing from parent directory\")\n",
|
||||
"except ImportError:\n",
|
||||
" # Fallback to parent directory modules\n",
|
||||
" import sys\n",
|
||||
" import os\n",
|
||||
" sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))\n",
|
||||
" from items import Item\n",
|
||||
" from testing import Tester\n",
|
||||
" print(\"✅ Using modules from parent directory\")\n",
|
||||
"\n",
|
||||
"print(\"✅ All imports successful!\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Environment setup\n",
|
||||
"try:\n",
|
||||
" from google.colab import userdata\n",
|
||||
" os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n",
|
||||
" os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
|
||||
" print(\"✅ Using Colab secrets\")\n",
|
||||
"except:\n",
|
||||
" from dotenv import load_dotenv\n",
|
||||
" load_dotenv(override=True)\n",
|
||||
" os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
||||
" os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')\n",
|
||||
" print(\"✅ Using local .env file\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Log in to HuggingFace\n",
|
||||
"hf_token = os.environ['HF_TOKEN']\n",
|
||||
"login(hf_token)\n",
|
||||
"\n",
|
||||
"# Initialize OpenAI client\n",
|
||||
"openai = OpenAI()\n",
|
||||
"\n",
|
||||
"# Enable matplotlib inline for Colab\n",
|
||||
"%matplotlib inline\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Data\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load pre-processed pickle files (our data loading hack)\n",
|
||||
"def load_pickle_data():\n",
|
||||
" \"\"\"\n",
|
||||
" Load pre-processed pickle files with fallback to sample data\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"📦 Loading pre-processed pickle files...\")\n",
|
||||
" \n",
|
||||
" # Try to load pickle files\n",
|
||||
" pickle_files = ['train.pkl', 'test.pkl', 'validation.pkl', \n",
|
||||
" 'data/train.pkl', 'data/test.pkl', 'data/validation.pkl',\n",
|
||||
" '../train.pkl', '../test.pkl', '../validation.pkl']\n",
|
||||
" \n",
|
||||
" train = None\n",
|
||||
" test = None\n",
|
||||
" validation = None\n",
|
||||
" \n",
|
||||
" # Load training data\n",
|
||||
" for file_path in ['train.pkl', 'data/train.pkl', '../train.pkl']:\n",
|
||||
" if os.path.exists(file_path):\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" train = pickle.load(f)\n",
|
||||
" print(f\"✅ Loaded training data: {file_path} ({len(train)} items)\")\n",
|
||||
" break\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ Error loading {file_path}: {e}\")\n",
|
||||
" # Try to load as dictionary and convert to SimpleItem\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" raw_data = pickle.load(f)\n",
|
||||
" if isinstance(raw_data, list) and len(raw_data) > 0:\n",
|
||||
" if isinstance(raw_data[0], dict):\n",
|
||||
" # Convert dictionary to SimpleItem\n",
|
||||
" train = []\n",
|
||||
" for item_dict in raw_data:\n",
|
||||
" item = SimpleItem(\n",
|
||||
" title=item_dict.get('title', ''),\n",
|
||||
" description=item_dict.get('description', ''),\n",
|
||||
" price=item_dict.get('price', 0.0),\n",
|
||||
" category=item_dict.get('category', 'Human_Generated'),\n",
|
||||
" token_count=item_dict.get('token_count', 0)\n",
|
||||
" )\n",
|
||||
" train.append(item)\n",
|
||||
" print(f\" Converted {len(train)} training items from dictionary format\")\n",
|
||||
" break\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\" ❌ Failed to convert {file_path}: {e2}\")\n",
|
||||
" \n",
|
||||
" # Load test data\n",
|
||||
" for file_path in ['test.pkl', 'data/test.pkl', '../test.pkl']:\n",
|
||||
" if os.path.exists(file_path):\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" test = pickle.load(f)\n",
|
||||
" print(f\"✅ Loaded test data: {file_path} ({len(test)} items)\")\n",
|
||||
" break\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ Error loading {file_path}: {e}\")\n",
|
||||
" # Try to load as dictionary and convert to SimpleItem\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" raw_data = pickle.load(f)\n",
|
||||
" if isinstance(raw_data, list) and len(raw_data) > 0:\n",
|
||||
" if isinstance(raw_data[0], dict):\n",
|
||||
" # Convert dictionary to SimpleItem\n",
|
||||
" test = []\n",
|
||||
" for item_dict in raw_data:\n",
|
||||
" item = SimpleItem(\n",
|
||||
" title=item_dict.get('title', ''),\n",
|
||||
" description=item_dict.get('description', ''),\n",
|
||||
" price=item_dict.get('price', 0.0),\n",
|
||||
" category=item_dict.get('category', 'Human_Generated'),\n",
|
||||
" token_count=item_dict.get('token_count', 0)\n",
|
||||
" )\n",
|
||||
" test.append(item)\n",
|
||||
" print(f\" Converted {len(test)} test items from dictionary format\")\n",
|
||||
" break\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\" ❌ Failed to convert {file_path}: {e2}\")\n",
|
||||
" \n",
|
||||
" # Load validation data\n",
|
||||
" for file_path in ['validation.pkl', 'data/validation.pkl', '../validation.pkl']:\n",
|
||||
" if os.path.exists(file_path):\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" validation = pickle.load(f)\n",
|
||||
" print(f\"✅ Loaded validation data: {file_path} ({len(validation)} items)\")\n",
|
||||
" break\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"❌ Error loading {file_path}: {e}\")\n",
|
||||
" # Try to load as dictionary and convert to SimpleItem\n",
|
||||
" try:\n",
|
||||
" with open(file_path, 'rb') as f:\n",
|
||||
" raw_data = pickle.load(f)\n",
|
||||
" if isinstance(raw_data, list) and len(raw_data) > 0:\n",
|
||||
" if isinstance(raw_data[0], dict):\n",
|
||||
" # Convert dictionary to SimpleItem\n",
|
||||
" validation = []\n",
|
||||
" for item_dict in raw_data:\n",
|
||||
" item = SimpleItem(\n",
|
||||
" title=item_dict.get('title', ''),\n",
|
||||
" description=item_dict.get('description', ''),\n",
|
||||
" price=item_dict.get('price', 0.0),\n",
|
||||
" category=item_dict.get('category', 'Human_Generated'),\n",
|
||||
" token_count=item_dict.get('token_count', 0)\n",
|
||||
" )\n",
|
||||
" validation.append(item)\n",
|
||||
" print(f\" Converted {len(validation)} validation items from dictionary format\")\n",
|
||||
" break\n",
|
||||
" except Exception as e2:\n",
|
||||
" print(f\" ❌ Failed to convert {file_path}: {e2}\")\n",
|
||||
" \n",
|
||||
" # If no pickle files found, create sample data\n",
|
||||
" if not train or not test:\n",
|
||||
" print(\"🔄 No pickle files found, creating sample data...\")\n",
|
||||
" train, test, validation = create_sample_data()\n",
|
||||
" \n",
|
||||
" # Debug: Check what we actually loaded\n",
|
||||
" print(f\"\\n🔍 Debug - Data loaded:\")\n",
|
||||
" print(f\" train: {len(train) if train else 0} items\")\n",
|
||||
" print(f\" test: {len(test) if test else 0} items\") \n",
|
||||
" print(f\" validation: {len(validation) if validation else 0} items\")\n",
|
||||
" \n",
|
||||
" # Additional safety check\n",
|
||||
" if not test or len(test) == 0:\n",
|
||||
" print(\"⚠️ WARNING: Test dataset is empty! Creating emergency sample data...\")\n",
|
||||
" # Create emergency test data\n",
|
||||
" emergency_test = [\n",
|
||||
" SimpleItem(\"Test Product 1\", \"A test product for evaluation\", 25.99, \"Test\", 10),\n",
|
||||
" SimpleItem(\"Test Product 2\", \"Another test product\", 45.50, \"Test\", 12),\n",
|
||||
" SimpleItem(\"Test Product 3\", \"Third test product\", 15.75, \"Test\", 8)\n",
|
||||
" ]\n",
|
||||
" test = emergency_test\n",
|
||||
" print(f\" Emergency test data created: {len(test)} items\")\n",
|
||||
" \n",
|
||||
" return train, test, validation\n",
|
||||
"\n",
|
||||
"def create_sample_data():\n",
|
||||
" \"\"\"\n",
|
||||
" Create sample data for demonstration\n",
|
||||
" \"\"\"\n",
|
||||
" # Sample product data (expanded for better testing)\n",
|
||||
" sample_products = [\n",
|
||||
" {\"title\": \"Wireless Bluetooth Headphones\", \"price\": 89.99, \"category\": \"Electronics\"},\n",
|
||||
" {\"title\": \"Stainless Steel Water Bottle\", \"price\": 24.99, \"category\": \"Home & Kitchen\"},\n",
|
||||
" {\"title\": \"Organic Cotton T-Shirt\", \"price\": 19.99, \"category\": \"Clothing\"},\n",
|
||||
" {\"title\": \"Ceramic Coffee Mug\", \"price\": 12.99, \"category\": \"Home & Kitchen\"},\n",
|
||||
" {\"title\": \"LED Desk Lamp\", \"price\": 45.99, \"category\": \"Electronics\"},\n",
|
||||
" {\"title\": \"Yoga Mat\", \"price\": 29.99, \"category\": \"Sports & Outdoors\"},\n",
|
||||
" {\"title\": \"Leather Wallet\", \"price\": 39.99, \"category\": \"Accessories\"},\n",
|
||||
" {\"title\": \"Bluetooth Speaker\", \"price\": 79.99, \"category\": \"Electronics\"},\n",
|
||||
" {\"title\": \"Kitchen Knife Set\", \"price\": 129.99, \"category\": \"Home & Kitchen\"},\n",
|
||||
" {\"title\": \"Running Shoes\", \"price\": 89.99, \"category\": \"Sports & Outdoors\"},\n",
|
||||
" {\"title\": \"Smartphone Case\", \"price\": 15.99, \"category\": \"Electronics\"},\n",
|
||||
" {\"title\": \"Coffee Maker\", \"price\": 89.99, \"category\": \"Home & Kitchen\"},\n",
|
||||
" {\"title\": \"Backpack\", \"price\": 49.99, \"category\": \"Accessories\"},\n",
|
||||
" {\"title\": \"Tennis Racket\", \"price\": 79.99, \"category\": \"Sports & Outdoors\"},\n",
|
||||
" {\"title\": \"Laptop Stand\", \"price\": 34.99, \"category\": \"Electronics\"}\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" # Create SimpleItem objects\n",
|
||||
" items = []\n",
|
||||
" for product in sample_products:\n",
|
||||
" item = SimpleItem(\n",
|
||||
" title=product['title'],\n",
|
||||
" description=f\"High-quality {product['title'].lower()}\",\n",
|
||||
" price=product['price'],\n",
|
||||
" category=product['category'],\n",
|
||||
" token_count=len(product['title'] + f\"High-quality {product['title'].lower()}\") // 4\n",
|
||||
" )\n",
|
||||
" items.append(item)\n",
|
||||
" \n",
|
||||
" # Split into train/test/validation (more balanced split)\n",
|
||||
" train = items[:10] # 10 items\n",
|
||||
" test = items[10:13] # 3 items \n",
|
||||
" validation = items[13:] # 2 items\n",
|
||||
" \n",
|
||||
" print(f\"✅ Created sample data: {len(train)} train, {len(test)} test, {len(validation)} validation\")\n",
|
||||
" return train, test, validation\n",
|
||||
"\n",
|
||||
"# Load the data\n",
|
||||
"train, test, validation = load_pickle_data()\n",
|
||||
"\n",
|
||||
"print(f\"\\n📊 Dataset Statistics:\")\n",
|
||||
"print(f\" Training: {len(train)} items\")\n",
|
||||
"print(f\" Test: {len(test)} items\")\n",
|
||||
"print(f\" Validation: {len(validation)} items\")\n",
|
||||
"\n",
|
||||
"if train:\n",
|
||||
" print(f\"\\n🔍 Sample Training Item:\")\n",
|
||||
" print(f\" Title: {train[0].title}\")\n",
|
||||
" print(f\" Price: ${train[0].price}\")\n",
|
||||
" print(f\" Category: {train[0].category}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare Fine-tuning Data\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# OpenAI recommends fine-tuning with 50-100 examples\n",
|
||||
"# Use our actual train/validation split from the pickle files\n",
|
||||
"fine_tune_train = train # Use all training data (150 items)\n",
|
||||
"fine_tune_validation = validation # Use validation data (50 items)\n",
|
||||
"\n",
|
||||
"print(f\"📊 Fine-tuning data prepared:\")\n",
|
||||
"print(f\" Training: {len(fine_tune_train)} items\")\n",
|
||||
"print(f\" Validation: {len(fine_tune_validation)} items\")\n",
|
||||
"\n",
|
||||
"# Weight and Biases integration (optional)\n",
|
||||
"wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer-ft\"}}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Helper Functions\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Utility function to extract price from a string\n",
|
||||
"def get_price(s):\n",
|
||||
" s = s.replace('$', '').replace(',', '')\n",
|
||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
|
||||
" return float(match.group()) if match else 0\n",
|
||||
"\n",
|
||||
"# Prompt generation functions\n",
|
||||
"def messages_for(item):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"def messages_with_price(item):\n",
|
||||
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
|
||||
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": user_prompt},\n",
|
||||
" {\"role\": \"assistant\", \"content\": f\"Price is ${item.price:.2f}\"}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"print(\"✅ Helper functions defined!\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Baseline GPT-4o Model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def gpt_4o_frontier(item):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=\"gpt-4o\",\n",
|
||||
" messages=messages_for(item),\n",
|
||||
" seed=42,\n",
|
||||
" max_tokens=5\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)\n",
|
||||
"\n",
|
||||
"print(\"🧪 Testing baseline GPT-4o model...\")\n",
|
||||
"\n",
|
||||
"# Safety check: Make sure we have test data\n",
|
||||
"if not test or len(test) == 0:\n",
|
||||
" print(\"❌ No test data available! Cannot run baseline test.\")\n",
|
||||
" print(\"💡 Please check the data loading section above.\")\n",
|
||||
" print(\"🔍 Debug info:\")\n",
|
||||
" print(f\" test variable exists: {test is not None}\")\n",
|
||||
" print(f\" test length: {len(test) if test else 'N/A'}\")\n",
|
||||
" print(f\" test type: {type(test)}\")\n",
|
||||
"else:\n",
|
||||
" print(f\"📊 Testing on {len(test)} items...\")\n",
|
||||
" print(f\"🔍 Test data preview:\")\n",
|
||||
" for i, item in enumerate(test[:3]): # Show first 3 items\n",
|
||||
" print(f\" Item {i}: {item.title} - ${item.price}\")\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" # Create Tester with correct size parameter\n",
|
||||
" tester = Tester(gpt_4o_frontier, test, size=len(test))\n",
|
||||
" tester.run()\n",
|
||||
" except IndexError as e:\n",
|
||||
" print(f\"❌ IndexError in Tester.test: {e}\")\n",
|
||||
" print(f\"🔍 Test data length: {len(test)}\")\n",
|
||||
" print(\"💡 This suggests the Tester is trying to access more items than available.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fine-tuning Implementation\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if fine_tuned_model_name:\n",
|
||||
" def gpt_fine_tuned(item):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=fine_tuned_model_name,\n",
|
||||
" messages=messages_for(item),\n",
|
||||
" seed=42,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)\n",
|
||||
" \n",
|
||||
" print(\"🧪 Testing fine-tuned model...\")\n",
|
||||
" # Create Tester with correct size parameter to avoid IndexError\n",
|
||||
" tester = Tester(gpt_fine_tuned, test, size=len(test))\n",
|
||||
" tester.run()\n",
|
||||
"else:\n",
|
||||
" print(\"⏳ Fine-tuned model not ready yet. Please wait and re-run the previous cell.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert items to JSONL format for fine-tuning\n",
|
||||
"def make_jsonl(items):\n",
|
||||
" result = \"\"\n",
|
||||
" for item in items:\n",
|
||||
" messages = messages_with_price(item)\n",
|
||||
" messages_str = json.dumps(messages)\n",
|
||||
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
|
||||
" return result.strip()\n",
|
||||
"\n",
|
||||
"def write_jsonl(items, filename):\n",
|
||||
" with open(filename, \"w\") as f:\n",
|
||||
" jsonl = make_jsonl(items)\n",
|
||||
" f.write(jsonl)\n",
|
||||
"\n",
|
||||
"# Create fine-tuning files\n",
|
||||
"write_jsonl(fine_tune_train, \"fine_tune_train.jsonl\")\n",
|
||||
"write_jsonl(fine_tune_validation, \"fine_tune_validation.jsonl\")\n",
|
||||
"\n",
|
||||
"print(\"✅ Fine-tuning files created:\")\n",
|
||||
"print(\" - fine_tune_train.jsonl\")\n",
|
||||
"print(\" - fine_tune_validation.jsonl\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Upload files to OpenAI\n",
|
||||
"with open(\"fine_tune_train.jsonl\", \"rb\") as f:\n",
|
||||
" train_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"\n",
|
||||
"with open(\"fine_tune_validation.jsonl\", \"rb\") as f:\n",
|
||||
" validation_file = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Files uploaded to OpenAI:\")\n",
|
||||
"print(f\" Training file ID: {train_file.id}\")\n",
|
||||
"print(f\" Validation file ID: {validation_file.id}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create fine-tuning job\n",
|
||||
"fine_tuning_job = openai.fine_tuning.jobs.create(\n",
|
||||
" training_file=train_file.id,\n",
|
||||
" validation_file=validation_file.id,\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
" seed=42,\n",
|
||||
" hyperparameters={\"n_epochs\": 1},\n",
|
||||
" integrations=[wandb_integration],\n",
|
||||
" suffix=\"pricer\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"🚀 Fine-tuning job created: {fine_tuning_job.id}\")\n",
|
||||
"print(\"⏳ This will take some time to complete...\")\n",
|
||||
"print(\"💡 You can monitor progress in the OpenAI dashboard or Weights & Biases\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# FIXED: Test enhanced model (if ready) - with correct Tester size\n",
|
||||
"try:\n",
|
||||
" enhanced_model_name = openai.fine_tuning.jobs.retrieve(fine_tuning_job_v2.id).fine_tuned_model\n",
|
||||
" \n",
|
||||
" def gpt_enhanced_fine_tuned(item):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=enhanced_model_name,\n",
|
||||
" messages=messages_v2(item, with_price=False),\n",
|
||||
" seed=42,\n",
|
||||
" temperature=1.0,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)\n",
|
||||
" \n",
|
||||
" print(\"🧪 Testing enhanced fine-tuned model...\")\n",
|
||||
" # Create Tester with correct size parameter to avoid IndexError\n",
|
||||
" tester = Tester(gpt_enhanced_fine_tuned, test, size=len(test))\n",
|
||||
" tester.run()\n",
|
||||
" \n",
|
||||
"except:\n",
|
||||
" print(\"⏳ Enhanced fine-tuned model not ready yet.\")\n",
|
||||
" print(\"💡 Please wait for completion and re-run this cell.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check job status\n",
|
||||
"job_id = fine_tuning_job.id\n",
|
||||
"job_status = openai.fine_tuning.jobs.retrieve(job_id)\n",
|
||||
"\n",
|
||||
"print(f\"📊 Job Status: {job_status.status}\")\n",
|
||||
"print(f\"📈 Training File: {job_status.training_file}\")\n",
|
||||
"print(f\"📈 Validation File: {job_status.validation_file}\")\n",
|
||||
"print(f\"🤖 Model: {job_status.model}\")\n",
|
||||
"\n",
|
||||
"# Get recent events\n",
|
||||
"events = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=10)\n",
|
||||
"print(f\"\\n📋 Recent Events:\")\n",
|
||||
"for event in events.data:\n",
|
||||
" print(f\" {event.created_at}: {event.message}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Test Fine-tuned Model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Wait for fine-tuning to complete and get the model name\n",
|
||||
"# Note: In practice, you would wait for the job to complete\n",
|
||||
"try:\n",
|
||||
" fine_tuned_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
|
||||
" print(f\"✅ Fine-tuned model ready: {fine_tuned_model_name}\")\n",
|
||||
"except:\n",
|
||||
" print(\"⏳ Fine-tuning still in progress...\")\n",
|
||||
" print(\"💡 Please wait for completion and re-run this cell\")\n",
|
||||
" fine_tuned_model_name = None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test the fine-tuned model (if ready)\n",
|
||||
"if fine_tuned_model_name:\n",
|
||||
" def gpt_fine_tuned(item):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=fine_tuned_model_name,\n",
|
||||
" messages=messages_for(item),\n",
|
||||
" seed=42,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)\n",
|
||||
" \n",
|
||||
" print(\"🧪 Testing fine-tuned model...\")\n",
|
||||
" Tester.test(gpt_fine_tuned, test)\n",
|
||||
"else:\n",
|
||||
" print(\"⏳ Fine-tuned model not ready yet. Please wait and re-run the previous cell.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced Fine-tuning with Enhanced Prompts\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Enhanced prompt function (based on gold standard)\n",
|
||||
"def messages_v2(item, with_price=True):\n",
|
||||
" system_message = (\n",
|
||||
" \"Role: You are a retail price estimator.\\n\"\n",
|
||||
" \"Market: United States; Currency: USD.\\n\"\n",
|
||||
" \"Scope: Predict the most likely new retail price. Ignore taxes, shipping, coupons, bundles, used/renewed.\\n\"\n",
|
||||
" \"Output: Only a number with two decimals (e.g., 129.99). No $ sign. No words.\\n\"\n",
|
||||
" \"Think silently; do not reveal reasoning.\"\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
|
||||
" \n",
|
||||
" return [\n",
|
||||
" {\"role\": \"system\", \"content\": system_message},\n",
|
||||
" {\"role\": \"user\", \"content\": str({\n",
|
||||
" \"query\": \"price_estimate\",\n",
|
||||
" \"locale\": \"en_US\",\n",
|
||||
" \"currency\": \"USD\",\n",
|
||||
" \"category\": item.category,\n",
|
||||
" \"description\": user_prompt,\n",
|
||||
" \"brand\": json.loads(item.details).get(\"Brand\", \"Unknown\") if item.details else \"Unknown\"\n",
|
||||
" })},\n",
|
||||
" {\"role\": \"assistant\", \"content\": f\"Price is ${item.price:.2f}\" if with_price else \"Price is $\"}\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"print(\"✅ Enhanced prompt function created!\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create enhanced fine-tuning data\n",
|
||||
"def make_jsonl_v2(items):\n",
|
||||
" result = \"\"\n",
|
||||
" for item in items:\n",
|
||||
" messages = messages_v2(item)\n",
|
||||
" messages_str = json.dumps(messages)\n",
|
||||
" result += '{\"messages\": ' + messages_str + '}\\n'\n",
|
||||
" return result.strip()\n",
|
||||
"\n",
|
||||
"def write_jsonl_v2(items, filename):\n",
|
||||
" with open(filename, \"w\") as f:\n",
|
||||
" jsonl = make_jsonl_v2(items)\n",
|
||||
" f.write(jsonl)\n",
|
||||
"\n",
|
||||
"# Create enhanced fine-tuning files\n",
|
||||
"write_jsonl_v2(fine_tune_train, \"fine_tune_train_v2.jsonl\")\n",
|
||||
"write_jsonl_v2(fine_tune_validation, \"fine_tune_validation_v2.jsonl\")\n",
|
||||
"\n",
|
||||
"print(\"✅ Enhanced fine-tuning files created:\")\n",
|
||||
"print(\" - fine_tune_train_v2.jsonl\")\n",
|
||||
"print(\" - fine_tune_validation_v2.jsonl\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Upload enhanced files and create second fine-tuning job\n",
|
||||
"with open(\"fine_tune_train_v2.jsonl\", \"rb\") as f:\n",
|
||||
" train_file_v2 = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"\n",
|
||||
"with open(\"fine_tune_validation_v2.jsonl\", \"rb\") as f:\n",
|
||||
" validation_file_v2 = openai.files.create(file=f, purpose=\"fine-tune\")\n",
|
||||
"\n",
|
||||
"# Create second fine-tuning job with enhanced prompts\n",
|
||||
"fine_tuning_job_v2 = openai.fine_tuning.jobs.create(\n",
|
||||
" training_file=train_file_v2.id,\n",
|
||||
" validation_file=validation_file_v2.id,\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
" seed=42,\n",
|
||||
" hyperparameters={\"n_epochs\": 1},\n",
|
||||
" integrations=[wandb_integration],\n",
|
||||
" suffix=\"pricer-v2\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"🚀 Enhanced fine-tuning job created: {fine_tuning_job_v2.id}\")\n",
|
||||
"print(\"⏳ This will take some time to complete...\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model Comparison and Results\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test enhanced model (if ready)\n",
|
||||
"try:\n",
|
||||
" enhanced_model_name = openai.fine_tuning.jobs.retrieve(fine_tuning_job_v2.id).fine_tuned_model\n",
|
||||
" \n",
|
||||
" def gpt_enhanced_fine_tuned(item):\n",
|
||||
" response = openai.chat.completions.create(\n",
|
||||
" model=enhanced_model_name,\n",
|
||||
" messages=messages_v2(item, with_price=False),\n",
|
||||
" seed=42,\n",
|
||||
" temperature=1.0,\n",
|
||||
" max_tokens=7\n",
|
||||
" )\n",
|
||||
" reply = response.choices[0].message.content\n",
|
||||
" return get_price(reply)\n",
|
||||
" \n",
|
||||
" print(\"🧪 Testing enhanced fine-tuned model...\")\n",
|
||||
" Tester.test(gpt_enhanced_fine_tuned, test)\n",
|
||||
" \n",
|
||||
"except:\n",
|
||||
" print(\"⏳ Enhanced fine-tuned model not ready yet.\")\n",
|
||||
" print(\"💡 Please wait for completion and re-run this cell.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary and Next Steps\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"🎉 Week 6 Product Pricer Challenge Complete!\")\n",
|
||||
"print(\"=\" * 50)\n",
|
||||
"\n",
|
||||
"print(\"\\n📊 What We Accomplished:\")\n",
|
||||
"print(\"✅ Loaded data using pickle files (our data loading hack)\")\n",
|
||||
"print(\"✅ Established baseline with GPT-4o\")\n",
|
||||
"print(\"✅ Implemented fine-tuning with OpenAI API\")\n",
|
||||
"print(\"✅ Created enhanced prompts for better performance\")\n",
|
||||
"print(\"✅ Set up comprehensive evaluation framework\")\n",
|
||||
"\n",
|
||||
"print(\"\\n🚀 Next Steps:\")\n",
|
||||
"print(\"1. Wait for fine-tuning jobs to complete\")\n",
|
||||
"print(\"2. Compare performance of all models\")\n",
|
||||
"print(\"3. Experiment with different hyperparameters\")\n",
|
||||
"print(\"4. Try different base models (GPT-4.1, etc.)\")\n",
|
||||
"print(\"5. Implement ensemble methods\")\n",
|
||||
"\n",
|
||||
"print(\"\\n💡 Key Learnings:\")\n",
|
||||
"print(\"• Fine-tuning can significantly improve model performance\")\n",
|
||||
"print(\"• Prompt engineering is crucial for good results\")\n",
|
||||
"print(\"• Data quality and quantity matter for fine-tuning\")\n",
|
||||
"print(\"• Evaluation metrics help track progress\")\n",
|
||||
"\n",
|
||||
"print(\"\\n🎯 This implementation follows the gold standard approach\")\n",
|
||||
"print(\" while incorporating our data loading improvements!\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
import os
|
||||
|
||||
# Try multiple model sources in order of preference
|
||||
BASE_MODEL_OPTIONS = [
|
||||
"/root/.llama/checkpoints/Llama3.1-8B", # Local llama-stack download
|
||||
"microsoft/DialoGPT-medium", # Accessible alternative
|
||||
"gpt2" # Fallback
|
||||
]
|
||||
|
||||
BASE_MODEL = None
|
||||
|
||||
MIN_TOKENS = 150 # Any less than this, and we don't have enough useful content
|
||||
MAX_TOKENS = 160 # Truncate after this many tokens. Then after adding in prompt text, we will get to around 180 tokens
|
||||
|
||||
MIN_CHARS = 300
|
||||
CEILING_CHARS = MAX_TOKENS * 7
|
||||
|
||||
class Item:
|
||||
"""
|
||||
An Item is a cleaned, curated datapoint of a Product with a Price
|
||||
Enhanced version with better error handling and alternative tokenizer
|
||||
"""
|
||||
|
||||
# Initialize tokenizer with fallback options
|
||||
tokenizer = None
|
||||
for model_path in BASE_MODEL_OPTIONS:
|
||||
try:
|
||||
if model_path.startswith("/") and not os.path.exists(model_path):
|
||||
continue # Skip local paths that don't exist
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
BASE_MODEL = model_path
|
||||
print(f"✅ Successfully loaded tokenizer from: {model_path}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load {model_path}: {e}")
|
||||
continue
|
||||
|
||||
if tokenizer is None:
|
||||
print("❌ All tokenizer options failed. Using character-based fallback.")
|
||||
# Create a dummy tokenizer for fallback
|
||||
class DummyTokenizer:
|
||||
def encode(self, text, add_special_tokens=False):
|
||||
# Rough approximation: 1 token ≈ 4 characters
|
||||
return list(range(len(text) // 4))
|
||||
def decode(self, tokens):
|
||||
return "dummy text"
|
||||
tokenizer = DummyTokenizer()
|
||||
BASE_MODEL = "fallback"
|
||||
|
||||
PREFIX = "Price is $"
|
||||
QUESTION = "How much does this cost to the nearest dollar?"
|
||||
REMOVALS = [
|
||||
'"Batteries Included?": "No"',
|
||||
'"Batteries Included?": "Yes"',
|
||||
'"Batteries Required?": "No"',
|
||||
'"Batteries Required?": "Yes"',
|
||||
"By Manufacturer",
|
||||
"Item",
|
||||
"Date First",
|
||||
"Package",
|
||||
":",
|
||||
"Number of",
|
||||
"Best Sellers",
|
||||
"Number",
|
||||
"Product "
|
||||
]
|
||||
|
||||
title: str
|
||||
price: float
|
||||
category: str
|
||||
token_count: int = 0
|
||||
details: Optional[str]
|
||||
prompt: Optional[str] = None
|
||||
include = False
|
||||
|
||||
def __init__(self, data, price):
|
||||
self.title = data['title']
|
||||
self.price = price
|
||||
self.parse(data)
|
||||
|
||||
def scrub_details(self):
|
||||
"""
|
||||
Clean up the details string by removing common text that doesn't add value
|
||||
"""
|
||||
details = self.details
|
||||
for remove in self.REMOVALS:
|
||||
details = details.replace(remove, "")
|
||||
return details
|
||||
|
||||
def scrub(self, stuff):
|
||||
"""
|
||||
Clean up the provided text by removing unnecessary characters and whitespace
|
||||
Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers
|
||||
"""
|
||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip()
|
||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",")
|
||||
words = stuff.split(' ')
|
||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]
|
||||
return " ".join(select)
|
||||
|
||||
def parse(self, data):
|
||||
"""
|
||||
Parse this datapoint and if it fits within the allowed Token range,
|
||||
then set include to True
|
||||
"""
|
||||
contents = '\n'.join(data['description'])
|
||||
if contents:
|
||||
contents += '\n'
|
||||
features = '\n'.join(data['features'])
|
||||
if features:
|
||||
contents += features + '\n'
|
||||
self.details = data['details']
|
||||
if self.details:
|
||||
contents += self.scrub_details() + '\n'
|
||||
if len(contents) > MIN_CHARS:
|
||||
contents = contents[:CEILING_CHARS]
|
||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents)}"
|
||||
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
if len(tokens) > MIN_TOKENS:
|
||||
tokens = tokens[:MAX_TOKENS]
|
||||
text = self.tokenizer.decode(tokens)
|
||||
self.make_prompt(text)
|
||||
self.include = True
|
||||
|
||||
def make_prompt(self, text):
|
||||
"""
|
||||
Set the prompt instance variable to be a prompt appropriate for training
|
||||
"""
|
||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n"
|
||||
self.prompt += f"{self.PREFIX}{str(round(self.price))}.00"
|
||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))
|
||||
|
||||
def test_prompt(self):
|
||||
"""
|
||||
Return a prompt suitable for testing, with the actual price removed
|
||||
"""
|
||||
return self.prompt.split(self.PREFIX)[0] + self.PREFIX
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Return a String version of this Item
|
||||
"""
|
||||
return f"<{self.title} = ${self.price}>"
|
||||
|
||||
|
||||
|
||||
BIN
week6/community-contributions/finetuning-joshua/test.pkl
Normal file
BIN
week6/community-contributions/finetuning-joshua/test.pkl
Normal file
Binary file not shown.
75
week6/community-contributions/finetuning-joshua/testing.py
Normal file
75
week6/community-contributions/finetuning-joshua/testing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import math
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
COLOR_MAP = {"red":RED, "orange": YELLOW, "green": GREEN}
|
||||
|
||||
class Tester:
|
||||
|
||||
def __init__(self, predictor, data, title=None, size=250):
|
||||
self.predictor = predictor
|
||||
self.data = data
|
||||
self.title = title or predictor.__name__.replace("_", " ").title()
|
||||
self.size = size
|
||||
self.guesses = []
|
||||
self.truths = []
|
||||
self.errors = []
|
||||
self.sles = []
|
||||
self.colors = []
|
||||
|
||||
def color_for(self, error, truth):
|
||||
if error<40 or error/truth < 0.2:
|
||||
return "green"
|
||||
elif error<80 or error/truth < 0.4:
|
||||
return "orange"
|
||||
else:
|
||||
return "red"
|
||||
|
||||
def run_datapoint(self, i):
|
||||
datapoint = self.data[i]
|
||||
guess = self.predictor(datapoint)
|
||||
truth = datapoint.price
|
||||
error = abs(guess - truth)
|
||||
log_error = math.log(truth+1) - math.log(guess+1)
|
||||
sle = log_error ** 2
|
||||
color = self.color_for(error, truth)
|
||||
title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+"..."
|
||||
self.guesses.append(guess)
|
||||
self.truths.append(truth)
|
||||
self.errors.append(error)
|
||||
self.sles.append(sle)
|
||||
self.colors.append(color)
|
||||
print(f"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}")
|
||||
|
||||
def chart(self, title):
|
||||
max_error = max(self.errors)
|
||||
plt.figure(figsize=(12, 8))
|
||||
max_val = max(max(self.truths), max(self.guesses))
|
||||
plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)
|
||||
plt.scatter(self.truths, self.guesses, s=3, c=self.colors)
|
||||
plt.xlabel('Ground Truth')
|
||||
plt.ylabel('Model Estimate')
|
||||
plt.xlim(0, max_val)
|
||||
plt.ylim(0, max_val)
|
||||
plt.title(title)
|
||||
plt.show()
|
||||
|
||||
def report(self):
|
||||
average_error = sum(self.errors) / self.size
|
||||
rmsle = math.sqrt(sum(self.sles) / self.size)
|
||||
hits = sum(1 for color in self.colors if color=="green")
|
||||
title = f"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%"
|
||||
self.chart(title)
|
||||
|
||||
def run(self):
|
||||
self.error = 0
|
||||
for i in range(self.size):
|
||||
self.run_datapoint(i)
|
||||
self.report()
|
||||
|
||||
@classmethod
|
||||
def test(cls, function, data):
|
||||
cls(function, data).run()
|
||||
BIN
week6/community-contributions/finetuning-joshua/train.pkl
Normal file
BIN
week6/community-contributions/finetuning-joshua/train.pkl
Normal file
Binary file not shown.
BIN
week6/community-contributions/finetuning-joshua/validation.pkl
Normal file
BIN
week6/community-contributions/finetuning-joshua/validation.pkl
Normal file
Binary file not shown.
Reference in New Issue
Block a user