Spaces:
Sleeping
Sleeping
add full evaluation suite to app
Browse files- app.py +80 -0
- config.yaml +1 -0
- evaluation.py +189 -0
- inference.py +84 -35
- local_test_config.yaml +12 -0
- original_config.yaml +46 -0
- run_sample_inference.py +60 -0
- upload_to_hf.py +190 -38
app.py
CHANGED
@@ -3,6 +3,10 @@ from inference import SentimentInference
|
|
3 |
import os
|
4 |
from datasets import load_dataset
|
5 |
import random
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# --- Initialize Sentiment Model ---
|
8 |
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
|
@@ -66,6 +70,61 @@ def predict_sentiment(text_input, true_label_state):
|
|
66 |
print(f"Error during prediction: {e}")
|
67 |
return f"Error during prediction: {str(e)}", true_label_state
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
# --- Gradio Interface ---
|
70 |
with gr.Blocks() as demo:
|
71 |
true_label = gr.State()
|
@@ -91,6 +150,27 @@ with gr.Blocks() as demo:
|
|
91 |
inputs=input_textbox
|
92 |
)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
# Wire actions
|
95 |
submit_button.click(
|
96 |
fn=predict_sentiment,
|
|
|
3 |
import os
|
4 |
from datasets import load_dataset
|
5 |
import random
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from evaluation import evaluate
|
9 |
+
from tqdm import tqdm
|
10 |
|
11 |
# --- Initialize Sentiment Model ---
|
12 |
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
|
|
|
70 |
print(f"Error during prediction: {e}")
|
71 |
return f"Error during prediction: {str(e)}", true_label_state
|
72 |
|
73 |
+
def run_full_evaluation_gradio():
|
74 |
+
"""Runs full evaluation on the IMDB test set and yields results for Gradio."""
|
75 |
+
if sentiment_inferer is None or sentiment_inferer.model is None:
|
76 |
+
yield "Error: Sentiment model could not be loaded. Cannot run evaluation."
|
77 |
+
return
|
78 |
+
|
79 |
+
try:
|
80 |
+
yield "Starting full evaluation... This will process 25,000 samples and may take 10-20 minutes. Please be patient."
|
81 |
+
|
82 |
+
device = sentiment_inferer.device
|
83 |
+
model = sentiment_inferer.model
|
84 |
+
tokenizer = sentiment_inferer.tokenizer
|
85 |
+
max_length = sentiment_inferer.max_length
|
86 |
+
batch_size = 16 # Consistent with evaluation.py default
|
87 |
+
|
88 |
+
yield "Loading IMDB test dataset (this might take a moment)..."
|
89 |
+
imdb_test_full = load_dataset("imdb", split="test")
|
90 |
+
yield f"IMDB test dataset loaded ({len(imdb_test_full)} samples). Tokenizing dataset..."
|
91 |
+
|
92 |
+
def tokenize_function(examples):
|
93 |
+
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
|
94 |
+
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
|
95 |
+
return tokenized_output
|
96 |
+
|
97 |
+
tokenized_imdb_test_full = imdb_test_full.map(tokenize_function, batched=True, num_proc=os.cpu_count()//2 if os.cpu_count() > 1 else 1)
|
98 |
+
tokenized_imdb_test_full = tokenized_imdb_test_full.remove_columns(["text"])
|
99 |
+
tokenized_imdb_test_full = tokenized_imdb_test_full.rename_column("label", "labels")
|
100 |
+
tokenized_imdb_test_full.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
|
101 |
+
|
102 |
+
test_dataloader_full = DataLoader(tokenized_imdb_test_full, batch_size=batch_size)
|
103 |
+
yield "Dataset tokenized and DataLoader prepared. Starting model evaluation on the test set..."
|
104 |
+
|
105 |
+
# The 'evaluate' function from evaluation.py expects the dataloader to be potentially wrapped by tqdm
|
106 |
+
# based on how it was called in evaluation.py's main block.
|
107 |
+
# We will wrap it with tqdm here for consistency if evaluate function expects it.
|
108 |
+
# Note: tqdm progress here will go to console, not Gradio UI directly.
|
109 |
+
tqdm_dataloader = tqdm(test_dataloader_full, desc="Evaluating in App")
|
110 |
+
|
111 |
+
results = evaluate(model, tqdm_dataloader, device)
|
112 |
+
|
113 |
+
results_str = "--- Full Evaluation Results ---\n"
|
114 |
+
for key, value in results.items():
|
115 |
+
if isinstance(value, float):
|
116 |
+
results_str += f"{key.capitalize()}: {value:.4f}\n"
|
117 |
+
else:
|
118 |
+
results_str += f"{key.capitalize()}: {value}\n"
|
119 |
+
results_str += "\nEvaluation finished."
|
120 |
+
yield results_str
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
import traceback
|
124 |
+
error_msg = f"An error occurred during full evaluation:\n{str(e)}\n{traceback.format_exc()}"
|
125 |
+
print(error_msg)
|
126 |
+
yield error_msg
|
127 |
+
|
128 |
# --- Gradio Interface ---
|
129 |
with gr.Blocks() as demo:
|
130 |
true_label = gr.State()
|
|
|
150 |
inputs=input_textbox
|
151 |
)
|
152 |
|
153 |
+
with gr.Accordion("Advanced: Full Model Evaluation on IMDB Test Set", open=False):
|
154 |
+
gr.Markdown(
|
155 |
+
"""**WARNING!** Clicking the button below will run the sentiment analysis model on the **entire IMDB test dataset (25,000 reviews)**. "
|
156 |
+
"This is a computationally intensive process and will take a considerable amount of time (potentially **10-20 minutes or more** depending on the hardware "
|
157 |
+
"of the Hugging Face Space or machine running this app). The application might appear unresponsive during this period. "
|
158 |
+
"Progress messages will be shown below."""
|
159 |
+
)
|
160 |
+
run_eval_button = gr.Button("Run Full Evaluation on IMDB Test Set")
|
161 |
+
evaluation_output_textbox = gr.Textbox(
|
162 |
+
label="Evaluation Progress & Results",
|
163 |
+
lines=15,
|
164 |
+
interactive=False,
|
165 |
+
show_label=True,
|
166 |
+
max_lines=20
|
167 |
+
)
|
168 |
+
run_eval_button.click(
|
169 |
+
fn=run_full_evaluation_gradio,
|
170 |
+
inputs=None,
|
171 |
+
outputs=evaluation_output_textbox
|
172 |
+
)
|
173 |
+
|
174 |
# Wire actions
|
175 |
submit_button.click(
|
176 |
fn=predict_sentiment,
|
config.yaml
CHANGED
@@ -4,6 +4,7 @@ model:
|
|
4 |
max_length: 880 # 256
|
5 |
dropout: 0.1
|
6 |
pooling_strategy: "mean" # Current default, change as needed
|
|
|
7 |
|
8 |
inference:
|
9 |
# Default path, can be overridden
|
|
|
4 |
max_length: 880 # 256
|
5 |
dropout: 0.1
|
6 |
pooling_strategy: "mean" # Current default, change as needed
|
7 |
+
num_weighted_layers: 6 # Match original training config
|
8 |
|
9 |
inference:
|
10 |
# Default path, can be overridden
|
evaluation.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef
|
3 |
+
from models import ModernBertForSentiment # Assuming models.py is in the same directory
|
4 |
+
from tqdm import tqdm # Add this import for the progress bar
|
5 |
+
|
6 |
+
|
7 |
+
def evaluate(model, dataloader, device):
|
8 |
+
model.eval()
|
9 |
+
all_preds = []
|
10 |
+
all_labels = []
|
11 |
+
all_probs_for_auc = []
|
12 |
+
total_loss = 0
|
13 |
+
|
14 |
+
with torch.no_grad():
|
15 |
+
for batch in dataloader:
|
16 |
+
# Move batch to device, ensure all model inputs are covered
|
17 |
+
input_ids = batch['input_ids'].to(device)
|
18 |
+
attention_mask = batch['attention_mask'].to(device)
|
19 |
+
labels = batch['labels'].to(device)
|
20 |
+
lengths = batch.get('lengths') # Get lengths from batch
|
21 |
+
if lengths is None:
|
22 |
+
# Fallback or error if lengths are expected but not found
|
23 |
+
# For now, let's raise an error if using weighted loss that needs it
|
24 |
+
# Or, if your model can run without it for some pooling strategies, handle accordingly
|
25 |
+
# However, the error clearly states it's needed when labels are specified.
|
26 |
+
pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model")
|
27 |
+
else:
|
28 |
+
lengths = lengths.to(device) # Move to device if found
|
29 |
+
|
30 |
+
# Pass all necessary parts of the batch to the model
|
31 |
+
model_inputs = {
|
32 |
+
'input_ids': input_ids,
|
33 |
+
'attention_mask': attention_mask,
|
34 |
+
'labels': labels
|
35 |
+
}
|
36 |
+
if lengths is not None:
|
37 |
+
model_inputs['lengths'] = lengths
|
38 |
+
|
39 |
+
outputs = model(**model_inputs)
|
40 |
+
loss = outputs.loss
|
41 |
+
logits = outputs.logits
|
42 |
+
|
43 |
+
total_loss += loss.item()
|
44 |
+
|
45 |
+
if logits.shape[1] > 1:
|
46 |
+
preds = torch.argmax(logits, dim=1)
|
47 |
+
else:
|
48 |
+
preds = (torch.sigmoid(logits) > 0.5).long()
|
49 |
+
all_preds.extend(preds.cpu().numpy())
|
50 |
+
|
51 |
+
all_labels.extend(labels.cpu().numpy())
|
52 |
+
|
53 |
+
if logits.shape[1] > 1:
|
54 |
+
probs = torch.softmax(logits, dim=1)[:, 1]
|
55 |
+
all_probs_for_auc.extend(probs.cpu().numpy())
|
56 |
+
else:
|
57 |
+
probs = torch.sigmoid(logits)
|
58 |
+
all_probs_for_auc.extend(probs.squeeze().cpu().numpy())
|
59 |
+
|
60 |
+
avg_loss = total_loss / len(dataloader)
|
61 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
62 |
+
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
|
63 |
+
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
|
64 |
+
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
|
65 |
+
mcc = matthews_corrcoef(all_labels, all_preds)
|
66 |
+
|
67 |
+
try:
|
68 |
+
roc_auc = roc_auc_score(all_labels, all_probs_for_auc)
|
69 |
+
except ValueError as e:
|
70 |
+
print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0")
|
71 |
+
roc_auc = 0.0
|
72 |
+
|
73 |
+
return {
|
74 |
+
'loss': avg_loss,
|
75 |
+
'accuracy': accuracy,
|
76 |
+
'f1': f1,
|
77 |
+
'roc_auc': roc_auc,
|
78 |
+
'precision': precision,
|
79 |
+
'recall': recall,
|
80 |
+
'mcc': mcc
|
81 |
+
}
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
import argparse
|
85 |
+
from torch.utils.data import DataLoader
|
86 |
+
from datasets import load_dataset
|
87 |
+
from inference import SentimentInference # Assuming inference.py is in the same directory
|
88 |
+
import yaml
|
89 |
+
from transformers import AutoTokenizer, AutoConfig
|
90 |
+
from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH
|
91 |
+
|
92 |
+
class SentimentInference:
|
93 |
+
def __init__(self, config_path):
|
94 |
+
with open(config_path, 'r') as f:
|
95 |
+
config_data = yaml.safe_load(f)
|
96 |
+
self.config_path = config_path
|
97 |
+
self.config_data = config_data
|
98 |
+
# Adjust to access the correct key from the nested config structure
|
99 |
+
self.model_hf_repo_id = config_data['model']['name_or_path']
|
100 |
+
self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id)
|
101 |
+
self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model'
|
102 |
+
self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False)
|
103 |
+
self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code
|
104 |
+
self.max_length = config_data['model']['max_length']
|
105 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
106 |
+
|
107 |
+
try:
|
108 |
+
if self.load_from_local_pt and self.local_model_weights_path:
|
109 |
+
print(f"Loading model from local path: {self.local_model_weights_path}")
|
110 |
+
# When loading local, config might also be local or from base model if not saved with custom checkpoint
|
111 |
+
# For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code
|
112 |
+
self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config)
|
113 |
+
self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True)
|
114 |
+
else:
|
115 |
+
print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}")
|
116 |
+
self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config)
|
117 |
+
print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...")
|
118 |
+
self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True)
|
119 |
+
print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.")
|
120 |
+
self.model.to(self.device)
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Failed to load model: {e}")
|
123 |
+
# Optionally print more detailed traceback
|
124 |
+
import traceback
|
125 |
+
traceback.print_exc()
|
126 |
+
exit()
|
127 |
+
|
128 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config)
|
129 |
+
|
130 |
+
def print_debug_info(self):
|
131 |
+
print(f"Model HF Repo ID: {self.model_hf_repo_id}")
|
132 |
+
print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}")
|
133 |
+
print(f"Local Model Weights Path: {self.local_model_weights_path}")
|
134 |
+
print(f"Load from Local PT: {self.load_from_local_pt}")
|
135 |
+
|
136 |
+
parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.")
|
137 |
+
parser.add_argument(
|
138 |
+
"--config_path",
|
139 |
+
type=str,
|
140 |
+
default="local_test_config.yaml",
|
141 |
+
help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)"
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--batch_size",
|
145 |
+
type=int,
|
146 |
+
default=16,
|
147 |
+
help="Batch size for evaluation."
|
148 |
+
)
|
149 |
+
args = parser.parse_args()
|
150 |
+
|
151 |
+
print(f"Using configuration: {args.config_path}")
|
152 |
+
print("Loading sentiment model and tokenizer...")
|
153 |
+
inferer = SentimentInference(config_path=args.config_path)
|
154 |
+
model = inferer.model
|
155 |
+
tokenizer = inferer.tokenizer
|
156 |
+
max_length = inferer.max_length
|
157 |
+
device = inferer.device
|
158 |
+
|
159 |
+
print("Loading IMDB test dataset...")
|
160 |
+
try:
|
161 |
+
imdb_dataset_test = load_dataset("imdb", split="test")
|
162 |
+
except Exception as e:
|
163 |
+
print(f"Failed to load IMDB dataset: {e}")
|
164 |
+
exit()
|
165 |
+
|
166 |
+
print("Tokenizing dataset...")
|
167 |
+
def tokenize_function(examples):
|
168 |
+
tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
|
169 |
+
tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]]
|
170 |
+
return tokenized_output
|
171 |
+
|
172 |
+
tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True)
|
173 |
+
tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"])
|
174 |
+
tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels")
|
175 |
+
tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"])
|
176 |
+
|
177 |
+
test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size)
|
178 |
+
|
179 |
+
print("Starting evaluation...")
|
180 |
+
progress_bar = tqdm(test_dataloader, desc="Evaluating")
|
181 |
+
|
182 |
+
results = evaluate(model, progress_bar, device)
|
183 |
+
|
184 |
+
print("\n--- Evaluation Results ---")
|
185 |
+
for key, value in results.items():
|
186 |
+
if isinstance(value, float):
|
187 |
+
print(f"{key.capitalize()}: {value:.4f}")
|
188 |
+
else:
|
189 |
+
print(f"{key.capitalize()}: {value}")
|
inference.py
CHANGED
@@ -1,58 +1,107 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
3 |
-
# models.py (containing ModernBertForSentiment) will be loaded from the Hub due to trust_remote_code=True
|
4 |
from typing import Dict, Any
|
5 |
import yaml
|
|
|
|
|
6 |
|
7 |
class SentimentInference:
|
8 |
def __init__(self, config_path: str = "config.yaml"):
|
9 |
-
"""Load configuration and initialize model and tokenizer from Hugging Face Hub."""
|
|
|
10 |
with open(config_path, 'r') as f:
|
11 |
config_data = yaml.safe_load(f)
|
|
|
12 |
|
13 |
model_yaml_cfg = config_data.get('model', {})
|
14 |
inference_yaml_cfg = config_data.get('inference', {})
|
15 |
|
16 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
17 |
-
if not model_hf_repo_id:
|
18 |
-
raise ValueError("model.name_or_path must be specified in config.yaml (e.g., 'username/model_name')")
|
19 |
-
|
20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
|
|
|
|
|
|
|
|
21 |
|
22 |
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
# but if ModernBertForSentiment.__init__ requires it, it must be provided.
|
43 |
-
# Assuming it's not critical for basic inference here to simplify.
|
44 |
-
# loaded_config.loss_function = model_yaml_cfg.get('loss_function', {'name': '...', 'params': {}})
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self.model.eval()
|
55 |
-
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
56 |
|
57 |
def predict(self, text: str) -> Dict[str, Any]:
|
58 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
|
|
3 |
from typing import Dict, Any
|
4 |
import yaml
|
5 |
+
import os
|
6 |
+
from models import ModernBertForSentiment
|
7 |
|
8 |
class SentimentInference:
|
9 |
def __init__(self, config_path: str = "config.yaml"):
|
10 |
+
"""Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
|
11 |
+
print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
|
12 |
with open(config_path, 'r') as f:
|
13 |
config_data = yaml.safe_load(f)
|
14 |
+
print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
|
15 |
|
16 |
model_yaml_cfg = config_data.get('model', {})
|
17 |
inference_yaml_cfg = config_data.get('inference', {})
|
18 |
|
19 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
|
|
|
|
|
|
20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
21 |
+
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
22 |
+
|
23 |
+
print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
|
24 |
+
print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this
|
25 |
|
26 |
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))
|
27 |
|
28 |
+
# --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
|
29 |
+
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
30 |
+
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
31 |
+
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
32 |
+
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
|
33 |
+
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
34 |
+
|
35 |
+
# --- Model Loading --- #
|
36 |
+
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
37 |
+
load_from_local_pt = False
|
38 |
+
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
39 |
+
print(f"Found local model weights path: {local_model_weights_path}")
|
40 |
+
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
41 |
+
load_from_local_pt = True
|
42 |
+
elif not model_hf_repo_id:
|
43 |
+
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
44 |
+
|
45 |
+
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
|
|
|
|
|
|
46 |
|
47 |
+
if load_from_local_pt:
|
48 |
+
print("Attempting to load model from local .pt checkpoint...")
|
49 |
+
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
50 |
+
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
51 |
+
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
52 |
+
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
|
53 |
+
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
|
54 |
+
if not base_model_for_config_id:
|
55 |
+
raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
|
56 |
+
|
57 |
+
print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
|
58 |
+
bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
|
59 |
+
|
60 |
+
# Augment config with parameters from model_yaml_cfg
|
61 |
+
bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
|
62 |
+
bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
|
63 |
+
bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
|
64 |
+
bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
|
65 |
+
# bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__
|
66 |
+
|
67 |
+
print("Instantiating ModernBertForSentiment model structure...")
|
68 |
+
self.model = ModernBertForSentiment(bert_config)
|
69 |
+
|
70 |
+
print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
|
71 |
+
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
72 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
73 |
+
model_state_to_load = checkpoint['model_state_dict']
|
74 |
+
else:
|
75 |
+
model_state_to_load = checkpoint # Assume it's the state_dict itself
|
76 |
+
self.model.load_state_dict(model_state_to_load)
|
77 |
+
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
|
78 |
+
|
79 |
+
else: # Load from Hugging Face Hub
|
80 |
+
print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
|
81 |
+
print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
|
82 |
+
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
|
83 |
+
if not model_hf_repo_id:
|
84 |
+
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
|
85 |
+
|
86 |
+
print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
|
87 |
+
loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
|
88 |
+
|
89 |
+
# Augment loaded_config
|
90 |
+
loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
|
91 |
+
loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
|
92 |
+
loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
|
93 |
+
loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
|
94 |
+
|
95 |
+
print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
|
96 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
97 |
+
model_hf_repo_id,
|
98 |
+
config=loaded_config,
|
99 |
+
trust_remote_code=True,
|
100 |
+
force_download=True # <--- TEMPORARY - remove when everything is working
|
101 |
+
)
|
102 |
+
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
103 |
+
|
104 |
self.model.eval()
|
|
|
105 |
|
106 |
def predict(self, text: str) -> Dict[str, Any]:
|
107 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
local_test_config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_model_for_config: "answerdotai/ModernBERT-base"
|
3 |
+
tokenizer_name_or_path: "answerdotai/ModernBERT-base"
|
4 |
+
max_length: 880
|
5 |
+
dropout: 0.1
|
6 |
+
pooling_strategy: "mean"
|
7 |
+
num_weighted_layers: 6
|
8 |
+
num_labels: 1
|
9 |
+
|
10 |
+
inference:
|
11 |
+
model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
|
12 |
+
max_length: 880
|
original_config.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: "answerdotai/ModernBERT-base"
|
3 |
+
loss_function:
|
4 |
+
name: "SentimentWeightedLoss" # Options: "SentimentWeightedLoss", "SentimentFocalLoss"
|
5 |
+
# Parameters for the chosen loss function.
|
6 |
+
# For SentimentFocalLoss, common params are:
|
7 |
+
# gamma_focal: 1.0 # (e.g., 2.0 for standard, -2.0 for reversed, 0 for none)
|
8 |
+
# label_smoothing_epsilon: 0.05 # (e.g., 0.0 to 0.1)
|
9 |
+
# For SentimentWeightedLoss, params is empty:
|
10 |
+
params:
|
11 |
+
gamma_focal: 1.0
|
12 |
+
label_smoothing_epsilon: 0.05
|
13 |
+
output_dir: "checkpoints"
|
14 |
+
max_length: 880 # 256
|
15 |
+
dropout: 0.1
|
16 |
+
# --- Pooling Strategy --- #
|
17 |
+
# Options: "cls", "mean", "cls_mean_concat", "weighted_layer", "cls_weighted_concat"
|
18 |
+
# "cls" uses just the [CLS] token for classification
|
19 |
+
# "mean" uses mean pooling over final hidden states for classification
|
20 |
+
# "cls_mean_concat" uses both [CLS] and mean pooling over final hidden states for classification
|
21 |
+
# "weighted_layer" uses a weighted combination of the final hidden states from the top N layers for classification
|
22 |
+
# "cls_weighted_concat" uses a weighted combination of the final hidden states from the top N layers and the [CLS] token for classification
|
23 |
+
|
24 |
+
pooling_strategy: "mean" # Current default, change as needed
|
25 |
+
|
26 |
+
num_weighted_layers: 6 # Number of top BERT layers to use for 'weighted_layer' strategies (e.g., 1 to 12 for BERT-base)
|
27 |
+
|
28 |
+
data:
|
29 |
+
# No specific data paths needed as we use HF datasets at the moment
|
30 |
+
|
31 |
+
training:
|
32 |
+
epochs: 6
|
33 |
+
batch_size: 16
|
34 |
+
lr: 1e-5 # 1e-5 # 2.0e-5
|
35 |
+
weight_decay_rate: 0.02 # 0.01
|
36 |
+
resume_from_checkpoint: "" # "checkpoints/mean_epoch2_0.9361acc_0.9355f1.pt" # Path to checkpoint file, or empty to not resume
|
37 |
+
|
38 |
+
inference:
|
39 |
+
# Default path, can be overridden
|
40 |
+
model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
|
41 |
+
# Using the same max_length as training for consistency
|
42 |
+
max_length: 880 # 256
|
43 |
+
|
44 |
+
|
45 |
+
# "answerdotai/ModernBERT-base"
|
46 |
+
# "answerdotai/ModernBERT-large"
|
run_sample_inference.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datasets import load_dataset
|
3 |
+
from inference import SentimentInference
|
4 |
+
|
5 |
+
def run_sample_inference(config_path: str = "config.yaml", num_samples: int = 5):
|
6 |
+
"""
|
7 |
+
Loads a sentiment analysis model from a checkpoint, runs inference on a few
|
8 |
+
samples from the IMDB validation set, and prints the results.
|
9 |
+
"""
|
10 |
+
print("Loading sentiment model...")
|
11 |
+
# Initialize SentimentInference
|
12 |
+
# Ensure config_path points to your configuration file that specifies the model path
|
13 |
+
inferer = SentimentInference(config_path=config_path)
|
14 |
+
print("Model loaded.")
|
15 |
+
|
16 |
+
print("\nLoading IMDB dataset (test split for validation samples)...")
|
17 |
+
# Load the IMDB dataset, test split is used as validation
|
18 |
+
try:
|
19 |
+
imdb_dataset = load_dataset("imdb", split="test")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"Failed to load IMDB dataset: {e}")
|
22 |
+
print("Please ensure you have an internet connection and the `datasets` library can access Hugging Face.")
|
23 |
+
print("You might need to run `pip install datasets` or check your network settings.")
|
24 |
+
return
|
25 |
+
|
26 |
+
print(f"Taking {num_samples} samples from the dataset.")
|
27 |
+
|
28 |
+
# Take a few samples
|
29 |
+
samples = imdb_dataset.shuffle().select(range(num_samples))
|
30 |
+
|
31 |
+
print("\nRunning inference on selected samples:\n")
|
32 |
+
for i, sample in enumerate(samples):
|
33 |
+
text = sample["text"]
|
34 |
+
true_label_id = sample["label"]
|
35 |
+
true_label = "positive" if true_label_id == 1 else "negative"
|
36 |
+
|
37 |
+
print(f"--- Sample {i+1}/{num_samples} ---")
|
38 |
+
print(f"Text: {text[:200]}...") # Print first 200 chars for brevity
|
39 |
+
print(f"True Sentiment: {true_label}")
|
40 |
+
|
41 |
+
prediction = inferer.predict(text)
|
42 |
+
print(f"Predicted Sentiment: {prediction['sentiment']}")
|
43 |
+
print(f"Confidence: {prediction['confidence']:.4f}\n")
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
parser = argparse.ArgumentParser(description="Run sample inference on IMDB dataset.")
|
47 |
+
parser.add_argument(
|
48 |
+
"--config_path",
|
49 |
+
type=str,
|
50 |
+
default="config.yaml",
|
51 |
+
help="Path to the configuration file (e.g., config.yaml)"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--num_samples",
|
55 |
+
type=int,
|
56 |
+
default=5,
|
57 |
+
help="Number of samples from IMDB test set to run inference on."
|
58 |
+
)
|
59 |
+
args = parser.parse_args()
|
60 |
+
run_sample_inference(config_path=args.config_path, num_samples=args.num_samples)
|
upload_to_hf.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
from huggingface_hub import HfApi, upload_folder, create_repo
|
2 |
from transformers import AutoTokenizer, AutoConfig
|
3 |
import os
|
4 |
import shutil
|
5 |
import tempfile
|
|
|
|
|
6 |
|
7 |
# --- Configuration ---
|
8 |
HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
|
@@ -15,7 +17,7 @@ ORIGINAL_BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
|
|
15 |
# Local path to your fine-tuned model checkpoint
|
16 |
LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
|
17 |
FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
|
18 |
-
# If your fine-tuned model is just a .pt file, ensure you also have a config.json for
|
19 |
# For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
|
20 |
|
21 |
# Files from your project to include (e.g., custom model code, inference script)
|
@@ -32,19 +34,29 @@ PROJECT_FILES_TO_UPLOAD = [
|
|
32 |
def upload_model_and_tokenizer():
|
33 |
api = HfApi()
|
34 |
|
35 |
-
|
36 |
-
print(f"
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Create a temporary directory to gather all files for upload
|
40 |
-
with tempfile.TemporaryDirectory() as
|
41 |
-
print(f"Created temporary directory for upload: {
|
42 |
|
43 |
# 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
|
44 |
-
print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {
|
45 |
try:
|
46 |
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
47 |
-
tokenizer.save_pretrained(
|
48 |
print("Tokenizer files saved.")
|
49 |
except Exception as e:
|
50 |
print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
@@ -53,58 +65,198 @@ def upload_model_and_tokenizer():
|
|
53 |
|
54 |
# 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
|
55 |
# This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
|
56 |
-
print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {
|
57 |
try:
|
58 |
config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
59 |
-
|
60 |
-
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
except Exception as e:
|
66 |
print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
67 |
return
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
return
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# 4. Copy other project files
|
83 |
for project_file in PROJECT_FILES_TO_UPLOAD:
|
84 |
local_project_file_path = project_file # Files are now at the root
|
85 |
if os.path.exists(local_project_file_path):
|
86 |
-
shutil.copy(local_project_file_path, os.path.join(
|
87 |
-
print(f"Copied project file {project_file} to {
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# 5. Upload the contents of the temporary directory
|
92 |
-
print(f"Uploading all files from {
|
93 |
try:
|
94 |
upload_folder(
|
95 |
-
folder_path=
|
96 |
repo_id=REPO_ID,
|
97 |
repo_type="model",
|
98 |
commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
|
99 |
)
|
100 |
print("All files uploaded successfully!")
|
101 |
except Exception as e:
|
102 |
-
print(f"Error uploading
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
-
# Make sure you are logged in to Hugging Face CLI:
|
106 |
-
# Run `huggingface-cli login` or `huggingface-cli login --token YOUR_HF_WRITE_TOKEN` in your terminal first.
|
107 |
-
print("Starting upload process...")
|
108 |
-
print(f"Target Hugging Face Repo ID: {REPO_ID}")
|
109 |
-
print("Ensure you have run 'huggingface-cli login' with a write token.")
|
110 |
upload_model_and_tokenizer()
|
|
|
1 |
+
from huggingface_hub import HfApi, upload_folder, create_repo, login
|
2 |
from transformers import AutoTokenizer, AutoConfig
|
3 |
import os
|
4 |
import shutil
|
5 |
import tempfile
|
6 |
+
import torch
|
7 |
+
import argparse
|
8 |
|
9 |
# --- Configuration ---
|
10 |
HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
|
|
|
17 |
# Local path to your fine-tuned model checkpoint
|
18 |
LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
|
19 |
FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
|
20 |
+
# If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBert
|
21 |
# For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
|
22 |
|
23 |
# Files from your project to include (e.g., custom model code, inference script)
|
|
|
34 |
def upload_model_and_tokenizer():
|
35 |
api = HfApi()
|
36 |
|
37 |
+
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}"
|
38 |
+
print(f"Preparing to upload to Hugging Face Hub repository: {REPO_ID}")
|
39 |
+
|
40 |
+
# Create the repository on Hugging Face Hub if it doesn't exist
|
41 |
+
# This should be done after login to ensure correct permissions
|
42 |
+
print(f"Ensuring repository '{REPO_ID}' exists on Hugging Face Hub...")
|
43 |
+
try:
|
44 |
+
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
45 |
+
print(f"Repository '{REPO_ID}' ensured.")
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error creating/accessing repository {REPO_ID}: {e}")
|
48 |
+
print("Please check your Hugging Face token and repository permissions.")
|
49 |
+
return
|
50 |
|
51 |
# Create a temporary directory to gather all files for upload
|
52 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
53 |
+
print(f"Created temporary directory for upload: {temp_dir}")
|
54 |
|
55 |
# 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
|
56 |
+
print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
|
57 |
try:
|
58 |
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
59 |
+
tokenizer.save_pretrained(temp_dir)
|
60 |
print("Tokenizer files saved.")
|
61 |
except Exception as e:
|
62 |
print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
|
|
65 |
|
66 |
# 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
|
67 |
# This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
|
68 |
+
print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
|
69 |
try:
|
70 |
config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
71 |
+
print(f"Config loaded. Initial num_labels (if exists): {getattr(config, 'num_labels', 'Not set')}")
|
72 |
+
|
73 |
+
# Set architecture first
|
74 |
+
config.architectures = ["ModernBertForSentiment"]
|
75 |
+
|
76 |
+
# Add necessary classification head attributes for AutoModelForSequenceClassification
|
77 |
+
config.num_labels = 1 # For IMDB sentiment (binary, single logit output based on training)
|
78 |
+
print(f"After attempting to set: config.num_labels = {config.num_labels}")
|
79 |
+
|
80 |
+
config.id2label = {0: "NEGATIVE", 1: "POSITIVE"} # Standard for binary, even with num_labels=1
|
81 |
+
config.label2id = {"NEGATIVE": 0, "POSITIVE": 1}
|
82 |
+
print(f"After setting id2label/label2id, config.num_labels is: {config.num_labels}")
|
83 |
+
|
84 |
+
# CRITICAL: Force num_labels to 1 again immediately before saving
|
85 |
+
config.num_labels = 1
|
86 |
+
print(f"Immediately before save, FINAL check config.num_labels = {config.num_labels}")
|
87 |
+
|
88 |
+
# Safeguard: Remove any existing config.json from temp_dir before saving ours
|
89 |
+
potential_old_config_path = os.path.join(temp_dir, "config.json")
|
90 |
+
if os.path.exists(potential_old_config_path):
|
91 |
+
os.remove(potential_old_config_path)
|
92 |
+
print(f"Removed existing config.json from {temp_dir} to ensure clean save.")
|
93 |
+
|
94 |
+
config.save_pretrained(temp_dir)
|
95 |
+
print(f"Model config.json (with num_labels={config.num_labels}, architectures={config.architectures}) saved to {temp_dir}.")
|
96 |
except Exception as e:
|
97 |
print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
98 |
return
|
99 |
|
100 |
+
# Load the fine-tuned model checkpoint to extract the state_dict
|
101 |
+
full_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME)
|
102 |
+
hf_model_path = os.path.join(temp_dir, "pytorch_model.bin")
|
103 |
+
|
104 |
+
if not os.path.exists(full_checkpoint_path):
|
105 |
+
print(f"ERROR: Local model checkpoint not found at {full_checkpoint_path}")
|
106 |
+
shutil.rmtree(temp_dir)
|
107 |
+
return
|
108 |
+
|
109 |
+
print(f"Loading local checkpoint from: {full_checkpoint_path}")
|
110 |
+
# Load checkpoint to CPU to avoid GPU memory issues if the script runner doesn't have/need GPU
|
111 |
+
checkpoint = torch.load(full_checkpoint_path, map_location='cpu')
|
112 |
+
|
113 |
+
model_state_dict = None
|
114 |
+
if 'model_state_dict' in checkpoint:
|
115 |
+
model_state_dict = checkpoint['model_state_dict']
|
116 |
+
print("Extracted 'model_state_dict' from checkpoint.")
|
117 |
+
elif 'state_dict' in checkpoint: # Another common key for state_dicts
|
118 |
+
model_state_dict = checkpoint['state_dict']
|
119 |
+
print("Extracted 'state_dict' from checkpoint.")
|
120 |
+
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
|
121 |
+
# If the checkpoint is already a state_dict (e.g., from torch.save(model.state_dict(), ...))
|
122 |
+
# Basic check: does it have keys that look like weights/biases?
|
123 |
+
if any(key.endswith('.weight') or key.endswith('.bias') for key in checkpoint.keys()):
|
124 |
+
model_state_dict = checkpoint
|
125 |
+
print("Checkpoint appears to be a raw state_dict (contains .weight or .bias keys).")
|
126 |
+
else:
|
127 |
+
print("Checkpoint is a dict, but does not immediately appear to be a state_dict (no .weight/.bias keys found).")
|
128 |
+
print(f"Checkpoint keys: {list(checkpoint.keys())[:10]}...") # Print some keys for diagnosis
|
129 |
+
|
130 |
else:
|
131 |
+
# This case handles if checkpoint is not a dict or doesn't match known structures
|
132 |
+
print(f"ERROR: Could not find a known state_dict key in the checkpoint, and it's not a recognizable raw state_dict.")
|
133 |
+
if isinstance(checkpoint, dict):
|
134 |
+
print(f"Checkpoint dictionary keys found: {list(checkpoint.keys())}")
|
135 |
+
else:
|
136 |
+
print(f"Checkpoint is not a dictionary. Type: {type(checkpoint)}")
|
137 |
+
shutil.rmtree(temp_dir)
|
138 |
return
|
139 |
|
140 |
+
if model_state_dict is None:
|
141 |
+
print("ERROR: model_state_dict was not successfully extracted. Aborting upload.")
|
142 |
+
shutil.rmtree(temp_dir)
|
143 |
+
return
|
144 |
+
|
145 |
+
# --- DEBUG: Print keys of the state_dict ---
|
146 |
+
print("\n--- Keys in extracted (original) model_state_dict (first 30 and last 10): ---")
|
147 |
+
state_dict_keys = list(model_state_dict.keys())
|
148 |
+
if len(state_dict_keys) > 0:
|
149 |
+
for i, key in enumerate(state_dict_keys[:30]):
|
150 |
+
print(f" {i+1}. {key}")
|
151 |
+
if len(state_dict_keys) > 40: # Show ellipsis if there's a gap
|
152 |
+
print(" ...")
|
153 |
+
# Print last 10 keys if there are more than 30
|
154 |
+
start_index_for_last_10 = max(30, len(state_dict_keys) - 10)
|
155 |
+
for i, key_idx in enumerate(range(start_index_for_last_10, len(state_dict_keys))):
|
156 |
+
print(f" {key_idx+1}. {state_dict_keys[key_idx]}")
|
157 |
+
else:
|
158 |
+
print(" (No keys found in model_state_dict)")
|
159 |
+
print(f"Total keys: {len(state_dict_keys)}")
|
160 |
+
print("-----------------------------------------------------------\n")
|
161 |
+
# --- END DEBUG ---
|
162 |
+
|
163 |
+
# Transform keys for Hugging Face compatibility if needed.
|
164 |
+
# For ModernBertForSentiment with self.bert and self.classifier (custom head):
|
165 |
+
# - Checkpoint 'bert.*' should remain 'bert.*'
|
166 |
+
# - Checkpoint 'classifier.*' keys (e.g., classifier.dense1.weight, classifier.out_proj.weight) should remain 'classifier.*' as they are.
|
167 |
+
transformed_state_dict = {}
|
168 |
+
has_classifier_weights_transformed = False # Used to track if out_proj was found
|
169 |
+
|
170 |
+
print("Transforming state_dict keys for Hugging Face Hub compatibility...")
|
171 |
+
for key, value in model_state_dict.items():
|
172 |
+
new_key = None
|
173 |
+
if key.startswith("bert."):
|
174 |
+
# Keep 'bert.' prefix as ModernBertForSentiment uses self.bert
|
175 |
+
new_key = key
|
176 |
+
elif key.startswith("classifier."):
|
177 |
+
# All parts of the custom classifier head should retain their names
|
178 |
+
new_key = key
|
179 |
+
if "out_proj" in key: # Just to confirm it exists
|
180 |
+
has_classifier_weights_transformed = True # Indicate out_proj was found and processed
|
181 |
+
|
182 |
+
if new_key:
|
183 |
+
transformed_state_dict[new_key] = value
|
184 |
+
if key != new_key:
|
185 |
+
print(f" Mapping '{key}' -> '{new_key}'")
|
186 |
+
else:
|
187 |
+
# print(f" Keeping key as is: '{key}'") # Optional
|
188 |
+
pass
|
189 |
+
else:
|
190 |
+
print(f" INFO: Discarding key not mapped: {key}")
|
191 |
+
|
192 |
+
# Check if the critical classifier output layer was present in the source checkpoint
|
193 |
+
# This check might need adjustment based on the actual layers of ClassifierHead
|
194 |
+
# For now, we check if any 'out_proj' key was seen under 'classifier.'
|
195 |
+
if not has_classifier_weights_transformed:
|
196 |
+
print("WARNING: No 'classifier.out_proj.*' keys were found in the source checkpoint.")
|
197 |
+
print(" Ensure your checkpoint contains the expected classifier layers.")
|
198 |
+
# Not necessarily an error to abort, as other classifier keys might be valid.
|
199 |
+
|
200 |
+
model_state_dict = transformed_state_dict
|
201 |
+
|
202 |
+
# --- DEBUG: Print keys of the TRANSFORMED state_dict ---
|
203 |
+
print("\n--- Keys in TRANSFORMED model_state_dict for upload (first 30 and last 10): ---")
|
204 |
+
state_dict_keys_transformed = list(transformed_state_dict.keys())
|
205 |
+
if len(state_dict_keys_transformed) > 0:
|
206 |
+
for i, key_t in enumerate(state_dict_keys_transformed[:30]):
|
207 |
+
print(f" {i+1}. {key_t}")
|
208 |
+
if len(state_dict_keys_transformed) > 40:
|
209 |
+
print(" ...")
|
210 |
+
start_index_for_last_10_t = max(30, len(state_dict_keys_transformed) - 10)
|
211 |
+
for i, key_idx_t in enumerate(range(start_index_for_last_10_t, len(state_dict_keys_transformed))):
|
212 |
+
print(f" {key_idx_t+1}. {state_dict_keys_transformed[key_idx_t]}")
|
213 |
+
else:
|
214 |
+
print(" (No keys found in transformed_state_dict)")
|
215 |
+
print(f"Total keys in transformed_state_dict: {len(state_dict_keys_transformed)}")
|
216 |
+
print("-----------------------------------------------------------\n")
|
217 |
+
|
218 |
+
# Save the TRANSFORMED state_dict
|
219 |
+
torch.save(transformed_state_dict, hf_model_path)
|
220 |
+
print(f"Saved TRANSFORMED model state_dict to {hf_model_path}.")
|
221 |
+
|
222 |
# 4. Copy other project files
|
223 |
for project_file in PROJECT_FILES_TO_UPLOAD:
|
224 |
local_project_file_path = project_file # Files are now at the root
|
225 |
if os.path.exists(local_project_file_path):
|
226 |
+
shutil.copy(local_project_file_path, os.path.join(temp_dir, os.path.basename(project_file)))
|
227 |
+
print(f"Copied project file {project_file} to {temp_dir}.")
|
228 |
+
|
229 |
+
# Before uploading, let's inspect the temp_dir to be absolutely sure what's there
|
230 |
+
print(f"--- Inspecting temp_dir ({temp_dir}) before upload: ---")
|
231 |
+
for item in os.listdir(temp_dir):
|
232 |
+
print(f" - {item}")
|
233 |
+
temp_config_path_to_check = os.path.join(temp_dir, "config.json")
|
234 |
+
if os.path.exists(temp_config_path_to_check):
|
235 |
+
print(f"--- Content of {temp_config_path_to_check} before upload: ---")
|
236 |
+
with open(temp_config_path_to_check, 'r') as f_check:
|
237 |
+
print(f_check.read())
|
238 |
+
print("--- End of config.json content ---")
|
239 |
+
else:
|
240 |
+
print(f"WARNING: {temp_config_path_to_check} does NOT exist before upload!")
|
241 |
|
242 |
# 5. Upload the contents of the temporary directory
|
243 |
+
print(f"Uploading all files from {temp_dir} to {REPO_ID}...")
|
244 |
try:
|
245 |
upload_folder(
|
246 |
+
folder_path=temp_dir,
|
247 |
repo_id=REPO_ID,
|
248 |
repo_type="model",
|
249 |
commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
|
250 |
)
|
251 |
print("All files uploaded successfully!")
|
252 |
except Exception as e:
|
253 |
+
print(f"Error uploading files: {e}")
|
254 |
+
finally:
|
255 |
+
print(f"Cleaning up temporary directory: {temp_dir}")
|
256 |
+
# The TemporaryDirectory context manager handles cleanup automatically
|
257 |
+
# but an explicit message is good for clarity.
|
258 |
+
|
259 |
+
print("Upload process finished.")
|
260 |
|
261 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
262 |
upload_model_and_tokenizer()
|