bvd757 commited on
Commit
6deef2b
·
1 Parent(s): 4b6287a
Files changed (1) hide show
  1. app.py +80 -94
app.py CHANGED
@@ -1,77 +1,23 @@
1
  import streamlit as st
2
  import json
3
  import numpy as np
 
 
4
  import torch
5
  from transformers import (
6
  DebertaV2Config,
7
  DebertaV2Model,
8
  DebertaV2Tokenizer,
9
  )
10
- import sentencepiece
11
 
12
- try:
13
- import sentencepiece
14
- except ImportError:
15
- st.error("Требуется установить SentencePiece: pip install sentencepiece")
16
- st.stop()
17
 
18
- model_name = "microsoft/deberta-v3-base"
19
- tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
20
-
21
- def preprocess_text(text, tokenizer, max_length=512):
22
- inputs = tokenizer(
23
- text,
24
- padding="max_length",
25
- truncation=True,
26
- max_length=max_length,
27
- return_tensors="pt"
28
- )
29
- return inputs
30
-
31
-
32
- def classify_text(text, model, tokenizer, device, threshold=0.5):
33
- inputs = preprocess_text(text, tokenizer)
34
- input_ids = inputs["input_ids"].to(device)
35
- attention_mask = inputs["attention_mask"].to(device)
36
- model.eval()
37
- with torch.no_grad():
38
- logits = model(input_ids, attention_mask)
39
- probs = torch.sigmoid(logits)
40
- predictions = (probs > threshold).int().cpu().numpy()
41
-
42
- return probs.cpu().numpy(), predictions
43
-
44
- def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
45
- probabilities, _ = classify_text(text, model, tokenizer, device)
46
- print(probabilities)
47
- themes = []
48
- for label in probabilities[0].argsort()[-limit:]:
49
- themes.append((label_to_theme[str(label)], probabilities[0][label]))
50
- return themes
51
-
52
- class DebertPaperClassifier(torch.nn.Module):
53
- def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None):
54
- super().__init__()
55
- self.config = DebertaV2Config.from_pretrained(model_name)
56
- self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config)
57
-
58
- self.classifier = torch.nn.Sequential(
59
- torch.nn.Dropout(dropout_rate),
60
- torch.nn.Linear(self.config.hidden_size, 512),
61
- torch.nn.LayerNorm(512),
62
- torch.nn.GELU(),
63
- torch.nn.Dropout(dropout_rate),
64
- torch.nn.Linear(512, num_labels)
65
- )
66
-
67
- self._init_weights()
68
- if class_weights is not None:
69
- self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
70
- else:
71
- self.loss_fct = torch.nn.BCEWithLogitsLoss()
72
-
73
- class DebertPaperClassifierV5(torch.nn.Module):
74
- def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None):
75
  super().__init__()
76
  self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
77
  self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
@@ -107,48 +53,88 @@ class DebertPaperClassifierV5(torch.nn.Module):
107
  module.weight.data.normal_(mean=0.0, std=0.02)
108
  if module.bias is not None:
109
  module.bias.data.zero_()
110
-
111
  @st.cache_resource
112
- def load_model(test=False):
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
- try:
115
- path = '/Users/bvd757/my_documents/Машинное обучение 2/Howework4'
116
- with open(f'{path}/label_to_theme.json', 'r') as f:
117
- label_to_theme = json.load(f)
118
- except:
119
- path = '/home/user/app'
120
- with open(f'{path}/label_to_theme.json', 'r') as f:
121
- label_to_theme = json.load(f)
122
-
123
- # model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
124
- # model.load_state_dict(torch.load("model_info/full_model_v4.pth", map_location=device))
125
-
126
- class_weights = torch.load(f'{path}/class_weights.pth').to(device)
127
- model = DebertPaperClassifierV5(device=device, num_labels=47, class_weights=class_weights).to(device)
128
- model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
129
- if test:
130
- print(device)
131
- print("Model!!!")
132
- text = 'We propose an architecture for VQA which utilizes recurrent layers to\ngenerate visual and textual attention. The memory characteristic of the\nproposed recurrent attention units offers a rich joint embedding of visual and\ntextual features and enables the model to reason relations between several\nparts of the image and question. Our single model outperforms the first place\nwinner on the VQA 1.0 dataset, performs within margin to the current\nstate-of-the-art ensemble model. We also experiment with replacing attention\nmechanisms in other state-of-the-art models with our implementation and show\nincreased accuracy. In both cases, our recurrent attention mechanism improves\nperformance in tasks requiring sequential or relational reasoning on the VQA\ndataset.'
133
- print(get_themes(text, model, tokenizer, label_to_theme, device))
134
  return model, tokenizer, label_to_theme, device
135
 
136
 
