|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import pickle as pkl |
|
from argparse import ArgumentParser |
|
from collections import OrderedDict |
|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
from build_index import load_model |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from nemo.utils import logging |
|
|
|
try: |
|
import faiss |
|
except ModuleNotFoundError: |
|
logging.warning("Faiss is required for building the index. Please install faiss-gpu") |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def get_query_embedding(query, model): |
|
"""Use entity linking encoder to get embedding for index query""" |
|
model_input = model.tokenizer( |
|
query, |
|
add_special_tokens=True, |
|
padding=True, |
|
truncation=True, |
|
max_length=512, |
|
return_token_type_ids=True, |
|
return_attention_mask=True, |
|
) |
|
|
|
query_emb = model.forward( |
|
input_ids=torch.LongTensor([model_input["input_ids"]]).to(device), |
|
token_type_ids=torch.LongTensor([model_input["token_type_ids"]]).to(device), |
|
attention_mask=torch.LongTensor([model_input["attention_mask"]]).to(device), |
|
) |
|
|
|
return query_emb |
|
|
|
|
|
def query_index( |
|
query: str, cfg: DictConfig, model: object, index: object, pca: object, idx2id: dict, id2string: dict, |
|
) -> Dict: |
|
|
|
""" |
|
Query the nearest neighbor index of entities to find the |
|
concepts in the index dataset that are most similar to the |
|
query. |
|
|
|
Args: |
|
query (str): entity to look up in the index |
|
cfg (DictConfig): config object to specifiy query parameters |
|
model (EntityLinkingModel): entity linking encoder model |
|
index (object): faiss index |
|
pca (object): sklearn pca transformation to be applied to queries |
|
idx2id (dict): dictionary mapping unique concept dataset index to |
|
its CUI |
|
id2string (dict): dictionary mapping each unqiue CUI to a |
|
representative english description of |
|
the concept |
|
Returns: |
|
A dictionary with the concept ids of the index's most similar |
|
entities as the keys and a tuple containing the string |
|
representation of that concept and its cosine similarity to |
|
the query as the values. |
|
""" |
|
query_emb = get_query_embedding(query, model).detach().cpu().numpy() |
|
|
|
if cfg.apply_pca: |
|
query_emb = pca.transform(query_emb) |
|
|
|
dist, neighbors = index.search(query_emb.astype(np.float32), cfg.query_num_factor * cfg.top_n) |
|
dist, neighbors = dist[0], neighbors[0] |
|
unique_ids = OrderedDict() |
|
neighbor_idx = 0 |
|
|
|
|
|
while len(unique_ids) < cfg.top_n and neighbor_idx < len(neighbors): |
|
concept_id_idx = neighbors[neighbor_idx] |
|
concept_id = idx2id[concept_id_idx] |
|
|
|
|
|
if concept_id not in unique_ids: |
|
concept = id2string[concept_id] |
|
unique_ids[concept_id] = (concept, 1 - dist[neighbor_idx]) |
|
|
|
neighbor_idx += 1 |
|
|
|
unique_ids = dict(unique_ids) |
|
|
|
return unique_ids |
|
|
|
|
|
def main(cfg: DictConfig, restore: bool): |
|
""" |
|
Loads faiss index and allows commandline queries |
|
to the index. Builds new index if one hasn't been built yet. |
|
|
|
Args: |
|
cfg: Config file specifying index parameters |
|
restore: Whether to restore model weights trained |
|
by the user. Otherwise will load weights |
|
used before self alignment pretraining. |
|
""" |
|
|
|
if not os.path.isfile(cfg.index.index_save_name) or ( |
|
cfg.apply_pca and not os.path.isfile(cfg.index.pca.pca_save_name) or not os.path.isfile(cfg.index.idx_to_id) |
|
): |
|
logging.warning("Either no index and/or no mapping from entity idx to ids exists. Please run `build_index.py`") |
|
return |
|
|
|
logging.info("Loading entity linking encoder model") |
|
model = load_model(cfg.model, restore) |
|
|
|
logging.info("Loading index and associated files") |
|
index = faiss.read_index(cfg.index.index_save_name) |
|
idx2id = pkl.load(open(cfg.index.idx_to_id, "rb")) |
|
id2string = pkl.load(open(cfg.index.id_to_string, "rb")) |
|
|
|
if cfg.index.apply_pca: |
|
pca = pkl.load(open(cfg.index.pca.pca_save_name, "rb")) |
|
|
|
while True: |
|
query = input("enter index query: ") |
|
output = query_index(query, cfg.top_n, cfg.index, model, index, pca, idx2id, id2string) |
|
|
|
if query == "exit": |
|
break |
|
|
|
for concept_id in output: |
|
concept_details = output[concept_id] |
|
concept_id = "C" + str(concept_id).zfill(7) |
|
print(concept_id, concept_details) |
|
|
|
print("----------------\n") |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--restore", action="store_true", help="Whether to restore encoder model weights from nemo path" |
|
) |
|
parser.add_argument("--project_dir", required=False, type=str, default=".") |
|
parser.add_argument("--cfg", required=False, type=str, default="./conf/umls_medical_entity_linking_config.yaml") |
|
args = parser.parse_args() |
|
|
|
cfg = OmegaConf.load(args.cfg) |
|
cfg.project_dir = args.project_dir |
|
|
|
main(cfg, args.restore) |
|
|