{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Predict Product Prices\n", "\n", "Model evaluation and inference tuning\n", "\n" ], "metadata": { "id": "GHsssBgWM_l0" } }, { "cell_type": "markdown", "source": [ "## Libraries and configuration" ], "metadata": { "id": "HnwMdAP3IHad" } }, { "cell_type": "code", "source": [ "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n", "!pip install -q --upgrade requests==2.32.3 bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb" ], "metadata": { "id": "MDyR63OTNUJ6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import os\n", "import re\n", "import math\n", "import numpy as np\n", "from tqdm import tqdm\n", "from google.colab import userdata\n", "from huggingface_hub import login\n", "import wandb\n", "import torch\n", "import torch.nn.functional as F\n", "import transformers\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from datetime import datetime\n", "from peft import PeftModel\n", "import matplotlib.pyplot as plt" ], "metadata": { "id": "-yikV8pRBer9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Models\n", "\n", "# WB or HF location of artifacts\n", "ARTIFCAT_LOCATTION=\"HF\"\n", "\n", "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", "\n", "PROJECT_NAME = \"pricer\"\n", "\n", "# RUN_NAME = \"2025-10-23_23.41.24\" # - Fine tuned 16 batches / 8 bit run\n", "# RUN_NAME = \"2025-10-25_05.02.00\" # - Fine tuned 4 batches / 4 bit / LoRA 64/128 / Gradient 8\n", "RUN_NAME = \"2024-09-13_13.04.39\" # Ed's model run\n", "\n", "# Hugging Face\n", "HF_USER = \"dkisselev\"\n", "\n", "if ARTIFCAT_LOCATTION==\"HF\":\n", " PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n", " # REVISION = None\n", " REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n", "\n", "\n", " # FINETUNED_MODEL = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n", "\n", " # Ed's model\n", " FINETUNED_MODEL = f\"ed-donner/{PROJECT_RUN_NAME}\"\n", "else:\n", " # Weights and Biases\n", " WANDB_ENTITY = \"dkisselev\"\n", " os.environ[\"WANDB_API_KEY\"]=userdata.get('WANDB_API_KEY')\n", "\n", " MODEL_ARTIFACT_NAME = f\"model-{RUN_NAME}\"\n", " REVISION_TAG=\"v22\"\n", " WANDB_ARTIFACT_PATH = f\"{WANDB_ENTITY}/{PROJECT_NAME}/{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\"\n", "\n", "# Data set\n", "\n", "# DATASET_NAME = f\"{HF_USER}/pricer-data2\"\n", "DATASET_NAME = \"ed-donner/pricer-data\"\n", "\n", "# Hyperparameters for QLoRA\n", "QUANT_4_BIT = True\n", "K_SEARCH_LIMIT = 900\n", "\n", "# Used for writing to output in color\n", "GREEN = \"\\033[92m\"\n", "YELLOW = \"\\033[93m\"\n", "RED = \"\\033[91m\"\n", "BLUE = \"\\033[94m\"\n", "RESET = \"\\033[0m\"\n", "COLOR_MAP = {\"red\":RED, \"orange\": BLUE, \"green\": GREEN}" ], "metadata": { "id": "uuTX-xonNeOK" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Load Data\n", "\n", "Data is loaded from Huggin Face\n" ], "metadata": { "id": "8JArT3QAQAjx" } }, { "cell_type": "code", "source": [ "# Log in to HuggingFace\n", "hf_token = userdata.get('HF_TOKEN')\n", "login(hf_token)" ], "metadata": { "id": "WyFPZeMcM88v" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "dataset = load_dataset(DATASET_NAME)\n", "train = dataset['train']\n", "test = dataset['test']" ], "metadata": { "id": "cvXVoJH8LS6u" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Load Tokenizer and Model" ], "metadata": { "id": "qJWQ0a3wZ0Bw" } }, { "cell_type": "code", "source": [ "# 4 or 8 but quantization\n", "if QUANT_4_BIT:\n", " quant_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " bnb_4bit_quant_type=\"nf4\"\n", " )\n", "else:\n", " quant_config = BitsAndBytesConfig(\n", " load_in_8bit=True\n", " )" ], "metadata": { "id": "lAUAAcEC6ido" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Load model from w&b\n", "if ARTIFCAT_LOCATTION==\"WB\":\n", " artifact = wandb.Api().artifact(WANDB_ARTIFACT_PATH, type='model')\n", " artifact_dir = artifact.download() # Downloads to a local cache dir" ], "metadata": { "id": "OQy4pCk-dutf" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Load the Tokenizer and the Model\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"\n", "\n", "base_model = AutoModelForCausalLM.from_pretrained(\n", " BASE_MODEL,\n", " quantization_config=quant_config,\n", " device_map=\"auto\",\n", ")\n", "base_model.generation_config.pad_token_id = tokenizer.pad_token_id\n", "\n", "if ARTIFCAT_LOCATTION==\"HF\":\n", " # Load the fine-tuned model with PEFT\n", " if REVISION:\n", " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL, revision=REVISION)\n", " else:\n", " fine_tuned_model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)\n", "else:\n", " # Model at W&B\n", " fine_tuned_model = PeftModel.from_pretrained(base_model, artifact_dir)\n", "\n", "print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")" ], "metadata": { "id": "R_O04fKxMMT-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Hyperparameter helpers" ], "metadata": { "id": "UObo1-RqaNnT" } }, { "cell_type": "code", "source": [ "def calculate_weighted_price(prices, probabilities):\n", " \"\"\"\n", " Calculates a normalized weighted average price.\n", "\n", " Args:\n", " prices (list or np.array): A list of prices.\n", " probabilities (list or np.array): A list of corresponding probabilities (or weights).\n", " Returns:\n", " float: The normalized weighted average price.\n", " \"\"\"\n", " # Convert lists to numpy arrays\n", " prices_array = np.array(prices)\n", " probs_array = np.array(probabilities)\n", "\n", " # Total of the probabilities to use for normalization\n", " total_prob = np.sum(probs_array)\n", "\n", " # Catch zero\n", " if total_prob == 0:\n", " if len(prices_array) > 0:\n", " return np.mean(prices_array)\n", " else:\n", " return 0.0\n", "\n", " # Weighted avrage\n", " weighted_price = np.average(prices_array, weights=probs_array)\n", "\n", " return weighted_price" ], "metadata": { "id": "n4u27kbwlekE" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def get_top_k_predictions(prompt, device=\"cuda\"):\n", " \"\"\"\n", " Gets the top K price/probability pairs from the model.\n", "\n", " Returns:\n", " (list, list): A tuple containing (prices, probabilities)\n", " \"\"\"\n", " set_seed(42)\n", " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", " attention_mask = torch.ones(inputs.shape, device=device)\n", "\n", " with torch.no_grad():\n", " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", " next_token_logits = outputs.logits[:, -1, :].to('cpu')\n", "\n", " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", " top_prob, top_token_id = next_token_probs.topk(K_SEARCH_LIMIT)\n", "\n", " prices = []\n", " probabilities = []\n", "\n", " for i in range(K_SEARCH_LIMIT):\n", " predicted_token = tokenizer.decode(top_token_id[0][i])\n", " probability_tensor = top_prob[0][i]\n", "\n", " try:\n", " price = float(predicted_token)\n", " except ValueError as e:\n", " price = 0.0\n", "\n", " if price > 0:\n", " prices.append(price)\n", " probabilities.append(probability_tensor.item())\n", "\n", " if not prices:\n", " return [], []\n", "\n", " return prices, probabilities" ], "metadata": { "id": "ROjIbGuH0FWS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def make_prompt(text):\n", " if ARTIFCAT_LOCATTION==\"HF\":\n", " return text\n", " p_array = text.split(\"\\n\")\n", " p_question = p_array[0].replace(\"How much does this cost to the nearest dollar?\",\"What is the price of this item?\")\n", " p_title = p_array[2]\n", " p_descr = re.sub(r'\\d', '', p_array[3])\n", " p_price = p_array[5]\n", " prompt = p_title + \"\\n\" + p_descr + \"\\n\" + \"Question: \"+ p_question + \"\\n\\n\" + p_price\n", " # prompt = p_array[0] + \"\\n\\n\\n\" + p_title + \"\\n\\n\" + p_descr + \"\\n\\n\" + p_price\n", " # return text\n", " return prompt" ], "metadata": { "id": "tnmTAiEG32xK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%matplotlib inline\n", "\n", "class Tester:\n", "\n", " def __init__(self, predictor, data, title=None, size=250):\n", " self.predictor = predictor\n", " self.data = data\n", " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", " self.size = size\n", " self.guesses = []\n", " self.truths = []\n", " self.errors = []\n", " self.sles = []\n", " self.colors = []\n", "\n", " def color_for(self, error, truth):\n", " if error<40 or error/truth < 0.2:\n", " return \"green\"\n", " elif error<80 or error/truth < 0.4:\n", " return \"orange\"\n", " else:\n", " return \"red\"\n", "\n", " def run_datapoint(self, i):\n", " datapoint = self.data[i]\n", "\n", " base_prompt = datapoint[\"text\"]\n", " prompt = make_prompt(base_prompt)\n", "\n", " guess = self.predictor(prompt)\n", "\n", " # guess = self.predictor(datapoint[\"text\"])\n", " truth = datapoint[\"price\"]\n", " error = abs(guess - truth)\n", " log_error = math.log(truth+1) - math.log(guess+1)\n", " sle = log_error ** 2\n", " color = self.color_for(error, truth)\n", " title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n", " self.guesses.append(guess)\n", " self.truths.append(truth)\n", " self.errors.append(error)\n", " self.sles.append(sle)\n", " self.colors.append(color)\n", " print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n", "\n", " def chart(self, title):\n", " max_error = max(self.errors)\n", " plt.figure(figsize=(12, 8))\n", " max_val = max(max(self.truths), max(self.guesses))\n", " plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n", " plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", " plt.xlabel('Ground Truth')\n", " plt.ylabel('Model Estimate')\n", " plt.xlim(0, max_val)\n", " plt.ylim(0, max_val)\n", " plt.title(title)\n", " plt.show()\n", "\n", " def report(self):\n", " average_error = sum(self.errors) / self.size\n", " rmsle = math.sqrt(sum(self.sles) / self.size)\n", " hits = sum(1 for color in self.colors if color==\"green\")\n", " title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n", " self.chart(title)\n", "\n", " def run(self):\n", " self.error = 0\n", " for i in range(self.size):\n", " self.run_datapoint(i)\n", " self.report()\n", "\n", " @classmethod\n", " def test(cls, function, data):\n", " cls(function, data).run()" ], "metadata": { "id": "VNAEw5Eg4ABk" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "class Search_K:\n", " \"\"\"\n", " Search for the optimal 'k' value.\n", " \"\"\"\n", " def __init__(self, predictor, data, title=None, size=250):\n", " self.predictor = predictor\n", " self.data = data\n", " self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n", " self.size = size\n", " self.truths = []\n", "\n", " self.all_k_errors = []\n", " self.max_k = K_SEARCH_LIMIT\n", "\n", " # Store the list of probabilities for each inference\n", " self.all_prob_lists = []\n", " # Store the standard deviation of probs for each inference\n", " self.prob_std_devs = []\n", "\n", " def color_for(self, error, truth):\n", " if error<40 or error/truth < 0.2:\n", " return \"green\"\n", " elif error<80 or error/truth < 0.4:\n", " return \"orange\"\n", " else:\n", " return \"red\"\n", "\n", " def run_datapoint(self, i):\n", " datapoint = self.data[i]\n", " base_prompt = datapoint[\"text\"]\n", " prompt = make_prompt(base_prompt)\n", " truth = datapoint[\"price\"]\n", " self.truths.append(truth)\n", "\n", " # Get the raw lists of prices and probabilities\n", " prices, probabilities = self.predictor(prompt)\n", "\n", " self.all_prob_lists.append(probabilities)\n", "\n", " if probabilities:\n", " # Calculate and store the spread (std dev) of this prob list\n", " self.prob_std_devs.append(np.std(probabilities))\n", " else:\n", " # No probabilities, append 0 for spread\n", " self.prob_std_devs.append(0.0)\n", "\n", " errors_for_this_datapoint = []\n", "\n", " if not prices:\n", " print(f\"{i+1}: No valid prices found. Truth: ${truth:,.2f}.\")\n", " error = np.abs(0 - truth)\n", " errors_for_this_datapoint = [error] * self.max_k\n", " self.all_k_errors.append(errors_for_this_datapoint)\n", " return\n", "\n", " # Iterate from k=1 up to max_k\n", " for k in range(1, self.max_k + 1):\n", " k_prices = prices[:k]\n", " k_probabilities = probabilities[:k]\n", "\n", " # Calculate the weighted price just for this k\n", " guess = calculate_weighted_price(k_prices, k_probabilities)\n", "\n", " # Calculate and store the error for this k\n", " error = np.abs(guess - truth)\n", " errors_for_this_datapoint.append(error)\n", "\n", " # Store the list of errors (for k=1 to max_k)\n", " self.all_k_errors.append(errors_for_this_datapoint)\n", "\n", " # Print a summary for this datapoint\n", " title = datapoint[\"text\"].split(\"\\n\\n\")[1][:20] + \"...\"\n", "\n", " # Using [0], [19], [-1] for k=1, k=20, k=max_k (0-indexed)\n", " k_1_err = errors_for_this_datapoint[0]\n", " k_20_err = errors_for_this_datapoint[19]\n", " k_max_err = errors_for_this_datapoint[-1]\n", "\n", " color = self.color_for(k_1_err, truth)\n", " print(f\"{COLOR_MAP[color]}{i+1}: Truth: ${truth:,.2f}. \"\n", " f\"Errors (k=1, k=20, k={self.max_k}): \"\n", " f\"(${k_1_err:,.2f}, ${k_20_err:,.2f}, ${k_max_err:,.2f}) \"\n", " f\"Item: {title}{RESET}\")\n", "\n", " def plot_k_vs_error(self, k_values, avg_errors_by_k, best_k, min_error):\n", " \"\"\"\n", " Plots the Average Error vs. k\n", " \"\"\"\n", " plt.figure(figsize=(12, 8))\n", " plt.plot(k_values, avg_errors_by_k, label='Average Error vs. k')\n", "\n", " # Highlight the best k\n", " plt.axvline(x=best_k, color='red', linestyle='--',\n", " label=f'Best k = {best_k} (Avg Error: ${min_error:,.2f})')\n", "\n", " plt.xlabel('Number of Top Probabilities/Prices (k)')\n", " plt.ylabel('Average Absolute Error ($)')\n", " plt.title(f'Optimal k Analysis for {self.title}')\n", " plt.legend()\n", " plt.grid(True, which='both', linestyle='--', linewidth=0.5)\n", " # Set x-axis to start at 1\n", " plt.xlim(left=1)\n", " plt.savefig(\"k_vs_error_plot.png\")\n", " plt.show()\n", "\n", "\n", " def plot_probability_spread(self, idx_min_std, idx_med_std, idx_max_std):\n", " probs_min = self.all_prob_lists[idx_min_std]\n", " probs_med = self.all_prob_lists[idx_med_std]\n", " probs_max = self.all_prob_lists[idx_max_std]\n", " std_min = self.prob_std_devs[idx_min_std]\n", " std_med = self.prob_std_devs[idx_med_std]\n", " std_max = self.prob_std_devs[idx_max_std]\n", "\n", " fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 7), sharey=True)\n", " fig.suptitle('Probability Distribution Spread Analysis (Examples)', fontsize=16)\n", "\n", " def plot_strip(ax, probs, title):\n", " if not probs:\n", " ax.set_title(f\"{title}\\n(No probabilities found)\")\n", " return\n", " jitter = np.random.normal(0, 0.01, size=len(probs))\n", " ax.scatter(jitter, probs, alpha=0.5, s=10) # Made points slightly larger\n", " ax.set_title(title)\n", " ax.set_xlabel(\"Jitter\")\n", " ax.get_xaxis().set_ticks([])\n", "\n", " plot_strip(ax1, probs_min,\n", " f'Inference {idx_min_std} (Lowest Spread)\\nStd Dev: {std_min:.6f}')\n", " ax1.set_ylabel('Probability')\n", " plot_strip(ax2, probs_med,\n", " f'Inference {idx_med_std} (Median Spread)\\nStd Dev: {std_med:.6f}')\n", " plot_strip(ax3, probs_max,\n", " f'Inference {idx_max_std} (Highest Spread)\\nStd Dev: {std_max:.6f}')\n", "\n", " plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n", " plt.savefig(\"spread_examples_plot.png\")\n", " plt.show()\n", "\n", " def plot_all_std_devs(self):\n", " \"\"\"\n", " Plots a histogram and a line plot of the standard deviation\n", " for ALL inferences.\n", " \"\"\"\n", " if not self.prob_std_devs:\n", " print(\"No probability spreads recorded, skipping all-std plot.\")\n", " return\n", "\n", " # Create a figure with two subplots\n", " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12))\n", " fig.suptitle('Full Spread Analysis for All Inferences', fontsize=16)\n", "\n", " # --- Plot Histogram ---\n", " ax1.hist(self.prob_std_devs, bins=50, edgecolor='black')\n", " ax1.set_title('Distribution of Probability Standard Deviations')\n", " ax1.set_xlabel('Standard Deviation')\n", " ax1.set_ylabel('Frequency (Number of Inferences)')\n", "\n", " mean_std = np.mean(self.prob_std_devs)\n", " ax1.axvline(mean_std, color='red', linestyle='--',\n", " label=f'Mean Std Dev: {mean_std:.6f}')\n", " ax1.legend()\n", "\n", " # --- Plot Line Plot ---\n", " ax2.plot(self.prob_std_devs, marker='o', linestyle='-',\n", " markersize=3, alpha=0.7, label='Std Dev per Inference')\n", " ax2.set_title('Probability Standard Deviation per Inference')\n", " ax2.set_xlabel('Inference Index (0 to 249)')\n", " ax2.set_ylabel('Standard Deviation')\n", "\n", " ax2.axhline(mean_std, color='red', linestyle='--',\n", " label=f'Mean Std Dev: {mean_std:.6f}')\n", " ax2.legend()\n", " ax2.set_xlim(0, len(self.prob_std_devs) - 1)\n", "\n", " plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n", " plt.savefig(\"all_std_devs_plot.png\") # Save the plot\n", " plt.show()\n", "\n", " def report(self):\n", " \"\"\"\n", " Calls all three plotting functions.\n", " \"\"\"\n", " if not self.all_k_errors:\n", " print(\"\\nNo data to report on. Exiting.\")\n", " return\n", "\n", " # Optimal k Analysis ---\n", " errors_array = np.array(self.all_k_errors)\n", " avg_errors_by_k = np.mean(errors_array, axis=0)\n", " best_k_index = np.argmin(avg_errors_by_k)\n", " min_error = avg_errors_by_k[best_k_index]\n", " best_k = best_k_index + 1\n", "\n", " print(\"\\n\" + \"=\"*40)\n", " print(\"--- Optimal k Analysis Report ---\")\n", " print(f\"Model: {self.title}\")\n", " print(f\"Inferences Run: {self.size}\")\n", " print(f\"Analyzed k from 1 to {self.max_k}\")\n", " print(f\"===================================\")\n", " print(f\"==> Best k: {best_k}\")\n", " print(f\"==> Minimum Average Error: ${min_error:,.2f}\")\n", " print(\"=\"*40 + \"\\n\")\n", "\n", " k_values = np.arange(1, self.max_k + 1)\n", " self.plot_k_vs_error(k_values, avg_errors_by_k, best_k, min_error)\n", "\n", " # Probability Spread Analysis ---\n", " if not self.prob_std_devs:\n", " print(\"\\nNo probability spreads recorded, skipping spread plots.\")\n", " return\n", "\n", " print(\"\\n\" + \"=\"*40)\n", " print(\"--- Probability Spread Analysis ---\")\n", "\n", " # Find indices for examples\n", " std_sorted_indices = np.argsort(self.prob_std_devs)\n", " idx_min_std = std_sorted_indices[0]\n", " idx_med_std = std_sorted_indices[len(std_sorted_indices) // 2]\n", " idx_max_std = std_sorted_indices[-1]\n", "\n", " print(f\"Lowest spread (std): {self.prob_std_devs[idx_min_std]:.6f} (Inference {idx_min_std})\")\n", " print(f\"Median spread (std): {self.prob_std_devs[idx_med_std]:.6f} (Inference {idx_med_std})\")\n", " print(f\"Highest spread (std): {self.prob_std_devs[idx_max_std]:.6f} (Inference {idx_max_std})\")\n", " print(\"=\"*40 + \"\\n\")\n", "\n", " # Plot example spreads\n", " self.plot_probability_spread(idx_min_std, idx_med_std, idx_max_std)\n", "\n", " # Plot all spreads\n", " self.plot_all_std_devs()\n", "\n", " return best_k\n", "\n", " def run(self):\n", " for i in range(self.size):\n", " self.run_datapoint(i)\n", " best_k=self.report()\n", " return best_k\n", "\n", " @classmethod\n", " def test(cls, function, data):\n", " cls(function, data).run()" ], "metadata": { "id": "dbWS1DPV4TPQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Search best K\n", "search_k = Search_K(get_top_k_predictions, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n", "best_k = search_k.run()" ], "metadata": { "id": "Vtt13OuVE-t7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "top_K = best_k\n", "\n", "def improved_model_predict(prompt, device=\"cuda\"):\n", " set_seed(42)\n", " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", " attention_mask = torch.ones(inputs.shape, device=device)\n", "\n", " with torch.no_grad():\n", " outputs = fine_tuned_model(inputs, attention_mask=attention_mask)\n", " next_token_logits = outputs.logits[:, -1, :].to('cpu')\n", "\n", " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", " top_prob, top_token_id = next_token_probs.topk(top_K)\n", "\n", " prices = []\n", " # Renamed 'weights' to 'probabilities' for clarity\n", " probabilities = []\n", "\n", " for i in range(top_K):\n", " predicted_token = tokenizer.decode(top_token_id[0][i])\n", " # This is a torch.Tensor\n", " probability_tensor = top_prob[0][i]\n", "\n", " # print(predicted_token, probability_tensor)\n", "\n", " try:\n", " # Try to convert the decoded token string to a float\n", " price = float(predicted_token)\n", " except ValueError as e:\n", " price = 0.0\n", "\n", " # Only include valid, positive prices\n", " if price > 0:\n", " prices.append(price)\n", " # We append the tensor to our list\n", " probabilities.append(probability_tensor)\n", "\n", " if not prices:\n", " # If no valid prices were found, return 0.0\n", " return 0.0\n", "\n", "\n", " # Convert the list of prices to a numpy array\n", " prices_np = np.array(prices)\n", "\n", " # Convert the list of torch.Tensors to a numpy array of floats\n", " probs_np = np.array([p.item() for p in probabilities])\n", "\n", " # Calculate the normalized weighted average\n", " final_price = np.average(prices_np, weights=probs_np)\n", "\n", " return float(final_price) # Return as a standard python float" ], "metadata": { "id": "tuwYu1NYljIv" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "prompt=make_prompt(test[80]['text'])\n", "print(prompt)\n", "\n", "improved_model_predict(prompt)" ], "metadata": { "id": "3SxpLBJH70E-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Run Estimate vs Ground Truth\n", "tester = Tester(improved_model_predict, test, title=f\"{MODEL_ARTIFACT_NAME}:{REVISION_TAG}\" if ARTIFCAT_LOCATTION==\"WB\" else None)\n", "tester.run()" ], "metadata": { "id": "W_KcLvyt6kbb" }, "execution_count": null, "outputs": [] } ] }