voxmenthe commited on
Commit
6529956
·
1 Parent(s): 127dca6

add full evaluation suite to app

Browse files
Files changed (8) hide show
  1. app.py +80 -0
  2. config.yaml +1 -0
  3. evaluation.py +189 -0
  4. inference.py +84 -35
  5. local_test_config.yaml +12 -0
  6. original_config.yaml +46 -0
  7. run_sample_inference.py +60 -0
  8. upload_to_hf.py +190 -38
app.py CHANGED
@@ -3,6 +3,10 @@ from inference import SentimentInference
3
  import os
4
  from datasets import load_dataset
5
  import random
 
 
 
 
6
 
7
  # --- Initialize Sentiment Model ---
8
  CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
@@ -66,6 +70,61 @@ def predict_sentiment(text_input, true_label_state):
66
  print(f"Error during prediction: {e}")
67
  return f"Error during prediction: {str(e)}", true_label_state
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # --- Gradio Interface ---
70
  with gr.Blocks() as demo:
71
  true_label = gr.State()
@@ -91,6 +150,27 @@ with gr.Blocks() as demo:
91
  inputs=input_textbox
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Wire actions
95
  submit_button.click(
96
  fn=predict_sentiment,
 
3
  import os
4
  from datasets import load_dataset
5
  import random
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from evaluation import evaluate
9
+ from tqdm import tqdm
10
 
11
  # --- Initialize Sentiment Model ---
12
  CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
 
70
  print(f"Error during prediction: {e}")
71
  return f"Error during prediction: {str(e)}", true_label_state
72
 
73
+ def run_full_evaluation_gradio():
74
+ """Runs full evaluation on the IMDB test set and yields results for Gradio."""
75
+ if sentiment_inferer is None or sentiment_inferer.model is None:
76
+ yield "Error: Sentiment model could not be loaded. Cannot run evaluation."
77
+ return
78
+
79
+ try:
80
+ yield "Starting full evaluation... This will process 25,000 samples and may take 10-20 minutes. Please be patient."
81
+
82
+ device = sentiment_inferer.device
83
+ model = sentiment_inferer.model
84
+ tokenizer = sentiment_inferer.tokenizer
85
+ max_length = sentiment_inferer.max_length
86
+ batch_size = 16 # Consistent with evaluation.py default
87
+
88
+ yield "Loading IMDB test dataset (this might take a moment)..."
89
+ imdb_test_full = load_dataset("imdb", split="test")
90
+ yield f"IMDB test dataset loaded ({len(imdb_test_full)} samples). Tokenizing dataset..."
91
+
92
+ def tokenize_function(examples):
93
+ tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
94
+ tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
95
+ return tokenized_output
96
+
97
+ tokenized_imdb_test_full = imdb_test_full.map(tokenize_function, batched=True, num_proc=os.cpu_count()//2 if os.cpu_count() > 1 else 1)
98
+ tokenized_imdb_test_full = tokenized_imdb_test_full.remove_columns(["text"])
99
+ tokenized_imdb_test_full = tokenized_imdb_test_full.rename_column("label", "labels")
100
+ tokenized_imdb_test_full.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
101
+
102
+ test_dataloader_full = DataLoader(tokenized_imdb_test_full, batch_size=batch_size)
103
+ yield "Dataset tokenized and DataLoader prepared. Starting model evaluation on the test set..."
104
+
105
+ # The 'evaluate' function from evaluation.py expects the dataloader to be potentially wrapped by tqdm
106
+ # based on how it was called in evaluation.py's main block.
107
+ # We will wrap it with tqdm here for consistency if evaluate function expects it.
108
+ # Note: tqdm progress here will go to console, not Gradio UI directly.
109
+ tqdm_dataloader = tqdm(test_dataloader_full, desc="Evaluating in App")
110
+
111
+ results = evaluate(model, tqdm_dataloader, device)
112
+
113
+ results_str = "--- Full Evaluation Results ---\n"
114
+ for key, value in results.items():
115
+ if isinstance(value, float):
116
+ results_str += f"{key.capitalize()}: {value:.4f}\n"
117
+ else:
118
+ results_str += f"{key.capitalize()}: {value}\n"
119
+ results_str += "\nEvaluation finished."
120
+ yield results_str
121
+
122
+ except Exception as e:
123
+ import traceback
124
+ error_msg = f"An error occurred during full evaluation:\n{str(e)}\n{traceback.format_exc()}"
125
+ print(error_msg)
126
+ yield error_msg
127
+
128
  # --- Gradio Interface ---
129
  with gr.Blocks() as demo:
130
  true_label = gr.State()
 
150
  inputs=input_textbox
151
  )
152
 
153
+ with gr.Accordion("Advanced: Full Model Evaluation on IMDB Test Set", open=False):
154
+ gr.Markdown(
155
+ """**WARNING!** Clicking the button below will run the sentiment analysis model on the **entire IMDB test dataset (25,000 reviews)**. "
156
+ "This is a computationally intensive process and will take a considerable amount of time (potentially **10-20 minutes or more** depending on the hardware "
157
+ "of the Hugging Face Space or machine running this app). The application might appear unresponsive during this period. "
158
+ "Progress messages will be shown below."""
159
+ )
160
+ run_eval_button = gr.Button("Run Full Evaluation on IMDB Test Set")
161
+ evaluation_output_textbox = gr.Textbox(
162
+ label="Evaluation Progress & Results",
163
+ lines=15,
164
+ interactive=False,
165
+ show_label=True,
166
+ max_lines=20
167
+ )
168
+ run_eval_button.click(
169
+ fn=run_full_evaluation_gradio,
170
+ inputs=None,
171
+ outputs=evaluation_output_textbox
172
+ )
173
+
174
  # Wire actions
175
  submit_button.click(
176
  fn=predict_sentiment,
config.yaml CHANGED
@@ -4,6 +4,7 @@ model:
4
  max_length: 880 # 256
5
  dropout: 0.1
6
  pooling_strategy: "mean" # Current default, change as needed
 
7
 
8
  inference:
9
  # Default path, can be overridden
 
4
  max_length: 880 # 256
5
  dropout: 0.1
6
  pooling_strategy: "mean" # Current default, change as needed
7
+ num_weighted_layers: 6 # Match original training config
8
 
9
  inference:
10
  # Default path, can be overridden
evaluation.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef
3
+ from models import ModernBertForSentiment # Assuming models.py is in the same directory
4
+ from tqdm import tqdm # Add this import for the progress bar
5
+
6
+
7
+ def evaluate(model, dataloader, device):
8
+ model.eval()
9
+ all_preds = []
10
+ all_labels = []
11
+ all_probs_for_auc = []
12
+ total_loss = 0
13
+
14
+ with torch.no_grad():
15
+ for batch in dataloader:
16
+ # Move batch to device, ensure all model inputs are covered
17
+ input_ids = batch['input_ids'].to(device)
18
+ attention_mask = batch['attention_mask'].to(device)
19
+ labels = batch['labels'].to(device)
20
+ lengths = batch.get('lengths') # Get lengths from batch
21
+ if lengths is None:
22
+ # Fallback or error if lengths are expected but not found
23
+ # For now, let's raise an error if using weighted loss that needs it
24
+ # Or, if your model can run without it for some pooling strategies, handle accordingly
25
+ # However, the error clearly states it's needed when labels are specified.
26
+ pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model")
27
+ else:
28
+ lengths = lengths.to(device) # Move to device if found
29
+
30
+ # Pass all necessary parts of the batch to the model
31
+ model_inputs = {
32
+ 'input_ids': input_ids,
33
+ 'attention_mask': attention_mask,
34
+ 'labels': labels
35
+ }
36
+ if lengths is not None:
37
+ model_inputs['lengths'] = lengths
38
+
39
+ outputs = model(**model_inputs)
40
+ loss = outputs.loss
41
+ logits = outputs.logits
42
+
43
+ total_loss += loss.item()
44
+
45
+ if logits.shape[1] > 1:
46
+ preds = torch.argmax(logits, dim=1)
47
+ else:
48
+ preds = (torch.sigmoid(logits) > 0.5).long()
49
+ all_preds.extend(preds.cpu().numpy())
50
+
51
+ all_labels.extend(labels.cpu().numpy())
52
+
53
+ if logits.shape[1] > 1:
54
+ probs = torch.softmax(logits, dim=1)[:, 1]
55
+ all_probs_for_auc.extend(probs.cpu().numpy())
56
+ else:
57
+ probs = torch.sigmoid(logits)
58
+ all_probs_for_auc.extend(probs.squeeze().cpu().numpy())
59
+
60
+ avg_loss = total_loss / len(dataloader)
61
+ accuracy = accuracy_score(all_labels, all_preds)
62
+ f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
63
+ precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
64
+ recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
65
+ mcc = matthews_corrcoef(all_labels, all_preds)
66
+
67
+ try:
68
+ roc_auc = roc_auc_score(all_labels, all_probs_for_auc)
69
+ except ValueError as e:
70
+ print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0")
71
+ roc_auc = 0.0
72
+
73
+ return {
74
+ 'loss': avg_loss,
75
+ 'accuracy': accuracy,
76
+ 'f1': f1,
77
+ 'roc_auc': roc_auc,
78
+ 'precision': precision,
79
+ 'recall': recall,
80
+ 'mcc': mcc
81
+ }
82
+
83
+ if __name__ == "__main__":
84
+ import argparse
85
+ from torch.utils.data import DataLoader
86
+ from datasets import load_dataset
87
+ from inference import SentimentInference # Assuming inference.py is in the same directory
88
+ import yaml
89
+ from transformers import AutoTokenizer, AutoConfig
90
+ from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH
91
+
92
+ class SentimentInference:
93
+ def __init__(self, config_path):
94
+ with open(config_path, 'r') as f:
95
+ config_data = yaml.safe_load(f)
96
+ self.config_path = config_path
97
+ self.config_data = config_data
98
+ # Adjust to access the correct key from the nested config structure
99
+ self.model_hf_repo_id = config_data['model']['name_or_path']
100
+ self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id)
101
+ self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model'
102
+ self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False)
103
+ self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code
104
+ self.max_length = config_data['model']['max_length']
105
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
106
+
107
+ try:
108
+ if self.load_from_local_pt and self.local_model_weights_path:
109
+ print(f"Loading model from local path: {self.local_model_weights_path}")
110
+ # When loading local, config might also be local or from base model if not saved with custom checkpoint
111
+ # For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code
112
+ self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config)
113
+ self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True)
114
+ else:
115
+ print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}")
116
+ self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config)
117
+ print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...")
118
+ self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True)
119
+ print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.")
120
+ self.model.to(self.device)
121
+ except Exception as e:
122
+ print(f"Failed to load model: {e}")
123
+ # Optionally print more detailed traceback
124
+ import traceback
125
+ traceback.print_exc()
126
+ exit()
127
+
128
+ self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config)
129
+
130
+ def print_debug_info(self):
131
+ print(f"Model HF Repo ID: {self.model_hf_repo_id}")
132
+ print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}")
133
+ print(f"Local Model Weights Path: {self.local_model_weights_path}")
134
+ print(f"Load from Local PT: {self.load_from_local_pt}")
135
+
136
+ parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.")
137
+ parser.add_argument(
138
+ "--config_path",
139
+ type=str,
140
+ default="local_test_config.yaml",
141
+ help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)"
142
+ )
143
+ parser.add_argument(
144
+ "--batch_size",
145
+ type=int,
146
+ default=16,
147
+ help="Batch size for evaluation."
148
+ )
149
+ args = parser.parse_args()
150
+
151
+ print(f"Using configuration: {args.config_path}")
152
+ print("Loading sentiment model and tokenizer...")
153
+ inferer = SentimentInference(config_path=args.config_path)
154
+ model = inferer.model
155
+ tokenizer = inferer.tokenizer
156
+ max_length = inferer.max_length
157
+ device = inferer.device
158
+
159
+ print("Loading IMDB test dataset...")
160
+ try:
161
+ imdb_dataset_test = load_dataset("imdb", split="test")
162
+ except Exception as e:
163
+ print(f"Failed to load IMDB dataset: {e}")
164
+ exit()
165
+
166
+ print("Tokenizing dataset...")
167
+ def tokenize_function(examples):
168
+ tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
169
+ tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
170
+ return tokenized_output
171
+
172
+ tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True)
173
+ tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"])
174
+ tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels")
175
+ tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
176
+
177
+ test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size)
178
+
179
+ print("Starting evaluation...")
180
+ progress_bar = tqdm(test_dataloader, desc="Evaluating")
181
+
182
+ results = evaluate(model, progress_bar, device)
183
+
184
+ print("\n--- Evaluation Results ---")
185
+ for key, value in results.items():
186
+ if isinstance(value, float):
187
+ print(f"{key.capitalize()}: {value:.4f}")
188
+ else:
189
+ print(f"{key.capitalize()}: {value}")
inference.py CHANGED
@@ -1,58 +1,107 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
3
- # models.py (containing ModernBertForSentiment) will be loaded from the Hub due to trust_remote_code=True
4
  from typing import Dict, Any
