bvd757 commited on
Commit
283dd4a
·
1 Parent(s): 31d74a6

Initual commit

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +173 -0
  3. class_weights.pth +0 -0
  4. full_model_v4.pth +3 -0
  5. label_to_theme.json +1 -0
  6. requirements.txt +121 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ full_model_v4.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ from transformers import (
6
+ DebertaV2Config,
7
+ DebertaV2Model,
8
+ DebertaV2Tokenizer,
9
+ )
10
+ import sentencepiece
11
+
12
+ model_name = "microsoft/deberta-v3-base"
13
+ tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
14
+
15
+ def preprocess_text(text, tokenizer, max_length=512):
16
+ inputs = tokenizer(
17
+ text,
18
+ padding="max_length",
19
+ truncation=True,
20
+ max_length=max_length,
21
+ return_tensors="pt"
22
+ )
23
+ return inputs
24
+
25
+
26
+ def classify_text(text, model, tokenizer, device, threshold=0.5):
27
+ inputs = preprocess_text(text, tokenizer)
28
+ input_ids = inputs["input_ids"].to(device)
29
+ attention_mask = inputs["attention_mask"].to(device)
30
+ model.eval()
31
+ with torch.no_grad():
32
+ logits = model(input_ids, attention_mask)
33
+ probs = torch.sigmoid(logits)
34
+ predictions = (probs > threshold).int().cpu().numpy()
35
+
36
+ return probs.cpu().numpy(), predictions
37
+
38
+ def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
39
+ probabilities, _ = classify_text(text, model, tokenizer, device)
40
+ themes = []
41
+ for label in probabilities[0].argsort()[-limit:]:
42
+ themes.append((label_to_theme[str(label)], probabilities[0][label]))
43
+ return themes
44
+
45
+ class DebertPaperClassifier(torch.nn.Module):
46
+ def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None):
47
+ super().__init__()
48
+ self.config = DebertaV2Config.from_pretrained(model_name)
49
+ self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config)
50
+
51
+ self.classifier = torch.nn.Sequential(
52
+ torch.nn.Dropout(dropout_rate),
53
+ torch.nn.Linear(self.config.hidden_size, 512),
54
+ torch.nn.LayerNorm(512),
55
+ torch.nn.GELU(),
56
+ torch.nn.Dropout(dropout_rate),
57
+ torch.nn.Linear(512, num_labels)
58
+ )
59
+
60
+ self._init_weights()
61
+ if class_weights is not None:
62
+ self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
63
+ else:
64
+ self.loss_fct = torch.nn.BCEWithLogitsLoss()
65
+
66
+ class DebertPaperClassifierV5(torch.nn.Module):
67
+ def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None):
68
+ super().__init__()
69
+ self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
70
+ self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
71
+
72
+ self.classifier = torch.nn.Sequential(
73
+ torch.nn.Dropout(dropout_rate),
74
+ torch.nn.Linear(self.config.hidden_size, 512),
75
+ torch.nn.LayerNorm(512),
76
+ torch.nn.GELU(),
77
+ torch.nn.Dropout(dropout_rate),
78
+ torch.nn.Linear(512, num_labels)
79
+ )
80
+
81
+ if class_weights is not None:
82
+ self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
83
+ else:
84
+ self.loss_fct = torch.nn.BCEWithLogitsLoss()
85
+
86
+ def forward(self, input_ids, attention_mask, labels=None):
87
+ outputs = self.deberta(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask
90
+ )
91
+ logits = self.classifier(outputs.last_hidden_state[:, 0, :])
92
+ loss = None
93
+ if labels is not None:
94
+ loss = self.loss_fct(logits, labels)
95
+ return (loss, logits) if loss is not None else logits
96
+
97
+ def _init_weights(self):
98
+ for module in self.classifier.modules():
99
+ if isinstance(module, torch.nn.Linear):
100
+ module.weight.data.normal_(mean=0.0, std=0.02)
101
+ if module.bias is not None:
102
+ module.bias.data.zero_()
103
+
104
+ def forward(self,
105
+ input_ids,
106
+ attention_mask,
107
+ labels=None,
108
+ ):
109
+ outputs = self.deberta(
110
+ input_ids=input_ids,
111
+ attention_mask=attention_mask
112
+ )
113
+
114
+ cls_output = outputs.last_hidden_state[:, 0, :]
115
+ logits = self.classifier(cls_output)
116
+
117
+ loss = None
118
+ if labels is not None:
119
+ loss = self.loss_fct(logits, labels)
120
+
121
+ return (loss, logits) if loss is not None else logits
122
+
123
+ @st.cache_resource
124
+ def load_model(test=False):
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ try:
127
+ path = 'Homework4'
128
+ with open(f'{path}/label_to_theme.json', 'r') as f:
129
+ label_to_theme = json.load(f)
130
+ except:
131
+ path = '/Users/bvd757/my_documents/Машинное обучение 2/Howework4'
132
+ with open(f'{path}/label_to_theme.json', 'r') as f:
133
+ label_to_theme = json.load(f)
134
+
135
+ class_weights = torch.load(f'{path}/class_weights.pth').to(device)
136
+
137
+ model = DebertPaperClassifier(device=device, num_labels=len(label_to_theme), class_weights=class_weights).to(device)
138
+ model.load_state_dict(torch.load(f"{path}/full_model_v4.pth", map_location=device))
139
+ if test:
140
+ print(device)
141
+ print(model)
142
+ print("Model!!!")
143
+ text = 'We propose an architecture for VQA which utilizes recurrent layers to\ngenerate visual and textual attention. The memory characteristic of the\nproposed recurrent attention units offers a rich joint embedding of visual and\ntextual features and enables the model to reason relations between several\nparts of the image and question. Our single model outperforms the first place\nwinner on the VQA 1.0 dataset, performs within margin to the current\nstate-of-the-art ensemble model. We also experiment with replacing attention\nmechanisms in other state-of-the-art models with our implementation and show\nincreased accuracy. In both cases, our recurrent attention mechanism improves\nperformance in tasks requiring sequential or relational reasoning on the VQA\ndataset.'
144
+ print(get_themes(text, model, tokenizer, label_to_theme, device))
145
+ return model, tokenizer, label_to_theme, device
146
+
147
+ def kek():
148
+
149
+ title = st.text_input("Title")
150
+ abstract = st.text_area("Abstract")
151
+
152
+ if st.button("Classify"):
153
+ if not title and not abstract:
154
+ st.warning("Please enter an abstract")
155
+ return
156
+
157
+ text = f"{title}\n\n{abstract}" if title and abstract else title or abstract
158
+ model, tokenizer, label_to_theme, device = load_model()
159
+
160
+ with st.spinner("Classifying..."):
161
+ themes = get_themes(text, model, tokenizer, label_to_theme, device, lim)
162
+
163
+ st.success("Classification results:")
164
+ for theme, prob in themes:
165
+ st.write(f"- {theme}: {prob:.2%}")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ inp = '0'
170
+ if inp != '0':
171
+ kek()
172
+ else:
173
+ load_model(True)
class_weights.pth ADDED
Binary file (1.53 kB). View file
 
