wilwork commited on
Commit
7dce425
·
verified ·
1 Parent(s): cff536c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # Load embedding model and tokenizer
7
+ model_name = "Supabase/gte-small"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModel.from_pretrained(model_name)
10
+ model.eval()
11
+
12
+ def get_embedding(text):
13
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
14
+ with torch.no_grad():
15
+ output = model(**inputs)
16
+ return output.last_hidden_state[:, 0, :].squeeze() # Use CLS token embedding
17
+
18
+ def get_similarity_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
19
+ paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
20
+
21
+ if not query.strip() or not paragraphs:
22
+ return "Please provide both a query and at least one document paragraph."
23
+
24
+ query_embedding = get_embedding(query)
25
+ ranked_paragraphs = []
26
+
27
+ for paragraph in paragraphs:
28
+ para_embedding = get_embedding(paragraph)
29
+ similarity = F.cosine_similarity(query_embedding, para_embedding, dim=0).item()
30
+
31
+ # Highlight words using threshold
32
+ tokens = tokenizer.tokenize(paragraph)
33
+ threshold = max(0.02, threshold_weight)
34
+ highlighted_text = " ".join(f"<b>{token}</b>" if similarity > threshold else token for token in tokens)
35
+ highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
36
+
37
+ ranked_paragraphs.append({"similarity": similarity, "highlighted_text": highlighted_text})
38
+
39
+ ranked_paragraphs.sort(key=lambda x: x["similarity"], reverse=True)
40
+
41
+ output_html = "<table border='1' style='width:100%; border-collapse: collapse;'>"
42
+ output_html += "<tr><th>Cosine Similarity</th><th>Highlighted Paragraph</th></tr>"
43
+ for item in ranked_paragraphs:
44
+ output_html += f"<tr><td>{round(item['similarity'], 4)}</td><td>{item['highlighted_text']}</td></tr>"
45
+ output_html += "</table>"
46
+
47
+ return output_html
48
+
49
+ interface = gr.Interface(
50
+ fn=get_similarity_and_excerpt,
51
+ inputs=[
52
+ gr.Textbox(label="Query", placeholder="Enter your search query..."),
53
+ gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
54
+ gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
55
+ gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
56
+ gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Similarity Threshold")
57
+ ],
58
+ outputs=[gr.HTML(label="Ranked Paragraphs")],
59
+ title="Embedding-Based Similarity Highlighting",
60
+ description="Uses cosine similarity with Supabase/gte-small embeddings to rank paragraphs and highlight relevant words.",
61
+ allow_flagging="never",
62
+ live=True
63
+ )
64
+
65
+ if __name__ == "__main__":
66
+ interface.launch()