5
  import yaml
 
 
6
 
7
  class SentimentInference:
8
  def __init__(self, config_path: str = "config.yaml"):
9
- """Load configuration and initialize model and tokenizer from Hugging Face Hub."""
 
10
  with open(config_path, 'r') as f:
11
  config_data = yaml.safe_load(f)
 
12
 
13
  model_yaml_cfg = config_data.get('model', {})
14
  inference_yaml_cfg = config_data.get('inference', {})
15
 
16
  model_hf_repo_id = model_yaml_cfg.get('name_or_path')
17
- if not model_hf_repo_id:
18
- raise ValueError("model.name_or_path must be specified in config.yaml (e.g., 'username/model_name')")
19
-
20
  tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
 
 
 
 
21
 
22
  self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
23
 
24
- print(f"Loading tokenizer from: {tokenizer_hf_repo_id}")
25
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_hf_repo_id)
26
-
27
- print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
28
- # Load the config that was uploaded with the model (config.json in the HF repo)
29
- # This config should already have the correct architecture defined by ModernBertConfig.
30
- # We then augment it with any custom parameters needed by ModernBertForSentiment's __init__.
31
- loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
32
-
33
- # Augment loaded_config with parameters from model_yaml_cfg needed for ModernBertForSentiment initialization
34
- # These should reflect how the model was trained and its specific custom head.
35
- loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean') # Default to 'mean' as per your models.py change
36
- loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
37
- loaded_config.classifier_dropout = model_yaml_cfg.get('dropout') # Allow None if not in yaml
38
- # num_labels should ideally be in the config.json uploaded to HF, but can be set here if needed.
39
- # For binary sentiment with a single logit output, num_labels is 1.
40
- loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
41
- # The loss_function might not be strictly needed for inference if the model doesn't use it in forward pass for eval,
42
- # but if ModernBertForSentiment.__init__ requires it, it must be provided.
43
- # Assuming it's not critical for basic inference here to simplify.
44
- # loaded_config.loss_function = model_yaml_cfg.get('loss_function', {'name': '...', 'params': {}})
45
 
