Yozhikoff commited on
Commit
cd9e1fc
Β·
1 Parent(s): 6ebf305

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import xml.etree.ElementTree as ET
4
+ import re
5
+ import urllib
6
+ import torch
7
+
8
+
9
+ from transformers import pipeline
10
+
11
+ classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
12
+
13
+
14
+ def get_arxiv_title_and_abstract(link):
15
+ # Regular expression pattern for arXiv link validation
16
+ try:
17
+ pattern = r'^https?://arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5})(?:\.pdf)?/?$'
18
+ match = re.match(pattern, link)
19
+
20
+ if not match:
21
+ raise ValueError("Invalid arXiv link")
22
+
23
+ # Construct the arXiv API URL for the paper
24
+ arxiv_id = match.group(1)
25
+ api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
26
+
27
+ # Retrieve the paper metadata using the arXiv API
28
+ with urllib.request.urlopen(api_url) as response:
29
+ xml_data = response.read().decode()
30
+
31
+ # Extract the title and abstract from the XML data
32
+ title = re.search(r'<title>(.*?)</title>', xml_data).group(1)
33
+ abstract = re.search(r'<summary>(.*?)</summary>', xml_data, re.DOTALL).group(1)
34
+
35
+ # Clean up the title and abstract
36
+ title = re.sub(r'\s+', ' ', title).strip()
37
+ abstract = re.sub(r'\s+', ' ', abstract).strip()
38
+
39
+ return title, abstract
40
+ except:
41
+ raise gr.Error('Invalid arXiv URL!')
42
+
43
+
44
+ def classify_paper(title, abstract):
45
+ text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
46
+ item = classifier.tokenizer(text)
47
+ input_tensor = torch.tensor(item['input_ids'])[None]
48
+ logits = classifier.model(input_tensor).logits[0]
49
+ preds = torch.sigmoid(logits).detach().cpu().numpy()
50
+ result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.1}
51
+ return result
52
+
53
+
54
+ with gr.Blocks(title='Paper classifier') as demo:
55
+
56
+ gr.Markdown('Please enter an arXiv link **OR** fill title and abstract manually')
57
+ with gr.Row():
58
+ with gr.Column():
59
+ arxiv_link = gr.Textbox(label="Arxiv link")
60
+
61
+ b1 = gr.Button("Parse Link")
62
+
63
+ title = gr.Textbox(label="Paper title")
64
+ abstract = gr.Textbox(label="Paper abstract")
65
+
66
+ b2 = gr.Button("Classify Paper", variant='primary')
67
+
68
+ b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
69
+
70
+
71
+ with gr.Column():
72
+ out = gr.Label(label="Topics")
73
+ b2.click(classify_paper, inputs=[title, abstract], outputs=out)
74
+
75
+ gr.Markdown('## Examples')
76
+ gr.Examples(
77
+ examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/1503.04376'], ['https://arxiv.org/abs/2201.06601']],
78
+ inputs=arxiv_link,
79
+ outputs=[title, abstract],
80
+ fn=get_arxiv_title_and_abstract,
81
+ cache_examples=True,
82
+ )
83
+
84
+ demo.launch(share=True)