#!/usr/bin/python3 import streamlit as st import json import numpy as np import torch from transformers import ( DebertaV2Config, DebertaV2Model, DebertaV2Tokenizer, ) model_name = "microsoft/deberta-v3-base" tokenizer = DebertaV2Tokenizer.from_pretrained(model_name) def preprocess_text(text, tokenizer, max_length=512): inputs = tokenizer( text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt" ) return inputs def classify_text(text, model, tokenizer, device, threshold=0.5): inputs = preprocess_text(text, tokenizer) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) model.eval() with torch.no_grad(): logits = model(input_ids, attention_mask) probs = torch.sigmoid(logits) predictions = (probs > threshold).int().numpy() return probs.numpy(), predictions def get_themes(text, model, tokenizer, label_to_theme, device, limit=5): probabilities, _ = classify_text(text, model, tokenizer, device) probabilities = probabilities / probabilities.sum() themes = [] for label in probabilities[0].argsort()[-limit:]: themes.append((label_to_theme[str(label)], probabilities[0][label])) return themes class DebertPaperClassifier(torch.nn.Module): def __init__(self, num_labels, device, dropout_rate=0.1, class_weights=None): super().__init__() self.config = DebertaV2Config.from_pretrained(model_name) self.deberta = DebertaV2Model.from_pretrained(model_name, config=self.config) self.classifier = torch.nn.Sequential( torch.nn.Dropout(dropout_rate), torch.nn.Linear(self.config.hidden_size, 512), torch.nn.LayerNorm(512), torch.nn.GELU(), torch.nn.Dropout(dropout_rate), torch.nn.Linear(512, num_labels) ) self._init_weights() if class_weights is not None: self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) else: self.loss_fct = torch.nn.BCEWithLogitsLoss() class DebertPaperClassifierV5(torch.nn.Module): def __init__(self, device, num_labels=47, dropout_rate=0.1, class_weights=None): super().__init__() self.config = DebertaV2Config.from_pretrained("microsoft/deberta-v3-base") self.deberta = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base", config=self.config) self.classifier = torch.nn.Sequential( torch.nn.Dropout(dropout_rate), torch.nn.Linear(self.config.hidden_size, 512), torch.nn.LayerNorm(512), torch.nn.GELU(), torch.nn.Dropout(dropout_rate), torch.nn.Linear(512, num_labels) ) if class_weights is not None: self.loss_fct = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device)) else: self.loss_fct = torch.nn.BCEWithLogitsLoss() def forward(self, input_ids, attention_mask, labels=None): outputs = self.deberta( input_ids=input_ids, attention_mask=attention_mask ) logits = self.classifier(outputs.last_hidden_state[:, 0, :]) loss = None if labels is not None: loss = self.loss_fct(logits, labels) return (loss, logits) if loss is not None else logits def _init_weights(self): for module in self.classifier.modules(): if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() def forward(self, input_ids, attention_mask, labels=None, ): outputs = self.deberta( input_ids=input_ids, attention_mask=attention_mask ) cls_output = outputs.last_hidden_state[:, 0, :] logits = self.classifier(cls_output) loss = None if labels is not None: loss = self.loss_fct(logits, labels) return (loss, logits) if loss is not None else logits @st.cache_resource def load_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open('model_info/label_to_theme.json', 'r') as f: label_to_theme = json.load(f) model = DebertPaperClassifierV5(device=device, num_labels=len(label_to_theme)).to(device) model.load_state_dict(torch.load("model_info/deberta_v3.pth", map_location=device)) return model, tokenizer, label_to_theme, device def kek(): st.title("arXiv Paper Classifier") st.markdown("""