audrey06100 commited on
Commit
94bf054
·
1 Parent(s): f3fbfd6

update channel_mapping.py

Browse files
Files changed (1) hide show
  1. channel_mapping.py +208 -108
channel_mapping.py CHANGED
@@ -1,78 +1,135 @@
1
  import utils
 
 
2
  import os
3
  import numpy as np
4
 
5
  import mne
6
  from mne.channels import read_custom_montage
7
 
8
- def reorder_data(filename, old_idx):
9
  filepath = os.path.dirname(str(filename))
10
  old_data = utils.read_train_data(filename)
11
  new_data = np.zeros((30, old_data.shape[1]))
12
- #print('old = ', old_data.shape)
 
 
 
 
13
 
14
- for j in range(30):
15
- new_data[j, :] = old_data[old_idx[j]-1, :]
16
 
17
- #print('i = ', i+1, ', ', new_data.shape)
18
  utils.save_data(new_data, filepath+'/mapped.csv')
19
  return
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def mapping(input_file, loc_file, fill_mode):
 
 
22
  template_montage = read_custom_montage("./template_chanlocs.loc")
23
  input_montage = read_custom_montage(loc_file)
24
  #template_montage.plot()
25
  #input_montage.plot()
 
 
 
26
 
27
- input_labels_dict = {}
28
  for i in range(30):
29
- template_montage.rename_channels({template_montage.ch_names[i]:str.upper(template_montage.ch_names[i])}) # 統一大寫
 
 
 
 
30
 
31
  for i in range(len(input_montage.ch_names)):
32
- input_montage.rename_channels({input_montage.ch_names[i]:str.upper(input_montage.ch_names[i])}) # 統一大寫
33
- input_labels_dict[input_montage.ch_names[i]] = i
 
 
 
34
 
35
 
36
  new_idx = [-1]*30
37
- new_idx_name = ['']*30 # tmp
38
- input_used = [0]*len(input_montage.ch_names)
 
 
 
 
 
 
39
  finish_flag = 1
40
-
41
- alias = {'T3':'T7', 'T4':'T8', 'T5':'P7', 'T6':'P8'} # CP7,FT7 ?
 
 
 
 
 
 
42
 
43
- # correct place
44
  for i in range(30):
45
- channel_name = template_montage.ch_names[i]
46
-
47
- if channel_name in input_labels_dict:
48
- new_idx[i] = input_labels_dict[channel_name]
49
- new_idx_name[i] = channel_name # tmp
50
 
51
- input_used[new_idx[i]] = 1
 
 
 
 
 
 
 
 
52
 
53
- elif channel_name in alias:
54
- template_montage.rename_channels({channel_name:alias[channel_name]})
55
- channel_name = template_montage.ch_names[i]
56
- new_idx[i] = input_labels_dict[channel_name]
57
- new_idx_name[i] = channel_name # tmp
58
 
59
- input_used[new_idx[i]] = 1
60
- else:
61
- finish_flag = 0
62
-
63
  if finish_flag == 1:
64
- print('Finish at stage 1,2 !')
65
- reorder_data(input_file, new_idx) # & save data to mapped.csv
 
 
 
 
66
  return
67
 
68
 
69
 
70
- # store channel positions in 2-d array
71
- template_pos = []
72
- template_pos_idx = []
73
-
74
  temporal_channels = []
75
- temporal_row_prefix = ['FC','C','CP','P']
76
 
77
  cnt = 0
78
  for i in range(7):
@@ -81,93 +138,136 @@ def mapping(input_file, loc_file, fill_mode):
81
  if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
82
  tmp.append('')
83
  else:
84
- tmp.append(template_montage.ch_names[cnt])
85
- template_pos_idx.append([i,j])
 
 
 
 
 
86
 
87
- if i>1 and j in [0,4]:
88
- temporal_channels.append(template_montage.ch_names[cnt])
89
  cnt += 1
90
- template_pos.append(tmp)
91
-
92
-
93
-
94
- # CZ
95
- template_CZ_idx = 14
96
- if new_idx[template_CZ_idx] == -1:
97
  min_dist = 1e5
98
- nearest_channel = 'CZ'
99
  for channel in input_montage.ch_names:
100
- cur_x, cur_y, cur_z = input_montage.get_positions()['ch_pos'][channel]
101
- if cur_x**2+cur_y**2 < min_dist and channel != 'CZ':
102
- nearest_channel = channel
103
- min_dist = cur_x**2+cur_y**2
104
- input_labels_dict['CZ'] = input_labels_dict[nearest_channel]
105
 
 
 
 
106
 
 
 
 
 
