File size: 3,452 Bytes
7472a45 43a9654 7472a45 5552636 6e70d21 691416f 6e70d21 5552636 6e70d21 a39d1e7 43a9654 5552636 43a9654 5552636 a39d1e7 5552636 43a9654 5552636 a39d1e7 43a9654 a39d1e7 5552636 a39d1e7 43a9654 5552636 43a9654 5552636 43a9654 5552636 43a9654 5552636 43a9654 5552636 43a9654 5552636 43a9654 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode
# Threshold for attention relevance
THRESHOLD = 0.02 # Adjust as needed based on observations
# Function to get relevance score and relevant excerpt with bolded tokens
def get_relevance_score_and_excerpt(query, paragraph):
if not query.strip() or not paragraph.strip():
return "Please provide both a query and a document paragraph.", ""
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True) # Get attention scores
# Extract logits and calculate relevance score
logit = output.logits.squeeze().item()
relevance_score = torch.sigmoid(torch.tensor(logit)).item()
# Extract attention scores (last layer)
attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
# Average across heads and batch dimension
attention_scores = attention.mean(dim=1).mean(dim=0) # Shape: (seq_len, seq_len)
# Tokenize query and paragraph separately
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
# Handle potential indexing issues
if para_end_idx <= para_start_idx:
return round(relevance_score, 4), "No relevant tokens extracted."
# Extract paragraph attention scores
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
return round(relevance_score, 4), "No relevant tokens extracted."
# Filter tokens based on threshold and preserve order
relevant_indices = (para_attention_scores > THRESHOLD).nonzero(as_tuple=True)[0].tolist()
# Reconstruct paragraph with bolded relevant tokens
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"**{token}** "
else:
highlighted_text += f"{token} "
# Convert tokens to readable format (handling special characters)
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
return round(relevance_score, 4), highlighted_text
# Define Gradio interface
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
],
outputs=[
gr.Textbox(label="Relevance Score"),
gr.HTML(label="Highlighted Document Paragraph")
],
title="Cross-Encoder Relevance Scoring with Highlighted Excerpt",
description="Enter a query and a document paragraph to get a relevance score and see relevant tokens in bold.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |