diff --git a/week6/community-contributions/hopeogbons/classifier_tester.py b/week6/community-contributions/hopeogbons/classifier_tester.py new file mode 100644 index 0000000..c9662f7 --- /dev/null +++ b/week6/community-contributions/hopeogbons/classifier_tester.py @@ -0,0 +1,123 @@ +""" +Classification Tester for Banking Intent Model +Evaluates model accuracy on intent classification +""" + +import matplotlib.pyplot as plt +from collections import Counter +from banking_intents import get_intent + +GREEN = "\033[92m" +RED = "\033[91m" +RESET = "\033[0m" + + +class ClassifierTester: + """Test framework for classification models""" + + def __init__(self, predictor, data, title=None, size=100): + self.predictor = predictor + self.data = data + self.title = title or predictor.__name__.replace("_", " ").title() + self.size = min(size, len(data)) + self.predictions = [] + self.actuals = [] + self.correct = 0 + self.incorrect = 0 + + def run_datapoint(self, i): + """Test a single example""" + item = self.data[i] + + # Get prediction + predicted_intent = self.predictor(item) + actual_intent = get_intent(item['label']) + + # Check if correct + is_correct = predicted_intent == actual_intent + + if is_correct: + self.correct += 1 + color = GREEN + status = "āœ“" + else: + self.incorrect += 1 + color = RED + status = "āœ—" + + self.predictions.append(predicted_intent) + self.actuals.append(actual_intent) + + # Print result + query = item['text'][:60] + "..." if len(item['text']) > 60 else item['text'] + print(f"{color}{status} {i+1}: {query}") + print(f" Predicted: {predicted_intent} | Actual: {actual_intent}{RESET}") + + def chart(self): + """Visualize top confusion pairs""" + # Find misclassifications + errors = {} + for pred, actual in zip(self.predictions, self.actuals): + if pred != actual: + pair = f"{actual} → {pred}" + errors[pair] = errors.get(pair, 0) + 1 + + if not errors: + print("\nšŸŽ‰ Perfect accuracy - no confusion to plot!") + return + + # Plot top 10 confusions + top_errors = sorted(errors.items(), key=lambda x: x[1], reverse=True)[:10] + + if top_errors: + labels = [pair for pair, _ in top_errors] + counts = [count for _, count in top_errors] + + plt.figure(figsize=(12, 6)) + plt.barh(labels, counts, color='coral') + plt.xlabel('Count') + plt.title('Top 10 Confusion Pairs (Actual → Predicted)') + plt.tight_layout() + plt.show() + + def report(self): + """Print final metrics and chart""" + accuracy = (self.correct / self.size) * 100 + + print("\n" + "="*70) + print(f"MODEL: {self.title}") + print(f"TESTED: {self.size} examples") + print(f"CORRECT: {self.correct} ({accuracy:.1f}%)") + print(f"INCORRECT: {self.incorrect}") + print("="*70) + + # Show most common errors + if self.incorrect > 0: + print("\nMost Common Errors:") + error_pairs = [(self.actuals[i], self.predictions[i]) + for i in range(len(self.actuals)) + if self.actuals[i] != self.predictions[i]] + error_counts = Counter(error_pairs).most_common(5) + + for (actual, pred), count in error_counts: + print(f" {actual} → {pred}: {count} times") + + # Chart + self.chart() + + return accuracy + + def run(self): + """Run the complete evaluation""" + print(f"Testing {self.title} on {self.size} examples...\n") + + for i in range(self.size): + self.run_datapoint(i) + + return self.report() + + @classmethod + def test(cls, function, data, size=100): + """Convenience method to test a predictor function""" + return cls(function, data, size=size).run() +