AIEEG / app_utils.py
audrey06100's picture
update
18b3426
import utils
import os
import math
import json
import jsbeautifier
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.channels import read_custom_montage
from scipy.interpolate import Rbf
from scipy.optimize import linear_sum_assignment
from sklearn.neighbors import NearestNeighbors
def get_matched(tpl_names, tpl_dict):
return [name for name in tpl_names if tpl_dict[name]["matched"]==True]
def get_empty_template(tpl_names, tpl_dict):
return [name for name in tpl_names if tpl_dict[name]["matched"]==False]
def get_unassigned_input(in_names, in_dict):
return [name for name in in_names if in_dict[name]["assigned"]==False]
def read_montage(loc_file):
tpl_montage = read_custom_montage("./template_chanlocs.loc")
in_montage = read_custom_montage(loc_file)
tpl_names = tpl_montage.ch_names
in_names = in_montage.ch_names
tpl_dict = {}
in_dict = {}
# convert all channel names to uppercase and store their information
for i, name in enumerate(tpl_names):
up_name = str.upper(name)
tpl_montage.rename_channels({name: up_name})
tpl_dict[up_name] = {
"index" : i,
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_name],
"matched" : False
}
for i, name in enumerate(in_names):
up_name = str.upper(name)
in_montage.rename_channels({name: up_name})
in_dict[up_name] = {
"index" : i,
"coord_3d" : in_montage.get_positions()['ch_pos'][up_name],
"assigned" : False
}
return tpl_montage, in_montage, tpl_dict, in_dict
def match_name(stage1_info):
# read the location file
loc_file = stage1_info["fileNames"]["inputData"]
tpl_montage, in_montage, tpl_dict, in_dict = read_montage(loc_file)
tpl_names = tpl_montage.ch_names
in_names = in_montage.ch_names
old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
is_orig_data = [False]*30
alias_dict = {
'T3': 'T7',
'T4': 'T8',
'T5': 'P7',
'T6': 'P8'
}
for i, name in enumerate(tpl_names):
if name in alias_dict and alias_dict[name] in in_dict:
tpl_montage.rename_channels({name: alias_dict[name]})
tpl_dict[alias_dict[name]] = tpl_dict.pop(name)
name = alias_dict[name]
if name in in_dict:
old_idx[i] = [in_dict[name]["index"]]
is_orig_data[i] = True
tpl_dict[name]["matched"] = True
in_dict[name]["assigned"] = True
# update the names
tpl_names = tpl_montage.ch_names
stage1_info.update({
"unassignedInput" : get_unassigned_input(in_names, in_dict),
"emptyTemplate" : get_empty_template(tpl_names, tpl_dict),
"mappingResult" : [
{
"index" : old_idx,
"isOriginalData" : is_orig_data
}
]
})
channel_info = {
"templateNames" : tpl_names,
"inputNames" : in_names,
"templateDict" : tpl_dict,
"inputDict" : in_dict
}
return stage1_info, channel_info, tpl_montage, in_montage
def align_coords(channel_info, tpl_montage, in_montage):
tpl_names = channel_info["templateNames"]
in_names = channel_info["inputNames"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
matched_names = get_matched(tpl_names, tpl_dict)
# 2D alignment (for visualization purposes)
fig = [tpl_montage.plot(), in_montage.plot()]
ax = [fig[0].axes[0], fig[1].axes[0]]
# extract the displayed 2D coordinates
all_tpl = ax[0].collections[0].get_offsets().data
all_in= ax[1].collections[0].get_offsets().data
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
plt.close('all')
# apply TPS to transform in_channels to align with tpl_channels positions
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
# apply the transformation to all in_channels
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
for i, name in enumerate(tpl_names):
tpl_dict[name]["coord_2d"] = all_tpl[i]
for i, name in enumerate(in_names):
in_dict[name]["coord_2d"] = transformed_in[i].tolist()
# 3D alignment
all_tpl = np.array([tpl_dict[name]["coord_3d"].tolist() for name in tpl_names])
all_in = np.array([in_dict[name]["coord_3d"].tolist() for name in in_names])
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
for i, name in enumerate(in_names):
in_dict[name]["coord_3d"] = transformed_in[i].tolist()
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def save_figure(channel_info, tpl_montage, filename1, filename2):
tpl_names = channel_info["templateNames"]
in_names = channel_info["inputNames"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
tpl_x = [tpl_dict[name]["coord_2d"][0] for name in tpl_names]
tpl_y = [tpl_dict[name]["coord_2d"][1] for name in tpl_names]
in_x = [in_dict[name]["coord_2d"][0] for name in in_names]
in_y = [in_dict[name]["coord_2d"][1] for name in in_names]
tpl_coords = np.vstack((tpl_x, tpl_y)).T
in_coords = np.vstack((in_x, in_y)).T
# extract template's head figure
tpl_fig = tpl_montage.plot()
tpl_ax = tpl_fig.axes[0]
lines = tpl_ax.lines
head_lines = []
for line in lines:
x, y = line.get_data()
head_lines.append((x,y))
# -------------------------plot input montage------------------------------
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
ax = fig.add_subplot(111)
fig.tight_layout()
ax.set_aspect('equal')
ax.axis('off')
# plot template's head
for x, y in head_lines:
ax.plot(x, y, color='black', linewidth=1.0)
# plot in_channels on it
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
for i, name in enumerate(in_names):
ax.text(in_coords[i,0]+0.004, in_coords[i,1], name, color='black', fontsize=10.0, va='center')
# save input_montage
fig.savefig(filename1)
# ---------------------------add indications-------------------------------
# plot unmatched input channels in red
indices = [in_dict[name]["index"] for name in in_names if in_dict[name]["assigned"]==False]
if indices != []:
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
for i in indices:
ax.text(in_coords[i,0]+0.004, in_coords[i,1], in_names[i], color='red', fontsize=10.0, va='center')
# save mapped_montage
fig.savefig(filename2)
# -------------------------------------------------------------------------
# store the tpl and in_channels' display positions (in px)
tpl_coords = ax.transData.transform(tpl_coords)
in_coords = ax.transData.transform(in_coords)
plt.close('all')
for i, name in enumerate(tpl_names):
left = tpl_coords[i,0]/6.4
bottom = tpl_coords[i,1]/6.4
tpl_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
for i, name in enumerate(in_names):
left = in_coords[i,0]/6.4
bottom = in_coords[i,1]/6.4
in_dict[name]["css_position"] = [round(left, 2), round(bottom, 2)]
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def find_neighbors(channel_info, empty_tpl_names, old_idx):
in_names = channel_info["inputNames"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
all_in = [np.array(in_dict[name]["coord_3d"]) for name in in_names]
empty_tpl = [np.array(tpl_dict[name]["coord_3d"]) for name in empty_tpl_names]
# use KNN to choose k nearest channels
k = 4 if len(in_names)>4 else len(in_names)
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
knn.fit(all_in)
for i, name in enumerate(empty_tpl_names):
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
idx = tpl_dict[name]["index"]
old_idx[idx] = indices[0].tolist()
return old_idx
def optimal_mapping(channel_info):
tpl_names = channel_info["templateNames"]
in_names = channel_info["inputNames"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
unass_in_names = get_unassigned_input(in_names, in_dict)
# reset all tpl.matched to False
for name in tpl_dict:
tpl_dict[name]["matched"] = False
all_tpl = np.array([tpl_dict[name]["coord_3d"] for name in tpl_names])
unass_in = np.array([in_dict[name]["coord_3d"] for name in unass_in_names])
# initialize the cost matrix for the Hungarian algorithm
if len(unass_in_names) < 30:
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
else:
cost_matrix = np.zeros((30, len(unass_in_names)))
# fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
for i in range(30):
for j in range(len(unass_in_names)):
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
# by minimizing the total distances between their positions.
row_idx, col_idx = linear_sum_assignment(cost_matrix)
# store the mapping result
old_idx = [[None]]*30
is_orig_data = [False]*30
for i, j in zip(row_idx, col_idx):
if j < len(unass_in_names): # filter out dummy channels
tpl_name = tpl_names[i]
in_name = unass_in_names[j]
old_idx[i] = [in_dict[in_name]["index"]]
is_orig_data[i] = True
tpl_dict[tpl_name]["matched"] = True
in_dict[in_name]["assigned"] = True
# fill the remaining empty tpl_channels
empty_tpl_names = get_empty_template(tpl_names, tpl_dict)
if empty_tpl_names != []:
old_idx = find_neighbors(channel_info, empty_tpl_names, old_idx)
result = {
"index" : old_idx,
"isOriginalData" : is_orig_data
}
channel_info["inputDict"] = in_dict
return result, channel_info
def mapping_result(stage1_info, channel_info, filename):
unassigned_num = len(stage1_info["unassignedInput"])
batch = math.ceil(unassigned_num/30) + 1
# map the remaining in_channels
results = stage1_info["mappingResult"]
for i in range(1, batch):
# optimally select 30 in_channels to map to the tpl_channels based on proximity
result, channel_info = optimal_mapping(channel_info)
results += [result]
'''
for i in range(batch):
results[i]["name"] = {}
for j, indices in enumerate(results[i]["index"]):
names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
results[i]["name"][channel_info["templateNames"][j]] = names
'''
data = {
#"templateNames" : channel_info["templateNames"],
#"inputNames" : channel_info["inputNames"],
"channelNum" : len(channel_info["inputNames"]),
"batch" : batch,
"mappingResult" : results
}
options = jsbeautifier.default_options()
options.indent_size = 4
json_data = jsbeautifier.beautify(json.dumps(data), options)
with open(filename, 'w') as jsonfile:
jsonfile.write(json_data)
stage1_info.update({
"batch" : batch,
"mappingResult" : results
})
return stage1_info, channel_info