From 487341b55ba816349ca963acf0d09fbb9d86a668 Mon Sep 17 00:00:00 2001 From: Saulo Martiello Mastelini Date: Mon, 18 Sep 2023 21:15:44 -0300 Subject: [PATCH] Simplify Adaptive Random Forest (#1323) --- docs/releases/unreleased.md | 4 + river/base/drift_detector.py | 3 +- river/forest/adaptive_random_forest.py | 456 ++++++++++--------------- river/forest/online_extra_trees.py | 7 +- 4 files changed, 193 insertions(+), 277 deletions(-) diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 8a655bfa95..507f905d22 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -5,3 +5,7 @@ River's mini-batch methods now support pandas v2. In particular, River conforms ## anomaly - Added `anomaly.LocalOutlierFactor`, which is an online version of the LOF algorithm for anomaly detection that matches the scikit-learn implementation. + +## forest + +- Simplify inner the structures of `forest.ARFClassifier` and `forest.ARFRegressor` by removing redundant class hierarchy. Simplify how concept drift logging can be accessed in individual trees and in the forest as a whole. diff --git a/river/base/drift_detector.py b/river/base/drift_detector.py index 62a511d5bb..25f0719f2a 100644 --- a/river/base/drift_detector.py +++ b/river/base/drift_detector.py @@ -8,7 +8,6 @@ from __future__ import annotations import abc -import numbers from . import base @@ -58,7 +57,7 @@ class DriftDetector(_BaseDriftDetector): """A drift detector.""" @abc.abstractmethod - def update(self, x: numbers.Number) -> DriftDetector: + def update(self, x: int | float) -> DriftDetector: """Update the detector with a single data point. Parameters diff --git a/river/forest/adaptive_random_forest.py b/river/forest/adaptive_random_forest.py index 413ba539e4..8e7a9c6645 100644 --- a/river/forest/adaptive_random_forest.py +++ b/river/forest/adaptive_random_forest.py @@ -47,11 +47,35 @@ def __init__( self.drift_detector = drift_detector self.warning_detector = warning_detector self.seed = seed + self._rng = random.Random(self.seed) - # Internal parameters - self._n_samples_seen = 0 - self._base_member_class: ForestMemberClassifier | ForestMemberRegressor | None = None + self._warning_detectors: list[base.DriftDetector] = ( + None # type: ignore + if self.warning_detector is None + else [self.warning_detector.clone() for _ in range(self.n_models)] + ) + self._drift_detectors: list[base.DriftDetector] = ( + None # type: ignore + if self.drift_detector is None + else [self.drift_detector.clone() for _ in range(self.n_models)] + ) + + # The background models + self._background: list[BaseTreeClassifier | BaseTreeRegressor | None] = ( + None if self.warning_detector is None else [None] * self.n_models # type: ignore + ) + + # Performance metrics used for weighted voting/aggregation + self._metrics = [self.metric.clone() for _ in range(self.n_models)] + + # Drift and warning logging + self._warning_tracker: dict = ( + collections.defaultdict(int) if self.warning_detector is not None else None # type: ignore + ) + self._drift_tracker: dict = ( + collections.defaultdict(int) if self.drift_detector is not None else None # type: ignore + ) @property def _min_number_of_models(self): @@ -64,49 +88,135 @@ def _unit_test_params(cls): def _unit_test_skips(self): return {"check_shuffle_features_no_impact"} - def learn_one(self, x: dict, y: base.typing.Target, **kwargs): - self._n_samples_seen += 1 + @abc.abstractmethod + def _drift_detector_input( + self, + tree_id: int, + y_true, + y_pred, + ) -> int | float: + raise NotImplementedError + + @abc.abstractmethod + def _new_base_model(self) -> BaseTreeClassifier | BaseTreeRegressor: + raise NotImplementedError + + def n_warnings_detected(self, tree_id: int | None = None) -> int | None: + """Get the total number of concept drift warnings detected, or the number on an individual + tree basis (optionally). + + If warning detection is disabled, will return `None`. + + Parameters + ---------- + tree_id + The number of the base learner in the ensemble: `[0, self.n_models - 1]. If `None`, + the total number of warnings is returned instead. + + Returns + ------- + The number of concept drift warnings detected. + + """ + + if self.warning_detector is None: + return None + + if tree_id is None: + return sum(self._warning_tracker.values()) + + return self._warning_tracker[tree_id] + + def n_drifts_detected(self, tree_id: int | None = None) -> int | None: + """Get the total number of concept drifts detected, or such number on an individual + tree basis (optionally). + + If drift detection is disabled, will return `None`. + + Parameters + ---------- + tree_id + The number of the base learner in the ensemble: `[0, self.n_models - 1]. If `None`, + the total number of warnings is returned instead. + + Returns + ------- + The number of concept drifts detected. + """ + + if self.drift_detector is None: + return None + + if tree_id is None: + return sum(self._drift_tracker.values()) + + return self._drift_tracker[tree_id] + + def learn_one(self, x: dict, y: base.typing.Target, **kwargs): if len(self) == 0: self._init_ensemble(sorted(x.keys())) - for model in self: - # Get prediction for instance - y_pred = ( - model.predict_proba_one(x) - if isinstance(model.metric, metrics.base.ClassificationMetric) - and not model.metric.requires_labels - else model.predict_one(x) - ) + for i, model in enumerate(self): + y_pred = model.predict_one(x) # Update performance evaluator - model.metric.update(y_true=y, y_pred=y_pred) + self._metrics[i].update( + y_true=y, + y_pred=model.predict_proba_one(x) + if isinstance(self.metric, metrics.base.ClassificationMetric) + and not self.metric.requires_labels + else y_pred, + ) k = poisson(rate=self.lambda_value, rng=self._rng) if k > 0: - model.learn_one(x=x, y=y, sample_weight=k, n_samples_seen=self._n_samples_seen) + if self.warning_detector is not None and self._background[i] is not None: + self._background[i].learn_one(x=x, y=y, sample_weight=k) # type: ignore + + model.learn_one(x=x, y=y, sample_weight=k) + + drift_input = None + if self.drift_detector is not None and self.warning_detector is not None: + drift_input = self._drift_detector_input(i, y, y_pred) + self._warning_detectors[i].update(drift_input) + + if self._warning_detectors[i].drift_detected: + self._background[i] = self._new_base_model() # type: ignore + # Reset the warning detector for the current object + self._warning_detectors[i] = self.warning_detector.clone() + + # Update warning tracker + self._warning_tracker[i] += 1 + + if self.drift_detector is not None: + drift_input = ( + drift_input + if drift_input is not None + else self._drift_detector_input(i, y, y_pred) + ) + self._drift_detectors[i].update(drift_input) + + if self._drift_detectors[i].drift_detected: + if self.warning_detector is not None and self._background[i] is not None: + self.data[i] = self._background[i] + self._background[i] = None + self._warning_detectors[i] = self.warning_detector.clone() + self._drift_detectors[i] = self.drift_detector.clone() + self._metrics[i] = self.metric.clone() + else: + self.data[i] = self._new_base_model() + self._drift_detectors[i] = self.drift_detector.clone() + self._metrics[i] = self.metric.clone() + + # Update warning tracker + self._drift_tracker[i] += 1 return self def _init_ensemble(self, features: list): self._set_max_features(len(features)) - - self.data = [ - self._base_member_class( # type: ignore - index_original=i, - model=self._new_base_model(), - created_on=self._n_samples_seen, - drift_detector=self.drift_detector, - warning_detector=self.warning_detector, - is_background_learner=False, - metric=self.metric, - ) - for i in range(self.n_models) - ] - - @abc.abstractmethod - def _new_base_model(self): - raise NotImplementedError + self.data = [self._new_base_model() for _ in range(self.n_models)] def _set_max_features(self, n_features): if self.max_features == "sqrt": @@ -230,12 +340,6 @@ def _new_leaf(self, initial_stats=None, parent=None): self.rng, ) - def new_instance(self): - new_instance = self.clone() - # Use existing rng to enforce a different model - new_instance.rng = self.rng - return new_instance - class BaseTreeRegressor(HoeffdingTreeRegressor): """ARF Hoeffding Tree regressor. @@ -343,12 +447,6 @@ def _new_leaf(self, initial_stats=None, parent=None): # noqa return new_adaptive - def new_instance(self): - new_instance = self.clone() - # Use existing rng to enforce a different model - new_instance.rng = self.rng - return new_instance - class ARFClassifier(BaseForest, base.Classifier): """Adaptive Random Forest classifier. @@ -472,7 +570,19 @@ class ARFClassifier(BaseForest, base.Classifier): >>> metric = metrics.Accuracy() >>> evaluate.progressive_val_score(dataset, model, metric) - Accuracy: 71.07% + Accuracy: 71.17% + + The total number of warnings and drifts detected, respectively + >>> model.n_warnings_detected(), model.n_drifts_detected() + (2, 1) + + The number of warnings detected by tree number 2 + >>> model.n_warnings_detected(2) + 1 + + And the corresponding number of actual concept drift detected + >>> model.n_drifts_detected(2) + 1 References ---------- @@ -523,9 +633,6 @@ def __init__( seed=seed, ) - self._n_samples_seen = 0 - self._base_member_class = ForestMemberClassifier # type: ignore - # Tree parameters self.grace_period = grace_period self.max_depth = max_depth @@ -566,9 +673,9 @@ def predict_proba_one(self, x: dict) -> dict[base.typing.ClfTarget, float]: self._init_ensemble(sorted(x.keys())) return y_pred # type: ignore - for model in self: + for i, model in enumerate(self): y_proba_temp = model.predict_proba_one(x) - metric_value = model.metric.get() + metric_value = self._metrics[i].get() if not self.disable_weighted_vote and metric_value > 0.0: y_proba_temp = {k: val * metric_value for k, val in y_proba_temp.items()} y_pred.update(y_proba_temp) @@ -601,9 +708,14 @@ def _new_base_model(self): rng=self._rng, ) + def _drift_detector_input( + self, tree_id: int, y_true: base.typing.ClfTarget, y_pred: base.typing.ClfTarget + ) -> int | float: + return int(not y_true == y_pred) + class ARFRegressor(BaseForest, base.Regressor): - r"""Adaptive Random Forest regressor. + """Adaptive Random Forest regressor. The 3 most important aspects of Adaptive Random Forest [^1] are: @@ -621,7 +733,7 @@ class ARFRegressor(BaseForest, base.Regressor): predictions and check for concept drifts. The deviations of the predictions to the target are monitored and normalized in the [0, 1] range to fulfill ADWIN's requirements. We assume that the data subjected to the normalization follows - a normal distribution, and thus, lies within the interval of the mean $\pm3\sigma$. + a normal distribution, and thus, lies within the interval of the mean $\\pm3\\sigma$. Parameters ---------- @@ -742,7 +854,7 @@ class ARFRegressor(BaseForest, base.Regressor): >>> metric = metrics.MAE() >>> evaluate.progressive_val_score(dataset, model, metric) - MAE: 0.800378 + MAE: 0.788619 """ @@ -791,9 +903,6 @@ def __init__( seed=seed, ) - self._n_samples_seen = 0 - self._base_member_class = ForestMemberRegressor # type: ignore - # Tree parameters self.grace_period = grace_period self.max_depth = max_depth @@ -820,6 +929,9 @@ def __init__( f"Valid values are: {self._VALID_AGGREGATION_METHOD}" ) + # Used to normalize the input for the drift trackers + self._drift_norm = [stats.Var() for _ in range(self.n_models)] + @property def _mutable_attributes(self): return { @@ -842,10 +954,10 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: if not self.disable_weighted_vote and self.aggregation_method != self._MEDIAN: weights = np.zeros(self.n_models) sum_weights = 0.0 - for idx, model in enumerate(self): - y_pred[idx] = model.predict_one(x) - weights[idx] = model.metric.get() - sum_weights += weights[idx] + for i, model in enumerate(self): + y_pred[i] = model.predict_one(x) + weights[i] = self._metrics[i].get() + sum_weights += weights[i] if sum_weights != 0: # The higher the error, the worse is the tree @@ -854,8 +966,8 @@ def predict_one(self, x: dict) -> base.typing.RegTarget: weights /= weights.sum() y_pred *= weights else: - for idx, model in enumerate(self): - y_pred[idx] = model.predict_one(x) + for i, model in enumerate(self): + y_pred[i] = model.predict_one(x) if self.aggregation_method == self._MEAN: y_pred = y_pred.mean() @@ -885,214 +997,19 @@ def _new_base_model(self): rng=self._rng, ) - @property - def valid_aggregation_method(self): - """Valid aggregation_method values.""" - return self._VALID_AGGREGATION_METHOD - - -class BaseForestMember: - """Base forest member class. - - This class represents a tree member of the forest. It includes a - base tree model, the background learner, drift detectors and performance - tracking parameters. - - The main purpose of this class is to train the foreground model. - Optionally, it monitors drift detection. Depending on the configuration, - if drift is detected then the foreground model is reset or replaced by a - background model. - - Parameters - ---------- - index_original - Tree index within the ensemble. - model - Tree learner. - created_on - Number of instances seen by the tree. - drift_detector - Drift Detection method. - warning_detector - Warning Detection method. - is_background_learner - True if the tree is a background learner. - metric - Metric to track performance. - - """ - - def __init__( - self, - index_original: int, - model: BaseTreeClassifier | BaseTreeRegressor, - created_on: int, - drift_detector: base.DriftDetector, - warning_detector: base.DriftDetector, - is_background_learner, - metric: metrics.base.MultiClassMetric | metrics.base.RegressionMetric, - ): - self.index_original = index_original - self.model = model - self.created_on = created_on - self.is_background_learner = is_background_learner - self.metric = metric.clone() - self.background_learner = None - - # Drift and warning detection - self.last_drift_on = 0 - self.last_warning_on = 0 - self.n_drifts_detected = 0 - self.n_warnings_detected = 0 - - # Initialize drift and warning detectors - if drift_detector is not None: - self._use_drift_detector = True - self.drift_detector = drift_detector.clone() - else: - self._use_drift_detector = False - self.drift_detector = None - - if warning_detector is not None: - self._use_background_learner = True - self.warning_detector = warning_detector.clone() - else: - self._use_background_learner = False - self.warning_detector = None - - def reset(self, n_samples_seen): - if self._use_background_learner and self.background_learner is not None: - # Replace foreground model with background model - self.model = self.background_learner.model - self.warning_detector = self.background_learner.warning_detector - self.drift_detector = self.background_learner.drift_detector - self.metric = self.background_learner.metric - self.created_on = self.background_learner.created_on - self.background_learner = None - else: - # Reset model - self.model = self.model.new_instance() - self.metric = self.metric.clone() - self.created_on = n_samples_seen - self.drift_detector = self.drift_detector.clone() - - def learn_one(self, x: dict, y: base.typing.Target, *, sample_weight: int, n_samples_seen: int): - self.model.learn_one(x, y, sample_weight=sample_weight) - - if self.background_learner: - # Train the background learner - self.background_learner.model.learn_one(x=x, y=y, sample_weight=sample_weight) - - if self._use_drift_detector and not self.is_background_learner: - drift_detector_input = self._drift_detector_input( - y_true=y, y_pred=self.model.predict_one(x) # type: ignore - ) - - # Check for warning only if use_background_learner is set - if self._use_background_learner: - self.warning_detector.update(drift_detector_input) - # Check if there was a (warning) change - if self.warning_detector.drift_detected: - self.last_warning_on = n_samples_seen - self.n_warnings_detected += 1 - # Create a new background learner object - self.background_learner = self.__class__( # type: ignore - index_original=self.index_original, - model=self.model.new_instance(), - created_on=n_samples_seen, - drift_detector=self.drift_detector, - warning_detector=self.warning_detector, - is_background_learner=True, - metric=self.metric, - ) - # Reset the warning detector for the current object - self.warning_detector = self.warning_detector.clone() - - # Update the drift detector - self.drift_detector.update(drift_detector_input) - - # Check if there was a change - if self.drift_detector.drift_detected: - self.last_drift_on = n_samples_seen - self.n_drifts_detected += 1 - self.reset(n_samples_seen) - - @abc.abstractmethod def _drift_detector_input( self, - y_true: base.typing.ClfTarget | base.typing.RegTarget, - y_pred: base.typing.ClfTarget | base.typing.RegTarget, - ): - raise NotImplementedError - - -class ForestMemberClassifier(BaseForestMember, base.Classifier): # type: ignore - """Forest member class for classification""" - - def __init__( - self, - index_original: int, - model: BaseTreeClassifier, - created_on: int, - drift_detector: base.DriftDetector, - warning_detector: base.DriftDetector, - is_background_learner, - metric: metrics.base.MultiClassMetric, - ): - super().__init__( - index_original=index_original, - model=model, - created_on=created_on, - drift_detector=drift_detector, - warning_detector=warning_detector, - is_background_learner=is_background_learner, - metric=metric, - ) - - def _drift_detector_input( # type: ignore - self, y_true: base.typing.ClfTarget, y_pred: base.typing.ClfTarget - ): - return int(not y_true == y_pred) # Not correctly_classifies - - def predict_one(self, x): - return self.model.predict_one(x) - - def predict_proba_one(self, x): - return self.model.predict_proba_one(x) - - -class ForestMemberRegressor(BaseForestMember, base.Regressor): # type: ignore - """Forest member class for regression""" - - def __init__( - self, - index_original: int, - model: BaseTreeRegressor, - created_on: int, - drift_detector: base.DriftDetector, - warning_detector: base.DriftDetector, - is_background_learner, - metric: metrics.base.RegressionMetric, - ): - super().__init__( - index_original=index_original, - model=model, - created_on=created_on, - drift_detector=drift_detector, - warning_detector=warning_detector, - is_background_learner=is_background_learner, - metric=metric, - ) - self._var = stats.Var() # Used to track drift - - def _drift_detector_input(self, y_true: float, y_pred: float): # type: ignore + tree_id: int, + y_true: int | float, + y_pred: int | float, + ) -> int | float: drift_input = y_true - y_pred - self._var.update(drift_input) + self._drift_norm[tree_id].update(drift_input) - if self._var.mean.n == 1: + if self._drift_norm[tree_id].mean.n == 1: return 0.5 # The expected error is the normalized mean error - sd = math.sqrt(self._var.get()) + sd = math.sqrt(self._drift_norm[tree_id].get()) # We assume the error follows a normal distribution -> (empirical rule) # 99.73% of the values lie between [mean - 3*sd, mean + 3*sd]. We @@ -1100,10 +1017,7 @@ def _drift_detector_input(self, y_true: float, y_pred: float): # type: ignore # min-max norm to cope with ADWIN's requirements return (drift_input + 3 * sd) / (6 * sd) if sd > 0 else 0.5 - def reset(self, n_samples_seen): - super().reset(n_samples_seen) - # Reset the stats for the drift detector - self._var = stats.Var() - - def predict_one(self, x): - return self.model.predict_one(x) + @property + def valid_aggregation_method(self): + """Valid aggregation_method values.""" + return self._VALID_AGGREGATION_METHOD diff --git a/river/forest/online_extra_trees.py b/river/forest/online_extra_trees.py index cd7c6d7edc..ce904c9b83 100644 --- a/river/forest/online_extra_trees.py +++ b/river/forest/online_extra_trees.py @@ -3,7 +3,6 @@ import abc import collections import math -import numbers import random import sys @@ -221,7 +220,7 @@ def __weight_sampler_factory(self) -> Sampler: def _detection_mode_all( drift_detector: base.DriftDetector, warning_detector: base.DriftDetector, - detector_input: numbers.Number, + detector_input: int | float, ) -> tuple[bool, bool]: in_warning = warning_detector.update(detector_input).drift_detected in_drift = drift_detector.update(detector_input).drift_detected @@ -232,7 +231,7 @@ def _detection_mode_all( def _detection_mode_drop( drift_detector: base.DriftDetector, warning_detector: base.DriftDetector, - detector_input: numbers.Number, + detector_input: int | float, ) -> tuple[bool, bool]: in_drift = drift_detector.update(detector_input).drift_detected @@ -242,7 +241,7 @@ def _detection_mode_drop( def _detection_mode_off( drift_detector: base.DriftDetector, warning_detector: base.DriftDetector, - detector_input: numbers.Number, + detector_input: int | float, ) -> tuple[bool, bool]: return False, False