import utils import os import math import json import numpy as np import matplotlib.pyplot as plt import mne from mne.channels import read_custom_montage from scipy.interpolate import Rbf from scipy.optimize import linear_sum_assignment from sklearn.neighbors import NearestNeighbors def reorder_data(idx_order, fill_flags, filename, new_filename): # read the input data raw_data = utils.read_train_data(filename) #print(raw_data.shape) new_data = np.zeros((30, raw_data.shape[1])) zero_arr = np.zeros((1, raw_data.shape[1])) for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): if flag == False: new_data[i, :] = raw_data[idx_set[0], :] elif idx_set == []: new_data[i, :] = zero_arr else: tmp_data = [raw_data[j, :] for j in idx_set] new_data[i, :] = np.mean(tmp_data, axis=0) utils.save_data(new_data, new_filename) return raw_data.shape def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename): # read the denoised data d_data = utils.read_train_data(filename) if batch_cnt == 0: new_data = np.zeros((raw_data_shape[0], d_data.shape[1])) #print(new_data.shape) else: new_data = utils.read_train_data(new_filename) for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): # ignore if this channel was filled using "fillmode" if flag == False: new_data[idx_set[0], :] = d_data[i, :] utils.save_data(new_data, new_filename) return def get_matched(tpl_order, tpl_dict): return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True] def get_empty_templates(tpl_order, tpl_dict): return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False] def get_unassigned_inputs(in_order, in_dict): return [channel for channel in in_order if in_dict[channel]["assigned"]==False] def read_montage_data(loc_file): tpl_montage = read_custom_montage("./template_chanlocs.loc") in_montage = read_custom_montage(loc_file) tpl_order = tpl_montage.ch_names in_order = in_montage.ch_names tpl_dict = {} in_dict = {} # convert all channel names to uppercase and store the channel information for i, channel in enumerate(tpl_order): up_channel = str.upper(channel) tpl_montage.rename_channels({channel: up_channel}) tpl_dict[up_channel] = { "index" : i, "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel], "matched" : False } for i, channel in enumerate(in_order): up_channel = str.upper(channel) in_montage.rename_channels({channel: up_channel}) in_dict[up_channel] = { "index" : i, "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel], "assigned" : False } return tpl_montage, in_montage, tpl_dict, in_dict def save_figures(channel_info, tpl_montage, filename1, filename2): tpl_order = channel_info["templateOrder"] in_order = channel_info["inputOrder"] tpl_dict = channel_info["templateDict"] in_dict = channel_info["inputDict"] tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order] tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order] in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order] in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order] tpl_coords = np.vstack((tpl_x, tpl_y)).T in_coords = np.vstack((in_x, in_y)).T # extract template's head figure tpl_fig = tpl_montage.plot() tpl_ax = tpl_fig.axes[0] lines = tpl_ax.lines head_lines = [] for line in lines: x, y = line.get_data() head_lines.append((x,y)) plt.close() # -------------------------plot input montage------------------------------ fig = plt.figure(figsize=(6.4,6.4), dpi=100) ax = fig.add_subplot(111) fig.tight_layout() ax.set_aspect('equal') ax.axis('off') # plot template's head for x, y in head_lines: ax.plot(x, y, color='black', linewidth=1.0) # plot in_channels on it ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black') for i, channel in enumerate(in_order): ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center') # save input_montage fig.savefig(filename1) # ---------------------------add indications------------------------------- # plot unmatched input channels in red indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False] ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red') for i in indices: ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center') # save mapped_montage fig.savefig(filename2) # ------------------------------------------------------------------------- # store the tpl and in_channels' display positions (in px). tpl_coords = ax.transData.transform(tpl_coords) in_coords = ax.transData.transform(in_coords) plt.close() for i, channel in enumerate(tpl_order): css_left = (tpl_coords[i,0]-11)/6.4 css_bottom = (tpl_coords[i,1]-7)/6.4 tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] for i, channel in enumerate(in_order): css_left = (in_coords[i,0]-11)/6.4 css_bottom = (in_coords[i,1]-7)/6.4 in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] channel_info.update({ "templateDict" : tpl_dict, "inputDict" : in_dict }) return channel_info def align_coords(channel_info, tpl_montage, in_montage): tpl_order = channel_info["templateOrder"] in_order = channel_info["inputOrder"] tpl_dict = channel_info["templateDict"] in_dict = channel_info["inputDict"] matched = get_matched(tpl_order, tpl_dict) # 2D alignment (for visualization purposes) fig = [tpl_montage.plot(), in_montage.plot()] ax = [fig[0].axes[0], fig[1].axes[0]] # extract the displayed 2D coordinates from the plots all_tpl = ax[0].collections[0].get_offsets().data all_in= ax[1].collections[0].get_offsets().data matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) # apply TPS to transform in_channels positions to align with tpl_channels positions rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate') rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate') # apply the transformation to all in_channels transformed_in_x = rbf_x(all_in[:,0], all_in[:,1]) transformed_in_y = rbf_y(all_in[:,0], all_in[:,1]) transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T # store the 2D positions for i, channel in enumerate(tpl_order): tpl_dict[channel]["coord_2d"] = all_tpl[i] for i, channel in enumerate(in_order): in_dict[channel]["coord_2d"] = transformed_in[i].tolist() # 3D alignment all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order]) all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order]) matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate') rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate') rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate') transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T # update in_channels' 3D positions for i, channel in enumerate(in_order): in_dict[channel]["coord_3d"] = transformed_in[i].tolist() channel_info.update({ "templateDict" : tpl_dict, "inputDict" : in_dict }) return channel_info def find_neighbors(channel_info, missing_channels, new_idx): in_order = channel_info["inputOrder"] tpl_dict = channel_info["templateDict"] in_dict = channel_info["inputDict"] all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order] empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels] # use KNN to choose k nearest channels k = 4 if len(in_order)>4 else len(in_order) knn = NearestNeighbors(n_neighbors=k, metric='euclidean') knn.fit(all_in) for i, channel in enumerate(missing_channels): distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1)) idx = tpl_dict[channel]["index"] new_idx[idx] = indices[0].tolist() return new_idx def match_names(stage1_info, channel_info): # read the location file loc_file = stage1_info["fileNames"]["input_loc"] tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file) tpl_order = tpl_montage.ch_names in_order = in_montage.ch_names new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode" alias_dict = { 'T3': 'T7', 'T4': 'T8', 'T5': 'P7', 'T6': 'P8' } for i, channel in enumerate(tpl_order): if channel in alias_dict and alias_dict[channel] in in_dict: tpl_montage.rename_channels({channel: alias_dict[channel]}) tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel) channel = alias_dict[channel] if channel in in_dict: new_idx[i] = [in_dict[channel]["index"]] fill_flags[i] = False tpl_dict[channel]["matched"] = True in_dict[channel]["assigned"] = True # update the names tpl_order = tpl_montage.ch_names stage1_info.update({ "unassignedInputs" : get_unassigned_inputs(in_order, in_dict), "missingTemplates" : get_empty_templates(tpl_order, tpl_dict), "mappingData" : [ { "newOrder" : new_idx, "fillFlags" : fill_flags } ] }) channel_info.update({ "templateOrder" : tpl_order, "inputOrder" : in_order, "templateDict" : tpl_dict, "inputDict" : in_dict }) return stage1_info, channel_info, tpl_montage, in_montage def optimal_mapping(channel_info): tpl_order = channel_info["templateOrder"] in_order = channel_info["inputOrder"] tpl_dict = channel_info["templateDict"] in_dict = channel_info["inputDict"] unassigned = get_unassigned_inputs(in_order, in_dict) # reset all tpl.matched to False for channel in tpl_dict: tpl_dict[channel]["matched"] = False all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order]) unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned]) # initialize the cost matrix for the Hungarian algorithm if len(unassigned) < 30: cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row else: cost_matrix = np.zeros((30, len(unassigned))) # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels for i in range(30): for j in range(len(unassigned)): cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000) # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel # by minimizing the total distances between their positions. row_idx, col_idx = linear_sum_assignment(cost_matrix) # store the mapping results new_idx = [[]]*30 fill_flags = [True]*30 for i, j in zip(row_idx, col_idx): if j < len(unassigned): # filter out dummy channels tpl_channel = tpl_order[i] in_channel = unassigned[j] new_idx[i] = [in_dict[in_channel]["index"]] fill_flags[i] = False tpl_dict[tpl_channel]["matched"] = True in_dict[in_channel]["assigned"] = True #print(f'{tpl_channel}({i}) <- {in_channel}({j})') # fill the remaining empty tpl_channels missing_channels = get_empty_templates(tpl_order, tpl_dict) if missing_channels != []: new_idx = find_neighbors(channel_info, missing_channels, new_idx) mapping_data = { "newOrder" : new_idx, "fillFlags" : fill_flags } channel_info.update({ "templateDict" : tpl_dict, "inputDict" : in_dict }) return mapping_data, channel_info def mapping_result(stage1_info, stage2_info, channel_info, filename): unassigned_num = len(stage1_info["unassignedInputs"]) batch_num = math.ceil(unassigned_num/30) + 1 # map the remaining in_channels for i in range(1, batch_num): # optimally select 30 in_channels to map to the tpl_channels based on proximity new_mapping_data, channel_info = optimal_mapping(channel_info) stage1_info["mappingData"] += [new_mapping_data] # save the mapping results new_dict = { #"templateOrder" : channel_info["templateOrder"], #"inputOrder" : channel_info["inputOrder"], "batchNum" : batch_num, "mappingData" : stage1_info["mappingData"] } with open(filename, 'w') as jsonfile: jsonfile.write(json.dumps(new_dict)) stage2_info["totalBatchNum"] = batch_num return stage1_info, stage2_info, channel_info