107
 
108
- finish_flag = 1
109
 
110
- if fill_mode == "zero":
111
- z_row_idx = len(input_montage.ch_names)
112
 
113
  for i in range(30):
114
  if new_idx[i] != -1:
115
  continue
116
 
117
- channel_name = template_montage.ch_names[i]
118
- channel_prefix = channel_name[:len(channel_name)-1]
119
- channel_suffix = -1 if channel_name[-1]=='Z' else int(channel_name[-1])
120
 
121
- # current target channel is in the middle
122
- if channel_suffix == -1:
123
 
124
- if fill_mode == "zero":
125
- new_idx[i] = z_row_idx
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- elif fill_mode == "adjacent":
128
-
129
- if channel_prefix+str(1) in input_labels_dict: # ex: FCZ<-FC1
130
- new_idx[i] = input_labels_dict[channel_prefix+str(1)]
131
- new_idx_name[i] = channel_prefix+str(1) # tmp
132
- elif (channel_name in ['FCZ','CPZ']): # and ('CZ' in input_labels_dict): # ex: FCZ<-CZ
133
- new_idx[i] = input_labels_dict['CZ']
134
- new_idx_name[i] = 'CZ' # tmp
135
- elif channel_prefix+str(3) in input_labels_dict: # ex: FCZ<-FC3
136
- new_idx[i] = input_labels_dict[channel_prefix+str(3)]
137
- new_idx_name[i] = channel_prefix+str(3) # tmp
 
 
 
 
 
 
 
 
 
 
 
138
  else:
139
- new_idx[i] = input_labels_dict['CZ']
140
- new_idx_name[i] = 'CZ' # tmp
141
-
142
- # current target channel is in the left/right region
143
- else:
144
- try:
145
- # if the current target channel is a temporal channel
146
- potential_neighbor = temporal_row_prefix[temporal_channels.index(channel_name)//2]+str(5 if channel_suffix%2==1 else 6) # ex: FT7<-FC5
147
- except:
148
- potential_neighbor = channel_name[:len(channel_name)-1]+str(channel_suffix-2) # ex: FC3<-FC1, FC4<-FC2
149
-
150
- if (potential_neighbor in input_labels_dict) and (input_used[input_labels_dict[potential_neighbor]]==0):
151
- new_idx[i] = input_labels_dict[potential_neighbor]
152
- new_idx_name[i] = potential_neighbor # tmp
153
-
154
- input_used[new_idx[i]] = 1
155
- else:
156
- if fill_mode == "zero":
157
- new_idx[i] = z_row_idx
158
- elif fill_mode == "adjacent": # 先這樣暫時這樣...QQ
159
- mid_channel = template_pos[template_pos_idx[i][0]][2]
160
- mid_channel_idx = template_montage.ch_names.index(mid_channel)
161
- new_idx[i] = new_idx[mid_channel_idx]
162
- new_idx_name[i] = mid_channel # tmp
163
-
164
- #finish_flag = 0
165
-
166
- #if finish_flag == 1:
167
- # print('Finish at stage 3,4 !')
168
- # reorder_data(input_file, new_idx) # & save data to mapped.csv
169
- # return
170
- #else:
171
- # print('Error: the channel mapping process has failed!')
172
- reorder_data(input_file, new_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
 
1
  import utils
2
+
3
+ import time
4
  import os
5
  import numpy as np
6
 
7
  import mne
8
  from mne.channels import read_custom_montage
9
 
10
+ def reorder_data(filename, old_idx, fill_mode):
11
  filepath = os.path.dirname(str(filename))
12
  old_data = utils.read_train_data(filename)
13
  new_data = np.zeros((30, old_data.shape[1]))
14
+ print('old data shape: ', old_data.shape)
15
+
16
+ if fill_mode == 'zero':
17
+ zero_arr = np.zeros((1, old_data.shape[1]))
18
+ old_data = np.concatenate((old_data, zero_arr), axis=0)
19
 
20
+ for i in range(30):
21
+ new_data[i, :] = old_data[old_idx[i]-1, :]
22
 
23
+ print('new data shape: ', new_data.shape)
24
  utils.save_data(new_data, filepath+'/mapped.csv')
25
  return
26
 
27
+
28
+ class Channel: # (DigMontage):
29
+
30
+ def __init__(self, index, name=None, used=False, coord=None, topo_index=None, topo_location=None):
31
+
32
+ # super().__init__()
33
+ # self._montage = montage
34
+
35
+ self.name = name
36
+ self.index = index
37
+ self.used = used
38
+ self.coord = coord
39
+ self.topo_index = topo_index
40
+ self.topo_location = topo_location
41
+
42
+ def prefix(self):
43
+ ret = ''.join(filter(str.isalpha, self.name))
44
+ return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret
45
+
46
+ def suffix(self):
47
+ return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name)))
48
+
49
+
50
  def mapping(input_file, loc_file, fill_mode):
