wilwork commited on
Commit
5552636
·
verified ·
1 Parent(s): a39d1e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -8,11 +8,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval() # Set the model to evaluation mode
10
 
11
- # Function to get relevance score and relevant excerpt while preserving order
 
 
 
12
  def get_relevance_score_and_excerpt(query, paragraph):
13
  if not query.strip() or not paragraph.strip():
14
  return "Please provide both a query and a document paragraph.", ""
15
-
16
  # Tokenize the input
17
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
18
 
@@ -23,17 +26,17 @@ def get_relevance_score_and_excerpt(query, paragraph):
23
  logit = output.logits.squeeze().item()
24
  relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
- # Extract attention scores (last attention layer)
27
  attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
 
29
- # Average attention across heads and batch dimension
30
  attention_scores = attention.mean(dim=1).mean(dim=0) # Shape: (seq_len, seq_len)
31
 
32
- # Get tokenized query and paragraph separately
33
  query_tokens = tokenizer.tokenize(query)
34
  paragraph_tokens = tokenizer.tokenize(paragraph)
35
 
36
- query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and [SEP]
37
  para_start_idx = query_len
38
  para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
39
 
@@ -41,22 +44,27 @@ def get_relevance_score_and_excerpt(query, paragraph):
41
  if para_end_idx <= para_start_idx:
42
  return round(relevance_score, 4), "No relevant tokens extracted."
43
 
 
44
  para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
45
 
46
  if para_attention_scores.numel() == 0:
47
  return round(relevance_score, 4), "No relevant tokens extracted."
48
 
49
- # Get indices of top-k attended tokens while preserving order
50
- top_k = min(5, para_attention_scores.size(0)) # Ensure top-k does not exceed available tokens
51
- top_indices = para_attention_scores.topk(top_k).indices.sort().values # Sort indices to preserve order
52
 
53
- # Extract highlighted tokens from the paragraph
54
- highlighted_tokens = [paragraph_tokens[i] for i in top_indices.tolist()]
 
 
 
 
 
55
 
56
- # Convert tokens back to a readable string
57
- excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
58
 
59
- return round(relevance_score, 4), excerpt
60
 
61
  # Define Gradio interface
62
  interface = gr.Interface(
@@ -67,10 +75,10 @@ interface = gr.Interface(
67
  ],
68
  outputs=[
69
  gr.Textbox(label="Relevance Score"),
70
- gr.Textbox(label="Most Relevant Excerpt")
71
  ],
72
- title="Cross-Encoder Relevance Scoring with Ordered Excerpt Extraction",
73
- description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt in original order.",
74
  allow_flagging="never",
75
  live=True
76
  )
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval() # Set the model to evaluation mode
10
 
11
+ # Threshold for attention relevance
12
+ THRESHOLD = 0.02 # Adjust as needed based on observations
13
+
14
+ # Function to get relevance score and relevant excerpt with bolded tokens
15
  def get_relevance_score_and_excerpt(query, paragraph):
16
  if not query.strip() or not paragraph.strip():
17
  return "Please provide both a query and a document paragraph.", ""
18
+
19
  # Tokenize the input
20
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
21
 
 
26
  logit = output.logits.squeeze().item()
27
  relevance_score = torch.sigmoid(torch.tensor(logit)).item()
28
 
29
+ # Extract attention scores (last layer)
30
  attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
31
 
32
+ # Average across heads and batch dimension
33
  attention_scores = attention.mean(dim=1).mean(dim=0) # Shape: (seq_len, seq_len)
34
 
35
+ # Tokenize query and paragraph separately
36
  query_tokens = tokenizer.tokenize(query)
37
  paragraph_tokens = tokenizer.tokenize(paragraph)
38
 
39
+ query_len = len(query_tokens) + 2 # +2 for [CLS] and first [SEP]
40
  para_start_idx = query_len
41
  para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
42
 
 
44
  if para_end_idx <= para_start_idx:
45
  return round(relevance_score, 4), "No relevant tokens extracted."
46
 
47
+ # Extract paragraph attention scores
48
  para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
49
 
50
  if para_attention_scores.numel() == 0:
51
  return round(relevance_score, 4), "No relevant tokens extracted."
52
 
53
+ # Filter tokens based on threshold and preserve order
54
+ relevant_indices = (para_attention_scores > THRESHOLD).nonzero(as_tuple=True)[0].tolist()
 
55
 
56
+ # Reconstruct paragraph with bolded relevant tokens
57
+ highlighted_text = ""
58
+ for idx, token in enumerate(paragraph_tokens):
59
+ if idx in relevant_indices:
60
+ highlighted_text += f"**{token}** "
61
+ else:
62
+ highlighted_text += f"{token} "
63
 
64
+ # Convert tokens to readable format (handling special characters)
65
+ highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
66
 
67
+ return round(relevance_score, 4), highlighted_text
68
 
69
  # Define Gradio interface
70
  interface = gr.Interface(
 
75
  ],
76
  outputs=[
77
  gr.Textbox(label="Relevance Score"),
78
+ gr.HTML(label="Highlighted Document Paragraph")
79
  ],
80
+ title="Cross-Encoder Relevance Scoring with Highlighted Excerpt",
81
+ description="Enter a query and a document paragraph to get a relevance score and see relevant tokens in bold.",
82
  allow_flagging="never",
83
  live=True
84
  )