imdb-sentiment-demo / evaluation.py
voxmenthe's picture
add full evaluation suite to app
6529956
raw
history blame
9.13 kB
import torch
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef
from models import ModernBertForSentiment # Assuming models.py is in the same directory
from tqdm import tqdm # Add this import for the progress bar
def evaluate(model, dataloader, device):
model.eval()
all_preds = []
all_labels = []
all_probs_for_auc = []
total_loss = 0
with torch.no_grad():
for batch in dataloader:
# Move batch to device, ensure all model inputs are covered
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
lengths = batch.get('lengths') # Get lengths from batch
if lengths is None:
# Fallback or error if lengths are expected but not found
# For now, let's raise an error if using weighted loss that needs it
# Or, if your model can run without it for some pooling strategies, handle accordingly
# However, the error clearly states it's needed when labels are specified.
pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model")
else:
lengths = lengths.to(device) # Move to device if found
# Pass all necessary parts of the batch to the model
model_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
if lengths is not None:
model_inputs['lengths'] = lengths
outputs = model(**model_inputs)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
if logits.shape[1] > 1:
preds = torch.argmax(logits, dim=1)
else:
preds = (torch.sigmoid(logits) > 0.5).long()
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
if logits.shape[1] > 1:
probs = torch.softmax(logits, dim=1)[:, 1]
all_probs_for_auc.extend(probs.cpu().numpy())
else:
probs = torch.sigmoid(logits)
all_probs_for_auc.extend(probs.squeeze().cpu().numpy())
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
mcc = matthews_corrcoef(all_labels, all_preds)
try:
roc_auc = roc_auc_score(all_labels, all_probs_for_auc)
except ValueError as e:
print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0")
roc_auc = 0.0
return {
'loss': avg_loss,
'accuracy': accuracy,
'f1': f1,
'roc_auc': roc_auc,
'precision': precision,
'recall': recall,
'mcc': mcc
}
if __name__ == "__main__":
import argparse
from torch.utils.data import DataLoader
from datasets import load_dataset
from inference import SentimentInference # Assuming inference.py is in the same directory
import yaml
from transformers import AutoTokenizer, AutoConfig
from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH
class SentimentInference:
def __init__(self, config_path):
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
self.config_path = config_path
self.config_data = config_data
# Adjust to access the correct key from the nested config structure
self.model_hf_repo_id = config_data['model']['name_or_path']
self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id)
self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model'
self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False)
self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code
self.max_length = config_data['model']['max_length']
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
try:
if self.load_from_local_pt and self.local_model_weights_path:
print(f"Loading model from local path: {self.local_model_weights_path}")
# When loading local, config might also be local or from base model if not saved with custom checkpoint
# For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code
self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config)
self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True)
else:
print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}")
self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config)
print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...")
self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True)
print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.")
self.model.to(self.device)
except Exception as e:
print(f"Failed to load model: {e}")
# Optionally print more detailed traceback
import traceback
traceback.print_exc()
exit()
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config)
def print_debug_info(self):
print(f"Model HF Repo ID: {self.model_hf_repo_id}")
print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}")
print(f"Local Model Weights Path: {self.local_model_weights_path}")
print(f"Load from Local PT: {self.load_from_local_pt}")
parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.")
parser.add_argument(
"--config_path",
type=str,
default="local_test_config.yaml",
help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)"
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Batch size for evaluation."
)
args = parser.parse_args()
print(f"Using configuration: {args.config_path}")
print("Loading sentiment model and tokenizer...")
inferer = SentimentInference(config_path=args.config_path)
model = inferer.model
tokenizer = inferer.tokenizer
max_length = inferer.max_length
device = inferer.device
print("Loading IMDB test dataset...")
try:
imdb_dataset_test = load_dataset("imdb", split="test")
except Exception as e:
print(f"Failed to load IMDB dataset: {e}")
exit()
print("Tokenizing dataset...")
def tokenize_function(examples):
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
return tokenized_output
tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True)
tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"])
tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels")
tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size)
print("Starting evaluation...")
progress_bar = tqdm(test_dataloader, desc="Evaluating")
results = evaluate(model, progress_bar, device)
print("\n--- Evaluation Results ---")
for key, value in results.items():
if isinstance(value, float):
print(f"{key.capitalize()}: {value:.4f}")
else:
print(f"{key.capitalize()}: {value}")