dnlp-demo / main.py
Jackss
Stuff
0810777
raw
history blame
1.54 kB
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')
# papers = [{'title': 'BERT', 'abstract': 'We introduce a new language representation model called BERT'},
# {'title': 'Attention is all you need', 'abstract': ' The dominant sequence transduction models are based on complex recurrent or convolutional neural networks'}]
# concatenate title and abstract
class Input(BaseModel):
papers: list = []
app = FastAPI()
@app.post('/similarity')
def similarity(input: Input):
papers = input.papers
title_abs = [d['title'] + tokenizer.sep_token + (d.get('abstract') or '') for d in papers]
# preprocess the input
inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
result = model(**inputs)
# take the first token in the batch as the embedding
embeddings = result.last_hidden_state[:, 0, :].detach().numpy()
res = cosine_similarity(embeddings, embeddings).tolist()
return {"output": res}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")