Spaces:
Sleeping
Sleeping
import torch | |
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef | |
from models import ModernBertForSentiment # Assuming models.py is in the same directory | |
from tqdm import tqdm # Add this import for the progress bar | |
def evaluate(model, dataloader, device): | |
model.eval() | |
all_preds = [] | |
all_labels = [] | |
all_probs_for_auc = [] | |
total_loss = 0 | |
num_batches = len(dataloader) | |
processed_batches = 0 | |
with torch.no_grad(): | |
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress | |
processed_batches += 1 | |
# Move batch to device, ensure all model inputs are covered | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
labels = batch['labels'].to(device) | |
lengths = batch.get('lengths') # Get lengths from batch | |
if lengths is None: | |
# Fallback or error if lengths are expected but not found | |
# For now, let's raise an error if using weighted loss that needs it | |
# Or, if your model can run without it for some pooling strategies, handle accordingly | |
# However, the error clearly states it's needed when labels are specified. | |
pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model") | |
else: | |
lengths = lengths.to(device) # Move to device if found | |
# Pass all necessary parts of the batch to the model | |
model_inputs = { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'labels': labels | |
} | |
if lengths is not None: | |
model_inputs['lengths'] = lengths | |
outputs = model(**model_inputs) | |
loss = outputs.loss | |
logits = outputs.logits | |
total_loss += loss.item() | |
if logits.shape[1] > 1: | |
preds = torch.argmax(logits, dim=1) | |
else: | |
preds = (torch.sigmoid(logits) > 0.5).long() | |
all_preds.extend(preds.cpu().numpy()) | |
all_labels.extend(labels.cpu().numpy()) | |
# Populate probabilities for AUC calculation | |
if logits.shape[1] > 1: | |
# Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC | |
probs_for_auc = torch.softmax(logits, dim=1)[:, 1] | |
else: | |
# Binary classification with a single logit output | |
probs_for_auc = torch.sigmoid(logits).squeeze() | |
all_probs_for_auc.extend(probs_for_auc.cpu().numpy()) | |
# Yield progress update | |
progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero | |
if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final | |
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)" | |
avg_loss = total_loss / num_batches | |
accuracy = accuracy_score(all_labels, all_preds) | |
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) | |
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) | |
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) | |
mcc = matthews_corrcoef(all_labels, all_preds) | |
try: | |
roc_auc = roc_auc_score(all_labels, all_probs_for_auc) | |
except ValueError as e: | |
print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0") | |
roc_auc = 0.0 | |
results = { | |
'accuracy': accuracy, | |
'f1': f1, | |
'roc_auc': roc_auc, | |
'precision': precision, | |
'recall': recall, | |
'mcc': mcc, | |
'average_loss': avg_loss | |
} | |
yield f"Processed {processed_batches}/{num_batches} batches (100.00%)" # Ensure final progress update | |
yield "Evaluation complete. Compiling results..." | |
yield results | |
if __name__ == "__main__": | |
import argparse | |
from torch.utils.data import DataLoader | |
from datasets import load_dataset | |
from inference import SentimentInference # Assuming inference.py is in the same directory | |
import yaml | |
from transformers import AutoTokenizer, AutoConfig | |
from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH | |
class SentimentInference: | |
def __init__(self, config_path): | |
with open(config_path, 'r') as f: | |
config_data = yaml.safe_load(f) | |
self.config_path = config_path | |
self.config_data = config_data | |
# Adjust to access the correct key from the nested config structure | |
self.model_hf_repo_id = config_data['model']['name_or_path'] | |
self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id) | |
self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model' | |
self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False) | |
self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code | |
self.max_length = config_data['model']['max_length'] | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
try: | |
if self.load_from_local_pt and self.local_model_weights_path: | |
print(f"Loading model from local path: {self.local_model_weights_path}") | |
# When loading local, config might also be local or from base model if not saved with custom checkpoint | |
# For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code | |
self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config) | |
self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True) | |
else: | |
print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}") | |
self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config) | |
print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...") | |
self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True) | |
print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.") | |
self.model.to(self.device) | |
except Exception as e: | |
print(f"Failed to load model: {e}") | |
# Optionally print more detailed traceback | |
import traceback | |
traceback.print_exc() | |
exit() | |
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config) | |
def print_debug_info(self): | |
print(f"Model HF Repo ID: {self.model_hf_repo_id}") | |
print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}") | |
print(f"Local Model Weights Path: {self.local_model_weights_path}") | |
print(f"Load from Local PT: {self.load_from_local_pt}") | |
parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.") | |
parser.add_argument( | |
"--config_path", | |
type=str, | |
default="local_test_config.yaml", | |
help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)" | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=16, | |
help="Batch size for evaluation." | |
) | |
args = parser.parse_args() | |
print(f"Using configuration: {args.config_path}") | |
print("Loading sentiment model and tokenizer...") | |
inferer = SentimentInference(config_path=args.config_path) | |
model = inferer.model | |
tokenizer = inferer.tokenizer | |
max_length = inferer.max_length | |
device = inferer.device | |
print("Loading IMDB test dataset...") | |
try: | |
imdb_dataset_test = load_dataset("imdb", split="test") | |
except Exception as e: | |
print(f"Failed to load IMDB dataset: {e}") | |
exit() | |
print("Tokenizing dataset...") | |
def tokenize_function(examples): | |
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length) | |
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]] | |
return tokenized_output | |
tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True) | |
tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"]) | |
tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels") | |
tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"]) | |
test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size) | |
print("Starting evaluation...") | |
progress_bar = tqdm(evaluate(model, test_dataloader, device), desc="Evaluating") | |
for update in progress_bar: | |
if isinstance(update, dict): | |
results = update | |
break | |
else: | |
progress_bar.set_postfix_str(update) | |
print("\n--- Evaluation Results ---") | |
for key, value in results.items(): | |
if isinstance(value, float): | |
print(f"{key.capitalize()}: {value:.4f}") | |
else: | |
print(f"{key.capitalize()}: {value}") |