add counter to total
Browse files- hf-modelf-family-stats-gradio.py +107 -0
hf-modelf-family-stats-gradio.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import csv
|
3 |
+
import pandas as pd
|
4 |
+
import gradio as gr
|
5 |
+
from huggingface_hub import HfApi
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
def get_model_stats(search_term):
|
9 |
+
# Initialize the API
|
10 |
+
api = HfApi()
|
11 |
+
|
12 |
+
# Create a temporary file for the CSV
|
13 |
+
temp_dir = tempfile.mkdtemp()
|
14 |
+
output_file = Path(temp_dir) / f"{search_term}_models_alltime.csv"
|
15 |
+
|
16 |
+
# Get the generator of models with the working sort parameter
|
17 |
+
print(f"Fetching {search_term} models with download statistics...")
|
18 |
+
models_generator = api.list_models(
|
19 |
+
search=search_term,
|
20 |
+
expand=["downloads", "downloadsAllTime"], # Get both 30-day and all-time downloads
|
21 |
+
sort="_id" # Sort by ID to avoid timeout issues
|
22 |
+
)
|
23 |
+
|
24 |
+
# Initialize counters for total downloads
|
25 |
+
total_30day_downloads = 0
|
26 |
+
total_alltime_downloads = 0
|
27 |
+
|
28 |
+
# Create and write to CSV
|
29 |
+
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
|
30 |
+
csv_writer = csv.writer(csvfile)
|
31 |
+
# Write header
|
32 |
+
csv_writer.writerow(["Model ID", "Downloads (30 days)", "Downloads (All Time)"])
|
33 |
+
|
34 |
+
# Process models
|
35 |
+
model_count = 0
|
36 |
+
for model in models_generator:
|
37 |
+
# Get download counts
|
38 |
+
downloads_30day = getattr(model, 'downloads', 0)
|
39 |
+
downloads_alltime = getattr(model, 'downloads_all_time', 0)
|
40 |
+
|
41 |
+
# Add to totals
|
42 |
+
total_30day_downloads += downloads_30day
|
43 |
+
total_alltime_downloads += downloads_alltime
|
44 |
+
|
45 |
+
# Write to CSV
|
46 |
+
csv_writer.writerow([
|
47 |
+
getattr(model, 'id', "Unknown"),
|
48 |
+
downloads_30day,
|
49 |
+
downloads_alltime
|
50 |
+
])
|
51 |
+
model_count += 1
|
52 |
+
|
53 |
+
# Read the CSV file into a pandas DataFrame
|
54 |
+
df = pd.read_csv(output_file)
|
55 |
+
|
56 |
+
# Create status message with total downloads
|
57 |
+
status_message = (
|
58 |
+
f"Found {model_count} models for search term '{search_term}'\n"
|
59 |
+
f"Total 30-day downloads: {total_30day_downloads:,}\n"
|
60 |
+
f"Total all-time downloads: {total_alltime_downloads:,}"
|
61 |
+
)
|
62 |
+
|
63 |
+
# Return both the DataFrame, status message, and the CSV file path
|
64 |
+
return df, status_message, str(output_file)
|
65 |
+
|
66 |
+
# Create the Gradio interface
|
67 |
+
with gr.Blocks(title="Hugging Face Model Statistics") as demo:
|
68 |
+
gr.Markdown("# Hugging Face Model Statistics")
|
69 |
+
gr.Markdown("Enter a search term to find model statistics from Hugging Face Hub")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
search_input = gr.Textbox(
|
73 |
+
label="Search Term",
|
74 |
+
placeholder="Enter a model name or keyword (e.g., 'gemma', 'llama')",
|
75 |
+
value="gemma"
|
76 |
+
)
|
77 |
+
search_button = gr.Button("Search")
|
78 |
+
|
79 |
+
with gr.Row():
|
80 |
+
output_table = gr.Dataframe(
|
81 |
+
headers=["Model ID", "Downloads (30 days)", "Downloads (All Time)"],
|
82 |
+
datatype=["str", "number", "number"],
|
83 |
+
label="Model Statistics"
|
84 |
+
)
|
85 |
+
status_message = gr.Textbox(label="Status", lines=3) # Increased lines to show all stats
|
86 |
+
|
87 |
+
with gr.Row():
|
88 |
+
download_button = gr.Button("Download CSV")
|
89 |
+
csv_file = gr.File(label="CSV File", visible=False)
|
90 |
+
|
91 |
+
# Store the CSV file path in a state
|
92 |
+
csv_path = gr.State()
|
93 |
+
|
94 |
+
search_button.click(
|
95 |
+
fn=get_model_stats,
|
96 |
+
inputs=search_input,
|
97 |
+
outputs=[output_table, status_message, csv_path]
|
98 |
+
)
|
99 |
+
|
100 |
+
download_button.click(
|
101 |
+
fn=lambda x: x,
|
102 |
+
inputs=csv_path,
|
103 |
+
outputs=csv_file
|
104 |
+
)
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
demo.launch()
|