|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def process_text(query, document, weight): |
|
|
|
inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True) |
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs, output_attentions=True) |
|
relevance_score = torch.sigmoid(outputs.logits).item() |
|
attentions = outputs.attentions[-1].squeeze(0).mean(0) |
|
|
|
|
|
def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10): |
|
base_relevance_tensor = torch.tensor(base_relevance) |
|
threshold = min_threshold + (max_threshold - min_threshold) * ( |
|
1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5))) |
|
) |
|
return threshold.item() |
|
|
|
dynamic_threshold = calculate_threshold(relevance_score) * weight |
|
|
|
|
|
relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist() |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
highlighted_text = "" |
|
for idx, token in enumerate(tokens): |
|
if idx in relevant_indices: |
|
highlighted_text += f"<b>{token}</b> " |
|
else: |
|
highlighted_text += f"{token} " |
|
|
|
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split()) |
|
|
|
|
|
print(f"Relevance Score: {relevance_score}") |
|
print(f"Dynamic Threshold: {dynamic_threshold}") |
|
|
|
return relevance_score, dynamic_threshold, highlighted_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_text, |
|
inputs=[ |
|
gr.Textbox(label="Query"), |
|
gr.Textbox(label="Document Paragraph"), |
|
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"), |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Relevance Score"), |
|
gr.Textbox(label="Dynamic Threshold"), |
|
gr.HTML(label="Highlighted Document Paragraph") |
|
] |
|
) |
|
|
|
iface.launch() |