wilwork commited on
Commit
fe29744
·
verified ·
1 Parent(s): 0e17055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -8,6 +8,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
 
 
 
 
11
  # Function to compute relevance score and dynamically adjust threshold
12
  def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
13
  if not query.strip() or not paragraph.strip():
@@ -23,8 +27,8 @@ def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
23
  logit = output.logits.squeeze().item()
24
  base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
- # Dynamically adjust the attention threshold based on user weight (no relevance score influence)
27
- dynamic_threshold = max(0.02, threshold_weight)
28
 
29
  # Extract attention scores (last layer)
30
  attention = output.attentions[-1]
@@ -66,14 +70,14 @@ interface = gr.Interface(
66
  inputs=[
67
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
68
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
69
- gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
70
  ],
71
  outputs=[
72
  gr.Textbox(label="Relevance Score"),
73
  gr.HTML(label="Highlighted Document Paragraph")
74
  ],
75
  title="Cross-Encoder Attention Highlighting",
76
- description="Adjust the attention threshold to control token highlighting sensitivity.",
77
  allow_flagging="never",
78
  live=True
79
  )
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
+ # Sigmoid-based threshold adjustment function
12
+ def calculate_threshold(base_relevance, min_threshold=0.02, max_threshold=0.5, k=10):
13
+ return min_threshold + (max_threshold - min_threshold) * (1 / (1 + torch.exp(-k * (base_relevance - 0.5))))
14
+
15
  # Function to compute relevance score and dynamically adjust threshold
16
  def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
17
  if not query.strip() or not paragraph.strip():
 
27
  logit = output.logits.squeeze().item()
28
  base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
29
 
30
+ # Compute dynamic threshold using sigmoid-based adjustment
31
+ dynamic_threshold = calculate_threshold(base_relevance_score) * threshold_weight
32
 
33
  # Extract attention scores (last layer)
34
  attention = output.attentions[-1]
 
70
  inputs=[
71
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
72
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
73
+ gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Threshold Weight")
74
  ],
75
  outputs=[
76
  gr.Textbox(label="Relevance Score"),
77
  gr.HTML(label="Highlighted Document Paragraph")
78
  ],
79
  title="Cross-Encoder Attention Highlighting",
80
+ description="Adjust the attention threshold weight to control token highlighting sensitivity.",
81
  allow_flagging="never",
82
  live=True
83
  )