File size: 4,379 Bytes
7472a45
e64db30
a1e51f4
7472a45
 
 
 
a1e51f4
e64db30
5552636
0f56dc9
271e600
 
 
 
 
6700bfc
a1e51f4
6700bfc
43a9654
6700bfc
 
 
 
 
 
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
 
 
 
 
43a9654
6700bfc
43a9654
6700bfc
 
 
 
 
 
 
 
 
 
 
 
271e600
43a9654
6700bfc
a1e51f4
 
43a9654
a1e51f4
6700bfc
 
 
0f56dc9
43a9654
 
6700bfc
 
a1e51f4
6700bfc
 
a1e51f4
 
43a9654
 
a1e51f4
271e600
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Function to compute relevance score (in logits) and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
    # Handle empty input for paragraphs
    paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
    
    if not query.strip() or not paragraphs:
        return "Please provide both a query and at least one document paragraph.", ""

    ranked_paragraphs = []
    
    # Process each paragraph and calculate its logits and highlighted text
    for paragraph in paragraphs:
        # Tokenize the input
        inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
        
        with torch.no_grad():
            output = model(**inputs, output_attentions=True)

        # Extract logits (no sigmoid applied)
        logit = output.logits.squeeze().item()
        base_relevance_score = logit  # Relevance score in logits

        # Dynamically adjust the attention threshold based on user weight
        dynamic_threshold = max(0.02, threshold_weight)

        # Extract attention scores (last layer)
        attention = output.attentions[-1]
        attention_scores = attention.mean(dim=1).mean(dim=0)

        query_tokens = tokenizer.tokenize(query)
        paragraph_tokens = tokenizer.tokenize(paragraph)

        query_len = len(query_tokens) + 2  # +2 for special tokens [CLS] and first [SEP]
        para_start_idx = query_len
        para_end_idx = len(inputs["input_ids"][0]) - 1

        if para_end_idx <= para_start_idx:
            continue

        para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

        if para_attention_scores.numel() == 0:
            continue

        # Get indices of relevant tokens above dynamic threshold
        relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()

        # Reconstruct paragraph with bolded relevant tokens using HTML tags
        highlighted_text = ""
        for idx, token in enumerate(paragraph_tokens):
            if idx in relevant_indices:
                highlighted_text += f"<b>{token}</b> "
            else:
                highlighted_text += f"{token} "

        highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

        ranked_paragraphs.append({
            "logit": logit,
            "highlighted_text": highlighted_text
        })
    
    # Sort paragraphs by logit (descending)
    ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True)

    # Prepare output
    relevance_scores = [round(p["logit"], 4) for p in ranked_paragraphs]
    highlighted_texts = [p["highlighted_text"] for p in ranked_paragraphs]

    return "\n".join([f"Relevance Score (Logits): {score}" for score in relevance_scores]), "\n\n".join(highlighted_texts)

# Define Gradio interface with a slider for threshold adjustment and ability to add multiple paragraphs
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
        gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
    ],
    outputs=[
        gr.Textbox(label="Relevance Scores (Logits)"),
        gr.HTML(label="Highlighted Document Paragraphs")
    ],
    title="Cross-Encoder Attention Highlighting with Reranking",
    description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()