File size: 3,097 Bytes
7472a45
 
 
 
 
 
 
 
43a9654
7472a45
43a9654
6e70d21
691416f
6e70d21
 
 
43a9654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()  # Set the model to evaluation mode

# Function to get relevance score and relevant excerpt while preserving order
def get_relevance_score_and_excerpt(query, paragraph):
    if not query.strip() or not paragraph.strip():
        return "Please provide both a query and a document paragraph.", ""
    
    # Tokenize the input
    inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
    
    with torch.no_grad():
        output = model(**inputs, output_attentions=True)  # Get attention scores

    # Extract logits and calculate relevance score
    logit = output.logits.squeeze().item()
    relevance_score = torch.sigmoid(torch.tensor(logit)).item()

    # Extract attention scores (use the last attention layer)
    attention = output.attentions[-1]  # Shape: (batch_size, num_heads, seq_len, seq_len)

    # Average across attention heads to get token importance
    attention_scores = attention.mean(dim=1).squeeze(0)  # Shape: (seq_len, seq_len)
    
    # Tokenize query and paragraph separately to get lengths
    query_tokens = tokenizer.tokenize(query)
    paragraph_tokens = tokenizer.tokenize(paragraph)

    # Extract only paragraph-related attention scores
    query_length = len(query_tokens) + 2  # +2 for special tokens like [CLS] and [SEP]
    para_start_idx = query_length
    para_end_idx = len(inputs["input_ids"][0]) - 1  # Ignore final [SEP] token

    para_attention_scores = attention_scores[0, para_start_idx:para_end_idx].mean(dim=0)

    # Get indices of top-k attended tokens while preserving order
    top_k = min(5, len(paragraph_tokens))  # Extract top 5 tokens or fewer if short
    top_indices = para_attention_scores.argsort(descending=True)[:top_k].sort().values  # Sort to preserve order

    # Extract top tokens based on original order
    highlighted_tokens = [paragraph_tokens[i] for i in top_indices]

    # Reconstruct the excerpt from ordered tokens
    excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)

    return round(relevance_score, 4), excerpt

# Define Gradio interface
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.Textbox(label="Most Relevant Excerpt")
    ],
    title="Cross-Encoder Relevance Scoring with Ordered Excerpt Extraction",
    description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt in original order.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()