Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,49 @@ from transformers import pipeline
|
|
10 |
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
import re
|
14 |
import urllib.request
|
15 |
import xml.etree.ElementTree as ET
|
@@ -50,10 +93,54 @@ def classify_paper(title, abstract):
|
|
50 |
input_tensor = torch.tensor(item['input_ids'])[None]
|
51 |
logits = classifier.model(input_tensor).logits[0]
|
52 |
preds = torch.sigmoid(logits).detach().cpu().numpy()
|
53 |
-
result = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
return result
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
with gr.Blocks(title='Paper classifier') as demo:
|
58 |
gr.Markdown('# Paper Topic Classifier')
|
59 |
with gr.Row():
|
|
|
10 |
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
11 |
|
12 |
|
13 |
+
import re
|
14 |
+
import urllib.request
|
15 |
+
import xml.etree.ElementTree as ET
|
16 |
+
|
17 |
+
def get_arxiv_title_and_abstract(link):
|
18 |
+
try:
|
19 |
+
# Validate the arxiv link
|
20 |
+
pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
|
21 |
+
match = re.match(pattern, link)
|
22 |
+
if not match:
|
23 |
+
raise ValueError('Invalid arxiv link')
|
24 |
+
|
25 |
+
# Construct the arxiv API URL
|
26 |
+
arxiv_id = match.group(2)
|
27 |
+
api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
|
28 |
+
|
29 |
+
# Send a request to the arxiv API
|
30 |
+
response = urllib.request.urlopen(api_url)
|
31 |
+
xml_data = response.read()
|
32 |
+
|
33 |
+
# Parse the XML data
|
34 |
+
root = ET.fromstring(xml_data)
|
35 |
+
entry = root.find('{http://www.w3.org/2005/Atom}entry')
|
36 |
+
title = entry.find('{http://www.w3.org/2005/Atom}title').text
|
37 |
+
summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
|
38 |
+
|
39 |
+
return title, summary
|
40 |
+
except:
|
41 |
+
raise gr.Error('Invalid arXiv URL!')
|
42 |
+
|
43 |
+
|
44 |
+
import gradio as gr
|
45 |
+
import xml.etree.ElementTree as ET
|
46 |
+
import re
|
47 |
+
import urllib
|
48 |
+
import torch
|
49 |
+
|
50 |
+
|
51 |
+
from transformers import pipeline
|
52 |
+
|
53 |
+
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
|
54 |
+
|
55 |
+
|
56 |
import re
|
57 |
import urllib.request
|
58 |
import xml.etree.ElementTree as ET
|
|
|
93 |
input_tensor = torch.tensor(item['input_ids'])[None]
|
94 |
logits = classifier.model(input_tensor).logits[0]
|
95 |
preds = torch.sigmoid(logits).detach().cpu().numpy()
|
96 |
+
result = {}
|
97 |
+
for num, prob in enumerate(preds):
|
98 |
+
if prob < 0.25:
|
99 |
+
continue
|
100 |
+
if classifier.model.config.id2label[num] in result:
|
101 |
+
if result[classifier.model.config.id2label[num]] > prob:
|
102 |
+
continue
|
103 |
+
result[classifier.model.config.id2label[num]] = float(prob)
|
104 |
return result
|
105 |
|
106 |
|
107 |
+
with gr.Blocks(title='Paper classifier') as demo:
|
108 |
+
gr.Markdown('# Paper Topic Classifier')
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Column():
|
111 |
+
gr.Markdown('## Inputs')
|
112 |
+
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
|
113 |
+
arxiv_link = gr.Textbox(label="Arxiv link", placeholder="Flip this text")
|
114 |
+
|
115 |
+
b1 = gr.Button("Parse Link")
|
116 |
+
|
117 |
+
title = gr.Textbox(label="Paper title", placeholder="Title text")
|
118 |
+
abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
|
119 |
+
|
120 |
+
b2 = gr.Button("Classify Paper", variant='primary')
|
121 |
+
|
122 |
+
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
|
123 |
+
|
124 |
+
|
125 |
+
with gr.Column():
|
126 |
+
gr.Markdown('## Topics')
|
127 |
+
gr.Markdown('## ')
|
128 |
+
gr.Markdown('## ')
|
129 |
+
out = gr.Label(label="Topics")
|
130 |
+
b2.click(classify_paper, inputs=[title, abstract], outputs=out)
|
131 |
+
|
132 |
+
gr.Markdown('## Examples')
|
133 |
+
gr.Examples(
|
134 |
+
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
|
135 |
+
inputs=arxiv_link,
|
136 |
+
outputs=[title, abstract],
|
137 |
+
fn=get_arxiv_title_and_abstract,
|
138 |
+
cache_examples=True,
|
139 |
+
)
|
140 |
+
|
141 |
+
demo.launch()
|
142 |
+
|
143 |
+
|
144 |
with gr.Blocks(title='Paper classifier') as demo:
|
145 |
gr.Markdown('# Paper Topic Classifier')
|
146 |
with gr.Row():
|