NlpDev / app.py
Nikita Pogadaev
adding model runner, first commit
c2c8638
#!/usr/bin/python3
import streamlit as st
import json
import numpy as np
import torch
from transformers import (
DebertaV2Config,
DebertaV2Model,
DebertaV2Tokenizer,
)
model_name = "microsoft/deberta-v3-base"
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name)
def preprocess_text(text, tokenizer, max_length=512):
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt"
)
return inputs
def classify_text(text, model, tokenizer, device, threshold=0.5):
inputs = preprocess_text(text, tokenizer)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
model.eval()
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.sigmoid(logits)
predictions = (probs > threshold).int().numpy()
return probs.numpy(), predictions
def get_themes(text, model, tokenizer, label_to_theme, device, limit=5):
probabilities, _ = classify_text(text, model, tokenizer, device)
probabilities = probabilities / probabilities.sum()
themes = []
for label in probabilities[0].argsort()[-limit:]:
themes.append((label_to_theme[str(label)], probabilities[0][label]))
return themes
class DebertPaperClassifier(torch.nn.Module):
def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None):
super().__init__()
self.config = DebertaV2Config.from_pretrained(model_name)
self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config)
self.classifier = torch.nn.Sequential(
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(self.config.hidden_size, 512),
torch.nn.LayerNorm(512),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(512, num_labels)
)
self._init_weights()
if class_weights is not None:
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
else:
self.loss_fct = torch.nn.BCEWithLogitsLoss()
class DebertPaperClassifierV5(torch.nn.Module):
def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None):
super().__init__()
self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
self.classifier = torch.nn.Sequential(
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(self.config.hidden_size, 512),
torch.nn.LayerNorm(512),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(512, num_labels)
)
if class_weights is not None:
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
else:
self.loss_fct = torch.nn.BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = self.classifier(outputs.last_hidden_state[:, 0, :])
loss = None
if labels is not None:
loss = self.loss_fct(logits, labels)
return (loss, logits) if loss is not None else logits
def _init_weights(self):
for module in self.classifier.modules():
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def forward(self,
input_ids,
attention_mask,
labels=None,
):
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask
)
cls_output = outputs.last_hidden_state[:, 0, :]
logits = self.classifier(cls_output)
loss = None
if labels is not None:
loss = self.loss_fct(logits, labels)
return (loss, logits) if loss is not None else logits
@st.cache_resource
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open('model_info/label_to_theme.json', 'r') as f:
label_to_theme = json.load(f)
model = DebertPaperClassifierV5(device=device, num_labels=len(label_to_theme)).to(device)
model.load_state_dict(torch.load("model_info/deberta_v3.pth", map_location=device))
return model, tokenizer, label_to_theme, device
def kek():
st.title("arXiv Paper Classifier")
st.markdown("""
<style>
.image-row {
display: flex;
flex-direction: row;
gap: 10px;
}
</style>
<div class="image-row">
<img width=100px src='https://storage.yandexcloud.net/lms-vault/media/cache/c9/a7/c9a754ba1b2bb5b34e1f178d4ec26f24.jpg'>
<img width=300px src='https://pic.rutubelist.ru/video/ba/b6/bab6ab515c15837e28eb6c99df192cae.jpg'>
</div>
""", unsafe_allow_html=True)
st.write("write the title or abstract to classify topic theme")
title = st.text_input("title")
abstract = st.text_area("abstract")
lim = int(st.number_input("top ? themes"))
if st.button("CLASSIFY"):
if not title and not abstract:
st.warning("empty abstract!!!")
return
text = f"{title}\n\n{abstract}" if title and abstract else title or abstract
model, tokenizer, label_to_theme, device = load_model()
with st.spinner("classifying..."):
themes = get_themes(text, model, tokenizer, label_to_theme, device, lim)
co = 0
st.success(f"top {int(lim)} results:")
for th, pr in themes:
st.write(f"{lim - co}. - {th}: {pr:.1%}")
co += 1
if __name__ == "__main__":
kek()