import json import logging import asyncio from typing import Tuple, Optional, Dict, Any from datasets import load_dataset from huggingface_hub import HfApi, ModelCard, hf_hub_download from huggingface_hub import hf_api from transformers import AutoConfig, AutoTokenizer from app.config.base import HF_TOKEN from app.config.hf_config import OFFICIAL_PROVIDERS_REPO from app.core.formatting import LogFormatter logger = logging.getLogger(__name__) GATED_ERROR = "The model is gated by the model authors and requires special access permissions. Please contact us to request evaluation." class ModelValidator: def __init__(self): self.token = HF_TOKEN self.api = HfApi(token=self.token) self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} async def check_model_card( self, model_id: str ) -> Tuple[bool, str, Optional[Dict[str, Any]]]: """Check if model has a valid model card""" try: logger.info(LogFormatter.info(f"Checking model card for {model_id}")) # Get model card content using ModelCard.load try: model_card = await asyncio.to_thread(ModelCard.load, model_id) logger.info(LogFormatter.success("Model card found")) except Exception as e: error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it." logger.error(LogFormatter.error(error_msg, e)) return False, error_msg, None # Check license in model card data if model_card.data.license is None and not ( "license_name" in model_card.data and "license_link" in model_card.data ): error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair." logger.warning(LogFormatter.warning(error_msg)) return False, error_msg, None # Enforce card content length if len(model_card.text) < 200: error_msg = ( "Please add a description to your model card, it is too short." ) logger.warning(LogFormatter.warning(error_msg)) return False, error_msg, None logger.info(LogFormatter.success("Model card validation passed")) return True, "", model_card except Exception as e: error_msg = "Failed to validate model card" logger.error(LogFormatter.error(error_msg, e)) return False, str(e), None async def get_safetensors_metadata( self, model_id: str, is_adapter: bool = False, revision: str = "main" ) -> Optional[Dict]: """Get metadata from a safetensors file""" try: if is_adapter: metadata = await asyncio.to_thread( hf_api.parse_safetensors_file_metadata, model_id, "adapter_model.safetensors", token=self.token, revision=revision, ) else: metadata = await asyncio.to_thread( hf_api.get_safetensors_metadata, repo_id=model_id, token=self.token, revision=revision, ) return metadata except Exception as e: logger.error(f"Failed to get safetensors metadata: {str(e)}") return None async def get_model_size( self, model_info: Any, precision: str, base_model: str, revision: str ) -> Tuple[Optional[float], Optional[str]]: """ Get model size in billions of parameters. First, try to use safetensors metadata (which includes a parameter count). If that isn’t available, then as a fallback, use file metadata from the repository to sum the sizes of weight files. For the fallback, we assume (for example) that for float16 storage each parameter takes ~2 bytes. For GPTQ models (detected via the precision argument or model ID), we adjust by a factor (e.g. 8). Returns: Tuple of (model_size_in_billions, error_message). If successful, error_message is None. """ try: logger.info( LogFormatter.info(f"Checking model size for {model_info.modelId}") ) # Check if model is an adapter by looking for an adapter config file. is_adapter = any( hasattr(s, "rfilename") and s.rfilename == "adapter_config.json" for s in model_info.siblings ) model_size = None # This will hold the total parameter count if available. if is_adapter and base_model: # For adapters, we need to get both the adapter and base model metadata. adapter_meta = await self.get_safetensors_metadata( model_info.id, is_adapter=True, revision=revision ) base_meta = await self.get_safetensors_metadata( base_model, revision="main" ) if adapter_meta and base_meta: adapter_size = sum(adapter_meta.parameter_count.values()) base_size = sum(base_meta.parameter_count.values()) model_size = adapter_size + base_size else: # For regular models, try to get the model size from safetensors metadata. meta = await self.get_safetensors_metadata( model_info.id, revision=revision ) if meta: model_size = sum(meta.parameter_count.values()) if model_size is not None: # Adjust for GPTQ models if necessary. factor = ( 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 ) model_size = round((model_size / 1e9) * factor, 3) logger.info( LogFormatter.success( f"Model size: {model_size}B parameters (from safetensors metadata)" ) ) return model_size, None # Fallback: use file metadata from the repository. logger.info( "Safetensors metadata not available. Falling back to file metadata to estimate model size." ) weight_file_extensions = [".bin", ".safetensors"] fallback_size_bytes = 0 # If model_info does not contain file metadata, re-fetch with files_metadata=True. if not model_info.siblings or all( getattr(s, "size", None) is None for s in model_info.siblings ): logger.info( "Re-fetching model info with file metadata for fallback estimation." ) model_info = await asyncio.to_thread( self.api.model_info, model_info.id, files_metadata=True ) # Sum up the sizes of files that appear to be weight files. for sibling in model_info.siblings: if hasattr(sibling, "rfilename") and sibling.size is not None: if any( sibling.rfilename.endswith(ext) for ext in weight_file_extensions ): fallback_size_bytes += sibling.size if fallback_size_bytes > 0: # Estimate parameter count based on file size. # For float16 weights we assume ~2 bytes per parameter. factor = ( 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 ) estimated_param_count = (fallback_size_bytes / 2) * factor model_size = round(estimated_param_count / 1e9, 3) # in billions logger.info( LogFormatter.success( f"Fallback model size: {model_size}B parameters" ) ) return model_size, None else: return ( None, "Model size could not be determined using file metadata fallback", ) except Exception as e: logger.error(LogFormatter.error(f"Error while determining model size: {e}")) return None, str(e) async def check_chat_template( self, model_id: str, revision: str ) -> Tuple[bool, Optional[str]]: """Check if model has a valid chat template""" try: logger.info(LogFormatter.info(f"Checking chat template for {model_id}")) try: config_file = await asyncio.to_thread( hf_hub_download, repo_id=model_id, filename="tokenizer_config.json", revision=revision, repo_type="model", ) with open(config_file, "r") as f: tokenizer_config = json.load(f) if "chat_template" not in tokenizer_config: error_msg = f"The model {model_id} doesn't have a chat_template in its tokenizer_config.json. Chat templates are required to accurately evaluate responses." logger.error(LogFormatter.error(error_msg)) return False, error_msg logger.info(LogFormatter.success("Valid chat template found")) return True, None except Exception as e: error_msg = f"Error checking chat_template: {str(e)}" logger.error(LogFormatter.error(error_msg)) return False, error_msg except Exception as e: error_msg = "Failed to check chat template" logger.error(LogFormatter.error(error_msg, e)) return False, str(e) async def is_model_on_hub( self, model_name: str, revision: str, gated: bool = False, test_tokenizer: bool = False, trust_remote_code: bool = False, ) -> Tuple[bool, Optional[str], Optional[Any]]: """Check if model exists and is properly configured on the Hub""" try: config = await asyncio.to_thread( AutoConfig.from_pretrained, model_name, revision=revision, trust_remote_code=trust_remote_code, token=self.token, force_download=True, ) if test_tokenizer: try: await asyncio.to_thread( AutoTokenizer.from_pretrained, model_name, revision=revision, trust_remote_code=trust_remote_code, token=self.token, ) except ValueError as e: return ( False, f"The tokenizer is not available in an official Transformers release: {e}", None, ) except Exception: # When running on hugging face we get into this except block instead of the one below if gated: return ( False, GATED_ERROR, None, ) return ( False, "The tokenizer cannot be loaded. Ensure the tokenizer class is part of a stable Transformers release and correctly configured.", None, ) return True, None, config except ValueError: return ( False, "The model requires `trust_remote_code=True` to launch, and for safety reasons, we don't accept such models automatically.", None, ) except Exception as e: if gated: return ( False, GATED_ERROR, None, ) return ( False, f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", None, ) async def check_official_provider_status( self, model_id: str, existing_models: Dict[str, list] ) -> Tuple[bool, Optional[str]]: """ Check if model is from official provider and has finished submission. Args: model_id: The model identifier (org/model-name) existing_models: Dictionary of models by status from get_models() Returns: Tuple[bool, Optional[str]]: (is_valid, error_message) """ try: logger.info( LogFormatter.info(f"Checking official provider status for {model_id}") ) # Get model organization model_org = model_id.split("/")[0] if "/" in model_id else None if not model_org: return True, None # Load official providers dataset dataset = load_dataset(OFFICIAL_PROVIDERS_REPO) official_providers = dataset["train"][0]["CURATED_SET"] # Check if model org is in official providers is_official = model_org in official_providers if is_official: logger.info( LogFormatter.info( f"Model organization '{model_org}' is an official provider" ) ) # Check for finished submissions if "finished" in existing_models: for model in existing_models["finished"]: # TODO: remove this after official provider evaluation is implemented if model["name"] == model_id and False: error_msg = ( f"Model {model_id} is an official provider model " f"with a completed evaluation. " f"To re-evaluate, please open a discussion." ) logger.error( LogFormatter.error("Validation failed", error_msg) ) return False, error_msg logger.info( LogFormatter.success( "No finished submission found for this official provider model" ) ) else: logger.info( LogFormatter.info( f"Model organization '{model_org}' is not an official provider" ) ) return True, None except Exception as e: error_msg = f"Failed to check official provider status: {str(e)}" logger.error(LogFormatter.error(error_msg)) return False, error_msg