AIEEG / channel_mapping.py
audrey06100's picture
init
7129427
raw
history blame
5.33 kB
import utils
import os
import numpy as np
import mne
from mne.channels import read_custom_montage
def reorder_data(filename, old_idx):
filepath = os.path.dirname(str(filename))
old_data = utils.read_train_data(filename)
new_data = np.zeros((30, old_data.shape[1]))
#print('old = ', old_data.shape)
for j in range(30):
new_data[j, :] = old_data[old_idx[j]-1, :]
#print('i = ', i+1, ', ', new_data.shape)
utils.save_data(new_data, filepath+'/mapped.csv')
return
def mapping(input_file, loc_file, fill_mode):
template_montage = read_custom_montage("./template_chanlocs.loc")
input_montage = read_custom_montage(loc_file)
#template_montage.plot()
#input_montage.plot()
input_labels_dict = {}
for i in range(30):
template_montage.rename_channels({template_montage.ch_names[i]:str.upper(template_montage.ch_names[i])}) # 統一大寫
for i in range(len(input_montage.ch_names)):
input_montage.rename_channels({input_montage.ch_names[i]:str.upper(input_montage.ch_names[i])}) # 統一大寫
input_labels_dict[input_montage.ch_names[i]] = i
new_idx = [-1]*30
new_idx_name = ['']*30 # tmp
input_used = [0]*len(input_montage.ch_names)
finish_flag = 1
alias = {'T3':'T7', 'T4':'T8', 'T5':'P7', 'T6':'P8'} # CP7,FT7 ?
# correct place
for i in range(30):
channel_name = template_montage.ch_names[i]
if channel_name in input_labels_dict:
new_idx[i] = input_labels_dict[channel_name]
new_idx_name[i] = channel_name # tmp
input_used[new_idx[i]] = 1
elif channel_name in alias:
template_montage.rename_channels({channel_name:alias[channel_name]})
channel_name = template_montage.ch_names[i]
new_idx[i] = input_labels_dict[channel_name]
new_idx_name[i] = channel_name # tmp
input_used[new_idx[i]] = 1
else:
finish_flag = 0
if finish_flag == 1:
print('Finish at stage 1,2 !')
reorder_data(input_file, new_idx) # & save data to mapped.csv
return
# store channel positions in 2-d array
template_pos = []
template_pos_idx = []
temporal_channels = []
temporal_row_prefix = ['FC','C','CP','P']
cnt = 0
for i in range(7):
tmp = []
for j in range(5):
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
tmp.append('')
else:
tmp.append(template_montage.ch_names[cnt])
template_pos_idx.append([i,j])
if i>1 and j in [0,4]:
temporal_channels.append(template_montage.ch_names[cnt])
cnt += 1
template_pos.append(tmp)
# CZ
template_CZ_idx = 14
if new_idx[template_CZ_idx] == -1:
min_dist = 1e5
nearest_channel = 'CZ'
for channel in input_montage.ch_names:
cur_x, cur_y, cur_z = input_montage.get_positions()['ch_pos'][channel]
if cur_x**2+cur_y**2 < min_dist and channel != 'CZ':
nearest_channel = channel
min_dist = cur_x**2+cur_y**2
input_labels_dict['CZ'] = input_labels_dict[nearest_channel]
finish_flag = 1
if fill_mode == "zero":
z_row_idx = len(input_montage.ch_names)
for i in range(30):
if new_idx[i] != -1:
continue
channel_name = template_montage.ch_names[i]
channel_prefix = channel_name[:len(channel_name)-1]
channel_suffix = -1 if channel_name[-1]=='Z' else int(channel_name[-1])
# current target channel is in the middle
if channel_suffix == -1:
if fill_mode == "zero":
new_idx[i] = z_row_idx
elif fill_mode == "adjacent":
if channel_prefix+str(1) in input_labels_dict: # ex: FCZ<-FC1
new_idx[i] = input_labels_dict[channel_prefix+str(1)]
new_idx_name[i] = channel_prefix+str(1) # tmp
elif (channel_name in ['FCZ','CPZ']): # and ('CZ' in input_labels_dict): # ex: FCZ<-CZ
new_idx[i] = input_labels_dict['CZ']
new_idx_name[i] = 'CZ' # tmp
elif channel_prefix+str(3) in input_labels_dict: # ex: FCZ<-FC3
new_idx[i] = input_labels_dict[channel_prefix+str(3)]
new_idx_name[i] = channel_prefix+str(3) # tmp
else:
new_idx[i] = input_labels_dict['CZ']
new_idx_name[i] = 'CZ' # tmp
# current target channel is in the left/right region
else:
try:
# if the current target channel is a temporal channel
potential_neighbor = temporal_row_prefix[temporal_channels.index(channel_name)//2]+str(5 if channel_suffix%2==1 else 6) # ex: FT7<-FC5
except:
potential_neighbor = channel_name[:len(channel_name)-1]+str(channel_suffix-2) # ex: FC3<-FC1, FC4<-FC2
if (potential_neighbor in input_labels_dict) and (input_used[input_labels_dict[potential_neighbor]]==0):
new_idx[i] = input_labels_dict[potential_neighbor]
new_idx_name[i] = potential_neighbor # tmp
input_used[new_idx[i]] = 1
else:
if fill_mode == "zero":
new_idx[i] = z_row_idx
elif fill_mode == "adjacent": # 先這樣暫時這樣...QQ
mid_channel = template_pos[template_pos_idx[i][0]][2]
mid_channel_idx = template_montage.ch_names.index(mid_channel)
new_idx[i] = new_idx[mid_channel_idx]
new_idx_name[i] = mid_channel # tmp
#finish_flag = 0
#if finish_flag == 1:
# print('Finish at stage 3,4 !')
# reorder_data(input_file, new_idx) # & save data to mapped.csv
# return
#else:
# print('Error: the channel mapping process has failed!')
reorder_data(input_file, new_idx)