Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import re | |
import requests | |
from urllib.parse import urlparse | |
import xml.etree.ElementTree as ET | |
################################################## | |
# Global setup | |
################################################## | |
model_path = "ssocean/NAIP" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = None | |
tokenizer = None | |
################################################## | |
# Fetch paper info from arXiv | |
################################################## | |
def fetch_arxiv_paper(arxiv_input): | |
""" | |
Fetch paper title & abstract from an arXiv URL or ID. | |
""" | |
try: | |
if "arxiv.org" in arxiv_input: | |
parsed = urlparse(arxiv_input) | |
path = parsed.path | |
arxiv_id = path.split("/")[-1].replace(".pdf", "") | |
else: | |
arxiv_id = arxiv_input.strip() | |
api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" | |
resp = requests.get(api_url) | |
if resp.status_code != 200: | |
return { | |
"title": "", | |
"abstract": "", | |
"success": False, | |
"message": "Error fetching paper from arXiv API", | |
} | |
root = ET.fromstring(resp.text) | |
ns = {"arxiv": "http://www.w3.org/2005/Atom"} | |
entry = root.find(".//arxiv:entry", ns) | |
if entry is None: | |
return {"title": "", "abstract": "", "success": False, "message": "Paper not found"} | |
title = entry.find("arxiv:title", ns).text.strip() | |
abstract = entry.find("arxiv:summary", ns).text.strip() | |
return { | |
"title": title, | |
"abstract": abstract, | |
"success": True, | |
"message": "Paper fetched successfully!", | |
} | |
except Exception as e: | |
return { | |
"title": "", | |
"abstract": "", | |
"success": False, | |
"message": f"Error fetching paper: {e}", | |
} | |
################################################## | |
# Prediction function | |
################################################## | |
def predict(title, abstract): | |
""" | |
Predict a normalized academic impact score (0โ1) from title & abstract. | |
""" | |
global model, tokenizer | |
if model is None: | |
# 1) Load config | |
config = AutoConfig.from_pretrained(model_path) | |
# 2) Remove quantization_config if it exists (avoid NoneType error in PEFT) | |
if hasattr(config, "quantization_config"): | |
del config.quantization_config | |
# 3) Optionally set number of labels | |
config.num_labels = 1 | |
# 4) Load the model | |
model_loaded = AutoModelForSequenceClassification.from_pretrained( | |
model_path, | |
config=config, | |
torch_dtype=torch.float32, # float32 for stable cublasLt | |
device_map=None, | |
low_cpu_mem_usage=False | |
) | |
model_loaded.to(device) | |
model_loaded.eval() | |
# 5) Load tokenizer | |
tokenizer_loaded = AutoTokenizer.from_pretrained(model_path) | |
# Assign to globals | |
model, tokenizer = model_loaded, tokenizer_loaded | |
text = ( | |
f"Given a certain paper,\n" | |
f"Title: {title.strip()}\n" | |
f"Abstract: {abstract.strip()}\n" | |
f"Predict its normalized academic impact (0~1):" | |
) | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
prob = torch.sigmoid(logits).item() | |
score = min(1.0, prob + 0.05) | |
return round(score, 4) | |
except Exception as e: | |
print("Prediction error:", e) | |
return 0.0 | |
################################################## | |
# Grading | |
################################################## | |
def get_grade_and_emoji(score): | |
"""Map a 0โ1 score to an A/B/C style grade with an emoji indicator.""" | |
if score >= 0.900: | |
return "AAA ๐" | |
if score >= 0.800: | |
return "AA โญ" | |
if score >= 0.650: | |
return "A โจ" | |
if score >= 0.600: | |
return "BBB ๐ต" | |
if score >= 0.550: | |
return "BB ๐" | |
if score >= 0.500: | |
return "B ๐" | |
if score >= 0.400: | |
return "CCC ๐" | |
if score >= 0.300: | |
return "CC โ๏ธ" | |
return "C ๐" | |
################################################## | |
# Validation | |
################################################## | |
def validate_input(title, abstract): | |
""" | |
Ensure the title has at least 3 words, the abstract at least 50, | |
and check for ASCII-only characters. | |
""" | |
non_ascii = re.compile(r"[^\x00-\x7F]") | |
if len(title.split()) < 3: | |
return False, "Title must be at least 3 words." | |
if len(abstract.split()) < 50: | |
return False, "Abstract must be at least 50 words." | |
if non_ascii.search(title): | |
return False, "Title contains non-ASCII characters." | |
if non_ascii.search(abstract): | |
return False, "Abstract contains non-ASCII characters." | |
return True, "Inputs look good." | |
def update_button_status(title, abstract): | |
"""Enable or disable the predict button based on validation.""" | |
valid, msg = validate_input(title, abstract) | |
if not valid: | |
return gr.update(value="Error: " + msg), gr.update(interactive=False) | |
return gr.update(value=msg), gr.update(interactive=True) | |
################################################## | |
# Process arXiv input | |
################################################## | |
def process_arxiv_input(arxiv_input): | |
""" | |
Called when user clicks 'Fetch Paper Details' to fill in title/abstract from arXiv. | |
""" | |
if not arxiv_input.strip(): | |
return "", "", "Please enter an arXiv URL or ID" | |
res = fetch_arxiv_paper(arxiv_input) | |
if res["success"]: | |
return res["title"], res["abstract"], res["message"] | |
return "", "", res["message"] | |
################################################## | |
# Custom CSS | |
################################################## | |
css = """ | |
.gradio-container { font-family: Arial, sans-serif; } | |
.main-title { | |
text-align: center; color: #2563eb; font-size: 2.5rem!important; | |
margin-bottom:1rem!important; | |
background: linear-gradient(45deg,#2563eb,#1d4ed8); | |
-webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
} | |
.input-section { | |
background:#fff; padding:1.5rem; border-radius:0.5rem; | |
box-shadow:0 4px 6px rgba(0,0,0,0.1); | |
} | |
.result-section { | |
background:#f7f9fc; padding:1.5rem; border-radius:0.5rem; | |
margin-top:2rem; | |
} | |
.grade-display { | |
font-size:2.5rem; text-align:center; margin-top:1rem; | |
} | |
.arxiv-input { | |
margin-bottom:1.5rem; padding:1rem; background:#f3f4f6; | |
border-radius:0.5rem; | |
} | |
.arxiv-link { | |
color:#2563eb; text-decoration: underline; | |
} | |
""" | |
################################################## | |
# Example Papers | |
################################################## | |
example_papers = [ | |
{ | |
"title": "Attention Is All You Need", | |
"abstract": ( | |
"The dominant sequence transduction models are based on complex recurrent or " | |
"convolutional neural networks that include an encoder and a decoder. The best performing " | |
"models also connect the encoder and decoder through an attention mechanism. We propose a " | |
"new simple network architecture, the Transformer, based solely on attention mechanisms, " | |
"dispensing with recurrence and convolutions entirely. Experiments on two machine " | |
"translation tasks show these models to be superior in quality while being more " | |
"parallelizable and requiring significantly less time to train." | |
), | |
"score": 0.982, | |
"note": "Revolutionary paper that introduced the Transformer architecture." | |
}, | |
{ | |
"title": "Language Models are Few-Shot Learners", | |
"abstract": ( | |
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by " | |
"pre-training on a large corpus of text followed by fine-tuning on a specific task. While " | |
"typically task-agnostic in architecture, this method still requires task-specific " | |
"fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans " | |
"can generally perform a new language task from only a few examples or from simple " | |
"instructionsโsomething which current NLP systems still largely struggle to do. Here we " | |
"show that scaling up language models greatly improves task-agnostic, few-shot " | |
"performance, sometimes even reaching competitiveness with prior state-of-the-art " | |
"fine-tuning approaches." | |
), | |
"score": 0.956, | |
"note": "Groundbreaking GPT-3 paper on few-shot learning." | |
}, | |
{ | |
"title": "An Empirical Study of Neural Network Training Protocols", | |
"abstract": ( | |
"This paper presents a comparative analysis of different training protocols for neural " | |
"networks across various architectures. We examine the effects of learning rate schedules, " | |
"batch size selection, and optimization algorithms on model convergence and final " | |
"performance. Our experiments span multiple datasets and model sizes, providing practical " | |
"insights for deep learning practitioners." | |
), | |
"score": 0.623, | |
"note": "Solid empirical comparison of training protocols." | |
} | |
] | |
################################################## | |
# Build the Gradio Interface | |
################################################## | |
with gr.Blocks(theme=gr.themes.Default(), css=css) as iface: | |
gr.Markdown("<div class='main-title'>Papers Impact: AI-Powered Research Impact Predictor</div>") | |
gr.Markdown("**Predict the potential research impact (0โ1) from title & abstract.**") | |
with gr.Row(): | |
with gr.Column(elem_classes="input-section"): | |
gr.Markdown("### Import from arXiv") | |
with gr.Group(elem_classes="arxiv-input"): | |
arxiv_input = gr.Textbox( | |
lines=1, | |
placeholder="e.g. 2504.11651", | |
label="arXiv URL or ID", | |
value="2504.11651" | |
) | |
gr.Markdown( | |
""" | |
<p> | |
Enter an arXiv ID or URL. For example: | |
<code>2504.11651</code> or <code>https://arxiv.org/pdf/2504.11651</code> | |
</p> | |
""" | |
) | |
fetch_btn = gr.Button("๐ Fetch Paper Details", variant="secondary") | |
gr.Markdown("### Or Enter Manually") | |
title_input = gr.Textbox( | |
lines=2, | |
placeholder="Paper title (โฅ3 words)...", | |
label="Paper Title" | |
) | |
abs_input = gr.Textbox( | |
lines=5, | |
placeholder="Paper abstract (โฅ50 words)...", | |
label="Paper Abstract" | |
) | |
status_box = gr.Textbox(label="Validation Status", interactive=False) | |
predict_btn = gr.Button("๐ฏ Predict Impact", interactive=False, variant="primary") | |
with gr.Column(elem_classes="result-section"): | |
score_box = gr.Number(label="Impact Score") | |
grade_box = gr.Textbox(label="Grade", elem_classes="grade-display") | |
############## METHODOLOGY EXPLANATION ############## | |
gr.Markdown( | |
""" | |
### Scientific Methodology | |
- **Training Data**: Model trained on an extensive dataset of published papers in CS.CV, CS.CL, CS.AI | |
- **Optimization**: NDCG optimization with Sigmoid activation and MSE loss | |
- **Validation**: Cross-validated against historical citation data | |
- **Architecture**: Advanced transformer-based (LLaMA derivative) textual encoder | |
- **Metrics**: Quantitative analysis of citation patterns and research influence | |
""" | |
) | |
############## RATING SCALE ############## | |
gr.Markdown( | |
""" | |
### Rating Scale | |
| Grade | Score Range | Description | Emoji | | |
|-------|-------------|---------------------|-------| | |
| AAA | 0.900โ1.000 | **Exceptional** | ๐ | | |
| AA | 0.800โ0.899 | **Very High** | โญ | | |
| A | 0.650โ0.799 | **High** | โจ | | |
| BBB | 0.600โ0.649 | **Above Average** | ๐ต | | |
| BB | 0.550โ0.599 | **Moderate** | ๐ | | |
| B | 0.500โ0.549 | **Average** | ๐ | | |
| CCC | 0.400โ0.499 | **Below Average** | ๐ | | |
| CC | 0.300โ0.399 | **Low** | โ๏ธ | | |
| C | <0.300 | **Limited** | ๐ | | |
""" | |
) | |
############## EXAMPLE PAPERS ############## | |
gr.Markdown("### Example Papers") | |
for paper in example_papers: | |
gr.Markdown( | |
f"**{paper['title']}** \n" | |
f"Score: {paper['score']} | Grade: {get_grade_and_emoji(paper['score'])} \n" | |
f"{paper['abstract']} \n" | |
f"*{paper['note']}*\n---" | |
) | |
################################################## | |
# Events | |
################################################## | |
# Validation triggers | |
title_input.change(update_button_status, [title_input, abs_input], [status_box, predict_btn]) | |
abs_input.change(update_button_status, [title_input, abs_input], [status_box, predict_btn]) | |
# arXiv fetch | |
fetch_btn.click(process_arxiv_input, [arxiv_input], [title_input, abs_input, status_box]) | |
# Predict handler | |
def run_predict(t, a): | |
s = predict(t, a) | |
return s, get_grade_and_emoji(s) | |
predict_btn.click(run_predict, [title_input, abs_input], [score_box, grade_box]) | |
################################################## | |
# Launch | |
################################################## | |
if __name__ == "__main__": | |
iface.launch() | |