From bc29bb7e3246b3116fb1f3a16df7e8dddffffcab Mon Sep 17 00:00:00 2001 From: Aman Malhotra Date: Wed, 5 Oct 2022 13:49:46 -0700 Subject: [PATCH] [Feature]Add Online Explainability notebooks for SageMaker Clarify (#3613) * Add Online Explainability notebooks for SageMaker Clarify * Correcting text in clean-up sections of online explainability example notebooks * Updating install commands for captum and sagemaker pypy packages * debug captum installation * change instance type Co-authored-by: Aaron Markham Co-authored-by: atqy <95724753+atqy@users.noreply.github.com> Co-authored-by: atqy --- sagemaker-clarify/index.rst | 9 + .../code/inference.py | 57 + .../code/requirements.txt | 3 + ...xplainability_with_sagemaker_clarify.ipynb | 1097 +++++++++++++++++ .../scripts/train.py | 109 ++ ...xplainability_with_sagemaker_clarify.ipynb | 1051 ++++++++++++++++ 6 files changed, 2326 insertions(+) create mode 100644 sagemaker-clarify/online_explainability/natural_language_processing/code/inference.py create mode 100644 sagemaker-clarify/online_explainability/natural_language_processing/code/requirements.txt create mode 100644 sagemaker-clarify/online_explainability/natural_language_processing/nlp_online_explainability_with_sagemaker_clarify.ipynb create mode 100644 sagemaker-clarify/online_explainability/natural_language_processing/scripts/train.py create mode 100644 sagemaker-clarify/online_explainability/tabular/tabular_online_explainability_with_sagemaker_clarify.ipynb diff --git a/sagemaker-clarify/index.rst b/sagemaker-clarify/index.rst index fc5d79b652..7bca215dcf 100644 --- a/sagemaker-clarify/index.rst +++ b/sagemaker-clarify/index.rst @@ -26,3 +26,12 @@ SageMaker Clarify Model Monitoring :maxdepth: 1 ../sagemaker_model_monitor/fairness_and_explainability/SageMaker-Model-Monitor-Fairness-and-Explainability + +SageMaker Clarify Online Explainability +--------------------------------------- + +.. toctree:: + :maxdepth: 1 + + online_explainability/tabular/tabular_online_explainability_with_sagemaker_clarify + online_explainability/natural_language_processing/nlp_online_explainability_with_sagemaker_clarify diff --git a/sagemaker-clarify/online_explainability/natural_language_processing/code/inference.py b/sagemaker-clarify/online_explainability/natural_language_processing/code/inference.py new file mode 100644 index 0000000000..79d4b72789 --- /dev/null +++ b/sagemaker-clarify/online_explainability/natural_language_processing/code/inference.py @@ -0,0 +1,57 @@ +from io import StringIO +import numpy as np +import os +import pandas as pd +import json +from transformers import AutoTokenizer, AutoModelForSequenceClassification +import torch +from typing import Any, Dict, List + + +def model_fn(model_dir: str) -> Dict[str, Any]: + """ + Load the model for inference + """ + model_path = os.path.join(model_dir, "model") + + # Load HuggingFace tokenizer. + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + + # Load HuggingFace model from disk. + model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + model_dict = {"model": model, "tokenizer": tokenizer} + return model_dict + + +def predict_fn(input_data: List, model: Dict) -> np.ndarray: + """ + Apply model to the incoming request + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = model["tokenizer"] + huggingface_model = model["model"] + + encoded_input = tokenizer(input_data, truncation=True, padding=True, max_length=128, return_tensors="pt") + encoded_input = {k: v.to(device) for k, v in encoded_input.items()} + with torch.no_grad(): + output = huggingface_model(input_ids=encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"]) + res = torch.nn.Softmax(dim=1)(output.logits).detach().cpu().numpy()[:, 1] + return res + + +def input_fn(request_body: str, request_content_type: str) -> List[str]: + """ + Deserialize and prepare the prediction input + """ + if request_content_type == "application/json": + sentences = [json.loads(request_body)] + + elif request_content_type == "text/csv": + # We have a single column with the text. + sentences = list(pd.read_csv(StringIO(request_body), header=None).values[:, 0].astype(str)) + else: + sentences = request_body + return sentences diff --git a/sagemaker-clarify/online_explainability/natural_language_processing/code/requirements.txt b/sagemaker-clarify/online_explainability/natural_language_processing/code/requirements.txt new file mode 100644 index 0000000000..721806cd03 --- /dev/null +++ b/sagemaker-clarify/online_explainability/natural_language_processing/code/requirements.txt @@ -0,0 +1,3 @@ +transformers==4.2.1 +torch==1.7.1 +pandas \ No newline at end of file diff --git a/sagemaker-clarify/online_explainability/natural_language_processing/nlp_online_explainability_with_sagemaker_clarify.ipynb b/sagemaker-clarify/online_explainability/natural_language_processing/nlp_online_explainability_with_sagemaker_clarify.ipynb new file mode 100644 index 0000000000..ac77aa7e77 --- /dev/null +++ b/sagemaker-clarify/online_explainability/natural_language_processing/nlp_online_explainability_with_sagemaker_clarify.ipynb @@ -0,0 +1,1097 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# NLP Online Explainability with SageMaker Clarify" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* [Introduction](#Introduction)\n", + "* [General Setup](#General-Setup)\n", + " * [Install dependencies](#Install-dependencies)\n", + " * [Import libraries](#Import-libraries)\n", + " * [Set configurations](#Set-configurations)\n", + " * [Create serializer and deserializer](#Create-serializer-and-deserializer)\n", + " * [For visualization](#For-visualization)\n", + "* [Prepare data](#Prepare-data)\n", + " * [Download data](#Download-data)\n", + " * [Loading the data](#Loading-the-data)\n", + " * [Data preparation for model training](#Data-preparation-for-model-training)\n", + " * [Upload the dataset](#Upload-the-dataset)\n", + "* [Train and Deploy Hugging Face Model](#Train-and-Deploy-Hugging-Face-Model)\n", + " * [Train model with Hugging Face estimator](#Train-model-with-Hugging-Face-estimator)\n", + " * [Download the trained model files](#Download-the-trained-model-files)\n", + " * [Prepare model container definition](#Prepare-model-container-definition)\n", + "* [Create endpoint](#Create-endpoint)\n", + " * [Create model](#Create-model)\n", + " * [Create endpoint config](#Create-endpoint-config)\n", + " * [Create endpoint](#Create-endpoint)\n", + "* [Invoke endpoint](#Invoke-endpoint)\n", + " * [Single record request](#Single-record-request)\n", + " * [Single record request, no explanation](#Single-record-request,-no-explanation)\n", + " * [Batch request, explain both](#Batch-request,-explain-both)\n", + " * [Batch request, explain none](#Batch-request,-explain-none)\n", + " * [Batch request with more records, explain some of the records](#Batch-request-with-more-records,-explain-some-of-the-records)\n", + "* [Cleanup](#Cleanup)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. \n", + "\n", + "SageMaker Clarify currently supports explainability for SageMaker models as an offline processing job. This example notebook showcases a new feature for explainability on a [SageMaker real-time inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) endpoint, a.k.a. online explainability.\n", + "\n", + "This example notebook walks you through: \n", + "1. Key terms and concepts needed to understand SageMaker Clarify\n", + "1. Trained the model on the Women's ecommerce clothing reviews dataset.\n", + "1. Create a model from trained model artifacts, create an endpoint configuration with the new SageMaker Clarify explainer configuration, and create an endpoint using the same explainer configuration.\n", + "1. Invoke the endpoint with single and batch request with different `EnableExplanations` query.\n", + "1. Explaining the importance of the various input features on the model's decision.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## General Setup\n", + "\n", + "We recommend you use `Python 3 (Data Science)` kernel on SageMaker Studio or `conda_python3` kernel on SageMaker Notebook Instance." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Install dependencies\n", + "\n", + "The following packages are required by data preparation and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"datasets[s3]==1.6.2\" \"transformers==4.6.1\" --upgrade" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Upgrade the SageMaker Python SDK, and captum is used to visualize the feature attributions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sagemaker --upgrade\n", + "!pip install boto3 --upgrade\n", + "!pip install botocore --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install captum --upgrade" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "import csv\n", + "import pandas as pd\n", + "import numpy as np\n", + "import pprint\n", + "import tarfile\n", + "\n", + "from sagemaker.huggingface import HuggingFace\n", + "from datasets import Dataset\n", + "from datasets.filesystems import S3FileSystem\n", + "from captum.attr import visualization\n", + "from sklearn.model_selection import train_test_split\n", + "from sagemaker import get_execution_role, Session\n", + "from sagemaker.serializers import CSVSerializer\n", + "from sagemaker.deserializers import JSONDeserializer\n", + "from sagemaker.utils import unique_name_from_base" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "boto3_session = boto3.session.Session()\n", + "sagemaker_client = boto3.client(\"sagemaker\")\n", + "sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\")\n", + "\n", + "# Initialize sagemaker session\n", + "sagemaker_session = Session(\n", + " boto_session=boto3_session,\n", + " sagemaker_client=sagemaker_client,\n", + " sagemaker_runtime_client=sagemaker_runtime_client,\n", + ")\n", + "\n", + "region = sagemaker_session.boto_region_name\n", + "print(f\"Region: {region}\")\n", + "\n", + "role = get_execution_role()\n", + "print(f\"Role: {role}\")\n", + "\n", + "prefix = unique_name_from_base(\"DEMO-NLP-Women-Clothing\")\n", + "\n", + "s3_bucket = sagemaker_session.default_bucket()\n", + "s3_prefix = f\"sagemaker/{prefix}\"\n", + "s3_key = f\"s3://{s3_bucket}/{s3_prefix}\"\n", + "print(f\"Demo S3 key: {s3_key}\")\n", + "\n", + "model_name = f\"{prefix}-model\"\n", + "print(f\"Demo model name: {model_name}\")\n", + "endpoint_config_name = f\"{prefix}-endpoint-config\"\n", + "print(f\"Demo endpoint config name: {endpoint_config_name}\")\n", + "endpoint_name = f\"{prefix}-endpoint\"\n", + "print(f\"Demo endpoint name: {endpoint_name}\")\n", + "\n", + "# SageMaker Clarify model directory name\n", + "model_path = \"model/\"\n", + "\n", + "# Instance type for training and hosting\n", + "instance_type = \"ml.m5.xlarge\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create serializer and deserializer\n", + "\n", + "CSV serializer to serialize test data to string" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "csv_serializer = CSVSerializer()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "JSON deserializer to deserialize invoke endpoint response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "json_deserializer = JSONDeserializer()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### For visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This method is a wrapper around the captum that helps produce visualizations for local explanations. It will\n", + "# visualize the attributions for the tokens with red or green colors for negative and positive attributions.\n", + "def visualization_record(\n", + " attributions, # list of attributions for the tokens\n", + " text, # list of tokens\n", + " pred, # the prediction value obtained from the endpoint\n", + " delta,\n", + " true_label, # the true label from the dataset\n", + " normalize=True, # normalizes the attributions so that the max absolute value is 1. Yields stronger colors.\n", + " max_frac_to_show=0.05, # what fraction of tokens to highlight, set to 1 for all.\n", + " match_to_pred=False, # whether to limit highlights to red for negative predictions and green for positive ones.\n", + " # By enabling `match_to_pred` you show what tokens contribute to a high/low prediction not those that oppose it.\n", + "):\n", + " if normalize:\n", + " attributions = attributions / max(max(attributions), max(-attributions))\n", + " if max_frac_to_show is not None and max_frac_to_show < 1:\n", + " num_show = int(max_frac_to_show * attributions.shape[0])\n", + " sal = attributions\n", + " if pred < 0.5:\n", + " sal = -sal\n", + " if not match_to_pred:\n", + " sal = np.abs(sal)\n", + " top_idxs = np.argsort(-sal)[:num_show]\n", + " mask = np.zeros_like(attributions)\n", + " mask[top_idxs] = 1\n", + " attributions = attributions * mask\n", + " return visualization.VisualizationDataRecord(\n", + " attributions,\n", + " pred,\n", + " int(pred > 0.5),\n", + " true_label,\n", + " attributions.sum() > 0,\n", + " attributions.sum(),\n", + " text,\n", + " delta,\n", + " )\n", + "\n", + "\n", + "def visualize_result(result, all_labels):\n", + " if not result[\"explanations\"]:\n", + " print(f\"No Clarify explanations for the record(s)\")\n", + " return\n", + " all_explanations = result[\"explanations\"][\"kernel_shap\"]\n", + " all_predictions = list(csv.reader(result[\"predictions\"][\"data\"].splitlines()))\n", + "\n", + " labels = []\n", + " predictions = []\n", + " explanations = []\n", + "\n", + " for i, expl in enumerate(all_explanations):\n", + " if expl:\n", + " labels.append(all_labels[i])\n", + " predictions.append(all_predictions[i])\n", + " explanations.append(all_explanations[i])\n", + "\n", + " attributions_dataset = [\n", + " np.array([attr[\"attribution\"][0] for attr in expl[0][\"attributions\"]])\n", + " for expl in explanations\n", + " ]\n", + " tokens_dataset = [\n", + " np.array([attr[\"description\"][\"partial_text\"] for attr in expl[0][\"attributions\"]])\n", + " for expl in explanations\n", + " ]\n", + "\n", + " # You can customize the following display settings\n", + " normalize = True\n", + " max_frac_to_show = 1\n", + " match_to_pred = False\n", + " vis = []\n", + " for attr, token, pred, label in zip(attributions_dataset, tokens_dataset, predictions, labels):\n", + " vis.append(\n", + " visualization_record(\n", + " attr, token, float(pred[0]), 0.0, label, normalize, max_frac_to_show, match_to_pred\n", + " )\n", + " )\n", + " _ = visualization.visualize_text(vis)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download data\n", + "Data Source: `https://www.kaggle.com/nicapotato/womens-ecommerce-clothing-reviews/`\n", + "\n", + "The Women’s E-Commerce Clothing Reviews dataset has been made available under a Creative Commons Public Domain license. A copy of the dataset has been saved in a sample data Amazon S3 bucket. In the first section of the notebook, we’ll walk through how to download the data and get started with building the ML workflow as a SageMaker pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! curl https://sagemaker-sample-files.s3.amazonaws.com/datasets/tabular/womens_clothing_ecommerce/Womens_Clothing_E-Commerce_Reviews.csv > womens_clothing_reviews_dataset.csv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\"womens_clothing_reviews_dataset.csv\", index_col=[0])\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Context**\n", + "\n", + "The Women’s Clothing E-Commerce dataset contains reviews written by customers. Because the dataset contains real commercial data, it has been anonymized, and any references to the company in the review text and body have been replaced with “retailer”.\n", + "\n", + "\n", + "\n", + "**Content**\n", + "\n", + "The dataset contains 23486 rows and 10 columns. Each row corresponds to a customer review.\n", + "\n", + "The columns include:\n", + "\n", + "* Clothing ID: Integer Categorical variable that refers to the specific piece being reviewed.\n", + "* Age: Positive Integer variable of the reviewer's age.\n", + "* Title: String variable for the title of the review.\n", + "* Review Text: String variable for the review body.\n", + "* Rating: Positive Ordinal Integer variable for the product score granted by the customer from 1 Worst, to 5 Best.\n", + "* Recommended IND: Binary variable stating where the customer recommends the product where 1 is recommended, 0 is not recommended.\n", + "* Positive Feedback Count: Positive Integer documenting the number of other customers who found this review positive.\n", + "* Division Name: Categorical name of the product high level division.\n", + "* Department Name: Categorical name of the product department name.\n", + "* Class Name: Categorical name of the product class name.\n", + "\n", + "**Goal**\n", + "\n", + "To predict the sentiment of a review based on the text, and then explain the predictions using SageMaker Clarify." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data preparation for model training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Target Variable Creation\n", + "Since the dataset does not contain a column that indicates the sentiment of the customer reviews, lets create one. To do this, let's assume that reviews with a `Rating` of 4 or higher indicate positive sentiment and reviews with a `Rating` of 2 or lower indicate negative sentiment. Let's also assume that a `Rating` of 3 indicates neutral sentiment and exclude these rows from the dataset. Additionally, to predict the sentiment of a review, we are going to use the `Review Text` column; therefore let's remove rows that are empty in the `Review Text` column of the dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_target_column(df, min_positive_score, max_negative_score):\n", + " neutral_values = [i for i in range(max_negative_score + 1, min_positive_score)]\n", + " for neutral_value in neutral_values:\n", + " df = df[df[\"Rating\"] != neutral_value]\n", + " df[\"Sentiment\"] = df[\"Rating\"] >= min_positive_score\n", + " return df.replace({\"Sentiment\": {True: 1, False: 0}})\n", + "\n", + "\n", + "df = create_target_column(df, 4, 2)\n", + "df = df[~df[\"Review Text\"].isna()]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Train-Validation-Test splits" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The most common approach for model evaluation is using the train/validation/test split. Although this approach can be very effective in general, it can result in misleading results and potentially fail when used on classification problems with a severe class imbalance. Instead, the technique must be modified to stratify the sampling by the class label as below. Stratification ensures that all classes are well represented across the train, validation and test datasets.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = \"Sentiment\"\n", + "cols = \"Review Text\"\n", + "\n", + "X = df[cols]\n", + "y = df[target]\n", + "\n", + "# Data split: 11%(val) of the 90% (train and test) of the dataset ~ 10%; resulting in 80:10:10split\n", + "test_dataset_size = 0.10\n", + "val_dataset_size = 0.11\n", + "RANDOM_STATE = 42\n", + "\n", + "# Stratified train-val-test split\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=test_dataset_size, stratify=y, random_state=RANDOM_STATE\n", + ")\n", + "X_train, X_val, y_train, y_val = train_test_split(\n", + " X_train, y_train, test_size=val_dataset_size, stratify=y_train, random_state=RANDOM_STATE\n", + ")\n", + "\n", + "print(\n", + " \"Dataset: train \",\n", + " X_train.shape,\n", + " y_train.shape,\n", + " y_train.value_counts(dropna=False, normalize=True).to_dict(),\n", + ")\n", + "print(\n", + " \"Dataset: validation \",\n", + " X_val.shape,\n", + " y_val.shape,\n", + " y_val.value_counts(dropna=False, normalize=True).to_dict(),\n", + ")\n", + "print(\n", + " \"Dataset: test \",\n", + " X_test.shape,\n", + " y_test.shape,\n", + " y_test.value_counts(dropna=False, normalize=True).to_dict(),\n", + ")\n", + "\n", + "# Combine the independent columns with the label\n", + "df_train = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)\n", + "df_test = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)\n", + "df_val = pd.concat([X_val, y_val], axis=1).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "headers = df_test.columns.to_list()\n", + "feature_headers = headers[0]\n", + "label_header = headers[1]\n", + "print(f\"Feature names: {feature_headers}\")\n", + "print(f\"Label name: {label_header}\")\n", + "print(f\"Test data (without label column):\")\n", + "test_data = df_test.iloc[:, :1]\n", + "test_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have split the dataset into train, test, and validation datasets. We use the train and validation datasets during training process, and run Clarify on the test dataset.\n", + "\n", + "In the cell below, we convert the Pandas DataFrames into Hugging Face Datasets for downstream modeling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = Dataset.from_pandas(df_train)\n", + "val_dataset = Dataset.from_pandas(df_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Upload the dataset\n", + "Here, we upload the prepared datasets to S3 buckets so that we can train the model with the Hugging Face Estimator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# S3 key prefix for the datasets\n", + "s3 = S3FileSystem()\n", + "\n", + "# save train_dataset to s3\n", + "training_input_path = f\"{s3_key}/train\"\n", + "print(f\"training input path: {training_input_path}\")\n", + "train_dataset.save_to_disk(training_input_path, fs=s3)\n", + "\n", + "# save val_dataset to s3\n", + "val_input_path = f\"{s3_key}/test\"\n", + "print(f\"validation input path: {val_input_path}\")\n", + "val_dataset.save_to_disk(val_input_path, fs=s3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train and Deploy Hugging Face Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this step of the workflow, we use the [Hugging Face Estimator](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html) to load the pre-trained `distilbert-base-uncased` model and fine-tune the model on our dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train model with Hugging Face estimator\n", + "The hyperparameters defined below are parameters that are passed to the custom PyTorch code in [`scripts/train.py`](./scripts/train.py). The only required parameter is `model_name`. The other parameters like `epoch`, `train_batch_size` all have default values which can be overridden by setting their values here.\n", + "\n", + "The training job requires GPU instance type. Here, we use `ml.g4dn.xlarge`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters passed into the training job\n", + "hyperparameters = {\"epochs\": 1, \"model_name\": \"distilbert-base-uncased\"}\n", + "\n", + "huggingface_estimator = HuggingFace(\n", + " entry_point=\"train.py\",\n", + " source_dir=\"scripts\",\n", + " instance_type=\"ml.g4dn.xlarge\",\n", + " instance_count=1,\n", + " transformers_version=\"4.6.1\",\n", + " pytorch_version=\"1.7.1\",\n", + " py_version=\"py36\",\n", + " role=role,\n", + " hyperparameters=hyperparameters,\n", + " disable_profiler=True,\n", + " debugger_hook_config=False,\n", + ")\n", + "\n", + "# starting the train job with our uploaded datasets as input\n", + "huggingface_estimator.fit({\"train\": training_input_path, \"test\": val_input_path}, logs=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download the trained model files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! aws s3 cp {huggingface_estimator.model_data} model.tar.gz\n", + "! mkdir -p {model_path}\n", + "! tar -xvf model.tar.gz -C {model_path}/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare model container definition\n", + "\n", + "We are going to use the trained model files along with the HuggingFace Inference container to deploy the model to a SageMaker endpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tarfile.open(\"hf_model.tar.gz\", mode=\"w:gz\") as archive:\n", + " archive.add(model_path, recursive=True)\n", + " archive.add(\"code/\")\n", + "directory_name = s3_prefix.split(\"/\")[-1]\n", + "zipped_model_path = sagemaker_session.upload_data(\n", + " path=\"hf_model.tar.gz\", key_prefix=directory_name + \"/hf-model-sm\"\n", + ")\n", + "zipped_model_path" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create a new model object and then update its model artifact and inference script. The model object will be used to create the SageMaker model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = huggingface_estimator.create_model(name=model_name)\n", + "container_def = model.prepare_container_def(instance_type=instance_type)\n", + "container_def[\"ModelDataUrl\"] = zipped_model_path\n", + "container_def[\"Environment\"][\"SAGEMAKER_PROGRAM\"] = \"inference.py\"\n", + "pprint.pprint(container_def)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Create endpoint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create model\n", + "\n", + "The following parameters are required to create a SageMaker model:\n", + "\n", + "* `ExecutionRoleArn`: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deployment\n", + "\n", + "* `ModelName`: name of the SageMaker model.\n", + "\n", + "* `PrimaryContainer`: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_model(\n", + " ExecutionRoleArn=role,\n", + " ModelName=model_name,\n", + " PrimaryContainer=container_def,\n", + ")\n", + "print(f\"Model created: {model_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create endpoint config\n", + "\n", + "Create an endpoint configuration by calling the `create_endpoint_config` API. Here, supply the same `model_name` used in the `create_model` API call. The `create_endpoint_config` now supports the additional parameter `ClarifyExplainerConfig` to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the `ShapBaseline` parameter) or by a S3 baseline file (the `ShapBaselineUri` parameter). Please see the developer guide for the optional parameters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we use a special token as the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "baseline = [[\"\"]]\n", + "print(f\"SHAP baseline: {baseline}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `TextConfig` configured with `sentence` level granularity (When granularity is `sentence`, each sentence is a feature, and we need a few sentences per review for good visualization) and the language as English." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_endpoint_config(\n", + " EndpointConfigName=endpoint_config_name,\n", + " ProductionVariants=[\n", + " {\n", + " \"VariantName\": \"TestVariant\",\n", + " \"ModelName\": model_name,\n", + " \"InitialInstanceCount\": 1,\n", + " \"InstanceType\": instance_type,\n", + " }\n", + " ],\n", + " ExplainerConfig={\n", + " \"ClarifyExplainerConfig\": {\n", + " \"InferenceConfig\": {\"FeatureTypes\": [\"text\"]},\n", + " \"ShapConfig\": {\n", + " \"ShapBaselineConfig\": {\"ShapBaseline\": csv_serializer.serialize(baseline)},\n", + " \"TextConfig\": {\"Granularity\": \"sentence\", \"Language\": \"en\"},\n", + " },\n", + " }\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create endpoint\n", + "\n", + "Once you have your model and endpoint configuration ready, use the `create_endpoint` API to create your endpoint. The `endpoint_name` must be unique within an AWS Region in your AWS account. The `create_endpoint` API is synchronous in nature and returns an immediate response with the endpoint status being `Creating` state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_endpoint(\n", + " EndpointName=endpoint_name,\n", + " EndpointConfigName=endpoint_config_name,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wait for the endpoint to be in \"InService\" state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_session.wait_for_endpoint(endpoint_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Invoke endpoint\n", + "\n", + "There are expanding business needs and legislative regulations that require explanations of _why_ a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.\n", + "\n", + "Kernel SHAP algorithm requires a baseline (also known as background dataset). By definition, `baseline` should either be a S3 URI to the baseline dataset file, or an in-place list of records. Baseline dataset type shall be the same as the original request data type, and baseline records shall only include features. \n", + "\n", + "Below are the several different combination of endpoint invocation, call them one by one and visualize the explanations by running the subsequent cell. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Single record request\n", + "\n", + "Put only one record in the request body, and then send the request to the endpoint to get its predictions and explanations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_records = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Accept=\"text/csv\",\n", + " Body=csv_serializer.serialize(test_data.iloc[:num_records, :].to_numpy()),\n", + ")\n", + "pprint.pprint(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, df_test[label_header][:num_records])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Single record request, no explanation\n", + "\n", + "Use the `EnableExplanations` parameter to disable the explanations for this request." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_records = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Accept=\"text/csv\",\n", + " Body=csv_serializer.serialize(test_data.iloc[:num_records, :].to_numpy()),\n", + " EnableExplanations=\"`false`\", # Do not provide explanations\n", + ")\n", + "pprint.pprint(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, df_test[label_header][:num_records])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batch request, explain both\n", + "\n", + "Put two records in the request body, and then send the request to the endpoint to get their predictions and explanations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_records = 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Accept=\"text/csv\",\n", + " Body=csv_serializer.serialize(test_data.iloc[:num_records, :].to_numpy()),\n", + ")\n", + "pprint.pprint(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, df_test[label_header][:num_records])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batch request with more records, explain some of the records\n", + "\n", + "Put a few more records to the request body, and then use the `EnableExplanations` expression to filter the records to be explained according to their predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_records = 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Accept=\"text/csv\",\n", + " Body=csv_serializer.serialize(test_data.iloc[:num_records, :].to_numpy()),\n", + " EnableExplanations=\"[0]>`0.99`\", # Explain a record only when its prediction meets the condition\n", + ")\n", + "pprint.pprint(response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, df_test[label_header][:num_records])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Cleanup\n", + "\n", + "Finally, don’t forget to clean up the resources we set up and used for this demo!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_endpoint(EndpointName=endpoint_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_model(ModelName=model_name)" + ] + } + ], + "metadata": { + "instance_type": "ml.m5.xlarge", + "kernelspec": { + "display_name": "Python 3 (Data Science)", + "language": "python", + "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + }, + "toc-autonumbering": false, + "toc-showmarkdowntxt": false + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sagemaker-clarify/online_explainability/natural_language_processing/scripts/train.py b/sagemaker-clarify/online_explainability/natural_language_processing/scripts/train.py new file mode 100644 index 0000000000..9c340cd114 --- /dev/null +++ b/sagemaker-clarify/online_explainability/natural_language_processing/scripts/train.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer +from sklearn.metrics import accuracy_score, precision_recall_fscore_support +from datasets import load_from_disk +import random +import logging +import sys +import argparse +import os +import torch + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # hyperparameters sent by the client are passed as command-line arguments to the script. + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=32) + parser.add_argument("--eval_batch_size", type=int, default=64) + parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--model_name", type=str) + parser.add_argument("--learning_rate", type=str, default=5e-5) + + # Data, model, and output directories + parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) + parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) + parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"]) + + args, _ = parser.parse_known_args() + + # Set up logging + logger = logging.getLogger(__name__) + logging.basicConfig( + level=logging.getLevelName("INFO"), + handlers=[logging.StreamHandler(sys.stdout)], + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # load datasets + train_dataset = load_from_disk(args.training_dir) + test_dataset = load_from_disk(args.test_dir) + logger.info(f" loaded train_dataset length is: {len(train_dataset)}") + logger.info(f" loaded test_dataset length is: {len(test_dataset)}") + + # compute metrics function for binary classification + def compute_metrics(pred): + labels = pred.label_ids + preds = pred.predictions.argmax(-1) + precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary") + acc = accuracy_score(labels, preds) + return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} + + # download model from model hub + model = AutoModelForSequenceClassification.from_pretrained(args.model_name) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Tokenizer helper function + def tokenize(batch): + return tokenizer(batch['Review Text'], padding=True, truncation=True) + + # Tokenize the dataset + train_dataset = train_dataset.map(tokenize, batched=True) + test_dataset = test_dataset.map(tokenize, batched=True) + + # Set format for PyTorch + train_dataset = train_dataset.rename_column("Sentiment", "labels") + train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels']) + test_dataset = test_dataset.rename_column("Sentiment", "labels") + test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels']) + + # define training args + training_args = TrainingArguments( + output_dir=args.model_dir, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.train_batch_size, + per_device_eval_batch_size=args.eval_batch_size, + warmup_steps=args.warmup_steps, + evaluation_strategy="epoch", + logging_dir=f"{args.output_data_dir}/logs", + learning_rate=float(args.learning_rate), + ) + + # create Trainer instance + trainer = Trainer( + model=model, + args=training_args, + compute_metrics=compute_metrics, + train_dataset=train_dataset, + eval_dataset=test_dataset, + tokenizer=tokenizer, + ) + + # train model + trainer.train() + + # evaluate model + eval_result = trainer.evaluate(eval_dataset=test_dataset) + + # writes eval result to file which can be accessed later in s3 ouput + with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: + print(f"***** Eval results *****") + for key, value in sorted(eval_result.items()): + writer.write(f"{key} = {value}\n") + + # Saves the model to s3 + trainer.save_model(args.model_dir) diff --git a/sagemaker-clarify/online_explainability/tabular/tabular_online_explainability_with_sagemaker_clarify.ipynb b/sagemaker-clarify/online_explainability/tabular/tabular_online_explainability_with_sagemaker_clarify.ipynb new file mode 100644 index 0000000000..2264a3b67f --- /dev/null +++ b/sagemaker-clarify/online_explainability/tabular/tabular_online_explainability_with_sagemaker_clarify.ipynb @@ -0,0 +1,1051 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6a46bd98", + "metadata": { + "tags": [] + }, + "source": [ + "# Tabular Online Explainability with SageMaker Clarify" + ] + }, + { + "cell_type": "markdown", + "id": "e2789005", + "metadata": {}, + "source": [ + "* [Introduction](#Introduction)\n", + "* [General Setup](#General-Setup)\n", + " * [Install dependencies](#Install-dependencies)\n", + " * [Import libraries](#Import-libraries)\n", + " * [Set configurations](#Set-configurations)\n", + " * [Create serializer and deserializer](#Create-serializer-and-deserializer)\n", + " * [For visualization](#For-visualization)\n", + "* [Prepare data](#Prepare-data)\n", + " * [Download data](#Download-data)\n", + " * [Loading the data: Adult Dataset](#Loading-the-data:-Adult-Dataset)\n", + " * [Data inspection](#Data-inspection)\n", + " * [Encode and Upload the Dataset](#Encode-and-Upload-the-Dataset)\n", + "* [Train XGBoost Model](#Train-XGBoost-Model)\n", + "* [Create endpoint](#Create-endpoint)\n", + " * [Create model](#Create-model)\n", + " * [Create endpoint config](#Create-endpoint-config)\n", + " * [Create endpoint](#Create-endpoint)\n", + "* [Invoke endpoint](#Invoke-endpoint)\n", + " * [Single record request](#Single-record-request)\n", + " * [Single record request, no explanation](#Single-record-request,-no-explanation)\n", + " * [Batch request, explain both](#Batch-request,-explain-both)\n", + " * [Batch request, explain none](#Batch-request,-explain-none)\n", + " * [Batch request with more records, explain some of the records](#Batch-request-with-more-records,-explain-some-of-the-records)\n", + "* [Cleanup](#Cleanup)" + ] + }, + { + "cell_type": "markdown", + "id": "2116c025", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. \n", + "\n", + "SageMaker Clarify currently supports explainability for SageMaker models as an offline processing job. This example notebook showcases a new feature for explainability on a [SageMaker real-time inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) endpoint, a.k.a. online explainability.\n", + "\n", + "This example notebook walks you through: \n", + "1. Key terms and concepts needed to understand SageMaker Clarify\n", + "1. Trained the model on a training dataset.\n", + "1. Create a model from trained model artifacts, create an endpoint configuration with the new SageMaker Clarify explainer configuration, and create an endpoint using the same explainer configuration.\n", + "1. Invoke the endpoint with single and batch request with different `EnableExplanations` query.\n", + "1. Explaining the importance of the various input features on the model's decision.\n", + "\n", + "\n", + "In doing so, the notebook will first train a [SageMaker XGBoost](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html) model using training dataset, then use SageMaker Clarify to analyze a testing dataset in CSV format." + ] + }, + { + "cell_type": "markdown", + "id": "bf64fc0a", + "metadata": {}, + "source": [ + "## General Setup\n", + "\n", + "We recommend you use `Python 3 (Data Science)` kernel on SageMaker Studio or `conda_python3` kernel on SageMaker Notebook Instance." + ] + }, + { + "cell_type": "markdown", + "id": "48047964", + "metadata": {}, + "source": [ + "### Install dependencies\n", + "\n", + "Upgrade the SageMaker Python SDK. Install shap and matplotlib which are used to visualize the feature attributions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec349708", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sagemaker --upgrade\n", + "!pip install boto3 --upgrade\n", + "!pip install botocore --upgrade\n", + "!pip install shap --upgrade" + ] + }, + { + "cell_type": "markdown", + "id": "7ec24d29", + "metadata": {}, + "source": [ + "### Import libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85c9245b", + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "import io\n", + "import os\n", + "import shap\n", + "import pprint\n", + "import pandas as pd\n", + "import numpy as np\n", + "from collections import OrderedDict\n", + "from sagemaker import get_execution_role, Session\n", + "from sagemaker.serializers import CSVSerializer\n", + "from sagemaker.deserializers import JSONDeserializer\n", + "from sagemaker.utils import unique_name_from_base" + ] + }, + { + "cell_type": "markdown", + "id": "b662efac", + "metadata": {}, + "source": [ + "### Set configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0787380b", + "metadata": {}, + "outputs": [], + "source": [ + "boto3_session = boto3.session.Session()\n", + "sagemaker_client = boto3.client(\"sagemaker\")\n", + "sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\")\n", + "\n", + "# Initialize sagemaker session\n", + "sagemaker_session = Session(\n", + " boto_session=boto3_session,\n", + " sagemaker_client=sagemaker_client,\n", + " sagemaker_runtime_client=sagemaker_runtime_client,\n", + ")\n", + "\n", + "region = sagemaker_session.boto_region_name\n", + "print(f\"Region: {region}\")\n", + "\n", + "role = get_execution_role()\n", + "print(f\"Role: {role}\")\n", + "\n", + "s3_client = boto3.client(\"s3\")\n", + "\n", + "prefix = unique_name_from_base(\"DEMO-Tabular-Adult\")\n", + "\n", + "s3_bucket = sagemaker_session.default_bucket()\n", + "s3_prefix = f\"sagemaker/{prefix}\"\n", + "s3_key = f\"s3://{s3_bucket}/{s3_prefix}\"\n", + "print(f\"Demo S3 key: {s3_key}\")\n", + "\n", + "model_name = f\"{prefix}-model\"\n", + "print(f\"Demo model name: {model_name}\")\n", + "endpoint_config_name = f\"{prefix}-endpoint-config\"\n", + "print(f\"Demo endpoint config name: {endpoint_config_name}\")\n", + "endpoint_name = f\"{prefix}-endpoint\"\n", + "print(f\"Demo endpoint name: {endpoint_name}\")\n", + "\n", + "# Instance type for training and hosting\n", + "instance_type = \"ml.m5.xlarge\"" + ] + }, + { + "cell_type": "markdown", + "id": "7c44c759", + "metadata": {}, + "source": [ + "### Create serializer and deserializer" + ] + }, + { + "cell_type": "markdown", + "id": "9f0a7aba", + "metadata": {}, + "source": [ + "CSV serializer to serialize test data to string" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79569f81", + "metadata": {}, + "outputs": [], + "source": [ + "csv_serializer = CSVSerializer()" + ] + }, + { + "cell_type": "markdown", + "id": "6e64de94", + "metadata": {}, + "source": [ + "JSON deserializer to deserialize invoke endpoint response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "201230ef", + "metadata": {}, + "outputs": [], + "source": [ + "json_deserializer = JSONDeserializer()" + ] + }, + { + "cell_type": "markdown", + "id": "57480844", + "metadata": {}, + "source": [ + "### For visualization\n", + "\n", + "SHAP plots are useful visualization tools to interpret the explanations. For example, [SHAP additive force layout](https://shap.readthedocs.io/en/latest/generated/shap.plots.force.html) shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f780c60", + "metadata": {}, + "outputs": [], + "source": [ + "def force_plot(expected_value, shap_values, feature_data, feature_headers):\n", + " \"\"\"\n", + " Visualize the given SHAP values with an additive force layout.\n", + "\n", + " For more information: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Force%20Plot%20Colors.html\n", + " \"\"\"\n", + " force_plot_display = shap.plots.force(\n", + " base_value=expected_value,\n", + " shap_values=shap_values,\n", + " features=feature_data,\n", + " feature_names=feature_headers,\n", + " matplotlib=True,\n", + " )\n", + "\n", + "\n", + "def display_plots(explanations, expected_value, request_records, predictions):\n", + " \"\"\"\n", + " Display the Model Explainability plots\n", + " \"\"\"\n", + " per_request_shap_values = OrderedDict()\n", + " feature_headers = []\n", + " for i, record_output in enumerate(explanations):\n", + " per_record_shap_values = []\n", + " if record_output is not None:\n", + " feature_headers = []\n", + " for feature_attribution in record_output:\n", + " per_record_shap_values.append(\n", + " feature_attribution[\"attributions\"][0][\"attribution\"][0]\n", + " )\n", + " feature_headers.append(feature_attribution[\"feature_header\"])\n", + " per_request_shap_values[i] = per_record_shap_values\n", + "\n", + " for record_index, shap_values in per_request_shap_values.items():\n", + " print(\n", + " f\"Visualize the SHAP values for Record number {record_index + 1} with Model Prediction: {predictions[record_index][0]}\"\n", + " )\n", + " force_plot(\n", + " expected_value,\n", + " np.array(shap_values),\n", + " request_records.iloc[record_index],\n", + " feature_headers,\n", + " )\n", + "\n", + "\n", + "def visualize_result(result, request_records, expected_value):\n", + " \"\"\"\n", + " Visualize the output from the endpoint.\n", + " \"\"\"\n", + " predictions = pd.read_csv(io.StringIO(result[\"predictions\"][\"data\"]), header=None)\n", + " predictions = predictions.values.tolist()\n", + " print(f\"Model Inference output: \")\n", + " for i, model_output in enumerate(predictions):\n", + " print(f\"Record: {i + 1}\\tModel Prediction: {model_output[0]}\")\n", + "\n", + " if \"kernel_shap\" in result[\"explanations\"]:\n", + " explanations = result[\"explanations\"][\"kernel_shap\"]\n", + " display_plots(explanations, expected_value, request_records, predictions)\n", + " else:\n", + " print(f\"No Clarify explanations for the record(s)\")" + ] + }, + { + "cell_type": "markdown", + "id": "aefb875e", + "metadata": {}, + "source": [ + "## Prepare data\n", + "\n", + "### Download data\n", + "Data Source: [https://archive.ics.uci.edu/ml/machine-learning-databases/adult/](https://archive.ics.uci.edu/ml/machine-learning-databases/adult/)\n", + "\n", + "Let's __download__ the data and save it in the local folder with the name adult.data and adult.test from UCI repository$^{[2]}$.\n", + "\n", + "$^{[2]}$Dua Dheeru, and Efi Karra Taniskidou. \"[UCI Machine Learning Repository](http://archive.ics.uci.edu/ml)\". Irvine, CA: University of California, School of Information and Computer Science (2017)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b94940af", + "metadata": {}, + "outputs": [], + "source": [ + "adult_columns = [\n", + " \"Age\",\n", + " \"Workclass\",\n", + " \"fnlwgt\",\n", + " \"Education\",\n", + " \"Education-Num\",\n", + " \"Marital Status\",\n", + " \"Occupation\",\n", + " \"Relationship\",\n", + " \"Ethnic group\",\n", + " \"Sex\",\n", + " \"Capital Gain\",\n", + " \"Capital Loss\",\n", + " \"Hours per week\",\n", + " \"Country\",\n", + " \"Target\",\n", + "]\n", + "if not os.path.isfile(\"adult.data\"):\n", + " s3_client.download_file(\n", + " \"sagemaker-sample-files\", \"datasets/tabular/uci_adult/adult.data\", \"adult.data\"\n", + " )\n", + " print(f\"adult.data saved!\")\n", + "else:\n", + " print(f\"adult.data already on disk.\")\n", + "\n", + "if not os.path.isfile(\"adult.test\"):\n", + " s3_client.download_file(\n", + " \"sagemaker-sample-files\", \"datasets/tabular/uci_adult/adult.test\", \"adult.test\"\n", + " )\n", + " print(f\"adult.test saved!\")\n", + "else:\n", + " print(f\"adult.test already on disk.\")" + ] + }, + { + "cell_type": "markdown", + "id": "9b39109e", + "metadata": {}, + "source": [ + "### Loading the data: Adult Dataset\n", + "From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.\n", + "\n", + "Here are the features and their possible values:\n", + "1. **Age**: continuous.\n", + "1. **Workclass**: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.\n", + "1. **Fnlwgt**: continuous (the number of people the census takers believe that observation represents).\n", + "1. **Education**: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.\n", + "1. **Education-num**: continuous.\n", + "1. **Marital-status**: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.\n", + "1. **Occupation**: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.\n", + "1. **Relationship**: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.\n", + "1. **Ethnic group**: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.\n", + "1. **Sex**: Female, Male.\n", + " * **Note**: this data is extracted from the 1994 Census and enforces a binary option on Sex\n", + "1. **Capital-gain**: continuous.\n", + "1. **Capital-loss**: continuous.\n", + "1. **Hours-per-week**: continuous.\n", + "1. **Native-country**: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.\n", + "\n", + "Next, we specify our binary prediction task: \n", + "15. **Target**: <=50,000, >$50,000." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f6d7062", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "training_data = pd.read_csv(\n", + " \"adult.data\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\"\n", + ").dropna()\n", + "\n", + "testing_data = pd.read_csv(\n", + " \"adult.test\", names=adult_columns, sep=r\"\\s*,\\s*\", engine=\"python\", na_values=\"?\", skiprows=1\n", + ").dropna()\n", + "\n", + "training_data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "9a74eb9c", + "metadata": {}, + "source": [ + "### Data inspection\n", + "Plotting histograms for the distribution of the different features is a good way to visualize the data. Let's plot a few of the features that can be considered _sensitive_. \n", + "Let's take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~$\\frac{1}{7}$th of respondents." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fce0b8f", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "training_data[\"Sex\"].value_counts().sort_values().plot(kind=\"bar\", title=\"Counts of Sex\", rot=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e05094f6", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "training_data[\"Sex\"].where(training_data[\"Target\"] == \">50K\").value_counts().sort_values().plot(\n", + " kind=\"bar\", title=\"Counts of Sex earning >$50K\", rot=0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2c52d24f", + "metadata": {}, + "source": [ + "### Encode and Upload the Dataset\n", + "Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e258c0de", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import preprocessing\n", + "\n", + "\n", + "def number_encode_features(df):\n", + " result = df.copy()\n", + " encoders = {}\n", + " for column in result.columns:\n", + " if result.dtypes[column] == np.object:\n", + " encoders[column] = preprocessing.LabelEncoder()\n", + " # print('Column:', column, result[column])\n", + " result[column] = encoders[column].fit_transform(result[column].fillna(\"None\"))\n", + " return result, encoders\n", + "\n", + "\n", + "training_data = pd.concat([training_data[\"Target\"], training_data.drop([\"Target\"], axis=1)], axis=1)\n", + "training_data, _ = number_encode_features(training_data)\n", + "training_data.to_csv(\"train_data.csv\", index=False, header=False)\n", + "\n", + "testing_data, _ = number_encode_features(testing_data)\n", + "test_features = testing_data.drop([\"Target\"], axis=1)\n", + "test_target = testing_data[\"Target\"]\n", + "test_features.to_csv(\"test_features.csv\", index=False, header=False)" + ] + }, + { + "cell_type": "markdown", + "id": "f64e24d8", + "metadata": {}, + "source": [ + "A quick note about our encoding: the \"Female\" Sex value has been encoded as 0 and \"Male\" as 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90c3e935", + "metadata": {}, + "outputs": [], + "source": [ + "training_data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "01c8a06b", + "metadata": {}, + "source": [ + "Get the feature names and the label names from the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efe526bf", + "metadata": {}, + "outputs": [], + "source": [ + "feature_headers = testing_data.columns.to_list()\n", + "label_header = feature_headers.pop()\n", + "print(f\"Feature names: {feature_headers}\")\n", + "print(f\"Label name: {label_header}\")" + ] + }, + { + "cell_type": "markdown", + "id": "99ec8ebd", + "metadata": {}, + "source": [ + "Lastly, let's upload the data to S3 so that they can be used by the training job." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "637902e4", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.s3 import S3Uploader\n", + "from sagemaker.inputs import TrainingInput\n", + "\n", + "train_uri = S3Uploader.upload(\"train_data.csv\", \"s3://{}/{}\".format(s3_bucket, prefix))\n", + "train_input = TrainingInput(train_uri, content_type=\"csv\")\n", + "test_uri = S3Uploader.upload(\"test_features.csv\", \"s3://{}/{}\".format(s3_bucket, prefix))" + ] + }, + { + "cell_type": "markdown", + "id": "a953f7d8", + "metadata": {}, + "source": [ + "## Train XGBoost Model\n", + "\n", + "Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc99c39e", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.image_uris import retrieve\n", + "from sagemaker.estimator import Estimator\n", + "\n", + "container = retrieve(\"xgboost\", region, version=\"1.3-1\")\n", + "xgb = Estimator(\n", + " container,\n", + " role,\n", + " instance_count=1,\n", + " instance_type=instance_type,\n", + " disable_profiler=True,\n", + " debugger_hook_config=False,\n", + ")\n", + "\n", + "xgb.set_hyperparameters(\n", + " max_depth=5,\n", + " eta=0.2,\n", + " gamma=4,\n", + " min_child_weight=6,\n", + " subsample=0.8,\n", + " objective=\"binary:logistic\",\n", + " num_round=800,\n", + ")\n", + "\n", + "xgb.fit({\"train\": train_input}, logs=False)" + ] + }, + { + "cell_type": "markdown", + "id": "796a8b78", + "metadata": {}, + "source": [ + "Create a new model object which will be used to create the SageMaker model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4f2e34b", + "metadata": {}, + "outputs": [], + "source": [ + "model = xgb.create_model(name=model_name)\n", + "container_def = model.prepare_container_def()\n", + "container_def" + ] + }, + { + "cell_type": "markdown", + "id": "2519d5f9", + "metadata": { + "tags": [] + }, + "source": [ + "## Create endpoint" + ] + }, + { + "cell_type": "markdown", + "id": "12c22d4a", + "metadata": {}, + "source": [ + "### Create model\n", + "\n", + "The following parameters are required to create a SageMaker model:\n", + "\n", + "* `ExecutionRoleArn`: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deployment\n", + "\n", + "* `ModelName`: name of the SageMaker model.\n", + "\n", + "* `PrimaryContainer`: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14991a1c", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_model(\n", + " ExecutionRoleArn=role,\n", + " ModelName=model_name,\n", + " PrimaryContainer=container_def,\n", + ")\n", + "print(f\"Model created: {model_name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "0c8ce1b6", + "metadata": {}, + "source": [ + "### Create endpoint config\n", + "\n", + "Create an endpoint configuration by calling the `create_endpoint_config` API. Here, supply the same `model_name` used in the `create_model` API call. The `create_endpoint_config` now supports the additional parameter `ClarifyExplainerConfig` to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the `ShapBaseline` parameter) or by a S3 baseline file (the `ShapBaselineUri` parameter). Please see the developer guide for the other parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e64aa1c3", + "metadata": {}, + "outputs": [], + "source": [ + "baseline = test_features.mean().to_list() # Inline baseline data\n", + "print(f\"Use the mean of the test data as the SHAP baseline: {baseline}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d196e917", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_endpoint_config(\n", + " EndpointConfigName=endpoint_config_name,\n", + " ProductionVariants=[\n", + " {\n", + " \"VariantName\": \"TestVariant\",\n", + " \"ModelName\": model_name,\n", + " \"InitialInstanceCount\": 1,\n", + " \"InstanceType\": instance_type,\n", + " }\n", + " ],\n", + " ExplainerConfig={\n", + " \"ClarifyExplainerConfig\": {\n", + " # \"EnableExplanations\": \"`false`\", # By default explanations are enabled, but you can change the condition by this parameter.\n", + " \"InferenceConfig\": {\n", + " \"FeatureHeaders\": feature_headers,\n", + " },\n", + " \"ShapConfig\": {\n", + " \"ShapBaselineConfig\": {\n", + " \"ShapBaseline\": csv_serializer.serialize(baseline), # inline baseline data\n", + " }\n", + " },\n", + " }\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e1a7c19b", + "metadata": {}, + "source": [ + "### Create endpoint\n", + "\n", + "Once you have your model and endpoint configuration ready, use the `create_endpoint` API to create your endpoint. The `endpoint_name` must be unique within an AWS Region in your AWS account. The `create_endpoint` API is synchronous in nature and returns an immediate response with the endpoint status being `Creating` state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70975b6f", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.create_endpoint(\n", + " EndpointName=endpoint_name,\n", + " EndpointConfigName=endpoint_config_name,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "53fa01f8", + "metadata": {}, + "source": [ + "Wait for the endpoint to be in \"InService\" state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6270f18a", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_session.wait_for_endpoint(endpoint_name)" + ] + }, + { + "cell_type": "markdown", + "id": "a3742fae", + "metadata": { + "tags": [] + }, + "source": [ + "## Invoke endpoint\n", + "\n", + "There are expanding business needs and legislative regulations that require explanations of _why_ a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.\n", + "\n", + "Below are the several different combination of endpoint invocation, call them one by one and visualize the explanations by running the subsequent cell. " + ] + }, + { + "cell_type": "markdown", + "id": "36419959", + "metadata": {}, + "source": [ + "### Single record request\n", + "\n", + "Put only one record in the request body, and then send the request to the endpoint to get its predictions and explanations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38a0dc1e", + "metadata": {}, + "outputs": [], + "source": [ + "request_records = test_features.iloc[:1, :]\n", + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(request_records.to_numpy()),\n", + ")\n", + "pprint.pprint(response)" + ] + }, + { + "cell_type": "markdown", + "id": "4202f8e2", + "metadata": {}, + "source": [ + "Print the response body which is JSON. Please see the developer guide for its schema." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d231c0", + "metadata": {}, + "outputs": [], + "source": [ + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "markdown", + "id": "7f6693da", + "metadata": {}, + "source": [ + "Use SHAP plots to visualize the result. [SHAP additive force layout](https://shap.readthedocs.io/en/latest/generated/shap.plots.force.html) shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue.\n", + "\n", + "The expected value is the average of the model predictions over the baseline. Here we predict the baseline data and then compute the expected value. Only the predictions are needed, so the `EnableExplanations` parameter is used to disable the explanations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2eaa7ced", + "metadata": {}, + "outputs": [], + "source": [ + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(baseline),\n", + " EnableExplanations=\"`false`\", # Do not provide explanations\n", + ")\n", + "json_object = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "expected_value = float(\n", + " pd.read_csv(io.StringIO(json_object[\"predictions\"][\"data\"]), header=None)\n", + " .astype(float)\n", + " .mean(axis=1)\n", + ")\n", + "print(f\"expected value: {expected_value}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da8a8e1b", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, request_records, expected_value)" + ] + }, + { + "cell_type": "markdown", + "id": "3430f125", + "metadata": {}, + "source": [ + "### Single record request, no explanation\n", + "\n", + "Use the `EnableExplanations` parameter to disable the explanations for this request." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4cd5d44", + "metadata": {}, + "outputs": [], + "source": [ + "request_records = test_features.iloc[:1, :]\n", + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(request_records.to_numpy()),\n", + " EnableExplanations=\"`false`\", # Do not provide explanations\n", + ")\n", + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3709a842", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, request_records, expected_value)" + ] + }, + { + "cell_type": "markdown", + "id": "a8824447", + "metadata": {}, + "source": [ + "### Batch request, explain both\n", + "\n", + "Put two records in the request body, and then send the request to the endpoint to get their predictions and explanations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61434d42", + "metadata": {}, + "outputs": [], + "source": [ + "request_records = test_features.iloc[:2, :]\n", + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(request_records.to_numpy()),\n", + ")\n", + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff81310b", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, request_records, expected_value)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a4732f0", + "metadata": {}, + "outputs": [], + "source": [ + "request_records = test_features.iloc[:2, :]\n", + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(request_records.to_numpy()),\n", + " EnableExplanations=\"`false`\", # Do not provide explanations\n", + ")\n", + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6caec6c0", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, request_records, expected_value)" + ] + }, + { + "cell_type": "markdown", + "id": "63eabecd", + "metadata": {}, + "source": [ + "### Batch request with more records, explain some of the records\n", + "\n", + "Put a few more records to the request body, and then use the `EnableExplanations` expression to filter the records to be explained according to their predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5922a2c0", + "metadata": {}, + "outputs": [], + "source": [ + "request_records = test_features.iloc[:70, :]\n", + "response = sagemaker_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"text/csv\",\n", + " Body=csv_serializer.serialize(request_records.to_numpy()),\n", + " EnableExplanations=\"[0]>`0.95`\", # Explain a record only when its prediction is greater than the threshold\n", + ")\n", + "result = json_deserializer.deserialize(response[\"Body\"], content_type=response[\"ContentType\"])\n", + "pprint.pprint(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d73b9455", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_result(result, request_records, expected_value)" + ] + }, + { + "cell_type": "markdown", + "id": "f50ee663", + "metadata": { + "tags": [] + }, + "source": [ + "## Cleanup\n", + "\n", + "Finally, don’t forget to clean up the resources we set up and used for this demo!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea70e043", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_endpoint(EndpointName=endpoint_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f7c3f69", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "780e3045", + "metadata": {}, + "outputs": [], + "source": [ + "sagemaker_client.delete_model(ModelName=model_name)" + ] + } + ], + "metadata": { + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "Python 3 (Data Science)", + "language": "python", + "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/datascience-1.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}