46
- print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
47
- # trust_remote_code=True allows loading models.py (containing ModernBertForSentiment)
48
- # from the Hugging Face model repository.
49
- self.model = AutoModelForSequenceClassification.from_pretrained(
50
- model_hf_repo_id,
51
- config=loaded_config, # Pass the augmented config
52
- trust_remote_code=True
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  self.model.eval()
55
- print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
56
 
57
  def predict(self, text: str) -> Dict[str, Any]:
58
  inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
 
3
  from typing import Dict, Any
4
  import yaml
5
+ import os
6
+ from models import ModernBertForSentiment
7
 
8
  class SentimentInference:
9
  def __init__(self, config_path: str = "config.yaml"):
10
+ """Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
11
+ print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
12
  with open(config_path, 'r') as f:
13
  config_data = yaml.safe_load(f)
14
+ print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
15
 
16
  model_yaml_cfg = config_data.get('model', {})
17
  inference_yaml_cfg = config_data.get('inference', {})
18
 
19
  model_hf_repo_id = model_yaml_cfg.get('name_or_path')
 
 
 
20
  tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
21
+ local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
22
+
23
+ print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
24
+ print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this
25
 
26
  self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
27
 
28
+ # --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
29
+ if not tokenizer_hf_repo_id and not model_hf_repo_id:
30
+ raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
31
+ effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
32
+ print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
33
+ self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
34
+
35
+ # --- Model Loading --- #
36
+ # Determine if we are loading from a local .pt file or from Hugging Face Hub
37
+ load_from_local_pt = False
38
+ if local_model_weights_path and os.path.isfile(local_model_weights_path):
39
+ print(f"Found local model weights path: {local_model_weights_path}")
40
+ print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
41
+ load_from_local_pt = True
42
+ elif not model_hf_repo_id:
43
+ raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
44
+
45
+ print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
 
 
 
46
 
47
+ if load_from_local_pt:
48
+ print("Attempting to load model from local .pt checkpoint...")
49
+ print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
50
+ # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
51
+ # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
52
+ base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
53
+ print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
54
+ if not base_model_for_config_id:
55
+ raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
56
+
57
+ print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
58
+ bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
59
+
60
+ # Augment config with parameters from model_yaml_cfg
61
+ bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
62
+ bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
63
+ bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
64
+ bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
65
+ # bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__
66
+
67
+ print("Instantiating ModernBertForSentiment model structure...")
68
+ self.model = ModernBertForSentiment(bert_config)
69
+
70
+ print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
71
+ checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
72
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
73
+ model_state_to_load = checkpoint['model_state_dict']
74
+ else:
75
+ model_state_to_load = checkpoint # Assume it's the state_dict itself
76
+ self.model.load_state_dict(model_state_to_load)
77
+ print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
78
+
79
+ else: # Load from Hugging Face Hub
80
+ print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
81
+ print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
82
+ print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
83
+ if not model_hf_repo_id:
84
+ raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
85
+
86
+ print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
87
+ loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
88
+
89
+ # Augment loaded_config
90
+ loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
91
+ loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
92
+ loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
93
+ loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
94
+
95
+ print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
96
+ self.model = AutoModelForSequenceClassification.from_pretrained(
97
+ model_hf_repo_id,
98
+ config=loaded_config,
99
+ trust_remote_code=True,
100
+ force_download=True # <--- TEMPORARY - remove when everything is working
101
+ )
102
+ print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
103
+
104
  self.model.eval()
 
105
 
106
  def predict(self, text: str) -> Dict[str, Any]:
107
  inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
local_test_config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_model_for_config: "answerdotai/ModernBERT-base"
3
+ tokenizer_name_or_path: "answerdotai/ModernBERT-base"
4
+ max_length: 880
5
+ dropout: 0.1
6
+ pooling_strategy: "mean"
7
+ num_weighted_layers: 6
8
+ num_labels: 1
9
+
10
+ inference:
11
+ model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
12
+ max_length: 880
original_config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: "answerdotai/ModernBERT-base"
3
+ loss_function:
4
+ name: "SentimentWeightedLoss" # Options: "SentimentWeightedLoss", "SentimentFocalLoss"
5
+ # Parameters for the chosen loss function.
6
+ # For SentimentFocalLoss, common params are:
7
+ # gamma_focal: 1.0 # (e.g., 2.0 for standard, -2.0 for reversed, 0 for none)
8
+ # label_smoothing_epsilon: 0.05 # (e.g., 0.0 to 0.1)
9
+ # For SentimentWeightedLoss, params is empty:
10
+ params:
11
+ gamma_focal: 1.0
12
+ label_smoothing_epsilon: 0.05
13
+ output_dir: "checkpoints"
14
+ max_length: 880 # 256
15
+ dropout: 0.1
16
+ # --- Pooling Strategy --- #
17
+ # Options: "cls", "mean", "cls_mean_concat", "weighted_layer", "cls_weighted_concat"
18
+ # "cls" uses just the [CLS] token for classification
19
+ # "mean" uses mean pooling over final hidden states for classification
20
+ # "cls_mean_concat" uses both [CLS] and mean pooling over final hidden states for classification
21
+ # "weighted_layer" uses a weighted combination of the final hidden states from the top N layers for classification
22
+ # "cls_weighted_concat" uses a weighted combination of the final hidden states from the top N layers and the [CLS] token for classification
23
+
24
+ pooling_strategy: "mean" # Current default, change as needed
25
+
26
+ num_weighted_layers: 6 # Number of top BERT layers to use for 'weighted_layer' strategies (e.g., 1 to 12 for BERT-base)
27
+
28
+ data:
29
+ # No specific data paths needed as we use HF datasets at the moment
30
+
31
+ training:
32
+ epochs: 6
33
+ batch_size: 16
34
+ lr: 1e-5 # 1e-5 # 2.0e-5
35
+ weight_decay_rate: 0.02 # 0.01
36
+ resume_from_checkpoint: "" # "checkpoints/mean_epoch2_0.9361acc_0.9355f1.pt" # Path to checkpoint file, or empty to not resume
37
+
38
+ inference:
39
+ # Default path, can be overridden
40
+ model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
41
+ # Using the same max_length as training for consistency
42
+ max_length: 880 # 256
43
+
44
+
45
+ # "answerdotai/ModernBERT-base"
46
+ # "answerdotai/ModernBERT-large"
run_sample_inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datasets import load_dataset
3
+ from inference import SentimentInference
4
+
5
+ def run_sample_inference(config_path: str = "config.yaml", num_samples: int = 5):
6
+ """
7
+ Loads a sentiment analysis model from a checkpoint, runs inference on a few
8
+ samples from the IMDB validation set, and prints the results.
9
+ """
10
+ print("Loading sentiment model...")
11
+ # Initialize SentimentInference
12
+ # Ensure config_path points to your configuration file that specifies the model path
13
+ inferer = SentimentInference(config_path=config_path)
14
+ print("Model loaded.")
15
+
16
+ print("\nLoading IMDB dataset (test split for validation samples)...")
17
+ # Load the IMDB dataset, test split is used as validation
18
+ try:
19
+ imdb_dataset = load_dataset("imdb", split="test")
20
+ except Exception as e:
21
+ print(f"Failed to load IMDB dataset: {e}")
22
+ print("Please ensure you have an internet connection and the `datasets` library can access Hugging Face.")
23
+ print("You might need to run `pip install datasets` or check your network settings.")
24
+ return
25
+
26
+ print(f"Taking {num_samples} samples from the dataset.")
27
+
28
+ # Take a few samples
29
+ samples = imdb_dataset.shuffle().select(range(num_samples))
30
+
31
+ print("\nRunning inference on selected samples:\n")
32
+ for i, sample in enumerate(samples):
33
+ text = sample["text"]
34
+ true_label_id = sample["label"]
35
+ true_label = "positive" if true_label_id == 1 else "negative"
36
+
37
+ print(f"--- Sample {i+1}/{num_samples} ---")
38
+ print(f"Text: {text[:200]}...") # Print first 200 chars for brevity
39
+ print(f"True Sentiment: {true_label}")
40
+
41
+ prediction = inferer.predict(text)
42
+ print(f"Predicted Sentiment: {prediction['sentiment']}")
43
+ print(f"Confidence: {prediction['confidence']:.4f}\n")
44
+
45
+ if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser(description="Run sample inference on IMDB dataset.")
47
+ parser.add_argument(
48
+ "--config_path",
49
+ type=str,
50
+ default="config.yaml",
51
+ help="Path to the configuration file (e.g., config.yaml)"
52
+ )
53
+ parser.add_argument(
54
+ "--num_samples",
55
+ type=int,
56
+ default=5,
57
+ help="Number of samples from IMDB test set to run inference on."
58
+ )
59
+ args = parser.parse_args()
60
+ run_sample_inference(config_path=args.config_path, num_samples=args.num_samples)
upload_to_hf.py CHANGED
@@ -1,8 +1,10 @@
1
- from huggingface_hub import HfApi, upload_folder, create_repo
2
  from transformers import AutoTokenizer, AutoConfig
3
  import os
4
  import shutil
5
  import tempfile
 
 
6
 
7
  # --- Configuration ---
8
  HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
@@ -15,7 +17,7 @@ ORIGINAL_BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
15
  # Local path to your fine-tuned model checkpoint
16
  LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
17
  FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
18
- # If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBertForSentiment
19
  # For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
20
 
21
  # Files from your project to include (e.g., custom model code, inference script)
@@ -32,19 +34,29 @@ PROJECT_FILES_TO_UPLOAD = [
32
  def upload_model_and_tokenizer():
33
  api = HfApi()
34
 
35
- # Create the repository on Hugging Face Hub (if it doesn't exist)
36
- print(f"Creating repository {REPO_ID} on Hugging Face Hub...")
37
- create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Create a temporary directory to gather all files for upload
40
- with tempfile.TemporaryDirectory() as tmp_upload_dir:
41
- print(f"Created temporary directory for upload: {tmp_upload_dir}")
42
 
43
  # 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
44
- print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {tmp_upload_dir}...")
45
  try:
46
  tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
47
- tokenizer.save_pretrained(tmp_upload_dir)
48
  print("Tokenizer files saved.")
49
  except Exception as e:
50
  print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
@@ -53,58 +65,198 @@ def upload_model_and_tokenizer():
53
 
54
  # 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
55
  # This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
56
- print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {tmp_upload_dir}...")
57
  try:
58
  config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
59
- # If your fine-tuned ModernBertForSentiment has specific architectural changes in its config
60
- # that are NOT automatically handled by loading the state_dict (e.g. num_labels if not standard),
61
- # you might need to update 'config' here before saving.
62
- # For now, we assume the base config is sufficient or your model's state_dict handles it.
63
- config.save_pretrained(tmp_upload_dir)
64
- print("Model config.json saved.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
67
  return
68
 
69
- # 3. Copy fine-tuned model checkpoint to temporary directory
70
- # The fine-tuned weights should be named 'pytorch_model.bin' or 'model.safetensors' for HF to auto-load.
71
- # Or, your config.json in the repo should point to the custom name.
72
- # For simplicity, we'll rename it to HF standard name of pytorch_model.bin.
73
- local_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME)
74
- if os.path.exists(local_checkpoint_path):
75
- hf_model_path = os.path.join(tmp_upload_dir, "pytorch_model.bin")
76
- shutil.copyfile(local_checkpoint_path, hf_model_path)
77
- print(f"Copied fine-tuned model {FINE_TUNED_MODEL_FILENAME} to {hf_model_path}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  else:
79
- print(f"Error: Fine-tuned model checkpoint {local_checkpoint_path} not found.")
 
 
 
 
 
 
80
  return
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # 4. Copy other project files
83
  for project_file in PROJECT_FILES_TO_UPLOAD:
84
  local_project_file_path = project_file # Files are now at the root
85
  if os.path.exists(local_project_file_path):
86
- shutil.copy(local_project_file_path, os.path.join(tmp_upload_dir, os.path.basename(project_file)))
87
- print(f"Copied project file {project_file} to {tmp_upload_dir}.")
88
- else:
89
- print(f"Warning: Project file {project_file} not found at {local_project_file_path}.")
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # 5. Upload the contents of the temporary directory
92
- print(f"Uploading all files from {tmp_upload_dir} to {REPO_ID}...")
93
  try:
94
  upload_folder(
95
- folder_path=tmp_upload_dir,
96
  repo_id=REPO_ID,
97
  repo_type="model",
98
  commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
99
  )
100
  print("All files uploaded successfully!")
101
  except Exception as e:
102
- print(f"Error uploading folder to Hugging Face Hub: {e}")
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
- # Make sure you are logged in to Hugging Face CLI:
106
- # Run `huggingface-cli login` or `huggingface-cli login --token YOUR_HF_WRITE_TOKEN` in your terminal first.
107
- print("Starting upload process...")
108
- print(f"Target Hugging Face Repo ID: {REPO_ID}")
109
- print("Ensure you have run 'huggingface-cli login' with a write token.")
110
  upload_model_and_tokenizer()
 
1
+ from huggingface_hub import HfApi, upload_folder, create_repo, login
2
  from transformers import AutoTokenizer, AutoConfig
3
  import os
4
  import shutil
5
  import tempfile
6
+ import torch
7
+ import argparse
8
 
9
  # --- Configuration ---
10
  HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
 
17
  # Local path to your fine-tuned model checkpoint
18
  LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
19
  FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
20
+ # If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBert
21
  # For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
22
 
23
  # Files from your project to include (e.g., custom model code, inference script)
 
34
  def upload_model_and_tokenizer():
35
  api = HfApi()
36
 
37
+ REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}"
38
+ print(f"Preparing to upload to Hugging Face Hub repository: {REPO_ID}")
39
+
40
+ # Create the repository on Hugging Face Hub if it doesn't exist
41
+ # This should be done after login to ensure correct permissions
42
+ print(f"Ensuring repository '{REPO_ID}' exists on Hugging Face Hub...")
43
+ try:
44
+ create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
45
+ print(f"Repository '{REPO_ID}' ensured.")
46
+ except Exception as e:
47
+ print(f"Error creating/accessing repository {REPO_ID}: {e}")
48
+ print("Please check your Hugging Face token and repository permissions.")
49
+ return
50
 
51
  # Create a temporary directory to gather all files for upload
52
+ with tempfile.TemporaryDirectory() as temp_dir:
53
+ print(f"Created temporary directory for upload: {temp_dir}")
54
 
55
  # 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
56
+ print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
57
  try:
58
  tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
59
+ tokenizer.save_pretrained(temp_dir)
60
  print("Tokenizer files saved.")
61
  except Exception as e:
62
  print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
 
65
 
66
  # 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
67
  # This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
68
+ print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
69
  try:
70
  config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
71
+ print(f"Config loaded. Initial num_labels (if exists): {getattr(config, 'num_labels', 'Not set')}")
72
+
73
+ # Set architecture first
74
+ config.architectures = ["ModernBertForSentiment"]
75
+
76
+ # Add necessary classification head attributes for AutoModelForSequenceClassification
77
+ config.num_labels = 1 # For IMDB sentiment (binary, single logit output based on training)
78
+ print(f"After attempting to set: config.num_labels = {config.num_labels}")
79
+
80
+ config.id2label = {0: "NEGATIVE", 1: "POSITIVE"} # Standard for binary, even with num_labels=1
81
+ config.label2id = {"NEGATIVE": 0, "POSITIVE": 1}
82
+ print(f"After setting id2label/label2id, config.num_labels is: {config.num_labels}")
83
+
84
+ # CRITICAL: Force num_labels to 1 again immediately before saving
85
+ config.num_labels = 1
86
+ print(f"Immediately before save, FINAL check config.num_labels = {config.num_labels}")
87
+
88
+ # Safeguard: Remove any existing config.json from temp_dir before saving ours
89
+ potential_old_config_path = os.path.join(temp_dir, "config.json")
90
+ if os.path.exists(potential_old_config_path):
91
+ os.remove(potential_old_config_path)
92
+ print(f"Removed existing config.json from {temp_dir} to ensure clean save.")
93
+
94
+ config.save_pretrained(temp_dir)
95
+ print(f"Model config.json (with num_labels={config.num_labels}, architectures={config.architectures}) saved to {temp_dir}.")
96
  except Exception as e:
97
  print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
98
  return
99
 
100
+ # Load the fine-tuned model checkpoint to extract the state_dict
101
+ full_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME)
102
+ hf_model_path = os.path.join(temp_dir, "pytorch_model.bin")
103
+
104
+ if not os.path.exists(full_checkpoint_path):
105
+ print(f"ERROR: Local model checkpoint not found at {full_checkpoint_path}")
106
+ shutil.rmtree(temp_dir)
107
+ return
108
+
109
+ print(f"Loading local checkpoint from: {full_checkpoint_path}")
110
+ # Load checkpoint to CPU to avoid GPU memory issues if the script runner doesn't have/need GPU
111
+ checkpoint = torch.load(full_checkpoint_path, map_location='cpu')
112
+
113
+ model_state_dict = None
114
+ if 'model_state_dict' in checkpoint:
115
+ model_state_dict = checkpoint['model_state_dict']
116
+ print("Extracted 'model_state_dict' from checkpoint.")
117
+ elif 'state_dict' in checkpoint: # Another common key for state_dicts
118
+ model_state_dict = checkpoint['state_dict']
119
+ print("Extracted 'state_dict' from checkpoint.")
120
+ elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
121
+ # If the checkpoint is already a state_dict (e.g., from torch.save(model.state_dict(), ...))
122
+ # Basic check: does it have keys that look like weights/biases?
123
+ if any(key.endswith('.weight') or key.endswith('.bias') for key in checkpoint.keys()):
124
+ model_state_dict = checkpoint
125
+ print("Checkpoint appears to be a raw state_dict (contains .weight or .bias keys).")
126
+ else:
127
+ print("Checkpoint is a dict, but does not immediately appear to be a state_dict (no .weight/.bias keys found).")
128
+ print(f"Checkpoint keys: {list(checkpoint.keys())[:10]}...") # Print some keys for diagnosis
129
+
130
  else:
131
+ # This case handles if checkpoint is not a dict or doesn't match known structures
132
+ print(f"ERROR: Could not find a known state_dict key in the checkpoint, and it's not a recognizable raw state_dict.")
133
+ if isinstance(checkpoint, dict):
134
+ print(f"Checkpoint dictionary keys found: {list(checkpoint.keys())}")
135
+ else:
136
+ print(f"Checkpoint is not a dictionary. Type: {type(checkpoint)}")
137
+ shutil.rmtree(temp_dir)
138
  return
139
 
140
+ if model_state_dict is None:
141
+ print("ERROR: model_state_dict was not successfully extracted. Aborting upload.")
142
+ shutil.rmtree(temp_dir)
143
+ return
144
+
145
+ # --- DEBUG: Print keys of the state_dict ---
146
+ print("\n--- Keys in extracted (original) model_state_dict (first 30 and last 10): ---")
147
+ state_dict_keys = list(model_state_dict.keys())
148
+ if len(state_dict_keys) > 0:
149
+ for i, key in enumerate(state_dict_keys[:30]):
150
+ print(f" {i+1}. {key}")
151
+ if len(state_dict_keys) > 40: # Show ellipsis if there's a gap
152
+ print(" ...")
153
+ # Print last 10 keys if there are more than 30
154
+ start_index_for_last_10 = max(30, len(state_dict_keys) - 10)
155
+ for i, key_idx in enumerate(range(start_index_for_last_10, len(state_dict_keys))):
156
+ print(f" {key_idx+1}. {state_dict_keys[key_idx]}")
157
+ else:
158
+ print(" (No keys found in model_state_dict)")
159
+ print(f"Total keys: {len(state_dict_keys)}")
160
+ print("-----------------------------------------------------------\n")
161
+ # --- END DEBUG ---
162
+
163
+ # Transform keys for Hugging Face compatibility if needed.
164
+ # For ModernBertForSentiment with self.bert and self.classifier (custom head):
165
+ # - Checkpoint 'bert.*' should remain 'bert.*'
166
+ # - Checkpoint 'classifier.*' keys (e.g., classifier.dense1.weight, classifier.out_proj.weight) should remain 'classifier.*' as they are.
167
+ transformed_state_dict = {}
168
+ has_classifier_weights_transformed = False # Used to track if out_proj was found
169
+
170
+ print("Transforming state_dict keys for Hugging Face Hub compatibility...")
171
+ for key, value in model_state_dict.items():
172
+ new_key = None
173
+ if key.startswith("bert."):
174
+ # Keep 'bert.' prefix as ModernBertForSentiment uses self.bert
175
+ new_key = key
176
+ elif key.startswith("classifier."):
177
+ # All parts of the custom classifier head should retain their names
178
+ new_key = key
179
+ if "out_proj" in key: # Just to confirm it exists
180
+ has_classifier_weights_transformed = True # Indicate out_proj was found and processed
181
+
182
+ if new_key:
183
+ transformed_state_dict[new_key] = value
184
+ if key != new_key:
185
+ print(f" Mapping '{key}' -> '{new_key}'")
186
+ else:
187
+ # print(f" Keeping key as is: '{key}'") # Optional
188
+ pass
189
+ else:
190
+ print(f" INFO: Discarding key not mapped: {key}")
191
+
192
+ # Check if the critical classifier output layer was present in the source checkpoint
193
+ # This check might need adjustment based on the actual layers of ClassifierHead
194
+ # For now, we check if any 'out_proj' key was seen under 'classifier.'
195
+ if not has_classifier_weights_transformed:
196
+ print("WARNING: No 'classifier.out_proj.*' keys were found in the source checkpoint.")
197
+ print(" Ensure your checkpoint contains the expected classifier layers.")
198
+ # Not necessarily an error to abort, as other classifier keys might be valid.
199
+
200
+ model_state_dict = transformed_state_dict
201
+
202
+ # --- DEBUG: Print keys of the TRANSFORMED state_dict ---
203
+ print("\n--- Keys in TRANSFORMED model_state_dict for upload (first 30 and last 10): ---")
204
+ state_dict_keys_transformed = list(transformed_state_dict.keys())
205
+ if len(state_dict_keys_transformed) > 0:
206
+ for i, key_t in enumerate(state_dict_keys_transformed[:30]):
207
+ print(f" {i+1}. {key_t}")
208
+ if len(state_dict_keys_transformed) > 40:
209
+ print(" ...")
210
+ start_index_for_last_10_t = max(30, len(state_dict_keys_transformed) - 10)
211
+ for i, key_idx_t in enumerate(range(start_index_for_last_10_t, len(state_dict_keys_transformed))):
212
+ print(f" {key_idx_t+1}. {state_dict_keys_transformed[key_idx_t]}")
213
+ else:
214
+ print(" (No keys found in transformed_state_dict)")
215
+ print(f"Total keys in transformed_state_dict: {len(state_dict_keys_transformed)}")
216
+ print("-----------------------------------------------------------\n")
217
+
218
+ # Save the TRANSFORMED state_dict
219
+ torch.save(transformed_state_dict, hf_model_path)
220
+ print(f"Saved TRANSFORMED model state_dict to {hf_model_path}.")
221
+
222
  # 4. Copy other project files
223
  for project_file in PROJECT_FILES_TO_UPLOAD:
224
  local_project_file_path = project_file # Files are now at the root
225
  if os.path.exists(local_project_file_path):
226
+ shutil.copy(local_project_file_path, os.path.join(temp_dir, os.path.basename(project_file)))
227
+ print(f"Copied project file {project_file} to {temp_dir}.")
228
+
229
+ # Before uploading, let's inspect the temp_dir to be absolutely sure what's there
230
+ print(f"--- Inspecting temp_dir ({temp_dir}) before upload: ---")
231
+ for item in os.listdir(temp_dir):
232
+ print(f" - {item}")
233
+ temp_config_path_to_check = os.path.join(temp_dir, "config.json")
234
+ if os.path.exists(temp_config_path_to_check):
235
+ print(f"--- Content of {temp_config_path_to_check} before upload: ---")
236
+ with open(temp_config_path_to_check, 'r') as f_check:
237
+ print(f_check.read())
238
+ print("--- End of config.json content ---")
239
+ else:
240
+ print(f"WARNING: {temp_config_path_to_check} does NOT exist before upload!")
241
 
242
  # 5. Upload the contents of the temporary directory
243
+ print(f"Uploading all files from {temp_dir} to {REPO_ID}...")
244
  try:
245
  upload_folder(
246
+ folder_path=temp_dir,
247
  repo_id=REPO_ID,
248
  repo_type="model",
249
  commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
250
  )
251
  print("All files uploaded successfully!")
252
  except Exception as e:
253
+ print(f"Error uploading files: {e}")
254
+ finally:
255
+ print(f"Cleaning up temporary directory: {temp_dir}")
256
+ # The TemporaryDirectory context manager handles cleanup automatically
257
+ # but an explicit message is good for clarity.
258
+
259
+ print("Upload process finished.")
260
 
261
  if __name__ == "__main__":
 
 
 
 
 
262
  upload_model_and_tokenizer()