caleb2's picture
initial commit
d68c650
import unittest
import numpy as np
import h5py
import tempfile
from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data
class TestGetDataFromFile(unittest.TestCase):
def setUp(self):
self.temp_file = tempfile.NamedTemporaryFile(delete = False)
self.nx1, self.nx2, self.nx3 = 30, 20, 16
self.Nsamples = 10
with h5py.File(self.temp_file.name, 'w') as f:
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
self.temp_file1 = tempfile.NamedTemporaryFile(delete = False)
with h5py.File(self.temp_file1.name, 'w') as f:
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
self.bad_file = tempfile.NamedTemporaryFile(delete = False)
with h5py.File(self.bad_file.name, 'w') as f:
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
def tearDown(self):
self.temp_file.close()
self.temp_file1.close()
self.bad_file.close()
def test_missing_datasets(self):
with self.assertRaises(ValueError):
get_data_from_file(self.bad_file.name, self.nx2, self.nx2)
def test_data_extraction_shapes(self):
result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3)
self.assertEqual(result[0].shape, (self.Nsamples,))
self.assertEqual(result[1].shape, (self.Nsamples,))
self.assertEqual(result[2].shape, (self.Nsamples,))
self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2))
self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2))
def test_nrange_parameter(self):
Nrange = (2, 5)
result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
expected_size = Nrange[1] - Nrange[0]
self.assertEqual(result[0].shape, (expected_size,))
self.assertEqual(result[0].shape, (expected_size,))
def test_list_input(self):
file_list = [self.temp_file.name, self.temp_file1.name]
Nrange_list = [(0, -1), (0, -1)]
with self.assertRaises(TypeError):
# noinspection PyTypeChecker
_ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list)
def test_invalid_Nrange(self):
Nrange_list = [(0, -1), (0, -1)]
with self.assertRaises(TypeError):
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list)
for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']:
with self.assertRaises(TypeError):
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
with self.assertRaises(TypeError):
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange))
def test_good_data_load(self):
files = [self.temp_file.name, self.temp_file1.name]
Nrange_list = [(0, None), (0, self.Nsamples)]
ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
self.assertEqual(params.shape, (2 * self.Nsamples, 3))
self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2))
self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2))
test_size = 0.3
(params_train, bf_train, rho_train,
params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3,
ell1, ell2, a1, a2, test_size = test_size,
Nrange_list = Nrange_list)
Ntest = int(test_size * 2 * self.Nsamples)
Ntrain = 2 * self.Nsamples - Ntest
self.assertEqual(params_train.shape, (Ntrain, 3))
self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2))
self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2))
self.assertEqual(params_test.shape, (Ntest, 3))
self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2))
self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2))
def test_bad_data_load(self):
files = [self.temp_file.name, self.temp_file1.name]
Nrange_list = (0, -1)
ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
with self.assertRaises(TypeError):
_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
with self.assertRaises(TypeError):
_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list))
Nrange_list = [(0, -1), (0, -1)]
for test_size in [-1, 0.0, 1.5]:
with self.assertRaises(ValueError):
_ = load_training_data(files, self.nx2, self.nx3,
ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list)
if __name__ == '__main__':
unittest.main()