import datetime import os import copy from dataclasses import asdict, dataclass from functools import lru_cache from json import JSONDecodeError from typing import Any, Dict, List, Optional, Union from huggingface_hub.utils import GatedRepoError import gradio as gr from requests.exceptions import HTTPError import requests from diskcache import Cache from huggingface_hub import ( HfApi, hf_hub_url, list_repo_commits, logging, model_info, ) from tqdm.auto import tqdm from tqdm.contrib.concurrent import thread_map import backoff from huggingface_hub.utils import EntryNotFoundError, disable_progress_bars import httpx import orjson import httpx from functools import lru_cache from httpx import Client from httpx_caching import CachingClient from httpx_caching import OneDayCacheHeuristic client = Client() client = CachingClient(client, heuristic=OneDayCacheHeuristic()) # CACHE_DIR = "./cache" if platform == "darwin" else "/data/" disable_progress_bars() logging.set_verbosity_error() token = os.getenv("HF_TOKEN") # cache = Cache(CACHE_DIR) def get_model_labels(model): try: url = hf_hub_url(repo_id=model, filename="config.json") return list(requests.get(url).json()["label2id"].keys()) except (KeyError, JSONDecodeError, AttributeError): return None @dataclass class EngagementStats: likes: int downloads: int created_at: datetime.datetime def _get_engagement_stats(hub_id): api = HfApi(token=token) repo = api.repo_info(hub_id) return EngagementStats( likes=repo.likes, downloads=repo.downloads, created_at=list_repo_commits(hub_id, repo_type="model")[-1].created_at, ) def _try_load_model_card(hub_id): try: url = hf_hub_url( repo_id=hub_id, filename="README.md" ) # We grab card this way rather than via client library to improve performance card_text = client.get(url).text length = len(card_text) except EntryNotFoundError: card_text = None length = None except ( GatedRepoError ): # TODO return different values to reflect gating rather than no card card_text = None length = None return card_text, length def _try_parse_card_data(hub_id): data = {} keys = ["license", "language", "datasets", "tags"] for key in keys: try: value = model_info(hub_id, token=token).cardData[key] data[key] = value except (KeyError, AttributeError): data[key] = None return data @dataclass class ModelMetadata: hub_id: str tags: Optional[List[str]] license: Optional[str] library_name: Optional[str] datasets: Optional[List[str]] pipeline_tag: Optional[str] labels: Optional[List[str]] languages: Optional[Union[str, List[str]]] engagement_stats: Optional[EngagementStats] = None model_card_text: Optional[str] = None model_card_length: Optional[int] = None @classmethod def from_hub(cls, hub_id): try: model = model_info(hub_id) except (GatedRepoError, HTTPError): return None # TODO catch gated repos and handle properly card_text, length = _try_load_model_card(hub_id) data = _try_parse_card_data(hub_id) try: library_name = model.library_name except AttributeError: library_name = None try: pipeline_tag = model.pipeline_tag except AttributeError: pipeline_tag = None return ModelMetadata( hub_id=hub_id, languages=data["language"], tags=data["tags"], license=data["license"], library_name=library_name, datasets=data["datasets"], pipeline_tag=pipeline_tag, labels=get_model_labels(hub_id), engagement_stats=_get_engagement_stats(hub_id), model_card_text=card_text, model_card_length=length, ) COMMON_SCORES = { "license": { "required": True, "score": 2, "missing_recommendation": ( "You have not added a license to your models metadata" ), }, "datasets": { "required": False, "score": 1, "missing_recommendation": ( "You have not added any datasets to your models metadata" ), }, "model_card_text": { "required": True, "score": 3, "missing_recommendation": """You haven't created a model card for your model. It is strongly recommended to have a model card for your model. \nYou can create for your model by clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)""", }, "tags": { "required": False, "score": 2, "missing_recommendation": ( "You don't have any tags defined in your model metadata. Tags can help" " people find relevant models on the Hub. You can create for your model by" " clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)" ), }, } TASK_TYPES_WITH_LANGUAGES = { "text-classification", "token-classification", "table-question-answering", "question-answering", "zero-shot-classification", "translation", "summarization", "text-generation", "text2text-generation", "fill-mask", "sentence-similarity", "text-to-speech", "automatic-speech-recognition", "text-to-image", "image-to-text", "visual-question-answering", "document-question-answering", } LABELS_REQUIRED_TASKS = { "text-classification", "token-classification", "object-detection", "audio-classification", "image-classification", "tabular-classification", } ALL_PIPELINES = { "audio-classification", "audio-to-audio", "automatic-speech-recognition", "conversational", "depth-estimation", "document-question-answering", "feature-extraction", "fill-mask", "graph-ml", "image-classification", "image-segmentation", "image-to-image", "image-to-text", "object-detection", "question-answering", "reinforcement-learning", "robotics", "sentence-similarity", "summarization", "table-question-answering", "tabular-classification", "tabular-regression", "text-classification", "text-generation", "text-to-image", "text-to-speech", "text-to-video", "text2text-generation", "token-classification", "translation", "unconditional-image-generation", "video-classification", "visual-question-answering", "voice-activity-detection", "zero-shot-classification", "zero-shot-image-classification", } @lru_cache(maxsize=None) def generate_task_scores_dict(): task_scores = {} for task in ALL_PIPELINES: task_dict = copy.deepcopy(COMMON_SCORES) if task in TASK_TYPES_WITH_LANGUAGES: task_dict = { **task_dict, **{ "languages": { "required": True, "score": 2, "missing_recommendation": ( "You haven't defined any languages in your metadata. This" f" is usually recommend for {task} task" ), } }, } if task in LABELS_REQUIRED_TASKS: task_dict = { **task_dict, **{ "labels": { "required": True, "score": 2, "missing_recommendation": ( "You haven't defined any labels in the config.json file" f" these are usually recommended for {task}" ), } }, } max_score = sum(value["score"] for value in task_dict.values()) task_dict["_max_score"] = max_score task_scores[task] = task_dict return task_scores @lru_cache(maxsize=None) def generate_common_scores(): GENERIC_SCORES = copy.deepcopy(COMMON_SCORES) GENERIC_SCORES["_max_score"] = sum( value["score"] for value in GENERIC_SCORES.values() ) return GENERIC_SCORES SCORES = generate_task_scores_dict() GENERIC_SCORES = generate_common_scores() # @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days def _basic_check(hub_id): data = ModelMetadata.from_hub(hub_id) score = 0 if data is None: return None to_fix = {} if task := data.pipeline_tag: task_scores = SCORES[task] data_dict = asdict(data) for k, v in task_scores.items(): if k.startswith("_"): continue if data_dict[k] is None: to_fix[k] = task_scores[k]["missing_recommendation"].replace( "HUB_ID", hub_id ) if data_dict[k] is not None: score += v["score"] max_score = task_scores["_max_score"] score = score / max_score ( f"Your model's metadata score is {round(score*100)}% based on suggested" f" metadata for {task}. \n" ) if to_fix: recommendations = ( "Here are some suggestions to improve your model's metadata for" f" {task}: \n" ) for v in to_fix.values(): recommendations += f"\n- {v}" data_dict["recommendations"] = recommendations data_dict["score"] = score * 100 else: data_dict = asdict(data) for k, v in GENERIC_SCORES.items(): if k.startswith("_"): continue if data_dict[k] is None: to_fix[k] = GENERIC_SCORES[k]["missing_recommendation"].replace( "HUB_ID", hub_id ) if data_dict[k] is not None: score += v["score"] score = score / GENERIC_SCORES["_max_score"] data_dict["score"] = max( 0, (score / 2) * 100 ) # TODO currently setting a manual penalty for not having a task return orjson.dumps(data_dict) def basic_check(hub_id): return _basic_check(hub_id) def create_query_url(query, skip=0): return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model" # @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days def get_results(query) -> Dict[Any, Any]: url = create_query_url(query) r = client.get(url) return r.json() @backoff.on_exception( backoff.expo, Exception, max_time=2, raise_on_giveup=False, ) def parse_single_result(result): name, filename = result["name"], result["fileName"] search_result_file_url = hf_hub_url(name, filename) repo_hub_url = f"https://huggingface.co/{name}" score = _basic_check(name) if score is None: return None score = orjson.loads(score) return { "name": name, "search_result_file_url": search_result_file_url, "repo_hub_url": repo_hub_url, "metadata_score": score["score"], "model_card_length": score["model_card_length"], "is_licensed": bool(score["license"]), # "metadata_report": score } def filter_search_results( results: List[Dict[Any, Any]], min_score=None, min_model_card_length=None ): # TODO make code more intuitive results = thread_map(parse_single_result, results) for i, parsed_result in tqdm(enumerate(results)): # parsed_result = parse_single_result(result) if parsed_result is None: continue if ( min_score is None and min_model_card_length is not None and parsed_result["model_card_length"] > min_model_card_length or min_score is None and min_model_card_length is None ): yield parsed_result elif min_score is not None: if parsed_result["metadata_score"] <= min_score: continue if ( min_model_card_length is not None and parsed_result["model_card_length"] > min_model_card_length or min_model_card_length is None ): parsed_result["original_position"] = i yield parsed_result def sort_search_results(filtered_search_results): return sorted( list(filtered_search_results), key=lambda x: (x["metadata_score"], x["original_position"]), reverse=True, ) def find_context(text, query, window_size): # Split the text into words words = text.split() # Find the index of the query token try: index = words.index(query) # Get the start and end indices of the context window start = max(0, index - window_size) end = min(len(words), index + window_size + 1) return " ".join(words[start:end]) except ValueError: return " ".join(words[:window_size]) # single_result[ # "text" # ] = "lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." # results = [single_result] * 3 def create_markdown(results): rows = [] for result in results: row = f"""# [{result['name']}]({result['repo_hub_url']}) | Metadata Quality Score | Model card length | Licensed | |------------------------|-------------------|----------| | {result['metadata_score']:.0f}% | {result['model_card_length']} | {"✅" if result['is_licensed'] else "❌"} | \n *{result['text']}*
\n""" rows.append(row) return "\n".join(rows) def get_result_card_snippet(result): try: result_text = httpx.get(result["search_result_file_url"]).text result["text"] = find_context(result_text, query, 100) except httpx.ConnectError: result["text"] = "Could not load model card" return result def _search_hub( query: str, min_score: Optional[int] = None, min_model_card_length: Optional[int] = None, ): results = get_results(query) print(f"Found {len(results['hits'])} results") results = results["hits"] number_original_results = len(results) filtered_results = filter_search_results( results, min_score=min_score, min_model_card_length=min_model_card_length ) filtered_results = sort_search_results(filtered_results) # final_results = [] # for result in filtered_results: # result_text = httpx.get(result["search_result_file_url"]).text # result["text"] = find_context(result_text, query, 100) # final_results.append(result) final_results = thread_map(get_result_card_snippet, filtered_results) percent_of_original = round( len(final_results) / number_original_results * 100, ndigits=0 ) filtered_vs_og = f""" | Number of original results | Number of results after filtering | Percentage of results after filtering | | -------------------------- | --------------------------------- | -------------------------------------------- | | {number_original_results} | {len(final_results)} | {percent_of_original}% | """ print(final_results) return filtered_vs_og, create_markdown(final_results) def search_hub(query: str, min_score=None, min_model_card_length=None): return _search_hub(query, min_score, min_model_card_length) with gr.Blocks() as demo: with gr.Tab("Hub search with metadata quality filter"): gr.Markdown("# 🤗 Hub model search with metadata quality filters") with gr.Row(): with gr.Column(): query = gr.Textbox("x-ray", label="Search query") with gr.Column(): button = gr.Button("Search") with gr.Row(): # gr.Checkbox(False, label="Must have licence?") mim_model_card_length = gr.Number( None, label="Minimum model card length" ) min_metadata_score = gr.Slider(0, label="Minimum metadata score") filter_results = gr.Markdown("Filter results vs original search") results_markdown = gr.Markdown("Search results") button.click( search_hub, [query, min_metadata_score, mim_model_card_length], [filter_results, results_markdown], ) with gr.Tab("Scoring metadata quality"): with gr.Row(): gr.Markdown( f""" # Metadata quality scoring ``` {COMMON_SCORES} ``` For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it is expected to have language metadata associated with the model. ``` {TASK_TYPES_WITH_LANGUAGES} ``` """ ) demo.launch() # with gr.Blocks() as demo: # gr.Markdown( # """ # # Model Metadata Checker # This app will check your model's metadata for a few common issues.""" # ) # with gr.Row(): # text = gr.Text(label="Model ID") # button = gr.Button(label="Check", type="submit") # with gr.Row(): # gr.Markdown("Results") # markdown = gr.JSON() # button.click(_basic_check, text, markdown) # demo.queue(concurrency_count=32) # demo.launch() # gr.Interface(fn=basic_check, inputs="text", outputs="markdown").launch(debug=True)