Merge pull request #345 from zoya-hammad/community-contributions
Added ensemble agent updated with an XG Boost Agent to community-contributions.
This commit is contained in:
@@ -0,0 +1,541 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "40d49349-faaa-420c-9b65-0bdc9edfabce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# The Price is Right\n",
|
||||
"\n",
|
||||
"## Finishing off with Random Forests, XG Boost & Ensemble"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6cd8b15e-f88a-470d-a9a6-b6370effaff9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install xgboost"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbcdfea8-7241-46d7-a771-c0381a3e7063",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# imports\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import math\n",
|
||||
"import json\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import random\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from huggingface_hub import login\n",
|
||||
"import numpy as np\n",
|
||||
"import pickle\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from sentence_transformers import SentenceTransformer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"import chromadb\n",
|
||||
"from items import Item\n",
|
||||
"from testing import Tester\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.ensemble import RandomForestRegressor\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"from sklearn.metrics import mean_squared_error, r2_score\n",
|
||||
"import joblib\n",
|
||||
"import xgboost as xgb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6e88bd1-f89c-4b98-92fa-aa4bc1575bca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# CONSTANTS\n",
|
||||
"\n",
|
||||
"DB = \"products_vectorstore\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "98666e73-938e-469d-8987-e6e55ba5e034",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# environment\n",
|
||||
"\n",
|
||||
"load_dotenv(override=True)\n",
|
||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
|
||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc696493-0b6f-48aa-9fa8-b1ae0ecaf3cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load in the test pickle file:\n",
|
||||
"\n",
|
||||
"with open('test.pkl', 'rb') as file:\n",
|
||||
" test = pickle.load(file)\n",
|
||||
" \n",
|
||||
"# training data is already in Chroma"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d26a1104-cd11-4361-ab25-85fb576e0582",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client = chromadb.PersistentClient(path=DB)\n",
|
||||
"collection = client.get_or_create_collection('products')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e00b82a9-a8dc-46f1-8ea9-2f07cbc8e60d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n",
|
||||
"vectors = np.array(result['embeddings'])\n",
|
||||
"documents = result['documents']\n",
|
||||
"prices = [metadata['price'] for metadata in result['metadatas']]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf6492cb-b11a-4ad5-859b-a71a78ffb949",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Random Forest\n",
|
||||
"\n",
|
||||
"We will now train a Random Forest model.\n",
|
||||
"\n",
|
||||
"Can you spot the difference from what we did in Week 6? In week 6 we used the word2vec model to form vectors; this time we'll use the vectors we already have in Chroma, from the SentenceTransformer model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48894777-101f-4fe5-998c-47079407f340",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This next line takes an hour on my M1 Mac!\n",
|
||||
"\n",
|
||||
"rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)\n",
|
||||
"rf_model.fit(vectors, prices)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "90a07dde-6f57-4488-8d08-e8e5646754e7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"n_job = -1 means it is using every core"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62eb7ddf-e1da-481e-84c6-1256547566bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save the model to a file\n",
|
||||
"\n",
|
||||
"joblib.dump(rf_model, 'random_forest_model.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d281dc5e-761e-4a5e-86b3-29d9c0a33d4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load it back in again\n",
|
||||
"\n",
|
||||
"rf_model = joblib.load('random_forest_model.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "23760bf5-fe52-473d-bfbe-def6b7a67a77",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# XG Boost Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c65dcfb9-d2c1-431c-843d-c5908bc39e3f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dmatrix = xgb.DMatrix(vectors, label=prices)\n",
|
||||
"\n",
|
||||
"params = {\n",
|
||||
" \"objective\": \"reg:squarederror\",\n",
|
||||
" \"max_depth\": 6,\n",
|
||||
" \"learning_rate\": 0.1,\n",
|
||||
" \"nthread\": -1,\n",
|
||||
" \"verbosity\": 1,\n",
|
||||
" \"subsample\": 0.8,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"model = xgb.train(params, train_dmatrix, num_boost_round=100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a6980ca7-fc38-482c-8346-80c435058886",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"joblib.dump(model,'xg_boost_model.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0605f48-04f8-44a3-8d8c-c7be4cd840b2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"xgb_model = joblib.load('xg_boost_model.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22d10315-2b11-43b0-b042-679a2814dea1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Agents"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d438dec-8e5b-4e60-bb6f-c3f82e522dd9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from agents.specialist_agent import SpecialistAgent\n",
|
||||
"from agents.frontier_agent import FrontierAgent\n",
|
||||
"from agents.random_forest_agent import RandomForestAgent\n",
|
||||
"from agents.xg_boost_agent import XGBoostAgent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "afc39369-b97b-4a90-b17e-b20ef501d3c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"specialist = SpecialistAgent()\n",
|
||||
"frontier = FrontierAgent(collection)\n",
|
||||
"random_forest = RandomForestAgent()\n",
|
||||
"xg_boost = XGBoostAgent()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8e2d0d0a-8bb8-4b39-b046-322828c39244",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def description(item):\n",
|
||||
" return item.prompt.split(\"to the nearest dollar?\\n\\n\")[1].split(\"\\n\\nPrice is $\")[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bfe0434f-b29e-4cc0-bad9-b07624665727",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rf(item):\n",
|
||||
" return random_forest.price(description(item))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdf233ec-264f-4b34-9f2b-27c39692137b",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Tester.test(rf, test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "192b94ac-37d0-4569-bc7c-8fc4f92d129b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def xg_b(item):\n",
|
||||
" return xg_boost.price(description(item))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a3fa01c2-42d9-4ce7-ae36-1d874a0003c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"xg_b(test[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9183aab7-0586-4d43-b212-c40442c7ab34",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Tester.test(xg_b, test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0045825e-2df0-429a-8ebb-2617517a2e75",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Moving towards the ensemble model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f759bd2-7a7e-4c1a-80a0-e12470feca89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"product = \"Quadcast HyperX condenser mic for high quality audio for podcasting\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e44dbd25-fb95-4b6b-bbbb-8da5fc817105",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(specialist.price(product))\n",
|
||||
"print(frontier.price(product))\n",
|
||||
"print(random_forest.price(product))\n",
|
||||
"print(xg_boost.price(product))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1779b353-e2bb-4fc7-be7c-93057e4d688a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"specialists = []\n",
|
||||
"frontiers = []\n",
|
||||
"random_forests = []\n",
|
||||
"xg_boosts = []\n",
|
||||
"prices = []\n",
|
||||
"\n",
|
||||
"for item in tqdm(test[1000:1250]):\n",
|
||||
" text = description(item)\n",
|
||||
" specialists.append(specialist.price(text))\n",
|
||||
" frontiers.append(frontier.price(text))\n",
|
||||
" random_forests.append(random_forest.price(text))\n",
|
||||
" xg_boosts.append(xg_boost.price(text))\n",
|
||||
" prices.append(item.price)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f0bca725-4e34-405b-8d90-41d67086a25d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mins = [min(s,f,r,x) for s,f,r,x in zip(specialists, frontiers, random_forests, xg_boosts)]\n",
|
||||
"maxes = [max(s,f,r,x) for s,f,r,x in zip(specialists, frontiers, random_forests, xg_boosts)]\n",
|
||||
"\n",
|
||||
"X = pd.DataFrame({\n",
|
||||
" 'Specialist': specialists,\n",
|
||||
" 'Frontier': frontiers,\n",
|
||||
" 'RandomForest': random_forests,\n",
|
||||
" 'XGBoost' : xg_boosts,\n",
|
||||
" 'Min': mins,\n",
|
||||
" 'Max': maxes,\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"# Convert y to a Series\n",
|
||||
"y = pd.Series(prices)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "baac4947-02d8-4d12-82ed-9ace3c0bee39",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train a Linear Regression - current\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
"lr = LinearRegression()\n",
|
||||
"lr.fit(X, y)\n",
|
||||
"\n",
|
||||
"feature_columns = X.columns.tolist()\n",
|
||||
"\n",
|
||||
"for feature, coef in zip(feature_columns, lr.coef_):\n",
|
||||
" print(f\"{feature}: {coef:.2f}\")\n",
|
||||
"print(f\"Intercept={lr.intercept_:.2f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "702de4cb-2311-4753-9c05-f3a0fa7e9990",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train a Linear Regression - old vals w/o xg\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
"lr = LinearRegression()\n",
|
||||
"lr.fit(X, y)\n",
|
||||
"\n",
|
||||
"feature_columns = X.columns.tolist()\n",
|
||||
"\n",
|
||||
"for feature, coef in zip(feature_columns, lr.coef_):\n",
|
||||
" print(f\"{feature}: {coef:.2f}\")\n",
|
||||
"print(f\"Intercept={lr.intercept_:.2f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0bdf6e68-28a3-4ed2-b17e-de0ede923d34",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"joblib.dump(lr, 'ensemble_model.pkl')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e762441a-9470-4dd7-8a8f-ec0430e908c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from agents.ensemble_agent import EnsembleAgent\n",
|
||||
"ensemble = EnsembleAgent(collection)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a29f03c-8010-43b7-ae7d-1bc85ca6e8e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ensemble.price(product) #old val"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "13dbf002-eba6-4c7a-898f-d697f68ca28e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ensemble.price(product)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6a5e226-a508-43d5-aa42-cefbde72ffdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def ensemble_pricer(item):\n",
|
||||
" return max(0,ensemble.price(description(item)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8397b1ef-2ea3-4af8-bb34-36594e0600cc",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Tester.test(ensemble_pricer, test) #old "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0d26c9ff-994b-4799-af51-09d00ddc0c06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Tester.test(ensemble_pricer, test)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression
|
||||
import joblib
|
||||
|
||||
from agents.agent import Agent
|
||||
from agents.specialist_agent import SpecialistAgent
|
||||
from agents.frontier_agent import FrontierAgent
|
||||
from agents.random_forest_agent import RandomForestAgent
|
||||
from agents.xg_boost_agent import XGBoostAgent
|
||||
|
||||
class EnsembleAgent(Agent):
|
||||
|
||||
name = "Ensemble Agent"
|
||||
color = Agent.YELLOW
|
||||
|
||||
def __init__(self, collection):
|
||||
"""
|
||||
Create an instance of Ensemble, by creating each of the models
|
||||
And loading the weights of the Ensemble
|
||||
"""
|
||||
self.log("Initializing Ensemble Agent")
|
||||
self.specialist = SpecialistAgent()
|
||||
self.frontier = FrontierAgent(collection)
|
||||
self.random_forest = RandomForestAgent()
|
||||
self.xg_boost = XGBoostAgent()
|
||||
self.model = joblib.load('ensemble_model.pkl')
|
||||
self.log("Ensemble Agent is ready")
|
||||
|
||||
def price(self, description: str) -> float:
|
||||
"""
|
||||
Run this ensemble model
|
||||
Ask each of the models to price the product
|
||||
Then use the Linear Regression model to return the weighted price
|
||||
:param description: the description of a product
|
||||
:return: an estimate of its price
|
||||
"""
|
||||
self.log("Running Ensemble Agent - collaborating with specialist, frontier, xg boost and random forest agents")
|
||||
specialist = self.specialist.price(description)
|
||||
frontier = self.frontier.price(description)
|
||||
random_forest = self.random_forest.price(description)
|
||||
xg_boost = self.xg_boost.price(description)
|
||||
X = pd.DataFrame({
|
||||
'Specialist': [specialist],
|
||||
'Frontier': [frontier],
|
||||
'RandomForest': [random_forest],
|
||||
'XGBoost' : [xg_boost],
|
||||
'Min': [min(specialist, frontier, random_forest, xg_boost)],
|
||||
'Max': [max(specialist, frontier, random_forest, xg_boost)],
|
||||
})
|
||||
y = max(0, self.model.predict(X)[0])
|
||||
self.log(f"Ensemble Agent complete - returning ${y:.2f}")
|
||||
return y
|
||||
@@ -0,0 +1,46 @@
|
||||
# imports
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import joblib
|
||||
from agents.agent import Agent
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
|
||||
|
||||
class XGBoostAgent(Agent):
|
||||
|
||||
name = "XG Boost Agent"
|
||||
color = Agent.BRIGHT_MAGENTA
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize this object by loading in the saved model weights
|
||||
and the SentenceTransformer vector encoding model
|
||||
"""
|
||||
self.log("XG Boost Agent is initializing")
|
||||
self.vectorizer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
||||
self.model = joblib.load('xg_boost_model.pkl')
|
||||
self.log("XG Boost Agent is ready")
|
||||
|
||||
def price(self, description: str) -> float:
|
||||
"""
|
||||
Use an XG Boost model to estimate the price of the described item
|
||||
:param description: the product to be estimated
|
||||
:return: the price as a float
|
||||
"""
|
||||
self.log("XG Boost Agent is starting a prediction")
|
||||
vector = self.vectorizer.encode([description])
|
||||
vector = vector.reshape(1, -1)
|
||||
# Convert the vector to DMatrix
|
||||
dmatrix = xgb.DMatrix(vector)
|
||||
# Predict the price using the model
|
||||
result = max(0, self.model.predict(dmatrix)[0])
|
||||
self.log(f"XG Boost Agent completed - predicting ${result:.2f}")
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user