Skip to content

Commit

Permalink
Fix writer issue in xgboost hook (aws#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward J Kim authored Nov 1, 2019
1 parent e7ed26e commit bdf5459
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 30 deletions.
6 changes: 0 additions & 6 deletions tests/xgboost/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_default_hook(monkeypatch):
assert hook.out_dir == DEFAULT_SAGEMAKER_TORNASOLE_PATH


@pytest.mark.slow # 0:05 to run
def test_hook_save_all(tmpdir):
reset_collections()
save_config = SaveConfig(save_steps=[0, 1, 2, 3])
Expand All @@ -82,7 +81,6 @@ def test_hook_save_all(tmpdir):
assert len(collections["all"].tensor_names) == len(tensors)


@pytest.mark.slow # 0:05 to run
def test_hook_save_config_collections(tmpdir):
reset_collections()
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
Expand All @@ -101,7 +99,6 @@ def test_hook_save_config_collections(tmpdir):
assert all(step % 3 == 0 for step in fimp_steps[:-1])


@pytest.mark.slow # 0:05 to run
def test_hook_shap(tmpdir):
np.random.seed(42)
train_data = np.random.rand(10, 10)
Expand All @@ -120,7 +117,6 @@ def test_hook_shap(tmpdir):
assert any(t.endswith("/average_shap") for t in tensors)


@pytest.mark.slow # 0:05 to run
def test_hook_validation(tmpdir):
np.random.seed(42)
train_data = np.random.rand(5, 10)
Expand Down Expand Up @@ -149,7 +145,6 @@ def test_hook_validation(tmpdir):
assert "predictions" in tensors


@pytest.mark.slow # 0:05 to run
def test_hook_tree_model(tmpdir):
np.random.seed(42)
train_data = np.random.rand(5, 10)
Expand All @@ -172,7 +167,6 @@ def test_hook_tree_model(tmpdir):
assert "trees/{}".format(col) in tensors


@pytest.mark.slow # 0:05 to run
def test_hook_params(tmpdir):
np.random.seed(42)
train_data = np.random.rand(5, 10)
Expand Down
45 changes: 21 additions & 24 deletions tornasole/xgboost/hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import atexit
import os
from typing import Optional, List, Union, Tuple, Dict, Any
import numpy as np
Expand Down Expand Up @@ -97,9 +96,7 @@ def __init__(
self.hyperparameters = hyperparameters
self.train_data = self._validate_data(train_data)
self.validation_data = self._validate_data(validation_data)
# as we do cleanup ourselves at end of job
self.worker = self.get_worker_name()
atexit.unregister(self._cleanup)
set_hook(self)

def __call__(self, env: CallbackEnv) -> None:
Expand All @@ -113,16 +110,23 @@ def get_worker_name(self):

@classmethod
def hook_from_config(cls, json_config_path=None):
"""Relies on the existence of a JSON file.
First, check json_config_path. If it's not None,
If the file exists, use that.
If the file does not exist, throw an error.
Otherwise, check the filepath set by a SageMaker environment variable.
If the file exists, use that.
Otherwise,
return None.
"""
return create_hook_from_json_config(
cls, get_collection_manager(), json_config_path=json_config_path
)

def _cleanup(self):
# todo: this second export should go
self.export_collections()
training_has_ended(self.out_dir)

def _is_last_step(self, env: CallbackEnv) -> bool:
# env.iteration: current boosting round.
# env.end_iteration: round # when training will end. this is always num_round + 1. # noqa: E501
return env.iteration + 1 == env.end_iteration

def _is_collection_being_saved_for_step(self, name):
Expand All @@ -134,25 +138,23 @@ def _increment_step(self, iteration):
self._collections_to_save_for_step = None

def _callback(self, env: CallbackEnv) -> None:
# env.rank: rabit rank of the node/process. master node has rank 0.
# env.iteration: current boosting round.
# env.begin_iteration: round # when training started. this is always 0.
# env.end_iteration: round # when training will end. this is always num_round + 1. # noqa: E501
# env.model: model object.
# Write the tensors from the previous step if the write is still available.
self._close_writer()

if not self.prepared_collections:
# at this point we need all collections to be ready
# this may not be the case at creation of hook
# as user's code after hook might add collections
self._prepare_collections()
self.prepared_collections = True

if not self.exported_collections:
self._increment_step(env.iteration)

if self.last_saved_step is not None and not self.exported_collections:
self.export_collections()
self.exported_collections = True

self._increment_step(env.iteration)

if not self._is_last_step(env) and not self._get_collections_to_save_for_step():
if not self._get_collections_to_save_for_step():
self.logger.debug("Skipping iteration {}".format(self.step))
return

Expand All @@ -179,12 +181,7 @@ def _callback(self, env: CallbackEnv) -> None:
if self._is_collection_being_saved_for_step(CollectionKeys.TREES):
self.write_tree_model(env)

if not self._is_last_step(env):
self._close_writer()

if self._is_last_step(env):
self._cleanup()

self.last_saved_step = self.step
self.logger.info("Saved iteration {}.".format(self.step))

def write_hyperparameters(self, env: CallbackEnv):
Expand Down Expand Up @@ -251,7 +248,7 @@ def _write_for_tensor(self, tensor_name, tensor_value, save_collections):

@staticmethod
def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):
raise NotImplementedError("Reductions are not support by XGBoost hook")
raise NotImplementedError("Reductions are not supported by XGBoost hook")

@staticmethod
def _make_numpy_array(tensor_value):
Expand Down

0 comments on commit bdf5459

Please sign in to comment.