AIEEG / app_utils.py
audrey06100's picture
Update app_utils.py
89c0b72 verified
raw
history blame
13 kB
import utils
import os
import math
import json
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.channels import read_custom_montage
from scipy.interpolate import Rbf
from scipy.optimize import linear_sum_assignment
from sklearn.neighbors import NearestNeighbors
def reorder_data(idx_order, fill_flags, filename, new_filename):
# read the input data
raw_data = utils.read_train_data(filename)
#print(raw_data.shape)
new_data = np.zeros((30, raw_data.shape[1]))
zero_arr = np.zeros((1, raw_data.shape[1]))
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
if flag == False:
new_data[i, :] = raw_data[idx_set[0], :]
elif idx_set == []:
new_data[i, :] = zero_arr
else:
tmp_data = [raw_data[j, :] for j in idx_set]
new_data[i, :] = np.mean(tmp_data, axis=0)
utils.save_data(new_data, new_filename)
return raw_data.shape
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
# read the denoised data
d_data = utils.read_train_data(filename)
if batch_cnt == 0:
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
#print(new_data.shape)
else:
new_data = utils.read_train_data(new_filename)
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
# ignore if this channel was filled using "fillmode"
if flag == False:
new_data[idx_set[0], :] = d_data[i, :]
utils.save_data(new_data, new_filename)
return
def get_matched(tpl_order, tpl_dict):
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
def get_empty_templates(tpl_order, tpl_dict):
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
def get_unassigned_inputs(in_order, in_dict):
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
def read_montage_data(loc_file):
tpl_montage = read_custom_montage("./template_chanlocs.loc")
in_montage = read_custom_montage(loc_file)
tpl_order = tpl_montage.ch_names
in_order = in_montage.ch_names
tpl_dict = {}
in_dict = {}
# convert all channel names to uppercase and store the channel information
for i, channel in enumerate(tpl_order):
up_channel = str.upper(channel)
tpl_montage.rename_channels({channel: up_channel})
tpl_dict[up_channel] = {
"index" : i,
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
"matched" : False
}
for i, channel in enumerate(in_order):
up_channel = str.upper(channel)
in_montage.rename_channels({channel: up_channel})
in_dict[up_channel] = {
"index" : i,
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
"assigned" : False
}
return tpl_montage, in_montage, tpl_dict, in_dict
def save_figures(channel_info, tpl_montage, filename1, filename2):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
tpl_coords = np.vstack((tpl_x, tpl_y)).T
in_coords = np.vstack((in_x, in_y)).T
# extract template's head figure
tpl_fig = tpl_montage.plot()
tpl_ax = tpl_fig.axes[0]
lines = tpl_ax.lines
head_lines = []
for line in lines:
x, y = line.get_data()
head_lines.append((x,y))
plt.close()
# -------------------------plot input montage------------------------------
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
ax = fig.add_subplot(111)
fig.tight_layout()
ax.set_aspect('equal')
ax.axis('off')
# plot template's head
for x, y in head_lines:
ax.plot(x, y, color='black', linewidth=1.0)
# plot in_channels on it
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
for i, channel in enumerate(in_order):
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
# save input_montage
fig.savefig(filename1)
# ---------------------------add indications-------------------------------
# plot unmatched input channels in red
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
for i in indices:
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
# save mapped_montage
fig.savefig(filename2)
# -------------------------------------------------------------------------
# store the tpl and in_channels' display positions (in px).
tpl_coords = ax.transData.transform(tpl_coords)
in_coords = ax.transData.transform(in_coords)
plt.close()
for i, channel in enumerate(tpl_order):
css_left = (tpl_coords[i,0]-11)/6.4
css_bottom = (tpl_coords[i,1]-7)/6.4
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
for i, channel in enumerate(in_order):
css_left = (in_coords[i,0]-11)/6.4
css_bottom = (in_coords[i,1]-7)/6.4
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def align_coords(channel_info, tpl_montage, in_montage):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
matched = get_matched(tpl_order, tpl_dict)
# 2D alignment (for visualization purposes)
fig = [tpl_montage.plot(), in_montage.plot()]
ax = [fig[0].axes[0], fig[1].axes[0]]
# extract the displayed 2D coordinates from the plots
all_tpl = ax[0].collections[0].get_offsets().data
all_in= ax[1].collections[0].get_offsets().data
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
# apply TPS to transform in_channels positions to align with tpl_channels positions
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
# apply the transformation to all in_channels
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
# store the 2D positions
for i, channel in enumerate(tpl_order):
tpl_dict[channel]["coord_2d"] = all_tpl[i]
for i, channel in enumerate(in_order):
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
# 3D alignment
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
# update in_channels' 3D positions
for i, channel in enumerate(in_order):
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return channel_info
def find_neighbors(channel_info, missing_channels, new_idx):
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
# use KNN to choose k nearest channels
k = 4 if len(in_order)>4 else len(in_order)
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
knn.fit(all_in)
for i, channel in enumerate(missing_channels):
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
idx = tpl_dict[channel]["index"]
new_idx[idx] = indices[0].tolist()
return new_idx
def match_names(stage1_info, channel_info):
# read the location file
loc_file = stage1_info["fileNames"]["input_loc"]
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
tpl_order = tpl_montage.ch_names
in_order = in_montage.ch_names
new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
alias_dict = {
'T3': 'T7',
'T4': 'T8',
'T5': 'P7',
'T6': 'P8'
}
for i, channel in enumerate(tpl_order):
if channel in alias_dict and alias_dict[channel] in in_dict:
tpl_montage.rename_channels({channel: alias_dict[channel]})
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
channel = alias_dict[channel]
if channel in in_dict:
new_idx[i] = [in_dict[channel]["index"]]
fill_flags[i] = False
tpl_dict[channel]["matched"] = True
in_dict[channel]["assigned"] = True
# update the names
tpl_order = tpl_montage.ch_names
stage1_info.update({
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
"mappingData" : [
{
"newOrder" : new_idx,
"fillFlags" : fill_flags
}
]
})
channel_info.update({
"templateOrder" : tpl_order,
"inputOrder" : in_order,
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return stage1_info, channel_info, tpl_montage, in_montage
def optimal_mapping(channel_info):
tpl_order = channel_info["templateOrder"]
in_order = channel_info["inputOrder"]
tpl_dict = channel_info["templateDict"]
in_dict = channel_info["inputDict"]
unassigned = get_unassigned_inputs(in_order, in_dict)
# reset all tpl.matched to False
for channel in tpl_dict:
tpl_dict[channel]["matched"] = False
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
# initialize the cost matrix for the Hungarian algorithm
if len(unassigned) < 30:
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
else:
cost_matrix = np.zeros((30, len(unassigned)))
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
for i in range(30):
for j in range(len(unassigned)):
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
# by minimizing the total distances between their positions.
row_idx, col_idx = linear_sum_assignment(cost_matrix)
# store the mapping results
new_idx = [[]]*30
fill_flags = [True]*30
for i, j in zip(row_idx, col_idx):
if j < len(unassigned): # filter out dummy channels
tpl_channel = tpl_order[i]
in_channel = unassigned[j]
new_idx[i] = [in_dict[in_channel]["index"]]
fill_flags[i] = False
tpl_dict[tpl_channel]["matched"] = True
in_dict[in_channel]["assigned"] = True
#print(f'{tpl_channel}({i}) <- {in_channel}({j})')
# fill the remaining empty tpl_channels
missing_channels = get_empty_templates(tpl_order, tpl_dict)
if missing_channels != []:
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
mapping_data = {
"newOrder" : new_idx,
"fillFlags" : fill_flags
}
channel_info.update({
"templateDict" : tpl_dict,
"inputDict" : in_dict
})
return mapping_data, channel_info
def mapping_result(stage1_info, stage2_info, channel_info, filename):
unassigned_num = len(stage1_info["unassignedInputs"])
batch_num = math.ceil(unassigned_num/30) + 1
# map the remaining in_channels
for i in range(1, batch_num):
# optimally select 30 in_channels to map to the tpl_channels based on proximity
new_mapping_data, channel_info = optimal_mapping(channel_info)
stage1_info["mappingData"] += [new_mapping_data]
# save the mapping results
new_dict = {
#"templateOrder" : channel_info["templateOrder"],
#"inputOrder" : channel_info["inputOrder"],
"batchNum" : batch_num,
"mappingData" : stage1_info["mappingData"]
}
with open(filename, 'w') as jsonfile:
jsonfile.write(json.dumps(new_dict))
stage2_info["totalBatchNum"] = batch_num
return stage1_info, stage2_info, channel_info