51
+ second1 = time.time()
52
+
53
  template_montage = read_custom_montage("./template_chanlocs.loc")
54
  input_montage = read_custom_montage(loc_file)
55
  #template_montage.plot()
56
  #input_montage.plot()
57
+
58
+ template = {}
59
+ input = {}
60
 
61
+ # convert all channel names to uppercase
62
  for i in range(30):
63
+ channel = template_montage.ch_names[i]
64
+ template_montage.rename_channels({channel: str.upper(channel)})
65
+
66
+ channel = str.upper(channel)
67
+ template[channel] = Channel(index=i, name=channel)
68
 
69
  for i in range(len(input_montage.ch_names)):
70
+ channel = input_montage.ch_names[i]
71
+ input_montage.rename_channels({channel: str.upper(channel)})
72
+
73
+ channel = str.upper(channel)
74
+ input[channel] = Channel(index=i, coord=input_montage.get_positions()['ch_pos'][channel])
75
 
76
 
77
  new_idx = [-1]*30
78
+ new_idx_name = ['']*30 # tmp
79
+ missing_channels = []
80
+ z_row_idx = len(input_montage.ch_names)
81
+
82
+
83
+ # STAGE_1
84
+
85
+ # match the template's channel names with the input ones
86
  finish_flag = 1
87
+ alias = {
88
+ 'T3': 'T7',
89
+ 'T4': 'T8',
90
+ 'T5': 'P7',
91
+ 'T6': 'P8',
92
+ 'TP7': 'T5\'',
93
+ 'TP8': 'T6\'',
94
+ }
95
 
 
96
  for i in range(30):
97
+ channel = template_montage.ch_names[i]
98
+ if channel not in input.keys() | alias.keys():
99
+ finish_flag = 0
100
+ continue
 
101
 
102
+ if channel not in input and channel in alias:
103
+ if alias[channel] in input:
104
+ template_montage.rename_channels({channel: alias[channel]})
105
+ template[alias[channel]] = template.pop(channel)
106
+ template[alias[channel]].name = alias[channel]
107
+ channel = alias[channel]
108
+ else:
109
+ finish_flag = 0
110
+ continue
111
 
112
+ new_idx[i] = input[channel].index
113
+ new_idx_name[i] = channel # tmp
114
+ input[channel].used = True
 
 
115
 
 
 
 
 
116
  if finish_flag == 1:
117
+ second2 = time.time()
118
+ print('Finish at stage 1 ! (',second2 - second1,'s)')
119
+ #print('new idx order:', new_idx)
120
+ #print('new_idx_name:', new_idx_name) # tmp
121
+
122
+ reorder_data(input_file, new_idx, fill_mode) # & save data to mapped.csv
123
  return
124
 
125
 
126
 
127
+ # STAGE_2
128
+
129
+ # store channel positions in a 2-d array
130
+ template_topo_pos = []
131
  temporal_channels = []
132
+ temporal_row_prefix = ['FC', 'C', 'CP', 'P']
133
 
134
  cnt = 0
135
  for i in range(7):
 
138
  if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
139
  tmp.append('')
140
  else:
141
+ channel = template_montage.ch_names[cnt]
142
+ tmp.append(channel)
143
+
144
+ ver = 'front' if i<3 else 'center' if i==3 else 'back'
145
+ hor = 'left' if j<2 else 'center' if j==2 else 'right'
146
+ template[channel].topo_index = [i, j]
147
+ template[channel].topo_location = [ver, hor]
148
 
149
+ if i > 1 and j in [0, 4]:
150
+ temporal_channels.append(channel)
151
  cnt += 1
152
+ template_topo_pos.append(tmp)
153
+
154
+
155
+ # ensure that CZ is found or imputed by another channel
156
+ if 'CZ' not in input and fill_mode=='adjacent':
 
 
157
  min_dist = 1e5
 
158
  for channel in input_montage.ch_names:
159
+ curr_x, curr_y, curr_z = input[channel].coord.round(6)
 
 
 
 
160
 
