Yozhikoff commited on
Commit
f840f33
Β·
1 Parent(s): 1516d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -1
app.py CHANGED
@@ -10,6 +10,49 @@ from transformers import pipeline
10
  classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import re
14
  import urllib.request
15
  import xml.etree.ElementTree as ET
@@ -50,10 +93,54 @@ def classify_paper(title, abstract):
50
  input_tensor = torch.tensor(item['input_ids'])[None]
51
  logits = classifier.model(input_tensor).logits[0]
52
  preds = torch.sigmoid(logits).detach().cpu().numpy()
53
- result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.25}
 
 
 
 
 
 
 
54
  return result
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  with gr.Blocks(title='Paper classifier') as demo:
58
  gr.Markdown('# Paper Topic Classifier')
59
  with gr.Row():
 
10
  classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
11
 
12
 
13
+ import re
14
+ import urllib.request
15
+ import xml.etree.ElementTree as ET
16
+
17
+ def get_arxiv_title_and_abstract(link):
18
+ try:
19
+ # Validate the arxiv link
20
+ pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
21
+ match = re.match(pattern, link)
22
+ if not match:
23
+ raise ValueError('Invalid arxiv link')
24
+
25
+ # Construct the arxiv API URL
26
+ arxiv_id = match.group(2)
27
+ api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
28
+
29
+ # Send a request to the arxiv API
30
+ response = urllib.request.urlopen(api_url)
31
+ xml_data = response.read()
32
+
33
+ # Parse the XML data
34
+ root = ET.fromstring(xml_data)
35
+ entry = root.find('{http://www.w3.org/2005/Atom}entry')
36
+ title = entry.find('{http://www.w3.org/2005/Atom}title').text
37
+ summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
38
+
39
+ return title, summary
40
+ except:
41
+ raise gr.Error('Invalid arXiv URL!')
42
+
43
+
44
+ import gradio as gr
45
+ import xml.etree.ElementTree as ET
46
+ import re
47
+ import urllib
48
+ import torch
49
+
50
+
51
+ from transformers import pipeline
52
+
53
+ classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")
54
+
55
+
56
  import re
57
  import urllib.request
58
  import xml.etree.ElementTree as ET
 
93
  input_tensor = torch.tensor(item['input_ids'])[None]
94
  logits = classifier.model(input_tensor).logits[0]
95
  preds = torch.sigmoid(logits).detach().cpu().numpy()
96
+ result = {}
97
+ for num, prob in enumerate(preds):
98
+ if prob < 0.25:
99
+ continue
100
+ if classifier.model.config.id2label[num] in result:
101
+ if result[classifier.model.config.id2label[num]] > prob:
102
+ continue
103
+ result[classifier.model.config.id2label[num]] = float(prob)
104
  return result
105
 
106
 
107
+ with gr.Blocks(title='Paper classifier') as demo:
108
+ gr.Markdown('# Paper Topic Classifier')
109
+ with gr.Row():
110
+ with gr.Column():
111
+ gr.Markdown('## Inputs')
112
+ gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
113
+ arxiv_link = gr.Textbox(label="Arxiv link", placeholder="Flip this text")
114
+
115
+ b1 = gr.Button("Parse Link")
116
+
117
+ title = gr.Textbox(label="Paper title", placeholder="Title text")
118
+ abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")
119
+
120
+ b2 = gr.Button("Classify Paper", variant='primary')
121
+
122
+ b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")
123
+
124
+
125
+ with gr.Column():
126
+ gr.Markdown('## Topics')
127
+ gr.Markdown('## ')
128
+ gr.Markdown('## ')
129
+ out = gr.Label(label="Topics")
130
+ b2.click(classify_paper, inputs=[title, abstract], outputs=out)
131
+
132
+ gr.Markdown('## Examples')
133
+ gr.Examples(
134
+ examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
135
+ inputs=arxiv_link,
136
+ outputs=[title, abstract],
137
+ fn=get_arxiv_title_and_abstract,
138
+ cache_examples=True,
139
+ )
140
+
141
+ demo.launch()
142
+
143
+
144
  with gr.Blocks(title='Paper classifier') as demo:
145
  gr.Markdown('# Paper Topic Classifier')
146
  with gr.Row():