|
import gradio as gr |
|
import numpy as np |
|
from usearch.index import Index |
|
from sentence_transformers import SentenceTransformer |
|
from datasets import load_dataset |
|
from sentencex import segment |
|
|
|
model = SentenceTransformer("Corran/SciGenAllMiniLM") |
|
|
|
train = load_dataset("Corran/SciGenColbertTriplets")['train'] |
|
|
|
rf = sorted(list(set(train['query']))) |
|
rf_emb = model.encode(rf) |
|
index = Index(ndim=rf_emb[0].size) |
|
index.add(range(len(classes)), classes_emb) |
|
|
|
def get_matches(input): |
|
global index, model, rf |
|
emb = model.encode(input,batch_size=128) |
|
matches = index.search(emb,4) |
|
if type(input)==list and len(input)>1: |
|
matches = [m[0] for m in matches] |
|
else: |
|
matches = [m for m in matches] |
|
return [(rf[m.key],m.distance) for m in matches] |
|
|
|
|
|
|
|
def return_rf_scores(paragraph): |
|
|
|
sentences = list(segment("en", paragraph)) |
|
matches = get_matches(sentences) |
|
|
|
output = {} |
|
for s,m in zip(sentences,matches): |
|
output[s] = m |
|
|
|
return output |
|
|
|
demo = gr.Interface(fn=return_rf_scores, inputs="text", outputs="json") |
|
demo.launch() |
|
|