Spaces:
Sleeping
Sleeping
import utils | |
import time | |
import os | |
import numpy as np | |
import gradio as gr | |
import mne | |
from mne.channels import read_custom_montage | |
from scipy.interpolate import Rbf | |
from scipy.optimize import linear_sum_assignment | |
from sklearn.neighbors import NearestNeighbors | |
def reorder_to_template(app_state, filename): | |
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] | |
old_data = utils.read_train_data(filename) # original raw data | |
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data | |
new_filename = app_state["filepath"]+'mapped.csv' | |
#print('new order 1:', app_state["stage1NewOrder"]) | |
#print('new order 2:', app_state["stage2NewOrder"]) | |
zero_arr = np.zeros((1, old_data.shape[1])) | |
old_data = np.concatenate((old_data, zero_arr), axis=0) | |
for i in range(30): | |
curr_idx_set = old_idx[i] | |
#print("channel_{}'s index set: {}".format(i, curr_idx_set)) | |
if curr_idx_set == []: | |
new_data[i, :] = zero_arr | |
else: | |
tmp_data = [old_data[j, :] for j in curr_idx_set] | |
new_data[i, :] = np.mean(tmp_data, axis=0) | |
print('old.shape, new.shape: ', old_data.shape, new_data.shape) | |
utils.save_data(new_data, new_filename) | |
return | |
def reorder_to_origin(app_state, channel_info, filename, new_filename): | |
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] | |
old_data = utils.read_train_data(filename) # denoised data | |
template_order = channel_info["templateByIndex"] | |
if app_state["runnigState"] == "stage1": | |
new_data = np.zeros((len(channel_info["inputByName"]), old_data.shape[1])) | |
else: | |
new_data = utils.read_train_data(new_filename) | |
for i, channel in enumerate(template_order): | |
idx_set = old_idx[i] | |
# ignore if this channel doesn't exist | |
if len(idx_set)==1 and channel_info["templateByName"][channel]["matched"]==True: | |
new_data[idx_set[0], :] = old_data[i, :] | |
print('old.shape, new.shape: ', old_data.shape, new_data.shape) | |
utils.save_data(new_data, new_filename) | |
return | |
class Channel: | |
def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None): | |
self.name = name | |
self.index = index | |
self.matched = matched | |
self.assigned = assigned # for input channels | |
self.coord = coord | |
self.css_position = css_position | |
def read_montage_data(loc_file): | |
template_montage = read_custom_montage("./template_chanlocs.loc") | |
input_montage = read_custom_montage(loc_file) | |
template_dict = {} | |
input_dict = {} | |
montages = [template_montage, input_montage] | |
dicts = [template_dict, input_dict] | |
num = [30, len(input_montage.ch_names)] | |
for i in range(2): | |
for j in range(num[i]): | |
channel = montages[i].ch_names[j] | |
montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase | |
channel = str.upper(channel) | |
dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel]) | |
return template_montage, input_montage, template_dict, input_dict | |
def align_coords(channel_info, template_montage, input_montage): | |
template_dict = channel_info["templateByName"] | |
input_dict = channel_info["inputByName"] | |
template_order = channel_info["templateByIndex"] | |
input_order = channel_info["inputByIndex"] | |
matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True] | |
# 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual') | |
fig = [template_montage.plot(), input_montage.plot()] | |
fig[0].set_size_inches(5.6, 5.6) | |
fig[1].set_size_inches(5.6, 5.6) | |
ax = [fig[0].axes[0], fig[1].axes[0]] | |
ax[0].set_aspect('equal') | |
ax[1].set_aspect('equal') | |
ax[0].figure.canvas.draw() #update the figure | |
ax[1].figure.canvas.draw() | |
# get the original coords | |
all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px) | |
all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data) | |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched]) | |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched]) | |
# transform the xy axis (template's -> input's) | |
rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate') | |
rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate') | |
# apply to all template channels | |
transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1]) | |
transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1]) | |
#transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T | |
# update input, template's position | |
for i, channel in enumerate(template_order): | |
css_left = (transformed_tpl_x[i]-11)/560 | |
css_bottom = (transformed_tpl_y[i]-7)/560 | |
template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] | |
for i, channel in enumerate(input_order): | |
css_left = (all_in[i][0]-11)/560 | |
css_bottom = (all_in[i][1]-7)/560 | |
input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"] | |
# 3-d (to use KNN) | |
# get the original coords | |
all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order]) | |
all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order]) | |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched]) | |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched]) | |
# transform the xyz axis (input's -> template's) | |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate') | |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate') | |
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate') | |
# apply to all input channels | |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2]) | |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T | |
# update input's position | |
for i, channel in enumerate(input_order): | |
input_dict[channel]["coord"] = transformed_in[i].tolist() | |
channel_info.update({ | |
"templateByName" : template_dict, | |
"inputByName" : input_dict, | |
}) | |
return channel_info | |
def fill_channels(app_state, channel_info, fill_mode): | |
new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"] | |
template_dict = channel_info["templateByName"] | |
input_dict = channel_info["inputByName"] | |
template_order = channel_info["templateByIndex"] | |
input_order = channel_info["inputByIndex"] | |
z_row_idx = channel_info["dataShape"][0] | |
unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False] | |
if unmatched == []: | |
return app_state | |
if fill_mode == 'zero': | |
for channel in unmatched: | |
idx = template_dict[channel]["index"] | |
new_idx[idx] = [z_row_idx] | |
elif fill_mode == 'mean_auto': | |
# use KNN to choose k nearest channels | |
in_coords = [input_dict[channel]["coord"] for channel in input_order] | |
in_coords = np.array([in_coords[i] for i in range(len(in_coords))]) | |
k = 4 if len(input_dict)>4 else len(input_dict) | |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean') | |
knn.fit(in_coords) | |
for channel in unmatched: | |
distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1)) | |
selected = [input_order[i] for i in indices[0]] | |
print(channel, ':', selected) | |
idx = template_dict[channel]["index"] | |
new_idx[idx] = indices[0].tolist() | |
if app_state["runnigState"] == "stage1": | |
app_state["stage1NewOrder"] = new_idx | |
else: | |
app_state["stage2NewOrder"] = new_idx | |
return app_state | |
def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode): | |
second1 = time.time() | |
template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file) | |
template_order = template_montage.ch_names | |
new_idx = [[]]*30 | |
missing_channels = [] | |
alias = { | |
'T3': 'T7', | |
'T4': 'T8', | |
'T5': 'P7', | |
'T6': 'P8', | |
#'TP7': 'T5\'', | |
#'TP8': 'T6\'', | |
} | |
# match the names of input channels -> template channels | |
for i, channel in enumerate(template_order): | |
if channel in alias and alias[channel] in input_dict: | |
template_montage.rename_channels({channel: alias[channel]}) | |
template_dict[alias[channel]] = template_dict.pop(channel) | |
template_dict[alias[channel]].name = alias[channel] | |
channel = alias[channel] | |
if channel in input_dict: | |
new_idx[i] = [input_dict[channel].index] | |
template_dict[channel].matched = True | |
input_dict[channel].matched = True | |
input_dict[channel].assigned = True | |
else: | |
missing_channels.append(i) | |
channel_info.update({ | |
"missingChannelsIndex" : missing_channels, | |
"templateByName" : {k : v.__dict__ for k,v in template_dict.items()}, | |
"inputByName" : {k : v.__dict__ for k,v in input_dict.items()}, | |
"templateByIndex" : template_montage.ch_names, | |
"inputByIndex" : input_montage.ch_names | |
}) | |
app_state.update({ | |
"stage1NewOrder" : new_idx, | |
"runnigState" : "stage1" | |
}) | |
# align input, template's coordinates | |
channel_info = align_coords(channel_info, template_montage, input_montage) | |
# fill the unmatched channels | |
app_state = fill_channels(app_state, channel_info, fill_mode) | |
second2 = time.time() | |
print('Mapping (stage1) finished in',second2 - second1,'s.') | |
return app_state, channel_info | |
def mapping_stage2(app_state, channel_info, fill_mode): | |
second1 = time.time() | |
template_dict = channel_info["templateByName"] | |
input_dict = channel_info["inputByName"] | |
template_order = channel_info["templateByIndex"] | |
unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False] | |
if unassigned == []: | |
app_state["runnigState"] = "finished" | |
return app_state, channel_info | |
tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order]) | |
unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned]) | |
# set all tpl.matched to False | |
for channel in template_dict: | |
template_dict[channel]["matched"] = False | |
# initialize the cost matrix | |
if len(unassigned) < 30: | |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row | |
else: | |
cost_matrix = np.zeros((30, len(unassigned))) | |
for i in range(30): | |
for j in range(len(unassigned)): | |
cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance | |
#print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j]) | |
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...? | |
row_idx, col_idx = linear_sum_assignment(cost_matrix) | |
matches = [] | |
new_idx = [[]]*30 | |
for i in range(30): | |
if col_idx[i] < len(unassigned): # filter out dummy channels | |
matches.append([row_idx[i], col_idx[i]]) | |
tpl_channel = template_order[row_idx[i]] | |
in_channel = unassigned[col_idx[i]] | |
template_dict[tpl_channel]["matched"] = True | |
input_dict[in_channel]["assigned"] = True | |
new_idx[i] = [input_dict[in_channel]["index"]] | |
print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]]) | |
channel_info.update({ | |
"templateByName" : template_dict, | |
"inputByName" : input_dict | |
}) | |
app_state.update({ | |
"stage2NewOrder" : new_idx, | |
"runnigState" : "stage2" | |
}) | |
# fill the unmatched channels | |
app_state = fill_channels(app_state, channel_info, fill_mode) | |
second2 = time.time() | |
print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.') | |
return app_state, channel_info | |