File size: 3,719 Bytes
7472a45 0e17055 7472a45 fe29744 51f19d1 7f50308 0e17055 691416f 6e70d21 5552636 6e70d21 a39d1e7 43a9654 0e17055 43a9654 7f50308 43a9654 7f50308 fe29744 43a9654 5552636 0e17055 a39d1e7 43a9654 7f50308 a39d1e7 0e17055 43a9654 a39d1e7 7f50308 a39d1e7 7f50308 43a9654 0e17055 5552636 0e17055 5552636 43a9654 5552636 43a9654 0e17055 43a9654 0e17055 43a9654 0e17055 fe29744 43a9654 0e17055 5552636 43a9654 0e17055 fe29744 43a9654 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Sigmoid-based threshold adjustment function
def calculate_threshold(base_relevance, min_threshold=0.02, max_threshold=0.5, k=10):
base_relevance_tensor = torch.tensor(base_relevance) # Ensure input is a tensor
threshold = min_threshold + (max_threshold - min_threshold) * (
1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
)
return threshold.item() # Convert tensor back to float for use in other functions
# Function to compute relevance score and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
if not query.strip() or not paragraph.strip():
return "Please provide both a query and a document paragraph.", ""
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True)
# Extract logits and calculate base relevance score
logit = output.logits.squeeze().item()
base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
# Compute dynamic threshold using sigmoid-based adjustment
dynamic_threshold = calculate_threshold(base_relevance_score) * threshold_weight
# Extract attention scores (last layer)
attention = output.attentions[-1]
attention_scores = attention.mean(dim=1).mean(dim=0)
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1
if para_end_idx <= para_start_idx:
return round(base_relevance_score, 4), "No relevant tokens extracted."
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
return round(base_relevance_score, 4), "No relevant tokens extracted."
# Get indices of relevant tokens above dynamic threshold
relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Reconstruct paragraph with bolded relevant tokens using HTML tags
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
return round(base_relevance_score, 4), highlighted_text
# Define Gradio interface with a slider for threshold adjustment
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Threshold Weight")
],
outputs=[
gr.Textbox(label="Relevance Score"),
gr.HTML(label="Highlighted Document Paragraph")
],
title="Cross-Encoder Attention Highlighting",
description="Adjust the attention threshold weight to control token highlighting sensitivity.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |