File size: 2,691 Bytes
7472a45
 
e64db30
7472a45
 
 
e64db30
7472a45
 
e64db30
 
5552636
e64db30
 
 
 
 
43a9654
e64db30
43a9654
e64db30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5552636
e64db30
5552636
0e17055
5552636
 
43a9654
5552636
43a9654
e64db30
 
 
43a9654
e64db30
 
 
 
 
43a9654
e64db30
 
 
43a9654
 
0e17055
e64db30
5552636
e64db30
43a9654
 
e64db30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model to evaluation mode
model.eval()

# Function to compute relevance and highlight relevant tokens
def process_text(query, document, weight):
    # Tokenize input
    inputs = tokenizer(query, document, return_tensors="pt", truncation=True, padding=True)
    input_ids = inputs["input_ids"]
    
    # Get model outputs with attentions
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
        relevance_score = torch.sigmoid(outputs.logits).item()  # Convert logits to relevance score
        attentions = outputs.attentions[-1].squeeze(0).mean(0)  # Mean attention across heads

    # Calculate dynamic threshold using sigmoid function
    def calculate_threshold(base_relevance, min_threshold=0.0, max_threshold=0.5, k=10):
        base_relevance_tensor = torch.tensor(base_relevance)
        threshold = min_threshold + (max_threshold - min_threshold) * (
            1 / (1 + torch.exp(-k * (base_relevance_tensor - 0.5)))
        )
        return threshold.item()

    dynamic_threshold = calculate_threshold(relevance_score) * weight

    # Extract important tokens based on attention scores
    relevant_indices = (attentions > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
    
    # Highlight tokens in the original order, using HTML bold tags
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    highlighted_text = ""
    for idx, token in enumerate(tokens):
        if idx in relevant_indices:
            highlighted_text += f"<b>{token}</b> "
        else:
            highlighted_text += f"{token} "

    highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

    # Print values to debug
    print(f"Relevance Score: {relevance_score}")
    print(f"Dynamic Threshold: {dynamic_threshold}")

    return relevance_score, dynamic_threshold, highlighted_text

# Create Gradio interface with a slider for threshold adjustment weight
iface = gr.Interface(
    fn=process_text,
    inputs=[
        gr.Textbox(label="Query"),
        gr.Textbox(label="Document Paragraph"),
        gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Threshold Weight"),
    ],
    outputs=[
        gr.Textbox(label="Relevance Score"),
        gr.Textbox(label="Dynamic Threshold"),
        gr.HTML(label="Highlighted Document Paragraph")
    ]
)

iface.launch()