Add classifier testing framework for Banking Intent Model

This commit introduces a new Python module, classifier_tester.py, which provides a testing framework for evaluating the accuracy of classification models on intent classification tasks. The module includes methods for running tests on individual data points, reporting metrics, and visualizing confusion pairs, enhancing the overall testing capabilities for the Banking77 application.
This commit is contained in:
Hope Ogbons
2025-10-31 03:19:59 +01:00
parent 8368944a43
commit 3414454f43

View File

@@ -0,0 +1,123 @@
"""
Classification Tester for Banking Intent Model
Evaluates model accuracy on intent classification
"""
import matplotlib.pyplot as plt
from collections import Counter
from banking_intents import get_intent
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"
class ClassifierTester:
"""Test framework for classification models"""
def __init__(self, predictor, data, title=None, size=100):
self.predictor = predictor
self.data = data
self.title = title or predictor.__name__.replace("_", " ").title()
self.size = min(size, len(data))
self.predictions = []
self.actuals = []
self.correct = 0
self.incorrect = 0
def run_datapoint(self, i):
"""Test a single example"""
item = self.data[i]
# Get prediction
predicted_intent = self.predictor(item)
actual_intent = get_intent(item['label'])
# Check if correct
is_correct = predicted_intent == actual_intent
if is_correct:
self.correct += 1
color = GREEN
status = ""
else:
self.incorrect += 1
color = RED
status = ""
self.predictions.append(predicted_intent)
self.actuals.append(actual_intent)
# Print result
query = item['text'][:60] + "..." if len(item['text']) > 60 else item['text']
print(f"{color}{status} {i+1}: {query}")
print(f" Predicted: {predicted_intent} | Actual: {actual_intent}{RESET}")
def chart(self):
"""Visualize top confusion pairs"""
# Find misclassifications
errors = {}
for pred, actual in zip(self.predictions, self.actuals):
if pred != actual:
pair = f"{actual}{pred}"
errors[pair] = errors.get(pair, 0) + 1
if not errors:
print("\n🎉 Perfect accuracy - no confusion to plot!")
return
# Plot top 10 confusions
top_errors = sorted(errors.items(), key=lambda x: x[1], reverse=True)[:10]
if top_errors:
labels = [pair for pair, _ in top_errors]
counts = [count for _, count in top_errors]
plt.figure(figsize=(12, 6))
plt.barh(labels, counts, color='coral')
plt.xlabel('Count')
plt.title('Top 10 Confusion Pairs (Actual → Predicted)')
plt.tight_layout()
plt.show()
def report(self):
"""Print final metrics and chart"""
accuracy = (self.correct / self.size) * 100
print("\n" + "="*70)
print(f"MODEL: {self.title}")
print(f"TESTED: {self.size} examples")
print(f"CORRECT: {self.correct} ({accuracy:.1f}%)")
print(f"INCORRECT: {self.incorrect}")
print("="*70)
# Show most common errors
if self.incorrect > 0:
print("\nMost Common Errors:")
error_pairs = [(self.actuals[i], self.predictions[i])
for i in range(len(self.actuals))
if self.actuals[i] != self.predictions[i]]
error_counts = Counter(error_pairs).most_common(5)
for (actual, pred), count in error_counts:
print(f" {actual}{pred}: {count} times")
# Chart
self.chart()
return accuracy
def run(self):
"""Run the complete evaluation"""
print(f"Testing {self.title} on {self.size} examples...\n")
for i in range(self.size):
self.run_datapoint(i)
return self.report()
@classmethod
def test(cls, function, data, size=100):
"""Convenience method to test a predictor function"""
return cls(function, data, size=size).run()