bvd757's picture
theme_to_description
6222f2c
raw
history blame contribute delete
4.88 kB
import streamlit as st
import json
import numpy as np
import sentencepiece
from pathlib import Path
import torch
from transformers import (
DebertaV2Config,
DebertaV2Model,
DebertaV2Tokenizer,
)
MODEL_NAME = "microsoft/deberta-v3-base"
MAX_LENGTH = 512
NUM_LABELS = 47
DROPOUT_RATE = 0.1
THRESHOLD = 0.5
class DebertaV3PaperClassifier(torch.nn.Module):
def __init__(self, device, num_labels=NUM_LABELS, dropout_rate=DROPOUT_RATE, class_weights=None):
super().__init__()
self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base")
self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config)
self.classifier = torch.nn.Sequential(
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(self.config.hidden_size, 512),
torch.nn.LayerNorm(512),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(512, num_labels)
)
if class_weights is not None:
self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
else:
self.loss_fct = torch.nn.BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = self.classifier(outputs.last_hidden_state[:, 0, :])
loss = None
if labels is not None:
loss = self.loss_fct(logits, labels)
return (loss, logits) if loss is not None else logits
def _init_weights(self):
for module in self.classifier.modules():
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
@st.cache_resource
def load_assets():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_path = Path(__file__).parent
with open(base_path/"label_to_theme.json") as f:
label_to_theme = json.load(f)
with open(base_path/"theme_to_descripiton.json") as f:
theme_to_description = json.load(f)
class_weights = torch.load(f"{base_path}/class_weights.pth").to(device)
model = DebertaV3PaperClassifier(device=device, num_labels=NUM_LABELS, class_weights=class_weights).to(device)
model.load_state_dict(torch.load(f"{base_path}/full_model_v4.pth", map_location=device))
model.eval()
tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer, device, label_to_theme, theme_to_description
def preprocess_text(text, tokenizer, max_length=MAX_LENGTH):
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt"
)
return inputs
def predict(text: str, model, tokenizer, device) -> list:
"""Run model prediction on input text."""
inputs = preprocess_text(text, tokenizer)
with torch.no_grad():
logits = model(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device)
)
probs = torch.sigmoid(logits).cpu().numpy()[0]
return probs
def get_themes(probs: np.ndarray, label_to_theme: dict) -> list:
"""Get top K themes with probabilities."""
sorted_indices = np.argsort(-probs)
labels = []
sum_percent = 0
for idx in sorted_indices:
labels.append((label_to_theme[str(idx)], probs[idx]))
sum_percent += probs[idx]
if sum_percent >= 0.95:
break
return labels
def main():
st.title("Paper Classification App")
st.write("Classify research papers using DeBERTa model")
model, tokenizer, device, label_to_theme, theme_to_description = load_assets()
title = st.text_input("Title")
abstract = st.text_area("Abstract")
if st.button("Classify"):
if not title and not abstract:
st.warning("Please enter title and/or abstract")
return
if abstract is None:
text = title
elif title is None:
text = abstract
else:
text = f"{title}\n\n{abstract}"
with st.spinner("Analyzing text..."):
probabilities = predict(text, model, tokenizer, device)
themes = get_themes(probabilities, label_to_theme)
st.success("Predicted themes (click to expand):")
# for theme, prob in themes:
# st.write(f"- {theme}: {prob:.2%}")
for theme, prob in themes:
with st.expander(f"{theme} ({prob:.1%})"):
st.markdown(f"**Description**: {theme_to_description[theme]}")
if __name__ == "__main__":
main()