From 3d5e638145dd6448e880d89dd1e7fa2bd081050e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 17 Oct 2023 15:59:20 +0200 Subject: [PATCH] Do not use the red-blue colormap for a 3-class classification decision boundary --- python_scripts/trees_classification.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python_scripts/trees_classification.py b/python_scripts/trees_classification.py index 61ba1d8a2..0f4c24ed9 100644 --- a/python_scripts/trees_classification.py +++ b/python_scripts/trees_classification.py @@ -53,15 +53,22 @@ # %% import matplotlib.pyplot as plt +import matplotlib as mpl import seaborn as sns from sklearn.inspection import DecisionBoundaryDisplay +tab10_norm = mpl.colors.Normalize(vmin=-0.5, vmax=8.5) # create a palette to be used in the scatterplot -palette = ["tab:red", "tab:blue", "black"] - -DecisionBoundaryDisplay.from_estimator( - linear_model, data_train, response_method="predict", cmap="RdBu", alpha=0.5 +palette = ["tab:blue", "tab:green", "tab:orange"] + +dbd = DecisionBoundaryDisplay.from_estimator( + linear_model, + data_train, + response_method="predict", + cmap="tab10", + norm=tab10_norm, + alpha=0.5, ) sns.scatterplot( data=penguins, @@ -105,7 +112,12 @@ # %% DecisionBoundaryDisplay.from_estimator( - tree, data_train, response_method="predict", cmap="RdBu", alpha=0.5 + tree, + data_train, + response_method="predict", + cmap="tab10", + norm=tab10_norm, + alpha=0.5, ) sns.scatterplot( data=penguins,