Skip to content

Commit

Permalink
Add plots on multiclass predict_proba
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturoAmorQ committed Oct 10, 2023
1 parent 12c3f33 commit 8c987c6
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions python_scripts/trees_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,75 @@
#
# In the next exercise, you will increase the tree depth to get an intuition on
# how such parameter affects the space partitioning.
#
# Finally, we can try to visualize the output of predict_proba for a multiclass
# problem using `DecisionBoundaryDisplay`, except that For a K-class problem,
# you'll have K probability outputs for each data point. Visualizing all these
# on a single plot can quickly become tricky to interpret. It is then common to
# instead produce K separate plots, one for each class, in a one-vs-rest (or
# one-vs-all) fashion.
#
# For example, in the plot below, the first column shows in red the certainty on
# classifying a data point as belonging to the "Adelie" class. Notice that the
# logistic regression is more certain than our under-fitting tree in this case.
# Indeed, the shallow tree is unsure between classes "Adelie" and "Chinstrap".
# In the same column, the blue color represents the certainty of **not**
# belonging to the "Adelie" class. The same logic applies to the other columns.

# %%
import numpy as np

classifiers = {
"logistic": linear_model,
"tree": tree,
}
n_classifiers = len(classifiers)

xx = np.linspace(30, 60, 100)
yy = np.linspace(10, 23, 100)
xx, yy = np.meshgrid(xx, yy)
Xfull = pd.DataFrame(
{"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()}
)

plt.figure(figsize=(12, 4))
plt.subplots_adjust(bottom=0.2, top=0.95)

for index, (name, classifier) in enumerate(classifiers.items()):
classifier.fit(data_train, target_train)
target_pred = classifier.predict(data_test)
probas = classifier.predict_proba(Xfull)
n_classes = len(np.unique(classifier.classes_))

for k in range(n_classes):
plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1)
plt.title(f"Class {classifier.classes_[k]}")
if k == 0:
plt.ylabel(name)
imshow_handle = plt.imshow(
probas[:, k].reshape((100, 100)),
extent=(30, 60, 10, 23),
vmin=0.0,
vmax=1.0,
origin="lower",
cmap="RdBu_r",
)
plt.xticks(())
plt.yticks(())
idx = target_test == classifier.classes_[k]
plt.scatter(
data_test["Culmen Length (mm)"].loc[idx],
data_test["Culmen Depth (mm)"].loc[idx],
marker="o",
c="w",
edgecolor="k",
)

ax = plt.axes([0.15, 0.04, 0.7, 0.05])
plt.colorbar(imshow_handle, cax=ax, orientation="horizontal")
_ = plt.title("Probability")

# %% [markdown]
# In scikit-learn v1.4 `DecisionBoundaryDisplay` will support a `class_of_interest`
# parameter that will allow in particular for a visualization of `predict_proba` in
# multi-class settings.

0 comments on commit 8c987c6

Please sign in to comment.