fellafrom26 commited on
Commit
ca4bc25
·
verified ·
1 Parent(s): 4f65f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -44
app.py CHANGED
@@ -1,44 +1,51 @@
1
- # app.py
2
- import streamlit as st
3
- from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
- import torch
5
- import numpy as np
6
-
7
- @st.cache(allow_output_mutation=True)
8
- def load_model():
9
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
10
- model = DistilBertForSequenceClassification.from_pretrained('model/')
11
- return tokenizer, model
12
-
13
- tokenizer, model = load_model()
14
-
15
- st.title('arXiv Article Classifier')
16
- title = st.text_input('Title')
17
- abstract = st.text_area('Abstract')
18
- text = title + ' ' + abstract if abstract else title
19
-
20
- if st.button('Predict'):
21
- if not text.strip():
22
- st.error('Please enter at least a title.')
23
- else:
24
- inputs = tokenizer(
25
- text,
26
- truncation=True,
27
- padding=True,
28
- max_length=512,
29
- return_tensors='pt'
30
- )
31
- with torch.no_grad():
32
- logits = model(**inputs).logits
33
- probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0]
34
- sorted_indices = np.argsort(-probs)
35
-
36
- cumulative = 0
37
- result = []
38
- for idx in sorted_indices:
39
- cumulative += probs[idx]
40
- result.append((model.config.id2label[idx], probs[idx]))
41
- if cumulative >= 0.95:
42
- break
43
- for tag, prob in result:
44
- st.write(f'{tag}: {prob:.2%}')
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
3
+ import torch
4
+ import numpy as np
5
+
6
+ MAPPING = {
7
+ 'cs': 'Computer Science', 'econ': 'Economics', 'eess': 'Electrical Engineering and Systems Science', 'math': 'Mathematics',
8
+ 'q-bio': 'Quantitative Biology', 'q-fin': 'Quantitative Finance', 'stat': 'Statistics'
9
+ }
10
+
11
+ @st.cache_resource
12
+ def load_model():
13
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
14
+ model = DistilBertForSequenceClassification.from_pretrained('model/')
15
+ return tokenizer, model
16
+
17
+ tokenizer, model = load_model()
18
+
19
+ st.title('arXiv Article Classifier')
20
+ title = st.text_input('Title')
21
+ abstract = st.text_area('Abstract')
22
+ text = title + ' ' + abstract if abstract else title
23
+
24
+ if st.button('Predict'):
25
+ if not text.strip():
26
+ st.error('Please enter at least a title.')
27
+ else:
28
+ inputs = tokenizer(
29
+ text,
30
+ truncation=True,
31
+ padding=True,
32
+ max_length=512,
33
+ return_tensors='pt'
34
+ )
35
+ with torch.no_grad():
36
+ logits = model(**inputs).logits
37
+ probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0]
38
+ sorted_indices = np.argsort(-probs)
39
+
40
+ cumulative = 0
41
+ result = []
42
+ for idx in sorted_indices:
43
+ cumulative += probs[idx]
44
+ result.append((model.config.id2label[idx], probs[idx]))
45
+ if cumulative >= 0.95:
46
+ break
47
+ for tag, prob in result:
48
+ if tag in MAPPING:
49
+ st.write(f'{MAPPING[tag]}: {prob:.2%}')
50
+ else:
51
+ st.write(f'{tag}: {prob:.2%}')