wilwork commited on
Commit
a39d1e7
·
verified ·
1 Parent(s): 43a9654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -14,7 +14,7 @@ def get_relevance_score_and_excerpt(query, paragraph):
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
- inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
18
 
19
  with torch.no_grad():
20
  output = model(**inputs, output_attentions=True) # Get attention scores
@@ -23,31 +23,37 @@ def get_relevance_score_and_excerpt(query, paragraph):
23
  logit = output.logits.squeeze().item()
24
  relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
- # Extract attention scores (use the last attention layer)
27
  attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
 
29
- # Average across attention heads to get token importance
30
- attention_scores = attention.mean(dim=1).squeeze(0) # Shape: (seq_len, seq_len)
31
-
32
- # Tokenize query and paragraph separately to get lengths
33
  query_tokens = tokenizer.tokenize(query)
34
  paragraph_tokens = tokenizer.tokenize(paragraph)
35
 
36
- # Extract only paragraph-related attention scores
37
- query_length = len(query_tokens) + 2 # +2 for special tokens like [CLS] and [SEP]
38
- para_start_idx = query_length
39
  para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
40
 
41
- para_attention_scores = attention_scores[0, para_start_idx:para_end_idx].mean(dim=0)
 
 
 
 
 
 
 
42
 
43
  # Get indices of top-k attended tokens while preserving order
44
- top_k = min(5, len(paragraph_tokens)) # Extract top 5 tokens or fewer if short
45
- top_indices = para_attention_scores.argsort(descending=True)[:top_k].sort().values # Sort to preserve order
46
 
47
- # Extract top tokens based on original order
48
- highlighted_tokens = [paragraph_tokens[i] for i in top_indices]
49
 
50
- # Reconstruct the excerpt from ordered tokens
51
  excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
52
 
53
  return round(relevance_score, 4), excerpt
 
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
+ inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
18
 
19
  with torch.no_grad():
20
  output = model(**inputs, output_attentions=True) # Get attention scores
 
23
  logit = output.logits.squeeze().item()
24
  relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
+ # Extract attention scores (last attention layer)
27
  attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
 
29
+ # Average attention across heads and batch dimension
30
+ attention_scores = attention.mean(dim=1).mean(dim=0) # Shape: (seq_len, seq_len)
31
+
32
+ # Get tokenized query and paragraph separately
33
  query_tokens = tokenizer.tokenize(query)
34
  paragraph_tokens = tokenizer.tokenize(paragraph)
35
 
36
+ query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and [SEP]
37
+ para_start_idx = query_len
 
38
  para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token
39
 
40
+ # Handle potential indexing issues
41
+ if para_end_idx <= para_start_idx:
42
+ return round(relevance_score, 4), "No relevant tokens extracted."
43
+
44
+ para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
45
+
46
+ if para_attention_scores.numel() == 0:
47
+ return round(relevance_score, 4), "No relevant tokens extracted."
48
 
49
  # Get indices of top-k attended tokens while preserving order
50
+ top_k = min(5, para_attention_scores.size(0)) # Ensure top-k does not exceed available tokens
51
+ top_indices = para_attention_scores.topk(top_k).indices.sort().values # Sort indices to preserve order
52
 
53
+ # Extract highlighted tokens from the paragraph
54
+ highlighted_tokens = [paragraph_tokens[i] for i in top_indices.tolist()]
55
 
56
+ # Convert tokens back to a readable string
57
  excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
58
 
59
  return round(relevance_score, 4), excerpt