Spaces:
Sleeping
Sleeping
Commit
·
4f8474c
1
Parent(s):
7e31e3e
init
Browse files
utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import csv
|
3 |
+
#from model import cumbersome_model2
|
4 |
+
#from model import UNet_family
|
5 |
+
#from model import UNet_attention
|
6 |
+
from model import tf_model
|
7 |
+
from model import tf_data
|
8 |
+
#from opts import get_opts
|
9 |
+
#from tools import pick_models
|
10 |
+
|
11 |
+
import time
|
12 |
+
import torch
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import shutil
|
16 |
+
from scipy.signal import decimate, resample_poly, firwin, lfilter
|
17 |
+
|
18 |
+
|
19 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
20 |
+
|
21 |
+
def resample(signal, fs):
|
22 |
+
# downsample the signal to a sample rate of 256 Hz
|
23 |
+
if fs>256:
|
24 |
+
fs_down = 256 # Desired sample rate
|
25 |
+
q = int(fs / fs_down) # Downsampling factor
|
26 |
+
signal_new = []
|
27 |
+
for ch in signal:
|
28 |
+
x_down = decimate(ch, q)
|
29 |
+
signal_new.append(x_down)
|
30 |
+
|
31 |
+
# upsample the signal to a sample rate of 256 Hz
|
32 |
+
elif fs<256:
|
33 |
+
fs_up = 256 # Desired sample rate
|
34 |
+
p = int(fs_up / fs) # Upsampling factor
|
35 |
+
signal_new = []
|
36 |
+
for ch in signal:
|
37 |
+
x_up = resample_poly(ch, p, 1)
|
38 |
+
signal_new.append(x_up)
|
39 |
+
|
40 |
+
else:
|
41 |
+
signal_new = signal
|
42 |
+
|
43 |
+
signal_new = np.array(signal_new).astype(np.float64)
|
44 |
+
|
45 |
+
return signal_new
|
46 |
+
|
47 |
+
def FIR_filter(signal, lowcut, highcut):
|
48 |
+
fs = 256.0
|
49 |
+
# Number of FIR filter taps
|
50 |
+
numtaps = 1000
|
51 |
+
# Use firwin to create a bandpass FIR filter
|
52 |
+
fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs)
|
53 |
+
# Apply the filter to signal:
|
54 |
+
filtered_signal = lfilter(fir_coeff, 1.0, signal)
|
55 |
+
|
56 |
+
return filtered_signal
|
57 |
+
|
58 |
+
|
59 |
+
def read_train_data(file_name):
|
60 |
+
with open(file_name, 'r', newline='') as f:
|
61 |
+
lines = csv.reader(f)
|
62 |
+
data = []
|
63 |
+
for line in lines:
|
64 |
+
data.append(line)
|
65 |
+
|
66 |
+
data = np.array(data).astype(np.float64)
|
67 |
+
return data
|
68 |
+
|
69 |
+
|
70 |
+
def cut_data(raw_data):
|
71 |
+
raw_data = np.array(raw_data).astype(np.float64)
|
72 |
+
total = int(len(raw_data[0]) / 1024)
|
73 |
+
for i in range(total):
|
74 |
+
table = raw_data[:, i * 1024:(i + 1) * 1024]
|
75 |
+
filename = './temp2/' + str(i) + '.csv'
|
76 |
+
with open(filename, 'w', newline='') as csvfile:
|
77 |
+
writer = csv.writer(csvfile)
|
78 |
+
writer.writerows(table)
|
79 |
+
return total
|
80 |
+
|
81 |
+
|
82 |
+
def glue_data(file_name, total, output):
|
83 |
+
gluedata = 0
|
84 |
+
for i in range(total):
|
85 |
+
file_name1 = file_name + 'output{}.csv'.format(str(i))
|
86 |
+
with open(file_name1, 'r', newline='') as f:
|
87 |
+
lines = csv.reader(f)
|
88 |
+
raw_data = []
|
89 |
+
for line in lines:
|
90 |
+
raw_data.append(line)
|
91 |
+
raw_data = np.array(raw_data).astype(np.float64)
|
92 |
+
#print(i)
|
93 |
+
if i == 0:
|
94 |
+
gluedata = raw_data
|
95 |
+
else:
|
96 |
+
smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2
|
97 |
+
gluedata[:, -1] = smooth
|
98 |
+
raw_data[:, 1] = smooth
|
99 |
+
gluedata = np.append(gluedata, raw_data, axis=1)
|
100 |
+
#print(gluedata.shape)
|
101 |
+
filename2 = output
|
102 |
+
with open(filename2, 'w', newline='') as csvfile:
|
103 |
+
writer = csv.writer(csvfile)
|
104 |
+
writer.writerows(gluedata)
|
105 |
+
#print("GLUE DONE!" + filename2)
|
106 |
+
|
107 |
+
|
108 |
+
def save_data(data, filename):
|
109 |
+
with open(filename, 'w', newline='') as csvfile:
|
110 |
+
writer = csv.writer(csvfile)
|
111 |
+
writer.writerows(data)
|
112 |
+
|
113 |
+
def dataDelete(path):
|
114 |
+
try:
|
115 |
+
shutil.rmtree(path)
|
116 |
+
except OSError as e:
|
117 |
+
print(e)
|
118 |
+
else:
|
119 |
+
pass
|
120 |
+
#print("The directory is deleted successfully")
|
121 |
+
|
122 |
+
|
123 |
+
def decode_data(data, std_num, mode=5):
|
124 |
+
|
125 |
+
if mode == "ICUNet":
|
126 |
+
# 1. read name
|
127 |
+
model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
|
128 |
+
resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
|
129 |
+
# 2. load model
|
130 |
+
checkpoint = torch.load(resumeLoc, map_location='cpu')
|
131 |
+
model.load_state_dict(checkpoint['state_dict'], False)
|
132 |
+
model.eval()
|
133 |
+
# 3. decode strategy
|
134 |
+
with torch.no_grad():
|
135 |
+
data = data[np.newaxis, :, :]
|
136 |
+
data = torch.Tensor(data)
|
137 |
+
decode = model(data)
|
138 |
+
|
139 |
+
|
140 |
+
elif mode == "UNetpp" or mode == "AttUnet":
|
141 |
+
# 1. read name
|
142 |
+
if mode == "UNetpp":
|
143 |
+
model = UNet_family.NestedUNet3(num_classes=30)
|
144 |
+
elif mode == "AttUnet":
|
145 |
+
model = UNet_attention.UNetpp3_Transformer(num_classes=30)
|
146 |
+
resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
|
147 |
+
# 2. load model
|
148 |
+
checkpoint = torch.load(resumeLoc, map_location='cpu')
|
149 |
+
model.load_state_dict(checkpoint['state_dict'], False)
|
150 |
+
model.eval()
|
151 |
+
# 3. decode strategy
|
152 |
+
with torch.no_grad():
|
153 |
+
data = data[np.newaxis, :, :]
|
154 |
+
data = torch.Tensor(data)
|
155 |
+
decode1, decode2, decode = model(data)
|
156 |
+
|
157 |
+
|
158 |
+
elif mode == "EEGART":
|
159 |
+
# 1. read name
|
160 |
+
resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
|
161 |
+
# 2. load model
|
162 |
+
checkpoint = torch.load(resumeLoc, map_location='cpu')
|
163 |
+
model = tf_model.make_model(30, 30, N=2)
|
164 |
+
model.load_state_dict(checkpoint['state_dict'])
|
165 |
+
model.eval()
|
166 |
+
# 3. decode strategy
|
167 |
+
with torch.no_grad():
|
168 |
+
data = torch.FloatTensor(data)
|
169 |
+
data = data.unsqueeze(0)
|
170 |
+
src = data
|
171 |
+
tgt = data
|
172 |
+
batch = tf_data.Batch(src, tgt, 0)
|
173 |
+
out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
|
174 |
+
decode = model.generator(out)
|
175 |
+
decode = decode.permute(0, 2, 1)
|
176 |
+
#add_tensor = torch.zeros(1, 30, 1)
|
177 |
+
#decode = torch.cat((decode, add_tensor), dim=2)
|
178 |
+
|
179 |
+
# 4. numpy
|
180 |
+
#print(decode.shape)
|
181 |
+
decode = np.array(decode.cpu()).astype(np.float64)
|
182 |
+
return decode
|
183 |
+
|
184 |
+
def preprocessing(filename, samplerate):
|
185 |
+
# establish temp folder
|
186 |
+
try:
|
187 |
+
os.mkdir("./temp2/")
|
188 |
+
except OSError as e:
|
189 |
+
dataDelete("./temp2/")
|
190 |
+
os.mkdir("./temp2/")
|
191 |
+
print(e)
|
192 |
+
|
193 |
+
# read data
|
194 |
+
signal = read_train_data(filename)
|
195 |
+
#print(signal.shape)
|
196 |
+
# resample
|
197 |
+
signal = resample(signal, samplerate)
|
198 |
+
#print(signal.shape)
|
199 |
+
# FIR_filter
|
200 |
+
signal = FIR_filter(signal, 1, 50)
|
201 |
+
#print(signal.shape)
|
202 |
+
# cutting data
|
203 |
+
total_file_num = cut_data(signal)
|
204 |
+
|
205 |
+
return total_file_num
|
206 |
+
|
207 |
+
|
208 |
+
# model = tf.keras.models.load_model('./denoise_model/')
|
209 |
+
def reconstruct(model_name, total, outputfile):
|
210 |
+
# -------------------decode_data---------------------------
|
211 |
+
second1 = time.time()
|
212 |
+
for i in range(total):
|
213 |
+
file_name = './temp2/{}.csv'.format(str(i))
|
214 |
+
data_noise = read_train_data(file_name)
|
215 |
+
|
216 |
+
std = np.std(data_noise)
|
217 |
+
avg = np.average(data_noise)
|
218 |
+
|
219 |
+
data_noise = (data_noise-avg)/std
|
220 |
+
|
221 |
+
# Deep Learning Artifact Removal
|
222 |
+
d_data = decode_data(data_noise, std, model_name)
|
223 |
+
d_data = d_data[0]
|
224 |
+
|
225 |
+
outputname = "./temp2/output{}.csv".format(str(i))
|
226 |
+
save_data(d_data, outputname)
|
227 |
+
|
228 |
+
# --------------------glue_data----------------------------
|
229 |
+
glue_data("./temp2/", total, outputfile)
|
230 |
+
# -------------------delete_data---------------------------
|
231 |
+
dataDelete("./temp2/")
|
232 |
+
second2 = time.time()
|
233 |
+
|
234 |
+
print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
|
235 |
+
|