Spaces:
Running
Running
import streamlit as st | |
import json | |
import numpy as np | |
import sentencepiece | |
from pathlib import Path | |
import torch | |
from transformers import ( | |
DebertaV2Config, | |
DebertaV2Model, | |
DebertaV2Tokenizer, | |
) | |
MODEL_NAME = "microsoft/deberta-v3-base" | |
MAX_LENGTH = 512 | |
NUM_LABELS = 47 | |
DROPOUT_RATE = 0.1 | |
THRESHOLD = 0.5 | |
class DebertaV3PaperClassifier(torch.nn.Module): | |
def __init__(self, device, num_labels=NUM_LABELS, dropout_rate=DROPOUT_RATE, class_weights=None): | |
super().__init__() | |
self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base") | |
self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config) | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Dropout(dropout_rate), | |
torch.nn.Linear(self.config.hidden_size, 512), | |
torch.nn.LayerNorm(512), | |
torch.nn.GELU(), | |
torch.nn.Dropout(dropout_rate), | |
torch.nn.Linear(512, num_labels) | |
) | |
if class_weights is not None: | |
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) | |
else: | |
self.loss_fct = torch.nn.BCEWithLogitsLoss() | |
def forward(self, input_ids, attention_mask, labels=None): | |
outputs = self.deberta( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
logits = self.classifier(outputs.last_hidden_state[:, 0, :]) | |
loss = None | |
if labels is not None: | |
loss = self.loss_fct(logits, labels) | |
return (loss, logits) if loss is not None else logits | |
def _init_weights(self): | |
for module in self.classifier.modules(): | |
if isinstance(module, torch.nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
def load_assets(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
base_path = Path(__file__).parent | |
with open(base_path/"label_to_theme.json") as f: | |
label_to_theme = json.load(f) | |
with open(base_path/"theme_to_descripiton.json") as f: | |
theme_to_description = json.load(f) | |
class_weights = torch.load(f"{base_path}/class_weights.pth").to(device) | |
model = DebertaV3PaperClassifier(device=device, num_labels=NUM_LABELS, class_weights=class_weights).to(device) | |
model.load_state_dict(torch.load(f"{base_path}/full_model_v4.pth", map_location=device)) | |
model.eval() | |
tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME) | |
return model, tokenizer, device, label_to_theme, theme_to_description | |
def preprocess_text(text, tokenizer, max_length=MAX_LENGTH): | |
inputs = tokenizer( | |
text, | |
padding="max_length", | |
truncation=True, | |
max_length=max_length, | |
return_tensors="pt" | |
) | |
return inputs | |
def predict(text: str, model, tokenizer, device) -> list: | |
"""Run model prediction on input text.""" | |
inputs = preprocess_text(text, tokenizer) | |
with torch.no_grad(): | |
logits = model( | |
input_ids=inputs["input_ids"].to(device), | |
attention_mask=inputs["attention_mask"].to(device) | |
) | |
probs = torch.sigmoid(logits).cpu().numpy()[0] | |
return probs | |
def get_themes(probs: np.ndarray, label_to_theme: dict) -> list: | |
"""Get top K themes with probabilities.""" | |
sorted_indices = np.argsort(-probs) | |
labels = [] | |
sum_percent = 0 | |
for idx in sorted_indices: | |
labels.append((label_to_theme[str(idx)], probs[idx])) | |
sum_percent += probs[idx] | |
if sum_percent >= 0.95: | |
break | |
return labels | |
def main(): | |
st.title("Paper Classification App") | |
st.write("Classify research papers using DeBERTa model") | |
model, tokenizer, device, label_to_theme, theme_to_description = load_assets() | |
title = st.text_input("Title") | |
abstract = st.text_area("Abstract") | |
if st.button("Classify"): | |
if not title and not abstract: | |
st.warning("Please enter title and/or abstract") | |
return | |
if abstract is None: | |
text = title | |
elif title is None: | |
text = abstract | |
else: | |
text = f"{title}\n\n{abstract}" | |
with st.spinner("Analyzing text..."): | |
probabilities = predict(text, model, tokenizer, device) | |
themes = get_themes(probabilities, label_to_theme) | |
st.success("Predicted themes (click to expand):") | |
# for theme, prob in themes: | |
# st.write(f"- {theme}: {prob:.2%}") | |
for theme, prob in themes: | |
with st.expander(f"{theme} ({prob:.1%})"): | |
st.markdown(f"**Description**: {theme_to_description[theme]}") | |
if __name__ == "__main__": | |
main() | |