File size: 12,124 Bytes
7129427
94bf054
7129427
 
a22369d
7129427
 
 
a22369d
 
 
 
884c10b
995c1d0
a22369d
 
884c10b
995c1d0
 
8f041e4
 
 
94bf054
 
8f041e4
 
 
 
 
 
 
 
a22369d
 
f1a11e6
7129427
 
884c10b
995c1d0
a22369d
884c10b
a22369d
884c10b
 
a22369d
 
 
 
 
 
 
884c10b
a22369d
 
 
 
 
94bf054
8f041e4
94bf054
a22369d
8f041e4
 
a22369d
 
8f041e4
7a54f74
94bf054
 
a22369d
 
7129427
 
a22369d
 
94bf054
7a54f74
 
 
 
 
 
 
a22369d
7a54f74
 
a22369d
 
 
 
884c10b
a22369d
884c10b
 
 
 
a22369d
7a54f74
a22369d
 
 
 
7a54f74
a22369d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94bf054
a22369d
 
 
 
94bf054
a22369d
 
 
 
 
 
 
 
 
 
884c10b
a22369d
 
 
884c10b
a22369d
884c10b
a22369d
995c1d0
884c10b
 
 
 
 
a22369d
 
884c10b
a22369d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995c1d0
a22369d
 
 
 
 
995c1d0
 
 
 
 
 
884c10b
a22369d
884c10b
a22369d
94bf054
a22369d
 
 
 
94bf054
 
 
 
 
a22369d
 
7a54f74
7129427
a22369d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884c10b
a22369d
 
 
 
 
 
884c10b
995c1d0
 
884c10b
a22369d
 
884c10b
a22369d
884c10b
a22369d
 
 
884c10b
7129427
884c10b
a22369d
94bf054
884c10b
 
 
a22369d
 
884c10b
 
8f041e4
a22369d
 
7129427
a22369d
 
 
7129427
a22369d
 
995c1d0
a22369d
 
 
 
995c1d0
 
7129427
a22369d
 
94bf054
a22369d
 
7129427
a22369d
 
0ab020b
a22369d
 
 
 
 
995c1d0
 
94bf054
884c10b
a22369d
 
 
884c10b
995c1d0
 
884c10b
94bf054
a22369d
884c10b
7129427
a22369d
884c10b
 
a22369d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import utils
import time
import os
import numpy as np
import gradio as gr

import mne
from mne.channels import read_custom_montage
from scipy.interpolate import Rbf
from scipy.optimize import linear_sum_assignment
from sklearn.neighbors import NearestNeighbors

def reorder_to_template(app_state, filename):
	old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
	old_data = utils.read_train_data(filename) # original raw data
	new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
	new_filename = app_state["filepath"]+'mapped.csv'
	#print('new order 1:', app_state["stage1NewOrder"])
	#print('new order 2:', app_state["stage2NewOrder"])
	
	zero_arr = np.zeros((1, old_data.shape[1]))
	old_data = np.concatenate((old_data, zero_arr), axis=0)
	
	for i in range(30):
		curr_idx_set = old_idx[i]
		#print("channel_{}'s index set: {}".format(i, curr_idx_set))
		
		if curr_idx_set == []:
			new_data[i, :] = zero_arr
		else:
			tmp_data = [old_data[j, :] for j in curr_idx_set]
			new_data[i, :] = np.mean(tmp_data, axis=0)
	
	print('old.shape, new.shape: ', old_data.shape, new_data.shape)
	utils.save_data(new_data, new_filename)
	return

def reorder_to_origin(app_state, channel_info, filename, new_filename):
	old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
	old_data = utils.read_train_data(filename) # denoised data
	template_order = channel_info["templateByIndex"]
	
	if app_state["runnigState"] == "stage1":
		new_data = np.zeros((len(channel_info["inputByName"]), old_data.shape[1]))
	else:
		new_data = utils.read_train_data(new_filename)
	
	for i, channel in enumerate(template_order):
		idx_set = old_idx[i]
		
		# ignore if this channel doesn't exist
		if len(idx_set)==1 and channel_info["templateByName"][channel]["matched"]==True:
			new_data[idx_set[0], :] = old_data[i, :]
	
	print('old.shape, new.shape: ', old_data.shape, new_data.shape)
	utils.save_data(new_data, new_filename)
	return

class Channel:

	def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None):
		self.name = name
		self.index = index
		self.matched = matched
		self.assigned = assigned # for input channels
		self.coord = coord
		self.css_position = css_position


def read_montage_data(loc_file):
		
	template_montage = read_custom_montage("./template_chanlocs.loc")
	input_montage = read_custom_montage(loc_file)
	template_dict = {}
	input_dict = {}
	
	montages = [template_montage, input_montage]
	dicts = [template_dict, input_dict]
	num = [30, len(input_montage.ch_names)]
	
	for i in range(2):
		for j in range(num[i]):
			channel = montages[i].ch_names[j]
			montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
	        
			channel = str.upper(channel)
			dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel])
	
	return template_montage, input_montage, template_dict, input_dict

