diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 64f1cb31edaa..4883c4fba605 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3503,7 +3503,21 @@ def predict(self, data, start_iteration=0, num_iteration=None, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape) - def refit(self, data, label, decay_rate=0.9, **kwargs): + def refit( + self, + data, + label, + decay_rate=0.9, + reference=None, + weight=None, + group=None, + init_score=None, + feature_name='auto', + categorical_feature='auto', + dataset_params=None, + free_raw_data=True, + **kwargs + ): """Refit the existing Booster by new data. Parameters @@ -3516,6 +3530,35 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): decay_rate : float, optional (default=0.9) Decay rate of refit, will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees. + reference : Dataset or None, optional (default=None) + Reference for ``data``. + weight : list, numpy 1-D array, pandas Series or None, optional (default=None) + Weight for each ``data`` instance. Weight should be non-negative values because the Hessian + value multiplied by weight is supposed to be non-negative. + group : list, numpy 1-D array, pandas Series or None, optional (default=None) + Group/query size for ``data``. + Only used in the learning-to-rank task. + sum(group) = n_samples. + For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, + where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None) + Init score for ``data``. + feature_name : list of str, or 'auto', optional (default="auto") + Feature names for ``data``. + If 'auto' and data is pandas DataFrame, data columns names are used. + categorical_feature : list of str or int, or 'auto', optional (default="auto") + Categorical features for ``data``. + If list of int, interpreted as indices. + If list of str, interpreted as feature names (need to specify ``feature_name`` as well). + If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used. + All values in categorical features should be less than int32 max value (2147483647). + Large values could be memory consuming. Consider using consecutive integers starting from zero. + All negative values in categorical features will be treated as missing values. + The output cannot be monotonically constrained with respect to a categorical feature. + dataset_params : dict or None, optional (default=None) + Other parameters for Dataset ``data``. + free_raw_data : bool, optional (default=True) + If True, raw data is freed after constructing inner Dataset for ``data``. **kwargs Other parameters for refit. These parameters will be passed to ``predict`` method. @@ -3527,6 +3570,8 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): """ if self.__set_objective_to_none: raise LightGBMError('Cannot refit due to null objective function.') + if dataset_params is None: + dataset_params = {} predictor = self._to_predictor(deepcopy(kwargs)) leaf_preds = predictor.predict(data, -1, pred_leaf=True) nrow, ncol = leaf_preds.shape @@ -3540,7 +3585,19 @@ def refit(self, data, label, decay_rate=0.9, **kwargs): default_value=None ) new_params["linear_tree"] = bool(out_is_linear.value) - train_set = Dataset(data, label, params=new_params) + new_params.update(dataset_params) + train_set = Dataset( + data=data, + label=label, + reference=reference, + weight=weight, + group=group, + init_score=init_score, + feature_name=feature_name, + categorical_feature=categorical_feature, + params=new_params, + free_raw_data=free_raw_data, + ) new_params['refit_decay_rate'] = decay_rate new_booster = Booster(new_params, train_set) # Copy models diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 2d0ed6d86293..280f19af989a 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1545,6 +1545,40 @@ def test_refit(): assert err_pred > new_err_pred +def test_refit_dataset_params(): + # check refit accepts dataset_params + X, y = load_breast_cancer(return_X_y=True) + lgb_train = lgb.Dataset(X, y, init_score=np.zeros(y.size)) + train_params = { + 'objective': 'binary', + 'verbose': -1, + 'seed': 123 + } + gbm = lgb.train(train_params, lgb_train, num_boost_round=10) + non_weight_err_pred = log_loss(y, gbm.predict(X)) + refit_weight = np.random.rand(y.shape[0]) + dataset_params = { + 'max_bin': 260, + 'min_data_in_bin': 5, + 'data_random_seed': 123, + } + new_gbm = gbm.refit( + data=X, + label=y, + weight=refit_weight, + dataset_params=dataset_params, + decay_rate=0.0, + ) + weight_err_pred = log_loss(y, new_gbm.predict(X)) + train_set_params = new_gbm.train_set.get_params() + stored_weights = new_gbm.train_set.get_weight() + assert weight_err_pred != non_weight_err_pred + assert train_set_params["max_bin"] == 260 + assert train_set_params["min_data_in_bin"] == 5 + assert train_set_params["data_random_seed"] == 123 + np.testing.assert_allclose(stored_weights, refit_weight) + + def test_mape_rf(): X, y = load_boston(return_X_y=True) params = {