diff --git a/week7/community_contributions/dkisselev-zz/Week_7_Excersise_fine_tuned_model.ipynb b/week7/community_contributions/dkisselev-zz/Week_7_Excersise_fine_tuned_model.ipynb new file mode 100644 index 0000000..2090543 --- /dev/null +++ b/week7/community_contributions/dkisselev-zz/Week_7_Excersise_fine_tuned_model.ipynb @@ -0,0 +1,820 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GHsssBgWM_l0" + }, + "source": [ + "# Predict Product Prices\n", + "\n", + "Model evaluation and inference tuning\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HnwMdAP3IHad" + }, + "source": [ + "## Libraries and configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MDyR63OTNUJ6" + }, + "outputs": [], + "source": [ + "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n", + "!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-yikV8pRBer9" + }, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import math\n", + "import numpy as np\n", + "from google.colab import userdata\n", + "from huggingface_hub import login\n", + "import wandb\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n", + "from datasets import load_dataset\n", + "from peft import PeftModel\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uuTX-xonNeOK" + }, + "outputs": [], + "source": [ + "# Models\n", + "\n", + "# WB or HF location of artifacts\n", + "ARTIFCAT_LOCATTION=\"HF\"\n", + "\n", + "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", + "\n", + "PROJECT_NAME = \"pricer\"\n", + "\n", + "# RUN_NAME = \"2025-10-23_23.41.24\" # - Fine tuned 16 batches / 8 bit run\n", + "# RUN_NAME = \"2025-10-25_05.02.00\" # - Fine tuned 4 batches / 4 bit / LoRA 64/128 / Gradient 8\n", + "RUN_NAME = \"2024-09-13_13.04.39\" # Ed's model run\n", + "\n", + "# Hugging Face\n", + "HF_USER = \"dkisselev\"\n", + "\n", + "if ARTIFCAT_LOCATTION==\"HF\":\n", + " PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n", + " # REVISION = None\n", + " REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n", + "\n", + "\n", + " # FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n", + "\n", + " # Ed's model\n", + " FINETUNED_MODEL = f\"ed-donner/{PROJECT_RUN_NAME}\"\n", + "else:\n", + " # Weights and Biases\n", + " WANDB_ENTITY = \"dkisselev\"\n", + " os.environ[\"WANDB_API_KEY\"]=userdata.get('WANDB_API_KEY')\n", + "\n", + " MODEL_ARTIFACT_NAME = f\"model-{RUN_NAME}\"\n", + " REVISION_TAG=\"v22\"\n", + " WANDB_ARTIFACT_PATH = f\"{WANDB_ENTITY}/{PROJECT_NAME}/{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\"\n", + "\n", + "# Data set\n", + "\n", + "# DATASET_NAME = f\"{HF_USER}/pricer-data2\"\n", + "DATASET_NAME = \"ed-donner/pricer-data\"\n", + "\n", + "# Hyperparameters for QLoRA\n", + "QUANT_4_BIT = True\n", + "K_SEARCH_LIMIT = 900\n", + "\n", + "# Used for writing to output in color\n", + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "BLUE = \"\\033[94m\"\n", + "RESET = \"\\033[0m\"\n", + "COLOR_MAP = {\"red\":RED, \"orange\": BLUE, \"green\": GREEN}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8JArT3QAQAjx" + }, + "source": [ + "### Load Data\n", + "\n", + "Data is loaded from Huggin Face\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WyFPZeMcM88v" + }, + "outputs": [], + "source": [ + "# Log in to HuggingFace\n", + "hf_token = userdata.get('HF_TOKEN')\n", + "login(hf_token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cvXVoJH8LS6u" + }, + "outputs": [], + "source": [ + "dataset = load_dataset(DATASET_NAME)\n", + "train = dataset['train']\n", + "test = dataset['test']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qJWQ0a3wZ0Bw" + }, + "source": [ + "## Load Tokenizer and Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lAUAAcEC6ido" + }, + "outputs": [], + "source": [ + "# 4 or 8 but quantization\n", + "if QUANT_4_BIT:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + " bnb_4bit_quant_type=\"nf4\"\n", + " )\n", + "else:\n", + " quant_config = BitsAndBytesConfig(\n", + " load_in_8bit=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OQy4pCk-dutf" + }, + "outputs": [], + "source": [ + "# Load model from w&b\n", + "if ARTIFCAT_LOCATTION==\"WB\":\n", + " artifact = wandb.Api().artifact(WANDB_ARTIFACT_PATH, type='model')\n", + " artifact_dir = artifact.download() # Downloads to a local cache dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "R_O04fKxMMT-" + }, + "outputs": [], + "source": [ + "# Load the Tokenizer and the Model\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "tokenizer.padding_side = \"right\"\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(\n", + " BASE_MODEL,\n", + " quantization_config=quant_config,\n", + " device_map=\"auto\",\n", + ")\n", + "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n", + "\n", + "if ARTIFCAT_LOCATTION==\"HF\":\n", + " # Load the fine-tuned model with PEFT\n", + " if REVISION:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n", + " else:\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n", + "else:\n", + " # Model at W&B\n", + " fine_tuned_model = PeftModel.from_pretrained(base_model, artifact_dir)\n", + "\n", + "print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UObo1-RqaNnT" + }, + "source": [ + "## Hyperparameter helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n4u27kbwlekE" + }, + "outputs": [], + "source": [ + "def calculate_weighted_price(prices, probabilities):\n", + " \"\"\"\n", + " Calculates a normalized weighted average price.\n", + "\n", + " Args:\n", + " prices (list or np.array): A list of prices.\n", + " probabilities (list or np.array): A list of corresponding probabilities (or weights).\n", + " Returns:\n", + " float: The normalized weighted average price.\n", + " \"\"\"\n", + " # Convert lists to numpy arrays\n", + " prices_array = np.array(prices)\n", + " probs_array = np.array(probabilities)\n", + "\n", + " # Total of the probabilities to use for normalization\n", + " total_prob = np.sum(probs_array)\n", + "\n", + " # Catch zero\n", + " if total_prob == 0:\n", + " if len(prices_array) > 0:\n", + " return np.mean(prices_array)\n", + " else:\n", + " return 0.0\n", + "\n", + " # Weighted avrage\n", + " weighted_price = np.average(prices_array, weights=probs_array)\n", + "\n", + " return weighted_price" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ROjIbGuH0FWS" + }, + "outputs": [], + "source": [ + "def get_top_k_predictions(prompt, device=\"cuda\"):\n", + " \"\"\"\n", + " Gets the top K price/probability pairs from the model.\n", + "\n", + " Returns:\n", + " (list, list): A tuple containing (prices, probabilities)\n", + " \"\"\"\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", + " attention_mask = torch.ones(inputs.shape, device=device)\n", + "\n", + " with torch.no_grad():\n", + " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", + " next_token_logits = outputs.logits[:, -1, :].to('cpu')\n", + "\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " top_prob, top_token_id = next_token_probs.topk(K_SEARCH_LIMIT)\n", + "\n", + " prices = []\n", + " probabilities = []\n", + "\n", + " for i in range(K_SEARCH_LIMIT):\n", + " predicted_token = tokenizer.decode(top_token_id[0][i])\n", + " probability_tensor = top_prob[0][i]\n", + "\n", + " try:\n", + " price = float(predicted_token)\n", + " except ValueError as e:\n", + " price = 0.0\n", + "\n", + " if price > 0:\n", + " prices.append(price)\n", + " probabilities.append(probability_tensor.item())\n", + "\n", + " if not prices:\n", + " return [], []\n", + "\n", + " return prices, probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tnmTAiEG32xK" + }, + "outputs": [], + "source": [ + "def make_prompt(text):\n", + " if ARTIFCAT_LOCATTION==\"HF\":\n", + " return text\n", + " p_array = text.split(\"\\n\")\n", + " p_question = p_array[0].replace(\"How much does this cost to the nearest dollar?\",\"What is the price of this item?\")\n", + " p_title = p_array[2]\n", + " p_descr = re.sub(r'\\d', '', p_array[3])\n", + " p_price = p_array[5]\n", + " prompt = p_title + \"\\n\" + p_descr + \"\\n\" + \"Question: \"+ p_question + \"\\n\\n\" + p_price\n", + " # prompt = p_array[0] + \"\\n\\n\\n\" + p_title + \"\\n\\n\" + p_descr + \"\\n\\n\" + p_price\n", + " # return text\n", + " return prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VNAEw5Eg4ABk" + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "class Tester:\n", + "\n", + " def __init__(self, predictor, data, title=None, size=250):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", + " self.size = size\n", + " self.guesses = []\n", + " self.truths = []\n", + " self.errors = []\n", + " self.sles = []\n", + " self.colors = []\n", + "\n", + " def color_for(self, error, truth):\n", + " if error<40 or error/truth < 0.2:\n", + " return \"green\"\n", + " elif error<80 or error/truth < 0.4:\n", + " return \"orange\"\n", + " else:\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + " datapoint = self.data[i]\n", + "\n", + " base_prompt = datapoint[\"text\"]\n", + " prompt = make_prompt(base_prompt)\n", + "\n", + " guess = self.predictor(prompt)\n", + "\n", + " # guess = self.predictor(datapoint[\"text\"])\n", + " truth = datapoint[\"price\"]\n", + " error = abs(guess - truth)\n", + " log_error = math.log(truth+1) - math.log(guess+1)\n", + " sle = log_error ** 2\n", + " color = self.color_for(error, truth)\n", + " title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.sles.append(sle)\n", + " self.colors.append(color)\n", + " print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n", + "\n", + " def chart(self, title):\n", + " max_error = max(self.errors)\n", + " plt.figure(figsize=(12, 8))\n", + " max_val = max(max(self.truths), max(self.guesses))\n", + " plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n", + " plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", + " plt.xlabel('Ground Truth')\n", + " plt.ylabel('Model Estimate')\n", + " plt.xlim(0, max_val)\n", + " plt.ylim(0, max_val)\n", + " plt.title(title)\n", + " plt.show()\n", + "\n", + " def report(self):\n", + " average_error = sum(self.errors) / self.size\n", + " rmsle = math.sqrt(sum(self.sles) / self.size)\n", + " hits = sum(1 for color in self.colors if color==\"green\")\n", + " title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n", + " self.chart(title)\n", + "\n", + " def run(self):\n", + " self.error = 0\n", + " for i in range(self.size):\n", + " self.run_datapoint(i)\n", + " self.report()\n", + "\n", + " @classmethod\n", + " def test(cls, function, data):\n", + " cls(function, data).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dbWS1DPV4TPQ" + }, + "outputs": [], + "source": [ + "class Search_K:\n", + " \"\"\"\n", + " Search for the optimal 'k' value.\n", + " \"\"\"\n", + " def __init__(self, predictor, data, title=None, size=250):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", + " self.size = size\n", + " self.truths = []\n", + "\n", + " self.all_k_errors = []\n", + " self.max_k = K_SEARCH_LIMIT\n", + "\n", + " # Store the list of probabilities for each inference\n", + " self.all_prob_lists = []\n", + " # Store the standard deviation of probs for each inference\n", + " self.prob_std_devs = []\n", + "\n", + " def color_for(self, error, truth):\n", + " if error<40 or error/truth < 0.2:\n", + " return \"green\"\n", + " elif error<80 or error/truth < 0.4:\n", + " return \"orange\"\n", + " else:\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + " datapoint = self.data[i]\n", + " base_prompt = datapoint[\"text\"]\n", + " prompt = make_prompt(base_prompt)\n", + " truth = datapoint[\"price\"]\n", + " self.truths.append(truth)\n", + "\n", + " # Get the raw lists of prices and probabilities\n", + " prices, probabilities = self.predictor(prompt)\n", + "\n", + " self.all_prob_lists.append(probabilities)\n", + "\n", + " if probabilities:\n", + " # Calculate and store the spread (std dev) of this prob list\n", + " self.prob_std_devs.append(np.std(probabilities))\n", + " else:\n", + " # No probabilities, append 0 for spread\n", + " self.prob_std_devs.append(0.0)\n", + "\n", + " errors_for_this_datapoint = []\n", + "\n", + " if not prices:\n", + " print(f\"{i+1}: No valid prices found. Truth: ${truth:,.2f}.\")\n", + " error = np.abs(0 - truth)\n", + " errors_for_this_datapoint = [error] * self.max_k\n", + " self.all_k_errors.append(errors_for_this_datapoint)\n", + " return\n", + "\n", + " # Iterate from k=1 up to max_k\n", + " for k in range(1, self.max_k + 1):\n", + " k_prices = prices[:k]\n", + " k_probabilities = probabilities[:k]\n", + "\n", + " # Calculate the weighted price just for this k\n", + " guess = calculate_weighted_price(k_prices, k_probabilities)\n", + "\n", + " # Calculate and store the error for this k\n", + " error = np.abs(guess - truth)\n", + " errors_for_this_datapoint.append(error)\n", + "\n", + " # Store the list of errors (for k=1 to max_k)\n", + " self.all_k_errors.append(errors_for_this_datapoint)\n", + "\n", + " # Print a summary for this datapoint\n", + " title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n", + "\n", + " # Using [0], [19], [-1] for k=1, k=20, k=max_k (0-indexed)\n", + " k_1_err = errors_for_this_datapoint[0]\n", + " k_20_err = errors_for_this_datapoint[19]\n", + " k_max_err = errors_for_this_datapoint[-1]\n", + "\n", + " color = self.color_for(k_1_err, truth)\n", + " print(f\"{COLOR_MAP[color]}{i+1}: Truth: ${truth:,.2f}. \"\n", + " f\"Errors (k=1, k=20, k={self.max_k}): \"\n", + " f\"(${k_1_err:,.2f}, ${k_20_err:,.2f}, ${k_max_err:,.2f}) \"\n", + " f\"Item: {title}{RESET}\")\n", + "\n", + " def plot_k_vs_error(self, k_values, avg_errors_by_k, best_k, min_error):\n", + " \"\"\"\n", + " Plots the Average Error vs. k\n", + " \"\"\"\n", + " plt.figure(figsize=(12, 8))\n", + " plt.plot(k_values, avg_errors_by_k, label='Average Error vs. k')\n", + "\n", + " # Highlight the best k\n", + " plt.axvline(x=best_k, color='red', linestyle='--',\n", + " label=f'Best k = {best_k} (Avg Error: ${min_error:,.2f})')\n", + "\n", + " plt.xlabel('Number of Top Probabilities/Prices (k)')\n", + " plt.ylabel('Average Absolute Error ($)')\n", + " plt.title(f'Optimal k Analysis for {self.title}')\n", + " plt.legend()\n", + " plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n", + " # Set x-axis to start at 1\n", + " plt.xlim(left=1)\n", + " plt.savefig(\"k_vs_error_plot.png\")\n", + " plt.show()\n", + "\n", + "\n", + " def plot_probability_spread(self, idx_min_std, idx_med_std, idx_max_std):\n", + " probs_min = self.all_prob_lists[idx_min_std]\n", + " probs_med = self.all_prob_lists[idx_med_std]\n", + " probs_max = self.all_prob_lists[idx_max_std]\n", + " std_min = self.prob_std_devs[idx_min_std]\n", + " std_med = self.prob_std_devs[idx_med_std]\n", + " std_max = self.prob_std_devs[idx_max_std]\n", + "\n", + " fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 7), sharey=True)\n", + " fig.suptitle('Probability Distribution Spread Analysis (Examples)', fontsize=16)\n", + "\n", + " def plot_strip(ax, probs, title):\n", + " if not probs:\n", + " ax.set_title(f\"{title}\\n(No probabilities found)\")\n", + " return\n", + " jitter = np.random.normal(0, 0.01, size=len(probs))\n", + " ax.scatter(jitter, probs, alpha=0.5, s=10) # Made points slightly larger\n", + " ax.set_title(title)\n", + " ax.set_xlabel(\"Jitter\")\n", + " ax.get_xaxis().set_ticks([])\n", + "\n", + " plot_strip(ax1, probs_min,\n", + " f'Inference {idx_min_std} (Lowest Spread)\\nStd Dev: {std_min:.6f}')\n", + " ax1.set_ylabel('Probability')\n", + " plot_strip(ax2, probs_med,\n", + " f'Inference {idx_med_std} (Median Spread)\\nStd Dev: {std_med:.6f}')\n", + " plot_strip(ax3, probs_max,\n", + " f'Inference {idx_max_std} (Highest Spread)\\nStd Dev: {std_max:.6f}')\n", + "\n", + " plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n", + " plt.savefig(\"spread_examples_plot.png\")\n", + " plt.show()\n", + "\n", + " def plot_all_std_devs(self):\n", + " \"\"\"\n", + " Plots a histogram and a line plot of the standard deviation\n", + " for ALL inferences.\n", + " \"\"\"\n", + " if not self.prob_std_devs:\n", + " print(\"No probability spreads recorded, skipping all-std plot.\")\n", + " return\n", + "\n", + " # Create a figure with two subplots\n", + " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))\n", + " fig.suptitle('Full Spread Analysis for All Inferences', fontsize=16)\n", + "\n", + " # --- Plot Histogram ---\n", + " ax1.hist(self.prob_std_devs, bins=50, edgecolor='black')\n", + " ax1.set_title('Distribution of Probability Standard Deviations')\n", + " ax1.set_xlabel('Standard Deviation')\n", + " ax1.set_ylabel('Frequency (Number of Inferences)')\n", + "\n", + " mean_std = np.mean(self.prob_std_devs)\n", + " ax1.axvline(mean_std, color='red', linestyle='--',\n", + " label=f'Mean Std Dev: {mean_std:.6f}')\n", + " ax1.legend()\n", + "\n", + " # --- Plot Line Plot ---\n", + " ax2.plot(self.prob_std_devs, marker='o', linestyle='-',\n", + " markersize=3, alpha=0.7, label='Std Dev per Inference')\n", + " ax2.set_title('Probability Standard Deviation per Inference')\n", + " ax2.set_xlabel('Inference Index (0 to 249)')\n", + " ax2.set_ylabel('Standard Deviation')\n", + "\n", + " ax2.axhline(mean_std, color='red', linestyle='--',\n", + " label=f'Mean Std Dev: {mean_std:.6f}')\n", + " ax2.legend()\n", + " ax2.set_xlim(0, len(self.prob_std_devs) - 1)\n", + "\n", + " plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n", + " plt.savefig(\"all_std_devs_plot.png\") # Save the plot\n", + " plt.show()\n", + "\n", + " def report(self):\n", + " \"\"\"\n", + " Calls all three plotting functions.\n", + " \"\"\"\n", + " if not self.all_k_errors:\n", + " print(\"\\nNo data to report on. Exiting.\")\n", + " return\n", + "\n", + " # Optimal k Analysis ---\n", + " errors_array = np.array(self.all_k_errors)\n", + " avg_errors_by_k = np.mean(errors_array, axis=0)\n", + " best_k_index = np.argmin(avg_errors_by_k)\n", + " min_error = avg_errors_by_k[best_k_index]\n", + " best_k = best_k_index + 1\n", + "\n", + " print(\"\\n\" + \"=\"*40)\n", + " print(\"--- Optimal k Analysis Report ---\")\n", + " print(f\"Model: {self.title}\")\n", + " print(f\"Inferences Run: {self.size}\")\n", + " print(f\"Analyzed k from 1 to {self.max_k}\")\n", + " print(f\"===================================\")\n", + " print(f\"==> Best k: {best_k}\")\n", + " print(f\"==> Minimum Average Error: ${min_error:,.2f}\")\n", + " print(\"=\"*40 + \"\\n\")\n", + "\n", + " k_values = np.arange(1, self.max_k + 1)\n", + " self.plot_k_vs_error(k_values, avg_errors_by_k, best_k, min_error)\n", + "\n", + " # Probability Spread Analysis ---\n", + " if not self.prob_std_devs:\n", + " print(\"\\nNo probability spreads recorded, skipping spread plots.\")\n", + " return\n", + "\n", + " print(\"\\n\" + \"=\"*40)\n", + " print(\"--- Probability Spread Analysis ---\")\n", + "\n", + " # Find indices for examples\n", + " std_sorted_indices = np.argsort(self.prob_std_devs)\n", + " idx_min_std = std_sorted_indices[0]\n", + " idx_med_std = std_sorted_indices[len(std_sorted_indices) // 2]\n", + " idx_max_std = std_sorted_indices[-1]\n", + "\n", + " print(f\"Lowest spread (std): {self.prob_std_devs[idx_min_std]:.6f} (Inference {idx_min_std})\")\n", + " print(f\"Median spread (std): {self.prob_std_devs[idx_med_std]:.6f} (Inference {idx_med_std})\")\n", + " print(f\"Highest spread (std): {self.prob_std_devs[idx_max_std]:.6f} (Inference {idx_max_std})\")\n", + " print(\"=\"*40 + \"\\n\")\n", + "\n", + " # Plot example spreads\n", + " self.plot_probability_spread(idx_min_std, idx_med_std, idx_max_std)\n", + "\n", + " # Plot all spreads\n", + " self.plot_all_std_devs()\n", + "\n", + " return best_k\n", + "\n", + " def run(self):\n", + " for i in range(self.size):\n", + " self.run_datapoint(i)\n", + " best_k=self.report()\n", + " return best_k\n", + "\n", + " @classmethod\n", + " def test(cls, function, data):\n", + " cls(function, data).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Vtt13OuVE-t7" + }, + "outputs": [], + "source": [ + "# Search best K\n", + "search_k = Search_K(get_top_k_predictions, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n", + "best_k = search_k.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tuwYu1NYljIv" + }, + "outputs": [], + "source": [ + "top_K = best_k\n", + "\n", + "def improved_model_predict(prompt, device=\"cuda\"):\n", + " set_seed(42)\n", + " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", + " attention_mask = torch.ones(inputs.shape, device=device)\n", + "\n", + " with torch.no_grad():\n", + " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", + " next_token_logits = outputs.logits[:, -1, :].to('cpu')\n", + "\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " top_prob, top_token_id = next_token_probs.topk(top_K)\n", + "\n", + " prices = []\n", + " # Renamed 'weights' to 'probabilities' for clarity\n", + " probabilities = []\n", + "\n", + " for i in range(top_K):\n", + " predicted_token = tokenizer.decode(top_token_id[0][i])\n", + " # This is a torch.Tensor\n", + " probability_tensor = top_prob[0][i]\n", + "\n", + " # print(predicted_token, probability_tensor)\n", + "\n", + " try:\n", + " # Try to convert the decoded token string to a float\n", + " price = float(predicted_token)\n", + " except ValueError as e:\n", + " price = 0.0\n", + "\n", + " # Only include valid, positive prices\n", + " if price > 0:\n", + " prices.append(price)\n", + " # We append the tensor to our list\n", + " probabilities.append(probability_tensor)\n", + "\n", + " if not prices:\n", + " # If no valid prices were found, return 0.0\n", + " return 0.0\n", + "\n", + "\n", + " # Convert the list of prices to a numpy array\n", + " prices_np = np.array(prices)\n", + "\n", + " # Convert the list of torch.Tensors to a numpy array of floats\n", + " probs_np = np.array([p.item() for p in probabilities])\n", + "\n", + " # Calculate the normalized weighted average\n", + " final_price = np.average(prices_np, weights=probs_np)\n", + "\n", + " return float(final_price) # Return as a standard python float" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3SxpLBJH70E-" + }, + "outputs": [], + "source": [ + "prompt=make_prompt(test[80]['text'])\n", + "print(prompt)\n", + "\n", + "improved_model_predict(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W_KcLvyt6kbb" + }, + "outputs": [], + "source": [ + "# Run Estimate vs Ground Truth\n", + "tester = Tester(improved_model_predict, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n", + "tester.run()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}