Spaces:
Sleeping
Sleeping
Commit
·
94bf054
1
Parent(s):
f3fbfd6
update channel_mapping.py
Browse files- channel_mapping.py +208 -108
channel_mapping.py
CHANGED
@@ -1,78 +1,135 @@
|
|
1 |
import utils
|
|
|
|
|
2 |
import os
|
3 |
import numpy as np
|
4 |
|
5 |
import mne
|
6 |
from mne.channels import read_custom_montage
|
7 |
|
8 |
-
def reorder_data(filename, old_idx):
|
9 |
filepath = os.path.dirname(str(filename))
|
10 |
old_data = utils.read_train_data(filename)
|
11 |
new_data = np.zeros((30, old_data.shape[1]))
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
for
|
15 |
-
new_data[
|
16 |
|
17 |
-
|
18 |
utils.save_data(new_data, filepath+'/mapped.csv')
|
19 |
return
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def mapping(input_file, loc_file, fill_mode):
|
|
|
|
|
22 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
23 |
input_montage = read_custom_montage(loc_file)
|
24 |
#template_montage.plot()
|
25 |
#input_montage.plot()
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
for i in range(30):
|
29 |
-
template_montage.
|
|
|
|
|
|
|
|
|
30 |
|
31 |
for i in range(len(input_montage.ch_names)):
|
32 |
-
input_montage.
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
new_idx = [-1]*30
|
37 |
-
new_idx_name = ['']*30
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
finish_flag = 1
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
# correct place
|
44 |
for i in range(30):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
new_idx_name[i] = channel_name # tmp
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
new_idx[i] = input_labels_dict[channel_name]
|
57 |
-
new_idx_name[i] = channel_name # tmp
|
58 |
|
59 |
-
input_used[new_idx[i]] = 1
|
60 |
-
else:
|
61 |
-
finish_flag = 0
|
62 |
-
|
63 |
if finish_flag == 1:
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
return
|
67 |
|
68 |
|
69 |
|
70 |
-
#
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
temporal_channels = []
|
75 |
-
temporal_row_prefix = ['FC','C','CP','P']
|
76 |
|
77 |
cnt = 0
|
78 |
for i in range(7):
|
@@ -81,93 +138,136 @@ def mapping(input_file, loc_file, fill_mode):
|
|
81 |
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
|
82 |
tmp.append('')
|
83 |
else:
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
if i>1 and j in [0,4]:
|
88 |
-
temporal_channels.append(
|
89 |
cnt += 1
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
template_CZ_idx = 14
|
96 |
-
if new_idx[template_CZ_idx] == -1:
|
97 |
min_dist = 1e5
|
98 |
-
nearest_channel = 'CZ'
|
99 |
for channel in input_montage.ch_names:
|
100 |
-
|
101 |
-
if cur_x**2+cur_y**2 < min_dist and channel != 'CZ':
|
102 |
-
nearest_channel = channel
|
103 |
-
min_dist = cur_x**2+cur_y**2
|
104 |
-
input_labels_dict['CZ'] = input_labels_dict[nearest_channel]
|
105 |
|
|
|
|
|
|
|
106 |
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
|
110 |
-
if fill_mode == "zero":
|
111 |
-
z_row_idx = len(input_montage.ch_names)
|
112 |
|
113 |
for i in range(30):
|
114 |
if new_idx[i] != -1:
|
115 |
continue
|
116 |
|
117 |
-
|
118 |
-
channel_prefix = channel_name[:len(channel_name)-1]
|
119 |
-
channel_suffix = -1 if channel_name[-1]=='Z' else int(channel_name[-1])
|
120 |
|
121 |
-
|
122 |
-
|
123 |
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
else:
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
|
|
1 |
import utils
|
2 |
+
|
3 |
+
import time
|
4 |
import os
|
5 |
import numpy as np
|
6 |
|
7 |
import mne
|
8 |
from mne.channels import read_custom_montage
|
9 |
|
10 |
+
def reorder_data(filename, old_idx, fill_mode):
|
11 |
filepath = os.path.dirname(str(filename))
|
12 |
old_data = utils.read_train_data(filename)
|
13 |
new_data = np.zeros((30, old_data.shape[1]))
|
14 |
+
print('old data shape: ', old_data.shape)
|
15 |
+
|
16 |
+
if fill_mode == 'zero':
|
17 |
+
zero_arr = np.zeros((1, old_data.shape[1]))
|
18 |
+
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
19 |
|
20 |
+
for i in range(30):
|
21 |
+
new_data[i, :] = old_data[old_idx[i]-1, :]
|
22 |
|
23 |
+
print('new data shape: ', new_data.shape)
|
24 |
utils.save_data(new_data, filepath+'/mapped.csv')
|
25 |
return
|
26 |
|
27 |
+
|
28 |
+
class Channel: # (DigMontage):
|
29 |
+
|
30 |
+
def __init__(self, index, name=None, used=False, coord=None, topo_index=None, topo_location=None):
|
31 |
+
|
32 |
+
# super().__init__()
|
33 |
+
# self._montage = montage
|
34 |
+
|
35 |
+
self.name = name
|
36 |
+
self.index = index
|
37 |
+
self.used = used
|
38 |
+
self.coord = coord
|
39 |
+
self.topo_index = topo_index
|
40 |
+
self.topo_location = topo_location
|
41 |
+
|
42 |
+
def prefix(self):
|
43 |
+
ret = ''.join(filter(str.isalpha, self.name))
|
44 |
+
return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret
|
45 |
+
|
46 |
+
def suffix(self):
|
47 |
+
return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name)))
|
48 |
+
|
49 |
+
|
50 |
def mapping(input_file, loc_file, fill_mode):
|
51 |
+
second1 = time.time()
|
52 |
+
|
53 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
54 |
input_montage = read_custom_montage(loc_file)
|
55 |
#template_montage.plot()
|
56 |
#input_montage.plot()
|
57 |
+
|
58 |
+
template = {}
|
59 |
+
input = {}
|
60 |
|
61 |
+
# convert all channel names to uppercase
|
62 |
for i in range(30):
|
63 |
+
channel = template_montage.ch_names[i]
|
64 |
+
template_montage.rename_channels({channel: str.upper(channel)})
|
65 |
+
|
66 |
+
channel = str.upper(channel)
|
67 |
+
template[channel] = Channel(index=i, name=channel)
|
68 |
|
69 |
for i in range(len(input_montage.ch_names)):
|
70 |
+
channel = input_montage.ch_names[i]
|
71 |
+
input_montage.rename_channels({channel: str.upper(channel)})
|
72 |
+
|
73 |
+
channel = str.upper(channel)
|
74 |
+
input[channel] = Channel(index=i, coord=input_montage.get_positions()['ch_pos'][channel])
|
75 |
|
76 |
|
77 |
new_idx = [-1]*30
|
78 |
+
new_idx_name = ['']*30 # tmp
|
79 |
+
missing_channels = []
|
80 |
+
z_row_idx = len(input_montage.ch_names)
|
81 |
+
|
82 |
+
|
83 |
+
# STAGE_1
|
84 |
+
|
85 |
+
# match the template's channel names with the input ones
|
86 |
finish_flag = 1
|
87 |
+
alias = {
|
88 |
+
'T3': 'T7',
|
89 |
+
'T4': 'T8',
|
90 |
+
'T5': 'P7',
|
91 |
+
'T6': 'P8',
|
92 |
+
'TP7': 'T5\'',
|
93 |
+
'TP8': 'T6\'',
|
94 |
+
}
|
95 |
|
|
|
96 |
for i in range(30):
|
97 |
+
channel = template_montage.ch_names[i]
|
98 |
+
if channel not in input.keys() | alias.keys():
|
99 |
+
finish_flag = 0
|
100 |
+
continue
|
|
|
101 |
|
102 |
+
if channel not in input and channel in alias:
|
103 |
+
if alias[channel] in input:
|
104 |
+
template_montage.rename_channels({channel: alias[channel]})
|
105 |
+
template[alias[channel]] = template.pop(channel)
|
106 |
+
template[alias[channel]].name = alias[channel]
|
107 |
+
channel = alias[channel]
|
108 |
+
else:
|
109 |
+
finish_flag = 0
|
110 |
+
continue
|
111 |
|
112 |
+
new_idx[i] = input[channel].index
|
113 |
+
new_idx_name[i] = channel # tmp
|
114 |
+
input[channel].used = True
|
|
|
|
|
115 |
|
|
|
|
|
|
|
|
|
116 |
if finish_flag == 1:
|
117 |
+
second2 = time.time()
|
118 |
+
print('Finish at stage 1 ! (',second2 - second1,'s)')
|
119 |
+
#print('new idx order:', new_idx)
|
120 |
+
#print('new_idx_name:', new_idx_name) # tmp
|
121 |
+
|
122 |
+
reorder_data(input_file, new_idx, fill_mode) # & save data to mapped.csv
|
123 |
return
|
124 |
|
125 |
|
126 |
|
127 |
+
# STAGE_2
|
128 |
+
|
129 |
+
# store channel positions in a 2-d array
|
130 |
+
template_topo_pos = []
|
131 |
temporal_channels = []
|
132 |
+
temporal_row_prefix = ['FC', 'C', 'CP', 'P']
|
133 |
|
134 |
cnt = 0
|
135 |
for i in range(7):
|
|
|
138 |
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
|
139 |
tmp.append('')
|
140 |
else:
|
141 |
+
channel = template_montage.ch_names[cnt]
|
142 |
+
tmp.append(channel)
|
143 |
+
|
144 |
+
ver = 'front' if i<3 else 'center' if i==3 else 'back'
|
145 |
+
hor = 'left' if j<2 else 'center' if j==2 else 'right'
|
146 |
+
template[channel].topo_index = [i, j]
|
147 |
+
template[channel].topo_location = [ver, hor]
|
148 |
|
149 |
+
if i > 1 and j in [0, 4]:
|
150 |
+
temporal_channels.append(channel)
|
151 |
cnt += 1
|
152 |
+
template_topo_pos.append(tmp)
|
153 |
+
|
154 |
+
|
155 |
+
# ensure that CZ is found or imputed by another channel
|
156 |
+
if 'CZ' not in input and fill_mode=='adjacent':
|
|
|
|
|
157 |
min_dist = 1e5
|
|
|
158 |
for channel in input_montage.ch_names:
|
159 |
+
curr_x, curr_y, curr_z = input[channel].coord.round(6)
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
if curr_x**2 + curr_y**2 < min_dist:
|
162 |
+
nearest_channel = channel
|
163 |
+
min_dist = curr_x**2 + curr_y**2
|
164 |
|
165 |
+
if input[nearest_channel].used == True:
|
166 |
+
missing_channels.append('CZ')
|
167 |
+
input[nearest_channel].used = True
|
168 |
+
input['CZ'] = input[nearest_channel]
|
169 |
|
170 |
+
print("CZ's nearest neighbor:", nearest_channel)
|
171 |
|
|
|
|
|
172 |
|
173 |
for i in range(30):
|
174 |
if new_idx[i] != -1:
|
175 |
continue
|
176 |
|
177 |
+
channel = template_montage.ch_names[i]
|
|
|
|
|
178 |
|
179 |
+
curr_prefix = template[channel].prefix()
|
180 |
+
curr_suffix = template[channel].suffix()
|
181 |
|
182 |
+
curr_row = template[channel].topo_index[0]
|
183 |
+
curr_col = template[channel].topo_index[1]
|
184 |
+
curr_ver = template[channel].topo_location[0]
|
185 |
+
curr_hor = template[channel].topo_location[1]
|
186 |
+
|
187 |
+
impute_channel = ''
|
188 |
+
|
189 |
+
# if the current channel is a temporal channel
|
190 |
+
if channel in temporal_channels:
|
191 |
+
curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2]
|
192 |
+
curr_suffix = 7 if curr_hor=='left' else 8
|
193 |
+
|
194 |
+
if fill_mode == 'zero':
|
195 |
|
196 |
+
impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2)
|
197 |
+
if impute_channel not in input or input[impute_channel].used==True:
|
198 |
+
impute_channel = ''
|
199 |
+
new_idx[i] = z_row_idx
|
200 |
+
missing_channels.append(channel)
|
201 |
+
continue
|
202 |
+
|
203 |
+
elif fill_mode == 'adjacent':
|
204 |
+
|
205 |
+
if curr_hor == 'center': # FZ, FPZ, CZ...
|
206 |
+
|
207 |
+
if curr_prefix+str(1) in input: # ex: FZ<-F1
|
208 |
+
impute_channel = curr_prefix + str(1)
|
209 |
+
elif curr_ver=='front' and template_topo_pos[curr_row+1][curr_col] in input: # ex: FZ<-FCZ
|
210 |
+
impute_channel = template_topo_pos[curr_row+1][curr_col]
|
211 |
+
|
212 |
+
elif curr_ver=='back' and template_topo_pos[curr_row-1][curr_col] in input: # ex: PZ<-CPZ
|
213 |
+
impute_channel = template_topo_pos[curr_row-1][curr_col]
|
214 |
+
|
215 |
+
elif curr_prefix+str(3) in input: # ex: FZ<-F3
|
216 |
+
impute_channel = curr_prefix + str(3)
|
217 |
+
|
218 |
else:
|
219 |
+
impute_channel = 'CZ'
|
220 |
+
|
221 |
+
elif curr_hor == 'left' or curr_hor == 'right':
|
222 |
+
|
223 |
+
ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1
|
224 |
+
ver_dir = 1 if curr_ver == 'front' else -1
|
225 |
+
|
226 |
+
# search horizontally
|
227 |
+
cnt = 0
|
228 |
+
tmp_suffix = curr_suffix
|
229 |
+
while tmp_suffix > 0: # ex: F7<-F5/F3/F1
|
230 |
+
tmp_suffix = curr_suffix - 2*cnt
|
231 |
+
if curr_prefix+str(tmp_suffix) in input:
|
232 |
+
impute_channel = curr_prefix + str(tmp_suffix)
|
233 |
+
break
|
234 |
+
|
235 |
+
if cnt == 2:
|
236 |
+
# check row+1/row-1
|
237 |
+
if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input:
|
238 |
+
impute_channel = template_topo_pos[curr_row+1][curr_col]
|
239 |
+
break
|
240 |
+
if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input:
|
241 |
+
impute_channel = template_topo_pos[curr_row-1][curr_col]
|
242 |
+
break
|
243 |
+
cnt += 1
|
244 |
+
|
245 |
+
# search vertically
|
246 |
+
if impute_channel == '':
|
247 |
+
cnt = 0
|
248 |
+
tmp_row = curr_row + ver_dir
|
249 |
+
while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one
|
250 |
+
if template_topo_pos[tmp_row][curr_col] in input:
|
251 |
+
impute_channel = template_topo_pos[tmp_row][curr_col]
|
252 |
+
break
|
253 |
+
tmp_row += ver_dir
|
254 |
+
|
255 |
+
# if still cannot find available channel...
|
256 |
+
if impute_channel == '':
|
257 |
+
impute_channel = 'CZ'
|
258 |
+
|
259 |
+
new_idx[i] = input[impute_channel].index
|
260 |
+
new_idx_name[i] = impute_channel # tmp
|
261 |
+
if input[impute_channel].used == True: # this channel is shared with others
|
262 |
+
missing_channels.append(channel)
|
263 |
+
input[impute_channel].used = True
|
264 |
+
|
265 |
+
second2 = time.time()
|
266 |
+
print('Finish at stage 2 ! (',second2 - second1,'s)')
|
267 |
+
#print('new_idx:', new_idx)
|
268 |
+
#print('new_idx_name:', new_idx_name) # tmp
|
269 |
+
print('missing_channels:', missing_channels)
|
270 |
+
reorder_data(input_file, new_idx, fill_mode) # & save data to mapped.csv
|
271 |
+
|
272 |
+
return
|
273 |
|