wilwork commited on
Commit
43a9654
·
verified ·
1 Parent(s): 6a947e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -3
app.py CHANGED
@@ -6,12 +6,68 @@ import torch
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
- model.eval() # Set model to evaluation mode
10
 
11
- # Function to get relevance score and relevant excerpt based on attention scores
12
  def get_relevance_score_and_excerpt(query, paragraph):
13
  if not query.strip() or not paragraph.strip():
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
- inputs = t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ model.eval() # Set the model to evaluation mode
10
 
11
+ # Function to get relevance score and relevant excerpt while preserving order
12
  def get_relevance_score_and_excerpt(query, paragraph):
13
  if not query.strip() or not paragraph.strip():
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
+ inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
18
+
19
+ with torch.no_grad():
20
+ output = model(**inputs, output_attentions=True) # Get attention scores
21
+
22
+ # Extract logits and calculate relevance score
23
+ logit = output.logits.squeeze().item()
24
+ relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
+
26
+ # Extract attention scores (use the last attention layer)
27
+ attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
+
29
+ # Average across attention heads to get token importance
30
+ attention_scores = attention.mean(dim=1).squeeze(0) # Shape: (seq_len, seq_len)
31
+
32
+ # Tokenize query and paragraph separately to get lengths
33
+ query_tokens = tokenizer.tokenize(query)
34
+ paragraph_tokens = tokenizer.tokenize(paragraph)
35
+
36
+ # Extract only paragraph-related attention scores
37
+ query_length = len(query_tokens) + 2 # +2 for special tokens like [CLS] and [SEP]
38
+ para_start_idx = query_length
39
+ para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
40
+
41
+ para_attention_scores = attention_scores[0, para_start_idx:para_end_idx].mean(dim=0)
42
+
43
+ # Get indices of top-k attended tokens while preserving order
44
+ top_k = min(5, len(paragraph_tokens)) # Extract top 5 tokens or fewer if short
45
+ top_indices = para_attention_scores.argsort(descending=True)[:top_k].sort().values # Sort to preserve order
46
+
47
+ # Extract top tokens based on original order
48
+ highlighted_tokens = [paragraph_tokens[i] for i in top_indices]
49
+
50
+ # Reconstruct the excerpt from ordered tokens
51
+ excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
52
+
53
+ return round(relevance_score, 4), excerpt
54
+
55
+ # Define Gradio interface
56
+ interface = gr.Interface(
57
+ fn=get_relevance_score_and_excerpt,
58
+ inputs=[
59
+ gr.Textbox(label="Query", placeholder="Enter your search query..."),
60
+ gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
61
+ ],
62
+ outputs=[
63
+ gr.Textbox(label="Relevance Score"),
64
+ gr.Textbox(label="Most Relevant Excerpt")
65
+ ],
66
+ title="Cross-Encoder Relevance Scoring with Ordered Excerpt Extraction",
67
+ description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt in original order.",
68
+ allow_flagging="never",
69
+ live=True
70
+ )
71
+
72
+ if __name__ == "__main__":
73
+ interface.launch()