AIEEG / channel_mapping.py
audrey06100's picture
update
995c1d0
raw
history blame
12.1 kB
import utils
import time
import os
import numpy as np
import gradio as gr
import mne
from mne.channels import read_custom_montage
from scipy.interpolate import Rbf
from scipy.optimize import linear_sum_assignment
from sklearn.neighbors import NearestNeighbors
def reorder_to_template(app_state, filename):
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
old_data = utils.read_train_data(filename) # original raw data
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
new_filename = app_state["filepath"]+'mapped.csv'
#print('new order 1:', app_state["stage1NewOrder"])
#print('new order 2:', app_state["stage2NewOrder"])
zero_arr = np.zeros((1, old_data.shape[1]))
old_data = np.concatenate((old_data, zero_arr), axis=0)
for i in range(30):
curr_idx_set = old_idx[i]
#print("channel_{}'s index set: {}".format(i, curr_idx_set))
if curr_idx_set == []:
new_data[i, :] = zero_arr
else:
tmp_data = [old_data[j, :] for j in curr_idx_set]
new_data[i, :] = np.mean(tmp_data, axis=0)
print('old.shape, new.shape: ', old_data.shape, new_data.shape)
utils.save_data(new_data, new_filename)
return
def reorder_to_origin(app_state, channel_info, filename, new_filename):
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
old_data = utils.read_train_data(filename) # denoised data
template_order = channel_info["templateByIndex"]
if app_state["runnigState"] == "stage1":
new_data = np.zeros((len(channel_info["inputByName"]), old_data.shape[1]))
else:
new_data = utils.read_train_data(new_filename)
for i, channel in enumerate(template_order):
idx_set = old_idx[i]
# ignore if this channel doesn't exist
if len(idx_set)==1 and channel_info["templateByName"][channel]["matched"]==True:
new_data[idx_set[0], :] = old_data[i, :]
print('old.shape, new.shape: ', old_data.shape, new_data.shape)
utils.save_data(new_data, new_filename)
return
class Channel:
def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None):
self.name = name
self.index = index
self.matched = matched
self.assigned = assigned # for input channels
self.coord = coord
self.css_position = css_position
def read_montage_data(loc_file):
template_montage = read_custom_montage("./template_chanlocs.loc")
input_montage = read_custom_montage(loc_file)
template_dict = {}
input_dict = {}
montages = [template_montage, input_montage]
dicts = [template_dict, input_dict]
num = [30, len(input_montage.ch_names)]
for i in range(2):
for j in range(num[i]):
channel = montages[i].ch_names[j]
montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
channel = str.upper(channel)
dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel])
return template_montage, input_montage, template_dict, input_dict
def align_coords(channel_info, template_montage, input_montage):
template_dict = channel_info["templateByName"]
input_dict = channel_info["inputByName"]
template_order = channel_info["templateByIndex"]
input_order = channel_info["inputByIndex"]
matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
# 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
fig = [template_montage.plot(), input_montage.plot()]
fig[0].set_size_inches(5.6, 5.6)
fig[1].set_size_inches(5.6, 5.6)
ax = [fig[0].axes[0], fig[1].axes[0]]
ax[0].set_aspect('equal')
ax[1].set_aspect('equal')
ax[0].figure.canvas.draw() #update the figure
ax[1].figure.canvas.draw()
# get the original coords
all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px)
all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
# transform the xy axis (template's -> input's)
rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate')
rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate')
# apply to all template channels
transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1])
transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1])
#transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T
# update input, template's position
for i, channel in enumerate(template_order):
css_left = (transformed_tpl_x[i]-11)/560
css_bottom = (transformed_tpl_y[i]-7)/560
template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
for i, channel in enumerate(input_order):
css_left = (all_in[i][0]-11)/560
css_bottom = (all_in[i][1]-7)/560
input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
# 3-d (to use KNN)
# get the original coords
all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order])
all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order])
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
# transform the xyz axis (input's -> template's)
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
# apply to all input channels
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
# update input's position
for i, channel in enumerate(input_order):
input_dict[channel]["coord"] = transformed_in[i].tolist()
channel_info.update({
"templateByName" : template_dict,
"inputByName" : input_dict,
})
return channel_info
def fill_channels(app_state, channel_info, fill_mode):
new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
template_dict = channel_info["templateByName"]
input_dict = channel_info["inputByName"]
template_order = channel_info["templateByIndex"]
input_order = channel_info["inputByIndex"]
z_row_idx = channel_info["dataShape"][0]
unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
if unmatched == []:
return app_state
if fill_mode == 'zero':
for channel in unmatched:
idx = template_dict[channel]["index"]
new_idx[idx] = [z_row_idx]
elif fill_mode == 'mean_auto':
# use KNN to choose k nearest channels
in_coords = [input_dict[channel]["coord"] for channel in input_order]
in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
k = 4 if len(input_dict)>4 else len(input_dict)
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
knn.fit(in_coords)
for channel in unmatched:
distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1))
selected = [input_order[i] for i in indices[0]]
print(channel, ':', selected)
idx = template_dict[channel]["index"]
new_idx[idx] = indices[0].tolist()
if app_state["runnigState"] == "stage1":
app_state["stage1NewOrder"] = new_idx
else:
app_state["stage2NewOrder"] = new_idx
return app_state
def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
second1 = time.time()
template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
template_order = template_montage.ch_names
new_idx = [[]]*30
missing_channels = []
alias = {
'T3': 'T7',
'T4': 'T8',
'T5': 'P7',
'T6': 'P8',
#'TP7': 'T5\'',
#'TP8': 'T6\'',
}
# match the names of input channels -> template channels
for i, channel in enumerate(template_order):
if channel in alias and alias[channel] in input_dict:
template_montage.rename_channels({channel: alias[channel]})
template_dict[alias[channel]] = template_dict.pop(channel)
template_dict[alias[channel]].name = alias[channel]
channel = alias[channel]
if channel in input_dict:
new_idx[i] = [input_dict[channel].index]
template_dict[channel].matched = True
input_dict[channel].matched = True
input_dict[channel].assigned = True
else:
missing_channels.append(i)
channel_info.update({
"missingChannelsIndex" : missing_channels,
"templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
"inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
"templateByIndex" : template_montage.ch_names,
"inputByIndex" : input_montage.ch_names
})
app_state.update({
"stage1NewOrder" : new_idx,
"runnigState" : "stage1"
})
# align input, template's coordinates
channel_info = align_coords(channel_info, template_montage, input_montage)
# fill the unmatched channels
app_state = fill_channels(app_state, channel_info, fill_mode)
second2 = time.time()
print('Mapping (stage1) finished in',second2 - second1,'s.')
return app_state, channel_info
def mapping_stage2(app_state, channel_info, fill_mode):
second1 = time.time()
template_dict = channel_info["templateByName"]
input_dict = channel_info["inputByName"]
template_order = channel_info["templateByIndex"]
unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
if unassigned == []:
app_state["runnigState"] = "finished"
return app_state, channel_info
tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
# set all tpl.matched to False
for channel in template_dict:
template_dict[channel]["matched"] = False
# initialize the cost matrix
if len(unassigned) < 30:
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row
else:
cost_matrix = np.zeros((30, len(unassigned)))
for i in range(30):
for j in range(len(unassigned)):
cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance
#print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
row_idx, col_idx = linear_sum_assignment(cost_matrix)
matches = []
new_idx = [[]]*30
for i in range(30):
if col_idx[i] < len(unassigned): # filter out dummy channels
matches.append([row_idx[i], col_idx[i]])
tpl_channel = template_order[row_idx[i]]
in_channel = unassigned[col_idx[i]]
template_dict[tpl_channel]["matched"] = True
input_dict[in_channel]["assigned"] = True
new_idx[i] = [input_dict[in_channel]["index"]]
print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
channel_info.update({
"templateByName" : template_dict,
"inputByName" : input_dict
})
app_state.update({
"stage2NewOrder" : new_idx,
"runnigState" : "stage2"
})
# fill the unmatched channels
app_state = fill_channels(app_state, channel_info, fill_mode)
second2 = time.time()
print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.')
return app_state, channel_info