Spaces:
Sleeping
Sleeping
File size: 13,692 Bytes
6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 6529956 472f1d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
from huggingface_hub import HfApi, upload_folder, create_repo, login
from transformers import AutoTokenizer, AutoConfig
import os
import shutil
import tempfile
import torch
import argparse
# --- Configuration ---
HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
MODEL_NAME_ON_HF = "modernbert-imdb-sentiment" # The name of the model on Hugging Face
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}"
# Original base model from which the tokenizer and initial config were derived
ORIGINAL_BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
# Local path to your fine-tuned model checkpoint
LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
# If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBert
# For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
# Files from your project to include (e.g., custom model code, inference script)
# The user has moved these to the root directory.
PROJECT_FILES_TO_UPLOAD = [
"config.yaml",
"inference.py",
"models.py",
"train_utils.py",
"classifiers.py",
"README.md"
]
def upload_model_and_tokenizer():
api = HfApi()
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}"
print(f"Preparing to upload to Hugging Face Hub repository: {REPO_ID}")
# Create the repository on Hugging Face Hub if it doesn't exist
# This should be done after login to ensure correct permissions
print(f"Ensuring repository '{REPO_ID}' exists on Hugging Face Hub...")
try:
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
print(f"Repository '{REPO_ID}' ensured.")
except Exception as e:
print(f"Error creating/accessing repository {REPO_ID}: {e}")
print("Please check your Hugging Face token and repository permissions.")
return
# Create a temporary directory to gather all files for upload
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory for upload: {temp_dir}")
# 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
try:
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
tokenizer.save_pretrained(temp_dir)
print("Tokenizer files saved.")
except Exception as e:
print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
print("Please ensure this model name is correct and accessible.")
return
# 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
# This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...")
try:
config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
print(f"Config loaded. Initial num_labels (if exists): {getattr(config, 'num_labels', 'Not set')}")
# Set architecture first
config.architectures = ["ModernBertForSentiment"]
# Add necessary classification head attributes for AutoModelForSequenceClassification
config.num_labels = 1 # For IMDB sentiment (binary, single logit output based on training)
print(f"After attempting to set: config.num_labels = {config.num_labels}")
config.id2label = {0: "NEGATIVE", 1: "POSITIVE"} # Standard for binary, even with num_labels=1
config.label2id = {"NEGATIVE": 0, "POSITIVE": 1}
print(f"After setting id2label/label2id, config.num_labels is: {config.num_labels}")
# CRITICAL: Force num_labels to 1 again immediately before saving
config.num_labels = 1
print(f"Immediately before save, FINAL check config.num_labels = {config.num_labels}")
# Safeguard: Remove any existing config.json from temp_dir before saving ours
potential_old_config_path = os.path.join(temp_dir, "config.json")
if os.path.exists(potential_old_config_path):
os.remove(potential_old_config_path)
print(f"Removed existing config.json from {temp_dir} to ensure clean save.")
config.save_pretrained(temp_dir)
print(f"Model config.json (with num_labels={config.num_labels}, architectures={config.architectures}) saved to {temp_dir}.")
except Exception as e:
print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
return
# Load the fine-tuned model checkpoint to extract the state_dict
full_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME)
hf_model_path = os.path.join(temp_dir, "pytorch_model.bin")
if not os.path.exists(full_checkpoint_path):
print(f"ERROR: Local model checkpoint not found at {full_checkpoint_path}")
shutil.rmtree(temp_dir)
return
print(f"Loading local checkpoint from: {full_checkpoint_path}")
# Load checkpoint to CPU to avoid GPU memory issues if the script runner doesn't have/need GPU
checkpoint = torch.load(full_checkpoint_path, map_location='cpu')
model_state_dict = None
if 'model_state_dict' in checkpoint:
model_state_dict = checkpoint['model_state_dict']
print("Extracted 'model_state_dict' from checkpoint.")
elif 'state_dict' in checkpoint: # Another common key for state_dicts
model_state_dict = checkpoint['state_dict']
print("Extracted 'state_dict' from checkpoint.")
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
# If the checkpoint is already a state_dict (e.g., from torch.save(model.state_dict(), ...))
# Basic check: does it have keys that look like weights/biases?
if any(key.endswith('.weight') or key.endswith('.bias') for key in checkpoint.keys()):
model_state_dict = checkpoint
print("Checkpoint appears to be a raw state_dict (contains .weight or .bias keys).")
else:
print("Checkpoint is a dict, but does not immediately appear to be a state_dict (no .weight/.bias keys found).")
print(f"Checkpoint keys: {list(checkpoint.keys())[:10]}...") # Print some keys for diagnosis
else:
# This case handles if checkpoint is not a dict or doesn't match known structures
print(f"ERROR: Could not find a known state_dict key in the checkpoint, and it's not a recognizable raw state_dict.")
if isinstance(checkpoint, dict):
print(f"Checkpoint dictionary keys found: {list(checkpoint.keys())}")
else:
print(f"Checkpoint is not a dictionary. Type: {type(checkpoint)}")
shutil.rmtree(temp_dir)
return
if model_state_dict is None:
print("ERROR: model_state_dict was not successfully extracted. Aborting upload.")
shutil.rmtree(temp_dir)
return
# --- DEBUG: Print keys of the state_dict ---
print("\n--- Keys in extracted (original) model_state_dict (first 30 and last 10): ---")
state_dict_keys = list(model_state_dict.keys())
if len(state_dict_keys) > 0:
for i, key in enumerate(state_dict_keys[:30]):
print(f" {i+1}. {key}")
if len(state_dict_keys) > 40: # Show ellipsis if there's a gap
print(" ...")
# Print last 10 keys if there are more than 30
start_index_for_last_10 = max(30, len(state_dict_keys) - 10)
for i, key_idx in enumerate(range(start_index_for_last_10, len(state_dict_keys))):
print(f" {key_idx+1}. {state_dict_keys[key_idx]}")
else:
print(" (No keys found in model_state_dict)")
print(f"Total keys: {len(state_dict_keys)}")
print("-----------------------------------------------------------\n")
# --- END DEBUG ---
# Transform keys for Hugging Face compatibility if needed.
# For ModernBertForSentiment with self.bert and self.classifier (custom head):
# - Checkpoint 'bert.*' should remain 'bert.*'
# - Checkpoint 'classifier.*' keys (e.g., classifier.dense1.weight, classifier.out_proj.weight) should remain 'classifier.*' as they are.
transformed_state_dict = {}
has_classifier_weights_transformed = False # Used to track if out_proj was found
print("Transforming state_dict keys for Hugging Face Hub compatibility...")
for key, value in model_state_dict.items():
new_key = None
if key.startswith("bert."):
# Keep 'bert.' prefix as ModernBertForSentiment uses self.bert
new_key = key
elif key.startswith("classifier."):
# All parts of the custom classifier head should retain their names
new_key = key
if "out_proj" in key: # Just to confirm it exists
has_classifier_weights_transformed = True # Indicate out_proj was found and processed
if new_key:
transformed_state_dict[new_key] = value
if key != new_key:
print(f" Mapping '{key}' -> '{new_key}'")
else:
# print(f" Keeping key as is: '{key}'") # Optional
pass
else:
print(f" INFO: Discarding key not mapped: {key}")
# Check if the critical classifier output layer was present in the source checkpoint
# This check might need adjustment based on the actual layers of ClassifierHead
# For now, we check if any 'out_proj' key was seen under 'classifier.'
if not has_classifier_weights_transformed:
print("WARNING: No 'classifier.out_proj.*' keys were found in the source checkpoint.")
print(" Ensure your checkpoint contains the expected classifier layers.")
# Not necessarily an error to abort, as other classifier keys might be valid.
model_state_dict = transformed_state_dict
# --- DEBUG: Print keys of the TRANSFORMED state_dict ---
print("\n--- Keys in TRANSFORMED model_state_dict for upload (first 30 and last 10): ---")
state_dict_keys_transformed = list(transformed_state_dict.keys())
if len(state_dict_keys_transformed) > 0:
for i, key_t in enumerate(state_dict_keys_transformed[:30]):
print(f" {i+1}. {key_t}")
if len(state_dict_keys_transformed) > 40:
print(" ...")
start_index_for_last_10_t = max(30, len(state_dict_keys_transformed) - 10)
for i, key_idx_t in enumerate(range(start_index_for_last_10_t, len(state_dict_keys_transformed))):
print(f" {key_idx_t+1}. {state_dict_keys_transformed[key_idx_t]}")
else:
print(" (No keys found in transformed_state_dict)")
print(f"Total keys in transformed_state_dict: {len(state_dict_keys_transformed)}")
print("-----------------------------------------------------------\n")
# Save the TRANSFORMED state_dict
torch.save(transformed_state_dict, hf_model_path)
print(f"Saved TRANSFORMED model state_dict to {hf_model_path}.")
# 4. Copy other project files
for project_file in PROJECT_FILES_TO_UPLOAD:
local_project_file_path = project_file # Files are now at the root
if os.path.exists(local_project_file_path):
shutil.copy(local_project_file_path, os.path.join(temp_dir, os.path.basename(project_file)))
print(f"Copied project file {project_file} to {temp_dir}.")
# Before uploading, let's inspect the temp_dir to be absolutely sure what's there
print(f"--- Inspecting temp_dir ({temp_dir}) before upload: ---")
for item in os.listdir(temp_dir):
print(f" - {item}")
temp_config_path_to_check = os.path.join(temp_dir, "config.json")
if os.path.exists(temp_config_path_to_check):
print(f"--- Content of {temp_config_path_to_check} before upload: ---")
with open(temp_config_path_to_check, 'r') as f_check:
print(f_check.read())
print("--- End of config.json content ---")
else:
print(f"WARNING: {temp_config_path_to_check} does NOT exist before upload!")
# 5. Upload the contents of the temporary directory
print(f"Uploading all files from {temp_dir} to {REPO_ID}...")
try:
upload_folder(
folder_path=temp_dir,
repo_id=REPO_ID,
repo_type="model",
commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
)
print("All files uploaded successfully!")
except Exception as e:
print(f"Error uploading files: {e}")
finally:
print(f"Cleaning up temporary directory: {temp_dir}")
# The TemporaryDirectory context manager handles cleanup automatically
# but an explicit message is good for clarity.
print("Upload process finished.")
if __name__ == "__main__":
upload_model_and_tokenizer() |