Spaces:
Runtime error
Runtime error
File size: 4,655 Bytes
cd9e1fc f840f33 cd9e1fc 84837a0 cd9e1fc 84837a0 cd9e1fc dcadea4 cd9e1fc f840f33 cd9e1fc f840f33 65e66d9 f840f33 cd9e1fc 8defca4 cd9e1fc 8defca4 0749c7f 1516d05 cd9e1fc 18b8fa4 cd9e1fc 5055671 0749c7f dcadea4 cd9e1fc 9e56b79 cd9e1fc e993597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import re
import urllib
import urllib.request
import xml.etree.ElementTree as ET
import gradio as gr
import torch
from transformers import pipeline
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
def get_arxiv_title_and_abstract(link):
try:
# Validate the arxiv link
pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
match = re.match(pattern, link)
if not match:
raise ValueError('Invalid arxiv link')
# Construct the arxiv API URL
arxiv_id = match.group(2)
api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
# Send a request to the arxiv API
response = urllib.request.urlopen(api_url)
xml_data = response.read()
# Parse the XML data
root = ET.fromstring(xml_data)
entry = root.find('{http://www.w3.org/2005/Atom}entry')
title = entry.find('{http://www.w3.org/2005/Atom}title').text
summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
return title, summary
except:
raise gr.Error('Invalid arXiv URL!')
def classify_paper(title, abstract):
if title == '' and abstract == '':
raise gr.Error('Fill Title or/and Abstract')
text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
item = classifier.tokenizer(text)
input_tensor = torch.tensor(item['input_ids'])[None]
logits = classifier.model(input_tensor).logits[0]
preds = torch.sigmoid(logits).detach().cpu().numpy()
result = {}
for num, prob in enumerate(preds):
if prob < 0.25:
continue
if classifier.model.config.id2label[num] in result:
if result[classifier.model.config.id2label[num]] > prob:
continue
result[classifier.model.config.id2label[num]] = float(prob)
return result
with gr.Blocks(title='Paper classifier') as demo:
gr.Markdown('# Paper Topic Classifier')
with gr.Row():
with gr.Column():
gr.Markdown('## Inputs')
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")
b1 = gr.Button("Parse Link")
title = gr.Textbox(label="Paper title", placeholder="Title text")
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
b2 = gr.Button("Classify Paper", variant='primary')
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
with gr.Column():
gr.Markdown('## Topics')
gr.Markdown('## ')
gr.Markdown('## ')
out = gr.Label(label="Topics")
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
gr.Markdown('## Examples')
gr.Examples(
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
inputs=arxiv_link,
outputs=[title, abstract],
fn=get_arxiv_title_and_abstract,
cache_examples=True,
)
demo.launch()
with gr.Blocks(title='Paper classifier') as demo:
gr.Markdown('# Paper Topic Classifier')
with gr.Row():
with gr.Column():
gr.Markdown('## Inputs')
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")
b1 = gr.Button("Parse Link")
title = gr.Textbox(label="Paper title", placeholder="Title text")
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
b2 = gr.Button("Classify Paper", variant='primary')
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
with gr.Column():
gr.Markdown('## Topics')
gr.Markdown('## ')
gr.Markdown('## ')
out = gr.Label(label="Topics")
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
gr.Markdown('## Examples')
gr.Examples(
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
inputs=arxiv_link,
outputs=[title, abstract],
fn=get_arxiv_title_and_abstract,
cache_examples=True,
)
demo.launch() |