File size: 3,719 Bytes
7472a45
 
 
 
 
 
 
 
0e17055
7472a45
fe29744
 
51f19d1
 
 
 
 
 
7f50308
0e17055
691416f
6e70d21
5552636
6e70d21
a39d1e7
43a9654
 
0e17055
43a9654
7f50308
43a9654
7f50308
 
fe29744
 
43a9654
5552636
0e17055
 
a39d1e7
43a9654
 
 
7f50308
a39d1e7
0e17055
43a9654
a39d1e7
7f50308
a39d1e7
 
 
 
7f50308
 
 
 
43a9654
0e17055
5552636
 
 
0e17055
5552636
 
43a9654
5552636
43a9654
0e17055
43a9654
0e17055
43a9654
 
 
 
0e17055
fe29744
43a9654
 
0e17055
5552636
43a9654
0e17055
fe29744
43a9654
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Sigmoid-based threshold adjustment function
def calculate_threshold(base_relevance, min_threshold=0.02, max_threshold=0.5, k=10):
    base_relevance_tensor = torch.tensor(base_relevance)  # Ensure input is a tensor
    threshold = min_threshold + (max_threshold - min_threshold) * (
        1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
    )
    return threshold.item()  # Convert tensor back to float for use in other functions
    
# Function to compute relevance score and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""

    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)

    # Extract logits and calculate base relevance score
    logit = output.logits.squeeze().item()
    base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Compute dynamic threshold using sigmoid-based adjustment
    dynamic_threshold = calculate_threshold(base_relevance_score) * threshold_weight

    # Extract attention scores (last layer)
    attention = output.attentions[-1]
    attention_scores = attention.mean(dim=1).mean(dim=0)

    query_tokens = tokenizer.tokenize(query)
    paragraph_tokens = tokenizer.tokenize(paragraph)

    query_len = len(query_tokens) + 2  # +2 for special tokens [CLS] and first [SEP]
    para_start_idx = query_len
    para_end_idx = len(inputs["input_ids"][0]) - 1

    if para_end_idx <= para_start_idx:
        return round(base_relevance_score, 4), "No relevant tokens extracted."

    para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

    if para_attention_scores.numel() == 0:
        return round(base_relevance_score, 4), "No relevant tokens extracted."

    # Get indices of relevant tokens above dynamic threshold
    relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()

    # Reconstruct paragraph with bolded relevant tokens using HTML tags
    highlighted_text = ""
    for idx, token in enumerate(paragraph_tokens):
        if idx in relevant_indices:
            highlighted_text += f"<b>{token}</b> "
        else:
            highlighted_text += f"{token} "

    highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

    return round(base_relevance_score, 4), highlighted_text

# Define Gradio interface with a slider for threshold adjustment
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
        gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Threshold Weight")
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.HTML(label="Highlighted Document Paragraph")
    ],
    title="Cross-Encoder Attention Highlighting",
    description="Adjust the attention threshold weight to control token highlighting sensitivity.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()