161
+ if curr_x**2 + curr_y**2 < min_dist:
162
+ nearest_channel = channel
163
+ min_dist = curr_x**2 + curr_y**2
164
 
165
+ if input[nearest_channel].used == True:
166
+ missing_channels.append('CZ')
167
+ input[nearest_channel].used = True
168
+ input['CZ'] = input[nearest_channel]
169
 
170
+ print("CZ's nearest neighbor:", nearest_channel)
171
 
 
 
172
 
173
  for i in range(30):
174
  if new_idx[i] != -1:
175
  continue
176
 
177
+ channel = template_montage.ch_names[i]
 
 
178
 
179
+ curr_prefix = template[channel].prefix()
180
+ curr_suffix = template[channel].suffix()
181
 
182
+ curr_row = template[channel].topo_index[0]
183
+ curr_col = template[channel].topo_index[1]
184
+ curr_ver = template[channel].topo_location[0]
185
+ curr_hor = template[channel].topo_location[1]
186
+
187
+ impute_channel = ''
188
+
189
+ # if the current channel is a temporal channel
190
+ if channel in temporal_channels:
191
+ curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2]
192
+ curr_suffix = 7 if curr_hor=='left' else 8
193
+
194
+ if fill_mode == 'zero':
195
 
196
+ impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2)
197
+ if impute_channel not in input or input[impute_channel].used==True:
198
+ impute_channel = ''
199
+ new_idx[i] = z_row_idx
200
+ missing_channels.append(channel)
201
+ continue
202
+
203
+ elif fill_mode == 'adjacent':
204
+
205
+ if curr_hor == 'center': # FZ, FPZ, CZ...
206
+
207
+ if curr_prefix+str(1) in input: # ex: FZ<-F1
208
+ impute_channel = curr_prefix + str(1)
209
+ elif curr_ver=='front' and template_topo_pos[curr_row+1][curr_col] in input: # ex: FZ<-FCZ
210
+ impute_channel = template_topo_pos[curr_row+1][curr_col]
211
+
212
+ elif curr_ver=='back' and template_topo_pos[curr_row-1][curr_col] in input: # ex: PZ<-CPZ
213
+ impute_channel = template_topo_pos[curr_row-1][curr_col]
214
+
215
+ elif curr_prefix+str(3) in input: # ex: FZ<-F3
216
+ impute_channel = curr_prefix + str(3)
217
+
218
  else:
219
+ impute_channel = 'CZ'
220
+
221
+ elif curr_hor == 'left' or curr_hor == 'right':
222
+
223
+ ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1
224
+ ver_dir = 1 if curr_ver == 'front' else -1
225
+
226
+ # search horizontally
227
+ cnt = 0
228
+ tmp_suffix = curr_suffix
229
+ while tmp_suffix > 0: # ex: F7<-F5/F3/F1
230
+ tmp_suffix = curr_suffix - 2*cnt
231
+ if curr_prefix+str(tmp_suffix) in input:
232
+ impute_channel = curr_prefix + str(tmp_suffix)
233
+ break
234
+
235
+ if cnt == 2:
236
+ # check row+1/row-1
237
+ if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input:
238
+ impute_channel = template_topo_pos[curr_row+1][curr_col]
239
+ break
240
+ if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input:
241
+ impute_channel = template_topo_pos[curr_row-1][curr_col]
242
+ break
243
+ cnt += 1
244
+
245
+ # search vertically
246
+ if impute_channel == '':
247
+ cnt = 0
248
+ tmp_row = curr_row + ver_dir
249
+ while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one
250
+ if template_topo_pos[tmp_row][curr_col] in input:
251
+ impute_channel = template_topo_pos[tmp_row][curr_col]
252
+ break
253
+ tmp_row += ver_dir
254
+
255
+ # if still cannot find available channel...
256
+ if impute_channel == '':
257
+ impute_channel = 'CZ'
258
+
259
+ new_idx[i] = input[impute_channel].index
260
+ new_idx_name[i] = impute_channel # tmp
261
+ if input[impute_channel].used == True: # this channel is shared with others
262
+ missing_channels.append(channel)
263
+ input[impute_channel].used = True
264
+
265
+ second2 = time.time()
266
+ print('Finish at stage 2 ! (',second2 - second1,'s)')
267
+ #print('new_idx:', new_idx)
268
+ #print('new_idx_name:', new_idx_name) # tmp
269
+ print('missing_channels:', missing_channels)
270
+ reorder_data(input_file, new_idx, fill_mode) # & save data to mapped.csv
271
+
272
+ return
273