File size: 4,379 Bytes
7472a45 e64db30 a1e51f4 7472a45 a1e51f4 e64db30 5552636 0f56dc9 271e600 6700bfc a1e51f4 6700bfc 43a9654 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc 43a9654 6700bfc 43a9654 6700bfc 271e600 43a9654 6700bfc a1e51f4 43a9654 a1e51f4 6700bfc 0f56dc9 43a9654 6700bfc a1e51f4 6700bfc a1e51f4 43a9654 a1e51f4 271e600 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Function to compute relevance score (in logits) and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
# Handle empty input for paragraphs
paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
if not query.strip() or not paragraphs:
return "Please provide both a query and at least one document paragraph.", ""
ranked_paragraphs = []
# Process each paragraph and calculate its logits and highlighted text
for paragraph in paragraphs:
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True)
# Extract logits (no sigmoid applied)
logit = output.logits.squeeze().item()
base_relevance_score = logit # Relevance score in logits
# Dynamically adjust the attention threshold based on user weight
dynamic_threshold = max(0.02, threshold_weight)
# Extract attention scores (last layer)
attention = output.attentions[-1]
attention_scores = attention.mean(dim=1).mean(dim=0)
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1
if para_end_idx <= para_start_idx:
continue
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
continue
# Get indices of relevant tokens above dynamic threshold
relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Reconstruct paragraph with bolded relevant tokens using HTML tags
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
ranked_paragraphs.append({
"logit": logit,
"highlighted_text": highlighted_text
})
# Sort paragraphs by logit (descending)
ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True)
# Prepare output
relevance_scores = [round(p["logit"], 4) for p in ranked_paragraphs]
highlighted_texts = [p["highlighted_text"] for p in ranked_paragraphs]
return "\n".join([f"Relevance Score (Logits): {score}" for score in relevance_scores]), "\n\n".join(highlighted_texts)
# Define Gradio interface with a slider for threshold adjustment and ability to add multiple paragraphs
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
],
outputs=[
gr.Textbox(label="Relevance Scores (Logits)"),
gr.HTML(label="Highlighted Document Paragraphs")
],
title="Cross-Encoder Attention Highlighting with Reranking",
description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |