wilwork commited on
Commit
691416f
·
verified ·
1 Parent(s): 0ebdae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -4,28 +4,43 @@ import torch
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
9
 
10
  # Function to compute relevance score
11
  def get_relevance_score(query, paragraph):
 
 
 
 
 
 
12
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
13
- model.eval()
 
14
  with torch.no_grad():
15
- scores = model(**inputs).logits.squeeze().item()
16
- return round(scores, 4)
 
 
17
 
18
- # Gradio interface
19
  interface = gr.Interface(
20
  fn=get_relevance_score,
21
  inputs=[
22
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
23
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
24
  ],
25
- outputs=gr.Number(label="Relevance Score"),
26
  title="Cross-Encoder Relevance Scoring",
27
- description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model."
 
28
  )
29
 
30
  if __name__ == "__main__":
 
31
  interface.launch()
 
4
 
5
  # Load model and tokenizer
6
  model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
+
8
+ print("Loading model and tokenizer...")
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
+ model.eval() # Set model to evaluation mode
12
+ print("Model and tokenizer loaded successfully.")
13
 
14
  # Function to compute relevance score
15
  def get_relevance_score(query, paragraph):
16
+ if not query.strip() or not paragraph.strip():
17
+ return "Please provide both a query and a document paragraph."
18
+
19
+ print(f"Received inputs -> Query: {query}, Paragraph: {paragraph}")
20
+
21
+ # Tokenize inputs
22
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
23
+
24
+ # Perform inference without gradient tracking
25
  with torch.no_grad():
26
+ score = model(**inputs).logits.squeeze().item()
27
+
28
+ print(f"Calculated score: {score}")
29
+ return round(score, 4)
30
 
31
+ # Define Gradio interface
32
  interface = gr.Interface(
33
  fn=get_relevance_score,
34
  inputs=[
35
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
36
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
37
  ],
38
+ outputs=gr.Textbox(label="Relevance Score"),
39
  title="Cross-Encoder Relevance Scoring",
40
+ description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model.",
41
+ allow_flagging="never"
42
  )
43
 
44
  if __name__ == "__main__":
45
+ print("Launching Gradio app...")
46
  interface.launch()