wilwork commited on
Commit
6e70d21
·
verified ·
1 Parent(s): 8014cd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -22
app.py CHANGED
@@ -4,47 +4,64 @@ import torch
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
-
8
- print("Loading model and tokenizer...")
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
  model.eval() # Set model to evaluation mode
12
- print("Model and tokenizer loaded successfully.")
13
 
14
- # Function to compute relevance score
15
- def get_relevance_score(query, paragraph):
16
  if not query.strip() or not paragraph.strip():
17
- return "Please provide both a query and a document paragraph."
 
 
 
18
 
19
- print(f"Received inputs -> Query: {query}, Paragraph: {paragraph}")
 
20
 
21
- # Tokenize inputs
22
- inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
 
23
 
24
- # Perform inference without gradient tracking
25
- with torch.no_grad():
26
- score = model(**inputs).logits.squeeze().item()
 
 
27
 
28
- print(f"Calculated score: {score}")
29
- return round(score, 4)
 
 
 
 
 
 
 
 
 
30
 
 
 
31
 
32
- def test_function(query, paragraph):
33
- return f"Received query: {query}, paragraph: {paragraph}"
34
 
35
  # Define Gradio interface
36
  interface = gr.Interface(
37
- fn=get_relevance_score,
38
  inputs=[
39
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
40
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
41
  ],
42
- outputs=gr.Textbox(label="Relevance Score"),
43
- title="Cross-Encoder Relevance Scoring",
44
- description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model."
 
 
 
 
 
45
  )
46
 
47
-
48
  if __name__ == "__main__":
49
- print("Launching Gradio app...")
50
  interface.launch()
 
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval() # Set model to evaluation mode
 
10
 
11
+ # Function to get relevance score and relevant excerpt based on attention scores
12
+ def get_relevance_score_and_excerpt(query, paragraph):
13
  if not query.strip() or not paragraph.strip():
14
+ return "Please provide both a query and a document paragraph.", ""
15
+
16
+ # Tokenize the input
17
+ inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
18
 
19
+ with torch.no_grad():
20
+ output = model(**inputs, output_attentions=True) # Get attention scores
21
 
22
+ # Extract logits and calculate relevance score
23
+ logit = output.logits.squeeze().item()
24
+ relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
+ # Extract attention scores (use the last attention layer)
27
+ attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
+
29
+ # Average across attention heads to get token importance
30
+ attention_scores = attention.mean(dim=1).squeeze(0) # Shape: (seq_len, seq_len)
31
 
32
+ # Focus on the paragraph part only (ignore query tokens)
33
+ input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
34
+ query_length = len(tokenizer.tokenize(query))
35
+
36
+ # Extract attention for the paragraph tokens only
37
+ paragraph_tokens = input_tokens[query_length + 2 : -1] # Skip query and special tokens like [SEP]
38
+ paragraph_attention = attention_scores[query_length + 2 : -1, query_length + 2 : -1].mean(dim=0)
39
+
40
+ # Get the top tokens with highest attention scores
41
+ top_token_indices = torch.argsort(paragraph_attention, descending=True)[:5] # Top 5 tokens
42
+ highlighted_tokens = [paragraph_tokens[i] for i in top_token_indices]
43
 
44
+ # Reconstruct the excerpt from top attention tokens
45
+ excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
46
 
47
+ return round(relevance_score, 4), excerpt
 
48
 
49
  # Define Gradio interface
50
  interface = gr.Interface(
51
+ fn=get_relevance_score_and_excerpt,
52
  inputs=[
53
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
54
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
55
  ],
56
+ outputs=[
57
+ gr.Textbox(label="Relevance Score"),
58
+ gr.Textbox(label="Most Relevant Excerpt")
59
+ ],
60
+ title="Cross-Encoder Relevance Scoring with Attention-Based Excerpt Extraction",
61
+ description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt using attention scores.",
62
+ allow_flagging="never",
63
+ live=True
64
  )
65
 
 
66
  if __name__ == "__main__":
 
67
  interface.launch()