embeds / main.py
chainyo's picture
fix password
1b29f8c
raw
history blame contribute delete
2.22 kB
import pinecone
import requests
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
from config import config
def search(text: str, k: int = 5):
"""Get the k closest articles to the text."""
embeds = _get_embeddings(text)
r = requests.post(
f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
headers={
"Api-Key": config.pinecone_api_key,
"accept": "application/json",
"content-type": "application/json",
},
json={
"vector": embeds,
"top_k": k,
"includeMetadata": True,
"includeValues": False,
},
)
if r.status_code == 200:
return r.json()
else:
raise Exception(f"Error: {r.status_code} - {r.text}")
def _get_embeddings(text: str):
inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
last_hidden_states = st.session_state.model(**inputs_ids)[0]
return last_hidden_states.mean(dim=1).squeeze().tolist()
password = st.text_input("Password", type="password")
if password == config.password:
st.title("PubMed Embeddings")
st.subheader("Search for a PubMed article and get its id.")
text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")
with st.spinner("Loading Embedding Model..."):
pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
if "index" not in st.session_state:
st.session_state.index = pinecone.Index(config.pinecone_index)
if "tokenizer" not in st.session_state:
st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if "model" not in st.session_state:
st.session_state.model = AutoModel.from_pretrained(config.model_name)
if st.button("Search"):
with st.spinner("Searching..."):
results = search(text)
for res in results["matches"]:
st.write(f"{res['id']} - confidence: {res['score']:.2f}")
else:
st.write("Password incorrect!")