Spaces:
Runtime error
Runtime error
import os | |
import copy | |
from dataclasses import asdict, dataclass | |
from functools import lru_cache | |
from json import JSONDecodeError | |
from typing import Any, Dict, List, Optional, Union | |
from huggingface_hub.utils import GatedRepoError | |
import gradio as gr | |
from requests.exceptions import HTTPError | |
import requests | |
from diskcache import Cache | |
from huggingface_hub import ( | |
HfApi, | |
hf_hub_url, | |
list_repo_commits, | |
logging, | |
model_info, | |
) | |
from tqdm.auto import tqdm | |
from tqdm.contrib.concurrent import thread_map | |
import backoff | |
from huggingface_hub.utils import EntryNotFoundError, disable_progress_bars | |
import httpx | |
import orjson | |
import httpx | |
from functools import lru_cache | |
from httpx import Client | |
from httpx_caching import CachingClient | |
from httpx_caching import OneDayCacheHeuristic | |
from cachetools import cached, TTLCache | |
from datetime import timedelta | |
from datetime import datetime | |
cache = TTLCache(maxsize=500_000, ttl=timedelta(hours=24), timer=datetime.now) | |
client = Client() | |
client = CachingClient(client, heuristic=OneDayCacheHeuristic()) | |
# CACHE_DIR = "./cache" if platform == "darwin" else "/data/" | |
disable_progress_bars() | |
logging.set_verbosity_error() | |
token = os.getenv("HF_TOKEN") | |
# cache = Cache(CACHE_DIR) | |
def get_model_labels(model): | |
try: | |
url = hf_hub_url(repo_id=model, filename="config.json") | |
return list(requests.get(url).json()["label2id"].keys()) | |
except (KeyError, JSONDecodeError, AttributeError): | |
return None | |
class EngagementStats: | |
likes: int | |
downloads: int | |
created_at: datetime | |
def _get_engagement_stats(hub_id): | |
api = HfApi(token=token) | |
repo = api.repo_info(hub_id) | |
return EngagementStats( | |
likes=repo.likes, | |
downloads=repo.downloads, | |
created_at=list_repo_commits(hub_id, repo_type="model")[-1].created_at, | |
) | |
def _try_load_model_card(hub_id): | |
try: | |
url = hf_hub_url( | |
repo_id=hub_id, filename="README.md" | |
) # We grab card this way rather than via client library to improve performance | |
card_text = client.get(url).text | |
length = len(card_text) | |
except EntryNotFoundError: | |
card_text = None | |
length = None | |
except ( | |
GatedRepoError | |
): # TODO return different values to reflect gating rather than no card | |
card_text = None | |
length = None | |
return card_text, length | |
def _try_parse_card_data(hub_id): | |
data = {} | |
keys = ["license", "language", "datasets", "tags"] | |
for key in keys: | |
try: | |
value = model_info(hub_id, token=token).cardData[key] | |
data[key] = value | |
except (KeyError, AttributeError): | |
data[key] = None | |
return data | |
class ModelMetadata: | |
hub_id: str | |
tags: Optional[List[str]] | |
license: Optional[str] | |
library_name: Optional[str] | |
datasets: Optional[List[str]] | |
pipeline_tag: Optional[str] | |
labels: Optional[List[str]] | |
languages: Optional[Union[str, List[str]]] | |
engagement_stats: Optional[EngagementStats] = None | |
model_card_text: Optional[str] = None | |
model_card_length: Optional[int] = None | |
def from_hub(cls, hub_id): | |
try: | |
model = model_info(hub_id) | |
except (GatedRepoError, HTTPError): | |
return None # TODO catch gated repos and handle properly | |
card_text, length = _try_load_model_card(hub_id) | |
data = _try_parse_card_data(hub_id) | |
try: | |
library_name = model.library_name | |
except AttributeError: | |
library_name = None | |
try: | |
pipeline_tag = model.pipeline_tag | |
except AttributeError: | |
pipeline_tag = None | |
return ModelMetadata( | |
hub_id=hub_id, | |
languages=data["language"], | |
tags=data["tags"], | |
license=data["license"], | |
library_name=library_name, | |
datasets=data["datasets"], | |
pipeline_tag=pipeline_tag, | |
labels=get_model_labels(hub_id), | |
engagement_stats=_get_engagement_stats(hub_id), | |
model_card_text=card_text, | |
model_card_length=length, | |
) | |
COMMON_SCORES = { | |
"license": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You have not added a license to your models metadata" | |
), | |
}, | |
"datasets": { | |
"required": False, | |
"score": 1, | |
"missing_recommendation": ( | |
"You have not added any datasets to your models metadata" | |
), | |
}, | |
"model_card_text": { | |
"required": True, | |
"score": 3, | |
"missing_recommendation": """You haven't created a model card for your model. It is strongly recommended to have a model card for your model. \nYou can create for your model by clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)""", | |
}, | |
"tags": { | |
"required": False, | |
"score": 2, | |
"missing_recommendation": ( | |
"You don't have any tags defined in your model metadata. Tags can help" | |
" people find relevant models on the Hub. You can create for your model by" | |
" clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)" | |
), | |
}, | |
} | |
TASK_TYPES_WITH_LANGUAGES = { | |
"text-classification", | |
"token-classification", | |
"table-question-answering", | |
"question-answering", | |
"zero-shot-classification", | |
"translation", | |
"summarization", | |
"text-generation", | |
"text2text-generation", | |
"fill-mask", | |
"sentence-similarity", | |
"text-to-speech", | |
"automatic-speech-recognition", | |
"text-to-image", | |
"image-to-text", | |
"visual-question-answering", | |
"document-question-answering", | |
} | |
LABELS_REQUIRED_TASKS = { | |
"text-classification", | |
"token-classification", | |
"object-detection", | |
"audio-classification", | |
"image-classification", | |
"tabular-classification", | |
} | |
ALL_PIPELINES = { | |
"audio-classification", | |
"audio-to-audio", | |
"automatic-speech-recognition", | |
"conversational", | |
"depth-estimation", | |
"document-question-answering", | |
"feature-extraction", | |
"fill-mask", | |
"graph-ml", | |
"image-classification", | |
"image-segmentation", | |
"image-to-image", | |
"image-to-text", | |
"object-detection", | |
"question-answering", | |
"reinforcement-learning", | |
"robotics", | |
"sentence-similarity", | |
"summarization", | |
"table-question-answering", | |
"tabular-classification", | |
"tabular-regression", | |
"text-classification", | |
"text-generation", | |
"text-to-image", | |
"text-to-speech", | |
"text-to-video", | |
"text2text-generation", | |
"token-classification", | |
"translation", | |
"unconditional-image-generation", | |
"video-classification", | |
"visual-question-answering", | |
"voice-activity-detection", | |
"zero-shot-classification", | |
"zero-shot-image-classification", | |
} | |
def generate_task_scores_dict(): | |
task_scores = {} | |
for task in ALL_PIPELINES: | |
task_dict = copy.deepcopy(COMMON_SCORES) | |
if task in TASK_TYPES_WITH_LANGUAGES: | |
task_dict = { | |
**task_dict, | |
**{ | |
"languages": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any languages in your metadata. This" | |
f" is usually recommend for {task} task" | |
), | |
} | |
}, | |
} | |
if task in LABELS_REQUIRED_TASKS: | |
task_dict = { | |
**task_dict, | |
**{ | |
"labels": { | |
"required": True, | |
"score": 2, | |
"missing_recommendation": ( | |
"You haven't defined any labels in the config.json file" | |
f" these are usually recommended for {task}" | |
), | |
} | |
}, | |
} | |
max_score = sum(value["score"] for value in task_dict.values()) | |
task_dict["_max_score"] = max_score | |
task_scores[task] = task_dict | |
return task_scores | |
def generate_common_scores(): | |
GENERIC_SCORES = copy.deepcopy(COMMON_SCORES) | |
GENERIC_SCORES["_max_score"] = sum( | |
value["score"] for value in GENERIC_SCORES.values() | |
) | |
return GENERIC_SCORES | |
SCORES = generate_task_scores_dict() | |
GENERIC_SCORES = generate_common_scores() | |
# @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days | |
def _basic_check(hub_id): | |
data = ModelMetadata.from_hub(hub_id) | |
score = 0 | |
if data is None: | |
return None | |
to_fix = {} | |
if task := data.pipeline_tag: | |
task_scores = SCORES[task] | |
data_dict = asdict(data) | |
for k, v in task_scores.items(): | |
if k.startswith("_"): | |
continue | |
if data_dict[k] is None: | |
to_fix[k] = task_scores[k]["missing_recommendation"].replace( | |
"HUB_ID", hub_id | |
) | |
if data_dict[k] is not None: | |
score += v["score"] | |
max_score = task_scores["_max_score"] | |
score = score / max_score | |
( | |
f"Your model's metadata score is {round(score*100)}% based on suggested" | |
f" metadata for {task}. \n" | |
) | |
if to_fix: | |
recommendations = ( | |
"Here are some suggestions to improve your model's metadata for" | |
f" {task}: \n" | |
) | |
for v in to_fix.values(): | |
recommendations += f"\n- {v}" | |
data_dict["recommendations"] = recommendations | |
data_dict["score"] = score * 100 | |
else: | |
data_dict = asdict(data) | |
for k, v in GENERIC_SCORES.items(): | |
if k.startswith("_"): | |
continue | |
if data_dict[k] is None: | |
to_fix[k] = GENERIC_SCORES[k]["missing_recommendation"].replace( | |
"HUB_ID", hub_id | |
) | |
if data_dict[k] is not None: | |
score += v["score"] | |
score = score / GENERIC_SCORES["_max_score"] | |
data_dict["score"] = max( | |
0, (score / 2) * 100 | |
) # TODO currently setting a manual penalty for not having a task | |
return orjson.dumps(data_dict) | |
def basic_check(hub_id): | |
return _basic_check(hub_id) | |
def create_query_url(query, skip=0): | |
return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model" | |
def get_results(query) -> Dict[Any, Any]: | |
url = create_query_url(query) | |
r = client.get(url) | |
return r.json() | |
def parse_single_result(result): | |
name, filename = result["name"], result["fileName"] | |
search_result_file_url = hf_hub_url(name, filename) | |
repo_hub_url = f"https://huggingface.co/{name}" | |
score = _basic_check(name) | |
if score is None: | |
return None | |
score = orjson.loads(score) | |
return { | |
"name": name, | |
"search_result_file_url": search_result_file_url, | |
"repo_hub_url": repo_hub_url, | |
"metadata_score": score["score"], | |
"model_card_length": score["model_card_length"], | |
"is_licensed": bool(score["license"]), | |
# "metadata_report": score | |
} | |
def filter_for_license(results): | |
for result in results: | |
if result["is_licensed"]: | |
yield result | |
def filter_for_min_model_card_length(results, min_model_card_length): | |
for result in results: | |
if result["model_card_length"] > min_model_card_length: | |
yield result | |
def filter_search_results( | |
results: List[Dict[Any, Any]], | |
min_score=None, | |
min_model_card_length=None, | |
): # TODO make code more intuitive | |
results = thread_map(parse_single_result, results) | |
for i, parsed_result in tqdm(enumerate(results)): | |
# parsed_result = parse_single_result(result) | |
if parsed_result is None: | |
continue | |
if ( | |
min_score is None | |
and min_model_card_length is not None | |
and parsed_result["model_card_length"] > min_model_card_length | |
or min_score is None | |
and min_model_card_length is None | |
): | |
yield parsed_result | |
elif min_score is not None: | |
if parsed_result["metadata_score"] <= min_score: | |
continue | |
if ( | |
min_model_card_length is not None | |
and parsed_result["model_card_length"] > min_model_card_length | |
or min_model_card_length is None | |
): | |
parsed_result["original_position"] = i | |
yield parsed_result | |
def sort_search_results( | |
filtered_search_results, | |
first_sort="metadata_score", | |
second_sort="original_position", # TODO expose these in results | |
): | |
return sorted( | |
list(filtered_search_results), | |
key=lambda x: (x[first_sort], x[second_sort]), | |
reverse=True, | |
) | |
def find_context(text, query, window_size): | |
# Split the text into words | |
words = text.split() | |
# Find the index of the query token | |
try: | |
index = words.index(query) | |
# Get the start and end indices of the context window | |
start = max(0, index - window_size) | |
end = min(len(words), index + window_size + 1) | |
return " ".join(words[start:end]) | |
except ValueError: | |
return " ".join(words[:window_size]) | |
def create_markdown(results): # TODO move to separate file | |
rows = [] | |
for result in results: | |
row = f"""# [{result['name']}]({result['repo_hub_url']}) | |
| Metadata Quality Score | Model card length | Licensed | | |
|------------------------|-------------------|----------| | |
| {result['metadata_score']:.0f}% | {result['model_card_length']} | {"✅" if result['is_licensed'] else "❌"} | | |
\n | |
*{result['text']}* | |
<hr> | |
\n""" | |
rows.append(row) | |
return "\n".join(rows) | |
def get_result_card_snippet(result): | |
try: | |
result_text = httpx.get(result["search_result_file_url"]).text | |
result["text"] = find_context(result_text, query, 100) | |
except httpx.ConnectError: | |
result["text"] = "Could not load model card" | |
return result | |
def _search_hub( | |
query: str, | |
min_score: Optional[int] = None, | |
min_model_card_length: Optional[int] = None, | |
): | |
results = get_results(query) | |
print(f"Found {len(results['hits'])} results") | |
results = results["hits"] | |
number_original_results = len(results) | |
filtered_results = filter_search_results( | |
results, min_score=min_score, min_model_card_length=min_model_card_length | |
) | |
filtered_results = sort_search_results(filtered_results) | |
# final_results = [] | |
# for result in filtered_results: | |
# result_text = httpx.get(result["search_result_file_url"]).text | |
# result["text"] = find_context(result_text, query, 100) | |
# final_results.append(result) | |
final_results = thread_map(get_result_card_snippet, filtered_results) | |
percent_of_original = round( | |
len(final_results) / number_original_results * 100, ndigits=0 | |
) | |
filtered_vs_og = f""" | |
| Number of original results | Number of results after filtering | Percentage of results after filtering | | |
| -------------------------- | --------------------------------- | -------------------------------------------- | | |
| {number_original_results} | {len(final_results)} | {percent_of_original}% | | |
""" | |
print(final_results) | |
return filtered_vs_og, create_markdown(final_results) | |
def search_hub(query: str, min_score=None, min_model_card_length=None): | |
return _search_hub(query, min_score, min_model_card_length) | |
with gr.Blocks() as demo: | |
with gr.Tab("Hub search with metadata quality filter"): | |
gr.Markdown("# 🤗 Hub model search with metadata quality filters") | |
with gr.Row(): | |
with gr.Column(): | |
query = gr.Textbox("x-ray", label="Search query") | |
with gr.Column(): | |
button = gr.Button("Search") | |
with gr.Row(): | |
# gr.Checkbox(False, label="Must have licence?") | |
mim_model_card_length = gr.Number( | |
None, label="Minimum model card length" | |
) | |
min_metadata_score = gr.Slider(0, label="Minimum metadata score") | |
filter_results = gr.Markdown("Filter results vs original search") | |
results_markdown = gr.Markdown("Search results") | |
button.click( | |
search_hub, | |
[query, min_metadata_score, mim_model_card_length], | |
[filter_results, results_markdown], | |
) | |
# with gr.Tab("Scoring metadata quality"): | |
# with gr.Row(): | |
# gr.Markdown( | |
# f""" | |
# # Metadata quality scoring | |
# ``` | |
# {COMMON_SCORES} | |
# ``` | |
# For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it | |
# is expected to have language metadata associated with the model. | |
# ``` | |
# {TASK_TYPES_WITH_LANGUAGES} | |
# ``` | |
# """ | |
# ) | |
demo.launch() | |
# with gr.Blocks() as demo: | |
# gr.Markdown( | |
# """ | |
# # Model Metadata Checker | |
# This app will check your model's metadata for a few common issues.""" | |
# ) | |
# with gr.Row(): | |
# text = gr.Text(label="Model ID") | |
# button = gr.Button(label="Check", type="submit") | |
# with gr.Row(): | |
# gr.Markdown("Results") | |
# markdown = gr.JSON() | |
# button.click(_basic_check, text, markdown) | |
# demo.queue(concurrency_count=32) | |
# demo.launch() | |
# gr.Interface(fn=basic_check, inputs="text", outputs="markdown").launch(debug=True) | |