elsayedmohammed commited on
Commit
c26f717
·
verified ·
1 Parent(s): 3e17f60

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +3 -40
script.py CHANGED
@@ -1,45 +1,8 @@
1
  mport torch
2
  import numpy as np
3
- from sklearn.metrics import accuracy_score # Example metric
4
-
5
- # Load your hidden test set (adjust path and format to your data)
6
- TEST_DATA_PATH = "test_data.pt" # Replace with the actual path
7
- TEST_LABELS_PATH = "test_labels.pt"
8
-
9
- test_data = torch.load(TEST_DATA_PATH)
10
- test_labels = torch.load(TEST_LABELS_PATH)
11
-
12
- # Evaluation script entry point
13
- def evaluate_submission(model_checkpoint_path: str):
14
- """
15
- Evaluates the submitted model on the hidden test set.
16
- Args:
17
- model_checkpoint_path (str): Path to the submitted model checkpoint.
18
-
19
- Returns:
20
- dict: A dictionary containing the evaluation metrics.
21
- """
22
- # Load the participant's model
23
- model = torch.load(model_checkpoint_path)
24
- model.eval()
25
-
26
- # Move model and data to the appropriate device (e.g., GPU if available)
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- model = model.to(device)
29
- test_data_tensor = test_data.to(device)
30
-
31
- # Perform inference
32
- with torch.no_grad():
33
- predictions = model(test_data_tensor)
34
- predictions = torch.argmax(predictions, axis=1).cpu().numpy()
35
-
36
- # Calculate evaluation metric (e.g., accuracy)
37
- accuracy = accuracy_score(test_labels, predictions)
38
-
39
- return {"accuracy": accuracy} # Replace with other metrics as needed
40
 
41
  if __name__ == "__main__":
42
  # For local testing, you can pass a sample model path here
43
- sample_model_path = "sample_submission.pt" # Replace with a test checkpoint
44
- result = evaluate_submission(sample_model_path)
45
- print(result)
 
1
  mport torch
2
  import numpy as np
3
+ from sklearn.metrics import accuracy_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  if __name__ == "__main__":
6
  # For local testing, you can pass a sample model path here
7
+
8
+ print("inside script.py")