diff --git a/lectures/tools/classification.md b/lectures/tools/classification.md index 31dd6293..87075283 100644 --- a/lectures/tools/classification.md +++ b/lectures/tools/classification.md @@ -378,12 +378,12 @@ could be properly detected from noise (i.e. enemy aircraft vs. noise). def plot_roc(mod, X, y): # predicted_probs is an N x 2 array, where N is number of observations # and 2 is number of classes - predicted_probs = mod.predict_proba(X_test) + predicted_probs = mod.predict_proba(X) # keep the second column, for label=1 predicted_prob1 = predicted_probs[:, 1] - fpr, tpr, _ = metrics.roc_curve(y_test, predicted_prob1) + fpr, tpr, _ = metrics.roc_curve(y, predicted_prob1) # Plot ROC curve fig, ax = plt.subplots()