|
from transformers import pipeline |
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
from torch.nn.functional import cosine_similarity |
|
import gradio as gr |
|
|
|
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
def get_similarity(sentence1, sentence2): |
|
input_texts = [sentence1, sentence2] |
|
|
|
batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt") |
|
outputs = model(**batch_dict) |
|
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) |
|
similarity = cosine_similarity(embeddings[0].unsqueeze(0), embeddings[1].unsqueeze(0)) |
|
similarity = round(similarity.item(), 4) |
|
return similarity |
|
|
|
checkpoint = "intfloat/multilingual-e5-large" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModel.from_pretrained(checkpoint) |
|
|
|
demo = gr.Blocks(theme="freddyaboulton/dracula_revamped") |
|
with demo: |
|
|
|
gr.Markdown("# Sentence Similarity") |
|
gr.Markdown("### How to use:") |
|
gr.Markdown("- Enter Passage 1 and Passage 2, then press Submit") |
|
gr.Markdown("- Select an example, then press Submit") |
|
gr.Markdown("Model: https://huggingface.co/intfloat/multilingual-e5-large (Multilingual: 94 languages)") |
|
|
|
with gr.Row(): |
|
|
|
p_txt1 = gr.Textbox(placeholder="Enter passage 1", label="Passage 1", lines=3, scale=2) |
|
p_txt2 = gr.Textbox(placeholder="Enter passage 2", label="Passage 2", lines=3, scale=2) |
|
o_txt = gr.Textbox(placeholder="Similarity score", lines=1, interactive=False, label="Similarity score (0-1)", scale=1) |
|
|
|
submit = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
[ |
|
["A big bus is running on the road in the city.", "There is a big bus running on the road."], |
|
["A big bus is running on the road in the city.", "Two children in costumes are standing on the bed."], |
|
["街中の道路を大きなバスが走っています。", "道路を大きなバスが走っています。"], |
|
["街中の道路を大きなバスが走っています。", "ベッドの上で衣装を着た二人の子供が立っています。"], |
|
["A big bus is running on the road in the city.", "道路を大きなバスが走っています。"] |
|
], |
|
inputs=[p_txt1, p_txt2] |
|
) |
|
|
|
submit.click( |
|
get_similarity, |
|
[p_txt1, p_txt2], |
|
o_txt |
|
) |
|
|
|
demo.launch() |