Skip to content

Commit

Permalink
FIX: correctly pass task type to data subsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeurer committed Apr 16, 2021
1 parent 802c337 commit 79627e1
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def fit(
if X_test is not None:
X_test, y_test = self.InputValidator.transform(X_test, y_test)

self._task = task

X, y = self.subsample_if_too_large(
X=X,
y=y,
Expand Down Expand Up @@ -625,8 +627,6 @@ def fit(
)

self._backend._make_internals_directory()

self._task = datamanager.info['task']
self._label_num = datamanager.info['label_num']

# == Pickle the data manager to speed up loading
Expand Down Expand Up @@ -840,7 +840,14 @@ def _fit_cleanup(self):
return

@staticmethod
def subsample_if_too_large(X, y, logger, seed, memory_limit, task):
def subsample_if_too_large(
X: SUPPORTED_FEAT_TYPES,
y: SUPPORTED_TARGET_TYPES,
logger,
seed: int,
memory_limit: int,
task: int,
):
if memory_limit and isinstance(X, np.ndarray):
if X.dtype == np.float32:
multiplier = 4
Expand Down Expand Up @@ -884,12 +891,14 @@ def subsample_if_too_large(X, y, logger, seed, memory_limit, task):
train_size=new_num_samples,
random_state=seed,
)
else:
elif task in REGRESSION_TASKS:
X, _, y, _ = sklearn.model_selection.train_test_split(
X, y,
train_size=new_num_samples,
random_state=seed,
)
else:
raise ValueError(task)
return X, y

def refit(self, X, y):
Expand Down

0 comments on commit 79627e1

Please sign in to comment.