File size: 2,691 Bytes
7472a45 e64db30 7472a45 e64db30 7472a45 e64db30 5552636 e64db30 43a9654 e64db30 43a9654 e64db30 5552636 e64db30 5552636 0e17055 5552636 43a9654 5552636 43a9654 e64db30 43a9654 e64db30 43a9654 e64db30 43a9654 0e17055 e64db30 5552636 e64db30 43a9654 e64db30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set model to evaluation mode
model.eval()
# Function to compute relevance and highlight relevant tokens
def process_text(query, document, weight):
# Tokenize input
inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
# Get model outputs with attentions
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
relevance_score = torch.sigmoid(outputs.logits).item() # Convert logits to relevance score
attentions = outputs.attentions[-1].squeeze(0).mean(0) # Mean attention across heads
# Calculate dynamic threshold using sigmoid function
def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10):
base_relevance_tensor = torch.tensor(base_relevance)
threshold = min_threshold + (max_threshold - min_threshold) * (
1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
)
return threshold.item()
dynamic_threshold = calculate_threshold(relevance_score) * weight
# Extract important tokens based on attention scores
relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Highlight tokens in the original order, using HTML bold tags
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
highlighted_text = ""
for idx, token in enumerate(tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
# Print values to debug
print(f"Relevance Score: {relevance_score}")
print(f"Dynamic Threshold: {dynamic_threshold}")
return relevance_score, dynamic_threshold, highlighted_text
# Create Gradio interface with a slider for threshold adjustment weight
iface = gr.Interface(
fn=process_text,
inputs=[
gr.Textbox(label="Query"),
gr.Textbox(label="Document Paragraph"),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"),
],
outputs=[
gr.Textbox(label="Relevance Score"),
gr.Textbox(label="Dynamic Threshold"),
gr.HTML(label="Highlighted Document Paragraph")
]
)
iface.launch() |