|
import gradio as gr |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" |
|
|
|
print("Loading model and tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
model.eval() |
|
print("Model and tokenizer loaded successfully.") |
|
|
|
|
|
def get_relevance_score(query, paragraph): |
|
if not query.strip() or not paragraph.strip(): |
|
return "Please provide both a query and a document paragraph." |
|
|
|
print(f"Received inputs -> Query: {query}, Paragraph: {paragraph}") |
|
|
|
|
|
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
score = model(**inputs).logits.squeeze().item() |
|
|
|
print(f"Calculated score: {score}") |
|
return round(score, 4) |
|
|
|
|
|
def test_function(query, paragraph): |
|
return f"Received query: {query}, paragraph: {paragraph}" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=get_relevance_score, |
|
inputs=[ |
|
gr.Textbox(label="Query", placeholder="Enter your search query..."), |
|
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...") |
|
], |
|
outputs=gr.Textbox(label="Relevance Score"), |
|
title="Cross-Encoder Relevance Scoring", |
|
description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Launching Gradio app...") |
|
interface.launch() |