| | """ |
| | Evaluation script for trained model with comprehensive analysis |
| | """ |
| | import argparse |
| | import sys |
| | import os |
| | import numpy as np |
| | import pandas as pd |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer |
| |
|
| | |
| | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| |
|
| | from src import ( |
| | load_config, |
| | compute_metrics_factory, |
| | plot_confusion_matrix, |
| | print_classification_report |
| | ) |
| | from src.data_loader import prepare_datasets_for_training |
| |
|
| |
|
| | def analyze_errors( |
| | test_dataset, |
| | predictions: np.ndarray, |
| | labels: np.ndarray, |
| | id2label: dict, |
| | tokenizer, |
| | top_n: int = 10 |
| | ) -> pd.DataFrame: |
| | """ |
| | Analyze misclassified examples. |
| | |
| | Args: |
| | test_dataset: Test dataset |
| | predictions: Predicted labels |
| | labels: True labels |
| | id2label: Label mapping |
| | tokenizer: Tokenizer to decode text |
| | top_n: Number of examples to show per error type |
| | |
| | Returns: |
| | DataFrame with error analysis |
| | """ |
| | errors = [] |
| | for i, (pred, true_label) in enumerate(zip(predictions, labels)): |
| | if pred != true_label: |
| | |
| | |
| | errors.append({ |
| | 'index': i, |
| | 'true_label': id2label[true_label], |
| | 'predicted_label': id2label[pred], |
| | 'error_type': f"{id2label[true_label]} -> {id2label[pred]}" |
| | }) |
| | |
| | error_df = pd.DataFrame(errors) |
| | if len(error_df) > 0: |
| | print(f"\nError Analysis:") |
| | print(f"Total errors: {len(error_df)}") |
| | print(f"\nError type distribution:") |
| | print(error_df['error_type'].value_counts()) |
| | |
| | return error_df |
| |
|
| |
|
| | def evaluate_model( |
| | model_path: str, |
| | config_path: str = "config.yaml", |
| | save_plots: bool = True |
| | ): |
| | """ |
| | Evaluate trained model on test set with comprehensive analysis. |
| | |
| | Args: |
| | model_path: Path to the trained model |
| | config_path: Path to configuration file |
| | save_plots: Whether to save visualization plots |
| | """ |
| | print("=" * 60) |
| | print("Model Evaluation") |
| | print("=" * 60) |
| | |
| | |
| | config = load_config(config_path) |
| | |
| | |
| | output_dir = config['training'].get('output_dir', './results') |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | print("\n[1/5] Loading datasets...") |
| | tokenized_datasets, label2id, id2label, _ = prepare_datasets_for_training(config_path) |
| | test_dataset = tokenized_datasets['test'] |
| | print(f"✓ Test samples: {len(test_dataset)}") |
| | |
| | |
| | print("\n[2/5] Loading trained model...") |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | model = AutoModelForSequenceClassification.from_pretrained(model_path) |
| | print(f"✓ Model loaded from {model_path}") |
| | |
| | |
| | print("\n[3/5] Running evaluation...") |
| | compute_metrics_fn = compute_metrics_factory(id2label) |
| | trainer = Trainer( |
| | model=model, |
| | tokenizer=tokenizer, |
| | compute_metrics=compute_metrics_fn |
| | ) |
| | |
| | |
| | predictions_output = trainer.predict(test_dataset) |
| | predictions = np.argmax(predictions_output.predictions, axis=1) |
| | labels = predictions_output.label_ids |
| | |
| | |
| | print("\n[4/5] Computing detailed metrics...") |
| | print("\n" + "=" * 60) |
| | print("Test Set Results") |
| | print("=" * 60) |
| | |
| | metrics = predictions_output.metrics |
| | |
| | |
| | print("\nOverall Metrics:") |
| | overall_metrics = ['accuracy', 'f1_weighted', 'f1_macro', 'precision_weighted', 'recall_weighted'] |
| | for metric in overall_metrics: |
| | key = f'test_{metric}' |
| | if key in metrics: |
| | print(f" {metric.replace('_', ' ').title()}: {metrics[key]:.4f}") |
| | |
| | |
| | print("\nPer-Class Metrics:") |
| | label_names = [id2label[i] for i in range(len(id2label))] |
| | for label_name in label_names: |
| | precision_key = f'test_precision_{label_name}' |
| | recall_key = f'test_recall_{label_name}' |
| | f1_key = f'test_f1_{label_name}' |
| | if precision_key in metrics: |
| | print(f"\n {label_name.upper()}:") |
| | print(f" Precision: {metrics[precision_key]:.4f}") |
| | print(f" Recall: {metrics[recall_key]:.4f}") |
| | print(f" F1-Score: {metrics[f1_key]:.4f}") |
| | print(f" Support: {metrics.get(f'test_support_{label_name}', 'N/A')}") |
| | |
| | |
| | print("\n" + "=" * 60) |
| | print_classification_report(labels, predictions, label_names) |
| | |
| | |
| | print("\n[5/5] Generating visualizations...") |
| | if save_plots: |
| | plot_confusion_matrix( |
| | labels, |
| | predictions, |
| | label_names, |
| | save_path=os.path.join(output_dir, "confusion_matrix.png"), |
| | normalize=False |
| | ) |
| | |
| | |
| | plot_confusion_matrix( |
| | labels, |
| | predictions, |
| | label_names, |
| | save_path=os.path.join(output_dir, "confusion_matrix_normalized.png"), |
| | normalize=True |
| | ) |
| | |
| | |
| | error_df = analyze_errors(test_dataset, predictions, labels, id2label, tokenizer) |
| | if len(error_df) > 0 and save_plots: |
| | error_path = os.path.join(output_dir, "error_analysis.csv") |
| | error_df.to_csv(error_path, index=False) |
| | print(f"✓ Error analysis saved to {error_path}") |
| | |
| | print("\n" + "=" * 60) |
| | print("Evaluation Complete! 🎉") |
| | print("=" * 60) |
| | print(f"\nResults saved to: {output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Evaluate trained model") |
| | parser.add_argument( |
| | "--model-path", |
| | type=str, |
| | default="./results/final_model", |
| | help="Path to the trained model" |
| | ) |
| | parser.add_argument( |
| | "--config", |
| | type=str, |
| | default="config.yaml", |
| | help="Path to configuration file" |
| | ) |
| | parser.add_argument( |
| | "--no-plots", |
| | action="store_true", |
| | help="Skip generating visualization plots" |
| | ) |
| | args = parser.parse_args() |
| | |
| | evaluate_model(args.model_path, args.config, save_plots=not args.no_plots) |
| |
|