MetaRefine / app.py
davanstrien's picture
davanstrien HF Staff
tidy cache
8b71d33
raw
history blame
18.1 kB
import os
import copy
from dataclasses import asdict, dataclass
from functools import lru_cache
from json import JSONDecodeError
from typing import Any, Dict, List, Optional, Union
from huggingface_hub.utils import GatedRepoError
import gradio as gr
from requests.exceptions import HTTPError
import requests
from diskcache import Cache
from huggingface_hub import (
HfApi,
hf_hub_url,
list_repo_commits,
logging,
model_info,
)
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map
import backoff
from huggingface_hub.utils import EntryNotFoundError, disable_progress_bars
import httpx
import orjson
import httpx
from functools import lru_cache
from httpx import Client
from httpx_caching import CachingClient
from httpx_caching import OneDayCacheHeuristic
from cachetools import cached, TTLCache
from datetime import timedelta
from datetime import datetime
cache = TTLCache(maxsize=500_000, ttl=timedelta(hours=24), timer=datetime.now)
client = Client()
client = CachingClient(client, heuristic=OneDayCacheHeuristic())
# CACHE_DIR = "./cache" if platform == "darwin" else "/data/"
disable_progress_bars()
logging.set_verbosity_error()
token = os.getenv("HF_TOKEN")
# cache = Cache(CACHE_DIR)
def get_model_labels(model):
try:
url = hf_hub_url(repo_id=model, filename="config.json")
return list(requests.get(url).json()["label2id"].keys())
except (KeyError, JSONDecodeError, AttributeError):
return None
@dataclass
class EngagementStats:
likes: int
downloads: int
created_at: datetime
def _get_engagement_stats(hub_id):
api = HfApi(token=token)
repo = api.repo_info(hub_id)
return EngagementStats(
likes=repo.likes,
downloads=repo.downloads,
created_at=list_repo_commits(hub_id, repo_type="model")[-1].created_at,
)
def _try_load_model_card(hub_id):
try:
url = hf_hub_url(
repo_id=hub_id, filename="README.md"
) # We grab card this way rather than via client library to improve performance
card_text = client.get(url).text
length = len(card_text)
except EntryNotFoundError:
card_text = None
length = None
except (
GatedRepoError
): # TODO return different values to reflect gating rather than no card
card_text = None
length = None
return card_text, length
def _try_parse_card_data(hub_id):
data = {}
keys = ["license", "language", "datasets", "tags"]
for key in keys:
try:
value = model_info(hub_id, token=token).cardData[key]
data[key] = value
except (KeyError, AttributeError):
data[key] = None
return data
@dataclass
class ModelMetadata:
hub_id: str
tags: Optional[List[str]]
license: Optional[str]
library_name: Optional[str]
datasets: Optional[List[str]]
pipeline_tag: Optional[str]
labels: Optional[List[str]]
languages: Optional[Union[str, List[str]]]
engagement_stats: Optional[EngagementStats] = None
model_card_text: Optional[str] = None
model_card_length: Optional[int] = None
@classmethod
def from_hub(cls, hub_id):
try:
model = model_info(hub_id)
except (GatedRepoError, HTTPError):
return None # TODO catch gated repos and handle properly
card_text, length = _try_load_model_card(hub_id)
data = _try_parse_card_data(hub_id)
try:
library_name = model.library_name
except AttributeError:
library_name = None
try:
pipeline_tag = model.pipeline_tag
except AttributeError:
pipeline_tag = None
return ModelMetadata(
hub_id=hub_id,
languages=data["language"],
tags=data["tags"],
license=data["license"],
library_name=library_name,
datasets=data["datasets"],
pipeline_tag=pipeline_tag,
labels=get_model_labels(hub_id),
engagement_stats=_get_engagement_stats(hub_id),
model_card_text=card_text,
model_card_length=length,
)
COMMON_SCORES = {
"license": {
"required": True,
"score": 2,
"missing_recommendation": (
"You have not added a license to your models metadata"
),
},
"datasets": {
"required": False,
"score": 1,
"missing_recommendation": (
"You have not added any datasets to your models metadata"
),
},
"model_card_text": {
"required": True,
"score": 3,
"missing_recommendation": """You haven't created a model card for your model. It is strongly recommended to have a model card for your model. \nYou can create for your model by clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)""",
},
"tags": {
"required": False,
"score": 2,
"missing_recommendation": (
"You don't have any tags defined in your model metadata. Tags can help"
" people find relevant models on the Hub. You can create for your model by"
" clicking [here](https://huggingface.co/HUB_ID/edit/main/README.md)"
),
},
}
TASK_TYPES_WITH_LANGUAGES = {
"text-classification",
"token-classification",
"table-question-answering",
"question-answering",
"zero-shot-classification",
"translation",
"summarization",
"text-generation",
"text2text-generation",
"fill-mask",
"sentence-similarity",
"text-to-speech",
"automatic-speech-recognition",
"text-to-image",
"image-to-text",
"visual-question-answering",
"document-question-answering",
}
LABELS_REQUIRED_TASKS = {
"text-classification",
"token-classification",
"object-detection",
"audio-classification",
"image-classification",
"tabular-classification",
}
ALL_PIPELINES = {
"audio-classification",
"audio-to-audio",
"automatic-speech-recognition",
"conversational",
"depth-estimation",
"document-question-answering",
"feature-extraction",
"fill-mask",
"graph-ml",
"image-classification",
"image-segmentation",
"image-to-image",
"image-to-text",
"object-detection",
"question-answering",
"reinforcement-learning",
"robotics",
"sentence-similarity",
"summarization",
"table-question-answering",
"tabular-classification",
"tabular-regression",
"text-classification",
"text-generation",
"text-to-image",
"text-to-speech",
"text-to-video",
"text2text-generation",
"token-classification",
"translation",
"unconditional-image-generation",
"video-classification",
"visual-question-answering",
"voice-activity-detection",
"zero-shot-classification",
"zero-shot-image-classification",
}
@lru_cache(maxsize=None)
def generate_task_scores_dict():
task_scores = {}
for task in ALL_PIPELINES:
task_dict = copy.deepcopy(COMMON_SCORES)
if task in TASK_TYPES_WITH_LANGUAGES:
task_dict = {
**task_dict,
**{
"languages": {
"required": True,
"score": 2,
"missing_recommendation": (
"You haven't defined any languages in your metadata. This"
f" is usually recommend for {task} task"
),
}
},
}
if task in LABELS_REQUIRED_TASKS:
task_dict = {
**task_dict,
**{
"labels": {
"required": True,
"score": 2,
"missing_recommendation": (
"You haven't defined any labels in the config.json file"
f" these are usually recommended for {task}"
),
}
},
}
max_score = sum(value["score"] for value in task_dict.values())
task_dict["_max_score"] = max_score
task_scores[task] = task_dict
return task_scores
@lru_cache(maxsize=None)
def generate_common_scores():
GENERIC_SCORES = copy.deepcopy(COMMON_SCORES)
GENERIC_SCORES["_max_score"] = sum(
value["score"] for value in GENERIC_SCORES.values()
)
return GENERIC_SCORES
SCORES = generate_task_scores_dict()
GENERIC_SCORES = generate_common_scores()
# @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days
@cached(cache)
def _basic_check(hub_id):
data = ModelMetadata.from_hub(hub_id)
score = 0
if data is None:
return None
to_fix = {}
if task := data.pipeline_tag:
task_scores = SCORES[task]
data_dict = asdict(data)
for k, v in task_scores.items():
if k.startswith("_"):
continue
if data_dict[k] is None:
to_fix[k] = task_scores[k]["missing_recommendation"].replace(
"HUB_ID", hub_id
)
if data_dict[k] is not None:
score += v["score"]
max_score = task_scores["_max_score"]
score = score / max_score
(
f"Your model's metadata score is {round(score*100)}% based on suggested"
f" metadata for {task}. \n"
)
if to_fix:
recommendations = (
"Here are some suggestions to improve your model's metadata for"
f" {task}: \n"
)
for v in to_fix.values():
recommendations += f"\n- {v}"
data_dict["recommendations"] = recommendations
data_dict["score"] = score * 100
else:
data_dict = asdict(data)
for k, v in GENERIC_SCORES.items():
if k.startswith("_"):
continue
if data_dict[k] is None:
to_fix[k] = GENERIC_SCORES[k]["missing_recommendation"].replace(
"HUB_ID", hub_id
)
if data_dict[k] is not None:
score += v["score"]
score = score / GENERIC_SCORES["_max_score"]
data_dict["score"] = max(
0, (score / 2) * 100
) # TODO currently setting a manual penalty for not having a task
return orjson.dumps(data_dict)
def basic_check(hub_id):
return _basic_check(hub_id)
def create_query_url(query, skip=0):
return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
@cached(cache)
def get_results(query) -> Dict[Any, Any]:
url = create_query_url(query)
r = client.get(url)
return r.json()
@backoff.on_exception(
backoff.expo,
Exception,
max_time=2,
raise_on_giveup=False,
)
def parse_single_result(result):
name, filename = result["name"], result["fileName"]
search_result_file_url = hf_hub_url(name, filename)
repo_hub_url = f"https://huggingface.co/{name}"
score = _basic_check(name)
if score is None:
return None
score = orjson.loads(score)
return {
"name": name,
"search_result_file_url": search_result_file_url,
"repo_hub_url": repo_hub_url,
"metadata_score": score["score"],
"model_card_length": score["model_card_length"],
"is_licensed": bool(score["license"]),
# "metadata_report": score
}
def filter_for_license(results):
for result in results:
if result["is_licensed"]:
yield result
def filter_for_min_model_card_length(results, min_model_card_length):
for result in results:
if result["model_card_length"] > min_model_card_length:
yield result
def filter_search_results(
results: List[Dict[Any, Any]],
min_score=None,
min_model_card_length=None,
): # TODO make code more intuitive
results = thread_map(parse_single_result, results)
for i, parsed_result in tqdm(enumerate(results)):
# parsed_result = parse_single_result(result)
if parsed_result is None:
continue
if (
min_score is None
and min_model_card_length is not None
and parsed_result["model_card_length"] > min_model_card_length
or min_score is None
and min_model_card_length is None
):
yield parsed_result
elif min_score is not None:
if parsed_result["metadata_score"] <= min_score:
continue
if (
min_model_card_length is not None
and parsed_result["model_card_length"] > min_model_card_length
or min_model_card_length is None
):
parsed_result["original_position"] = i
yield parsed_result
def sort_search_results(
filtered_search_results,
first_sort="metadata_score",
second_sort="original_position", # TODO expose these in results
):
return sorted(
list(filtered_search_results),
key=lambda x: (x[first_sort], x[second_sort]),
reverse=True,
)
def find_context(text, query, window_size):
# Split the text into words
words = text.split()
# Find the index of the query token
try:
index = words.index(query)
# Get the start and end indices of the context window
start = max(0, index - window_size)
end = min(len(words), index + window_size + 1)
return " ".join(words[start:end])
except ValueError:
return " ".join(words[:window_size])
def create_markdown(results): # TODO move to separate file
rows = []
for result in results:
row = f"""# [{result['name']}]({result['repo_hub_url']})
| Metadata Quality Score | Model card length | Licensed |
|------------------------|-------------------|----------|
| {result['metadata_score']:.0f}% | {result['model_card_length']} | {"&#9989;" if result['is_licensed'] else "&#10060;"} |
\n
*{result['text']}*
<hr>
\n"""
rows.append(row)
return "\n".join(rows)
def get_result_card_snippet(result):
try:
result_text = httpx.get(result["search_result_file_url"]).text
result["text"] = find_context(result_text, query, 100)
except httpx.ConnectError:
result["text"] = "Could not load model card"
return result
def _search_hub(
query: str,
min_score: Optional[int] = None,
min_model_card_length: Optional[int] = None,
):
results = get_results(query)
print(f"Found {len(results['hits'])} results")
results = results["hits"]
number_original_results = len(results)
filtered_results = filter_search_results(
results, min_score=min_score, min_model_card_length=min_model_card_length
)
filtered_results = sort_search_results(filtered_results)
# final_results = []
# for result in filtered_results:
# result_text = httpx.get(result["search_result_file_url"]).text
# result["text"] = find_context(result_text, query, 100)
# final_results.append(result)
final_results = thread_map(get_result_card_snippet, filtered_results)
percent_of_original = round(
len(final_results) / number_original_results * 100, ndigits=0
)
filtered_vs_og = f"""
| Number of original results | Number of results after filtering | Percentage of results after filtering |
| -------------------------- | --------------------------------- | -------------------------------------------- |
| {number_original_results} | {len(final_results)} | {percent_of_original}% |
"""
print(final_results)
return filtered_vs_og, create_markdown(final_results)
def search_hub(query: str, min_score=None, min_model_card_length=None):
return _search_hub(query, min_score, min_model_card_length)
with gr.Blocks() as demo:
with gr.Tab("Hub search with metadata quality filter"):
gr.Markdown("# &#129303; Hub model search with metadata quality filters")
with gr.Row():
with gr.Column():
query = gr.Textbox("x-ray", label="Search query")
with gr.Column():
button = gr.Button("Search")
with gr.Row():
# gr.Checkbox(False, label="Must have licence?")
mim_model_card_length = gr.Number(
None, label="Minimum model card length"
)
min_metadata_score = gr.Slider(0, label="Minimum metadata score")
filter_results = gr.Markdown("Filter results vs original search")
results_markdown = gr.Markdown("Search results")
button.click(
search_hub,
[query, min_metadata_score, mim_model_card_length],
[filter_results, results_markdown],
)
# with gr.Tab("Scoring metadata quality"):
# with gr.Row():
# gr.Markdown(
# f"""
# # Metadata quality scoring
# ```
# {COMMON_SCORES}
# ```
# For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it
# is expected to have language metadata associated with the model.
# ```
# {TASK_TYPES_WITH_LANGUAGES}
# ```
# """
# )
demo.launch()
# with gr.Blocks() as demo:
# gr.Markdown(
# """
# # Model Metadata Checker
# This app will check your model's metadata for a few common issues."""
# )
# with gr.Row():
# text = gr.Text(label="Model ID")
# button = gr.Button(label="Check", type="submit")
# with gr.Row():
# gr.Markdown("Results")
# markdown = gr.JSON()
# button.click(_basic_check, text, markdown)
# demo.queue(concurrency_count=32)
# demo.launch()
# gr.Interface(fn=basic_check, inputs="text", outputs="markdown").launch(debug=True)