UMLS / app.py
mgbam's picture
Update app.py
8a63984 verified
import os
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import faiss
# Page configuration
st.set_page_config(page_title='KRISSBERT UMLS Linker', layout='wide')
st.title('🧬 KRISSBERT + UMLS Entity Linker (Local FAISS)')
# Paths & model name
METADATA_PATH = 'umls_metadata.json'
EMBED_PATH = 'umls_embeddings.npy'
INDEX_PATH = 'umls_index.faiss'
MODEL_NAME = 'microsoft/BiomedNLP-KRISSBERT-PubMed-UMLS-EL'
# 1️⃣ Load model & tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# 2️⃣ Load UMLS FAISS index & metadata
@st.cache_resource
def load_umls_index():
meta = json.load(open(METADATA_PATH, 'r'))
embeddings = np.load(EMBED_PATH)
index = faiss.read_index(INDEX_PATH)
return index, meta
faiss_index, umls_meta = load_umls_index()
# 3️⃣ Embed text (prefix underscores to avoid caching errors)
@st.cache_resource
def embed_text(text, _tokenizer, _model):
inputs = _tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
outputs = _model(**inputs)
emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
return emb / np.linalg.norm(emb)
# UI: examples + input
st.markdown('Enter a biomedical sentence to link entities via local UMLS FAISS index and KRISSBERT:')
examples = [
'The patient was administered metformin for type 2 diabetes.',
'ER crowding has become a widespread issue in hospitals.',
'Tamoxifen is used in the treatment of ER-positive breast cancer.'
]
selected = st.selectbox('🔍 Example queries', ['Choose...'] + examples)
sentence = st.text_area('📝 Sentence:', value=(selected if selected != 'Choose...' else ''))
if st.button('🔗 Link Entities'):
if not sentence.strip():
st.warning('Please enter a sentence first.')
else:
with st.spinner('Embedding sentence and searching FAISS…'):
sent_emb = embed_text(sentence, tokenizer, model).reshape(1, -1)
distances, indices = faiss_index.search(sent_emb, 5)
results = []
for idx in indices[0]:
entry = umls_meta.get(str(idx), {})
results.append({
'cui': entry.get('cui', ''),
'name': entry.get('name', ''),
'definition': entry.get('definition', ''),
'source': entry.get('source', '')
})
# Display
if results:
st.success('Top UMLS candidates:')
for item in results:
st.markdown(f"**{item['name']}** (CUI: `{item['cui']}`)")
if item['definition']:
st.markdown(f"> {item['definition']}\n")
st.markdown(f"_Source: {item['source']}_\n---")
else:
st.info('No matches found in UMLS index.')