|
|
|
|
|
import streamlit as st |
|
import json |
|
import numpy as np |
|
import torch |
|
from transformers import ( |
|
DebertaV2Config, |
|
DebertaV2Model, |
|
DebertaV2Tokenizer, |
|
) |
|
|
|
model_name = "microsoft/deberta-v3-base" |
|
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name) |
|
|
|
def preprocess_text(text, tokenizer, max_length=512): |
|
inputs = tokenizer( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors="pt" |
|
) |
|
return inputs |
|
|
|
|
|
def classify_text(text, model, tokenizer, device, threshold=0.5): |
|
inputs = preprocess_text(text, tokenizer) |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
model.eval() |
|
with torch.no_grad(): |
|
logits = model(input_ids, attention_mask) |
|
probs = torch.sigmoid(logits) |
|
predictions = (probs > threshold).int().numpy() |
|
|
|
return probs.numpy(), predictions |
|
|
|
def get_themes(text, model, tokenizer, label_to_theme, device, limit=5): |
|
probabilities, _ = classify_text(text, model, tokenizer, device) |
|
probabilities = probabilities / probabilities.sum() |
|
themes = [] |
|
for label in probabilities[0].argsort()[-limit:]: |
|
themes.append((label_to_theme[str(label)], probabilities[0][label])) |
|
return themes |
|
|
|
class DebertPaperClassifier(torch.nn.Module): |
|
def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None): |
|
super().__init__() |
|
self.config = DebertaV2Config.from_pretrained(model_name) |
|
self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config) |
|
|
|
self.classifier = torch.nn.Sequential( |
|
torch.nn.Dropout(dropout_rate), |
|
torch.nn.Linear(self.config.hidden_size, 512), |
|
torch.nn.LayerNorm(512), |
|
torch.nn.GELU(), |
|
torch.nn.Dropout(dropout_rate), |
|
torch.nn.Linear(512, num_labels) |
|
) |
|
|
|
self._init_weights() |
|
if class_weights is not None: |
|
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) |
|
else: |
|
self.loss_fct = torch.nn.BCEWithLogitsLoss() |
|
|
|
class DebertPaperClassifierV5(torch.nn.Module): |
|
def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None): |
|
super().__init__() |
|
self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base") |
|
self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config) |
|
|
|
self.classifier = torch.nn.Sequential( |
|
torch.nn.Dropout(dropout_rate), |
|
torch.nn.Linear(self.config.hidden_size, 512), |
|
torch.nn.LayerNorm(512), |
|
torch.nn.GELU(), |
|
torch.nn.Dropout(dropout_rate), |
|
torch.nn.Linear(512, num_labels) |
|
) |
|
|
|
if class_weights is not None: |
|
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) |
|
else: |
|
self.loss_fct = torch.nn.BCEWithLogitsLoss() |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
outputs = self.deberta( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
logits = self.classifier(outputs.last_hidden_state[:, 0, :]) |
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_fct(logits, labels) |
|
return (loss, logits) if loss is not None else logits |
|
|
|
def _init_weights(self): |
|
for module in self.classifier.modules(): |
|
if isinstance(module, torch.nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def forward(self, |
|
input_ids, |
|
attention_mask, |
|
labels=None, |
|
): |
|
outputs = self.deberta( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
cls_output = outputs.last_hidden_state[:, 0, :] |
|
logits = self.classifier(cls_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_fct(logits, labels) |
|
|
|
return (loss, logits) if loss is not None else logits |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
with open('model_info/label_to_theme.json', 'r') as f: |
|
label_to_theme = json.load(f) |
|
|
|
model = DebertPaperClassifierV5(device=device, num_labels=len(label_to_theme)).to(device) |
|
model.load_state_dict(torch.load("model_info/deberta_v3.pth", map_location=device)) |
|
return model, tokenizer, label_to_theme, device |
|
|
|
def kek(): |
|
st.title("arXiv Paper Classifier") |
|
st.markdown(""" |
|
<style> |
|
.image-row { |
|
display: flex; |
|
flex-direction: row; |
|
gap: 10px; |
|
} |
|
</style> |
|
|
|
<div class="image-row"> |
|
<img width=100px src='https://storage.yandexcloud.net/lms-vault/media/cache/c9/a7/c9a754ba1b2bb5b34e1f178d4ec26f24.jpg'> |
|
<img width=300px src='https://pic.rutubelist.ru/video/ba/b6/bab6ab515c15837e28eb6c99df192cae.jpg'> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.write("write the title or abstract to classify topic theme") |
|
|
|
title = st.text_input("title") |
|
abstract = st.text_area("abstract") |
|
lim = int(st.number_input("top ? themes")) |
|
|
|
if st.button("CLASSIFY"): |
|
if not title and not abstract: |
|
st.warning("empty abstract!!!") |
|
return |
|
|
|
text = f"{title}\n\n{abstract}" if title and abstract else title or abstract |
|
model, tokenizer, label_to_theme, device = load_model() |
|
|
|
with st.spinner("classifying..."): |
|
themes = get_themes(text, model, tokenizer, label_to_theme, device, lim) |
|
co = 0 |
|
st.success(f"top {int(lim)} results:") |
|
for th, pr in themes: |
|
st.write(f"{lim - co}. - {th}: {pr:.1%}") |
|
co += 1 |
|
|
|
if __name__ == "__main__": |
|
kek() |
|
|