full_model_v4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a8bd37371fde6f8e3c74815a6393577db39a188aca551475c862d70c67c98b3
3
+ size 737088702
label_to_theme.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "cs.AI", "1": "physics.soc-ph", "2": "stat.ML", "3": "cs.CE", "4": "cs.DB", "5": "cs.CL", "6": "cs.NA", "7": "cs.CY", "8": "cs.GT", "9": "cs.SI", "10": "stat.AP", "11": "cs.DL", "12": "math.ST", "13": "nlin.AO", "14": "cs.LO", "15": "cs.MM", "16": "cond-mat.dis-nn", "17": "cs.DM", "18": "cs.CC", "19": "stat.CO", "20": "cs.DC", "21": "cs.IT", "22": "cs.DS", "23": "cs.SY", "24": "q-bio.QM", "25": "cs.PL", "26": "cs.RO", "27": "cs.NE", "28": "cs.CR", "29": "cs.MA", "30": "q-bio.NC", "31": "cs.LG", "32": "cs.GR", "33": "physics.data-an", "34": "quant-ph", "35": "cs.IR", "36": "math.NA", "37": "math.PR", "38": "stat.ME", "39": "cs.SE", "40": "math.OC", "41": "math.IT", "42": "cs.HC", "43": "stat.TH", "44": "cs.NI", "45": "cs.CV", "46": "cs.SD"}
requirements.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.11.14
3
+ aiosignal==1.3.2
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ appnope==0.1.4
7
+ asttokens==3.0.0
8
+ async-timeout==5.0.1
9
+ attrs==25.3.0
10
+ blinker==1.9.0
11
+ cachetools==5.5.2
12
+ certifi==2025.1.31
13
+ charset-normalizer==3.4.1
14
+ click==8.1.8
15
+ colored==2.3.0
16
+ comm==0.2.2
17
+ contourpy==1.3.0
18
+ cycler==0.12.1
19
+ datasets==3.5.0
20
+ debugpy==1.8.12
21
+ decorator==5.2.1
22
+ dill==0.3.8
23
+ distlib==0.3.9
24
+ docker-pycreds==0.4.0
25
+ einops==0.8.1
26
+ eval_type_backport==0.2.2
27
+ exceptiongroup==1.2.2
28
+ executing==2.2.0
29
+ filelock==3.17.0
30
+ fonttools==4.56.0
31
+ frozenlist==1.5.0
32
+ fsspec==2024.12.0
33
+ gitdb==4.0.12
34
+ GitPython==3.1.44
35
+ huggingface-hub==0.29.3
36
+ idna==3.10
37
+ imageio==2.37.0
38
+ importlib_metadata==8.6.1
39
+ importlib_resources==6.5.2
40
+ ipykernel==6.29.5
41
+ ipython==8.18.1
42
+ jedi==0.19.2
43
+ Jinja2==3.1.5
44
+ joblib==1.4.2
45
+ jsonschema==4.23.0
46
+ jsonschema-specifications==2024.10.1
47
+ jupyter_client==8.6.3
48
+ jupyter_core==5.7.2
49
+ kiwisolver==1.4.7
50
+ lazy_loader==0.4
51
+ MarkupSafe==3.0.2
52
+ matplotlib==3.9.4
53
+ matplotlib-inline==0.1.7
54
+ mpmath==1.3.0
55
+ multidict==6.2.0
56
+ multiprocess==0.70.16
57
+ narwhals==1.33.0
58
+ nest-asyncio==1.6.0
59
+ networkx==3.2.1
60
+ numpy==2.0.2
61
+ opencv-python==4.11.0.86
62
+ packaging==24.2
63
+ pandas==2.2.3
64
+ parso==0.8.4
65
+ pexpect==4.9.0
66
+ pillow==11.1.0
67
+ platformdirs==4.3.6
68
+ prompt_toolkit==3.0.50
69
+ propcache==0.3.1
70
+ protobuf==5.29.4
71
+ psutil==7.0.0
72
+ ptyprocess==0.7.0
73
+ pure_eval==0.2.3
74
+ pyarrow==19.0.1
75
+ pydantic==2.11.1
76
+ pydantic_core==2.33.0
77
+ pydeck==0.9.1
78
+ Pygments==2.19.1
79
+ pyparsing==3.2.1
80
+ python-dateutil==2.9.0.post0
81
+ pytz==2025.1
82
+ PyYAML==6.0.2
83
+ pyzmq==26.2.1
84
+ referencing==0.36.2
85
+ regex==2024.11.6
86
+ requests==2.32.3
87
+ rpds-py==0.24.0
88
+ safetensors==0.5.3
89
+ scikit-image==0.24.0
90
+ scikit-learn==1.6.1
91
+ scipy==1.13.1
92
+ seaborn==0.13.2
93
+ sentencepiece==0.2.0
94
+ sentry-sdk==2.24.1
95
+ setproctitle==1.3.5
96
+ smmap==5.0.2
97
+ stack-data==0.6.3
98
+ streamlit==1.44.1
99
+ sympy==1.13.1
100
+ tenacity==9.1.2
101
+ termcolor==2.5.0
102
+ threadpoolctl==3.5.0
103
+ tifffile==2024.8.30
104
+ tokenizers==0.21.1
105
+ toml==0.10.2
106
+ torch==2.2.0
107
+ torchvision==0.17.0
108
+ tornado==6.4.2
109
+ tqdm==4.67.1
110
+ traitlets==5.14.3
111
+ transformers==4.50.3
112
+ typing-inspection==0.4.0
113
+ typing_extensions==4.12.2
114
+ tzdata==2025.1
115
+ urllib3==2.3.0
116
+ virtualenv==20.30.0
117
+ wandb==0.19.8
118
+ wcwidth==0.2.13
119
+ xxhash==3.5.0
120
+ yarl==1.18.3
121
+ zipp==3.21.0