File size: 3,452 Bytes
7472a45
 
 
 
 
 
 
 
43a9654
7472a45
5552636
 
 
 
6e70d21
691416f
6e70d21
5552636
6e70d21
a39d1e7
43a9654
 
 
 
 
 
 
 
5552636
43a9654
 
5552636
a39d1e7
 
5552636
43a9654
 
 
5552636
a39d1e7
43a9654
 
a39d1e7
 
 
 
5552636
a39d1e7
 
 
 
43a9654
5552636
 
43a9654
5552636
 
 
 
 
 
 
43a9654
5552636
 
43a9654
5552636
43a9654
 
 
 
 
 
 
 
 
 
5552636
43a9654
5552636
 
43a9654
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()  # Set the model to evaluation mode

# Threshold for attention relevance
THRESHOLD = 0.02  # Adjust as needed based on observations

# Function to get relevance score and relevant excerpt with bolded tokens
def get_relevance_score_and_excerpt(query, paragraph):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""

    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)  # Get attention scores

    # Extract logits and calculate relevance score
    logit = output.logits.squeeze().item()
    relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Extract attention scores (last layer)
    attention = output.attentions[-1]  # Shape: (batch_size, num_heads, seq_len, seq_len)

    # Average across heads and batch dimension
    attention_scores = attention.mean(dim=1).mean(dim=0)  # Shape: (seq_len, seq_len)

    # Tokenize query and paragraph separately
    query_tokens = tokenizer.tokenize(query)
    paragraph_tokens = tokenizer.tokenize(paragraph)

    query_len = len(query_tokens) + 2  # +2 for [CLS] and first [SEP]
    para_start_idx = query_len
    para_end_idx = len(inputs["input_ids"][0]) - 1  # Ignore final [SEP] token

    # Handle potential indexing issues
    if para_end_idx <= para_start_idx:
        return round(relevance_score, 4), "No relevant tokens extracted."

    # Extract paragraph attention scores
    para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

    if para_attention_scores.numel() == 0:
        return round(relevance_score, 4), "No relevant tokens extracted."

    # Filter tokens based on threshold and preserve order
    relevant_indices = (para_attention_scores > THRESHOLD).nonzero(as_tuple=True)[0].tolist()

    # Reconstruct paragraph with bolded relevant tokens
    highlighted_text = ""
    for idx, token in enumerate(paragraph_tokens):
        if idx in relevant_indices:
            highlighted_text += f"**{token}** "
        else:
            highlighted_text += f"{token} "

    # Convert tokens to readable format (handling special characters)
    highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

    return round(relevance_score, 4), highlighted_text

# Define Gradio interface
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.HTML(label="Highlighted Document Paragraph")
    ],
    title="Cross-Encoder Relevance Scoring with Highlighted Excerpt",
    description="Enter a query and a document paragraph to get a relevance score and see relevant tokens in bold.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()