AIEEG / app_utils.py
audrey06100's picture
update
efa3bef
raw
history blame
13.2 kB
import utils
import os
import math
import json
import jsbeautifier
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.channels import read_custom_montage
from scipy.interpolate import Rbf
from scipy.optimize import linear_sum_assignment
from sklearn.neighbors import NearestNeighbors
def reorder_data(idx_order, orig_flags, inputname, filename):
# read the input data
raw_data = utils.read_train_data(inputname)
#print(raw_data.shape)
new_data = np.zeros((30, raw_data.shape[1]))
zero_arr = np.zeros((1, raw_data.shape[1]))
for i, (idx_set, flag) in enumerate(zip(idx_order, orig_flags)):
if flag == True:
new_data[i, :] = raw_data[idx_set[0], :]
elif idx_set == [None]:
new_data[i, :] = zero_arr
else:
tmp_data = [raw_data[j, :] for j in idx_set]
new_data[i, :] = np.mean(tmp_data, axis=0)
utils.save_data(new_data, filename)
return raw_data.shape
def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, outputname):
# read the denoised data
d_data = utils.read_train_data(filename)
if batch_cnt == 0:
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
#print(new_data.shape)
else:
new_data = utils.read_train_data(outputname)
for i, (idx_set, flag) in enumerate(zip(idx_order, orig_flags)):
if flag == True:
new_data[idx_set[0], :] = d_data[i, :]
utils.save_data(new_data, outputname)
return
def get_matched(tpl_order, tpl_dict):
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
def get_empty_templates(tpl_order, tpl_dict):
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
def get_unassigned_inputs(in_order, in_dict):
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
def read_montage_data(loc_file):
tpl_montage = read_custom_montage("./template_chanlocs.loc")
in_montage = read_custom_montage(loc_file)
tpl_order = tpl_montage.ch_names
in_order = in_montage.ch_names
tpl_dict = {}
in_dict = {}
# convert all channel names to uppercase and store their information
for i, channel in enumerate(tpl_order):
up_channel = str.upper(channel)
tpl_montage.rename_channels({channel: up_channel})
tpl_dict[up_channel] = {
"index" : i,
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
"matched" : False
}
for i, channel in enumerate(in_order):
up_channel = str.upper(channel)
in_montage.rename_channels({channel: up_channel})
in_dict[up_channel] = {
"index" : i,
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
"assigned" : False
}
return tpl_montage, in_montage, tpl_dict, in_dict
def save_figures(channel_info, tpl_montage, filename1, filename2):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
tpl_coords = np.vstack((tpl_x, tpl_y)).T
in_coords = np.vstack((in_x, in_y)).T
# extract template's head figure
tpl_fig = tpl_montage.plot()
tpl_ax = tpl_fig.axes[0]
lines = tpl_ax.lines
head_lines = []
for line in lines:
x, y = line.get_data()
head_lines.append((x,y))
# -------------------------plot input montage------------------------------
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
ax = fig.add_subplot(111)
fig.tight_layout()
ax.set_aspect('equal')
ax.axis('off')
# plot template's head
for x, y in head_lines:
ax.plot(x, y, color='black', linewidth=1.0)
# plot in_channels on it
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
for i, channel in enumerate(in_order):
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
# save input_montage
fig.savefig(filename1)
# ---------------------------add indications-------------------------------
# plot unmatched input channels in red
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
if indices != []:
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
for i in indices:
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
# save mapped_montage
fig.savefig(filename2)
# -------------------------------------------------------------------------
# store the tpl and in_channels' display positions (in px).
tpl_coords = ax.transData.transform(tpl_coords)
in_coords = ax.transData.transform(in_coords)
plt.close('all')
for i, channel in enumerate(tpl_order):
css_left = (tpl_coords[i,0]-11)/6.4
css_bottom = (tpl_coords[i,1]-7)/6.4
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
for i, channel in enumerate(in_order):
css_left = (in_coords[i,0]-11)/6.4
css_bottom = (in_coords[i,1]-7)/6.4
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def align_coords(channel_info, tpl_montage, in_montage):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
matched = get_matched(tpl_order, tpl_dict)
# 2D alignment (for visualization purposes)
fig = [tpl_montage.plot(), in_montage.plot()]
ax = [fig[0].axes[0], fig[1].axes[0]]
# extract the displayed 2D coordinates
all_tpl = ax[0].collections[0].get_offsets().data
all_in= ax[1].collections[0].get_offsets().data
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
plt.close('all')
# apply TPS to transform in_channels to align with tpl_channels positions
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
# apply the transformation to all in_channels
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
for i, channel in enumerate(tpl_order):
tpl_dict[channel]["coord_2d"] = all_tpl[i]
for i, channel in enumerate(in_order):
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
# 3D alignment
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
for i, channel in enumerate(in_order):
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def find_neighbors(channel_info, missing_channels, new_idx):
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
# use KNN to choose k nearest channels
k = 4 if len(in_order)>4 else len(in_order)
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
knn.fit(all_in)
for i, channel in enumerate(missing_channels):
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
idx = tpl_dict[channel]["index"]
new_idx[idx] = indices[0].tolist()
return new_idx
def match_names(stage1_info):
# read the location file
loc_file = stage1_info["fileNames"]["inputLocation"]
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
tpl_order = tpl_montage.ch_names
in_order = in_montage.ch_names
new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
orig_flags = [False]*30
alias_dict = {
'T3': 'T7',
'T4': 'T8',
'T5': 'P7',
'T6': 'P8'
}
for i, channel in enumerate(tpl_order):
if channel in alias_dict and alias_dict[channel] in in_dict:
tpl_montage.rename_channels({channel: alias_dict[channel]})
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
channel = alias_dict[channel]
if channel in in_dict:
new_idx[i] = [in_dict[channel]["index"]]
orig_flags[i] = True
tpl_dict[channel]["matched"] = True
in_dict[channel]["assigned"] = True
# update the names
tpl_order = tpl_montage.ch_names
stage1_info.update({
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
"mappingResults" : [
{
"newOrder" : new_idx,
"isOriginalData" : orig_flags
}
]
})
channel_info = {
"templateOrder" : tpl_order,
"inputOrder" : in_order,
"templateDict" : tpl_dict,
"inputDict" : in_dict
}
return stage1_info, channel_info, tpl_montage, in_montage
def optimal_mapping(channel_info):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
unassigned = get_unassigned_inputs(in_order, in_dict)
# reset all tpl.matched to False
for channel in tpl_dict:
tpl_dict[channel]["matched"] = False
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
# initialize the cost matrix for the Hungarian algorithm
if len(unassigned) < 30:
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
else:
cost_matrix = np.zeros((30, len(unassigned)))
# fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
for i in range(30):
for j in range(len(unassigned)):
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
# by minimizing the total distances between their positions.
row_idx, col_idx = linear_sum_assignment(cost_matrix)
# store the mapping results
new_idx = [[None]]*30
orig_flags = [False]*30
for i, j in zip(row_idx, col_idx):
if j < len(unassigned): # filter out dummy channels
tpl_channel = tpl_order[i]
in_channel = unassigned[j]
new_idx[i] = [in_dict[in_channel]["index"]]
orig_flags[i] = True
tpl_dict[tpl_channel]["matched"] = True
in_dict[in_channel]["assigned"] = True
#print(f'{tpl_channel}({i}) <- {in_channel}({j})')
# fill the remaining empty tpl_channels
missing_channels = get_empty_templates(tpl_order, tpl_dict)
if missing_channels != []:
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
result = {
"newOrder" : new_idx,
"isOriginalData" : orig_flags
}
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return result, channel_info
def mapping_result(stage1_info, channel_info, filename):
unassigned_num = len(stage1_info["unassignedInputs"])
batch_num = math.ceil(unassigned_num/30) + 1
# map the remaining in_channels
results = stage1_info["mappingResults"]
for i in range(1, batch_num):
# optimally select 30 in_channels to map to the tpl_channels based on proximity
result, channel_info = optimal_mapping(channel_info)
results += [result]
data = {
#"templateOrder" : channel_info["templateOrder"],
#"inputOrder" : channel_info["inputOrder"],
"batchNum" : batch_num,
"mappingResults" : results
}
options = jsbeautifier.default_options()
options.indent_size = 4
res = jsbeautifier.beautify(json.dumps(data), options)
with open(filename, 'w') as jsonfile:
jsonfile.write(res)
stage1_info.update({
"batchNum" : batch_num,
"mappingResults" : results
})
return stage1_info, channel_info