|
import asyncio |
|
import os |
|
import re |
|
from typing import Dict |
|
|
|
import gradio as gr |
|
import httpx |
|
from cachetools import TTLCache, cached |
|
from cashews import NOT_NONE, cache |
|
from dotenv import load_dotenv |
|
from httpx import AsyncClient, Limits |
|
from huggingface_hub import ( |
|
ModelCard, |
|
ModelFilter, |
|
get_repo_discussions, |
|
hf_hub_url, |
|
list_models, |
|
logging, |
|
) |
|
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError |
|
from tqdm.asyncio import tqdm as atqdm |
|
from tqdm.auto import tqdm |
|
import random |
|
|
|
cache.setup("mem://") |
|
|
|
|
|
load_dotenv() |
|
token = os.environ["HUGGINGFACE_TOKEN"] |
|
user_agent = os.environ["USER_AGENT"] |
|
assert token |
|
assert user_agent |
|
|
|
headers = {"user-agent": user_agent, "authorization": f"Bearer {token}"} |
|
|
|
limits = Limits(max_keepalive_connections=10, max_connections=50) |
|
|
|
|
|
def create_client(): |
|
return AsyncClient(headers=headers, limits=limits, http2=True) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=60 * 10)) |
|
def get_models(user_or_org): |
|
model_filter = ModelFilter(library="transformers", author=user_or_org) |
|
return list( |
|
tqdm( |
|
iter( |
|
list_models( |
|
filter=model_filter, |
|
sort="downloads", |
|
direction=-1, |
|
cardData=True, |
|
full=True, |
|
) |
|
) |
|
) |
|
) |
|
|
|
|
|
def filter_models(models): |
|
new_models = [] |
|
for model in tqdm(models): |
|
try: |
|
if card_data := model.cardData: |
|
base_model = card_data.get("base_model", None) |
|
if not base_model: |
|
new_models.append(model) |
|
except AttributeError: |
|
continue |
|
return new_models |
|
|
|
|
|
MODEL_ID_RE_PATTERN = re.compile( |
|
"This model is a fine-tuned version of \[(.*?)\]\(.*?\)" |
|
) |
|
BASE_MODEL_PATTERN = re.compile("base_model:\s+(.+)") |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=60 * 3)) |
|
def has_model_card(model): |
|
if siblings := model.siblings: |
|
for sibling in siblings: |
|
if sibling.rfilename == "README.md": |
|
return True |
|
return False |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=60)) |
|
def check_already_has_base_model(text): |
|
return bool(re.search(BASE_MODEL_PATTERN, text)) |
|
|
|
|
|
@cached(cache=TTLCache(maxsize=100, ttl=60)) |
|
def extract_model_name(text): |
|
return match.group(1) if (match := re.search(MODEL_ID_RE_PATTERN, text)) else None |
|
|
|
|
|
|
|
|
|
|
|
@cache(ttl=120, condition=NOT_NONE) |
|
async def check_readme_for_match(model): |
|
if not has_model_card(model): |
|
return None |
|
model_card_url = hf_hub_url(model.modelId, "README.md") |
|
client = create_client() |
|
try: |
|
resp = await client.get(model_card_url) |
|
if check_already_has_base_model(resp.text): |
|
return None |
|
else: |
|
return None if resp.status_code != 200 else extract_model_name(resp.text) |
|
except httpx.ConnectError: |
|
return None |
|
except httpx.ReadTimeout: |
|
return None |
|
except httpx.ConnectTimeout: |
|
return None |
|
except Exception as e: |
|
print(e) |
|
return None |
|
|
|
|
|
@cache(ttl=120, condition=NOT_NONE) |
|
async def check_model_exists(model, match): |
|
client = create_client() |
|
url = f"https://huggingface.co/api/models/{match}" |
|
try: |
|
resp = await client.get(url) |
|
if resp.status_code == 200: |
|
return {"modelid": model.modelId, "match": match} |
|
if resp.status_code == 401: |
|
return False |
|
except httpx.ConnectError: |
|
return None |
|
except httpx.ReadTimeout: |
|
return None |
|
except httpx.ConnectTimeout: |
|
return None |
|
except Exception as e: |
|
print(e) |
|
return None |
|
|
|
|
|
@cache(ttl=120, condition=NOT_NONE) |
|
async def check_model(model): |
|
match = await check_readme_for_match(model) |
|
if match: |
|
return await check_model_exists(model, match) |
|
|
|
|
|
async def prep_tasks(models): |
|
tasks = [] |
|
for model in models: |
|
task = asyncio.create_task(check_model(model)) |
|
tasks.append(task) |
|
return [await f for f in atqdm.as_completed(tasks)] |
|
|
|
|
|
def get_data_for_user(user_or_org): |
|
models = get_models(user_or_org) |
|
models = filter_models(models) |
|
results = asyncio.run(prep_tasks(models)) |
|
results = [r for r in results if r is not None] |
|
return results |
|
|
|
|
|
logger = logging.get_logger() |
|
|
|
token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
|
|
def generate_issue_text(based_model_regex_match, opened_by=None): |
|
return f"""This pull request aims to enrich the metadata of your model by adding [`{based_model_regex_match}`](https://huggingface.co/{based_model_regex_match}) as a `base_model` field, situated in the `YAML` block of your model's `README.md`. |
|
|
|
How did we find this information? We performed a regular expression match on your `README.md` file to determine the connection. |
|
|
|
**Why add this?** Enhancing your model's metadata in this way: |
|
- **Boosts Discoverability** - It becomes straightforward to trace the relationships between various models on the Hugging Face Hub. |
|
- **Highlights Impact** - It showcases the contributions and influences different models have within the community. |
|
|
|
For a hands-on example of how such metadata can play a pivotal role in mapping model connections, take a look at [librarian-bots/base_model_explorer](https://huggingface.co/spaces/librarian-bots/base_model_explorer). |
|
|
|
This PR comes courtesy of [Librarian Bot](https://huggingface.co/librarian-bot) by request of {opened_by}""" |
|
|
|
|
|
def update_metadata(metadata_payload: Dict[str, str], user_making_request=None): |
|
metadata_payload["opened_pr"] = False |
|
regex_match = metadata_payload["match"] |
|
repo_id = metadata_payload["modelid"] |
|
try: |
|
model_card = ModelCard.load(repo_id) |
|
except RepositoryNotFoundError: |
|
return metadata_payload |
|
model_card.data["base_model"] = regex_match |
|
template = generate_issue_text(regex_match, opened_by=user_making_request) |
|
try: |
|
if previous_discussions := list(get_repo_discussions(repo_id)): |
|
logger.info("found previous discussions") |
|
if prs := [ |
|
discussion |
|
for discussion in previous_discussions |
|
if discussion.is_pull_request |
|
]: |
|
logger.info("found previous pull requests") |
|
for pr in prs: |
|
if pr.author == "librarian-bot": |
|
logger.info("previously opened PR") |
|
if ( |
|
pr.title |
|
== "Librarian Bot: Add base_model information to model" |
|
): |
|
logger.info("previously opened PR to add base_model tag") |
|
metadata_payload["opened_pr"] = True |
|
return metadata_payload |
|
model_card.push_to_hub( |
|
repo_id, |
|
token=token, |
|
repo_type="model", |
|
create_pr=True, |
|
commit_message="Librarian Bot: Add base_model information to model", |
|
commit_description=template, |
|
) |
|
metadata_payload["opened_pr"] = True |
|
return metadata_payload |
|
except HfHubHTTPError: |
|
return metadata_payload |
|
|
|
|
|
def open_prs(profile: gr.OAuthProfile | None, user_or_org: str = None): |
|
if not profile: |
|
return "Please login to open PR requests" |
|
username = profile.preferred_username |
|
user_to_receive_prs = user_or_org or username |
|
data = get_data_for_user(user_to_receive_prs) |
|
if user_or_org: |
|
random.sample(data, min(10, len(data))) |
|
if not data: |
|
return "No PRs to open" |
|
results = [] |
|
for metadata_payload in data: |
|
try: |
|
results.append( |
|
update_metadata(metadata_payload, user_making_request=username) |
|
) |
|
|
|
except Exception as e: |
|
logger.error(e) |
|
return f"Opened {len([r for r in results if r['opened_pr']])} PRs" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Librarian Bot") |
|
gr.LoginButton(), gr.LogoutButton() |
|
user = gr.Textbox(label="user or org to Open PRs for") |
|
button = gr.Button() |
|
results = gr.Markdown() |
|
button.click(open_prs, [user], results) |
|
|
|
|
|
demo.queue(concurrency_count=1).launch() |
|
|