Skip to content

Commit

Permalink
Do not use the red-blue colormap for a 3-class classification decisio…
Browse files Browse the repository at this point in the history
…n boundary
  • Loading branch information
ogrisel committed Oct 17, 2023
1 parent 307e84c commit 3d5e638
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions python_scripts/trees_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,22 @@

# %%
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

from sklearn.inspection import DecisionBoundaryDisplay

tab10_norm = mpl.colors.Normalize(vmin=-0.5, vmax=8.5)
# create a palette to be used in the scatterplot
palette = ["tab:red", "tab:blue", "black"]

DecisionBoundaryDisplay.from_estimator(
linear_model, data_train, response_method="predict", cmap="RdBu", alpha=0.5
palette = ["tab:blue", "tab:green", "tab:orange"]

dbd = DecisionBoundaryDisplay.from_estimator(
linear_model,
data_train,
response_method="predict",
cmap="tab10",
norm=tab10_norm,
alpha=0.5,
)
sns.scatterplot(
data=penguins,
Expand Down Expand Up @@ -105,7 +112,12 @@

# %%
DecisionBoundaryDisplay.from_estimator(
tree, data_train, response_method="predict", cmap="RdBu", alpha=0.5
tree,
data_train,
response_method="predict",
cmap="tab10",
norm=tab10_norm,
alpha=0.5,
)
sns.scatterplot(
data=penguins,
Expand Down

0 comments on commit 3d5e638

Please sign in to comment.