File size: 3,295 Bytes
7472a45
 
 
 
 
 
 
 
43a9654
7472a45
43a9654
6e70d21
691416f
6e70d21
 
 
a39d1e7
43a9654
 
 
 
 
 
 
 
a39d1e7
43a9654
 
a39d1e7
 
 
 
43a9654
 
 
a39d1e7
 
43a9654
 
a39d1e7
 
 
 
 
 
 
 
43a9654
 
a39d1e7
 
43a9654
a39d1e7
 
43a9654
a39d1e7
43a9654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()  # Set the model to evaluation mode

# Function to get relevance score and relevant excerpt while preserving order
def get_relevance_score_and_excerpt(query, paragraph):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""
    
    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)  # Get attention scores

    # Extract logits and calculate relevance score
    logit = output.logits.squeeze().item()
    relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Extract attention scores (last attention layer)
    attention = output.attentions[-1]  # Shape: (batch_size, num_heads, seq_len, seq_len)

    # Average attention across heads and batch dimension
    attention_scores = attention.mean(dim=1).mean(dim=0)  # Shape: (seq_len, seq_len)

    # Get tokenized query and paragraph separately
    query_tokens = tokenizer.tokenize(query)
    paragraph_tokens = tokenizer.tokenize(paragraph)

    query_len = len(query_tokens) + 2  # +2 for special tokens [CLS] and [SEP]
    para_start_idx = query_len
    para_end_idx = len(inputs["input_ids"][0]) - 1  # Ignore final [SEP] token

    # Handle potential indexing issues
    if para_end_idx <= para_start_idx:
        return round(relevance_score, 4), "No relevant tokens extracted."

    para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

    if para_attention_scores.numel() == 0:
        return round(relevance_score, 4), "No relevant tokens extracted."

    # Get indices of top-k attended tokens while preserving order
    top_k = min(5, para_attention_scores.size(0))  # Ensure top-k does not exceed available tokens
    top_indices = para_attention_scores.topk(top_k).indices.sort().values  # Sort indices to preserve order

    # Extract highlighted tokens from the paragraph
    highlighted_tokens = [paragraph_tokens[i] for i in top_indices.tolist()]

    # Convert tokens back to a readable string
    excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)

    return round(relevance_score, 4), excerpt

# Define Gradio interface
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.Textbox(label="Most Relevant Excerpt")
    ],
    title="Cross-Encoder Relevance Scoring with Ordered Excerpt Extraction",
    description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt in original order.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()