Spaces:
Sleeping
Sleeping
import utils | |
import os | |
import math | |
import json | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import mne | |
from mne.channels import read_custom_montage | |
from scipy.interpolate import Rbf | |
from scipy.optimize import linear_sum_assignment | |
from sklearn.neighbors import NearestNeighbors | |
def reorder_data(idx_order, fill_flags, filename, new_filename): | |
# read the input data | |
raw_data = utils.read_train_data(filename) | |
#print(raw_data.shape) | |
new_data = np.zeros((30, raw_data.shape[1])) | |
zero_arr = np.zeros((1, raw_data.shape[1])) | |
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): | |
if flag == False: | |
new_data[i, :] = raw_data[idx_set[0], :] | |
elif idx_set == []: | |
new_data[i, :] = zero_arr | |
else: | |
tmp_data = [raw_data[j, :] for j in idx_set] | |
new_data[i, :] = np.mean(tmp_data, axis=0) | |
utils.save_data(new_data, new_filename) | |
return raw_data.shape | |
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename): | |
# read the denoised data | |
d_data = utils.read_train_data(filename) | |
if batch_cnt == 0: | |
new_data = np.zeros((raw_data_shape[0], d_data.shape[1])) | |
#print(new_data.shape) | |
else: | |
new_data = utils.read_train_data(new_filename) | |
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)): | |
# ignore if this channel was filled using "fillmode" | |
if flag == False: | |
new_data[idx_set[0], :] = d_data[i, :] | |
utils.save_data(new_data, new_filename) | |
return | |
def get_matched(tpl_order, tpl_dict): | |
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True] | |
def get_empty_templates(tpl_order, tpl_dict): | |
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False] | |
def get_unassigned_inputs(in_order, in_dict): | |
return [channel for channel in in_order if in_dict[channel]["assigned"]==False] | |
def read_montage_data(loc_file): | |
tpl_montage = read_custom_montage("./template_chanlocs.loc") | |
in_montage = read_custom_montage(loc_file) | |
tpl_order = tpl_montage.ch_names | |
in_order = in_montage.ch_names | |
tpl_dict = {} | |
in_dict = {} | |
# convert all channel names to uppercase and store the channel information | |
for i, channel in enumerate(tpl_order): | |
up_channel = str.upper(channel) | |
tpl_montage.rename_channels({channel: up_channel}) | |
tpl_dict[up_channel] = { | |
"index" : i, | |
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel], | |
"matched" : False | |
} | |
for i, channel in enumerate(in_order): | |
up_channel = str.upper(channel) | |
in_montage.rename_channels({channel: up_channel}) | |
in_dict[up_channel] = { | |
"index" : i, | |
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel], | |
"assigned" : False | |
} | |
return tpl_montage, in_montage, tpl_dict, in_dict | |
def save_figures(channel_info, tpl_montage, filename1, filename2): | |
tpl_order = channel_info["templateOrder"] | |
in_order = channel_info["inputOrder"] | |
tpl_dict = channel_info["templateDict"] | |
in_dict = channel_info["inputDict"] | |
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order] | |
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order] | |
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order] | |
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order] | |
tpl_coords = np.vstack((tpl_x, tpl_y)).T | |
in_coords = np.vstack((in_x, in_y)).T | |
# extract template's head figure | |
tpl_fig = tpl_montage.plot() | |
tpl_ax = tpl_fig.axes[0] | |
lines = tpl_ax.lines | |
head_lines = [] | |
for line in lines: | |
x, y = line.get_data() | |
head_lines.append((x,y)) | |
plt.close() | |
# -------------------------plot input montage------------------------------ | |
fig = plt.figure(figsize=(6.4,6.4), dpi=100) | |
ax = fig.add_subplot(111) | |
fig.tight_layout() | |
ax.set_aspect('equal') | |
ax.axis('off') | |
# plot template's head | |
for x, y in head_lines: | |
ax.plot(x, y, color='black', linewidth=1.0) | |
# plot in_channels on it | |
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black') | |
for i, channel in enumerate(in_order): | |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center') | |
# save input_montage | |
fig.savefig(filename1) | |
# ---------------------------add indications------------------------------- | |
# plot unmatched input channels in red | |
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False] | |
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red') | |
for i in indices: | |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center') | |
# save mapped_montage | |
fig.savefig(filename2) | |
# ------------------------------------------------------------------------- | |
# store the tpl and in_channels' display positions (in px). | |
tpl_coords = ax.transData.transform(tpl_coords) | |
in_coords = ax.transData.transform(in_coords) | |
plt.close() | |
for i, channel in enumerate(tpl_order): | |
css_left = (tpl_coords[i,0]-11)/6.4 | |
css_bottom = (tpl_coords[i,1]-7)/6.4 | |
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] | |
for i, channel in enumerate(in_order): | |
css_left = (in_coords[i,0]-11)/6.4 | |
css_bottom = (in_coords[i,1]-7)/6.4 | |
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"] | |
channel_info.update({ | |
"templateDict" : tpl_dict, | |
"inputDict" : in_dict | |
}) | |
return channel_info | |
def align_coords(channel_info, tpl_montage, in_montage): | |
tpl_order = channel_info["templateOrder"] | |
in_order = channel_info["inputOrder"] | |
tpl_dict = channel_info["templateDict"] | |
in_dict = channel_info["inputDict"] | |
matched = get_matched(tpl_order, tpl_dict) | |
# 2D alignment (for visualization purposes) | |
fig = [tpl_montage.plot(), in_montage.plot()] | |
ax = [fig[0].axes[0], fig[1].axes[0]] | |
# extract the displayed 2D coordinates from the plots | |
all_tpl = ax[0].collections[0].get_offsets().data | |
all_in= ax[1].collections[0].get_offsets().data | |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) | |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) | |
# apply TPS to transform in_channels positions to align with tpl_channels positions | |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate') | |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate') | |
# apply the transformation to all in_channels | |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1]) | |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1]) | |
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T | |
# store the 2D positions | |
for i, channel in enumerate(tpl_order): | |
tpl_dict[channel]["coord_2d"] = all_tpl[i] | |
for i, channel in enumerate(in_order): | |
in_dict[channel]["coord_2d"] = transformed_in[i].tolist() | |
# 3D alignment | |
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order]) | |
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order]) | |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched]) | |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched]) | |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate') | |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate') | |
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate') | |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T | |
# update in_channels' 3D positions | |
for i, channel in enumerate(in_order): | |
in_dict[channel]["coord_3d"] = transformed_in[i].tolist() | |
channel_info.update({ | |
"templateDict" : tpl_dict, | |
"inputDict" : in_dict | |
}) | |
return channel_info | |
def find_neighbors(channel_info, missing_channels, new_idx): | |
in_order = channel_info["inputOrder"] | |
tpl_dict = channel_info["templateDict"] | |
in_dict = channel_info["inputDict"] | |
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order] | |
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels] | |
# use KNN to choose k nearest channels | |
k = 4 if len(in_order)>4 else len(in_order) | |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean') | |
knn.fit(all_in) | |
for i, channel in enumerate(missing_channels): | |
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1)) | |
idx = tpl_dict[channel]["index"] | |
new_idx[idx] = indices[0].tolist() | |
return new_idx | |
def match_names(stage1_info, channel_info): | |
# read the location file | |
loc_file = stage1_info["fileNames"]["input_loc"] | |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file) | |
tpl_order = tpl_montage.ch_names | |
in_order = in_montage.ch_names | |
new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels | |
fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode" | |
alias_dict = { | |
'T3': 'T7', | |
'T4': 'T8', | |
'T5': 'P7', | |
'T6': 'P8' | |
} | |
for i, channel in enumerate(tpl_order): | |
if channel in alias_dict and alias_dict[channel] in in_dict: | |
tpl_montage.rename_channels({channel: alias_dict[channel]}) | |
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel) | |
channel = alias_dict[channel] | |
if channel in in_dict: | |
new_idx[i] = [in_dict[channel]["index"]] | |
fill_flags[i] = False | |
tpl_dict[channel]["matched"] = True | |
in_dict[channel]["assigned"] = True | |
# update the names | |
tpl_order = tpl_montage.ch_names | |
stage1_info.update({ | |
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict), | |
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict), | |
"mappingData" : [ | |
{ | |
"newOrder" : new_idx, | |
"fillFlags" : fill_flags | |
} | |
] | |
}) | |
channel_info.update({ | |
"templateOrder" : tpl_order, | |
"inputOrder" : in_order, | |
"templateDict" : tpl_dict, | |
"inputDict" : in_dict | |
}) | |
return stage1_info, channel_info, tpl_montage, in_montage | |
def optimal_mapping(channel_info): | |
tpl_order = channel_info["templateOrder"] | |
in_order = channel_info["inputOrder"] | |
tpl_dict = channel_info["templateDict"] | |
in_dict = channel_info["inputDict"] | |
unassigned = get_unassigned_inputs(in_order, in_dict) | |
# reset all tpl.matched to False | |
for channel in tpl_dict: | |
tpl_dict[channel]["matched"] = False | |
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order]) | |
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned]) | |
# initialize the cost matrix for the Hungarian algorithm | |
if len(unassigned) < 30: | |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row | |
else: | |
cost_matrix = np.zeros((30, len(unassigned))) | |
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels | |
for i in range(30): | |
for j in range(len(unassigned)): | |
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000) | |
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel | |
# by minimizing the total distances between their positions. | |
row_idx, col_idx = linear_sum_assignment(cost_matrix) | |
# store the mapping results | |
new_idx = [[]]*30 | |
fill_flags = [True]*30 | |
for i, j in zip(row_idx, col_idx): | |
if j < len(unassigned): # filter out dummy channels | |
tpl_channel = tpl_order[i] | |
in_channel = unassigned[j] | |
new_idx[i] = [in_dict[in_channel]["index"]] | |
fill_flags[i] = False | |
tpl_dict[tpl_channel]["matched"] = True | |
in_dict[in_channel]["assigned"] = True | |
#print(f'{tpl_channel}({i}) <- {in_channel}({j})') | |
# fill the remaining empty tpl_channels | |
missing_channels = get_empty_templates(tpl_order, tpl_dict) | |
if missing_channels != []: | |
new_idx = find_neighbors(channel_info, missing_channels, new_idx) | |
mapping_data = { | |
"newOrder" : new_idx, | |
"fillFlags" : fill_flags | |
} | |
channel_info.update({ | |
"templateDict" : tpl_dict, | |
"inputDict" : in_dict | |
}) | |
return mapping_data, channel_info | |
def mapping_result(stage1_info, stage2_info, channel_info, filename): | |
unassigned_num = len(stage1_info["unassignedInputs"]) | |
batch_num = math.ceil(unassigned_num/30) + 1 | |
# map the remaining in_channels | |
for i in range(1, batch_num): | |
# optimally select 30 in_channels to map to the tpl_channels based on proximity | |
new_mapping_data, channel_info = optimal_mapping(channel_info) | |
stage1_info["mappingData"] += [new_mapping_data] | |
# save the mapping results | |
new_dict = { | |
#"templateOrder" : channel_info["templateOrder"], | |
#"inputOrder" : channel_info["inputOrder"], | |
"batchNum" : batch_num, | |
"mappingData" : stage1_info["mappingData"] | |
} | |
with open(filename, 'w') as jsonfile: | |
jsonfile.write(json.dumps(new_dict)) | |
stage2_info["totalBatchNum"] = batch_num | |
return stage1_info, stage2_info, channel_info | |