File size: 3,869 Bytes
7472a45 43a9654 7472a45 7f50308 6e70d21 691416f 6e70d21 5552636 6e70d21 a39d1e7 43a9654 7f50308 43a9654 7f50308 43a9654 5552636 43a9654 7f50308 a39d1e7 5552636 43a9654 7f50308 a39d1e7 43a9654 a39d1e7 7f50308 a39d1e7 7f50308 a39d1e7 7f50308 43a9654 7f50308 43a9654 5552636 43a9654 7f50308 5552636 43a9654 7f50308 43a9654 7f50308 5552636 43a9654 7f50308 43a9654 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode
# Function to compute relevance score and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph):
if not query.strip() or not paragraph.strip():
return "Please provide both a query and a document paragraph.", ""
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True) # Get attention scores
# Extract logits and calculate base relevance score
logit = output.logits.squeeze().item()
base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
# Dynamically adjust the attention threshold based on relevance score
dynamic_threshold = max(0.02, base_relevance_score * 0.1) # Example formula
# Extract attention scores (last layer)
attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
attention_scores = attention.mean(dim=1).mean(dim=0) # Average over heads and batch
# Tokenize query and paragraph separately
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
# Handle potential indexing issues
if para_end_idx <= para_start_idx:
return round(base_relevance_score, 4), "No relevant tokens extracted."
# Extract paragraph attention scores and apply dynamic threshold
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
return round(base_relevance_score, 4), "No relevant tokens extracted."
# Get indices of relevant tokens above dynamic threshold
relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Compute attention-weighted relevance score
if relevant_indices:
relevant_attention_values = para_attention_scores[relevant_indices]
attention_weighted_score = relevant_attention_values.mean().item() * base_relevance_score
else:
attention_weighted_score = base_relevance_score # No relevant tokens found
# Reconstruct paragraph with bolded relevant tokens
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"**{token}** "
else:
highlighted_text += f"{token} "
# Convert tokens back to readable format
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
return round(attention_weighted_score, 4), highlighted_text
# Define Gradio interface
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
],
outputs=[
gr.Textbox(label="Attention-Weighted Relevance Score"),
gr.HTML(label="Highlighted Document Paragraph")
],
title="Cross-Encoder with Dynamic Attention Threshold",
description="Enter a query and document paragraph to get a relevance score with relevant tokens in bold.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |