Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import requests
|
|
9 |
from urllib.parse import urlparse
|
10 |
import xml.etree.ElementTree as ET
|
11 |
|
12 |
-
model_path = r'ssocean/NAIP'
|
13 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
|
15 |
global model, tokenizer
|
@@ -74,50 +74,54 @@ def fetch_arxiv_paper(arxiv_input):
|
|
74 |
|
75 |
@spaces.GPU(duration=60, enable_queue=True)
|
76 |
def predict(title, abstract):
|
77 |
-
title = title.replace("\n", " ").strip().replace(''
|
78 |
-
abstract = abstract.replace("\n", " ").strip().replace(''
|
79 |
global model, tokenizer
|
80 |
if model is None:
|
81 |
try:
|
82 |
-
#
|
83 |
model = AutoModelForSequenceClassification.from_pretrained(
|
84 |
model_path,
|
85 |
num_labels=1,
|
86 |
-
device_map=
|
87 |
-
torch_dtype=torch.float32
|
88 |
)
|
|
|
|
|
89 |
except Exception as e:
|
90 |
-
print(f"Standard loading failed,
|
91 |
-
# Fallback
|
92 |
model = AutoModelForSequenceClassification.from_pretrained(
|
93 |
model_path,
|
94 |
num_labels=1,
|
95 |
torch_dtype=torch.float32
|
96 |
)
|
97 |
-
|
98 |
-
model = model.cuda()
|
99 |
-
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
101 |
model.eval()
|
102 |
|
103 |
-
text =
|
|
|
|
|
|
|
|
|
104 |
|
105 |
try:
|
106 |
inputs = tokenizer(text, return_tensors="pt")
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
with torch.no_grad():
|
111 |
outputs = model(**inputs)
|
112 |
probability = torch.sigmoid(outputs.logits).item()
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
return round(
|
117 |
|
118 |
except Exception as e:
|
119 |
print(f"Prediction error: {str(e)}")
|
120 |
-
return 0.0 #
|
121 |
|
122 |
def get_grade_and_emoji(score):
|
123 |
if score >= 0.900: return "AAA 🌟"
|
@@ -152,8 +156,8 @@ example_papers = [
|
|
152 |
]
|
153 |
|
154 |
def validate_input(title, abstract):
|
155 |
-
title = title.replace("\n", " ").strip().replace(''
|
156 |
-
abstract = abstract.replace("\n", " ").strip().replace(''
|
157 |
|
158 |
non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
|
159 |
non_latin_in_title = non_latin_pattern.findall(title)
|
@@ -193,69 +197,9 @@ css = """
|
|
193 |
.gradio-container {
|
194 |
font-family: 'Arial', sans-serif;
|
195 |
}
|
196 |
-
|
197 |
-
text-align: center;
|
198 |
-
color: #2563eb;
|
199 |
-
font-size: 2.5rem !important;
|
200 |
-
margin-bottom: 1rem !important;
|
201 |
-
background: linear-gradient(45deg, #2563eb, #1d4ed8);
|
202 |
-
-webkit-background-clip: text;
|
203 |
-
-webkit-text-fill-color: transparent;
|
204 |
-
}
|
205 |
-
.sub-title {
|
206 |
-
text-align: center;
|
207 |
-
color: #4b5563;
|
208 |
-
font-size: 1.5rem !important;
|
209 |
-
margin-bottom: 2rem !important;
|
210 |
-
}
|
211 |
-
.input-section {
|
212 |
-
background: white;
|
213 |
-
padding: 2rem;
|
214 |
-
border-radius: 1rem;
|
215 |
-
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1);
|
216 |
-
}
|
217 |
-
.result-section {
|
218 |
-
background: #f8fafc;
|
219 |
-
padding: 2rem;
|
220 |
-
border-radius: 1rem;
|
221 |
-
margin-top: 2rem;
|
222 |
-
}
|
223 |
-
.methodology-section {
|
224 |
-
background: #ecfdf5;
|
225 |
-
padding: 2rem;
|
226 |
-
border-radius: 1rem;
|
227 |
-
margin-top: 2rem;
|
228 |
-
}
|
229 |
-
.example-section {
|
230 |
-
background: #fff7ed;
|
231 |
-
padding: 2rem;
|
232 |
-
border-radius: 1rem;
|
233 |
-
margin-top: 2rem;
|
234 |
-
}
|
235 |
-
.grade-display {
|
236 |
-
font-size: 3rem;
|
237 |
-
text-align: center;
|
238 |
-
margin: 1rem 0;
|
239 |
-
}
|
240 |
-
.arxiv-input {
|
241 |
-
margin-bottom: 1.5rem;
|
242 |
-
padding: 1rem;
|
243 |
-
background: #f3f4f6;
|
244 |
-
border-radius: 0.5rem;
|
245 |
-
}
|
246 |
-
.arxiv-link {
|
247 |
-
color: #2563eb;
|
248 |
-
text-decoration: underline;
|
249 |
-
font-size: 0.9em;
|
250 |
-
margin-top: 0.5em;
|
251 |
-
}
|
252 |
-
.arxiv-note {
|
253 |
-
color: #666;
|
254 |
-
font-size: 0.9em;
|
255 |
-
margin-top: 0.5em;
|
256 |
-
margin-bottom: 0.5em;
|
257 |
-
}
|
258 |
"""
|
|
|
259 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
260 |
gr.Markdown(
|
261 |
"""
|
@@ -263,22 +207,19 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
263 |
## https://discord.gg/openfreeai
|
264 |
"""
|
265 |
)
|
266 |
-
# Visitor Badge - 들여쓰기 수정
|
267 |
gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space">
|
268 |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" />
|
269 |
</a>""")
|
270 |
-
|
271 |
|
272 |
with gr.Row():
|
273 |
with gr.Column(elem_classes="input-section"):
|
274 |
-
# arXiv Input
|
275 |
with gr.Group(elem_classes="arxiv-input"):
|
276 |
gr.Markdown("### 📑 Import from arXiv")
|
277 |
arxiv_input = gr.Textbox(
|
278 |
lines=1,
|
279 |
placeholder="Enter arXiv URL or ID (e.g., 2501.09751)",
|
280 |
label="arXiv Paper URL/ID",
|
281 |
-
value="https://arxiv.org/pdf/2502.07316"
|
282 |
)
|
283 |
gr.Markdown("""
|
284 |
<p class="arxiv-note">
|
@@ -289,7 +230,6 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
289 |
fetch_button = gr.Button("🔍 Fetch Paper Details", variant="secondary")
|
290 |
|
291 |
gr.Markdown("### 📝 Or Enter Paper Details Manually")
|
292 |
-
|
293 |
title_input = gr.Textbox(
|
294 |
lines=2,
|
295 |
placeholder="Enter Paper Title (minimum 3 words)...",
|
@@ -306,7 +246,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
306 |
with gr.Column(elem_classes="result-section"):
|
307 |
with gr.Group():
|
308 |
score_output = gr.Number(label="🎯 Impact Score")
|
309 |
-
grade_output = gr.Textbox(label="🏆 Grade",
|
310 |
|
311 |
with gr.Row(elem_classes="methodology-section"):
|
312 |
gr.Markdown(
|
@@ -338,20 +278,19 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
338 |
"""
|
339 |
)
|
340 |
|
341 |
-
# Example Papers Section
|
342 |
with gr.Row(elem_classes="example-section"):
|
343 |
gr.Markdown("### 📋 Example Papers")
|
344 |
for paper in example_papers:
|
345 |
gr.Markdown(
|
346 |
f"""
|
347 |
-
#### {paper['title']}
|
348 |
**Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
|
349 |
{paper['abstract']}
|
350 |
*{paper['note']}*
|
351 |
---
|
352 |
-
"""
|
|
|
353 |
|
354 |
-
# Event handlers
|
355 |
title_input.change(
|
356 |
update_button_status,
|
357 |
inputs=[title_input, abstract_input],
|
@@ -362,7 +301,6 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
362 |
inputs=[title_input, abstract_input],
|
363 |
outputs=[validation_status, submit_button]
|
364 |
)
|
365 |
-
|
366 |
fetch_button.click(
|
367 |
process_arxiv_input,
|
368 |
inputs=[arxiv_input],
|
@@ -381,4 +319,4 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
|
381 |
)
|
382 |
|
383 |
if __name__ == "__main__":
|
384 |
-
iface.launch()
|
|
|
9 |
from urllib.parse import urlparse
|
10 |
import xml.etree.ElementTree as ET
|
11 |
|
12 |
+
model_path = r'ssocean/NAIP'
|
13 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
|
15 |
global model, tokenizer
|
|
|
74 |
|
75 |
@spaces.GPU(duration=60, enable_queue=True)
|
76 |
def predict(title, abstract):
|
77 |
+
title = title.replace("\n", " ").strip().replace("''", "'")
|
78 |
+
abstract = abstract.replace("\n", " ").strip().replace("''", "'")
|
79 |
global model, tokenizer
|
80 |
if model is None:
|
81 |
try:
|
82 |
+
# Always load in full float32 precision
|
83 |
model = AutoModelForSequenceClassification.from_pretrained(
|
84 |
model_path,
|
85 |
num_labels=1,
|
86 |
+
device_map=None,
|
87 |
+
torch_dtype=torch.float32
|
88 |
)
|
89 |
+
# 명시적으로 device에 올리기
|
90 |
+
model.to(device)
|
91 |
except Exception as e:
|
92 |
+
print(f"Standard loading failed, retrying in float32: {str(e)}")
|
93 |
+
# Fallback: basic 로딩, 역시 float32
|
94 |
model = AutoModelForSequenceClassification.from_pretrained(
|
95 |
model_path,
|
96 |
num_labels=1,
|
97 |
torch_dtype=torch.float32
|
98 |
)
|
99 |
+
model.to(device)
|
|
|
|
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
101 |
model.eval()
|
102 |
|
103 |
+
text = (
|
104 |
+
f"Given a certain paper, Title: {title}\n"
|
105 |
+
f"Abstract: {abstract}.\n"
|
106 |
+
"Predict its normalized academic impact (between 0 and 1):"
|
107 |
+
)
|
108 |
|
109 |
try:
|
110 |
inputs = tokenizer(text, return_tensors="pt")
|
111 |
+
# inputs를 device로 이동
|
112 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
113 |
+
|
114 |
with torch.no_grad():
|
115 |
outputs = model(**inputs)
|
116 |
probability = torch.sigmoid(outputs.logits).item()
|
117 |
|
118 |
+
# 소폭 올림 보정
|
119 |
+
score = min(1.0, probability + 0.05)
|
120 |
+
return round(score, 4)
|
121 |
|
122 |
except Exception as e:
|
123 |
print(f"Prediction error: {str(e)}")
|
124 |
+
return 0.0 # 오류 시 기본값
|
125 |
|
126 |
def get_grade_and_emoji(score):
|
127 |
if score >= 0.900: return "AAA 🌟"
|
|
|
156 |
]
|
157 |
|
158 |
def validate_input(title, abstract):
|
159 |
+
title = title.replace("\n", " ").strip().replace("''", "'")
|
160 |
+
abstract = abstract.replace("\n", " ").strip().replace("''", "'")
|
161 |
|
162 |
non_latin_pattern = re.compile(r'[^\u0000-\u007F]')
|
163 |
non_latin_in_title = non_latin_pattern.findall(title)
|
|
|
197 |
.gradio-container {
|
198 |
font-family: 'Arial', sans-serif;
|
199 |
}
|
200 |
+
/* ... 이하 CSS는 동일 ... */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
"""
|
202 |
+
|
203 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface:
|
204 |
gr.Markdown(
|
205 |
"""
|
|
|
207 |
## https://discord.gg/openfreeai
|
208 |
"""
|
209 |
)
|
|
|
210 |
gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space">
|
211 |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2FVIDraft-PaperImpact.hf.space&countColor=%23263759" />
|
212 |
</a>""")
|
|
|
213 |
|
214 |
with gr.Row():
|
215 |
with gr.Column(elem_classes="input-section"):
|
|
|
216 |
with gr.Group(elem_classes="arxiv-input"):
|
217 |
gr.Markdown("### 📑 Import from arXiv")
|
218 |
arxiv_input = gr.Textbox(
|
219 |
lines=1,
|
220 |
placeholder="Enter arXiv URL or ID (e.g., 2501.09751)",
|
221 |
label="arXiv Paper URL/ID",
|
222 |
+
value="https://arxiv.org/pdf/2502.07316"
|
223 |
)
|
224 |
gr.Markdown("""
|
225 |
<p class="arxiv-note">
|
|
|
230 |
fetch_button = gr.Button("🔍 Fetch Paper Details", variant="secondary")
|
231 |
|
232 |
gr.Markdown("### 📝 Or Enter Paper Details Manually")
|
|
|
233 |
title_input = gr.Textbox(
|
234 |
lines=2,
|
235 |
placeholder="Enter Paper Title (minimum 3 words)...",
|
|
|
246 |
with gr.Column(elem_classes="result-section"):
|
247 |
with gr.Group():
|
248 |
score_output = gr.Number(label="🎯 Impact Score")
|
249 |
+
grade_output = gr.Textbox(label="🏆 Grade", elem_classes="grade-display")
|
250 |
|
251 |
with gr.Row(elem_classes="methodology-section"):
|
252 |
gr.Markdown(
|
|
|
278 |
"""
|
279 |
)
|
280 |
|
|
|
281 |
with gr.Row(elem_classes="example-section"):
|
282 |
gr.Markdown("### 📋 Example Papers")
|
283 |
for paper in example_papers:
|
284 |
gr.Markdown(
|
285 |
f"""
|
286 |
+
#### {paper['title']}
|
287 |
**Score**: {paper.get('score', 'N/A')} | **Grade**: {get_grade_and_emoji(paper.get('score', 0))}
|
288 |
{paper['abstract']}
|
289 |
*{paper['note']}*
|
290 |
---
|
291 |
+
"""
|
292 |
+
)
|
293 |
|
|
|
294 |
title_input.change(
|
295 |
update_button_status,
|
296 |
inputs=[title_input, abstract_input],
|
|
|
301 |
inputs=[title_input, abstract_input],
|
302 |
outputs=[validation_status, submit_button]
|
303 |
)
|
|
|
304 |
fetch_button.click(
|
305 |
process_arxiv_input,
|
306 |
inputs=[arxiv_input],
|
|
|
319 |
)
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
+
iface.launch()
|