def align_coords(channel_info, template_montage, input_montage):

	template_dict = channel_info["templateByName"]
	input_dict = channel_info["inputByName"]
	template_order = channel_info["templateByIndex"]
	input_order = channel_info["inputByIndex"]
	matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
	
	# 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
	fig = [template_montage.plot(), input_montage.plot()]
	fig[0].set_size_inches(5.6, 5.6)
	fig[1].set_size_inches(5.6, 5.6)
	
	ax = [fig[0].axes[0], fig[1].axes[0]]
	ax[0].set_aspect('equal')
	ax[1].set_aspect('equal')
	ax[0].figure.canvas.draw() #update the figure
	ax[1].figure.canvas.draw()
	
	# get the original coords
	all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px)
	all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
	matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
	matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
	
	# transform the xy axis (template's -> input's)
	rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate')
	rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate')
	
	# apply to all template channels
	transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1])
	transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1])
	#transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T
	
	# update input, template's position
	for i, channel in enumerate(template_order):
		css_left = (transformed_tpl_x[i]-11)/560
		css_bottom = (transformed_tpl_y[i]-7)/560
		template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
	for i, channel in enumerate(input_order):
		css_left = (all_in[i][0]-11)/560
		css_bottom = (all_in[i][1]-7)/560
		input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
	
	
	# 3-d (to use KNN)
	# get the original coords
	all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order])
	all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order])
	matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
	matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
	
	# transform the xyz axis (input's -> template's)
	rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
	rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
	rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
	
	# apply to all input channels
	transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
	transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
	transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
	transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
	
	# update input's position
	for i, channel in enumerate(input_order):
		input_dict[channel]["coord"] = transformed_in[i].tolist()
	
	channel_info.update({
	    "templateByName" : template_dict,
	    "inputByName" : input_dict,
	})
	return channel_info

def fill_channels(app_state, channel_info, fill_mode):

	new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
	template_dict = channel_info["templateByName"]
	input_dict = channel_info["inputByName"]
	template_order = channel_info["templateByIndex"]
	input_order = channel_info["inputByIndex"]
	z_row_idx = channel_info["dataShape"][0]
	unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
	if unmatched == []:
		return app_state
	
	if fill_mode == 'zero':
		for channel in unmatched:
			idx = template_dict[channel]["index"]
			new_idx[idx] = [z_row_idx]
	
	elif fill_mode == 'mean_auto':
		# use KNN to choose k nearest channels
		in_coords = [input_dict[channel]["coord"] for channel in input_order]
		in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
		
		k = 4 if len(input_dict)>4 else len(input_dict)
		knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
		knn.fit(in_coords)
		
		for channel in unmatched:
			distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1))
			selected = [input_order[i] for i in indices[0]]
			print(channel, ':', selected)
			
			idx = template_dict[channel]["index"]
			new_idx[idx] = indices[0].tolist()
	
	if app_state["runnigState"] == "stage1":
		app_state["stage1NewOrder"] = new_idx
	else:
		app_state["stage2NewOrder"] = new_idx
    
	return app_state

def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
	second1 = time.time()
	
	template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
	template_order = template_montage.ch_names
	new_idx = [[]]*30
	missing_channels = []
	alias = {
		'T3': 'T7',
		'T4': 'T8',
		'T5': 'P7',
		'T6': 'P8',
		#'TP7': 'T5\'',
		#'TP8': 'T6\'',
	}
	
	# match the names of input channels -> template channels
	for i, channel in enumerate(template_order):
		if channel in alias and alias[channel] in input_dict:
			template_montage.rename_channels({channel: alias[channel]})
			template_dict[alias[channel]] = template_dict.pop(channel)
			template_dict[alias[channel]].name = alias[channel]
			channel = alias[channel]
		
		if channel in input_dict:
			new_idx[i] = [input_dict[channel].index]
			
			template_dict[channel].matched = True
			input_dict[channel].matched = True
			input_dict[channel].assigned = True
		else:
			missing_channels.append(i)
	
	channel_info.update({
	    "missingChannelsIndex" : missing_channels,
	    "templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
	    "inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
	    "templateByIndex" : template_montage.ch_names,
		"inputByIndex" : input_montage.ch_names
	})
	app_state.update({
	    "stage1NewOrder" : new_idx,
	    "runnigState" : "stage1"
	})
	
	# align input, template's coordinates
	channel_info = align_coords(channel_info, template_montage, input_montage)
	# fill the unmatched channels
	app_state = fill_channels(app_state, channel_info, fill_mode)
	
	second2 = time.time()
	print('Mapping (stage1) finished in',second2 - second1,'s.')
	return app_state, channel_info

def mapping_stage2(app_state, channel_info, fill_mode):
	second1 = time.time()

	template_dict = channel_info["templateByName"]
	input_dict = channel_info["inputByName"]
	template_order = channel_info["templateByIndex"]
	unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
	if unassigned == []:
		app_state["runnigState"] = "finished"
		return app_state, channel_info
	
	tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
	unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
	
	# set all tpl.matched to False
	for channel in template_dict:
		template_dict[channel]["matched"] = False
	
	# initialize the cost matrix
	if len(unassigned) < 30:
		cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row
	else:
		cost_matrix = np.zeros((30, len(unassigned)))
	for i in range(30):
		for j in range(len(unassigned)):
			cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance
			#print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
	
	# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
	row_idx, col_idx = linear_sum_assignment(cost_matrix)
	
	matches = []
	new_idx = [[]]*30
	for i in range(30):
		if col_idx[i] < len(unassigned): # filter out dummy channels
			matches.append([row_idx[i], col_idx[i]])
			
			tpl_channel = template_order[row_idx[i]]
			in_channel = unassigned[col_idx[i]]
			template_dict[tpl_channel]["matched"] = True
			input_dict[in_channel]["assigned"] = True
			new_idx[i] = [input_dict[in_channel]["index"]]
			
			print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
	
	channel_info.update({
	    "templateByName" : template_dict,
	    "inputByName" : input_dict
	})
	app_state.update({
	    "stage2NewOrder" : new_idx,
	    "runnigState" : "stage2"
	})
	
	# fill the unmatched channels
	app_state = fill_channels(app_state, channel_info, fill_mode)
	
	second2 = time.time()
	print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.')
	return app_state, channel_info