KeivanR commited on
Commit
0c10f09
·
1 Parent(s): 22054c0

preds eval to numpy

Browse files
Files changed (1) hide show
  1. qwen_classifier/evaluate.py +2 -3
qwen_classifier/evaluate.py CHANGED
@@ -146,9 +146,8 @@ def _evaluate_local(test_data_path, hf_repo):
146
 
147
  logits = global_model(batch["input_ids"], batch["attention_mask"])
148
 
149
- preds = torch.sigmoid(logits).cpu() > 0.5 # Keeps as PyTorch tensor
150
- preds = preds.float() # Convert to 0.0/1.0 if needed
151
- labels = labels.cpu()
152
 
153
  all_preds.extend(preds)
154
  all_labels.extend(labels)
 
146
 
147
  logits = global_model(batch["input_ids"], batch["attention_mask"])
148
 
149
+ preds = torch.sigmoid(logits).cpu().numpy() > 0.5
150
+ labels = labels.cpu().numpy()
 
151
 
152
  all_preds.extend(preds)
153
  all_labels.extend(labels)