diff --git a/src/data_loading.py b/src/data_loading.py index 854723b..43c5d29 100644 --- a/src/data_loading.py +++ b/src/data_loading.py @@ -205,7 +205,7 @@ def __handle_tristan_v2_spectra(spectra_file_path: pathlib.Path, spectra_file: h if not data_log_scale: spectral_data = np.log10(spectral_data) - spectral_data /= spectra_file['ebins'] + spectral_data /= (spectra_file['ebins'][:])[:, np.newaxis] return spectral_data # ============================================================================= diff --git a/tests/data/tristan_v2/single_directory/spec.tot.00070 b/tests/data/tristan_v2/single_directory/spec.tot.00070 index 608f0ce..5c498c1 100644 Binary files a/tests/data/tristan_v2/single_directory/spec.tot.00070 and b/tests/data/tristan_v2/single_directory/spec.tot.00070 differ diff --git a/tests/data/tristan_v2/standard_structure/spec/spec.tot.00070 b/tests/data/tristan_v2/standard_structure/spec/spec.tot.00070 index 608f0ce..5c498c1 100644 Binary files a/tests/data/tristan_v2/standard_structure/spec/spec.tot.00070 and b/tests/data/tristan_v2/standard_structure/spec/spec.tot.00070 differ diff --git a/tests/test_data_loading.py b/tests/test_data_loading.py index aca19fb..8e6f3b3 100644 --- a/tests/test_data_loading.py +++ b/tests/test_data_loading.py @@ -204,7 +204,7 @@ def test_handle_tristan_v2_spectra(): for i,dataset_name in enumerate(dataset_names): test_spectra = data_loading.__handle_tristan_v2_spectra(file_path, file, dataset_name, cli_args) - assert np.array_equiv(test_spectra, np.full(5,i+1)) + assert np.array_equiv(test_spectra, np.full(5,i+2)) def test_handle_tristan_v2_spectra_cli_args(): parser = argparse.ArgumentParser() @@ -219,7 +219,7 @@ def test_handle_tristan_v2_spectra_cli_args(): for i,dataset_name in enumerate(dataset_names): test_spectra = data_loading.__handle_tristan_v2_spectra(file_path, file, dataset_name, cli_args) - assert np.array_equiv(test_spectra, np.full(5,i+1)) + assert np.array_equiv(test_spectra, np.full(5,i+2)) def test_handle_tristan_v2_spectra_raise_ValueError(): with pytest.raises(ValueError):