Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ from cachetools import TTLCache
|
|
10 |
from cachetools.func import ttl_cache
|
11 |
import time
|
12 |
import os
|
|
|
13 |
os.environ['HF_TRANSFER'] = "1"
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
@@ -37,22 +38,40 @@ def load_model():
|
|
37 |
logger.error(f"Failed to load model: {e}")
|
38 |
return False
|
39 |
|
40 |
-
@functools.lru_cache(maxsize=100)
|
41 |
def get_card_info(hub_id: str) -> Tuple[str, str]:
|
42 |
"""Get card information from a Hugging Face hub_id."""
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
try:
|
44 |
info = model_info(hub_id)
|
45 |
card = ModelCard.load(hub_id)
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
except Exception as e:
|
48 |
-
logger.
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
@spaces.GPU
|
58 |
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
|
@@ -71,7 +90,7 @@ def _generate_summary_gpu(card_text: str, card_type: str) -> str:
|
|
71 |
with torch.no_grad():
|
72 |
outputs = model.generate(
|
73 |
inputs,
|
74 |
-
max_new_tokens=
|
75 |
pad_token_id=tokenizer.pad_token_id,
|
76 |
eos_token_id=tokenizer.eos_token_id,
|
77 |
temperature=0.4,
|
@@ -102,16 +121,29 @@ def summarize(hub_id: str = "") -> str:
|
|
102 |
if hub_id:
|
103 |
# Fetch and infer card type automatically
|
104 |
card_type, card_text = get_card_info(hub_id)
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
else:
|
109 |
-
|
110 |
-
return f'{{"error": "{error_msg}"}}'
|
111 |
|
112 |
except Exception as e:
|
113 |
-
|
114 |
-
return f'{{"error": "{error_msg}"}}'
|
115 |
|
116 |
def create_interface():
|
117 |
interface = gr.Interface(
|
@@ -126,6 +158,6 @@ def create_interface():
|
|
126 |
if __name__ == "__main__":
|
127 |
if load_model():
|
128 |
interface = create_interface()
|
129 |
-
interface.launch(
|
130 |
else:
|
131 |
print("Failed to load model. Please check the logs for details.")
|
|
|
10 |
from cachetools.func import ttl_cache
|
11 |
import time
|
12 |
import os
|
13 |
+
import json
|
14 |
os.environ['HF_TRANSFER'] = "1"
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.INFO)
|
|
|
38 |
logger.error(f"Failed to load model: {e}")
|
39 |
return False
|
40 |
|
|
|
41 |
def get_card_info(hub_id: str) -> Tuple[str, str]:
|
42 |
"""Get card information from a Hugging Face hub_id."""
|
43 |
+
model_exists = False
|
44 |
+
dataset_exists = False
|
45 |
+
model_text = None
|
46 |
+
dataset_text = None
|
47 |
+
|
48 |
+
# Try getting model card
|
49 |
try:
|
50 |
info = model_info(hub_id)
|
51 |
card = ModelCard.load(hub_id)
|
52 |
+
model_exists = True
|
53 |
+
model_text = card.text
|
54 |
+
except Exception as e:
|
55 |
+
logger.debug(f"No model card found for {hub_id}: {e}")
|
56 |
+
|
57 |
+
# Try getting dataset card
|
58 |
+
try:
|
59 |
+
info = dataset_info(hub_id)
|
60 |
+
card = DatasetCard.load(hub_id)
|
61 |
+
dataset_exists = True
|
62 |
+
dataset_text = card.text
|
63 |
except Exception as e:
|
64 |
+
logger.debug(f"No dataset card found for {hub_id}: {e}")
|
65 |
+
|
66 |
+
# Handle different cases
|
67 |
+
if model_exists and dataset_exists:
|
68 |
+
return "both", (model_text, dataset_text)
|
69 |
+
elif model_exists:
|
70 |
+
return "model", model_text
|
71 |
+
elif dataset_exists:
|
72 |
+
return "dataset", dataset_text
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Could not find model or dataset with id {hub_id}")
|
75 |
|
76 |
@spaces.GPU
|
77 |
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
|
|
|
90 |
with torch.no_grad():
|
91 |
outputs = model.generate(
|
92 |
inputs,
|
93 |
+
max_new_tokens=0,
|
94 |
pad_token_id=tokenizer.pad_token_id,
|
95 |
eos_token_id=tokenizer.eos_token_id,
|
96 |
temperature=0.4,
|
|
|
121 |
if hub_id:
|
122 |
# Fetch and infer card type automatically
|
123 |
card_type, card_text = get_card_info(hub_id)
|
124 |
+
|
125 |
+
if card_type == "both":
|
126 |
+
model_text, dataset_text = card_text
|
127 |
+
model_summary = generate_summary(model_text, "model")
|
128 |
+
dataset_summary = generate_summary(dataset_text, "dataset")
|
129 |
+
return json.dumps({
|
130 |
+
"type": "both",
|
131 |
+
"hub_id": hub_id,
|
132 |
+
"model_summary": model_summary,
|
133 |
+
"dataset_summary": dataset_summary
|
134 |
+
})
|
135 |
+
else:
|
136 |
+
summary = generate_summary(card_text, card_type)
|
137 |
+
return json.dumps({
|
138 |
+
"summary": summary,
|
139 |
+
"type": card_type,
|
140 |
+
"hub_id": hub_id
|
141 |
+
})
|
142 |
else:
|
143 |
+
return json.dumps({"error": "Hub ID must be provided"})
|
|
|
144 |
|
145 |
except Exception as e:
|
146 |
+
return json.dumps({"error": str(e)})
|
|
|
147 |
|
148 |
def create_interface():
|
149 |
interface = gr.Interface(
|
|
|
158 |
if __name__ == "__main__":
|
159 |
if load_model():
|
160 |
interface = create_interface()
|
161 |
+
interface.launch()
|
162 |
else:
|
163 |
print("Failed to load model. Please check the logs for details.")
|