AIEEG / utils.py
audrey06100's picture
update
94fb4eb
import numpy as np
import csv
from model import cumbersome_model2
from model import UNet_family
from model import UNet_attention
from model import tf_model
from model import tf_data
import time
import torch
import os
import random
import shutil
from scipy.signal import decimate, resample_poly, firwin, lfilter
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def resample(signal, fs, tgt_fs):
# downsample the signal to the target sample rate
if fs>tgt_fs:
fs_down = tgt_fs # Desired sample rate
q = int(fs / fs_down) # Downsampling factor
signal_new = []
for ch in signal:
x_down = decimate(ch, q)
signal_new.append(x_down)
# upsample the signal to the target sample rate
elif fs<tgt_fs:
fs_up = tgt_fs # Desired sample rate
p = int(fs_up / fs) # Upsampling factor
signal_new = []
for ch in signal:
x_up = resample_poly(ch, p, 1)
signal_new.append(x_up)
else:
signal_new = signal
signal_new = np.array(signal_new).astype(np.float64)
return signal_new
def FIR_filter(signal, lowcut, highcut):
fs = 256.0
# Number of FIR filter taps
numtaps = 1000
# Use firwin to create a bandpass FIR filter
fir_coeff = firwin(numtaps, [lowcut, highcut], pass_zero=False, fs=fs)
# Apply the filter to signal:
filtered_signal = lfilter(fir_coeff, 1.0, signal)
return filtered_signal
def read_train_data(file_name):
with open(file_name, 'r', newline='') as f:
lines = csv.reader(f)
data = []
for line in lines:
data.append(line)
data = np.array(data).astype(np.float64)
return data
def cut_data(filepath, raw_data):
raw_data = np.array(raw_data).astype(np.float64)
total = int(len(raw_data[0]) / 1024)
for i in range(total):
table = raw_data[:, i * 1024:(i + 1) * 1024]
filename = filepath + 'temp2/' + str(i) + '.csv'
with open(filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerows(table)
return total
def glue_data(file_name, total):
gluedata = 0
for i in range(total):
file_name1 = file_name + 'output{}.csv'.format(str(i))
with open(file_name1, 'r', newline='') as f:
lines = csv.reader(f)
raw_data = []
for line in lines:
raw_data.append(line)
raw_data = np.array(raw_data).astype(np.float64)
#print(i)
if i == 0:
gluedata = raw_data
else:
smooth = (gluedata[:, -1] + raw_data[:, 1]) / 2
gluedata[:, -1] = smooth
raw_data[:, 1] = smooth
gluedata = np.append(gluedata, raw_data, axis=1)
#print(gluedata.shape)
return gluedata
def save_data(data, filename):
with open(filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerows(data)
def dataDelete(path):
try:
shutil.rmtree(path)
except OSError as e:
pass
#print(e)
else:
pass
#print("The directory is deleted successfully")
def decode_data(data, std_num, mode=5):
if mode == "ICUNet":
# 1. read name
model = cumbersome_model2.UNet1(n_channels=30, n_classes=30).to(device)
resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
# 2. load model
checkpoint = torch.load(resumeLoc, map_location=device)
model.load_state_dict(checkpoint['state_dict'], False)
model.eval()
# 3. decode strategy
with torch.no_grad():
data = data[np.newaxis, :, :]
data = torch.Tensor(data).to(device)
decode = model(data)
elif mode == "ICUNet++" or mode == "ICUNet_attn":
# 1. read name
if mode == "ICUNet++":
model = UNet_family.NestedUNet3(num_classes=30).to(device)
elif mode == "ICUNet_attn":
model = UNet_attention.UNetpp3_Transformer(num_classes=30).to(device)
resumeLoc = './model/' + mode + '/modelsave' + '/checkpoint.pth.tar'
# 2. load model
checkpoint = torch.load(resumeLoc, map_location=device)
model.load_state_dict(checkpoint['state_dict'], False)
model.eval()
# 3. decode strategy
with torch.no_grad():
data = data[np.newaxis, :, :]
data = torch.Tensor(data).to(device)
decode1, decode2, decode = model(data)
elif mode == "ART":
# 1. read name
resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
# 2. load model
checkpoint = torch.load(resumeLoc, map_location=device)
model = tf_model.make_model(30, 30, N=2).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# 3. decode strategy
with torch.no_grad():
data = torch.FloatTensor(data).to(device)
data = data.unsqueeze(0)
src = data
tgt = data # you can modify to randomize data
batch = tf_data.Batch(src, tgt, 0)
out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
decode = model.generator(out)
decode = decode.permute(0, 2, 1)
add_tensor = torch.zeros(1, 30, 1).to(device)
decode = torch.cat((decode, add_tensor), dim=2)
# 4. numpy
#print(decode.shape)
decode = np.array(decode.cpu()).astype(np.float64)
return decode
def reorder_data(raw_data, mapping_result):
new_data = np.zeros((30, raw_data.shape[1]))
zero_arr = np.zeros((1, raw_data.shape[1]))
for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
if flag == True:
new_data[i, :] = raw_data[indices[0], :]
elif indices[0] == None:
new_data[i, :] = zero_arr
else:
data = [raw_data[idx, :] for idx in indices]
new_data[i, :] = np.mean(data, axis=0)
return new_data
def preprocessing(filepath, inputfile, samplerate, mapping_result):
# establish temp folder
try:
os.mkdir(filepath+"temp2/")
except OSError as e:
dataDelete(filepath+"temp2/")
os.mkdir(filepath+"temp2/")
print(e)
# read data
signal = read_train_data(inputfile)
#print(signal.shape)
# channel mapping
signal = reorder_data(signal, mapping_result)
#print(signal.shape)
# resample
signal = resample(signal, samplerate, 256)
#print(signal.shape)
# FIR_filter
signal = FIR_filter(signal, 1, 50)
#print(signal.shape)
# cutting data
total_file_num = cut_data(filepath, signal)
return total_file_num
def restore_order(data, all_data, mapping_result):
for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
if flag == True:
all_data[indices[0], :] = data[i, :]
return all_data
def postprocessing(data, samplerate, outputfile, mapping_result, batch_cnt, channel_num):
# resample to original sampling rate
data = resample(data, 256, samplerate)
# reverse channel mapping
all_data = np.zeros((channel_num, data.shape[1])) if batch_cnt==0 else read_train_data(outputfile)
all_data = restore_order(data, all_data, mapping_result)
# save data
save_data(all_data, outputfile)
# model = tf.keras.models.load_model('./denoise_model/')
def reconstruct(model_name, total, filepath, batch_cnt):
# -------------------decode_data---------------------------
second1 = time.time()
for i in range(total):
file_name = filepath + 'temp2/{}.csv'.format(str(i))
data_noise = read_train_data(file_name)
std = np.std(data_noise)
avg = np.average(data_noise)
data_noise = (data_noise-avg)/std
# Deep Learning Artifact Removal
d_data = decode_data(data_noise, std, model_name)
d_data = d_data[0]
outputname = filepath + 'temp2/output{}.csv'.format(str(i))
save_data(d_data, outputname)
# --------------------glue_data----------------------------
data = glue_data(filepath+"temp2/", total)
# -------------------delete_data---------------------------
dataDelete(filepath+"temp2/")
second2 = time.time()
print(f"Using {model_name} model to reconstruct batch-{batch_cnt+1} has been success in {second2 - second1} sec(s)")
return data