changes to prompt, hyperparameters
This commit is contained in:
@@ -65,8 +65,8 @@ class PricePredictionFineTuner:
|
|||||||
print("Pickle files not found. Loading from Hugging Face...")
|
print("Pickle files not found. Loading from Hugging Face...")
|
||||||
self._load_from_huggingface(category)
|
self._load_from_huggingface(category)
|
||||||
|
|
||||||
self.fine_tune_train = self.train[:500]
|
self.fine_tune_train = self.train[:750]
|
||||||
self.fine_tune_validation = self.train[500:600]
|
self.fine_tune_validation = self.train[750:850]
|
||||||
|
|
||||||
print(f"Fine-tuning split: {len(self.fine_tune_train)} train, {len(self.fine_tune_validation)} validation")
|
print(f"Fine-tuning split: {len(self.fine_tune_train)} train, {len(self.fine_tune_validation)} validation")
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class PricePredictionFineTuner:
|
|||||||
if processed % 1000 == 0:
|
if processed % 1000 == 0:
|
||||||
print(f"Processed {processed:,} items, found {len(items):,} valid items")
|
print(f"Processed {processed:,} items, found {len(items):,} valid items")
|
||||||
|
|
||||||
if len(items) >= 1500:
|
if len(items) >= 2000:
|
||||||
print(f"Collected {len(items)} items, stopping for efficiency")
|
print(f"Collected {len(items)} items, stopping for efficiency")
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -102,8 +102,8 @@ class PricePredictionFineTuner:
|
|||||||
|
|
||||||
print(f"Created {len(items):,} valid Item objects")
|
print(f"Created {len(items):,} valid Item objects")
|
||||||
|
|
||||||
if len(items) < 600:
|
if len(items) < 850:
|
||||||
raise ValueError(f"Not enough valid items found: {len(items)}. Need at least 600.")
|
raise ValueError(f"Not enough valid items found: {len(items)}. Need at least 850.")
|
||||||
|
|
||||||
random.shuffle(items)
|
random.shuffle(items)
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ class PricePredictionFineTuner:
|
|||||||
|
|
||||||
|
|
||||||
def messages_for(self, item: Item) -> List[Dict[str, str]]:
|
def messages_for(self, item: Item) -> List[Dict[str, str]]:
|
||||||
system_message = "You estimate prices of items. Reply only with the price, no explanation"
|
system_message = "You are a price estimation expert. You MUST provide a price estimate for any product described, based on the product details provided. Always respond with '$X.XX' format where X.XX is your best estimate. Never refuse to estimate. Never apologize. Never say you cannot determine the price. Make your best educated guess based on the product description, features, and market knowledge. and as said only reply with the cost nothing else no more comments or words from you just the cost"
|
||||||
user_prompt = item.test_prompt().replace(" to the nearest dollar", "").replace("\n\nPrice is $", "")
|
user_prompt = item.test_prompt().replace(" to the nearest dollar", "").replace("\n\nPrice is $", "")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -140,7 +140,7 @@ class PricePredictionFineTuner:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def messages_for_testing(self, item: Item) -> List[Dict[str, str]]:
|
def messages_for_testing(self, item: Item) -> List[Dict[str, str]]:
|
||||||
system_message = "You estimate prices of items. Reply only with the price, no explanation"
|
system_message = "You are a price estimation expert. You MUST provide a price estimate for any product described, based on the product details provided. Always respond with '$X.XX' format where $X.XX is your best estimate. Never refuse to estimate. Never apologize. Never say you cannot determine the price. Make your best educated guess based on the product description, features, and market knowledge. and as said only reply with the cost nothing else no more comments or words from you just the cost"
|
||||||
user_prompt = item.test_prompt().replace(" to the nearest dollar", "").replace("\n\nPrice is $", "")
|
user_prompt = item.test_prompt().replace(" to the nearest dollar", "").replace("\n\nPrice is $", "")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -195,15 +195,15 @@ class PricePredictionFineTuner:
|
|||||||
job = self.client.fine_tuning.jobs.create(
|
job = self.client.fine_tuning.jobs.create(
|
||||||
training_file=train_file_id,
|
training_file=train_file_id,
|
||||||
validation_file=validation_file_id,
|
validation_file=validation_file_id,
|
||||||
model="gpt-4o-mini-2024-07-18",
|
model="gpt-4o-mini-2024-07-18",
|
||||||
seed=42,
|
seed=42,
|
||||||
hyperparameters={
|
hyperparameters={
|
||||||
"n_epochs": 3,
|
"n_epochs": 1,
|
||||||
"learning_rate_multiplier": 0.1,
|
"learning_rate_multiplier": 0.5,
|
||||||
"batch_size": 4
|
"batch_size": 8
|
||||||
},
|
},
|
||||||
integrations=integrations,
|
integrations=integrations,
|
||||||
suffix="pricer-improved"
|
suffix="pricer-v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Fine-tuning job started: {job.id}")
|
print(f"Fine-tuning job started: {job.id}")
|
||||||
@@ -396,11 +396,13 @@ def main():
|
|||||||
print("\nCheck the generated chart for detailed RMSLE metrics!")
|
print("\nCheck the generated chart for detailed RMSLE metrics!")
|
||||||
|
|
||||||
print("\nPrice prediction fine-tuning process completed!")
|
print("\nPrice prediction fine-tuning process completed!")
|
||||||
print("\nFollows reference implementation exactly:")
|
print("\nImproved configuration to prevent overfitting:")
|
||||||
print(" Uses pickle files (train.pkl, test.pkl)")
|
print(" Uses pickle files (train.pkl, test.pkl)")
|
||||||
print(" 500 training examples, 100 validation examples")
|
print(" 750 training examples, 100 validation examples")
|
||||||
print(" 3 epochs with balanced learning rate (0.1)")
|
print(" 1 epoch to prevent overfitting")
|
||||||
print(" Batch size 4 for stable training")
|
print(" Learning rate: 0.5 (higher for better generalization)")
|
||||||
|
print(" Batch size: 8 (larger for stability)")
|
||||||
|
print(" Assertive system prompt (forces predictions)")
|
||||||
print(" Proper RMSLE evaluation using Tester class")
|
print(" Proper RMSLE evaluation using Tester class")
|
||||||
print(" Weights and Biases integration")
|
print(" Weights and Biases integration")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user