Skip to content

Commit

Permalink
Support mixed model for RandomForestClassificationModel (#200)
Browse files Browse the repository at this point in the history
* Support mixed model for RandomForestClassificationModel

This PR only supports converting trees when impurity is gini.

Signed-off-by: Bobby Wang <[email protected]>

* fix bug

---------

Signed-off-by: Bobby Wang <[email protected]>
  • Loading branch information
wbo4958 authored Apr 12, 2023
1 parent ff89615 commit 34ce8bc
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 56 deletions.
151 changes: 145 additions & 6 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Literal, Tuple, Type, Union
import json
from typing import Any, Callable, List, Optional, Tuple, Type, Union

import cudf
import numpy as np
import pandas as pd
from pyspark import Row
from pyspark.ml.classification import _RandomForestClassifierParams
from pyspark.ml.param.shared import HasProbabilityCol
from pyspark.ml.classification import (
BinaryRandomForestClassificationSummary,
DecisionTreeClassificationModel,
)
from pyspark.ml.classification import (
RandomForestClassificationModel as SparkRandomForestClassificationModel,
)
from pyspark.ml.classification import (
RandomForestClassificationSummary,
_RandomForestClassifierParams,
)
from pyspark.ml.linalg import Vector
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType, FloatType, IntegerType, IntegralType

from spark_rapids_ml.core import CumlT, alias, pred
from spark_rapids_ml.tree import (
from .core import CumlT, alias, pred
from .tree import (
_RandomForestClass,
_RandomForestCumlParams,
_RandomForestEstimator,
_RandomForestModel,
)
from .utils import _get_spark_session, java_uid, translate_trees


class _RFClassifierParams(_RandomForestClassifierParams, HasProbabilityCol):
class _RFClassifierParams(
_RandomForestClassifierParams, HasProbabilityCol, HasRawPredictionCol
):
def __init__(self, *args: Any):
super().__init__(*args)

Expand All @@ -46,6 +61,14 @@ def setProbabilityCol(
"""
return self._set(probabilityCol=value)

def setRawPredictionCol(
self: "_RFClassifierParams", value: str
) -> "_RFClassifierParams":
"""
Sets the value of :py:attr:`rawPredictionCol`.
"""
return self._set(rawPredictionCol=value)


class RandomForestClassifier(
_RandomForestClass,
Expand Down Expand Up @@ -164,15 +187,66 @@ def __init__(
n_cols: int,
dtype: str,
treelite_model: str,
model_json: List[str],
num_classes: int,
):
super().__init__(
dtype=dtype,
n_cols=n_cols,
treelite_model=treelite_model,
model_json=model_json,
num_classes=num_classes,
)
self._num_classes = num_classes
self._model_json = model_json
self._rf_spark_model: Optional[SparkRandomForestClassificationModel] = None

def cpu(self) -> SparkRandomForestClassificationModel:
if self.getImpurity() != "gini":
# TODO, support entropy impurity
raise ValueError(
"Can't convert to Spark RandomForestClassificationModel"
" when impurity is not gini"
)

if self._rf_spark_model is None:
sc = _get_spark_session().sparkContext
assert sc._jvm is not None
assert sc._gateway is not None

uid = java_uid(sc, "rfc")

# Convert cuml trees to Spark trees
trees = [
translate_trees(sc, trees)
for trees_json in self._model_json
for trees in json.loads(trees_json)
]

# Wrap the trees into Spark DecisionTreeClassificationModel
decision_trees = [
sc._jvm.org.apache.spark.ml.classification.DecisionTreeClassificationModel(
uid, tree, self.numFeatures, self._num_classes
)
for tree in trees
]
object_class = (
sc._jvm.org.apache.spark.ml.classification.DecisionTreeClassificationModel
)
java_trees = sc._gateway.new_array(object_class, len(decision_trees))
for i in range(len(decision_trees)):
java_trees[i] = decision_trees[i]

# Create the Spark RandomForestClassificationModel
java_rf_model = sc._jvm.org.apache.spark.ml.classification.RandomForestClassificationModel(
uid,
java_trees,
self.numFeatures,
self._num_classes,
)
self._rf_spark_model = SparkRandomForestClassificationModel(java_rf_model)
self._copyValues(self._rf_spark_model)
return self._rf_spark_model

def _get_cuml_transform_func(
self, dataset: DataFrame
Expand Down Expand Up @@ -210,3 +284,68 @@ def hasSummary(self) -> bool:
def numClasses(self) -> int:
"""Number of classes (values which the label can take)."""
return self._num_classes

def predict(self, value: Vector) -> float:
"""
Predict label for the given features.
"""
return self.cpu().predict(value)

def predictLeaf(self, value: Vector) -> float:
"""
Predict the indices of the leaves corresponding to the feature vector.
"""
return self.cpu().predictLeaf(value)

def predictRaw(self, value: Vector) -> Vector:
"""
Raw prediction for each possible label.
"""
return self.cpu().predictRaw(value)

def predictProbability(self, value: Vector) -> Vector:
"""
Predict the probability of each class given the features.
"""
return self.cpu().predictProbability(value)

def evaluate(
self, dataset: DataFrame
) -> Union[
BinaryRandomForestClassificationSummary, RandomForestClassificationSummary
]:
"""
Evaluates the model on a test dataset.
Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
Test dataset to evaluate model on.
"""
return self.cpu().evaluate(dataset)

@property
def featureImportances(self) -> Vector:
"""
Estimate of the importance of each feature.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
See Also
--------
DecisionTreeClassificationModel.featureImportances
"""
return self.cpu().featureImportances

@property
def treeWeights(self) -> List[float]:
"""Return the weights for each tree"""
return self.cpu().treeWeights

@property
def trees(self) -> List[DecisionTreeClassificationModel]: # type: ignore
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return self.cpu().trees
6 changes: 5 additions & 1 deletion python/src/spark_rapids_ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,11 @@ def __init__(
dtype: str,
treelite_model: str,
):
super().__init__(dtype=dtype, n_cols=n_cols, treelite_model=treelite_model)
super().__init__(
dtype=dtype,
n_cols=n_cols,
treelite_model=treelite_model,
)

def _is_classification(self) -> bool:
return False
34 changes: 30 additions & 4 deletions python/src/spark_rapids_ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import base64
import json
import math
import pickle
from abc import abstractmethod
Expand All @@ -26,7 +27,13 @@
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol
from pyspark.ml.tree import _DecisionTreeModel
from pyspark.sql import DataFrame
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql.types import (
ArrayType,
IntegerType,
StringType,
StructField,
StructType,
)

from spark_rapids_ml.core import (
CumlInputType,
Expand Down Expand Up @@ -235,11 +242,21 @@ def _rf_fit(
serialized_model = rf._get_serialized_model()
pickled_model = pickle.dumps(serialized_model)
msg = base64.b64encode(pickled_model).decode("utf-8")
messages = context.allGather(msg)
trees = rf.get_json()
data = {"model_bytes": msg, "model_json": trees}
messages = context.allGather(json.dumps(data))

# concatenate the random forest in the worker0
if part_id == 0:
mod_bytes = [pickle.loads(base64.b64decode(i)) for i in messages]
mod_bytes = []
mod_jsons = []
for msg in messages:
data = json.loads(msg)
mod_bytes.append(
pickle.loads(base64.b64decode(data["model_bytes"]))
)
mod_jsons.append(data["model_json"])

all_tl_mod_handles = [rf._tl_handle_from_bytes(i) for i in mod_bytes]
rf._concatenate_treelite_handle(all_tl_mod_handles)

Expand All @@ -257,6 +274,7 @@ def _rf_fit(
}
if is_classification:
result["num_classes"] = rf.num_classes
result["model_json"] = [mod_jsons]
return result
else:
return {}
Expand All @@ -271,6 +289,8 @@ def _out_schema(self) -> Union[StructType, str]:
]
if self._is_classification():
fields.append(StructField("num_classes", IntegerType(), False))
fields.append(StructField("model_json", ArrayType(StringType()), False))

return StructType(fields)

def _require_nccl_ucx(self) -> Tuple[bool, bool]:
Expand All @@ -286,6 +306,7 @@ def __init__(
n_cols: int,
dtype: str,
treelite_model: str,
model_json: List[str] = [],
num_classes: int = -1, # only for classification
):
if self._is_classification():
Expand All @@ -294,9 +315,14 @@ def __init__(
n_cols=n_cols,
treelite_model=treelite_model,
num_classes=num_classes,
model_json=model_json,
)
else:
super().__init__(dtype=dtype, n_cols=n_cols, treelite_model=treelite_model)
super().__init__(
dtype=dtype,
n_cols=n_cols,
treelite_model=treelite_model,
)

self.treelite_model = treelite_model

Expand Down
Loading

0 comments on commit 34ce8bc

Please sign in to comment.