{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cff435f2-0261-4160-8e34-b5c56af2bb16",
   "metadata": {},
   "source": [
    "# Evaluation Metrics for Classification\n",
    "\n",
    "* Use the same data as in the previous session: Churn Prediction (Identify clients that want to leave the company)\n",
    "* Data https://www.kaggle.com/blastchar/telco-customer-churn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9c606de-f885-4881-ac03-2f91ee5dd3ff",
   "metadata": {},
   "source": [
    "## Review of previous Session\n",
    "\n",
    "* Metrics: A function, that compares the actual values with the predicted values. It outputs a single umber that measures the goodness of the model\n",
    "\n",
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d7a4bf0f-5824-4905-a374-b50d4279d4f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mutual_info_score\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d49cb1f5-a3e5-4f0f-877d-1d67d1284379",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9541f68c-d164-4097-a35e-c9894e4fba1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"../data/WA_Fn-UseC_-Telco-Customer-Churn.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b7472f5-def7-46c2-8d86-74f954f318c4",
   "metadata": {},
   "source": [
    "### Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0715e0db-c953-47be-b645-97d7c5ab0570",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unify column names\n",
    "df.columns = df.columns.str.lower().str.replace(\" \", \"_\")\n",
    "# List of all categorical columns\n",
    "categorical_columns = list(df.dtypes[df.dtypes == object].index)\n",
    "# Unify categolrical columns\n",
    "for c in categorical_columns:\n",
    "    df[c] = df[c].str.lower().str.replace(\" \", \"_\")\n",
    "\n",
    "# Change \"totalcharges to numeric and fill na with 0\n",
    "df.totalcharges = pd.to_numeric(df[\"totalcharges\"], errors=\"coerce\")\n",
    "df.totalcharges = df.totalcharges.fillna(0)\n",
    "\n",
    "# Change \"churn\" to type int\n",
    "df[\"churn\"] = (df[\"churn\"] == \"yes\").astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e47b2cf-747a-4970-ba4d-f02cbfe79f8d",
   "metadata": {},
   "source": [
    "### Model Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "61ddad15-3767-4ac5-abff-b86e7bfa986c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 80% train + val = train_full, 20% test\n",
    "df_train_full, df_test = train_test_split(df, test_size=0.2, random_state=1)\n",
    "# 75% train, 25% val out of train_full \n",
    "# 60% train, 20% val, 20% test out of df\n",
    "df_train, df_val = train_test_split(df_train_full, test_size=0.25, random_state=1)\n",
    "\n",
    "# reset index\n",
    "df_train = df_train.reset_index(drop=True)\n",
    "df_val = df_val.reset_index(drop=True)\n",
    "df_test = df_test.reset_index(drop=True)\n",
    "\n",
    "y_train = df_train[\"churn\"]\n",
    "y_val = df_val[\"churn\"]\n",
    "y_test = df_test[\"churn\"]\n",
    "\n",
    "# delete \"churn from df_train, df_val, df_test (not from df)\n",
    "del df_train[\"churn\"]\n",
    "del df_val[\"churn\"]\n",
    "del df_test[\"churn\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "81d44f24-cb7a-4b52-8b83-cffdd07e395a",
   "metadata": {},
   "outputs": [],
   "source": [
    "numerical = [\"tenure\", \"monthlycharges\", \"totalcharges\"]\n",
    "categorical = ['gender', 'seniorcitizen', 'partner', 'dependents',\n",
    "       'phoneservice', 'multiplelines', 'internetservice',\n",
    "       'onlinesecurity', 'onlinebackup', 'deviceprotection', 'techsupport',\n",
    "       'streamingtv', 'streamingmovies', 'contract', 'paperlessbilling',\n",
    "       'paymentmethod']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4ed56df1-a78f-4988-898d-6f82e0f2a0d0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jens/miniconda3/envs/ml-zoomcamp/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LogisticRegression()"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First convert to dictionary\n",
    "train_dicts = df_train[categorical+numerical].to_dict(orient=\"records\")\n",
    "\n",
    "dv = DictVectorizer(sparse=False) # don't use sparse matrix\n",
    "dv.fit(train_dicts)\n",
    "X_train = dv.transform(train_dicts)\n",
    "\n",
    "val_dicts = df_val[categorical+numerical].to_dict(orient=\"records\")\n",
    "X_val = dv.transform(val_dicts)\n",
    "\n",
    "model = LogisticRegression()\n",
    "model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a514ca21-d8b8-46c2-9590-1c6e5d3b9ec5",
   "metadata": {},
   "source": [
    "### Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f3963b4f-51dc-4c2b-8933-fa2ad8f7580d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8034066713981547"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = model.predict_proba(X_val)[:,1]\n",
    "churn_decision = (y_pred > 0.5)\n",
    "(y_val == churn_decision).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3aade31-a89c-4af7-a64a-3c6d39266d4f",
   "metadata": {},
   "source": [
    "## Accuracy and Dummy Model\n",
    "\n",
    "* Evaluate the model on different thresholds\n",
    "* Check the accuracy of the dummy baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "548c0052-9abd-42ce-98fd-0088bad90367",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1409"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# How many validation values do we have?\n",
    "len(y_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7bfe43fc-db2a-4f80-8373-4116ee4fdc3a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1132"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# How many correct decisions did we take?\n",
    "(y_val == churn_decision).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03bc157d-37be-4159-ac1a-8073e082a340",
   "metadata": {},
   "source": [
    "* Accruracy: Correct Prediction / All Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "94043d4b-ba0e-4964-9935-a9de58f26b1c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8034066713981547"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1132/1409"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dadb639-28c5-4a86-8c3b-3ff18a0d0aa1",
   "metadata": {},
   "source": [
    "* Is the threshold of 0.5 good?\n",
    "* Test the model with different thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b0df11ad-abe4-4f59-b393-343b1889e597",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.  , 0.05, 0.1 , 0.15, 0.2 , 0.25, 0.3 , 0.35, 0.4 , 0.45, 0.5 ,\n",
       "       0.55, 0.6 , 0.65, 0.7 , 0.75, 0.8 , 0.85, 0.9 , 0.95, 1.  ])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "thresholds = np.linspace(0,1,21)\n",
    "thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c8bc491a-bec8-43b7-b842-a0710395dc7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "threshold: 0.00, score: 0.274\n",
      "threshold: 0.05, score: 0.510\n",
      "threshold: 0.10, score: 0.591\n",
      "threshold: 0.15, score: 0.666\n",
      "threshold: 0.20, score: 0.710\n",
      "threshold: 0.25, score: 0.739\n",
      "threshold: 0.30, score: 0.760\n",
      "threshold: 0.35, score: 0.772\n",
      "threshold: 0.40, score: 0.785\n",
      "threshold: 0.45, score: 0.794\n",
      "threshold: 0.50, score: 0.803\n",
      "threshold: 0.55, score: 0.801\n",
      "threshold: 0.60, score: 0.796\n",
      "threshold: 0.65, score: 0.786\n",
      "threshold: 0.70, score: 0.766\n",
      "threshold: 0.75, score: 0.744\n",
      "threshold: 0.80, score: 0.734\n",
      "threshold: 0.85, score: 0.726\n",
      "threshold: 0.90, score: 0.726\n",
      "threshold: 0.95, score: 0.726\n",
      "threshold: 1.00, score: 0.726\n"
     ]
    }
   ],
   "source": [
    "scores = []\n",
    "for t in thresholds:\n",
    "    churn_decision = (y_pred > t)\n",
    "    score = (y_val == churn_decision).mean()\n",
    "    print(f\"threshold: {t:.2f}, score: {score:.3f}\")\n",
    "    scores.append(score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b06d517f-5b3b-4914-9f96-3249960af188",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfV0lEQVR4nO3de3SddZ3v8fc3SZMmTZq2ubVN06T0SotAaSgCIhe5FB1F1qAic2Bk9CAqXmZ5Y85ZM7NmuWZGD47OOYDW6gDqUauCo/VMoSClgFShF6BQmqRp6CUtubRpm537ZX/PH3u3hJDS3XYnT/azP6+1srKf/fzY+/sju5/88nt+z/OYuyMiIqkvI+gCREQkORToIiIhoUAXEQkJBbqISEgo0EVEQiIrqDcuLi72qqqqoN5eRCQlbdmy5aC7l4y0L7BAr6qqYvPmzUG9vYhISjKzPSfapykXEZGQUKCLiISEAl1EJCQU6CIiIaFAFxEJCQW6iEhIJBToZrbCzGrNrN7M7h5hf6GZ/d7MXjaz7WZ2e/JLFRGRd3LSdehmlgncD1wDNAKbzGyNu782pNnngNfc/YNmVgLUmtnP3L1vVKoWGQVHu/t5ce9hth9oJzPDmJSTRX5OJvk5E5iUk0lB/Ht+Thb5E7PInZCJmQVdtshxiZxYtByod/cGADNbDdwADA10Bwos9unOB9qAgSTXKpI07s7rBzvZsucwW/ceZsuew+xs6eBUbg+QYTApOxbusfCPfZVNnkhVUR6zi/KoKppEZVEeU/KyR68zInGJBHo5sG/IdiNw0bA29wFrgANAAfAxd48mpUKRJOjuG2Rb4xG27D3M1j2H2br3CG2dsT8gJ0/M4oLKqXzw3Jksq5zKuRVTyDSjo3eAjt4BOnsHiPTEvnf2vfl46P7Y40EiPf08V3+QR7b2vOX9C3MnxEN+Uuz7tDyqimNhX5Kfo5G+JEUigT7SJ234OOY64CXgKmAu8ISZPevu7W95IbM7gDsAZs+efcrFiiTqjaPdbNkTG3lv3RObRhmIxj62c0sm8b5FpSyrnMqyyqnMLcknI+PtH/Pc7ExKCnJO6/17+gfZ29bF7oOdse+HOtlzqIuX9x3hv7YdIDrkX1Bediazp+VRWZTHwrICllZO5YKKqRTmTTit95b0lUigNwIVQ7ZnERuJD3U78E2P3c+u3sxeBxYBLwxt5O6rgFUA1dXVuvedJEX/YJQdb7S/JcAPHI2NkCdOyOD8iil8+vKzWFY5laUVU5k6afSnPyZOyGRBWQELygretq9vIMr+I93sPtTJ3kNvhv3Olg6eeK35eNjPL81nWeVULoj/4jmreJJG8vKOEgn0TcB8M5sD7AduBm4Z1mYv8D7gWTMrAxYCDcksVOSYts4+XozPe2/Zc5iXG4/Q0x+b4Sufksuyqmn899lTqK6cxqIZBUzIHF+rc7OzMphTPIk5xZPetq+jd4Bt+47E+rb3MGtfeYPVm2IznlPzJnDB7DcD/rxZU8jNzhzr8mUcO2mgu/uAmd0FrAMygQfcfbuZ3RnfvxL4BvCQmb1CbIrm6+5+cBTrljQRjTq7WjuOh/eWvYdpaO0EICvDWFJeyC3LK+Mj2SnMKMwNuOIzk5+TxSXzirlkXjHw1v4fO3j7ZE0LEOv/4pmTuWB2LOAvmjON0skTgyxfAmZ+Kof1k6i6utp1+VwZyd5DXTz+WhN/rD/I1j2Hae+JLZiaNin7eHgtq5zKubMKmTgh/Uaohzv7eHHfkL9Q9h2lu3+QzAzjpgtmcddV86iYlhd0mTJKzGyLu1ePuE+BLkFzd7YfaOfx7U08/lozNU0RIDaHXF017XiAVxXlaQ55BP2DUWreiPDI1kZ+/sJeolHnI9Wz+NyV85g1VcEeNgp0GXcGBqO8sLuNx7c388Rrzew/0k2GQXXVNK5dXMa1i6czu0hhdKqajvbw/Q31/OKFfTjOR6sr+NyV85g5JbWnouRNCnQZF7r7Bnm6rpXHX2tifU0LR7r6yc7K4L3zi7l2yXTet6iUovzTWyYob3XgSDff21DPLzftwzA+dmEFn71ybsofYxAFugSorbOPJ3c08/hrzTy7s5We/iiTJ2Zx9dllXLukjMvmlzApJ7A7IYbe/iPd3P9UPb/atI8MM265aDafuWIuZTp4mrIU6DJm3GOrMp7c0cKTNS1s3t1G1GFG4cTYVMqS6SyfM23cLSUMu31tXdz/VD0Pb2kkM+PNYC8tULCnGgW6jKregUGeb2hjfU0L62ta2NvWBcCi6QVcfXYZ1y2Zzjnlk3VAcxzYe6iLe9fv5Dcv7icrw7j13ZV8+vK5p31GrIw9BbokXUukhw01rTxZ08yzOw/S1TdITlYGl84r5qpFpVy5qJRyHYgbt3Yf7OTe9fX854uNZGdlcNvFVdx5+VymjcFZtHJmFOhyxqLR2NLCJ2uaWV/TwrbGo0BsKuWqRaVctaiUS+YW68zFFNPQ2sG96+v53Uv7ycvO4pPvmcOnLptDwURdR2a8UqDLaXF3NtS2sm57bFVKS6QXM1haMSUe4mWcPaNAUykhsLM5wneeqOPRV5uYkjeBz1w+l9surtIv6HFIgS6nxN15uq6Vf3u8jlf2H6UgJ4v3LijhqkWlXLGwREsLQ+yVxqN8+/Fanq5rpbQgh89fNY+PXTib7CwdxB4vFOiSsBdeb+Pb62p5YXcbs6bm8qWrF/Ch82bqH3Saeb7hEN9+vJZNuw8za2ouf3v1Aj68tJzMES4zLGNLgS4n9er+2MhsQ20rJQU5fEEjs7R37C+1e9bVsv1AO/NK8/nyNQtYcc50TbMFSIEuJ1TfEps7XftKbO70zsvn8teaO5UholHnse1N/Nvjtexq7eRd5YV85bqFvHd+sYI9AAp0eZt9bV387yd38putjeROyOSTl53Fpy6bw2StbpATGBiM8tuXDvDdJ+rYf6Sb5XOm8dXrFnJh1bSgS0srCnQ5rqW9h/uequcXL+zFzPjriyu58/K5OtApCesdGOSXm/Zx7/p6WiO9XLGwhI8sq+DiuUVaxz4GFOjC4c4+Vj6zix9v3M3AoPPRCyv4/FXzdLEmOW3dfYP8+E+7Wfn0Lo509QOweMZkLp1XxCXzilleNU3X6RkFCvQ01tM/yKpnGvjhMw109A3w4fPL+dLV86ksevvtz0ROR/9glG2NR9lYf5Dndh1k654j9A1Gycowls6ewiVzi7lkbhFLZ0/VQfYkUKCnqdqmCF9c/SI1TRGuXVzGl69dyMLpb79psUgydfcNsmXPYZ7bdZCN9Qd5Zf9Rog65EzK5cM40Lp1bxKXzilk8YzIZWgZ5yt4p0PX3UAi5Oz/50x7+Ze0OCiZm8eDtF3LlwtKgy5I0kZudyXvmF/Oe+bH7oh7t7ufPDYfYWH+QjbsO8a+P1gAwJW8CF59VxLLKqSyZWcjimZMpzNVB+TOhQA+Zgx29fO3hbayvaeHKhSX8r5vO05X0JFCFuRO4bsl0rlsyHYgdmN+46xDPxQP+0VebjretmJbLkhmFLJk5mSXlk1kys5DSghwtj0yQplxCZENtC1/59Tbae/r5n+8/m9surtQ/BBn3Dnb0sv1AO9sPHI1933+U3Ye6ju8vzs9m8cx4yM+MhXzltLy0na7RlEvI9fQP8s1Ha3ho424WlhXws09dpLlySRnF+TlcvqCEyxeUHH8u0tPPjjcib4b8gXZ++EwDA9HYADQ/J4uzZxRQMS0PI/WC/apFpXzg3BlJf10FeoobeuDzE5dUcff1i5g4QWd5SmormDiB5XOmsXzOmyct9Q4MsrO54y0h/3xDW4BVnr55pfmj8roK9BR17MDnP6/dwWQd+JQ0kJOVyTnlhZxTXhh0KeOWAj0FtUZ6+drDL/NUbStXLizhno+cR7HO9BRJewr0FPNUTQtfffhl2nsG+KcPLdGBTxE5ToGeIoYe+Fw0vYCfferdOvApIm+hQE8Brx/s5M6fbqG2OcLtl1bx9RU68Ckib6dAH+caD3dxyw//TO9AlIduv5ArdOBTRE4goSvlmNkKM6s1s3ozu3uE/V81s5fiX6+a2aCZ6SLJZ6gl0sN/+9HzdPYO8H8/eZHCXETe0UkD3cwygfuB64HFwMfNbPHQNu5+j7uf7+7nA38HPO3uqblAdJw40tXHbf/xAi2RXh68fTmLZ04OuiQRGecSGaEvB+rdvcHd+4DVwA3v0P7jwC+SUVy66ugd4BMPbqKhtZNVt1azrHJq0CWJSApIJNDLgX1Dthvjz72NmeUBK4BHTrD/DjPbbGabW1tbT7XWtNDTP8gdP9nMK/uPct8tS49fsU5E5GQSCfSRFjmf6IpeHwSeO9F0i7uvcvdqd68uKSkZqUla6x+MctfPt7Jx1yG+/ZFzuTZ+dToRkUQkEuiNQMWQ7VnAgRO0vRlNt5yWaNT5yq9f5g87WvjGDUu4cemsoEsSkRSTSKBvAuab2RwzyyYW2muGNzKzQuBy4HfJLTH83J2//92r/O6lA3xtxUJuvbgq6JJEJAWddB26uw+Y2V3AOiATeMDdt5vZnfH9K+NNbwQed/fOUas2hNydbz5Ww8+e38tnrpjLZ6+YF3RJIpKiEjqxyN3XAmuHPbdy2PZDwEPJKixdfG/DLn7wdAO3vruSr123MOhyRCSF6RbcAfrxxt3cs66WG5eW808fWqKLbInIGVGgB+SRLY3845rtXLO4jHtuOjdtb6clIsmjQA/AY6828dWHX+bSeUXc+/GlZGXqxyAiZ05JMsaeqWvlC794kfMrprDq1mpdNVFEkkaBPoY2727jjp9uZm5pPg9+YjmTcnSxSxFJHgX6GNl+4Ci3P7iJmYW5/ORvllOYNyHokkQkZBToY2BgMMqXVr9E/sQsfvqpiygp0P0/RST5FOhj4JGtjexs6eAf/mIx5VNygy5HREJKgT7KuvsG+c4TdSydPYUV5+hiWyIyehToo+zBja/T3N7L3SsW6cQhERlVCvRRdLizj+9v2MX7FpVy0VlFQZcjIiGnQB9F9z1VT2fvAF+/flHQpYhIGlCgj5J9bV389E97uGnZLBaUFQRdjoikAQX6KPnOE3WYwd9esyDoUkQkTSjQR8H2A0f57Uv7uf3SOcwo1DJFERkbCvRR8M1HayjMncBnrpgbdCkikkYU6En2x50HeXbnQe66ch6FuTq9X0TGjgI9iaJR55uP7aB8Si63XlwZdDkikmYU6En0+20HeHV/O1++dgE5WbosroiMLQV6kvQNRPn247WcPWMyHz6/POhyRCQNKdCT5GfP72FfWzd3X79It5MTkUAo0JMg0tPPvevruWRuEe+dXxx0OSKSphToSfCDpxto6+zj764/WxfgEpHAKNDPUEt7Dz/6YwMfPG8m75pVGHQ5IpLGFOhn6Lt/2Mlg1PnqtQuDLkVE0pwC/QzUt3Twq837+KuLKpldlBd0OSKS5hToZ+CedTXkTsjk81fNC7oUEREF+unasqeNddub+fR7z6IoXzd9FpHgJRToZrbCzGrNrN7M7j5BmyvM7CUz225mTye3zPHF3fnXtTWUFOTwycvmBF2OiAgAWSdrYGaZwP3ANUAjsMnM1rj7a0PaTAG+B6xw971mVjpK9Y4Lf9jRwuY9h/nnG88hL/uk/wtFRMZEIiP05UC9uze4ex+wGrhhWJtbgN+4+14Ad29Jbpnjx8BglG89VsNZJZP4WHVF0OWIiByXSKCXA/uGbDfGnxtqATDVzDaY2RYzu22kFzKzO8xss5ltbm1tPb2KA/bwlkbqWzr42nWLyMrUIQgRGT8SSaSRTn30YdtZwDLgA8B1wN+b2dvuvebuq9y92t2rS0pKTrnYoHX3DfLdP9RxwewpXLekLOhyRETeIpEJ4EZg6NzCLODACG0Ounsn0GlmzwDnAXVJqXKceOC512lu7+W+Wy7QKf4iMu4kMkLfBMw3szlmlg3cDKwZ1uZ3wGVmlmVmecBFwI7klhqso939rHx6F1efXcqFVdOCLkdE5G1OOkJ39wEzuwtYB2QCD7j7djO7M75/pbvvMLPHgG1AFPiRu786moWPtf/44+tEegb4sk7xF5FxKqE1d+6+Flg77LmVw7bvAe5JXmnjx9Gufh784+tcf850zp4xOehyRERGpGUaCfjRHxuI9A7wxavnB12KiMgJKdBP4khXHw8+t5sPvGsGi6ZrdC4i45cC/SR++GwDnX0anYvI+KdAfwdtnX08FB+dLygrCLocEZF3pEB/Bz98toGu/kG++D6NzkVk/FOgn8Chjl5+vHE3Hzx3JvM1OheRFKBAP4FVzzbQ3T/IFzQ6F5EUoUAfwcGOXn6ycQ8fOm8m80rzgy5HRCQhCvQRrHqmgd4Bjc5FJLUo0IdpjfTykz/t5obzy5lbotG5iKQOBfowP3h6F30DUd34WURSjgJ9iJb2Hn765z18eGk5Z2l0LiIpRoE+xPef3sVA1PnCVZo7F5HUo0CPa27v4WfP7+XGpeVUFU8KuhwRkVOmQI/7/oZdDEZdc+cikrIU6EDT0R5+/sJe/vKCciqLNDoXkdSkQAe+t6GeaNT5vObORSSFpX2gHzjSzeoX9nHTsllUTMsLuhwRkdOW9oH+vQ31RN353JWaOxeR1JbWgb7/SDe/3LSPj1RXaHQuIikvrQP9/qfqAbhLK1tEJATSNtAbD3fx6837+Gh1BeVTcoMuR0TkjKVtoN//VD2Gae5cREIjLQN9X1sXv97cyMcurGCmRuciEhJpGej3ra8nw4zPXjk36FJERJIm7QJ976EuHt7ayMeXVzCjUKNzEQmPtAv0e9fvJDPD+KzmzkUkZNIq0A919PKbF/dzy/LZlE2eGHQ5IiJJlVCgm9kKM6s1s3ozu3uE/VeY2VEzeyn+9Q/JL/XM7XgjwmDUuXZJWdCliIgkXdbJGphZJnA/cA3QCGwyszXu/tqwps+6+1+MQo1JU9PUDsDCsoKAKxERSb5ERujLgXp3b3D3PmA1cMPoljU66pojFOdnU5SfE3QpIiJJl0iglwP7hmw3xp8b7mIze9nMHjWzJUmpLslqmztYOF2jcxEJp0QC3UZ4zodtbwUq3f084F7gtyO+kNkdZrbZzDa3traeUqFnKhp1djZHWKDpFhEJqUQCvRGoGLI9CzgwtIG7t7t7R/zxWmCCmRUPfyF3X+Xu1e5eXVJScgZln7r9R7rp6hvU/LmIhFYigb4JmG9mc8wsG7gZWDO0gZlNNzOLP14ef91DyS72TNQ2RQBYoCkXEQmpk65ycfcBM7sLWAdkAg+4+3YzuzO+fyVwE/AZMxsAuoGb3X34tEygaptjgT6/ND/gSkRERsdJAx2OT6OsHfbcyiGP7wPuS25pyVXbFKF8Si4FEycEXYqIyKhImzNF65ojWuEiIqGWFoHePxhlV6uWLIpIuKVFoO8+2En/oGuFi4iEWloE+rEDolqDLiJhlh6B3hQhM8M4q2RS0KWIiIyatAn0qqI8Jk7IDLoUEZFRkxaBrhUuIpIOQh/o3X2D7Gnr0vy5iIRe6AO9vqUDd1ikEbqIhFzoA10rXEQkXYQ+0OuaI2RnZVBZpBUuIhJuoQ/0mqYI80vzycwY6bLuIiLhEfpAr2uK6AxREUkLoQ70o139NLX36BroIpIWQh3odS2xA6IaoYtIOgh1oB+7S5FOKhKRdBDqQK9rjlCQk8WMwolBlyIiMupCHeg1TREWTC8gfrtTEZFQC22guzt1zRGdUCQiaSO0gd4a6eVIVz8Ly3RTaBFJD6EN9OOn/OuAqIikifAGepOWLIpIegltoNc1RyjOz6EoPyfoUkRExkRoA722KcLC6Zo/F5H0EcpAj0aduuYOrXARkbQSykBvPNxNd/+g5s9FJK2EMtC1wkVE0lEoA71OdykSkTQUykCvbYowa2ou+TlZQZciIjJmEgp0M1thZrVmVm9md79DuwvNbNDMbkpeiaeurlk3tRCR9HPSQDezTOB+4HpgMfBxM1t8gnbfAtYlu8hT0T8YZVdrh+bPRSTtJDJCXw7Uu3uDu/cBq4EbRmj3eeARoCWJ9Z2y1w920j/oGqGLSNpJJNDLgX1Dthvjzx1nZuXAjcDKd3ohM7vDzDab2ebW1tZTrTUhx0751wFREUk3iQT6SBcT92Hb/w583d0H3+mF3H2Vu1e7e3VJSUmCJZ6auuYImRnGWSWTRuX1RUTGq0SWgTQCFUO2ZwEHhrWpBlbHbyRRDLzfzAbc/bfJKPJU1DZFqCrKY+KEzLF+axGRQCUS6JuA+WY2B9gP3AzcMrSBu8859tjMHgL+XxBhDrER+pKZhUG8tYhIoE465eLuA8BdxFav7AB+5e7bzexOM7tztAs8FV19A+xp69L8uYikpYTOvHH3tcDaYc+NeADU3T9x5mWdnvqWDtzRVRZFJC2F6kxRrXARkXQWqkCva46QnZVBZZFWuIhI+glVoNc2dzC/NJ/MjJFWWoqIhFuoAr2uSddwEZH0FZpAP9rVT1N7Dwt1DRcRSVOhCXTd1EJE0l3oAl1TLiKSrkIT6HVNEQpysphRODHoUkREAhGaQK9tjrBgegHx68mIiKSdUAS6u1PXHNEJRSKS1kIR6K2RXo509bOwTKf8i0j6CkWg18RP+V84fXLAlYiIBCcUgV53bMmiRugiksZCEei1TRGK83Moys8JuhQRkcCEItDrmiO6ZK6IpL2UD/Ro1Klr7tAKFxFJeykf6I2Hu+nuH9QZoiKS9lI+0Gua2gFdw0VEJOUD/c0VLgp0EUlvKR/otc0dzJqaS35OQrdHFREJrZQPdN3UQkQkJqUDvW8gyq7WDs2fi4iQ4oG++1AnA1HXCF1EhBQP9GPXcNEBURGRFA/0uqYImRnG3NJJQZciIhK4lA702uYIc4onkZOVGXQpIiKBS+lAr2vWChcRkWNSNtC7+gbY29al+XMRkbiEAt3MVphZrZnVm9ndI+y/wcy2mdlLZrbZzN6T/FLfqr6lA3d0lUURkbiTnl5pZpnA/cA1QCOwyczWuPtrQ5o9Caxxdzezc4FfAYtGo+BjarXCRUTkLRIZoS8H6t29wd37gNXADUMbuHuHu3t8cxLgjLLapgg5WRlUFmmFi4gIJBbo5cC+IduN8efewsxuNLMa4L+AvxnphczsjviUzObW1tbTqfe42uYI88vyycywM3odEZGwSCTQR0rMt43A3f0/3X0R8GHgGyO9kLuvcvdqd68uKSk5pUKHq2uOaLpFRGSIRAK9EagYsj0LOHCixu7+DDDXzIrPsLYTOtLVR3N7r5YsiogMkUigbwLmm9kcM8sGbgbWDG1gZvPMzOKPLwCygUPJLvaYuuYOQDe1EBEZ6qSrXNx9wMzuAtYBmcAD7r7dzO6M718J/CVwm5n1A93Ax4YcJE262vhNLTRCFxF5U0J3hXD3tcDaYc+tHPL4W8C3klvaidU2tVOQk8WMwolj9ZYiIuNeSp4pWtcUuwZ6fJZHRERIwUB3d2qbIyzU/LmIyFukXKC3RHo52t2v+XMRkWFSLtB1yr+IyMhSLtDzsjO5ZnGZplxERIZJaJXLeFJdNY3qqmlBlyEiMu6k3AhdRERGpkAXEQkJBbqISEgo0EVEQkKBLiISEgp0EZGQUKCLiISEAl1EJCRsFC9b/s5vbNYK7DnN/7wYOJjEclKB+pwe1Of0cCZ9rnT3Ee/hGVignwkz2+zu1UHXMZbU5/SgPqeH0eqzplxEREJCgS4iEhKpGuirgi4gAOpzelCf08Oo9Dkl59BFROTtUnWELiIiwyjQRURCYlwHupmtMLNaM6s3s7tH2G9m9n/i+7eZ2QVB1JlMCfT5r+J93WZmG83svCDqTKaT9XlIuwvNbNDMbhrL+kZDIn02syvM7CUz225mT491jcmWwGe70Mx+b2Yvx/t8exB1JouZPWBmLWb26gn2Jz+/3H1cfgGZwC7gLCAbeBlYPKzN+4FHAQPeDTwfdN1j0OdLgKnxx9enQ5+HtFsPrAVuCrruMfg5TwFeA2bHt0uDrnsM+vw/gG/FH5cAbUB20LWfQZ/fC1wAvHqC/UnPr/E8Ql8O1Lt7g7v3AauBG4a1uQH4icf8GZhiZjPGutAkOmmf3X2jux+Ob/4ZmDXGNSZbIj9ngM8DjwAtY1ncKEmkz7cAv3H3vQDunur9TqTPDhSYmQH5xAJ9YGzLTB53f4ZYH04k6fk1ngO9HNg3ZLsx/typtkklp9qfTxL7DZ/KTtpnMysHbgRWjmFdoymRn/MCYKqZbTCzLWZ225hVNzoS6fN9wNnAAeAV4IvuHh2b8gKR9PwazzeJthGeG77GMpE2qSTh/pjZlcQC/T2jWtHoS6TP/w583d0HY4O3lJdIn7OAZcD7gFzgT2b2Z3evG+3iRkkifb4OeAm4CpgLPGFmz7p7+yjXFpSk59d4DvRGoGLI9ixiv7lPtU0qSag/ZnYu8CPgenc/NEa1jZZE+lwNrI6HeTHwfjMbcPffjkmFyZfoZ/ugu3cCnWb2DHAekKqBnkifbwe+6bEJ5nozex1YBLwwNiWOuaTn13iectkEzDezOWaWDdwMrBnWZg1wW/xo8buBo+7+xlgXmkQn7bOZzQZ+A9yawqO1oU7aZ3ef4+5V7l4FPAx8NoXDHBL7bP8OuMzMsswsD7gI2DHGdSZTIn3eS+wvEsysDFgINIxplWMr6fk1bkfo7j5gZncB64gdIX/A3beb2Z3x/SuJrXh4P1APdBH7DZ+yEuzzPwBFwPfiI9YBT+Er1SXY51BJpM/uvsPMHgO2AVHgR+4+4vK3VJDgz/kbwENm9gqx6Yivu3vKXlbXzH4BXAEUm1kj8I/ABBi9/NKp/yIiITGep1xEROQUKNBFREJCgS4iEhIKdBGRkFCgi4iEhAJdRCQkFOgiIiHx/wEWiTpFN6hTxwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(thresholds, scores);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5863d4d-9faa-453e-84bf-bd07c341abe6",
   "metadata": {},
   "source": [
    "* Here 0.5 is the best threshold, before and afterwards the accuracy is lower\n",
    "* Scikit-learn provides a function for the accuracy (above we wrote our own function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e0b0c194-da7f-4800-9c80-c3f9dc01a040",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "15a4ffa1-c321-4cac-8454-938d466dc308",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8034066713981547"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_val, (y_pred > 0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eb06644-8976-4f7d-ba85-6371313bd1e7",
   "metadata": {},
   "source": [
    "* Special cases: threshold = 0 and threshold = 1\n",
    "* threshold = 1 means, we are predicting that no customer is churning\n",
    "    * For this case (Dummy Model) the accuracy is ~73%\n",
    "* threshold = 0 means, we are predicting that all customers are churning\n",
    "* Our model gives an accuracy of 0.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5fcc8336-5ba7-49ca-a159-d20a1f6e1208",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7260468417317246"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# non-churning users\n",
    "1 - y_val.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55cd79f1-bbb4-4362-ba22-4310120c2489",
   "metadata": {},
   "source": [
    "* In this data set that are much more non-churning than churning users\n",
    "* We have class imbalance in this data set\n",
    "    * For class imbalance, predicting the majority class already gives quite good accuracy and accuracy can be a misleading metric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "739c4736-7a1e-4348-80c7-dbf753f256dd",
   "metadata": {},
   "source": [
    "## Confusion Table\n",
    "\n",
    "* Different types of errors and correct decisions\n",
    "* Arranging them in a table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd238f6d-a0c1-48c9-8e67-7696eb4c31bb",
   "metadata": {},
   "source": [
    "![confusion_table](Screenshot1.png \"Confusion Table\")\n",
    "![confusion_table](Screenshot2.png \"Confusion Table\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1f0809e8-5685-4e01-9562-ca5182b43cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "actual_positive = (y_val == 1)\n",
    "actual_negative = (y_val == 0)\n",
    "\n",
    "t = 0.5\n",
    "predict_positive = (y_pred >= t)\n",
    "predict_negative = (y_pred < t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b5efff4b-f501-4517-ab8a-c5eb1c480349",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "210"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tp = (predict_positive & actual_positive).sum()\n",
    "tp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "76eb0d5e-477d-4707-a2da-9b8679729b15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "922"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tn = (predict_negative & actual_negative).sum()\n",
    "tn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c1280b4c-afcc-4974-8de7-7de00db78535",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "101"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fp = (predict_positive & actual_negative).sum()\n",
    "fp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a58fcd48-35a5-4997-81d7-35381228ccd2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "176"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fn = (predict_negative & actual_positive).sum()\n",
    "fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bfda861c-c0bb-47f9-bae5-af260995fc0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[922, 101],\n",
       "       [176, 210]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix = np.array([\n",
    "                        [tn, fp],\n",
    "                        [fn, tp]])\n",
    "confusion_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af374736-2fdf-4618-b55d-6d7920b31a41",
   "metadata": {},
   "source": [
    "* We have a lot of more false negatives than false positives\n",
    "    * False positives are predicted to churn, but they are not going to. We would send them a discount e-mail, though not necessary and would loose profit\n",
    "    * False positives are not receiving an e-mail and are leaving, also loosing profit in this case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7e9d4922-c2c6-4e0d-8499-71911cae226e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.65, 0.07],\n",
       "       [0.12, 0.15]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# in relative numbers\n",
    "(confusion_matrix / confusion_matrix.sum()).round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de4279d1-2c39-4f05-9a9b-42a10d698e6c",
   "metadata": {},
   "source": [
    "![confusion_table](Screenshot3.png \"Confusion Table\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11ab231a-0bed-43f9-83e9-11c62ad8500f",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Precission and Recall\n",
    "\n",
    "* Very useful for binary clasification problems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "45a945de-f6a2-4284-842a-6426369ec348",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8034066713981547"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy\n",
    "(tp + tn) / (tp + tn + fp + fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d90ff83-3261-4dab-9a68-d5b686fd57e2",
   "metadata": {},
   "source": [
    "Now we will look at other metrics:\n",
    "* **Precision: Fraction of positive predictions that are correct**\n",
    "    * tp / (tp + fp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3dc1eb1b-a05c-4ab1-9205-2389a3444a28",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6752411575562701"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p = tp / (tp + fp)\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "276bb7e7-d262-45c1-8d7c-e40e803df077",
   "metadata": {},
   "source": [
    "* This means that (1-p)~33% will get a promotional e-mail, although they are not going to churn "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "918b9191-7a9d-45f0-a9c4-db000310cfa4",
   "metadata": {},
   "source": [
    "* **Recall: Fraction of correctly identified positive examples**\n",
    "    * tp / (tp + fn) = tp / (#positive observations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e454a016-f544-4022-9ad2-27c4ffe25021",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5440414507772021"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r = tp / (tp + fn)\n",
    "r"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fd9970c-194f-43cc-a1b1-9b5e449ec833",
   "metadata": {},
   "source": [
    "* This means that we failed to identify (1-r)~46% of people who are churning "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37e197a4-8ac7-40fe-a272-a9898aae2e0b",
   "metadata": {},
   "source": [
    "## ROC Curves\n",
    "\n",
    "### TPR and FPR\n",
    "\n",
    "* TPR: True positive rate TP / (FN + TP) (Nr of true positives by all positives) == RECALL\n",
    "* FPR: False positive rate FP / (TN + FP) (Nr of false positives by all negatives)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62f61b68-5658-4884-9f80-9c21dacb6abf",
   "metadata": {},
   "source": [
    "![confusion_table](Screenshot4.png \"Confusion Table\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a45ed148-67e0-4669-9037-a5932df0d3d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5440414507772021"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tpr = tp / (fn + tp)\n",
    "tpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5ce0bbcc-ef02-45e8-b465-9f5e105ab797",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.09872922776148582"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpr = fp / (tn + fp)\n",
    "fpr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d31259ea-b0c6-4fbf-b22f-410a1a6c81f4",
   "metadata": {},
   "source": [
    "* We now calculated these numbers for the threshold 0.5\n",
    "* ROC curves compares results different thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a1a5309c-de64-4139-98be-83bea83cdd39",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = []\n",
    "thresholds = np.linspace(0,1,101)\n",
    "\n",
    "for t in thresholds:\n",
    "    actual_positive = (y_val == 1)\n",
    "    actual_negative = (y_val == 0)\n",
    "\n",
    "    predict_positive = (y_pred >= t)\n",
    "    predict_negative = (y_pred < t)\n",
    "    \n",
    "    tp = (predict_positive & actual_positive).sum()\n",
    "    tn = (predict_negative & actual_negative).sum()\n",
    "    \n",
    "    fp = (predict_positive & actual_negative).sum()\n",
    "    fn = (predict_negative & actual_positive).sum()\n",
    "    \n",
    "    scores.append((t, tp, fp, fn, tn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "daa7a3c0-0f72-4a8e-8230-e081211f168a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0.0, 386, 1023, 0, 0),\n",
       " (0.01, 385, 914, 1, 109),\n",
       " (0.02, 384, 830, 2, 193),\n",
       " (0.03, 383, 766, 3, 257),\n",
       " (0.04, 381, 715, 5, 308),\n",
       " (0.05, 379, 683, 7, 340),\n",
       " (0.06, 377, 661, 9, 362),\n",
       " (0.07, 372, 640, 14, 383),\n",
       " (0.08, 371, 613, 15, 410),\n",
       " (0.09, 369, 580, 17, 443),\n",
       " (0.1, 366, 556, 20, 467),\n",
       " (0.11, 365, 528, 21, 495),\n",
       " (0.12, 365, 509, 21, 514),\n",
       " (0.13, 360, 477, 26, 546),\n",
       " (0.14, 355, 453, 31, 570),\n",
       " (0.15, 351, 435, 35, 588),\n",
       " (0.16, 347, 419, 39, 604),\n",
       " (0.17, 346, 401, 40, 622),\n",
       " (0.18, 344, 384, 42, 639),\n",
       " (0.19, 338, 369, 48, 654),\n",
       " (0.2, 333, 356, 53, 667),\n",
       " (0.21, 329, 341, 57, 682),\n",
       " (0.22, 323, 322, 63, 701),\n",
       " (0.23, 320, 313, 66, 710),\n",
       " (0.24, 316, 304, 70, 719),\n",
       " (0.25, 309, 291, 77, 732),\n",
       " (0.26, 304, 281, 82, 742),\n",
       " (0.27, 303, 270, 83, 753),\n",
       " (0.28, 296, 256, 90, 767),\n",
       " (0.29, 291, 245, 95, 778),\n",
       " (0.3, 284, 236, 102, 787),\n",
       " (0.31, 280, 230, 106, 793),\n",
       " (0.32, 278, 226, 108, 797),\n",
       " (0.33, 276, 221, 110, 802),\n",
       " (0.34, 274, 213, 112, 810),\n",
       " (0.35000000000000003, 272, 207, 114, 816),\n",
       " (0.36, 267, 201, 119, 822),\n",
       " (0.37, 265, 197, 121, 826),\n",
       " (0.38, 260, 185, 126, 838),\n",
       " (0.39, 252, 179, 134, 844),\n",
       " (0.4, 249, 166, 137, 857),\n",
       " (0.41000000000000003, 246, 159, 140, 864),\n",
       " (0.42, 243, 157, 143, 866),\n",
       " (0.43, 241, 150, 145, 873),\n",
       " (0.44, 234, 147, 152, 876),\n",
       " (0.45, 230, 134, 156, 889),\n",
       " (0.46, 224, 125, 162, 898),\n",
       " (0.47000000000000003, 218, 120, 168, 903),\n",
       " (0.48, 217, 115, 169, 908),\n",
       " (0.49, 213, 110, 173, 913),\n",
       " (0.5, 210, 101, 176, 922),\n",
       " (0.51, 207, 99, 179, 924),\n",
       " (0.52, 204, 93, 182, 930),\n",
       " (0.53, 196, 91, 190, 932),\n",
       " (0.54, 194, 86, 192, 937),\n",
       " (0.55, 185, 79, 201, 944),\n",
       " (0.56, 182, 76, 204, 947),\n",
       " (0.5700000000000001, 176, 68, 210, 955),\n",
       " (0.58, 171, 61, 215, 962),\n",
       " (0.59, 163, 59, 223, 964),\n",
       " (0.6, 151, 53, 235, 970),\n",
       " (0.61, 145, 49, 241, 974),\n",
       " (0.62, 141, 46, 245, 977),\n",
       " (0.63, 133, 40, 253, 983),\n",
       " (0.64, 125, 37, 261, 986),\n",
       " (0.65, 119, 34, 267, 989),\n",
       " (0.66, 114, 31, 272, 992),\n",
       " (0.67, 105, 29, 281, 994),\n",
       " (0.68, 94, 26, 292, 997),\n",
       " (0.6900000000000001, 88, 25, 298, 998),\n",
       " (0.7000000000000001, 76, 20, 310, 1003),\n",
       " (0.71, 63, 14, 323, 1009),\n",
       " (0.72, 57, 11, 329, 1012),\n",
       " (0.73, 47, 10, 339, 1013),\n",
       " (0.74, 41, 8, 345, 1015),\n",
       " (0.75, 33, 7, 353, 1016),\n",
       " (0.76, 30, 6, 356, 1017),\n",
       " (0.77, 25, 5, 361, 1018),\n",
       " (0.78, 19, 3, 367, 1020),\n",
       " (0.79, 15, 2, 371, 1021),\n",
       " (0.8, 13, 2, 373, 1021),\n",
       " (0.81, 6, 0, 380, 1023),\n",
       " (0.8200000000000001, 5, 0, 381, 1023),\n",
       " (0.8300000000000001, 3, 0, 383, 1023),\n",
       " (0.84, 0, 0, 386, 1023),\n",
       " (0.85, 0, 0, 386, 1023),\n",
       " (0.86, 0, 0, 386, 1023),\n",
       " (0.87, 0, 0, 386, 1023),\n",
       " (0.88, 0, 0, 386, 1023),\n",
       " (0.89, 0, 0, 386, 1023),\n",
       " (0.9, 0, 0, 386, 1023),\n",
       " (0.91, 0, 0, 386, 1023),\n",
       " (0.92, 0, 0, 386, 1023),\n",
       " (0.93, 0, 0, 386, 1023),\n",
       " (0.9400000000000001, 0, 0, 386, 1023),\n",
       " (0.9500000000000001, 0, 0, 386, 1023),\n",
       " (0.96, 0, 0, 386, 1023),\n",
       " (0.97, 0, 0, 386, 1023),\n",
       " (0.98, 0, 0, 386, 1023),\n",
       " (0.99, 0, 0, 386, 1023),\n",
       " (1.0, 0, 0, 386, 1023)]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "3ab7ac85-2d57-4446-b73a-ad948ae0b0e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>threshold</th>\n",
       "      <th>tp</th>\n",
       "      <th>fp</th>\n",
       "      <th>fn</th>\n",
       "      <th>tn</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00</td>\n",
       "      <td>386</td>\n",
       "      <td>1023</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.01</td>\n",
       "      <td>385</td>\n",
       "      <td>914</td>\n",
       "      <td>1</td>\n",
       "      <td>109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02</td>\n",
       "      <td>384</td>\n",
       "      <td>830</td>\n",
       "      <td>2</td>\n",
       "      <td>193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03</td>\n",
       "      <td>383</td>\n",
       "      <td>766</td>\n",
       "      <td>3</td>\n",
       "      <td>257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.04</td>\n",
       "      <td>381</td>\n",
       "      <td>715</td>\n",
       "      <td>5</td>\n",
       "      <td>308</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   threshold   tp    fp  fn   tn\n",
       "0       0.00  386  1023   0    0\n",
       "1       0.01  385   914   1  109\n",
       "2       0.02  384   830   2  193\n",
       "3       0.03  383   766   3  257\n",
       "4       0.04  381   715   5  308"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# put scores into data frame\n",
    "columns =[\"threshold\", \"tp\", \"fp\", \"fn\", \"tn\"]\n",
    "\n",
    "df_scores = pd.DataFrame(scores, columns=columns)\n",
    "df_scores.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e3edb2d0-cda9-4492-b42f-52a7d9a9e8a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_scores[\"tpr\"] = df_scores.tp/(df_scores.tp + df_scores.fn)\n",
    "df_scores[\"fpr\"] = df_scores.fp/(df_scores.fp + df_scores.tn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "3261064c-3572-4bec-a305-911c5317ac60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>threshold</th>\n",
       "      <th>tp</th>\n",
       "      <th>fp</th>\n",
       "      <th>fn</th>\n",
       "      <th>tn</th>\n",
       "      <th>tpr</th>\n",
       "      <th>fpr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00</td>\n",
       "      <td>386</td>\n",
       "      <td>1023</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.01</td>\n",
       "      <td>385</td>\n",
       "      <td>914</td>\n",
       "      <td>1</td>\n",
       "      <td>109</td>\n",
       "      <td>0.997409</td>\n",
       "      <td>0.893451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02</td>\n",
       "      <td>384</td>\n",
       "      <td>830</td>\n",
       "      <td>2</td>\n",
       "      <td>193</td>\n",
       "      <td>0.994819</td>\n",
       "      <td>0.811339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03</td>\n",
       "      <td>383</td>\n",
       "      <td>766</td>\n",
       "      <td>3</td>\n",
       "      <td>257</td>\n",
       "      <td>0.992228</td>\n",
       "      <td>0.748778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.04</td>\n",
       "      <td>381</td>\n",
       "      <td>715</td>\n",
       "      <td>5</td>\n",
       "      <td>308</td>\n",
       "      <td>0.987047</td>\n",
       "      <td>0.698925</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   threshold   tp    fp  fn   tn       tpr       fpr\n",
       "0       0.00  386  1023   0    0  1.000000  1.000000\n",
       "1       0.01  385   914   1  109  0.997409  0.893451\n",
       "2       0.02  384   830   2  193  0.994819  0.811339\n",
       "3       0.03  383   766   3  257  0.992228  0.748778\n",
       "4       0.04  381   715   5  308  0.987047  0.698925"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_scores.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "d7fdbccb-a4ce-4919-b2db-78c8df848f6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(df_scores[\"threshold\"], df_scores[\"tpr\"], label=\"TPR\")\n",
    "plt.plot(df_scores[\"threshold\"], df_scores[\"fpr\"], label=\"FPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a33eeb0b-feaf-484b-ac7f-61c10c35e663",
   "metadata": {},
   "source": [
    "* threshold=0 is the dumy model, that predicts everyone is churning. In this case bose TPR and FPR are 1.\n",
    "* We want to minimize FPR and maximiz TPR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7b2fe23-289d-416f-b088-ab60cf51b4c5",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Random Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "cba5f791-3afc-4bc1-aa73-2d9666fefa72",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([4.17022005e-01, 7.20324493e-01, 1.14374817e-04, ...,\n",
       "       7.73916250e-01, 3.34276405e-01, 8.89982208e-02])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(1)\n",
    "y_rand = np.random.uniform(0, 1, size=len(y_val))\n",
    "y_rand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "0f1c8e4c-1382-4793-8b01-2c5e5e71b3c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5017743080198722"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Accuracy of random model\n",
    "((y_rand >= 0.5) == y_val).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "55ab9059-13b3-4ba3-81cd-0f4f08d8a503",
   "metadata": {},
   "outputs": [],
   "source": [
    "# write tpr, fpr into dtaframe\n",
    "# as previous code, but in a function\n",
    "def tpr_fpr_dataframe(y_val, y_pred):\n",
    "    scores = []\n",
    "    thresholds = np.linspace(0,1,101)\n",
    "\n",
    "    for t in thresholds:\n",
    "        actual_positive = (y_val == 1)\n",
    "        actual_negative = (y_val == 0)\n",
    "\n",
    "        predict_positive = (y_pred >= t)\n",
    "        predict_negative = (y_pred < t)\n",
    "    \n",
    "        tp = (predict_positive & actual_positive).sum()\n",
    "        tn = (predict_negative & actual_negative).sum()\n",
    "    \n",
    "        fp = (predict_positive & actual_negative).sum()\n",
    "        fn = (predict_negative & actual_positive).sum()\n",
    "    \n",
    "        scores.append((t, tp, fp, fn, tn))\n",
    "        \n",
    "    columns =[\"threshold\", \"tp\", \"fp\", \"fn\", \"tn\"]\n",
    "    df_scores = pd.DataFrame(scores, columns=columns)\n",
    "        \n",
    "    df_scores[\"tpr\"] = df_scores.tp/(df_scores.tp + df_scores.fn)\n",
    "    df_scores[\"fpr\"] = df_scores.fp/(df_scores.fp + df_scores.tn)\n",
    "    \n",
    "    return df_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "9c0b8e0a-f96b-4af2-8912-8ff759666655",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_rand = tpr_fpr_dataframe(y_val, y_rand)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "3eba1737-0699-4790-91a4-931c9cf43d48",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(df_rand[\"threshold\"], df_rand[\"tpr\"], label=\"TPR\")\n",
    "plt.plot(df_rand[\"threshold\"], df_rand[\"fpr\"], label=\"FPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c42d9e17-ad71-4533-a804-21257274a465",
   "metadata": {},
   "source": [
    "### Ideal Model\n",
    "* All predictions are correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "bcf089d9-59b2-461b-8a66-b000b10c50d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1023, 386)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_neg = (y_val == 0).sum()\n",
    "num_pos = (y_val == 1).sum()\n",
    "num_neg, num_pos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "ae00074b-f951-49d7-9af4-f76a9bb8d202",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 1, 1, 1])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# create ideal validation set\n",
    "y_ideal = np.repeat([0,1],[num_neg, num_pos])\n",
    "y_ideal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "15496b90-a744-4770-96bf-d5b0e6de3a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create predictions (numbers between 0 and 1)\n",
    "y_ideal_pred = np.linspace(0, 1, len(y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "55bb59b9-0adf-45f8-bfd8-163140b8544e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# accuracy of ideal model\n",
    "# 72.6 of the customers are not churning\n",
    "((y_ideal_pred >= 0.726) == y_ideal).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "a1d8936d-d4e7-41c9-a7a1-4472ab2b53a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ideal =  tpr_fpr_dataframe(y_ideal, y_ideal_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "81c84263-c5ad-4212-9fde-c4905145d7dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>threshold</th>\n",
       "      <th>tp</th>\n",
       "      <th>fp</th>\n",
       "      <th>fn</th>\n",
       "      <th>tn</th>\n",
       "      <th>tpr</th>\n",
       "      <th>fpr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00</td>\n",
       "      <td>386</td>\n",
       "      <td>1023</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.01</td>\n",
       "      <td>386</td>\n",
       "      <td>1008</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.985337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02</td>\n",
       "      <td>386</td>\n",
       "      <td>994</td>\n",
       "      <td>0</td>\n",
       "      <td>29</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.971652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03</td>\n",
       "      <td>386</td>\n",
       "      <td>980</td>\n",
       "      <td>0</td>\n",
       "      <td>43</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.957967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.04</td>\n",
       "      <td>386</td>\n",
       "      <td>966</td>\n",
       "      <td>0</td>\n",
       "      <td>57</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.944282</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   threshold   tp    fp  fn  tn  tpr       fpr\n",
       "0       0.00  386  1023   0   0  1.0  1.000000\n",
       "1       0.01  386  1008   0  15  1.0  0.985337\n",
       "2       0.02  386   994   0  29  1.0  0.971652\n",
       "3       0.03  386   980   0  43  1.0  0.957967\n",
       "4       0.04  386   966   0  57  1.0  0.944282"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ideal.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "a3ca2a6c-640e-4a8c-bfef-0f41910db7c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(df_ideal[\"threshold\"], df_ideal[\"tpr\"], label=\"TPR\")\n",
    "plt.plot(df_ideal[\"threshold\"], df_ideal[\"fpr\"], label=\"FPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e40c7294-24b6-4b11-804d-0bdf95e35581",
   "metadata": {},
   "source": [
    "* cut at threshold 0.726\n",
    "* TPR is maximized and FPR is minimized for this threshold\n",
    "* Such an ideal moel does not exist in reality, but it helps us to see how good our model is"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4363131-ecaf-44c7-aea2-5cc665c13dbc",
   "metadata": {},
   "source": [
    "### Putting everything together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "9e469e37-84ff-4cf1-84b7-63b5558af2f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(df_scores[\"threshold\"], df_scores[\"tpr\"], label=\"TPR\")\n",
    "plt.plot(df_scores[\"threshold\"], df_scores[\"fpr\"], label=\"FPR\")\n",
    "\n",
    "#plt.plot(df_rand[\"threshold\"], df_rand[\"tpr\"], label=\"TPR\")\n",
    "#plt.plot(df_rand[\"threshold\"], df_rand[\"fpr\"], label=\"FPR\")\n",
    "\n",
    "plt.plot(df_ideal[\"threshold\"], df_ideal[\"tpr\"], label=\"TPR\", color=\"black\")\n",
    "plt.plot(df_ideal[\"threshold\"], df_ideal[\"fpr\"], label=\"FPR\", color=\"black\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faf97765-3835-466c-a105-865fcff92777",
   "metadata": {},
   "source": [
    "* Comparing two models is not always intuative\n",
    "* These two models have different thresholds\n",
    "    * Compare TPR and FPR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "c9a179e8-4b85-4df6-8ad9-eebe1c78901b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.plot(df_scores.fpr, df_scores.tpr, label=\"model\")\n",
    "plt.plot(df_rand.fpr, df_rand.tpr, label=\"random\")\n",
    "plt.plot(df_ideal.fpr, df_ideal.tpr, label=\"ideal\")\n",
    "plt.xlabel(\"FPR\")\n",
    "plt.ylabel(\"TPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "936e4a82-9ae3-4402-a065-9968e83ca572",
   "metadata": {},
   "source": [
    "* We want to get as close to the ideal model\n",
    "* To see how goo a model is we can draw the following plot:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "bcc7efec-0378-4be2-9731-772eb93bd0a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.plot(df_scores.fpr, df_scores.tpr, label=\"model\")\n",
    "plt.plot([0, 1], [0,1], label=\"random\")\n",
    "plt.xlabel(\"FPR\")\n",
    "plt.ylabel(\"TPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6adf70-1e73-436a-bb01-8b128341cf56",
   "metadata": {},
   "source": [
    "* We can uses sklearn to plot the ROC curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "3608c17c-7052-48fc-844a-81eef9a46db3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "f655cabe-ce25-41a1-848b-f9fdbae97540",
   "metadata": {},
   "outputs": [],
   "source": [
    "fpr, tpr, thresholds = roc_curve(y_val, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "54c46165-8b39-4d68-9051-b26d736d9d77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.plot(fpr, tpr, label=\"model\")\n",
    "plt.plot([0, 1], [0,1], label=\"random\")\n",
    "plt.xlabel(\"FPR\")\n",
    "plt.ylabel(\"TPR\")\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab48cd56-e124-45ed-a5d6-77ddd9d83545",
   "metadata": {},
   "source": [
    "* The differences occur, because our evaluation was only for 101 thresholds and sklearn is more accurate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93ce16f8-f51e-4dfe-9498-8edf09dcffcb",
   "metadata": {},
   "source": [
    "## ROC AUC\n",
    "\n",
    "* Area under the ROC curve - useful metric for binary classification\n",
    "    * For the random model, thes area under the curve is AUC=0.5\n",
    "    * For the ideal model, thes area under the curve is AUC=1\n",
    "    * I.e. for all models AUC is between 0.5 and 1\n",
    "* Getting the average prediction and the spread within predictions\n",
    "* We can use sklearn to calculate AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "5d993a0c-b4cd-4b5f-8d0e-95c60deeebe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "f62b6fdd-b1dc-4cb7-a9f6-1b0aedc61dc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8438099868820242"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# for the scores calculted by sklearn\n",
    "auc(fpr, tpr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "a390ec4a-a3ba-4fb4-8d3b-f0b0d9d02888",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8438391098010017"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# for our manually calculates scores\n",
    "auc(df_scores.fpr, df_scores.tpr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "f8348150-4153-4a28-a623-d7d89992fd1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9999430203759136"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# for the ideal model\n",
    "auc(df_ideal.fpr, df_ideal.tpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d2da908-0fb6-46f7-baf4-1fb604c003dc",
   "metadata": {},
   "source": [
    "* The differences again are due to the lower precission of our manually calculations\n",
    "* This calculation can also be done in one step, directly from the validations and predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "d0ef035d-a7fc-4ab4-b98f-1ac3388a6e2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "a604126e-541b-4892-a1ce-ef59a056c01b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8438099868820242"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_score(y_val, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ff55088-fa09-4400-be34-b3010fba2bd9",
   "metadata": {},
   "source": [
    "* The above line is a shortcut for:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "4f9043a6-9c6a-4233-8478-42e68e3d8f39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8438099868820242"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fpr, tpr, thresholds = roc_curve(y_val, y_pred)\n",
    "auc(fpr, tpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "442adb06-b440-4306-8df8-6b5b0229c3d4",
   "metadata": {},
   "source": [
    "**AUC Interpretation**\n",
    "\n",
    "* AUC tells us, what the propabilty is of a randomly selected positive sample has a higher score than a randomly selected negative sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "bacd7007-5dbf-44d2-a08c-999082bb0745",
   "metadata": {},
   "outputs": [],
   "source": [
    "# scores for negative samples\n",
    "neg = y_pred[y_val==0]\n",
    "# scores for positive samples\n",
    "pos = y_pred[y_val==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "b00c5d9b-633e-46c9-86fb-25bf845ce664",
   "metadata": {},
   "outputs": [],
   "source": [
    "# randomly select a positive sample\n",
    "import random\n",
    "\n",
    "pos_ind = random.randint(0, len(pos) -1)\n",
    "neg_ind = random.randint(0, len(neg) -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "bc3579ac-3be3-4b4d-8b03-a555c5fac729",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compare the scores of the positive and negative sample\n",
    "pos[pos_ind] > neg[neg_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "b9e20165-74f5-409e-a32f-c84e682fa6de",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8397"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Do this many times ...\n",
    "n = 10000\n",
    "success = 0\n",
    "for i in range(n):\n",
    "    pos_ind = random.randint(0, len(pos) -1)\n",
    "    neg_ind = random.randint(0, len(neg) -1)\n",
    "    \n",
    "    if pos[pos_ind] > neg[neg_ind]:\n",
    "        success += 1\n",
    "        \n",
    "success/n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a469d591-d4e1-476d-b007-1133be3e0ec3",
   "metadata": {},
   "source": [
    "* Our result is pretty close to the previous calculated AUC value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "1cfc4110-63a1-40b1-a5d9-bae7d21d376d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# alternative vectorized implementation\n",
    "n = 50000\n",
    "pos_ind = np.random.randint(0, len(pos), size=n)\n",
    "neg_ind = np.random.randint(0, len(neg), size=n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "514be844-29ef-41e8-9b70-8b0e95bf2b8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.84484"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(pos[pos_ind] > neg[neg_ind]).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acbd1364-4c2f-4782-8747-09cdba32af77",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Cross Validation\n",
    "\n",
    "* Evaluating the same model on different subsets of data\n",
    "* Getting the average prediction and the spread within predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "d52bd146-1126-4145-b609-5762483f1fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train a model\n",
    "def train(df, y_train):\n",
    "    dicts = df[categorical + numerical].to_dict(orient=\"records\")\n",
    "    dv = DictVectorizer(sparse=False)\n",
    "    X_train = dv.fit_transform(dicts)\n",
    "    \n",
    "    model = LogisticRegression(solver=\"liblinear\")\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    return dv, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "a6a000a6-73b4-4de1-98b7-88a2b14be957",
   "metadata": {},
   "outputs": [],
   "source": [
    "dv, model = train(df_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "66e2a623-bb7d-4014-a785-e8520d85ed5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make predictions \n",
    "def predict(df, dv, model):\n",
    "    dicts = df[categorical + numerical].to_dict(orient=\"records\")\n",
    "    X = dv.transform(dicts)\n",
    "    y_pred = model.predict_proba(X)[:,1]\n",
    "    return y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc46c450-541f-464d-906c-793827bce60d",
   "metadata": {},
   "source": [
    "* To do the k-fold split we can use sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "39a4bd14-12e1-468a-8269-b2722f66a529",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: tqdm in /home/jens/miniconda3/envs/ml-zoomcamp/lib/python3.9/site-packages (4.63.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "eee40d6f-fb02-4b7a-8453-932667ff90e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "99e898c4-5a9f-4250-86aa-dd998ce84167",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "ca9bcfc2-b4c6-47fb-a60a-b736242653be",
   "metadata": {},
   "outputs": [],
   "source": [
    "kfold = KFold(n_splits=10, shuffle=True, random_state=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b267a65b-6c06-4619-b5e3-03d02a575d89",
   "metadata": {},
   "source": [
    "* ```kfold.split(df_train_full)``` generates an iterator with indices for training and validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "3813e526-d414-42ea-9636-8fb8e8387587",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx, val_idx = next(kfold.split(df_train_full))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "c10e0edc-c08d-45b8-8992-79deade64ee0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full data length: 5634, train data length: 5070, validation data length 564\n"
     ]
    }
   ],
   "source": [
    "print(f\"full data length: {len(df_train_full)}, train data length: {len(train_idx)}, validation data length {len(val_idx)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "4870de48-1af3-4ab2-9873-0f177a757a5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10it [00:02,  4.19it/s]\n"
     ]
    }
   ],
   "source": [
    "# Loop through different folds\n",
    "scores = []\n",
    "for train_idx, val_idx in tqdm(kfold.split(df_train_full)):\n",
    "    df_train = df_train_full.iloc[train_idx]\n",
    "    df_val = df_train_full.iloc[val_idx]\n",
    "    \n",
    "    y_train = df_train.churn.values\n",
    "    y_val = df_val.churn.values\n",
    "    \n",
    "    dv, model = train(df_train, y_train)\n",
    "    y_pred = predict(df_val, dv, model)\n",
    "    \n",
    "    roc_auc = roc_auc_score(y_val, y_pred)\n",
    "    scores.append(roc_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "deed409d-79ea-4fdb-a521-7d873246a6bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.8493392490816277,\n",
       " 0.8413366336633662,\n",
       " 0.8590269587894291,\n",
       " 0.8330260883877869,\n",
       " 0.8242710918114144,\n",
       " 0.8416250416250417,\n",
       " 0.8437154021491371,\n",
       " 0.8223355471220746,\n",
       " 0.8450570623981029,\n",
       " 0.8611811367685119]"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "ea898a5e-9a02-4632-981b-e26a678b74e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean auc: 0.842, spread: 1.23%\n"
     ]
    }
   ],
   "source": [
    "print(f\"mean auc: {np.mean(scores):.3f}, spread: {np.std(scores)*100:.3}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d5aa5c9-cde0-447d-aaa2-7353b5584e51",
   "metadata": {},
   "source": [
    "* In logisgtic regression a regularization parameter \"C\" can be included\n",
    "* We can then test different regularization paramters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "81e7e6a1-b9f0-43c3-81e0-a80362a67987",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train a model\n",
    "def train2(df, y_train, C=1.0):\n",
    "    dicts = df[categorical + numerical].to_dict(orient=\"records\")\n",
    "    dv = DictVectorizer(sparse=False)\n",
    "    X_train = dv.fit_transform(dicts)\n",
    "    \n",
    "    model = LogisticRegression(solver=\"liblinear\", C=C)#, max_iter=10000)\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    return dv, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "2739be4e-794f-42d2-8140-f83e03583bb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  4.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 0.001\n",
      "mean auc: 0.825, spread: 1.31%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  4.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 0.01\n",
      "mean auc: 0.839, spread: 0.872%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  4.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 0.1\n",
      "mean auc: 0.841, spread: 0.748%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  3.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 0.5\n",
      "mean auc: 0.841, spread: 0.74%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  3.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 1\n",
      "mean auc: 0.841, spread: 0.739%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:01<00:00,  3.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regularization parameter C: 10\n",
      "mean auc: 0.841, spread: 0.745%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Loop through different folds\n",
    "n_splits = 5\n",
    "\n",
    "for C in [0.001, 0.01, 0.1, 0.5, 1, 10]: \n",
    "    scores = []\n",
    "    \n",
    "    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=1)\n",
    "    for train_idx, val_idx in tqdm(kfold.split(df_train_full), total=n_splits):\n",
    "        \n",
    "        df_train = df_train_full.iloc[train_idx]\n",
    "        df_val = df_train_full.iloc[val_idx]\n",
    "    \n",
    "        y_train = df_train.churn.values\n",
    "        y_val = df_val.churn.values\n",
    "    \n",
    "        dv, model = train2(df_train, y_train, C=C)\n",
    "        y_pred = predict(df_val, dv, model)\n",
    "    \n",
    "        roc_auc = roc_auc_score(y_val, y_pred)\n",
    "        scores.append(roc_auc)\n",
    "    \n",
    "    print(f\"Regularization parameter C: {C}\")    \n",
    "    print(f\"mean auc: {np.mean(scores):.3f}, spread: {np.std(scores)*100:.3}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f52f0065-c48a-4f2f-8361-8a56927a454f",
   "metadata": {},
   "source": [
    "* Train the final model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "c3cea9ed-c35c-4b0f-8e5c-f292164cd2cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8579400803839363"
      ]
     },
     "execution_count": 136,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dv, model = train2(df_train_full, df_train_full.churn.values, C=1)\n",
    "y_pred = predict(df_test, dv, model)\n",
    "    \n",
    "roc_auc = roc_auc_score(y_test, y_pred)\n",
    "roc_auc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42ae1539-1f7a-4c9f-87fe-65d381f146b2",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "* Metric: A single number that describes the performance of a model\n",
    "* Accuracy: fraction of correct answers; sometimes misleading\n",
    "* Precision and Recall are less misleading when we have class imbalance\n",
    "* ROC curve: A way to evaluate the performance at all thresholds; ok to use with class imbalance\n",
    "* K-fold Cross Validation: more reliable estimate for performance (mean + std)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0e291be-866a-4404-8ce5-9a2ed45b4d9a",
   "metadata": {},
   "source": [
    "# Explore More\n",
    "\n",
    "* Check presicion and recall for the dummy model\n",
    "* F1 score = 2 * P * R/(P + R)\n",
    "* evaluate precision and recall at different thresholds, plot P vs R - this way you will get the precision recall curve (similar ti ROC curve)\n",
    "* Area under the PR curve is also a useful metric\n",
    "\n",
    "Other projects\n",
    "\n",
    "* calculate the metrics for the datasets from the previous week\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2b83b0f-6670-48a4-9277-317004101ebf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-zoomcamp",
   "language": "python",
   "name": "ml-zoomcamp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}