File size: 2,911 Bytes
7472a45
 
 
 
 
 
 
 
691416f
7472a45
6e70d21
 
691416f
6e70d21
 
 
 
691416f
6e70d21
 
691416f
6e70d21
 
 
691416f
6e70d21
 
 
 
 
691416f
6e70d21
 
 
 
 
 
 
 
 
 
 
7472a45
6e70d21
 
b951b0b
6e70d21
b951b0b
691416f
7472a45
6e70d21
8014cd1
 
 
 
6e70d21
 
 
 
 
 
 
 
7472a45
 
 
8014cd1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()  # Set model to evaluation mode

# Function to get relevance score and relevant excerpt based on attention scores
def get_relevance_score_and_excerpt(query, paragraph):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""
    
    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)  # Get attention scores

    # Extract logits and calculate relevance score
    logit = output.logits.squeeze().item()
    relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Extract attention scores (use the last attention layer)
    attention = output.attentions[-1]  # Shape: (batch_size, num_heads, seq_len, seq_len)

    # Average across attention heads to get token importance
    attention_scores = attention.mean(dim=1).squeeze(0)  # Shape: (seq_len, seq_len)
    
    # Focus on the paragraph part only (ignore query tokens)
    input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
    query_length = len(tokenizer.tokenize(query))
    
    # Extract attention for the paragraph tokens only
    paragraph_tokens = input_tokens[query_length + 2 : -1]  # Skip query and special tokens like [SEP]
    paragraph_attention = attention_scores[query_length + 2 : -1, query_length + 2 : -1].mean(dim=0)

    # Get the top tokens with highest attention scores
    top_token_indices = torch.argsort(paragraph_attention, descending=True)[:5]  # Top 5 tokens
    highlighted_tokens = [paragraph_tokens[i] for i in top_token_indices]

    # Reconstruct the excerpt from top attention tokens
    excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)

    return round(relevance_score, 4), excerpt

# Define Gradio interface
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.Textbox(label="Most Relevant Excerpt")
    ],
    title="Cross-Encoder Relevance Scoring with Attention-Based Excerpt Extraction",
    description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt using attention scores.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()