File size: 3,275 Bytes
7472a45 e64db30 a1e51f4 7472a45 a1e51f4 e64db30 5552636 0f56dc9 a1e51f4 0f56dc9 a1e51f4 43a9654 a1e51f4 0f56dc9 a1e51f4 0f56dc9 a1e51f4 0f56dc9 a1e51f4 0f56dc9 a1e51f4 0f56dc9 a1e51f4 5552636 a1e51f4 5552636 0e17055 5552636 43a9654 5552636 43a9654 0f56dc9 43a9654 a1e51f4 43a9654 a1e51f4 0f56dc9 43a9654 0f56dc9 5552636 a1e51f4 0f56dc9 a1e51f4 43a9654 a1e51f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Function to compute relevance score (in logits) and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
if not query.strip() or not paragraph.strip():
return "Please provide both a query and a document paragraph.", ""
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True)
# Extract logits (no sigmoid applied)
logit = output.logits.squeeze().item()
base_relevance_score = logit # Relevance score in logits
# Dynamically adjust the attention threshold based on user weight (no relevance score influence)
dynamic_threshold = max(0.02, threshold_weight)
# Extract attention scores (last layer)
attention = output.attentions[-1]
attention_scores = attention.mean(dim=1).mean(dim=0)
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1
if para_end_idx <= para_start_idx:
return round(base_relevance_score, 4), "No relevant tokens extracted."
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
return round(base_relevance_score, 4), "No relevant tokens extracted."
# Get indices of relevant tokens above dynamic threshold
relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Reconstruct paragraph with bolded relevant tokens using HTML tags
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
return round(base_relevance_score, 4), highlighted_text
# Define Gradio interface with a slider for threshold adjustment
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
],
outputs=[
gr.Textbox(label="Relevance Score (Logits)"),
gr.HTML(label="Highlighted Document Paragraph")
],
title="Cross-Encoder Attention Highlighting",
description="Adjust the attention threshold to control token highlighting sensitivity.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |