audrey06100 commited on
Commit
4f8474c
·
1 Parent(s): 7e31e3e
Files changed (1) hide show
  1. utils.py +235 -0
utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import csv
3
+ #from model import cumbersome_model2
4
+ #from model import UNet_family
5
+ #from model import UNet_attention
6
+ from model import tf_model
7
+ from model import tf_data
8
+ #from opts import get_opts
9
+ #from tools import pick_models
10
+
11
+ import time
12
+ import torch
13
+ import os
14
+ import random
15
+ import shutil
16
+ from scipy.signal import decimate, resample_poly, firwin, lfilter
17
+
18
+
19
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
20
+
21
+ def resample(signal, fs):
22
+ # downsample the signal to a sample rate of 256 Hz
23
+ if fs>256:
24
+ fs_down = 256 # Desired sample rate
25
+ q = int(fs / fs_down) # Downsampling factor
26
+ signal_new = []
27
+ for ch in signal:
28
+ x_down = decimate(ch, q)
29
+ signal_new.append(x_down)
30
+
31
+ # upsample the signal to a sample rate of 256 Hz
32
+ elif fs<256:
33
+ fs_up = 256 # Desired sample rate
34
+ p = int(fs_up / fs) # Upsampling factor
35
+ signal_new = []
36
+ for ch in signal:
37
+ x_up = resample_poly(ch, p, 1)
38
+ signal_new.append(x_up)
39
+
40
+ else:
41
+ signal_new = signal
42
+
43
+ signal_new = np.array(signal_new).astype(np.float64)
44
+
45
+ return signal_new
46
+
47
+ def FIR_filter(signal, lowcut, highcut):
48
+ fs = 256.0
49
+ # Number of FIR filter taps
50
+ numtaps = 1000
51
+ # Use firwin to create a bandpass FIR filter
52
+ fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs)
53
+ # Apply the filter to signal:
54
+ filtered_signal = lfilter(fir_coeff, 1.0, signal)
55
+
56
+ return filtered_signal
57
+
58
+
59
+ def read_train_data(file_name):
60
+ with open(file_name, 'r', newline='') as f:
61
+ lines = csv.reader(f)
62
+ data = []
63
+ for line in lines:
64
+ data.append(line)
65
+
66
+ data = np.array(data).astype(np.float64)
67
+ return data
68
+
69
+
70
+ def cut_data(raw_data):
71
+ raw_data = np.array(raw_data).astype(np.float64)
72
+ total = int(len(raw_data[0]) / 1024)
73
+ for i in range(total):
74
+ table = raw_data[:, i * 1024:(i + 1) * 1024]
75
+ filename = './temp2/' + str(i) + '.csv'
76
+ with open(filename, 'w', newline='') as csvfile:
77
+ writer = csv.writer(csvfile)
78
+ writer.writerows(table)
79
+ return total
80
+
81
+
82
+ def glue_data(file_name, total, output):
83
+ gluedata = 0
84
+ for i in range(total):
85
+ file_name1 = file_name + 'output{}.csv'.format(str(i))
86
+ with open(file_name1, 'r', newline='') as f:
87
+ lines = csv.reader(f)
88
+ raw_data = []
89
+ for line in lines:
90
+ raw_data.append(line)
91
+ raw_data = np.array(raw_data).astype(np.float64)
92
+ #print(i)
93
+ if i == 0:
94
+ gluedata = raw_data
95
+ else:
96
+ smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2
97
+ gluedata[:, -1] = smooth
98
+ raw_data[:, 1] = smooth
99
+ gluedata = np.append(gluedata, raw_data, axis=1)
100
+ #print(gluedata.shape)
101
+ filename2 = output
102
+ with open(filename2, 'w', newline='') as csvfile:
103
+ writer = csv.writer(csvfile)
104
+ writer.writerows(gluedata)
105
+ #print("GLUE DONE!" + filename2)
106
+
107
+
108
+ def save_data(data, filename):
109
+ with open(filename, 'w', newline='') as csvfile:
110
+ writer = csv.writer(csvfile)
111
+ writer.writerows(data)
112
+
113
+ def dataDelete(path):
114
+ try:
115
+ shutil.rmtree(path)
116
+ except OSError as e:
117
+ print(e)
118
+ else:
119
+ pass
120
+ #print("The directory is deleted successfully")
121
+
122
+
123
+ def decode_data(data, std_num, mode=5):
124
+
125
+ if mode == "ICUNet":
126
+ # 1. read name
127
+ model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
128
+ resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
129
+ # 2. load model
130
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
131
+ model.load_state_dict(checkpoint['state_dict'], False)
132
+ model.eval()
133
+ # 3. decode strategy
134
+ with torch.no_grad():
135
+ data = data[np.newaxis, :, :]
136
+ data = torch.Tensor(data)
137
+ decode = model(data)
138
+
139
+
140
+ elif mode == "UNetpp" or mode == "AttUnet":
141
+ # 1. read name
142
+ if mode == "UNetpp":
143
+ model = UNet_family.NestedUNet3(num_classes=30)
144
+ elif mode == "AttUnet":
145
+ model = UNet_attention.UNetpp3_Transformer(num_classes=30)
146
+ resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
147
+ # 2. load model
148
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
149
+ model.load_state_dict(checkpoint['state_dict'], False)
150
+ model.eval()
151
+ # 3. decode strategy
152
+ with torch.no_grad():
153
+ data = data[np.newaxis, :, :]
154
+ data = torch.Tensor(data)
155
+ decode1, decode2, decode = model(data)
156
+
157
+
158
+ elif mode == "EEGART":
159
+ # 1. read name
160
+ resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
161
+ # 2. load model
162
+ checkpoint = torch.load(resumeLoc, map_location='cpu')
163
+ model = tf_model.make_model(30, 30, N=2)
164
+ model.load_state_dict(checkpoint['state_dict'])
165
+ model.eval()
166
+ # 3. decode strategy
167
+ with torch.no_grad():
168
+ data = torch.FloatTensor(data)
169
+ data = data.unsqueeze(0)
170
+ src = data
171
+ tgt = data
172
+ batch = tf_data.Batch(src, tgt, 0)
173
+ out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
174
+ decode = model.generator(out)
175
+ decode = decode.permute(0, 2, 1)
176
+ #add_tensor = torch.zeros(1, 30, 1)
177
+ #decode = torch.cat((decode, add_tensor), dim=2)
178
+
179
+ # 4. numpy
180
+ #print(decode.shape)
181
+ decode = np.array(decode.cpu()).astype(np.float64)
182
+ return decode
183
+
184
+ def preprocessing(filename, samplerate):
185
+ # establish temp folder
186
+ try:
187
+ os.mkdir("./temp2/")
188
+ except OSError as e:
189
+ dataDelete("./temp2/")
190
+ os.mkdir("./temp2/")
191
+ print(e)
192
+
193
+ # read data
194
+ signal = read_train_data(filename)
195
+ #print(signal.shape)
196
+ # resample
197
+ signal = resample(signal, samplerate)
198
+ #print(signal.shape)
199
+ # FIR_filter
200
+ signal = FIR_filter(signal, 1, 50)
201
+ #print(signal.shape)
202
+ # cutting data
203
+ total_file_num = cut_data(signal)
204
+
205
+ return total_file_num
206
+
207
+
208
+ # model = tf.keras.models.load_model('./denoise_model/')
209
+ def reconstruct(model_name, total, outputfile):
210
+ # -------------------decode_data---------------------------
211
+ second1 = time.time()
212
+ for i in range(total):
213
+ file_name = './temp2/{}.csv'.format(str(i))
214
+ data_noise = read_train_data(file_name)
215
+
216
+ std = np.std(data_noise)
217
+ avg = np.average(data_noise)
218
+
219
+ data_noise = (data_noise-avg)/std
220
+
221
+ # Deep Learning Artifact Removal
222
+ d_data = decode_data(data_noise, std, model_name)
223
+ d_data = d_data[0]
224
+
225
+ outputname = "./temp2/output{}.csv".format(str(i))
226
+ save_data(d_data, outputname)
227
+
228
+ # --------------------glue_data----------------------------
229
+ glue_data("./temp2/", total, outputfile)
230
+ # -------------------delete_data---------------------------
231
+ dataDelete("./temp2/")
232
+ second2 = time.time()
233
+
234
+ print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
235
+