Spaces:
Runtime error
Runtime error
import datetime | |
import os | |
from dataclasses import asdict, dataclass | |
from functools import lru_cache | |
from json import JSONDecodeError | |
from typing import List, Optional, Union | |
import gradio as gr | |
import requests | |
from huggingface_hub import ( | |
HfApi, | |
ModelCard, | |
hf_hub_url, | |
list_repo_commits, | |
logging, | |
model_info, | |
) | |
from huggingface_hub.utils import EntryNotFoundError, disable_progress_bars | |
disable_progress_bars() | |
logging.set_verbosity_error() | |
token = os.getenv("HF_TOKEN") | |
def get_model_labels(model): | |
try: | |
url = hf_hub_url(repo_id=model, filename="config.json") | |
return list(requests.get(url).json()["label2id"].keys()) | |
except (KeyError, JSONDecodeError, AttributeError): | |
return None | |
class EngagementStats: | |
likes: int | |
downloads: int | |
created_at: datetime.datetime | |
def _get_engagement_stats(hub_id): | |
api = HfApi(token=token) | |
repo = api.repo_info(hub_id) | |
return EngagementStats( | |
likes=repo.likes, | |
downloads=repo.downloads, | |
created_at=list_repo_commits(hub_id, repo_type="model")[-1].created_at, | |
) | |
def _try_load_model_card(hub_id): | |
try: | |
card_text = ModelCard.load(hub_id, token=token).text | |
length = len(card_text) | |
except EntryNotFoundError: | |
card_text = None | |
length = None | |
return card_text, length | |
def _try_parse_card_data(hub_id): | |
data = {} | |
keys = ["license", "language", "datasets", "tags"] | |
for key in keys: | |
try: | |
value = model_info(hub_id, token=token).cardData[key] | |
data[key] = value | |
except (KeyError, AttributeError): | |
data[key] = None | |
return data | |
class ModelMetadata: | |
hub_id: str | |
tags: Optional[List[str]] | |
license: Optional[str] | |
library_name: Optional[str] | |
datasets: Optional[List[str]] | |
pipeline_tag: Optional[str] | |
labels: Optional[List[str]] | |
languages: Optional[Union[str, List[str]]] | |
engagement_stats: Optional[EngagementStats] = None | |
model_card_text: Optional[str] = None | |
model_card_length: Optional[int] = None | |
def from_hub(cls, hub_id): | |
model = model_info(hub_id) | |
card_text, length = _try_load_model_card(hub_id) | |
data = _try_parse_card_data(hub_id) | |
try: | |
library_name = model.library_name | |
except AttributeError: | |
library_name = None | |
# try: | |
# tags = model.tags | |
# except AttributeError: | |
# tags = None | |
try: | |
pipeline_tag = model.pipeline_tag | |
except AttributeError: | |
pipeline_tag = None | |
return ModelMetadata( | |
hub_id=hub_id, | |
languages=data["language"], | |
tags=data["tags"], | |
license=data["license"], | |
library_name=library_name, | |
datasets=data["datasets"], | |
pipeline_tag=pipeline_tag, | |
labels=get_model_labels(hub_id), | |
engagement_stats=_get_engagement_stats(hub_id), | |
model_card_text=card_text, | |
model_card_length=length, | |
) | |
COMMON_SCORES = { | |
"license": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You have not added a license to your models metadata" | |
), | |
}, | |
"datasets": { | |
"required": False, | |
"score": 1, | |
"missing_recommendation": ( | |
"You have not added any datasets to your models metadata" | |
), | |
}, | |
"model_card_text": { | |
"required": True, | |
"score": 3, | |
"missing_recommendation": """You haven't created a model card for your model. It is strongly recommended to have a model card for your model. \nYou can create for your model by clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)""", | |
}, | |
"tags": { | |
"required": False, | |
"score": 2, | |
"missing_recommendation": ( | |
"You don't have any tags defined in your model metadata. Tags can help" | |
" people find relevant models on the Hub. You can create for your model by" | |
" clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)" | |
), | |
}, | |
} | |
TASK_TYPES_WITH_LANGUAGES = { | |
"text-classification", | |
"token-classification", | |
"table-question-answering", | |
"question-answering", | |
"zero-shot-classification", | |
"translation", | |
"summarization", | |
"text-generation", | |
"text2text-generation", | |
"fill-mask", | |
"sentence-similarity", | |
"text-to-speech", | |
"automatic-speech-recognition", | |
"text-to-image", | |
"image-to-text", | |
"visual-question-answering", | |
"document-question-answering", | |
} | |
LABELS_REQUIRED_TASKS = { | |
"text-classification", | |
"token-classification", | |
"object-detection", | |
"audio-classification", | |
"image-classification", | |
"tabular-classification", | |
} | |
ALL_PIPELINES = { | |
"audio-classification", | |
"audio-to-audio", | |
"automatic-speech-recognition", | |
"conversational", | |
"depth-estimation", | |
"document-question-answering", | |
"feature-extraction", | |
"fill-mask", | |
"graph-ml", | |
"image-classification", | |
"image-segmentation", | |
"image-to-image", | |
"image-to-text", | |
"object-detection", | |
"question-answering", | |
"reinforcement-learning", | |
"robotics", | |
"sentence-similarity", | |
"summarization", | |
"table-question-answering", | |
"tabular-classification", | |
"tabular-regression", | |
"text-classification", | |
"text-generation", | |
"text-to-image", | |
"text-to-speech", | |
"text-to-video", | |
"text2text-generation", | |
"token-classification", | |
"translation", | |
"unconditional-image-generation", | |
"video-classification", | |
"visual-question-answering", | |
"voice-activity-detection", | |
"zero-shot-classification", | |
"zero-shot-image-classification", | |
} | |
def generate_task_scores_dict(): | |
task_scores = {} | |
for task in ALL_PIPELINES: | |
task_dict = COMMON_SCORES.copy() | |
if task in TASK_TYPES_WITH_LANGUAGES: | |
task_dict = { | |
**task_dict, | |
**{ | |
"languages": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any languages in your metadata. This" | |
f" is usually recommend for {task} task" | |
), | |
} | |
}, | |
} | |
if task in LABELS_REQUIRED_TASKS: | |
task_dict = { | |
**task_dict, | |
**{ | |
"labels": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any labels in the config.json file" | |
f" these are usually recommended for {task}" | |
), | |
} | |
}, | |
} | |
max_score = sum(value["score"] for value in task_dict.values()) | |
task_dict["_max_score"] = max_score | |
task_scores[task] = task_dict | |
return task_scores | |
SCORES = generate_task_scores_dict() | |
def _basic_check(hub_id): | |
try: | |
data = ModelMetadata.from_hub(hub_id) | |
score = 0 | |
if task := data.pipeline_tag: | |
task_scores = SCORES[task] | |
to_fix = {} | |
data_dict = asdict(data) | |
for k, v in task_scores.items(): | |
if k.startswith("_"): | |
continue | |
if data_dict[k] is None: | |
to_fix[k] = task_scores[k]["missing_recommendation"].replace( | |
"HUB_ID", hub_id | |
) | |
if data_dict[k] is not None: | |
score += v["score"] | |
max_score = task_scores["_max_score"] | |
score = score / max_score | |
score_summary = ( | |
f"Your model's metadata score is {round(score*100)}% based on suggested" | |
f" metadata for {task}. \n" | |
) | |
# recommendations = [] | |
if to_fix: | |
recommendations = ( | |
"Here are some suggestions to improve your model's metadata for" | |
f" {task}: \n" | |
) | |
for v in to_fix.values(): | |
recommendations += f"\n- {v}" | |
return score_summary + recommendations if recommendations else score_summary | |
except Exception as e: | |
print(e) | |
return None | |
def basic_check(hub_id): | |
return _basic_check(hub_id) | |
# print("caching models...") | |
# print("getting top 5,000 models") | |
# models = list_models(sort="downloads", direction=-1, limit=5_000) | |
# model_ids = [model.modelId for model in models] | |
# print("calculating metadata scores...") | |
# thread_map(basic_check, model_ids) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Model Metadata Checker | |
This app will check your model's metadata for a few common issues.""" | |
) | |
with gr.Row(): | |
text = gr.Text(label="Model ID") | |
button = gr.Button(label="Check", type="submit") | |
with gr.Row(): | |
gr.Markdown("Results") | |
markdown = gr.Markdown(name="markdown") | |
button.click(_basic_check, text, markdown) | |
demo.launch(debug=True) | |
# gr.Interface(fn=basic_check, inputs="text", outputs="markdown").launch(debug=True) | |