File size: 3,869 Bytes
7472a45
 
 
 
 
 
 
 
43a9654
7472a45
7f50308
6e70d21
691416f
6e70d21
5552636
6e70d21
a39d1e7
43a9654
 
 
 
7f50308
43a9654
7f50308
 
 
 
43a9654
5552636
43a9654
7f50308
a39d1e7
5552636
43a9654
 
 
7f50308
a39d1e7
43a9654
 
a39d1e7
 
7f50308
a39d1e7
7f50308
a39d1e7
 
 
7f50308
 
 
 
43a9654
7f50308
 
 
 
 
 
43a9654
5552636
 
 
 
 
 
 
43a9654
7f50308
5552636
43a9654
7f50308
43a9654
 
 
 
 
 
 
 
 
7f50308
5552636
43a9654
7f50308
 
43a9654
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()  # Set the model to evaluation mode

# Function to compute relevance score and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""

    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)  # Get attention scores

    # Extract logits and calculate base relevance score
    logit = output.logits.squeeze().item()
    base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Dynamically adjust the attention threshold based on relevance score
    dynamic_threshold = max(0.02, base_relevance_score * 0.1)  # Example formula

    # Extract attention scores (last layer)
    attention = output.attentions[-1]  # Shape: (batch_size, num_heads, seq_len, seq_len)
    attention_scores = attention.mean(dim=1).mean(dim=0)  # Average over heads and batch

    # Tokenize query and paragraph separately
    query_tokens = tokenizer.tokenize(query)
    paragraph_tokens = tokenizer.tokenize(paragraph)

    query_len = len(query_tokens) + 2  # +2 for special tokens [CLS] and first [SEP]
    para_start_idx = query_len
    para_end_idx = len(inputs["input_ids"][0]) - 1  # Ignore final [SEP] token

    # Handle potential indexing issues
    if para_end_idx <= para_start_idx:
        return round(base_relevance_score, 4), "No relevant tokens extracted."

    # Extract paragraph attention scores and apply dynamic threshold
    para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

    if para_attention_scores.numel() == 0:
        return round(base_relevance_score, 4), "No relevant tokens extracted."

    # Get indices of relevant tokens above dynamic threshold
    relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()

    # Compute attention-weighted relevance score
    if relevant_indices:
        relevant_attention_values = para_attention_scores[relevant_indices]
        attention_weighted_score = relevant_attention_values.mean().item() * base_relevance_score
    else:
        attention_weighted_score = base_relevance_score  # No relevant tokens found

    # Reconstruct paragraph with bolded relevant tokens
    highlighted_text = ""
    for idx, token in enumerate(paragraph_tokens):
        if idx in relevant_indices:
            highlighted_text += f"**{token}** "
        else:
            highlighted_text += f"{token} "

    # Convert tokens back to readable format
    highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

    return round(attention_weighted_score, 4), highlighted_text

# Define Gradio interface
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
    ],
    outputs=[
        gr.Textbox(label="Attention-Weighted Relevance Score"),
        gr.HTML(label="Highlighted Document Paragraph")
    ],
    title="Cross-Encoder with Dynamic Attention Threshold",
    description="Enter a query and document paragraph to get a relevance score with relevant tokens in bold.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()