import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # Load model and tokenizer model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True) tokenizer = AutoTokenizer.from_pretrained(model_name) # Set model to evaluation mode model.eval() # Function to compute relevance and highlight relevant tokens def process_text(query, document, weight): # Tokenize input inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True) input_ids = inputs["input_ids"] # Get model outputs with attentions with torch.no_grad(): outputs = model(**inputs, output_attentions=True) relevance_score = torch.sigmoid(outputs.logits).item() # Convert logits to relevance score attentions = outputs.attentions[-1].squeeze(0).mean(0) # Mean attention across heads # Calculate dynamic threshold using sigmoid function def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10): base_relevance_tensor = torch.tensor(base_relevance) threshold = min_threshold + (max_threshold - min_threshold) * ( 1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5))) ) return threshold.item() dynamic_threshold = calculate_threshold(relevance_score) * weight # Extract important tokens based on attention scores relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist() # Highlight tokens in the original order, using HTML bold tags tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) highlighted_text = "" for idx, token in enumerate(tokens): if idx in relevant_indices: highlighted_text += f"{token} " else: highlighted_text += f"{token} " highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split()) # Print values to debug print(f"Relevance Score: {relevance_score}") print(f"Dynamic Threshold: {dynamic_threshold}") return relevance_score, dynamic_threshold, highlighted_text # Create Gradio interface with a slider for threshold adjustment weight iface = gr.Interface( fn=process_text, inputs=[ gr.Textbox(label="Query"), gr.Textbox(label="Document Paragraph"), gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"), ], outputs=[ gr.Textbox(label="Relevance Score"), gr.Textbox(label="Dynamic Threshold"), gr.HTML(label="Highlighted Document Paragraph") ] ) iface.launch()