X-encoder / app.py
wilwork's picture
Update app.py
e64db30 verified
raw
history blame
2.69 kB
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set model to evaluation mode
model.eval()
# Function to compute relevance and highlight relevant tokens
def process_text(query, document, weight):
# Tokenize input
inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
# Get model outputs with attentions
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
relevance_score = torch.sigmoid(outputs.logits).item() # Convert logits to relevance score
attentions = outputs.attentions[-1].squeeze(0).mean(0) # Mean attention across heads
# Calculate dynamic threshold using sigmoid function
def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10):
base_relevance_tensor = torch.tensor(base_relevance)
threshold = min_threshold + (max_threshold - min_threshold) * (
1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
)
return threshold.item()
dynamic_threshold = calculate_threshold(relevance_score) * weight
# Extract important tokens based on attention scores
relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Highlight tokens in the original order, using HTML bold tags
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
highlighted_text = ""
for idx, token in enumerate(tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
# Print values to debug
print(f"Relevance Score: {relevance_score}")
print(f"Dynamic Threshold: {dynamic_threshold}")
return relevance_score, dynamic_threshold, highlighted_text
# Create Gradio interface with a slider for threshold adjustment weight
iface = gr.Interface(
fn=process_text,
inputs=[
gr.Textbox(label="Query"),
gr.Textbox(label="Document Paragraph"),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"),
],
outputs=[
gr.Textbox(label="Relevance Score"),
gr.Textbox(label="Dynamic Threshold"),
gr.HTML(label="Highlighted Document Paragraph")
]
)
iface.launch()