File size: 4,655 Bytes
cd9e1fc
 
f840f33
 
 
 
 
 
 
 
 
 
 
cd9e1fc
 
84837a0
 
cd9e1fc
 
84837a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd9e1fc
 
 
 
 
dcadea4
 
 
cd9e1fc
 
 
 
 
f840f33
 
 
 
 
 
 
 
cd9e1fc
 
 
f840f33
 
 
 
 
 
65e66d9
f840f33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd9e1fc
8defca4
cd9e1fc
 
8defca4
0749c7f
1516d05
cd9e1fc
 
 
18b8fa4
 
cd9e1fc
 
 
 
 
 
 
5055671
0749c7f
dcadea4
cd9e1fc
 
 
 
 
9e56b79
cd9e1fc
 
 
 
 
 
e993597
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import re
import urllib
import urllib.request
import xml.etree.ElementTree as ET

import gradio as gr
import torch

from transformers import pipeline

classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")


def get_arxiv_title_and_abstract(link):
    try:
        # Validate the arxiv link
        pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
        match = re.match(pattern, link)
        if not match:
            raise ValueError('Invalid arxiv link')
        
        # Construct the arxiv API URL
        arxiv_id = match.group(2)
        api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
        
        # Send a request to the arxiv API
        response = urllib.request.urlopen(api_url)
        xml_data = response.read()
        
        # Parse the XML data
        root = ET.fromstring(xml_data)
        entry = root.find('{http://www.w3.org/2005/Atom}entry')
        title = entry.find('{http://www.w3.org/2005/Atom}title').text
        summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
        
        return title, summary
    except:
        raise gr.Error('Invalid arXiv URL!')


def classify_paper(title, abstract):
    if title == '' and abstract == '':
        raise gr.Error('Fill Title or/and Abstract')
    
    text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
    item = classifier.tokenizer(text)
    input_tensor = torch.tensor(item['input_ids'])[None]
    logits = classifier.model(input_tensor).logits[0]
    preds = torch.sigmoid(logits).detach().cpu().numpy()
    result = {}
    for num, prob in enumerate(preds):
        if prob < 0.25:
            continue
        if classifier.model.config.id2label[num] in result:
            if result[classifier.model.config.id2label[num]] > prob:
                continue
        result[classifier.model.config.id2label[num]] = float(prob)
    return result
    

with gr.Blocks(title='Paper classifier') as demo:
    gr.Markdown('# Paper Topic Classifier')
    with gr.Row():
        with gr.Column():
            gr.Markdown('## Inputs')
            gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
            arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")

            b1 = gr.Button("Parse Link")

            title = gr.Textbox(label="Paper title", placeholder="Title text")
            abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")

            b2 = gr.Button("Classify Paper", variant='primary')
            
            b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")


        with gr.Column():
            gr.Markdown('## Topics')
            gr.Markdown('## ')
            gr.Markdown('## ')
            out = gr.Label(label="Topics")
            b2.click(classify_paper, inputs=[title, abstract], outputs=out)
    
    gr.Markdown('## Examples')
    gr.Examples(
        examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
        inputs=arxiv_link,
        outputs=[title, abstract],
        fn=get_arxiv_title_and_abstract,
        cache_examples=True,
    )

demo.launch()
    

with gr.Blocks(title='Paper classifier') as demo:
    gr.Markdown('# Paper Topic Classifier')
    with gr.Row():
        with gr.Column():
            gr.Markdown('## Inputs')
            gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
            arxiv_link = gr.Textbox(label="Arxiv link", placeholder="https://arxiv.org/abs/1706.03762")

            b1 = gr.Button("Parse Link")

            title = gr.Textbox(label="Paper title", placeholder="Title text")
            abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")

            b2 = gr.Button("Classify Paper", variant='primary')
            
            b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")


        with gr.Column():
            gr.Markdown('## Topics')
            gr.Markdown('## ')
            gr.Markdown('## ')
            out = gr.Label(label="Topics")
            b2.click(classify_paper, inputs=[title, abstract], outputs=out)
    
    gr.Markdown('## Examples')
    gr.Examples(
        examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
        inputs=arxiv_link,
        outputs=[title, abstract],
        fn=get_arxiv_title_and_abstract,
        cache_examples=True,
    )

demo.launch()