import unittest import numpy as np import h5py import tempfile from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data class TestGetDataFromFile(unittest.TestCase): def setUp(self): self.temp_file = tempfile.NamedTemporaryFile(delete = False) self.nx1, self.nx2, self.nx3 = 30, 20, 16 self.Nsamples = 10 with h5py.File(self.temp_file.name, 'w') as f: f.create_dataset('ell', data = np.random.rand(self.Nsamples)) f.create_dataset('a1', data = np.random.rand(self.Nsamples)) f.create_dataset('a2', data = np.random.rand(self.Nsamples)) f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) self.temp_file1 = tempfile.NamedTemporaryFile(delete = False) with h5py.File(self.temp_file1.name, 'w') as f: f.create_dataset('ell', data = np.random.rand(self.Nsamples)) f.create_dataset('a1', data = np.random.rand(self.Nsamples)) f.create_dataset('a2', data = np.random.rand(self.Nsamples)) f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2)) self.bad_file = tempfile.NamedTemporaryFile(delete = False) with h5py.File(self.bad_file.name, 'w') as f: f.create_dataset('ell', data = np.random.rand(self.Nsamples)) f.create_dataset('a1', data = np.random.rand(self.Nsamples)) f.create_dataset('a2', data = np.random.rand(self.Nsamples)) f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2)) def tearDown(self): self.temp_file.close() self.temp_file1.close() self.bad_file.close() def test_missing_datasets(self): with self.assertRaises(ValueError): get_data_from_file(self.bad_file.name, self.nx2, self.nx2) def test_data_extraction_shapes(self): result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3) self.assertEqual(result[0].shape, (self.Nsamples,)) self.assertEqual(result[1].shape, (self.Nsamples,)) self.assertEqual(result[2].shape, (self.Nsamples,)) self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2)) self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2)) def test_nrange_parameter(self): Nrange = (2, 5) result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange) expected_size = Nrange[1] - Nrange[0] self.assertEqual(result[0].shape, (expected_size,)) self.assertEqual(result[0].shape, (expected_size,)) def test_list_input(self): file_list = [self.temp_file.name, self.temp_file1.name] Nrange_list = [(0, -1), (0, -1)] with self.assertRaises(TypeError): # noinspection PyTypeChecker _ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list) def test_invalid_Nrange(self): Nrange_list = [(0, -1), (0, -1)] with self.assertRaises(TypeError): _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list) for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']: with self.assertRaises(TypeError): _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange) with self.assertRaises(TypeError): _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange)) def test_good_data_load(self): files = [self.temp_file.name, self.temp_file1.name] Nrange_list = [(0, None), (0, self.Nsamples)] ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0 params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list) self.assertEqual(params.shape, (2 * self.Nsamples, 3)) self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2)) self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2)) test_size = 0.3 (params_train, bf_train, rho_train, params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list) Ntest = int(test_size * 2 * self.Nsamples) Ntrain = 2 * self.Nsamples - Ntest self.assertEqual(params_train.shape, (Ntrain, 3)) self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2)) self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2)) self.assertEqual(params_test.shape, (Ntest, 3)) self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2)) self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2)) def test_bad_data_load(self): files = [self.temp_file.name, self.temp_file1.name] Nrange_list = (0, -1) ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0 with self.assertRaises(TypeError): _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list) with self.assertRaises(TypeError): _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list)) Nrange_list = [(0, -1), (0, -1)] for test_size in [-1, 0.0, 1.5]: with self.assertRaises(ValueError): _ = load_training_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list) if __name__ == '__main__': unittest.main()