wilwork commited on
Commit
7472a45
·
verified ·
1 Parent(s): d7afc9b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+
10
+ # Function to compute relevance score
11
+ def get_relevance_score(query, paragraph):
12
+ inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
13
+ with torch.no_grad():
14
+ scores = model(**inputs).logits.squeeze().item()
15
+ return round(scores, 4)
16
+
17
+ # Gradio interface
18
+ interface = gr.Interface(
19
+ fn=get_relevance_score,
20
+ inputs=[
21
+ gr.Textbox(label="Query", placeholder="Enter your search query..."),
22
+ gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
23
+ ],
24
+ outputs=gr.Number(label="Relevance Score"),
25
+ title="Cross-Encoder Relevance Scoring",
26
+ description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model."
27
+ )
28
+
29
+ if __name__ == "__main__":
30
+ interface.launch()