From 4c9eca2d82a3210de1ab2bc5b03ac24cb7521179 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Tue, 26 Jul 2022 18:35:35 -0700 Subject: [PATCH] add unit test --- .../tvm/meta_schedule/cost_model/__init__.py | 2 +- .../tvm/meta_schedule/cost_model/xgb_model.py | 124 ++---------------- .../unittest/test_meta_schedule_cost_model.py | 92 ++++++++++++- 3 files changed, 103 insertions(+), 115 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 8fc6f04ac9558..47b418d5db129 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -19,4 +19,4 @@ """ from .cost_model import CostModel, PyCostModel from .random_model import RandomModel -from .xgb_model import XGBModel +from .xgb_model import XGBModel, XGBoostCustomCallback, PackSum diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 8cec77a857361..50f54f5da0881 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -35,15 +35,15 @@ from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve -try: - from xgboost.callback import TrainingCallback # type: ignore -except ImportError: - class TrainingCallback: # type: ignore - pass +if TYPE_CHECKING: + try: + from xgboost.callback import TrainingCallback # type: ignore + except ImportError: + class TrainingCallback: # type: ignore + pass -if TYPE_CHECKING: import xgboost as xgb # type: ignore from ..tune_context import TuneContext @@ -674,114 +674,8 @@ def init(env: "xgb.core.CallbackEnv"): booster.set_attr(best_iteration=str(state["best_iteration"])) booster.set_attr(best_score=str(state["best_score"])) - def callback(env: "xgb.core.CallbackEnv"): - # pylint:disable = import-outside-toplevel - import xgboost as xgb - from xgboost.callback import _fmt_metric # type: ignore - from xgboost.core import EarlyStopException # type: ignore - - try: - from xgboost.training import aggcv # type: ignore - except ImportError: - from xgboost.callback import _aggcv as aggcv # type: ignore - # pylint:enable = import-outside-toplevel - - if not state: - init(env) - booster: xgb.Booster = env.model - iteration: int = env.iteration - cvfolds: List[xgb.training.CVPack] = env.cvfolds - ##### Evaluation ##### - # `eval_result` is a list of (key, score) - eval_result: List[Tuple[str, float]] = [] - if cvfolds is None: - eval_result = list( - itertools_chain.from_iterable( - [ - (key, float(value)) - for key, value in map( - lambda x: x.split(":"), - booster.eval_set( - evals=evals, - iteration=iteration, - feval=feval, - ).split()[1:], - ) - ] - for feval in fevals - ) - ) - else: - eval_result = list( - itertools_chain.from_iterable( - [ - (key, score) - for key, score, _std in aggcv( - fold.eval( - iteration=iteration, - feval=feval, - ) - for fold in cvfolds - ) - ] - for feval in fevals - ) - ) - eval_result = list(eval_result) - eval_result.sort(key=sort_key) - - ##### Print eval result ##### - if verbose_eval and iteration % verbose_eval == 0: - info = [] - for key, score in eval_result: - if "null" not in key: - info.append(f"{key}: {score:.6f}") - logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) - - ##### Choose score and do early stopping ##### - score = None - for key, _score in eval_result: - if key == focused_metric: - score = _score - break - assert score is not None - best_score = state["best_score"] - best_iteration = state["best_iteration"] - if score < best_score: - tab = "\t" # to work with f-string - msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" - state["best_msg"] = msg - state["best_score"] = score - state["best_iteration"] = env.iteration - # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr( - best_score=str(state["best_score"]), - best_iteration=str(state["best_iteration"]), - best_msg=state["best_msg"], - ) - elif env.iteration - best_iteration >= early_stopping_rounds: - best_msg = state["best_msg"] - if verbose_eval and env.rank == 0: - logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) - - return callback - - -class XGBoostCallback(TrainingCallback): - """Base class for XGBoost callbacks.""" - - def __call__(self, env: "xgb.core.CallbackEnv"): - # Compatibility with xgboost < 1.3 - return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) - - def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: Dict): - raise NotImplementedError - - -class XGBoostCustomCallback(XGBoostCallback): +class XGBoostCustomCallback(TrainingCallback): """Custom callback class for xgboost to support multiple custom evaluation functions""" def __init__( @@ -804,6 +698,10 @@ def __init__( if cvfolds is not None: self.aggregated_cv = None + def __call__(self, env: "xgb.core.CallbackEnv"): + # Compatibility with xgboost < 1.3 + return self.after_iteration(env.model, env.iteration, env.evaluation_result_list) + def init(self, model: "xgb.Booster"): """Internal function for intialization""" booster: "xgb.Booster" = model diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index d1d5581813245..91c84bbdb88bc 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -26,7 +26,13 @@ import pytest import tvm import tvm.testing -from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.cost_model import ( + PyCostModel, + RandomModel, + XGBModel, + XGBoostCustomCallback, + PackSum, +) from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.search_strategy import MeasureCandidate @@ -228,5 +234,89 @@ def test_meta_schedule_xgb_model_reupdate(): model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) +def test_meta_schedule_xgb_model_callback(): + import xgboost as xgb + from itertools import chain as itertools_chain + from functools import partial + + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + with tempfile.NamedTemporaryFile() as path: + # Backup and train on new TrainingCallBack api + random_state = model.extractor.random_state # save feature extractor's random state + + model.save(path.name) + + old_booster = model.booster + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred1 = old_booster.predict(d_test.dmatrix) + + # Load and train on deprecated TrainingCallBack api + model.extractor.random_state = random_state # load feature extractor's random state + model.load(path.name) + d_train = PackSum( + xs=list(itertools_chain.from_iterable([g.features for g in model.data.values()])), + ys=np.concatenate( + [g.min_cost / g.costs for g in model.data.values()], + axis=0, + ), + ) + + def obj(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.rmse(ys_pred) + + def avg_peak_score(ys_pred: np.ndarray, d_train1: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return d_train.average_peak_score(ys_pred, model.average_peak_n) + + new_booster = xgb.train( + model.config.to_dict(), + d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + partial( + XGBoostCustomCallback( + early_stopping_rounds=model.early_stopping_rounds, + verbose_eval=model.verbose_eval, + fevals=[rmse, avg_peak_score], + evals=[(d_train.dmatrix, "tr")], + cvfolds=None, + ) + ) + ], + ) + + xs = [ + x.numpy().astype("float32") + for x in extractor.extract_from( + TuneContext(), + [_dummy_candidate() for i in range(predict_sample_count)], + ) + ] + d_test = PackSum(xs=xs, ys=None) + pred2 = new_booster.predict(d_test.dmatrix) + + assert np.allclose(pred1, pred2, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": tvm.testing.main()