137
- def kek():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
 
 
 
 
 
 
 
 
139
  title = st.text_input("Title")
140
  abstract = st.text_area("Abstract")
141
-
142
  if st.button("Classify"):
143
  if not title and not abstract:
144
- st.warning("Please enter an abstract")
145
  return
146
 
147
- text = f"{title}\n\n{abstract}" if title and abstract else title or abstract
148
- model, tokenizer, label_to_theme, device = load_model()
149
-
150
- with st.spinner("Classifying..."):
151
- themes = get_themes(text, model, tokenizer, label_to_theme, device, 5)
 
 
 
 
 
152
 
153
  st.success("Classification results:")
154
  for theme, prob in themes:
@@ -156,4 +142,4 @@ def kek():
156
 
157
 
158
  if __name__ == "__main__":
159
- kek()
 
1
  import streamlit as st
2
  import json
3
  import numpy as np
4
+ import sentencepiece
5
+ from pathlib import Path
6
  import torch
7
  from transformers import (
8
  DebertaV2Config,
9
  DebertaV2Model,
10
  DebertaV2Tokenizer,
11
  )
 
12
 
13
+ MODEL_NAME = "microsoft/deberta-v3-base"
14
+ MAX_LENGTH = 512
15
+ NUM_LABELS = 47
16
+ DROPOUT_RATE = 0.1
17
+ THRESHOLD = 0.5
18
 
19
+ class DebertaV3PaperClassifier(torch.nn.Module):
20
+ def __init__(self, device, num_labels=NUM_LABELS, dropout_rate=DROPOUT_RATE, class_weights=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  super().__init__()
22
  self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
23
  self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
 
53
  module.weight.data.normal_(mean=0.0, std=0.02)
54
  if module.bias is not None:
55
  module.bias.data.zero_()
56
+
57
  @st.cache_resource
58
+ def load_assets():
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ base_path = Path(__file__).parent
62
+ with open(base_path/"label_to_theme.json") as f:
63
+ label_to_theme = json.load(f)
64
+
65
+ class_weights = torch.load(f"{base_path}/class_weights.pth").to(device)
66
+
67
+ model = DebertaV3PaperClassifier(device=device, num_labels=NUM_LABELS, class_weights=class_weights).to(device)
68
+ model.load_state_dict(torch.load(f"{base_path}/full_model_v4.pth", map_location=device))
69
+ model.eval()
70
+
71
+ tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
72
  return model, tokenizer, label_to_theme, device
73
 
74
 
75
+ def preprocess_text(text, tokenizer, max_length=MAX_LENGTH):
76
+ inputs = tokenizer(
77
+ text,
78
+ padding="max_length",
79
+ truncation=True,
80
+ max_length=max_length,
81
+ return_tensors="pt"
82
+ )
83
+ return inputs
84
+
85
+
86
+ def predict(text: str, model, tokenizer, device) -> list:
87
+ """Run model prediction on input text."""
88
+ inputs = preprocess_text(text, tokenizer)
89
+
90
+ with torch.no_grad():
91
+ logits = model(
92
+ input_ids=inputs["input_ids"].to(device),
93
+ attention_mask=inputs["attention_mask"].to(device)
94
+ )
95
+
96
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
97
+ return probs
98
+
99
+
100
+ def get_themes(probs: np.ndarray, label_to_theme: dict) -> list:
101
+ """Get top K themes with probabilities."""
102
+ sorted_indices = np.argsort(-probs)
103
+ labels = []
104
+ sum_percent = 0
105
+ for idx in sorted_indices:
106
+ labels.append((label_to_theme[str(idx)], probs[idx]))
107
+ sum_percent += probs[idx]
108
+ if sum_percent >= 0.95:
109
+ break
110
 
111
+ return labels
112
+
113
+
114
+ def main():
115
+ st.title("Paper Classification App")
116
+ st.write("Classify research papers using DeBERTa model")
117
+
118
+ model, tokenizer, label_to_theme, device = load_assets()
119
+
120
  title = st.text_input("Title")
121
  abstract = st.text_area("Abstract")
122
+
123
  if st.button("Classify"):
124
  if not title and not abstract:
125
+ st.warning("Please enter title and/or abstract")
126
  return
127
 
128
+ if abstract is None:
129
+ text = title
130
+ elif text is None:
131
+ text = abstract
132
+ else:
133
+ text = f"{title}\n\n{abstract}"
134
+
135
+ with st.spinner("Analyzing text..."):
136
+ probabilities = predict(text, model, tokenizer, device)
137
+ themes = get_themes(probabilities, label_to_theme)
138
 
139
  st.success("Classification results:")
140
  for theme, prob in themes:
 
142
 
143
 
144
  if __name__ == "__main__":
145
+ main()