File size: 5,925 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle as pkl
from argparse import ArgumentParser
from collections import OrderedDict
from typing import Dict

import numpy as np
import torch
from build_index import load_model
from omegaconf import DictConfig, OmegaConf

from nemo.utils import logging

try:
    import faiss
except ModuleNotFoundError:
    logging.warning("Faiss is required for building the index. Please install faiss-gpu")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_query_embedding(query, model):
    """Use entity linking encoder to get embedding for index query"""
    model_input = model.tokenizer(
        query,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        max_length=512,
        return_token_type_ids=True,
        return_attention_mask=True,
    )

    query_emb = model.forward(
        input_ids=torch.LongTensor([model_input["input_ids"]]).to(device),
        token_type_ids=torch.LongTensor([model_input["token_type_ids"]]).to(device),
        attention_mask=torch.LongTensor([model_input["attention_mask"]]).to(device),
    )

    return query_emb


def query_index(
    query: str, cfg: DictConfig, model: object, index: object, pca: object, idx2id: dict, id2string: dict,
) -> Dict:

    """
    Query the nearest neighbor index of entities to find the 
    concepts in the index dataset that are most similar to the 
    query.

    Args:
        query (str): entity to look up in the index
        cfg (DictConfig): config object to specifiy query parameters
        model (EntityLinkingModel): entity linking encoder model
        index (object): faiss index
        pca (object): sklearn pca transformation to be applied to queries 
        idx2id (dict): dictionary mapping unique concept dataset index to 
                       its CUI
        id2string (dict): dictionary mapping each unqiue CUI to a 
                          representative english description of
                          the concept
    Returns:
        A dictionary with the concept ids of the index's most similar 
        entities as the keys and a tuple containing the string 
        representation of that concept and its cosine similarity to 
        the query as the values. 
    """
    query_emb = get_query_embedding(query, model).detach().cpu().numpy()

    if cfg.apply_pca:
        query_emb = pca.transform(query_emb)

    dist, neighbors = index.search(query_emb.astype(np.float32), cfg.query_num_factor * cfg.top_n)
    dist, neighbors = dist[0], neighbors[0]
    unique_ids = OrderedDict()
    neighbor_idx = 0

    # Many of nearest neighbors could map to the same concept id, their idx is their unique identifier
    while len(unique_ids) < cfg.top_n and neighbor_idx < len(neighbors):
        concept_id_idx = neighbors[neighbor_idx]
        concept_id = idx2id[concept_id_idx]

        # Only want one instance of each unique concept
        if concept_id not in unique_ids:
            concept = id2string[concept_id]
            unique_ids[concept_id] = (concept, 1 - dist[neighbor_idx])

        neighbor_idx += 1

    unique_ids = dict(unique_ids)

    return unique_ids


def main(cfg: DictConfig, restore: bool):
    """
    Loads faiss index and allows commandline queries 
    to the index. Builds new index if one hasn't been built yet.

    Args:
        cfg: Config file specifying index parameters
        restore: Whether to restore model weights trained
                 by the user. Otherwise will load weights
                 used before self alignment pretraining.
    """

    if not os.path.isfile(cfg.index.index_save_name) or (
        cfg.apply_pca and not os.path.isfile(cfg.index.pca.pca_save_name) or not os.path.isfile(cfg.index.idx_to_id)
    ):
        logging.warning("Either no index and/or no mapping from entity idx to ids exists. Please run `build_index.py`")
        return

    logging.info("Loading entity linking encoder model")
    model = load_model(cfg.model, restore)

    logging.info("Loading index and associated files")
    index = faiss.read_index(cfg.index.index_save_name)
    idx2id = pkl.load(open(cfg.index.idx_to_id, "rb"))
    id2string = pkl.load(open(cfg.index.id_to_string, "rb"))  # Should be created during dataset prep

    if cfg.index.apply_pca:
        pca = pkl.load(open(cfg.index.pca.pca_save_name, "rb"))

    while True:
        query = input("enter index query: ")
        output = query_index(query, cfg.top_n, cfg.index, model, index, pca, idx2id, id2string)

        if query == "exit":
            break

        for concept_id in output:
            concept_details = output[concept_id]
            concept_id = "C" + str(concept_id).zfill(7)
            print(concept_id, concept_details)

        print("----------------\n")


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        "--restore", action="store_true", help="Whether to restore encoder model weights from nemo path"
    )
    parser.add_argument("--project_dir", required=False, type=str, default=".")
    parser.add_argument("--cfg", required=False, type=str, default="./conf/umls_medical_entity_linking_config.yaml")
    args = parser.parse_args()

    cfg = OmegaConf.load(args.cfg)
    cfg.project_dir = args.project_dir

    main(cfg, args.restore)