|
import json |
|
import logging |
|
import asyncio |
|
from typing import Tuple, Optional, Dict, Any |
|
from datasets import load_dataset |
|
from huggingface_hub import HfApi, ModelCard, hf_hub_download |
|
from huggingface_hub import hf_api |
|
from transformers import AutoConfig, AutoTokenizer |
|
from app.config.base import HF_TOKEN |
|
from app.config.hf_config import OFFICIAL_PROVIDERS_REPO |
|
from app.core.formatting import LogFormatter |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
GATED_ERROR = "The model is gated by the model authors and requires special access permissions. Please contact us to request evaluation." |
|
|
|
class ModelValidator: |
|
def __init__(self): |
|
self.token = HF_TOKEN |
|
self.api = HfApi(token=self.token) |
|
self.headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} |
|
|
|
async def check_model_card( |
|
self, model_id: str |
|
) -> Tuple[bool, str, Optional[Dict[str, Any]]]: |
|
"""Check if model has a valid model card""" |
|
try: |
|
logger.info(LogFormatter.info(f"Checking model card for {model_id}")) |
|
|
|
|
|
try: |
|
model_card = await asyncio.to_thread(ModelCard.load, model_id) |
|
logger.info(LogFormatter.success("Model card found")) |
|
except Exception as e: |
|
error_msg = "Please add a model card to your model to explain how you trained/fine-tuned it." |
|
logger.error(LogFormatter.error(error_msg, e)) |
|
return False, error_msg, None |
|
|
|
|
|
if model_card.data.license is None and not ( |
|
"license_name" in model_card.data and "license_link" in model_card.data |
|
): |
|
error_msg = "License not found. Please add a license to your model card using the `license` metadata or a `license_name`/`license_link` pair." |
|
logger.warning(LogFormatter.warning(error_msg)) |
|
return False, error_msg, None |
|
|
|
|
|
if len(model_card.text) < 200: |
|
error_msg = ( |
|
"Please add a description to your model card, it is too short." |
|
) |
|
logger.warning(LogFormatter.warning(error_msg)) |
|
return False, error_msg, None |
|
|
|
logger.info(LogFormatter.success("Model card validation passed")) |
|
return True, "", model_card |
|
|
|
except Exception as e: |
|
error_msg = "Failed to validate model card" |
|
logger.error(LogFormatter.error(error_msg, e)) |
|
return False, str(e), None |
|
|
|
async def get_safetensors_metadata( |
|
self, model_id: str, is_adapter: bool = False, revision: str = "main" |
|
) -> Optional[Dict]: |
|
"""Get metadata from a safetensors file""" |
|
try: |
|
if is_adapter: |
|
metadata = await asyncio.to_thread( |
|
hf_api.parse_safetensors_file_metadata, |
|
model_id, |
|
"adapter_model.safetensors", |
|
token=self.token, |
|
revision=revision, |
|
) |
|
else: |
|
metadata = await asyncio.to_thread( |
|
hf_api.get_safetensors_metadata, |
|
repo_id=model_id, |
|
token=self.token, |
|
revision=revision, |
|
) |
|
return metadata |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to get safetensors metadata: {str(e)}") |
|
return None |
|
|
|
async def get_model_size( |
|
self, model_info: Any, precision: str, base_model: str, revision: str |
|
) -> Tuple[Optional[float], Optional[str]]: |
|
""" |
|
Get model size in billions of parameters. |
|
|
|
First, try to use safetensors metadata (which includes a parameter count). |
|
If that isn’t available, then as a fallback, use file metadata from the repository |
|
to sum the sizes of weight files. |
|
|
|
For the fallback, we assume (for example) that for float16 storage each parameter takes ~2 bytes. |
|
For GPTQ models (detected via the precision argument or model ID), we adjust by a factor (e.g. 8). |
|
|
|
Returns: |
|
Tuple of (model_size_in_billions, error_message). If successful, error_message is None. |
|
""" |
|
try: |
|
logger.info( |
|
LogFormatter.info(f"Checking model size for {model_info.modelId}") |
|
) |
|
|
|
|
|
is_adapter = any( |
|
hasattr(s, "rfilename") and s.rfilename == "adapter_config.json" |
|
for s in model_info.siblings |
|
) |
|
|
|
model_size = None |
|
|
|
if is_adapter and base_model: |
|
|
|
adapter_meta = await self.get_safetensors_metadata( |
|
model_info.id, is_adapter=True, revision=revision |
|
) |
|
base_meta = await self.get_safetensors_metadata( |
|
base_model, revision="main" |
|
) |
|
if adapter_meta and base_meta: |
|
adapter_size = sum(adapter_meta.parameter_count.values()) |
|
base_size = sum(base_meta.parameter_count.values()) |
|
model_size = adapter_size + base_size |
|
else: |
|
|
|
meta = await self.get_safetensors_metadata( |
|
model_info.id, revision=revision |
|
) |
|
if meta: |
|
model_size = sum(meta.parameter_count.values()) |
|
|
|
if model_size is not None: |
|
|
|
factor = ( |
|
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 |
|
) |
|
model_size = round((model_size / 1e9) * factor, 3) |
|
logger.info( |
|
LogFormatter.success( |
|
f"Model size: {model_size}B parameters (from safetensors metadata)" |
|
) |
|
) |
|
return model_size, None |
|
|
|
|
|
logger.info( |
|
"Safetensors metadata not available. Falling back to file metadata to estimate model size." |
|
) |
|
weight_file_extensions = [".bin", ".safetensors"] |
|
fallback_size_bytes = 0 |
|
|
|
|
|
if not model_info.siblings or all( |
|
getattr(s, "size", None) is None for s in model_info.siblings |
|
): |
|
logger.info( |
|
"Re-fetching model info with file metadata for fallback estimation." |
|
) |
|
model_info = await asyncio.to_thread( |
|
self.api.model_info, model_info.id, files_metadata=True |
|
) |
|
|
|
|
|
for sibling in model_info.siblings: |
|
if hasattr(sibling, "rfilename") and sibling.size is not None: |
|
if any( |
|
sibling.rfilename.endswith(ext) |
|
for ext in weight_file_extensions |
|
): |
|
fallback_size_bytes += sibling.size |
|
|
|
if fallback_size_bytes > 0: |
|
|
|
|
|
factor = ( |
|
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1 |
|
) |
|
estimated_param_count = (fallback_size_bytes / 2) * factor |
|
model_size = round(estimated_param_count / 1e9, 3) |
|
logger.info( |
|
LogFormatter.success( |
|
f"Fallback model size: {model_size}B parameters" |
|
) |
|
) |
|
return model_size, None |
|
else: |
|
return ( |
|
None, |
|
"Model size could not be determined using file metadata fallback", |
|
) |
|
|
|
except Exception as e: |
|
logger.error(LogFormatter.error(f"Error while determining model size: {e}")) |
|
return None, str(e) |
|
|
|
async def check_chat_template( |
|
self, model_id: str, revision: str |
|
) -> Tuple[bool, Optional[str]]: |
|
"""Check if model has a valid chat template""" |
|
try: |
|
logger.info(LogFormatter.info(f"Checking chat template for {model_id}")) |
|
|
|
try: |
|
config_file = await asyncio.to_thread( |
|
hf_hub_download, |
|
repo_id=model_id, |
|
filename="tokenizer_config.json", |
|
revision=revision, |
|
repo_type="model", |
|
) |
|
|
|
with open(config_file, "r") as f: |
|
tokenizer_config = json.load(f) |
|
|
|
if "chat_template" not in tokenizer_config: |
|
error_msg = f"The model {model_id} doesn't have a chat_template in its tokenizer_config.json. Chat templates are required to accurately evaluate responses." |
|
logger.error(LogFormatter.error(error_msg)) |
|
return False, error_msg |
|
|
|
logger.info(LogFormatter.success("Valid chat template found")) |
|
return True, None |
|
|
|
except Exception as e: |
|
error_msg = f"Error checking chat_template: {str(e)}" |
|
logger.error(LogFormatter.error(error_msg)) |
|
return False, error_msg |
|
|
|
except Exception as e: |
|
error_msg = "Failed to check chat template" |
|
logger.error(LogFormatter.error(error_msg, e)) |
|
return False, str(e) |
|
|
|
async def is_model_on_hub( |
|
self, |
|
model_name: str, |
|
revision: str, |
|
gated: bool = False, |
|
test_tokenizer: bool = False, |
|
trust_remote_code: bool = False, |
|
) -> Tuple[bool, Optional[str], Optional[Any]]: |
|
"""Check if model exists and is properly configured on the Hub""" |
|
try: |
|
config = await asyncio.to_thread( |
|
AutoConfig.from_pretrained, |
|
model_name, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
token=self.token, |
|
force_download=True, |
|
) |
|
|
|
if test_tokenizer: |
|
try: |
|
await asyncio.to_thread( |
|
AutoTokenizer.from_pretrained, |
|
model_name, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
token=self.token, |
|
) |
|
except ValueError as e: |
|
return ( |
|
False, |
|
f"The tokenizer is not available in an official Transformers release: {e}", |
|
None, |
|
) |
|
except Exception: |
|
|
|
if gated: |
|
return ( |
|
False, |
|
GATED_ERROR, |
|
None, |
|
) |
|
return ( |
|
False, |
|
"The tokenizer cannot be loaded. Ensure the tokenizer class is part of a stable Transformers release and correctly configured.", |
|
None, |
|
) |
|
|
|
return True, None, config |
|
|
|
except ValueError: |
|
return ( |
|
False, |
|
"The model requires `trust_remote_code=True` to launch, and for safety reasons, we don't accept such models automatically.", |
|
None, |
|
) |
|
except Exception as e: |
|
if gated: |
|
return ( |
|
False, |
|
GATED_ERROR, |
|
None, |
|
) |
|
return ( |
|
False, |
|
f"The model was not found or is misconfigured on the Hub. Error: {e.args[0]}", |
|
None, |
|
) |
|
|
|
async def check_official_provider_status( |
|
self, model_id: str, existing_models: Dict[str, list] |
|
) -> Tuple[bool, Optional[str]]: |
|
""" |
|
Check if model is from official provider and has finished submission. |
|
|
|
Args: |
|
model_id: The model identifier (org/model-name) |
|
existing_models: Dictionary of models by status from get_models() |
|
|
|
Returns: |
|
Tuple[bool, Optional[str]]: (is_valid, error_message) |
|
""" |
|
try: |
|
logger.info( |
|
LogFormatter.info(f"Checking official provider status for {model_id}") |
|
) |
|
|
|
|
|
model_org = model_id.split("/")[0] if "/" in model_id else None |
|
|
|
if not model_org: |
|
return True, None |
|
|
|
|
|
dataset = load_dataset(OFFICIAL_PROVIDERS_REPO) |
|
official_providers = dataset["train"][0]["CURATED_SET"] |
|
|
|
|
|
is_official = model_org in official_providers |
|
|
|
if is_official: |
|
logger.info( |
|
LogFormatter.info( |
|
f"Model organization '{model_org}' is an official provider" |
|
) |
|
) |
|
|
|
|
|
if "finished" in existing_models: |
|
for model in existing_models["finished"]: |
|
|
|
if model["name"] == model_id and False: |
|
error_msg = ( |
|
f"Model {model_id} is an official provider model " |
|
f"with a completed evaluation. " |
|
f"To re-evaluate, please open a discussion." |
|
) |
|
logger.error( |
|
LogFormatter.error("Validation failed", error_msg) |
|
) |
|
return False, error_msg |
|
|
|
logger.info( |
|
LogFormatter.success( |
|
"No finished submission found for this official provider model" |
|
) |
|
) |
|
else: |
|
logger.info( |
|
LogFormatter.info( |
|
f"Model organization '{model_org}' is not an official provider" |
|
) |
|
) |
|
|
|
return True, None |
|
|
|
except Exception as e: |
|
error_msg = f"Failed to check official provider status: {str(e)}" |
|
logger.error(LogFormatter.error(error_msg)) |
|
return False, error_msg |
|
|