import utils import time import os import numpy as np import gradio as gr import mne from mne.channels import read_custom_montage from scipy.interpolate import Rbf from scipy.optimize import linear_sum_assignment from sklearn.neighbors import NearestNeighbors def reorder_to_template(app_state, filename): old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] old_data = utils.read_train_data(filename) # original raw data new_data = np.zeros((30, old_data.shape[1])) # reordered raw data new_filename = app_state["filepath"]+'mapped.csv' #print('new order 1:', app_state["stage1NewOrder"]) #print('new order 2:', app_state["stage2NewOrder"]) zero_arr = np.zeros((1, old_data.shape[1])) old_data = np.concatenate((old_data, zero_arr), axis=0) for i in range(30): curr_idx_set = old_idx[i] #print("channel_{}'s index set: {}".format(i, curr_idx_set)) if curr_idx_set == []: new_data[i, :] = zero_arr else: tmp_data = [old_data[j, :] for j in curr_idx_set] new_data[i, :] = np.mean(tmp_data, axis=0) print('old.shape, new.shape: ', old_data.shape, new_data.shape) utils.save_data(new_data, new_filename) return def reorder_to_origin(app_state, channel_info, filename, new_filename): old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] old_data = utils.read_train_data(filename) # denoised data template_order = channel_info["templateByIndex"] if app_state["runnigState"] == "stage1": new_data = np.zeros((len(channel_info["inputByName"]), old_data.shape[1])) else: new_data = utils.read_train_data(new_filename) for i, channel in enumerate(template_order): idx_set = old_idx[i] # ignore if this channel doesn't exist if len(idx_set)==1 and channel_info["templateByName"][channel]["matched"]==True: new_data[idx_set[0], :] = old_data[i, :] print('old.shape, new.shape: ', old_data.shape, new_data.shape) utils.save_data(new_data, new_filename) return class Channel: def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None): self.name = name self.index = index self.matched = matched self.assigned = assigned # for input channels self.coord = coord self.css_position = css_position def read_montage_data(loc_file): template_montage = read_custom_montage("./template_chanlocs.loc") input_montage = read_custom_montage(loc_file) template_dict = {} input_dict = {} montages = [template_montage, input_montage] dicts = [template_dict, input_dict] num = [30, len(input_montage.ch_names)] for i in range(2): for j in range(num[i]): channel = montages[i].ch_names[j] montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase channel = str.upper(channel) dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel]) return template_montage, input_montage, template_dict, input_dict def align_coords(channel_info, template_montage, input_montage): template_dict = channel_info["templateByName"] input_dict = channel_info["inputByName"] template_order = channel_info["templateByIndex"] input_order = channel_info["inputByIndex"] matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True] # 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual') fig = [template_montage.plot(), input_montage.plot()] fig[0].set_size_inches(5.6, 5.6) fig[1].set_size_inches(5.6, 5.6) ax = [fig[0].axes[0], fig[1].axes[0]] ax[0].set_aspect('equal') ax[1].set_aspect('equal') ax[0].figure.canvas.draw() #update the figure ax[1].figure.canvas.draw() # get the original coords all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px) all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data) matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched]) matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched]) # transform the xy axis (template's -> input's) rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate') rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate') # apply to all template channels transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1]) transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1]) #transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T # update input, template's position for i, channel in enumerate(template_order): css_left = (transformed_tpl_x[i]-11)/560 css_bottom = (transformed_tpl_y[i]-7)/560 template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] for i, channel in enumerate(input_order): css_left = (all_in[i][0]-11)/560 css_bottom = (all_in[i][1]-7)/560 input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] # 3-d (to use KNN) # get the original coords all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order]) all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order]) matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched]) matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched]) # transform the xyz axis (input's -> template's) rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate') rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate') rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate') # apply to all input channels transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2]) transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T # update input's position for i, channel in enumerate(input_order): input_dict[channel]["coord"] = transformed_in[i].tolist() channel_info.update({ "templateByName" : template_dict, "inputByName" : input_dict, }) return channel_info def fill_channels(app_state, channel_info, fill_mode): new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] template_dict = channel_info["templateByName"] input_dict = channel_info["inputByName"] template_order = channel_info["templateByIndex"] input_order = channel_info["inputByIndex"] z_row_idx = channel_info["dataShape"][0] unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False] if unmatched == []: return app_state if fill_mode == 'zero': for channel in unmatched: idx = template_dict[channel]["index"] new_idx[idx] = [z_row_idx] elif fill_mode == 'mean_auto': # use KNN to choose k nearest channels in_coords = [input_dict[channel]["coord"] for channel in input_order] in_coords = np.array([in_coords[i] for i in range(len(in_coords))]) k = 4 if len(input_dict)>4 else len(input_dict) knn = NearestNeighbors(n_neighbors=k, metric='euclidean') knn.fit(in_coords) for channel in unmatched: distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1)) selected = [input_order[i] for i in indices[0]] print(channel, ':', selected) idx = template_dict[channel]["index"] new_idx[idx] = indices[0].tolist() if app_state["runnigState"] == "stage1": app_state["stage1NewOrder"] = new_idx else: app_state["stage2NewOrder"] = new_idx return app_state def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode): second1 = time.time() template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file) template_order = template_montage.ch_names new_idx = [[]]*30 missing_channels = [] alias = { 'T3': 'T7', 'T4': 'T8', 'T5': 'P7', 'T6': 'P8', #'TP7': 'T5\'', #'TP8': 'T6\'', } # match the names of input channels -> template channels for i, channel in enumerate(template_order): if channel in alias and alias[channel] in input_dict: template_montage.rename_channels({channel: alias[channel]}) template_dict[alias[channel]] = template_dict.pop(channel) template_dict[alias[channel]].name = alias[channel] channel = alias[channel] if channel in input_dict: new_idx[i] = [input_dict[channel].index] template_dict[channel].matched = True input_dict[channel].matched = True input_dict[channel].assigned = True else: missing_channels.append(i) channel_info.update({ "missingChannelsIndex" : missing_channels, "templateByName" : {k : v.__dict__ for k,v in template_dict.items()}, "inputByName" : {k : v.__dict__ for k,v in input_dict.items()}, "templateByIndex" : template_montage.ch_names, "inputByIndex" : input_montage.ch_names }) app_state.update({ "stage1NewOrder" : new_idx, "runnigState" : "stage1" }) # align input, template's coordinates channel_info = align_coords(channel_info, template_montage, input_montage) # fill the unmatched channels app_state = fill_channels(app_state, channel_info, fill_mode) second2 = time.time() print('Mapping (stage1) finished in',second2 - second1,'s.') return app_state, channel_info def mapping_stage2(app_state, channel_info, fill_mode): second1 = time.time() template_dict = channel_info["templateByName"] input_dict = channel_info["inputByName"] template_order = channel_info["templateByIndex"] unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False] if unassigned == []: app_state["runnigState"] = "finished" return app_state, channel_info tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order]) unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned]) # set all tpl.matched to False for channel in template_dict: template_dict[channel]["matched"] = False # initialize the cost matrix if len(unassigned) < 30: cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row else: cost_matrix = np.zeros((30, len(unassigned))) for i in range(30): for j in range(len(unassigned)): cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance #print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j]) # use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...? row_idx, col_idx = linear_sum_assignment(cost_matrix) matches = [] new_idx = [[]]*30 for i in range(30): if col_idx[i] < len(unassigned): # filter out dummy channels matches.append([row_idx[i], col_idx[i]]) tpl_channel = template_order[row_idx[i]] in_channel = unassigned[col_idx[i]] template_dict[tpl_channel]["matched"] = True input_dict[in_channel]["assigned"] = True new_idx[i] = [input_dict[in_channel]["index"]] print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]]) channel_info.update({ "templateByName" : template_dict, "inputByName" : input_dict }) app_state.update({ "stage2NewOrder" : new_idx, "runnigState" : "stage2" }) # fill the unmatched channels app_state = fill_channels(app_state, channel_info, fill_mode) second2 = time.time() print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.') return app_state, channel_info