Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,8 @@ import spaces
|
|
9 |
from cachetools import TTLCache
|
10 |
from cachetools.func import ttl_cache
|
11 |
import time
|
|
|
|
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
@@ -94,25 +96,19 @@ def generate_summary(card_text: str, card_type: str) -> str:
|
|
94 |
"""Cached wrapper for generate_summary with TTL."""
|
95 |
return _generate_summary_gpu(card_text, card_type)
|
96 |
|
97 |
-
def summarize(hub_id: str = ""
|
98 |
"""Interface function for Gradio. Returns JSON format."""
|
99 |
try:
|
100 |
if hub_id:
|
101 |
-
# Fetch and infer card type
|
102 |
-
|
103 |
-
#
|
104 |
-
|
105 |
-
|
106 |
-
return f'{{"error": "{error_msg}"}}'
|
107 |
-
card_type = inferred_type
|
108 |
else:
|
109 |
error_msg = "Error: Hub ID must be provided"
|
110 |
return f'{{"error": "{error_msg}"}}'
|
111 |
|
112 |
-
# Use the cached wrapper
|
113 |
-
summary = generate_summary(card_text, card_type)
|
114 |
-
return f'{{"summary": "{summary}", "type": "{card_type}", "hub_id": "{hub_id}"}}'
|
115 |
-
|
116 |
except Exception as e:
|
117 |
error_msg = str(e)
|
118 |
return f'{{"error": "{error_msg}"}}'
|
@@ -120,13 +116,10 @@ def summarize(hub_id: str = "", card_type: str = None) -> str:
|
|
120 |
def create_interface():
|
121 |
interface = gr.Interface(
|
122 |
fn=summarize,
|
123 |
-
inputs=
|
124 |
-
gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
|
125 |
-
gr.Radio(choices=["model", "dataset", None], label="Card Type (optional)", value=None),
|
126 |
-
],
|
127 |
outputs=gr.JSON(label="Output"),
|
128 |
title="Hugging Face Hub TLDR Generator",
|
129 |
-
description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.
|
130 |
)
|
131 |
return interface
|
132 |
|
|
|
9 |
from cachetools import TTLCache
|
10 |
from cachetools.func import ttl_cache
|
11 |
import time
|
12 |
+
import os
|
13 |
+
os.environ['HF_TRANSFER'] = "1"
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
|
|
96 |
"""Cached wrapper for generate_summary with TTL."""
|
97 |
return _generate_summary_gpu(card_text, card_type)
|
98 |
|
99 |
+
def summarize(hub_id: str = "") -> str:
|
100 |
"""Interface function for Gradio. Returns JSON format."""
|
101 |
try:
|
102 |
if hub_id:
|
103 |
+
# Fetch and infer card type automatically
|
104 |
+
card_type, card_text = get_card_info(hub_id)
|
105 |
+
# Use the cached wrapper
|
106 |
+
summary = generate_summary(card_text, card_type)
|
107 |
+
return f'{{"summary": "{summary}", "type": "{card_type}", "hub_id": "{hub_id}"}}'
|
|
|
|
|
108 |
else:
|
109 |
error_msg = "Error: Hub ID must be provided"
|
110 |
return f'{{"error": "{error_msg}"}}'
|
111 |
|
|
|
|
|
|
|
|
|
112 |
except Exception as e:
|
113 |
error_msg = str(e)
|
114 |
return f'{{"error": "{error_msg}"}}'
|
|
|
116 |
def create_interface():
|
117 |
interface = gr.Interface(
|
118 |
fn=summarize,
|
119 |
+
inputs=gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
|
|
|
|
|
|
|
120 |
outputs=gr.JSON(label="Output"),
|
121 |
title="Hugging Face Hub TLDR Generator",
|
122 |
+
description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
|
123 |
)
|
124 |
return interface
|
125 |
|