diff --git a/python_scripts/trees_classification.py b/python_scripts/trees_classification.py index ce102c149..32ce9aa44 100644 --- a/python_scripts/trees_classification.py +++ b/python_scripts/trees_classification.py @@ -240,3 +240,75 @@ # # In the next exercise, you will increase the tree depth to get an intuition on # how such parameter affects the space partitioning. +# +# Finally, we can try to visualize the output of predict_proba for a multiclass +# problem using `DecisionBoundaryDisplay`, except that For a K-class problem, +# you'll have K probability outputs for each data point. Visualizing all these +# on a single plot can quickly become tricky to interpret. It is then common to +# instead produce K separate plots, one for each class, in a one-vs-rest (or +# one-vs-all) fashion. +# +# For example, in the plot below, the first column shows in red the certainty on +# classifying a data point as belonging to the "Adelie" class. Notice that the +# logistic regression is more certain than our under-fitting tree in this case. +# Indeed, the shallow tree is unsure between classes "Adelie" and "Chinstrap". +# In the same column, the blue color represents the certainty of **not** +# belonging to the "Adelie" class. The same logic applies to the other columns. + +# %% +import numpy as np + +classifiers = { + "logistic": linear_model, + "tree": tree, +} +n_classifiers = len(classifiers) + +xx = np.linspace(30, 60, 100) +yy = np.linspace(10, 23, 100) +xx, yy = np.meshgrid(xx, yy) +Xfull = pd.DataFrame( + {"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()} +) + +plt.figure(figsize=(12, 4)) +plt.subplots_adjust(bottom=0.2, top=0.95) + +for index, (name, classifier) in enumerate(classifiers.items()): + classifier.fit(data_train, target_train) + target_pred = classifier.predict(data_test) + probas = classifier.predict_proba(Xfull) + n_classes = len(np.unique(classifier.classes_)) + + for k in range(n_classes): + plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1) + plt.title(f"Class {classifier.classes_[k]}") + if k == 0: + plt.ylabel(name) + imshow_handle = plt.imshow( + probas[:, k].reshape((100, 100)), + extent=(30, 60, 10, 23), + vmin=0.0, + vmax=1.0, + origin="lower", + cmap="RdBu_r", + ) + plt.xticks(()) + plt.yticks(()) + idx = target_test == classifier.classes_[k] + plt.scatter( + data_test["Culmen Length (mm)"].loc[idx], + data_test["Culmen Depth (mm)"].loc[idx], + marker="o", + c="w", + edgecolor="k", + ) + +ax = plt.axes([0.15, 0.04, 0.7, 0.05]) +plt.colorbar(imshow_handle, cax=ax, orientation="horizontal") +_ = plt.title("Probability") + +# %% [markdown] +# In scikit-learn v1.4 `DecisionBoundaryDisplay` will support a `class_of_interest` +# parameter that will allow in particular for a visualization of `predict_proba` in +# multi-class settings.