Week 8 updates
This commit is contained in:
37
week8/agents/random_forest_agent.py
Normal file
37
week8/agents/random_forest_agent.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# imports
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import joblib
|
||||
from agents.agent import Agent
|
||||
|
||||
|
||||
|
||||
class RandomForestAgent(Agent):
|
||||
|
||||
name = "Random Forest Agent"
|
||||
color = Agent.MAGENTA
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize this object by loading in the saved model weights
|
||||
and the SentenceTransformer vector encoding model
|
||||
"""
|
||||
self.log("Random Forest Agent is initializing")
|
||||
self.vectorizer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
||||
self.model = joblib.load('random_forest_model.pkl')
|
||||
self.log("Random Forest Agent is ready")
|
||||
|
||||
def price(self, description: str) -> float:
|
||||
"""
|
||||
Use a Random Forest model to estimate the price of the described item
|
||||
:param description: the product to be estimated
|
||||
:return: the price as a float
|
||||
"""
|
||||
self.log("Random Forest Agent is starting a prediction")
|
||||
vector = self.vectorizer.encode([description])
|
||||
result = max(0, self.model.predict(vector)[0])
|
||||
self.log(f"Random Forest Agent completed - predicting ${result:.2f}")
|
||||
return result
|
||||
Reference in New Issue
Block a user