Skip to content

Commit

Permalink
Test preprocessing (#26)
Browse files Browse the repository at this point in the history
* fixing #24

* fixing afb subtraction for test data

* bumping bug fix version number

* division of test data by training standard deviation
  • Loading branch information
htjb authored Jan 16, 2024
1 parent fd89b48 commit bbe7934
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:globalemu: Robust Global 21-cm Signal Emulation
:Author: Harry Thomas Jones Bevins
:Version: 1.8.0
:Version: 1.8.1
:Homepage: https://github.com/htjb/globalemu
:Documentation: https://globalemu.readthedocs.io/

Expand Down
18 changes: 8 additions & 10 deletions globalemu/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,9 @@ class process():
data set or not. Set to True by default as this is advised for
training both neutral fraction and global signal emulators.
logs: **list / default: [0, 1, 2]**
logs: **list / default: []**
| The indices corresponding to the astrophysical parameters in
"train_data.txt" that need to be logged. The default assumes
that the first three columns in "train_data.txt" are
:math:`{f_*}` (star formation efficiency),
:math:`{V_c}` (minimum virial circular velocity) and
:math:`{f_x}` (X-ray efficieny).
"train_data.txt" that need to be logged.
"""

def __init__(self, num, z, **kwargs):
Expand Down Expand Up @@ -137,7 +133,7 @@ def __init__(self, num, z, **kwargs):
if type(bool_kwargs[i]) is not bool:
raise TypeError(bool_strings[i] + " must be a bool.")

self.logs = kwargs.pop('logs', [0, 1, 2])
self.logs = kwargs.pop('logs', [])
if type(self.logs) is not list:
raise TypeError("'logs' must be a list.")

Expand Down Expand Up @@ -170,7 +166,6 @@ def load_data(file):
train_data = full_train_data.copy()
if self.preprocess_settings['AFB'] is True:
train_labels = full_train_labels.copy() - res.deltaT
test_labels -= res.deltaT
else:
train_labels = full_train_labels.copy()
else:
Expand All @@ -189,10 +184,14 @@ def load_data(file):
train_data.append(full_train_data[i, :])
if self.preprocess_settings['AFB'] is True:
train_labels.append(full_train_labels[i] - res.deltaT)

else:
train_labels.append(full_train_labels[i])
train_data, train_labels = np.array(train_data), \
np.array(train_labels)

if self.preprocess_settings['AFB'] is True:
test_labels = test_labels.copy() - res.deltaT

log_train_data = []
for i in range(train_data.shape[1]):
Expand Down Expand Up @@ -268,9 +267,8 @@ def load_data(file):
norm_train_labels = norm_train_labels.flatten()
np.save(self.base_dir + 'labels_stds.npy', labels_stds)

test_labels_stds = test_labels.std()
norm_test_labels = [
test_labels[i, :]/test_labels_stds
test_labels[i, :]/labels_stds
for i in range(test_labels.shape[0])]
norm_test_labels = np.array(norm_test_labels)

Expand Down

0 comments on commit bbe7934

Please sign in to comment.