imdb-sentiment-demo / evaluation.py
voxmenthe's picture
try again on gradio progress update streaming
3357f2e
import torch
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef
from models import ModernBertForSentiment # Assuming models.py is in the same directory
from tqdm import tqdm # Add this import for the progress bar
def evaluate(model, dataloader, device):
model.eval()
all_preds = []
all_labels = []
all_probs_for_auc = []
total_loss = 0
num_batches = len(dataloader)
processed_batches = 0
with torch.no_grad():
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
processed_batches += 1
# Move batch to device, ensure all model inputs are covered
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
lengths = batch.get('lengths') # Get lengths from batch
if lengths is None:
# Fallback or error if lengths are expected but not found
# For now, let's raise an error if using weighted loss that needs it
# Or, if your model can run without it for some pooling strategies, handle accordingly
# However, the error clearly states it's needed when labels are specified.
pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model")
else:
lengths = lengths.to(device) # Move to device if found
# Pass all necessary parts of the batch to the model
model_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
if lengths is not None:
model_inputs['lengths'] = lengths
outputs = model(**model_inputs)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
if logits.shape[1] > 1:
preds = torch.argmax(logits, dim=1)
else:
preds = (torch.sigmoid(logits) > 0.5).long()
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Populate probabilities for AUC calculation
if logits.shape[1] > 1:
# Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC
probs_for_auc = torch.softmax(logits, dim=1)[:, 1]
else:
# Binary classification with a single logit output
probs_for_auc = torch.sigmoid(logits).squeeze()
all_probs_for_auc.extend(probs_for_auc.cpu().numpy())
# Yield progress update
progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero
if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
avg_loss = total_loss / num_batches
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
mcc = matthews_corrcoef(all_labels, all_preds)
try:
roc_auc = roc_auc_score(all_labels, all_probs_for_auc)
except ValueError as e:
print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0")
roc_auc = 0.0
results = {
'accuracy': accuracy,
'f1': f1,
'roc_auc': roc_auc,
'precision': precision,
'recall': recall,
'mcc': mcc,
'average_loss': avg_loss
}
yield f"Processed {processed_batches}/{num_batches} batches (100.00%)" # Ensure final progress update
yield "Evaluation complete. Compiling results..."
yield results
if __name__ == "__main__":
import argparse
from torch.utils.data import DataLoader
from datasets import load_dataset
from inference import SentimentInference # Assuming inference.py is in the same directory
import yaml
from transformers import AutoTokenizer, AutoConfig
from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH
class SentimentInference:
def __init__(self, config_path):
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
self.config_path = config_path
self.config_data = config_data
# Adjust to access the correct key from the nested config structure
self.model_hf_repo_id = config_data['model']['name_or_path']
self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id)
self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model'
self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False)
self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code
self.max_length = config_data['model']['max_length']
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
try:
if self.load_from_local_pt and self.local_model_weights_path:
print(f"Loading model from local path: {self.local_model_weights_path}")
# When loading local, config might also be local or from base model if not saved with custom checkpoint
# For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code
self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config)
self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True)
else:
print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}")
self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config)
print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...")
self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True)
print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.")
self.model.to(self.device)
except Exception as e:
print(f"Failed to load model: {e}")
# Optionally print more detailed traceback
import traceback
traceback.print_exc()
exit()
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config)
def print_debug_info(self):
print(f"Model HF Repo ID: {self.model_hf_repo_id}")
print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}")
print(f"Local Model Weights Path: {self.local_model_weights_path}")
print(f"Load from Local PT: {self.load_from_local_pt}")
parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.")
parser.add_argument(
"--config_path",
type=str,
default="local_test_config.yaml",
help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)"
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Batch size for evaluation."
)
args = parser.parse_args()
print(f"Using configuration: {args.config_path}")
print("Loading sentiment model and tokenizer...")
inferer = SentimentInference(config_path=args.config_path)
model = inferer.model
tokenizer = inferer.tokenizer
max_length = inferer.max_length
device = inferer.device
print("Loading IMDB test dataset...")
try:
imdb_dataset_test = load_dataset("imdb", split="test")
except Exception as e:
print(f"Failed to load IMDB dataset: {e}")
exit()
print("Tokenizing dataset...")
def tokenize_function(examples):
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
return tokenized_output
tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True)
tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"])
tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels")
tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size)
print("Starting evaluation...")
progress_bar = tqdm(evaluate(model, test_dataloader, device), desc="Evaluating")
for update in progress_bar:
if isinstance(update, dict):
results = update
break
else:
progress_bar.set_postfix_str(update)
print("\n--- Evaluation Results ---")
for key, value in results.items():
if isinstance(value, float):
print(f"{key.capitalize()}: {value:.4f}")
else:
print(f"{key.capitalize()}: {value}")