Week 8 updates
This commit is contained in:
105
week8/agents/frontier_agent.py
Normal file
105
week8/agents/frontier_agent.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# imports
|
||||
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from openai import OpenAI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from datasets import load_dataset
|
||||
import chromadb
|
||||
from items import Item
|
||||
from testing import Tester
|
||||
from agents.agent import Agent
|
||||
|
||||
|
||||
class FrontierAgent(Agent):
|
||||
|
||||
name = "Frontier Agent"
|
||||
color = Agent.BLUE
|
||||
|
||||
MODEL = "gpt-4o-mini"
|
||||
|
||||
def __init__(self, collection):
|
||||
"""
|
||||
Set up this instance by connecting to OpenAI, to the Chroma Datastore,
|
||||
And setting up the vector encoding model
|
||||
"""
|
||||
self.log("Initializing Frontier Agent")
|
||||
self.openai = OpenAI()
|
||||
self.collection = collection
|
||||
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
||||
self.log("Frontier Agent is ready")
|
||||
|
||||
def make_context(self, similars: List[str], prices: List[float]) -> str:
|
||||
"""
|
||||
Create context that can be inserted into the prompt
|
||||
:param similars: similar products to the one being estimated
|
||||
:param prices: prices of the similar products
|
||||
:return: text to insert in the prompt that provides context
|
||||
"""
|
||||
message = "To provide some context, here are some other items that might be similar to the item you need to estimate.\n\n"
|
||||
for similar, price in zip(similars, prices):
|
||||
message += f"Potentially related product:\n{similar}\nPrice is ${price:.2f}\n\n"
|
||||
return message
|
||||
|
||||
def messages_for(self, description: str, similars: List[str], prices: List[float]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Create the message list to be included in a call to OpenAI
|
||||
With the system and user prompt
|
||||
:param description: a description of the product
|
||||
:param similars: similar products to this one
|
||||
:param prices: prices of similar products
|
||||
:return: the list of messages in the format expected by OpenAI
|
||||
"""
|
||||
system_message = "You estimate prices of items. Reply only with the price, no explanation"
|
||||
user_prompt = self.make_context(similars, prices)
|
||||
user_prompt += "And now the question for you:\n\n"
|
||||
user_prompt += "How much does this cost?\n\n" + description
|
||||
return [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": "Price is $"}
|
||||
]
|
||||
|
||||
def find_similars(self, description: str):
|
||||
"""
|
||||
Return a list of items similar to the given one by looking in the Chroma datastore
|
||||
"""
|
||||
self.log("Frontier Agent is performing a RAG search of the Chroma datastore to find 5 similar products")
|
||||
vector = self.model.encode([description])
|
||||
results = self.collection.query(query_embeddings=vector.astype(float).tolist(), n_results=5)
|
||||
documents = results['documents'][0][:]
|
||||
prices = [m['price'] for m in results['metadatas'][0][:]]
|
||||
self.log("Frontier Agent has found similar products")
|
||||
return documents, prices
|
||||
|
||||
def get_price(self, s) -> float:
|
||||
"""
|
||||
A utility that plucks a floating point number out of a string
|
||||
"""
|
||||
s = s.replace('$','').replace(',','')
|
||||
match = re.search(r"[-+]?\d*\.\d+|\d+", s)
|
||||
return float(match.group()) if match else 0.0
|
||||
|
||||
def price(self, description: str) -> float:
|
||||
"""
|
||||
Make a call to OpenAI to estimate the price of the described product,
|
||||
by looking up 5 similar products and including them in the prompt to give context
|
||||
:param description: a description of the product
|
||||
:return: an estimate of the price
|
||||
"""
|
||||
documents, prices = self.find_similars(description)
|
||||
self.log("Frontier Agent is about to call OpenAI with context including 5 similar products")
|
||||
response = self.openai.chat.completions.create(
|
||||
model=self.MODEL,
|
||||
messages=self.messages_for(description, documents, prices),
|
||||
seed=42,
|
||||
max_tokens=5
|
||||
)
|
||||
reply = response.choices[0].message.content
|
||||
result = self.get_price(reply)
|
||||
self.log(f"Frontier Agent completed - predicting ${result:.2f}")
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user