|
import tempfile |
|
import csv |
|
import pandas as pd |
|
import gradio as gr |
|
from huggingface_hub import HfApi |
|
from pathlib import Path |
|
|
|
def get_model_stats(search_term): |
|
|
|
api = HfApi() |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
output_file = Path(temp_dir) / f"{search_term}_models_alltime.csv" |
|
|
|
|
|
print(f"Fetching {search_term} models with download statistics...") |
|
models_generator = api.list_models( |
|
search=search_term, |
|
expand=["downloads", "downloadsAllTime"], |
|
sort="_id" |
|
) |
|
|
|
|
|
total_30day_downloads = 0 |
|
total_alltime_downloads = 0 |
|
|
|
|
|
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile: |
|
csv_writer = csv.writer(csvfile) |
|
|
|
csv_writer.writerow(["Model ID", "Downloads (30 days)", "Downloads (All Time)"]) |
|
|
|
|
|
model_count = 0 |
|
for model in models_generator: |
|
|
|
downloads_30day = getattr(model, 'downloads', 0) |
|
downloads_alltime = getattr(model, 'downloads_all_time', 0) |
|
|
|
|
|
total_30day_downloads += downloads_30day |
|
total_alltime_downloads += downloads_alltime |
|
|
|
|
|
csv_writer.writerow([ |
|
getattr(model, 'id', "Unknown"), |
|
downloads_30day, |
|
downloads_alltime |
|
]) |
|
model_count += 1 |
|
|
|
|
|
df = pd.read_csv(output_file) |
|
|
|
|
|
status_message = ( |
|
f"Found {model_count} models for search term '{search_term}'\n" |
|
f"Total 30-day downloads: {total_30day_downloads:,}\n" |
|
f"Total all-time downloads: {total_alltime_downloads:,}" |
|
) |
|
|
|
|
|
return df, status_message, str(output_file) |
|
|
|
|
|
with gr.Blocks(title="Hugging Face Model Statistics") as demo: |
|
gr.Markdown("# Hugging Face Model Statistics") |
|
gr.Markdown("Enter a search term to find model statistics from Hugging Face Hub") |
|
|
|
with gr.Row(): |
|
search_input = gr.Textbox( |
|
label="Search Term", |
|
placeholder="Enter a model name or keyword (e.g., 'gemma', 'llama')", |
|
value="gemma" |
|
) |
|
search_button = gr.Button("Search") |
|
|
|
with gr.Row(): |
|
output_table = gr.Dataframe( |
|
headers=["Model ID", "Downloads (30 days)", "Downloads (All Time)"], |
|
datatype=["str", "number", "number"], |
|
label="Model Statistics" |
|
) |
|
status_message = gr.Textbox(label="Status", lines=3) |
|
|
|
with gr.Row(): |
|
download_button = gr.Button("Download CSV") |
|
csv_file = gr.File(label="CSV File", visible=False) |
|
|
|
|
|
csv_path = gr.State() |
|
|
|
search_button.click( |
|
fn=get_model_stats, |
|
inputs=search_input, |
|
outputs=[output_table, status_message, csv_path] |
|
) |
|
|
|
download_button.click( |
|
fn=lambda x: x, |
|
inputs=csv_path, |
|
outputs=csv_file |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |