File size: 5,925 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle as pkl
from argparse import ArgumentParser
from collections import OrderedDict
from typing import Dict
import numpy as np
import torch
from build_index import load_model
from omegaconf import DictConfig, OmegaConf
from nemo.utils import logging
try:
import faiss
except ModuleNotFoundError:
logging.warning("Faiss is required for building the index. Please install faiss-gpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_query_embedding(query, model):
"""Use entity linking encoder to get embedding for index query"""
model_input = model.tokenizer(
query,
add_special_tokens=True,
padding=True,
truncation=True,
max_length=512,
return_token_type_ids=True,
return_attention_mask=True,
)
query_emb = model.forward(
input_ids=torch.LongTensor([model_input["input_ids"]]).to(device),
token_type_ids=torch.LongTensor([model_input["token_type_ids"]]).to(device),
attention_mask=torch.LongTensor([model_input["attention_mask"]]).to(device),
)
return query_emb
def query_index(
query: str, cfg: DictConfig, model: object, index: object, pca: object, idx2id: dict, id2string: dict,
) -> Dict:
"""
Query the nearest neighbor index of entities to find the
concepts in the index dataset that are most similar to the
query.
Args:
query (str): entity to look up in the index
cfg (DictConfig): config object to specifiy query parameters
model (EntityLinkingModel): entity linking encoder model
index (object): faiss index
pca (object): sklearn pca transformation to be applied to queries
idx2id (dict): dictionary mapping unique concept dataset index to
its CUI
id2string (dict): dictionary mapping each unqiue CUI to a
representative english description of
the concept
Returns:
A dictionary with the concept ids of the index's most similar
entities as the keys and a tuple containing the string
representation of that concept and its cosine similarity to
the query as the values.
"""
query_emb = get_query_embedding(query, model).detach().cpu().numpy()
if cfg.apply_pca:
query_emb = pca.transform(query_emb)
dist, neighbors = index.search(query_emb.astype(np.float32), cfg.query_num_factor * cfg.top_n)
dist, neighbors = dist[0], neighbors[0]
unique_ids = OrderedDict()
neighbor_idx = 0
# Many of nearest neighbors could map to the same concept id, their idx is their unique identifier
while len(unique_ids) < cfg.top_n and neighbor_idx < len(neighbors):
concept_id_idx = neighbors[neighbor_idx]
concept_id = idx2id[concept_id_idx]
# Only want one instance of each unique concept
if concept_id not in unique_ids:
concept = id2string[concept_id]
unique_ids[concept_id] = (concept, 1 - dist[neighbor_idx])
neighbor_idx += 1
unique_ids = dict(unique_ids)
return unique_ids
def main(cfg: DictConfig, restore: bool):
"""
Loads faiss index and allows commandline queries
to the index. Builds new index if one hasn't been built yet.
Args:
cfg: Config file specifying index parameters
restore: Whether to restore model weights trained
by the user. Otherwise will load weights
used before self alignment pretraining.
"""
if not os.path.isfile(cfg.index.index_save_name) or (
cfg.apply_pca and not os.path.isfile(cfg.index.pca.pca_save_name) or not os.path.isfile(cfg.index.idx_to_id)
):
logging.warning("Either no index and/or no mapping from entity idx to ids exists. Please run `build_index.py`")
return
logging.info("Loading entity linking encoder model")
model = load_model(cfg.model, restore)
logging.info("Loading index and associated files")
index = faiss.read_index(cfg.index.index_save_name)
idx2id = pkl.load(open(cfg.index.idx_to_id, "rb"))
id2string = pkl.load(open(cfg.index.id_to_string, "rb")) # Should be created during dataset prep
if cfg.index.apply_pca:
pca = pkl.load(open(cfg.index.pca.pca_save_name, "rb"))
while True:
query = input("enter index query: ")
output = query_index(query, cfg.top_n, cfg.index, model, index, pca, idx2id, id2string)
if query == "exit":
break
for concept_id in output:
concept_details = output[concept_id]
concept_id = "C" + str(concept_id).zfill(7)
print(concept_id, concept_details)
print("----------------\n")
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
"--restore", action="store_true", help="Whether to restore encoder model weights from nemo path"
)
parser.add_argument("--project_dir", required=False, type=str, default=".")
parser.add_argument("--cfg", required=False, type=str, default="./conf/umls_medical_entity_linking_config.yaml")
args = parser.parse_args()
cfg = OmegaConf.load(args.cfg)
cfg.project_dir = args.project_dir
main(cfg, args.restore)
|