Spaces:
Sleeping
Sleeping
File size: 10,254 Bytes
6529956 be92e89 6529956 be92e89 6529956 3357f2e be92e89 3357f2e be92e89 6529956 be92e89 6529956 be92e89 6529956 be92e89 6529956 be92e89 6529956 be92e89 6529956 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import torch
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef
from models import ModernBertForSentiment # Assuming models.py is in the same directory
from tqdm import tqdm # Add this import for the progress bar
def evaluate(model, dataloader, device):
model.eval()
all_preds = []
all_labels = []
all_probs_for_auc = []
total_loss = 0
num_batches = len(dataloader)
processed_batches = 0
with torch.no_grad():
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
processed_batches += 1
# Move batch to device, ensure all model inputs are covered
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
lengths = batch.get('lengths') # Get lengths from batch
if lengths is None:
# Fallback or error if lengths are expected but not found
# For now, let's raise an error if using weighted loss that needs it
# Or, if your model can run without it for some pooling strategies, handle accordingly
# However, the error clearly states it's needed when labels are specified.
pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model")
else:
lengths = lengths.to(device) # Move to device if found
# Pass all necessary parts of the batch to the model
model_inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
if lengths is not None:
model_inputs['lengths'] = lengths
outputs = model(**model_inputs)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
if logits.shape[1] > 1:
preds = torch.argmax(logits, dim=1)
else:
preds = (torch.sigmoid(logits) > 0.5).long()
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Populate probabilities for AUC calculation
if logits.shape[1] > 1:
# Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC
probs_for_auc = torch.softmax(logits, dim=1)[:, 1]
else:
# Binary classification with a single logit output
probs_for_auc = torch.sigmoid(logits).squeeze()
all_probs_for_auc.extend(probs_for_auc.cpu().numpy())
# Yield progress update
progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero
if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
avg_loss = total_loss / num_batches
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
mcc = matthews_corrcoef(all_labels, all_preds)
try:
roc_auc = roc_auc_score(all_labels, all_probs_for_auc)
except ValueError as e:
print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0")
roc_auc = 0.0
results = {
'accuracy': accuracy,
'f1': f1,
'roc_auc': roc_auc,
'precision': precision,
'recall': recall,
'mcc': mcc,
'average_loss': avg_loss
}
yield f"Processed {processed_batches}/{num_batches} batches (100.00%)" # Ensure final progress update
yield "Evaluation complete. Compiling results..."
yield results
if __name__ == "__main__":
import argparse
from torch.utils.data import DataLoader
from datasets import load_dataset
from inference import SentimentInference # Assuming inference.py is in the same directory
import yaml
from transformers import AutoTokenizer, AutoConfig
from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH
class SentimentInference:
def __init__(self, config_path):
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)
self.config_path = config_path
self.config_data = config_data
# Adjust to access the correct key from the nested config structure
self.model_hf_repo_id = config_data['model']['name_or_path']
self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id)
self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model'
self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False)
self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code
self.max_length = config_data['model']['max_length']
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
try:
if self.load_from_local_pt and self.local_model_weights_path:
print(f"Loading model from local path: {self.local_model_weights_path}")
# When loading local, config might also be local or from base model if not saved with custom checkpoint
# For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code
self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config)
self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True)
else:
print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}")
self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config)
print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...")
self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True)
print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.")
self.model.to(self.device)
except Exception as e:
print(f"Failed to load model: {e}")
# Optionally print more detailed traceback
import traceback
traceback.print_exc()
exit()
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config)
def print_debug_info(self):
print(f"Model HF Repo ID: {self.model_hf_repo_id}")
print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}")
print(f"Local Model Weights Path: {self.local_model_weights_path}")
print(f"Load from Local PT: {self.load_from_local_pt}")
parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.")
parser.add_argument(
"--config_path",
type=str,
default="local_test_config.yaml",
help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)"
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Batch size for evaluation."
)
args = parser.parse_args()
print(f"Using configuration: {args.config_path}")
print("Loading sentiment model and tokenizer...")
inferer = SentimentInference(config_path=args.config_path)
model = inferer.model
tokenizer = inferer.tokenizer
max_length = inferer.max_length
device = inferer.device
print("Loading IMDB test dataset...")
try:
imdb_dataset_test = load_dataset("imdb", split="test")
except Exception as e:
print(f"Failed to load IMDB dataset: {e}")
exit()
print("Tokenizing dataset...")
def tokenize_function(examples):
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
return tokenized_output
tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True)
tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"])
tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels")
tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size)
print("Starting evaluation...")
progress_bar = tqdm(evaluate(model, test_dataloader, device), desc="Evaluating")
for update in progress_bar:
if isinstance(update, dict):
results = update
break
else:
progress_bar.set_postfix_str(update)
print("\n--- Evaluation Results ---")
for key, value in results.items():
if isinstance(value, float):
print(f"{key.capitalize()}: {value:.4f}")
else:
print(f"{key.capitalize()}: {value}") |