Spaces:
Sleeping
Sleeping
from huggingface_hub import HfApi, upload_folder, create_repo, login | |
from transformers import AutoTokenizer, AutoConfig | |
import os | |
import shutil | |
import tempfile | |
import torch | |
import argparse | |
# --- Configuration --- | |
HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username | |
MODEL_NAME_ON_HF = "modernbert-imdb-sentiment" # The name of the model on Hugging Face | |
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}" | |
# Original base model from which the tokenizer and initial config were derived | |
ORIGINAL_BASE_MODEL_NAME = "answerdotai/ModernBERT-base" | |
# Local path to your fine-tuned model checkpoint | |
LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints" | |
FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint | |
# If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBert | |
# For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base. | |
# Files from your project to include (e.g., custom model code, inference script) | |
# The user has moved these to the root directory. | |
PROJECT_FILES_TO_UPLOAD = [ | |
"config.yaml", | |
"inference.py", | |
"models.py", | |
"train_utils.py", | |
"classifiers.py", | |
"README.md" | |
] | |
def upload_model_and_tokenizer(): | |
api = HfApi() | |
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}" | |
print(f"Preparing to upload to Hugging Face Hub repository: {REPO_ID}") | |
# Create the repository on Hugging Face Hub if it doesn't exist | |
# This should be done after login to ensure correct permissions | |
print(f"Ensuring repository '{REPO_ID}' exists on Hugging Face Hub...") | |
try: | |
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) | |
print(f"Repository '{REPO_ID}' ensured.") | |
except Exception as e: | |
print(f"Error creating/accessing repository {REPO_ID}: {e}") | |
print("Please check your Hugging Face token and repository permissions.") | |
return | |
# Create a temporary directory to gather all files for upload | |
with tempfile.TemporaryDirectory() as temp_dir: | |
print(f"Created temporary directory for upload: {temp_dir}") | |
# 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME | |
print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME) | |
tokenizer.save_pretrained(temp_dir) | |
print("Tokenizer files saved.") | |
except Exception as e: | |
print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}") | |
print("Please ensure this model name is correct and accessible.") | |
return | |
# 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME | |
# This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work. | |
print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {temp_dir}...") | |
try: | |
config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME) | |
print(f"Config loaded. Initial num_labels (if exists): {getattr(config, 'num_labels', 'Not set')}") | |
# Set architecture first | |
config.architectures = ["ModernBertForSentiment"] | |
# Add necessary classification head attributes for AutoModelForSequenceClassification | |
config.num_labels = 1 # For IMDB sentiment (binary, single logit output based on training) | |
print(f"After attempting to set: config.num_labels = {config.num_labels}") | |
config.id2label = {0: "NEGATIVE", 1: "POSITIVE"} # Standard for binary, even with num_labels=1 | |
config.label2id = {"NEGATIVE": 0, "POSITIVE": 1} | |
print(f"After setting id2label/label2id, config.num_labels is: {config.num_labels}") | |
# CRITICAL: Force num_labels to 1 again immediately before saving | |
config.num_labels = 1 | |
print(f"Immediately before save, FINAL check config.num_labels = {config.num_labels}") | |
# Safeguard: Remove any existing config.json from temp_dir before saving ours | |
potential_old_config_path = os.path.join(temp_dir, "config.json") | |
if os.path.exists(potential_old_config_path): | |
os.remove(potential_old_config_path) | |
print(f"Removed existing config.json from {temp_dir} to ensure clean save.") | |
config.save_pretrained(temp_dir) | |
print(f"Model config.json (with num_labels={config.num_labels}, architectures={config.architectures}) saved to {temp_dir}.") | |
except Exception as e: | |
print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}") | |
return | |
# Load the fine-tuned model checkpoint to extract the state_dict | |
full_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME) | |
hf_model_path = os.path.join(temp_dir, "pytorch_model.bin") | |
if not os.path.exists(full_checkpoint_path): | |
print(f"ERROR: Local model checkpoint not found at {full_checkpoint_path}") | |
shutil.rmtree(temp_dir) | |
return | |
print(f"Loading local checkpoint from: {full_checkpoint_path}") | |
# Load checkpoint to CPU to avoid GPU memory issues if the script runner doesn't have/need GPU | |
checkpoint = torch.load(full_checkpoint_path, map_location='cpu') | |
model_state_dict = None | |
if 'model_state_dict' in checkpoint: | |
model_state_dict = checkpoint['model_state_dict'] | |
print("Extracted 'model_state_dict' from checkpoint.") | |
elif 'state_dict' in checkpoint: # Another common key for state_dicts | |
model_state_dict = checkpoint['state_dict'] | |
print("Extracted 'state_dict' from checkpoint.") | |
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()): | |
# If the checkpoint is already a state_dict (e.g., from torch.save(model.state_dict(), ...)) | |
# Basic check: does it have keys that look like weights/biases? | |
if any(key.endswith('.weight') or key.endswith('.bias') for key in checkpoint.keys()): | |
model_state_dict = checkpoint | |
print("Checkpoint appears to be a raw state_dict (contains .weight or .bias keys).") | |
else: | |
print("Checkpoint is a dict, but does not immediately appear to be a state_dict (no .weight/.bias keys found).") | |
print(f"Checkpoint keys: {list(checkpoint.keys())[:10]}...") # Print some keys for diagnosis | |
else: | |
# This case handles if checkpoint is not a dict or doesn't match known structures | |
print(f"ERROR: Could not find a known state_dict key in the checkpoint, and it's not a recognizable raw state_dict.") | |
if isinstance(checkpoint, dict): | |
print(f"Checkpoint dictionary keys found: {list(checkpoint.keys())}") | |
else: | |
print(f"Checkpoint is not a dictionary. Type: {type(checkpoint)}") | |
shutil.rmtree(temp_dir) | |
return | |
if model_state_dict is None: | |
print("ERROR: model_state_dict was not successfully extracted. Aborting upload.") | |
shutil.rmtree(temp_dir) | |
return | |
# --- DEBUG: Print keys of the state_dict --- | |
print("\n--- Keys in extracted (original) model_state_dict (first 30 and last 10): ---") | |
state_dict_keys = list(model_state_dict.keys()) | |
if len(state_dict_keys) > 0: | |
for i, key in enumerate(state_dict_keys[:30]): | |
print(f" {i+1}. {key}") | |
if len(state_dict_keys) > 40: # Show ellipsis if there's a gap | |
print(" ...") | |
# Print last 10 keys if there are more than 30 | |
start_index_for_last_10 = max(30, len(state_dict_keys) - 10) | |
for i, key_idx in enumerate(range(start_index_for_last_10, len(state_dict_keys))): | |
print(f" {key_idx+1}. {state_dict_keys[key_idx]}") | |
else: | |
print(" (No keys found in model_state_dict)") | |
print(f"Total keys: {len(state_dict_keys)}") | |
print("-----------------------------------------------------------\n") | |
# --- END DEBUG --- | |
# Transform keys for Hugging Face compatibility if needed. | |
# For ModernBertForSentiment with self.bert and self.classifier (custom head): | |
# - Checkpoint 'bert.*' should remain 'bert.*' | |
# - Checkpoint 'classifier.*' keys (e.g., classifier.dense1.weight, classifier.out_proj.weight) should remain 'classifier.*' as they are. | |
transformed_state_dict = {} | |
has_classifier_weights_transformed = False # Used to track if out_proj was found | |
print("Transforming state_dict keys for Hugging Face Hub compatibility...") | |
for key, value in model_state_dict.items(): | |
new_key = None | |
if key.startswith("bert."): | |
# Keep 'bert.' prefix as ModernBertForSentiment uses self.bert | |
new_key = key | |
elif key.startswith("classifier."): | |
# All parts of the custom classifier head should retain their names | |
new_key = key | |
if "out_proj" in key: # Just to confirm it exists | |
has_classifier_weights_transformed = True # Indicate out_proj was found and processed | |
if new_key: | |
transformed_state_dict[new_key] = value | |
if key != new_key: | |
print(f" Mapping '{key}' -> '{new_key}'") | |
else: | |
# print(f" Keeping key as is: '{key}'") # Optional | |
pass | |
else: | |
print(f" INFO: Discarding key not mapped: {key}") | |
# Check if the critical classifier output layer was present in the source checkpoint | |
# This check might need adjustment based on the actual layers of ClassifierHead | |
# For now, we check if any 'out_proj' key was seen under 'classifier.' | |
if not has_classifier_weights_transformed: | |
print("WARNING: No 'classifier.out_proj.*' keys were found in the source checkpoint.") | |
print(" Ensure your checkpoint contains the expected classifier layers.") | |
# Not necessarily an error to abort, as other classifier keys might be valid. | |
model_state_dict = transformed_state_dict | |
# --- DEBUG: Print keys of the TRANSFORMED state_dict --- | |
print("\n--- Keys in TRANSFORMED model_state_dict for upload (first 30 and last 10): ---") | |
state_dict_keys_transformed = list(transformed_state_dict.keys()) | |
if len(state_dict_keys_transformed) > 0: | |
for i, key_t in enumerate(state_dict_keys_transformed[:30]): | |
print(f" {i+1}. {key_t}") | |
if len(state_dict_keys_transformed) > 40: | |
print(" ...") | |
start_index_for_last_10_t = max(30, len(state_dict_keys_transformed) - 10) | |
for i, key_idx_t in enumerate(range(start_index_for_last_10_t, len(state_dict_keys_transformed))): | |
print(f" {key_idx_t+1}. {state_dict_keys_transformed[key_idx_t]}") | |
else: | |
print(" (No keys found in transformed_state_dict)") | |
print(f"Total keys in transformed_state_dict: {len(state_dict_keys_transformed)}") | |
print("-----------------------------------------------------------\n") | |
# Save the TRANSFORMED state_dict | |
torch.save(transformed_state_dict, hf_model_path) | |
print(f"Saved TRANSFORMED model state_dict to {hf_model_path}.") | |
# 4. Copy other project files | |
for project_file in PROJECT_FILES_TO_UPLOAD: | |
local_project_file_path = project_file # Files are now at the root | |
if os.path.exists(local_project_file_path): | |
shutil.copy(local_project_file_path, os.path.join(temp_dir, os.path.basename(project_file))) | |
print(f"Copied project file {project_file} to {temp_dir}.") | |
# Before uploading, let's inspect the temp_dir to be absolutely sure what's there | |
print(f"--- Inspecting temp_dir ({temp_dir}) before upload: ---") | |
for item in os.listdir(temp_dir): | |
print(f" - {item}") | |
temp_config_path_to_check = os.path.join(temp_dir, "config.json") | |
if os.path.exists(temp_config_path_to_check): | |
print(f"--- Content of {temp_config_path_to_check} before upload: ---") | |
with open(temp_config_path_to_check, 'r') as f_check: | |
print(f_check.read()) | |
print("--- End of config.json content ---") | |
else: | |
print(f"WARNING: {temp_config_path_to_check} does NOT exist before upload!") | |
# 5. Upload the contents of the temporary directory | |
print(f"Uploading all files from {temp_dir} to {REPO_ID}...") | |
try: | |
upload_folder( | |
folder_path=temp_dir, | |
repo_id=REPO_ID, | |
repo_type="model", | |
commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}" | |
) | |
print("All files uploaded successfully!") | |
except Exception as e: | |
print(f"Error uploading files: {e}") | |
finally: | |
print(f"Cleaning up temporary directory: {temp_dir}") | |
# The TemporaryDirectory context manager handles cleanup automatically | |
# but an explicit message is good for clarity. | |
print("Upload process finished.") | |
if __name__ == "__main__": | |
upload_model_and_tokenizer() |