This commit is contained in:
Dmitry Kisselev
2025-10-26 00:34:38 -07:00
parent 61d8281cf7
commit 907748e560
2 changed files with 144 additions and 7949 deletions

View File

@@ -0,0 +1,820 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/dkisselev-zz/llm_engineering/blob/wk7/Week_7_Excersise_fine_tuned_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GHsssBgWM_l0"
},
"source": [
"# Predict Product Prices\n",
"\n",
"Model evaluation and inference tuning\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HnwMdAP3IHad"
},
"source": [
"## Libraries and configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MDyR63OTNUJ6"
},
"outputs": [],
"source": [
"!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
"!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-yikV8pRBer9"
},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"import math\n",
"import numpy as np\n",
"from google.colab import userdata\n",
"from huggingface_hub import login\n",
"import wandb\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n",
"from datasets import load_dataset\n",
"from peft import PeftModel\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uuTX-xonNeOK"
},
"outputs": [],
"source": [
"# Models\n",
"\n",
"# WB or HF location of artifacts\n",
"ARTIFCAT_LOCATTION=\"HF\"\n",
"\n",
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
"\n",
"PROJECT_NAME = \"pricer\"\n",
"\n",
"# RUN_NAME = \"2025-10-23_23.41.24\" # - Fine tuned 16 batches / 8 bit run\n",
"# RUN_NAME = \"2025-10-25_05.02.00\" # - Fine tuned 4 batches / 4 bit / LoRA 64/128 / Gradient 8\n",
"RUN_NAME = \"2024-09-13_13.04.39\" # Ed's model run\n",
"\n",
"# Hugging Face\n",
"HF_USER = \"dkisselev\"\n",
"\n",
"if ARTIFCAT_LOCATTION==\"HF\":\n",
" PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n",
" # REVISION = None\n",
" REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
"\n",
"\n",
" # FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n",
"\n",
" # Ed's model\n",
" FINETUNED_MODEL = f\"ed-donner/{PROJECT_RUN_NAME}\"\n",
"else:\n",
" # Weights and Biases\n",
" WANDB_ENTITY = \"dkisselev\"\n",
" os.environ[\"WANDB_API_KEY\"]=userdata.get('WANDB_API_KEY')\n",
"\n",
" MODEL_ARTIFACT_NAME = f\"model-{RUN_NAME}\"\n",
" REVISION_TAG=\"v22\"\n",
" WANDB_ARTIFACT_PATH = f\"{WANDB_ENTITY}/{PROJECT_NAME}/{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\"\n",
"\n",
"# Data set\n",
"\n",
"# DATASET_NAME = f\"{HF_USER}/pricer-data2\"\n",
"DATASET_NAME = \"ed-donner/pricer-data\"\n",
"\n",
"# Hyperparameters for QLoRA\n",
"QUANT_4_BIT = True\n",
"K_SEARCH_LIMIT = 900\n",
"\n",
"# Used for writing to output in color\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"BLUE = \"\\033[94m\"\n",
"RESET = \"\\033[0m\"\n",
"COLOR_MAP = {\"red\":RED, \"orange\": BLUE, \"green\": GREEN}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8JArT3QAQAjx"
},
"source": [
"### Load Data\n",
"\n",
"Data is loaded from Huggin Face\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WyFPZeMcM88v"
},
"outputs": [],
"source": [
"# Log in to HuggingFace\n",
"hf_token = userdata.get('HF_TOKEN')\n",
"login(hf_token)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvXVoJH8LS6u"
},
"outputs": [],
"source": [
"dataset = load_dataset(DATASET_NAME)\n",
"train = dataset['train']\n",
"test = dataset['test']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qJWQ0a3wZ0Bw"
},
"source": [
"## Load Tokenizer and Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lAUAAcEC6ido"
},
"outputs": [],
"source": [
"# 4 or 8 but quantization\n",
"if QUANT_4_BIT:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" bnb_4bit_quant_type=\"nf4\"\n",
" )\n",
"else:\n",
" quant_config = BitsAndBytesConfig(\n",
" load_in_8bit=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OQy4pCk-dutf"
},
"outputs": [],
"source": [
"# Load model from w&b\n",
"if ARTIFCAT_LOCATTION==\"WB\":\n",
" artifact = wandb.Api().artifact(WANDB_ARTIFACT_PATH, type='model')\n",
" artifact_dir = artifact.download() # Downloads to a local cache dir"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R_O04fKxMMT-"
},
"outputs": [],
"source": [
"# Load the Tokenizer and the Model\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"right\"\n",
"\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" quantization_config=quant_config,\n",
" device_map=\"auto\",\n",
")\n",
"base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"\n",
"if ARTIFCAT_LOCATTION==\"HF\":\n",
" # Load the fine-tuned model with PEFT\n",
" if REVISION:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n",
" else:\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n",
"else:\n",
" # Model at W&B\n",
" fine_tuned_model = PeftModel.from_pretrained(base_model, artifact_dir)\n",
"\n",
"print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UObo1-RqaNnT"
},
"source": [
"## Hyperparameter helpers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n4u27kbwlekE"
},
"outputs": [],
"source": [
"def calculate_weighted_price(prices, probabilities):\n",
" \"\"\"\n",
" Calculates a normalized weighted average price.\n",
"\n",
" Args:\n",
" prices (list or np.array): A list of prices.\n",
" probabilities (list or np.array): A list of corresponding probabilities (or weights).\n",
" Returns:\n",
" float: The normalized weighted average price.\n",
" \"\"\"\n",
" # Convert lists to numpy arrays\n",
" prices_array = np.array(prices)\n",
" probs_array = np.array(probabilities)\n",
"\n",
" # Total of the probabilities to use for normalization\n",
" total_prob = np.sum(probs_array)\n",
"\n",
" # Catch zero\n",
" if total_prob == 0:\n",
" if len(prices_array) > 0:\n",
" return np.mean(prices_array)\n",
" else:\n",
" return 0.0\n",
"\n",
" # Weighted avrage\n",
" weighted_price = np.average(prices_array, weights=probs_array)\n",
"\n",
" return weighted_price"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ROjIbGuH0FWS"
},
"outputs": [],
"source": [
"def get_top_k_predictions(prompt, device=\"cuda\"):\n",
" \"\"\"\n",
" Gets the top K price/probability pairs from the model.\n",
"\n",
" Returns:\n",
" (list, list): A tuple containing (prices, probabilities)\n",
" \"\"\"\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
" attention_mask = torch.ones(inputs.shape, device=device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
"\n",
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
" top_prob, top_token_id = next_token_probs.topk(K_SEARCH_LIMIT)\n",
"\n",
" prices = []\n",
" probabilities = []\n",
"\n",
" for i in range(K_SEARCH_LIMIT):\n",
" predicted_token = tokenizer.decode(top_token_id[0][i])\n",
" probability_tensor = top_prob[0][i]\n",
"\n",
" try:\n",
" price = float(predicted_token)\n",
" except ValueError as e:\n",
" price = 0.0\n",
"\n",
" if price > 0:\n",
" prices.append(price)\n",
" probabilities.append(probability_tensor.item())\n",
"\n",
" if not prices:\n",
" return [], []\n",
"\n",
" return prices, probabilities"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tnmTAiEG32xK"
},
"outputs": [],
"source": [
"def make_prompt(text):\n",
" if ARTIFCAT_LOCATTION==\"HF\":\n",
" return text\n",
" p_array = text.split(\"\\n\")\n",
" p_question = p_array[0].replace(\"How much does this cost to the nearest dollar?\",\"What is the price of this item?\")\n",
" p_title = p_array[2]\n",
" p_descr = re.sub(r'\\d', '', p_array[3])\n",
" p_price = p_array[5]\n",
" prompt = p_title + \"\\n\" + p_descr + \"\\n\" + \"Question: \"+ p_question + \"\\n\\n\" + p_price\n",
" # prompt = p_array[0] + \"\\n\\n\\n\" + p_title + \"\\n\\n\" + p_descr + \"\\n\\n\" + p_price\n",
" # return text\n",
" return prompt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VNAEw5Eg4ABk"
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"class Tester:\n",
"\n",
" def __init__(self, predictor, data, title=None, size=250):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = size\n",
" self.guesses = []\n",
" self.truths = []\n",
" self.errors = []\n",
" self.sles = []\n",
" self.colors = []\n",
"\n",
" def color_for(self, error, truth):\n",
" if error<40 or error/truth < 0.2:\n",
" return \"green\"\n",
" elif error<80 or error/truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
"\n",
" base_prompt = datapoint[\"text\"]\n",
" prompt = make_prompt(base_prompt)\n",
"\n",
" guess = self.predictor(prompt)\n",
"\n",
" # guess = self.predictor(datapoint[\"text\"])\n",
" truth = datapoint[\"price\"]\n",
" error = abs(guess - truth)\n",
" log_error = math.log(truth+1) - math.log(guess+1)\n",
" sle = log_error ** 2\n",
" color = self.color_for(error, truth)\n",
" title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color)\n",
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart(self, title):\n",
" max_error = max(self.errors)\n",
" plt.figure(figsize=(12, 8))\n",
" max_val = max(max(self.truths), max(self.guesses))\n",
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
" plt.xlabel('Ground Truth')\n",
" plt.ylabel('Model Estimate')\n",
" plt.xlim(0, max_val)\n",
" plt.ylim(0, max_val)\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" average_error = sum(self.errors) / self.size\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" hits = sum(1 for color in self.colors if color==\"green\")\n",
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
" self.chart(title)\n",
"\n",
" def run(self):\n",
" self.error = 0\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" self.report()\n",
"\n",
" @classmethod\n",
" def test(cls, function, data):\n",
" cls(function, data).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dbWS1DPV4TPQ"
},
"outputs": [],
"source": [
"class Search_K:\n",
" \"\"\"\n",
" Search for the optimal 'k' value.\n",
" \"\"\"\n",
" def __init__(self, predictor, data, title=None, size=250):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = size\n",
" self.truths = []\n",
"\n",
" self.all_k_errors = []\n",
" self.max_k = K_SEARCH_LIMIT\n",
"\n",
" # Store the list of probabilities for each inference\n",
" self.all_prob_lists = []\n",
" # Store the standard deviation of probs for each inference\n",
" self.prob_std_devs = []\n",
"\n",
" def color_for(self, error, truth):\n",
" if error<40 or error/truth < 0.2:\n",
" return \"green\"\n",
" elif error<80 or error/truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" base_prompt = datapoint[\"text\"]\n",
" prompt = make_prompt(base_prompt)\n",
" truth = datapoint[\"price\"]\n",
" self.truths.append(truth)\n",
"\n",
" # Get the raw lists of prices and probabilities\n",
" prices, probabilities = self.predictor(prompt)\n",
"\n",
" self.all_prob_lists.append(probabilities)\n",
"\n",
" if probabilities:\n",
" # Calculate and store the spread (std dev) of this prob list\n",
" self.prob_std_devs.append(np.std(probabilities))\n",
" else:\n",
" # No probabilities, append 0 for spread\n",
" self.prob_std_devs.append(0.0)\n",
"\n",
" errors_for_this_datapoint = []\n",
"\n",
" if not prices:\n",
" print(f\"{i+1}: No valid prices found. Truth: ${truth:,.2f}.\")\n",
" error = np.abs(0 - truth)\n",
" errors_for_this_datapoint = [error] * self.max_k\n",
" self.all_k_errors.append(errors_for_this_datapoint)\n",
" return\n",
"\n",
" # Iterate from k=1 up to max_k\n",
" for k in range(1, self.max_k + 1):\n",
" k_prices = prices[:k]\n",
" k_probabilities = probabilities[:k]\n",
"\n",
" # Calculate the weighted price just for this k\n",
" guess = calculate_weighted_price(k_prices, k_probabilities)\n",
"\n",
" # Calculate and store the error for this k\n",
" error = np.abs(guess - truth)\n",
" errors_for_this_datapoint.append(error)\n",
"\n",
" # Store the list of errors (for k=1 to max_k)\n",
" self.all_k_errors.append(errors_for_this_datapoint)\n",
"\n",
" # Print a summary for this datapoint\n",
" title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n",
"\n",
" # Using [0], [19], [-1] for k=1, k=20, k=max_k (0-indexed)\n",
" k_1_err = errors_for_this_datapoint[0]\n",
" k_20_err = errors_for_this_datapoint[19]\n",
" k_max_err = errors_for_this_datapoint[-1]\n",
"\n",
" color = self.color_for(k_1_err, truth)\n",
" print(f\"{COLOR_MAP[color]}{i+1}: Truth: ${truth:,.2f}. \"\n",
" f\"Errors (k=1, k=20, k={self.max_k}): \"\n",
" f\"(${k_1_err:,.2f}, ${k_20_err:,.2f}, ${k_max_err:,.2f}) \"\n",
" f\"Item: {title}{RESET}\")\n",
"\n",
" def plot_k_vs_error(self, k_values, avg_errors_by_k, best_k, min_error):\n",
" \"\"\"\n",
" Plots the Average Error vs. k\n",
" \"\"\"\n",
" plt.figure(figsize=(12, 8))\n",
" plt.plot(k_values, avg_errors_by_k, label='Average Error vs. k')\n",
"\n",
" # Highlight the best k\n",
" plt.axvline(x=best_k, color='red', linestyle='--',\n",
" label=f'Best k = {best_k} (Avg Error: ${min_error:,.2f})')\n",
"\n",
" plt.xlabel('Number of Top Probabilities/Prices (k)')\n",
" plt.ylabel('Average Absolute Error ($)')\n",
" plt.title(f'Optimal k Analysis for {self.title}')\n",
" plt.legend()\n",
" plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n",
" # Set x-axis to start at 1\n",
" plt.xlim(left=1)\n",
" plt.savefig(\"k_vs_error_plot.png\")\n",
" plt.show()\n",
"\n",
"\n",
" def plot_probability_spread(self, idx_min_std, idx_med_std, idx_max_std):\n",
" probs_min = self.all_prob_lists[idx_min_std]\n",
" probs_med = self.all_prob_lists[idx_med_std]\n",
" probs_max = self.all_prob_lists[idx_max_std]\n",
" std_min = self.prob_std_devs[idx_min_std]\n",
" std_med = self.prob_std_devs[idx_med_std]\n",
" std_max = self.prob_std_devs[idx_max_std]\n",
"\n",
" fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 7), sharey=True)\n",
" fig.suptitle('Probability Distribution Spread Analysis (Examples)', fontsize=16)\n",
"\n",
" def plot_strip(ax, probs, title):\n",
" if not probs:\n",
" ax.set_title(f\"{title}\\n(No probabilities found)\")\n",
" return\n",
" jitter = np.random.normal(0, 0.01, size=len(probs))\n",
" ax.scatter(jitter, probs, alpha=0.5, s=10) # Made points slightly larger\n",
" ax.set_title(title)\n",
" ax.set_xlabel(\"Jitter\")\n",
" ax.get_xaxis().set_ticks([])\n",
"\n",
" plot_strip(ax1, probs_min,\n",
" f'Inference {idx_min_std} (Lowest Spread)\\nStd Dev: {std_min:.6f}')\n",
" ax1.set_ylabel('Probability')\n",
" plot_strip(ax2, probs_med,\n",
" f'Inference {idx_med_std} (Median Spread)\\nStd Dev: {std_med:.6f}')\n",
" plot_strip(ax3, probs_max,\n",
" f'Inference {idx_max_std} (Highest Spread)\\nStd Dev: {std_max:.6f}')\n",
"\n",
" plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
" plt.savefig(\"spread_examples_plot.png\")\n",
" plt.show()\n",
"\n",
" def plot_all_std_devs(self):\n",
" \"\"\"\n",
" Plots a histogram and a line plot of the standard deviation\n",
" for ALL inferences.\n",
" \"\"\"\n",
" if not self.prob_std_devs:\n",
" print(\"No probability spreads recorded, skipping all-std plot.\")\n",
" return\n",
"\n",
" # Create a figure with two subplots\n",
" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))\n",
" fig.suptitle('Full Spread Analysis for All Inferences', fontsize=16)\n",
"\n",
" # --- Plot Histogram ---\n",
" ax1.hist(self.prob_std_devs, bins=50, edgecolor='black')\n",
" ax1.set_title('Distribution of Probability Standard Deviations')\n",
" ax1.set_xlabel('Standard Deviation')\n",
" ax1.set_ylabel('Frequency (Number of Inferences)')\n",
"\n",
" mean_std = np.mean(self.prob_std_devs)\n",
" ax1.axvline(mean_std, color='red', linestyle='--',\n",
" label=f'Mean Std Dev: {mean_std:.6f}')\n",
" ax1.legend()\n",
"\n",
" # --- Plot Line Plot ---\n",
" ax2.plot(self.prob_std_devs, marker='o', linestyle='-',\n",
" markersize=3, alpha=0.7, label='Std Dev per Inference')\n",
" ax2.set_title('Probability Standard Deviation per Inference')\n",
" ax2.set_xlabel('Inference Index (0 to 249)')\n",
" ax2.set_ylabel('Standard Deviation')\n",
"\n",
" ax2.axhline(mean_std, color='red', linestyle='--',\n",
" label=f'Mean Std Dev: {mean_std:.6f}')\n",
" ax2.legend()\n",
" ax2.set_xlim(0, len(self.prob_std_devs) - 1)\n",
"\n",
" plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
" plt.savefig(\"all_std_devs_plot.png\") # Save the plot\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" \"\"\"\n",
" Calls all three plotting functions.\n",
" \"\"\"\n",
" if not self.all_k_errors:\n",
" print(\"\\nNo data to report on. Exiting.\")\n",
" return\n",
"\n",
" # Optimal k Analysis ---\n",
" errors_array = np.array(self.all_k_errors)\n",
" avg_errors_by_k = np.mean(errors_array, axis=0)\n",
" best_k_index = np.argmin(avg_errors_by_k)\n",
" min_error = avg_errors_by_k[best_k_index]\n",
" best_k = best_k_index + 1\n",
"\n",
" print(\"\\n\" + \"=\"*40)\n",
" print(\"--- Optimal k Analysis Report ---\")\n",
" print(f\"Model: {self.title}\")\n",
" print(f\"Inferences Run: {self.size}\")\n",
" print(f\"Analyzed k from 1 to {self.max_k}\")\n",
" print(f\"===================================\")\n",
" print(f\"==> Best k: {best_k}\")\n",
" print(f\"==> Minimum Average Error: ${min_error:,.2f}\")\n",
" print(\"=\"*40 + \"\\n\")\n",
"\n",
" k_values = np.arange(1, self.max_k + 1)\n",
" self.plot_k_vs_error(k_values, avg_errors_by_k, best_k, min_error)\n",
"\n",
" # Probability Spread Analysis ---\n",
" if not self.prob_std_devs:\n",
" print(\"\\nNo probability spreads recorded, skipping spread plots.\")\n",
" return\n",
"\n",
" print(\"\\n\" + \"=\"*40)\n",
" print(\"--- Probability Spread Analysis ---\")\n",
"\n",
" # Find indices for examples\n",
" std_sorted_indices = np.argsort(self.prob_std_devs)\n",
" idx_min_std = std_sorted_indices[0]\n",
" idx_med_std = std_sorted_indices[len(std_sorted_indices) // 2]\n",
" idx_max_std = std_sorted_indices[-1]\n",
"\n",
" print(f\"Lowest spread (std): {self.prob_std_devs[idx_min_std]:.6f} (Inference {idx_min_std})\")\n",
" print(f\"Median spread (std): {self.prob_std_devs[idx_med_std]:.6f} (Inference {idx_med_std})\")\n",
" print(f\"Highest spread (std): {self.prob_std_devs[idx_max_std]:.6f} (Inference {idx_max_std})\")\n",
" print(\"=\"*40 + \"\\n\")\n",
"\n",
" # Plot example spreads\n",
" self.plot_probability_spread(idx_min_std, idx_med_std, idx_max_std)\n",
"\n",
" # Plot all spreads\n",
" self.plot_all_std_devs()\n",
"\n",
" return best_k\n",
"\n",
" def run(self):\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" best_k=self.report()\n",
" return best_k\n",
"\n",
" @classmethod\n",
" def test(cls, function, data):\n",
" cls(function, data).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vtt13OuVE-t7"
},
"outputs": [],
"source": [
"# Search best K\n",
"search_k = Search_K(get_top_k_predictions, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
"best_k = search_k.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tuwYu1NYljIv"
},
"outputs": [],
"source": [
"top_K = best_k\n",
"\n",
"def improved_model_predict(prompt, device=\"cuda\"):\n",
" set_seed(42)\n",
" inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
" attention_mask = torch.ones(inputs.shape, device=device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n",
" next_token_logits = outputs.logits[:, -1, :].to('cpu')\n",
"\n",
" next_token_probs = F.softmax(next_token_logits, dim=-1)\n",
" top_prob, top_token_id = next_token_probs.topk(top_K)\n",
"\n",
" prices = []\n",
" # Renamed 'weights' to 'probabilities' for clarity\n",
" probabilities = []\n",
"\n",
" for i in range(top_K):\n",
" predicted_token = tokenizer.decode(top_token_id[0][i])\n",
" # This is a torch.Tensor\n",
" probability_tensor = top_prob[0][i]\n",
"\n",
" # print(predicted_token, probability_tensor)\n",
"\n",
" try:\n",
" # Try to convert the decoded token string to a float\n",
" price = float(predicted_token)\n",
" except ValueError as e:\n",
" price = 0.0\n",
"\n",
" # Only include valid, positive prices\n",
" if price > 0:\n",
" prices.append(price)\n",
" # We append the tensor to our list\n",
" probabilities.append(probability_tensor)\n",
"\n",
" if not prices:\n",
" # If no valid prices were found, return 0.0\n",
" return 0.0\n",
"\n",
"\n",
" # Convert the list of prices to a numpy array\n",
" prices_np = np.array(prices)\n",
"\n",
" # Convert the list of torch.Tensors to a numpy array of floats\n",
" probs_np = np.array([p.item() for p in probabilities])\n",
"\n",
" # Calculate the normalized weighted average\n",
" final_price = np.average(prices_np, weights=probs_np)\n",
"\n",
" return float(final_price) # Return as a standard python float"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3SxpLBJH70E-"
},
"outputs": [],
"source": [
"prompt=make_prompt(test[80]['text'])\n",
"print(prompt)\n",
"\n",
"improved_model_predict(prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W_KcLvyt6kbb"
},
"outputs": [],
"source": [
"# Run Estimate vs Ground Truth\n",
"tester = Tester(improved_model_predict, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n",
"tester.run()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}