wilwork commited on
Commit
6700bfc
·
verified ·
1 Parent(s): 0f56dc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -42
app.py CHANGED
@@ -9,71 +9,89 @@ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
  # Function to compute relevance score (in logits) and dynamically adjust threshold
12
- def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
13
- if not query.strip() or not paragraph.strip():
14
- return "Please provide both a query and a document paragraph.", ""
15
 
16
- # Tokenize the input
17
- inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
18
 
19
- with torch.no_grad():
20
- output = model(**inputs, output_attentions=True)
 
 
 
 
 
21
 
22
- # Extract logits (no sigmoid applied)
23
- logit = output.logits.squeeze().item()
24
- base_relevance_score = logit # Relevance score in logits
25
 
26
- # Dynamically adjust the attention threshold based on user weight (no relevance score influence)
27
- dynamic_threshold = max(0.02, threshold_weight)
28
 
29
- # Extract attention scores (last layer)
30
- attention = output.attentions[-1]
31
- attention_scores = attention.mean(dim=1).mean(dim=0)
32
 
33
- query_tokens = tokenizer.tokenize(query)
34
- paragraph_tokens = tokenizer.tokenize(paragraph)
35
 
36
- query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
37
- para_start_idx = query_len
38
- para_end_idx = len(inputs["input_ids"][0]) - 1
39
 
40
- if para_end_idx <= para_start_idx:
41
- return round(base_relevance_score, 4), "No relevant tokens extracted."
42
 
43
- para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
44
 
45
- if para_attention_scores.numel() == 0:
46
- return round(base_relevance_score, 4), "No relevant tokens extracted."
47
 
48
- # Get indices of relevant tokens above dynamic threshold
49
- relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
50
 
51
- # Reconstruct paragraph with bolded relevant tokens using HTML tags
52
- highlighted_text = ""
53
- for idx, token in enumerate(paragraph_tokens):
54
- if idx in relevant_indices:
55
- highlighted_text += f"<b>{token}</b> "
56
- else:
57
- highlighted_text += f"{token} "
58
 
59
- highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
60
 
61
- return round(base_relevance_score, 4), highlighted_text
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Define Gradio interface with a slider for threshold adjustment
64
  interface = gr.Interface(
65
  fn=get_relevance_score_and_excerpt,
66
  inputs=[
67
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
68
- gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
 
 
69
  gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
70
  ],
71
  outputs=[
72
- gr.Textbox(label="Relevance Score (Logits)"),
73
- gr.HTML(label="Highlighted Document Paragraph")
74
  ],
75
- title="Cross-Encoder Attention Highlighting",
76
- description="Adjust the attention threshold to control token highlighting sensitivity.",
77
  allow_flagging="never",
78
  live=True
79
  )
 
9
  model.eval()
10
 
11
  # Function to compute relevance score (in logits) and dynamically adjust threshold
12
+ def get_relevance_score_and_excerpt(query, *paragraphs, threshold_weight):
13
+ if not query.strip() or not any(p.strip() for p in paragraphs):
14
+ return "Please provide both a query and at least one document paragraph.", ""
15
 
16
+ ranked_paragraphs = []
 
17
 
18
+ # Process each paragraph and calculate its logits and highlighted text
19
+ for paragraph in paragraphs:
20
+ # Tokenize the input
21
+ inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
22
+
23
+ with torch.no_grad():
24
+ output = model(**inputs, output_attentions=True)
25
 
26
+ # Extract logits (no sigmoid applied)
27
+ logit = output.logits.squeeze().item()
28
+ base_relevance_score = logit # Relevance score in logits
29
 
30
+ # Dynamically adjust the attention threshold based on user weight
31
+ dynamic_threshold = max(0.02, threshold_weight)
32
 
33
+ # Extract attention scores (last layer)
34
+ attention = output.attentions[-1]
35
+ attention_scores = attention.mean(dim=1).mean(dim=0)
36
 
37
+ query_tokens = tokenizer.tokenize(query)
38
+ paragraph_tokens = tokenizer.tokenize(paragraph)
39
 
40
+ query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
41
+ para_start_idx = query_len
42
+ para_end_idx = len(inputs["input_ids"][0]) - 1
43
 
44
+ if para_end_idx <= para_start_idx:
45
+ continue
46
 
47
+ para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
48
 
49
+ if para_attention_scores.numel() == 0:
50
+ continue
51
 
52
+ # Get indices of relevant tokens above dynamic threshold
53
+ relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
54
 
55
+ # Reconstruct paragraph with bolded relevant tokens using HTML tags
56
+ highlighted_text = ""
57
+ for idx, token in enumerate(paragraph_tokens):
58
+ if idx in relevant_indices:
59
+ highlighted_text += f"<b>{token}</b> "
60
+ else:
61
+ highlighted_text += f"{token} "
62
 
63
+ highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
64
 
65
+ ranked_paragraphs.append({
66
+ "logit": logit,
67
+ "highlighted_text": highlighted_text
68
+ })
69
+
70
+ # Sort paragraphs by logit (descending)
71
+ ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True)
72
+
73
+ # Prepare output
74
+ relevance_scores = [round(p["logit"], 4) for p in ranked_paragraphs]
75
+ highlighted_texts = [p["highlighted_text"] for p in ranked_paragraphs]
76
+
77
+ return "\n".join([f"Relevance Score: {score}" for score in relevance_scores]), "\n\n".join(highlighted_texts)
78
 
79
+ # Define Gradio interface with a slider for threshold adjustment and ability to add multiple paragraphs
80
  interface = gr.Interface(
81
  fn=get_relevance_score_and_excerpt,
82
  inputs=[
83
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
84
+ gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
85
+ gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
86
+ gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
87
  gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
88
  ],
89
  outputs=[
90
+ gr.Textbox(label="Relevance Scores (Logits)"),
91
+ gr.HTML(label="Highlighted Document Paragraphs")
92
  ],
93
+ title="Cross-Encoder Attention Highlighting with Reranking",
94
+ description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.",
95
  allow_flagging="never",
96
  live=True
97
  )