Skip to content

Commit

Permalink
feat: plot decision tree (#876)
Browse files Browse the repository at this point in the history
Closes #856 

### Summary of Changes

- Added plot method for DecisionTreeClassifier and
DecisionTreeRegressor.
- Added tests for both

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Saman Hushi <[email protected]>
Co-authored-by: peplaul0 <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
5 people authored Jul 12, 2024
1 parent e93299f commit d3f81dc
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds.data.image.containers import Image
from safeds.exceptions._ml import ModelNotFittedError
from safeds.ml.classical._bases import _DecisionTreeBase

from ._classifier import Classifier
Expand Down Expand Up @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> ClassifierMixin:
max_depth=self._max_depth,
min_samples_leaf=self._min_sample_count_in_leaves,
)

# ------------------------------------------------------------------------------------------------------------------
# Plot
# ------------------------------------------------------------------------------------------------------------------

def plot(self) -> Image:
"""
Get the image of the decision tree.
Returns
-------
plot:
The decision tree figure as an image.
Raises
------
ModelNotFittedError:
If model is not fitted.
"""
if not self.is_fitted:
raise ModelNotFittedError

from io import BytesIO

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

plot_tree(self._wrapped_model)

# save plot fig bytes in buffer
with BytesIO() as buffer:
plt.savefig(buffer)
image = buffer.getvalue()

# prevent forced plot from sklearn showing
plt.close()

return Image.from_bytes(image)
40 changes: 40 additions & 0 deletions src/safeds/ml/classical/regression/_decision_tree_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds.data.image.containers import Image
from safeds.exceptions._ml import ModelNotFittedError
from safeds.ml.classical._bases import _DecisionTreeBase

from ._regressor import Regressor
Expand Down Expand Up @@ -71,3 +73,41 @@ def _get_sklearn_model(self) -> RegressorMixin:
max_depth=self._max_depth,
min_samples_leaf=self._min_sample_count_in_leaves,
)

# ------------------------------------------------------------------------------------------------------------------
# Plot
# ------------------------------------------------------------------------------------------------------------------

def plot(self) -> Image:
"""
Get the image of the decision tree.
Returns
-------
plot:
The decision tree figure as an image.
Raises
------
ModelNotFittedError:
If model is not fitted.
"""
if not self.is_fitted:
raise ModelNotFittedError

from io import BytesIO

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

plot_tree(self._wrapped_model)

# save plot fig bytes in buffer
with BytesIO() as buffer:
plt.savefig(buffer)
image = buffer.getvalue()

# prevent forced plot from sklearn showing
plt.close()

return Image.from_bytes(image)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 22 additions & 1 deletion tests/safeds/ml/classical/classification/test_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Table
from safeds.exceptions import OutOfBoundsError
from safeds.exceptions import ModelNotFittedError, OutOfBoundsError
from safeds.ml.classical.classification import DecisionTreeClassifier
from syrupy import SnapshotAssertion

from tests.helpers import os_mac, skip_if_os


@pytest.fixture()
Expand Down Expand Up @@ -41,3 +44,21 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None
def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None:
with pytest.raises(OutOfBoundsError):
DecisionTreeClassifier(min_sample_count_in_leaves=min_sample_count_in_leaves)


class TestPlot:
def test_should_raise_if_model_is_not_fitted(self) -> None:
model = DecisionTreeClassifier()
with pytest.raises(ModelNotFittedError):
model.plot()

def test_should_check_that_plot_image_is_same_as_snapshot(
self,
training_set: TabularDataset,
snapshot_png_image: SnapshotAssertion,
) -> None:
skip_if_os([os_mac])

fitted_model = DecisionTreeClassifier().fit(training_set)
image = fitted_model.plot()
assert image == snapshot_png_image
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 18 additions & 1 deletion tests/safeds/ml/classical/regression/test_decision_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Table
from safeds.exceptions import OutOfBoundsError
from safeds.exceptions import ModelNotFittedError, OutOfBoundsError
from safeds.ml.classical.regression import DecisionTreeRegressor
from syrupy import SnapshotAssertion


@pytest.fixture()
Expand Down Expand Up @@ -41,3 +42,19 @@ def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None
def test_should_raise_if_less_than_or_equal_to_0(self, min_sample_count_in_leaves: int) -> None:
with pytest.raises(OutOfBoundsError):
DecisionTreeRegressor(min_sample_count_in_leaves=min_sample_count_in_leaves)


class TestPlot:
def test_should_raise_if_model_is_not_fitted(self) -> None:
model = DecisionTreeRegressor()
with pytest.raises(ModelNotFittedError):
model.plot()

def test_should_check_that_plot_image_is_same_as_snapshot(
self,
training_set: TabularDataset,
snapshot_png_image: SnapshotAssertion,
) -> None:
fitted_model = DecisionTreeRegressor().fit(training_set)
image = fitted_model.plot()
assert image == snapshot_png_image

0 comments on commit d3f81dc

Please sign in to comment.