import streamlit as st import json import numpy as np import sentencepiece from pathlib import Path import torch from transformers import ( DebertaV2Config, DebertaV2Model, DebertaV2Tokenizer, ) MODEL_NAME = "microsoft/deberta-v3-base" MAX_LENGTH = 512 NUM_LABELS = 47 DROPOUT_RATE = 0.1 THRESHOLD = 0.5 class DebertaV3PaperClassifier(torch.nn.Module): def __init__(self, device, num_labels=NUM_LABELS, dropout_rate=DROPOUT_RATE, class_weights=None): super().__init__() self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base") self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config) self.classifier = torch.nn.Sequential( torch.nn.Dropout(dropout_rate), torch.nn.Linear(self.config.hidden_size, 512), torch.nn.LayerNorm(512), torch.nn.GELU(), torch.nn.Dropout(dropout_rate), torch.nn.Linear(512, num_labels) ) if class_weights is not None: self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) else: self.loss_fct = torch.nn.BCEWithLogitsLoss() def forward(self, input_ids, attention_mask, labels=None): outputs = self.deberta( input_ids=input_ids, attention_mask=attention_mask ) logits = self.classifier(outputs.last_hidden_state[:, 0, :]) loss = None if labels is not None: loss = self.loss_fct(logits, labels) return (loss, logits) if loss is not None else logits def _init_weights(self): for module in self.classifier.modules(): if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() @st.cache_resource def load_assets(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") base_path = Path(__file__).parent with open(base_path/"label_to_theme.json") as f: label_to_theme = json.load(f) with open(base_path/"theme_to_descripiton.json") as f: theme_to_description = json.load(f) class_weights = torch.load(f"{base_path}/class_weights.pth").to(device) model = DebertaV3PaperClassifier(device=device, num_labels=NUM_LABELS, class_weights=class_weights).to(device) model.load_state_dict(torch.load(f"{base_path}/full_model_v4.pth", map_location=device)) model.eval() tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME) return model, tokenizer, device, label_to_theme, theme_to_description def preprocess_text(text, tokenizer, max_length=MAX_LENGTH): inputs = tokenizer( text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt" ) return inputs def predict(text: str, model, tokenizer, device) -> list: """Run model prediction on input text.""" inputs = preprocess_text(text, tokenizer) with torch.no_grad(): logits = model( input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device) ) probs = torch.sigmoid(logits).cpu().numpy()[0] return probs def get_themes(probs: np.ndarray, label_to_theme: dict) -> list: """Get top K themes with probabilities.""" sorted_indices = np.argsort(-probs) labels = [] sum_percent = 0 for idx in sorted_indices: labels.append((label_to_theme[str(idx)], probs[idx])) sum_percent += probs[idx] if sum_percent >= 0.95: break return labels def main(): st.title("Paper Classification App") st.write("Classify research papers using DeBERTa model") model, tokenizer, device, label_to_theme, theme_to_description = load_assets() title = st.text_input("Title") abstract = st.text_area("Abstract") if st.button("Classify"): if not title and not abstract: st.warning("Please enter title and/or abstract") return if abstract is None: text = title elif title is None: text = abstract else: text = f"{title}\n\n{abstract}" with st.spinner("Analyzing text..."): probabilities = predict(text, model, tokenizer, device) themes = get_themes(probabilities, label_to_theme) st.success("Predicted themes (click to expand):") # for theme, prob in themes: # st.write(f"- {theme}: {prob:.2%}") for theme, prob in themes: with st.expander(f"{theme} ({prob:.1%})"): st.markdown(f"**Description**: {theme_to_description[theme]}") if __name__ == "__main__": main()