wilwork commited on
Commit
e64db30
·
verified ·
1 Parent(s): 51f19d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -64
app.py CHANGED
@@ -1,64 +1,44 @@
1
  import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
- model.eval()
10
 
11
- # Sigmoid-based threshold adjustment function
12
- def calculate_threshold(base_relevance, min_threshold=0.02, max_threshold=0.5, k=10):
13
- base_relevance_tensor = torch.tensor(base_relevance) # Ensure input is a tensor
14
- threshold = min_threshold + (max_threshold - min_threshold) * (
15
- 1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
16
- )
17
- return threshold.item() # Convert tensor back to float for use in other functions
18
-
19
- # Function to compute relevance score and dynamically adjust threshold
20
- def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
21
- if not query.strip() or not paragraph.strip():
22
- return "Please provide both a query and a document paragraph.", ""
23
 
24
- # Tokenize the input
25
- inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
 
 
 
26
 
 
27
  with torch.no_grad():
28
- output = model(**inputs, output_attentions=True)
29
-
30
- # Extract logits and calculate base relevance score
31
- logit = output.logits.squeeze().item()
32
- base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
33
-
34
- # Compute dynamic threshold using sigmoid-based adjustment
35
- dynamic_threshold = calculate_threshold(base_relevance_score) * threshold_weight
36
-
37
- # Extract attention scores (last layer)
38
- attention = output.attentions[-1]
39
- attention_scores = attention.mean(dim=1).mean(dim=0)
40
-
41
- query_tokens = tokenizer.tokenize(query)
42
- paragraph_tokens = tokenizer.tokenize(paragraph)
43
-
44
- query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
45
- para_start_idx = query_len
46
- para_end_idx = len(inputs["input_ids"][0]) - 1
47
-
48
- if para_end_idx <= para_start_idx:
49
- return round(base_relevance_score, 4), "No relevant tokens extracted."
50
-
51
- para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
52
-
53
- if para_attention_scores.numel() == 0:
54
- return round(base_relevance_score, 4), "No relevant tokens extracted."
55
-
56
- # Get indices of relevant tokens above dynamic threshold
57
- relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
58
-
59
- # Reconstruct paragraph with bolded relevant tokens using HTML tags
60
  highlighted_text = ""
61
- for idx, token in enumerate(paragraph_tokens):
62
  if idx in relevant_indices:
63
  highlighted_text += f"<b>{token}</b> "
64
  else:
@@ -66,25 +46,25 @@ def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
66
 
67
  highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
68
 
69
- return round(base_relevance_score, 4), highlighted_text
 
 
70
 
71
- # Define Gradio interface with a slider for threshold adjustment
72
- interface = gr.Interface(
73
- fn=get_relevance_score_and_excerpt,
 
 
74
  inputs=[
75
- gr.Textbox(label="Query", placeholder="Enter your search query..."),
76
- gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
77
- gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Threshold Weight")
78
  ],
79
  outputs=[
80
  gr.Textbox(label="Relevance Score"),
 
81
  gr.HTML(label="Highlighted Document Paragraph")
82
- ],
83
- title="Cross-Encoder Attention Highlighting",
84
- description="Adjust the attention threshold weight to control token highlighting sensitivity.",
85
- allow_flagging="never",
86
- live=True
87
  )
88
 
89
- if __name__ == "__main__":
90
- interface.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
9
 
10
+ # Set model to evaluation mode
11
+ model.eval()
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Function to compute relevance and highlight relevant tokens
14
+ def process_text(query, document, weight):
15
+ # Tokenize input
16
+ inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True)
17
+ input_ids = inputs["input_ids"]
18
 
19
+ # Get model outputs with attentions
20
  with torch.no_grad():
21
+ outputs = model(**inputs, output_attentions=True)
22
+ relevance_score = torch.sigmoid(outputs.logits).item() # Convert logits to relevance score
23
+ attentions = outputs.attentions[-1].squeeze(0).mean(0) # Mean attention across heads
24
+
25
+ # Calculate dynamic threshold using sigmoid function
26
+ def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10):
27
+ base_relevance_tensor = torch.tensor(base_relevance)
28
+ threshold = min_threshold + (max_threshold - min_threshold) * (
29
+ 1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
30
+ )
31
+ return threshold.item()
32
+
33
+ dynamic_threshold = calculate_threshold(relevance_score) * weight
34
+
35
+ # Extract important tokens based on attention scores
36
+ relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
37
+
38
+ # Highlight tokens in the original order, using HTML bold tags
39
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  highlighted_text = ""
41
+ for idx, token in enumerate(tokens):
42
  if idx in relevant_indices:
43
  highlighted_text += f"<b>{token}</b> "
44
  else:
 
46
 
47
  highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
48
 
49
+ # Print values to debug
50
+ print(f"Relevance Score: {relevance_score}")
51
+ print(f"Dynamic Threshold: {dynamic_threshold}")
52
 
53
+ return relevance_score, dynamic_threshold, highlighted_text
54
+
55
+ # Create Gradio interface with a slider for threshold adjustment weight
56
+ iface = gr.Interface(
57
+ fn=process_text,
58
  inputs=[
59
+ gr.Textbox(label="Query"),
60
+ gr.Textbox(label="Document Paragraph"),
61
+ gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"),
62
  ],
63
  outputs=[
64
  gr.Textbox(label="Relevance Score"),
65
+ gr.Textbox(label="Dynamic Threshold"),
66
  gr.HTML(label="Highlighted Document Paragraph")
67
+ ]
 
 
 
 
68
  )
69
 
70
+ iface.launch()