diff --git a/code/3-model-fairness-deep-dive.ipynb b/code/3-model-fairness-deep-dive.ipynb index a0a5cf37..4d3a039f 100644 --- a/code/3-model-fairness-deep-dive.ipynb +++ b/code/3-model-fairness-deep-dive.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "1a9e96a8", + "id": "d38f3973", "metadata": {}, "source": [ "# Model fairness: hands-on" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "1afaf976", + "id": "06672ffb", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: questions \n", @@ -36,7 +36,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d3ff3de7", + "id": "6478e3d9", "metadata": {}, "outputs": [], "source": [ @@ -56,7 +56,7 @@ }, { "cell_type": "markdown", - "id": "c7d68382", + "id": "6bd00d02", "metadata": {}, "source": [ "## Scenario and data\n", @@ -79,7 +79,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca9852bc", + "id": "ac10ae1e", "metadata": {}, "outputs": [], "source": [ @@ -89,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2c4d220", + "id": "6b9132d0", "metadata": {}, "outputs": [], "source": [ @@ -111,7 +111,7 @@ }, { "cell_type": "markdown", - "id": "382eb78b", + "id": "5f3349fa", "metadata": {}, "source": [ "Check object type." @@ -120,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5a27bf97", + "id": "92062bb0", "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "markdown", - "id": "75df8b9c", + "id": "a7b31cbe", "metadata": {}, "source": [ "Preview data." @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e787962b", + "id": "faa3ff9e", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "1e4a5c9d", + "id": "33bd3e1d", "metadata": {}, "source": [ "Show details about the data." @@ -156,7 +156,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76249b4c", + "id": "0eb4f9c5", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "markdown", - "id": "3879442a", + "id": "ba5b1820", "metadata": {}, "source": [ "Next, we will look at whether the dataset contains bias; i.e., does the outcome 'UTILIZATION' take on a positive value more frequently for one racial group than another?\n", @@ -200,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "348e3eba", + "id": "444e7e22", "metadata": {}, "outputs": [], "source": [ @@ -210,7 +210,7 @@ }, { "cell_type": "markdown", - "id": "ad2cc5c6", + "id": "a1b52994", "metadata": {}, "source": [ "Some initial import error may occur since we're using the CPU-only version of torch. If you run the import statement twice it should correct itself. We've coded this as a try/except statement below." @@ -219,7 +219,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4305065", + "id": "fea86d50", "metadata": {}, "outputs": [], "source": [ @@ -236,7 +236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66cc8daf", + "id": "1512badd", "metadata": {}, "outputs": [], "source": [ @@ -251,7 +251,7 @@ }, { "cell_type": "markdown", - "id": "1d024d5c", + "id": "053e83d2", "metadata": {}, "source": [ "We see that the disparate impact is about 0.53, which means the privileged group has the favorable outcome at about 2x the rate as the unprivileged group does. \n", @@ -265,7 +265,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70b27f85", + "id": "517826a2", "metadata": {}, "outputs": [], "source": [ @@ -278,7 +278,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8bc847df", + "id": "13db4586", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ }, { "cell_type": "markdown", - "id": "fc3ec8b8", + "id": "f9f73c0a", "metadata": {}, "source": [ "### Validate the model\n", @@ -309,7 +309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f9249f09", + "id": "15fb1d66", "metadata": {}, "outputs": [], "source": [ @@ -320,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3aa78f88", + "id": "58985d89", "metadata": {}, "outputs": [], "source": [ @@ -369,7 +369,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dcb17dbd", + "id": "848fb10f", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ }, { "cell_type": "markdown", - "id": "777376a2", + "id": "f65affb1", "metadata": {}, "source": [ "We will plot `val_metrics`. The x-axis will be the threshold we use to output the label 1 (i.e., if the raw score is larger than the threshold, we output 1). \n", @@ -395,7 +395,7 @@ { "cell_type": "code", "execution_count": null, - "id": "04a5ba07", + "id": "9d88c4ab", "metadata": {}, "outputs": [], "source": [ @@ -432,7 +432,7 @@ }, { "cell_type": "markdown", - "id": "759ba4aa", + "id": "29120263", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: challenge\n", @@ -465,7 +465,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6f2d3fa7", + "id": "194d715e", "metadata": {}, "outputs": [], "source": [ @@ -490,7 +490,7 @@ }, { "cell_type": "markdown", - "id": "352bd671", + "id": "3c3de8d2", "metadata": {}, "source": [ "### Test the model\n", @@ -501,7 +501,7 @@ { "cell_type": "code", "execution_count": null, - "id": "118dff41", + "id": "eb8605f7", "metadata": {}, "outputs": [], "source": [ @@ -513,7 +513,7 @@ }, { "cell_type": "markdown", - "id": "392996fe", + "id": "9382f5b3", "metadata": {}, "source": [ "## Mitigate bias with in-processing\n", @@ -526,7 +526,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d6f4c53e", + "id": "a5ec1331", "metadata": {}, "outputs": [], "source": [ @@ -536,7 +536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82601d3f", + "id": "f5379292", "metadata": {}, "outputs": [], "source": [ @@ -548,7 +548,7 @@ }, { "cell_type": "markdown", - "id": "5907ae3a", + "id": "a991dd7d", "metadata": {}, "source": [ "We'll also define metrics for the reweighted data and print out the disparate impact of the dataset." @@ -557,7 +557,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a100e722", + "id": "24013c7a", "metadata": {}, "outputs": [], "source": [ @@ -572,7 +572,7 @@ }, { "cell_type": "markdown", - "id": "3424f60a", + "id": "ed2ca8ff", "metadata": {}, "source": [ "Then, we'll train a model, validate it, and evaluate of the test data." @@ -581,7 +581,7 @@ { "cell_type": "code", "execution_count": null, - "id": "357969f6", + "id": "0c283b19", "metadata": {}, "outputs": [], "source": [ @@ -596,7 +596,7 @@ { "cell_type": "code", "execution_count": null, - "id": "908f50fd", + "id": "e40227b4", "metadata": {}, "outputs": [], "source": [ @@ -611,7 +611,7 @@ { "cell_type": "code", "execution_count": null, - "id": "02de17a8", + "id": "9d73009d", "metadata": {}, "outputs": [], "source": [ @@ -630,7 +630,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5cfea4dc", + "id": "6bbbd9f3", "metadata": {}, "outputs": [], "source": [ @@ -639,7 +639,7 @@ }, { "cell_type": "markdown", - "id": "b6ec9cc8", + "id": "0315aac3", "metadata": {}, "source": [ "### Test" @@ -648,7 +648,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc0428da", + "id": "b226fc7a", "metadata": {}, "outputs": [], "source": [ @@ -660,7 +660,7 @@ }, { "cell_type": "markdown", - "id": "7228f558", + "id": "8fb38a1f", "metadata": {}, "source": [ "We see that the disparate impact score on the test data is better after reweighting than it was originally.\n", @@ -680,7 +680,7 @@ { "cell_type": "code", "execution_count": null, - "id": "03ca2a29", + "id": "98380eba", "metadata": {}, "outputs": [], "source": [ @@ -690,7 +690,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a91dc69b", + "id": "ce6c59b2", "metadata": {}, "outputs": [], "source": [ @@ -703,7 +703,7 @@ }, { "cell_type": "markdown", - "id": "fe419916", + "id": "d7df746a", "metadata": {}, "source": [ "Next, we fit the ThresholdOptimizer object to the validation data." @@ -712,7 +712,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32ba17cb", + "id": "6bf5b7bf", "metadata": {}, "outputs": [], "source": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "b9fc7ac2", + "id": "d483fc67", "metadata": {}, "source": [ "Then, we'll create a helper function, `mini_test` to allow us to call the `describe_metrics` function even though we are no longer evaluating our method as a variety of thresholds.\n", @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9637dcbb", + "id": "9ce0e264", "metadata": {}, "outputs": [], "source": [ @@ -766,7 +766,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e0e8e83f", + "id": "6881d5e3", "metadata": {}, "outputs": [], "source": [ @@ -782,7 +782,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b79a28c3", + "id": "04d8a75a", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33c82c74", + "id": "184740e9", "metadata": {}, "outputs": [], "source": [ @@ -804,7 +804,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cead3118", + "id": "bb9ee6c5", "metadata": {}, "outputs": [], "source": [ @@ -815,7 +815,7 @@ }, { "cell_type": "markdown", - "id": "35d87767", + "id": "79c72808", "metadata": {}, "source": [ "Scroll up and see how these results compare with the original classifier and with the in-processing technique. \n", @@ -828,7 +828,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7b5cbe0", + "id": "9ac1164e", "metadata": {}, "outputs": [], "source": [ @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "68b2b119", + "id": "e5645bf6", "metadata": {}, "source": [ "Recall that a value of 1 in the Race column corresponds to White people, while a value of 0 corresponds to non-White people.\n", diff --git a/code/5a-explainable-AI-method-overview.ipynb b/code/5a-explainable-AI-method-overview.ipynb deleted file mode 100644 index 2de45a24..00000000 --- a/code/5a-explainable-AI-method-overview.ipynb +++ /dev/null @@ -1,279 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "8d1b0ad5", - "metadata": {}, - "source": [ - "# Explainability methods overview" - ] - }, - { - "cell_type": "markdown", - "id": "4914d498", - "metadata": {}, - "source": [ - ":::::::::::::::::::::::::::::::::::::: questions \n", - "\n", - "- What are the major categories of explainability methods, and how do they differ?\n", - "- How do you determine which explainability method to use for a specific use case?\n", - "- What are the trade-offs between black-box and white-box approaches to explainability?\n", - "- How do post-hoc explanation methods compare to inherently interpretable models in terms of utility and reliability?\n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - "::::::::::::::::::::::::::::::::::::: objectives\n", - "\n", - "- Understand the key differences between black-box and white-box explanation methods.\n", - "- Explore the trade-offs between post-hoc explainability and inherent interpretability in models.\n", - "- Identify and categorize different explainability techniques based on their scope, model access, and approach.\n", - "- Learn when to apply specific explainability techniques for various machine learning tasks.\n", - " \n", - "::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - "## Fantastic Explainability Methods and Where to Use Them\n", - "\n", - "We will now take a bird's-eye view of explainability methods that are widely applied on complex models like neural networks. \n", - "We will get a sense of when to use which kind of method, and what the tradeoffs between these methods are. \n", - "\n", - "\n", - "## Three axes of use cases for understanding model behavior\n", - "\n", - "When deciding which explainability method to use, it is helpful to define your setting along three axes. \n", - "This helps in understanding the context in which the model is being used, and the kind of insights you are looking to gain from the model.\n", - "\n", - "### Inherently Interpretable vs Post Hoc Explainable\n", - "\n", - "Understanding the tradeoff between interpretability and complexity is crucial in machine learning. \n", - "Simple models like decision trees, random forests, and linear regression offer transparency and ease of understanding, making them ideal for explaining predictions to stakeholders. \n", - "In contrast, neural networks, while powerful, lack interpretability due to their complexity. \n", - "Post hoc explainable techniques can be applied to neural networks to provide explanations for predictions, but it's essential to recognize that using such methods involves a tradeoff between model complexity and interpretability. \n", - "\n", - "Striking the right balance between these factors is key to selecting the most suitable model for a given task, considering both its predictive performance and the need for interpretability.\n", - "\n", - "![The tradeoff between Interpretability and Complexity](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/e5-interpretability-vs-complexity.png){alt='_Credits: AAAI 2021 Tutorial on Explaining Machine Learning Predictions: State of the Art, Challenges, Opportunities._'}\n", - "\n", - "### Local vs Global Explanations\n", - "Local explanations focus on describing model behavior within a specific neighborhood, providing insights into individual predictions. \n", - "Conversely, global explanations aim to elucidate overall model behavior, offering a broader perspective. \n", - "While global explanations may be more comprehensive, they run the risk of being overly complex. \n", - "\n", - "Both types of explanations are valuable for uncovering biases and ensuring that the model makes predictions for the right reasons. \n", - "The tradeoff between local and global explanations has a long history in statistics, with methods like linear regression (global) and kernel smoothing (local) illustrating the importance of considering both perspectives in statistical analysis.\n", - "\n", - "### Black box vs White Box Approaches\n", - "Techniques that require access to model internals (e.g., model architecture and model weights) are called \"white box\" while techniques that only need query access to the model are called \"black box\". \n", - "Even without access to the model weights, black box or top down approaches can shed a lot of light on model behavior. \n", - "For example, by simply evaluating the model on certain kinds of data, high level biases or trends in the model’s decision making process can be unearthed. \n", - "\n", - "White box approaches use the weights and activations of the model to understand its behavior. \n", - "These classes or methods are more complex and diverse, and we will discuss them in more detail later in this episode.\n", - "Some large models are closed-source due to commercial or safety concerns; for example, users can’t get access to the weights of GPT-4. This limits the use of white box explanations for such models.\n", - "\n", - "\n", - "## Classes of Explainability Methods for Understanding Model Behavior\n", - "\n", - "### Diagnostic Testing\n", - "\n", - "This is the simplest approach towards explaining model behavior. \n", - "This involves applying a series of unit tests to the model, where each test is a sample input where you know what the correct output should be.\n", - "By identifying test examples that break the heuristics the model relies on (called counterfactuals), you can gain insights into the high-level behavior of the model.\n", - "\n", - "**Example Methods:** [Counterfactuals](https://arxiv.org/abs/1902.01007), [Unit tests](https://arxiv.org/abs/2005.04118)\n", - "\n", - "**Pros and Cons:**\n", - "These methods allow for gaining insights into the high-level behavior of the model without the needing access to model weights.\n", - "This is especially useful with recent powerful closed-source models like GPT-4. \n", - "One challenge with this approach is that it is hard to identify in advance what heuristics a model may depend on.\n", - "\n", - "\n", - "### Baking interpretability into models\n", - "\n", - "Some recent research has focused on tweaking highly complex models like neural networks, towards making them more interpretable inherently. \n", - "One such example with language models involves training the model to generate rationales for its prediction, in addition to its original prediction.\n", - "This approach has gained some traction, and there are even [public benchmarks](https://arxiv.org/abs/1911.03429) for evaluating the quality of these generated rationales.\n", - "\n", - "**Example methods:** [Rationales with WT5](https://arxiv.org/abs/2004.14546), [Older approaches for rationales](https://arxiv.org/abs/1606.04155) \n", - "\n", - "**Pros and cons:**\n", - "These models hope to achieve the best of both worlds: complex models that are also inherently interpretable.\n", - "However, research in this direction is still new, and there are no established and reliable approaches for real world applications just yet. \n", - "\n", - "\n", - "### Identifying Decision Rules of the Model:\n", - "\n", - "In this class of methods, we try find a set of rules that generally explain the decision making process of the model. \n", - "Loosely, these rules would be of the form \"if a specific condition is met, then the model will predict a certain class\".\n", - "\n", - "\n", - "**Example methods:** [Anchors](https://aaai.org/papers/11491-anchors-high-precision-model-agnostic-explanations/), [Universal Adversarial Triggers](https://arxiv.org/abs/1908.07125)\n", - "\n", - "![Example use of anchors (table from [Ribeiro et al.](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/))](https://raw.githubusercontent.com/carpentries-incubator/fair-explainable-ml/main/images/e5-anchors-example.png){alt='Table caption: \"Generated anchors for Tabular datasets\". Table shows the following rules: for the adult dataset, predict less than 50K if no capital gain or loss and never married. Predict over 50K if country is US, married, and work hours over 45. For RCDV dataset, predict not rearrested if person has no priors, no prison violations, and crime not against property. Predict re-arrested if person is male, black, has 1-5 priors, is not married, and the crime not against property. For the Lending dataset, predict bad loan if FICO score is less than 650. Predict good loan if FICO score is between 650 and 700 and loan amount is between 5400 and 10000.'}\n", - "\n", - "**Pros and cons:**\n", - "Some global rules help find \"bugs\" in the model, or identify high level biases. But finding such broad coverage rules is challenging. \n", - "Furthermore, these rules only showcase the model's weaknesses, but give next to no insight as to why these weaknesses exist.\n", - "\n", - "\n", - "\n", - "### Visualizing model weights or representations\n", - "Just like how a picture tells a thousand words, visualizations can help encapsulate complex model behavior in a simple image. \n", - "Visualizations are commonly used in explaining neural networks, where the weights or data representations of the model are directly visualized.\n", - "Many such approaches involve reducing the high-dimensional weights or representations to a 2D or 3D space, using techniques like PCA, tSNE, or UMAP.\n", - "Alternatively, these visualizations can retain their high dimensional representation, but use color or size to identify which dimensions or neurons are more important.\n", - "\n", - "**Example methods:** [Visualizing attention heatmaps](https://arxiv.org/abs/1612.08220), Weight visualizations, Model activation visualizations\n", - "\n", - "![Example usage of visualizing attention heatmaps for part-of-speech (POS) identification task using word2vec-encoded vectors. Each cell is a unit in a neural network (each row is a layer and each column is a dimension). Darker colors indicates that a unit is more importance for predictive accuracy (table from [Li et al.](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/1612.08220).)](https://raw.githubusercontent.com/carpentries-incubator/fair-explainable-ml/main/images/e5-visualization-heatmap.png){alt='Image shows a grid with 3 rows and 50 columns. Each cell is colored on a scale of -1.5 (white) to 0.9 (dark blue). Darker colors are concentrated in the first row in seemingly-random columns.'}\n", - "\n", - "**Pros and cons:**\n", - "Gleaning model behaviour from visualizations is very intuitive and user-friendly, and visualizations sometimes have interactive interfaces.\n", - "However, visualizations can be misleading, especially when high-dimensional vectors are reduced to 2D, leading to a loss of information (crowding issue).\n", - "\n", - "An iconic debate exemplifying the validity of visualizations has centered around attention heatmaps. \n", - "Research has shown them to be [unreliable](https://arxiv.org/abs/1902.10186), and then [reliable again](https://arxiv.org/abs/1908.04626). (Check out the titles of these papers!)\n", - "Thus, visualization can only be used as an additional step in an analysis, and not as a standalone method.\n", - "\n", - "\n", - "### Understanding the impact of training examples\n", - "These techniques unearth which training data instances caused the model to generate a specific prediction for a given sample. At a high level, these techniques mathematically identify what training samples that -- if removed from the training process -- are most influential for causing a particular prediction.\n", - "\n", - "**Example methods:** [Influence functions](https://arxiv.org/abs/1703.04730), [Representer point selection](https://arxiv.org/abs/1811.09720)\n", - "\n", - "![Example usage of representer point selection. The image on the left is a test image that is misclassified as a deer (the true label is antelope). The image on the right is the most influential training point. We see that this image is labeled \"zebra,\" but contains both zebras and antelopes. (example adapted from [Yeh et al.](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/1811.09720).)](https://raw.githubusercontent.com/carpentries-incubator/fair-explainable-ml/main/images/e5-influence.png){alt='Two images. On the left, several antelope are standing in the background on a grassy field. On the right, several zebra graze in a field in the background, while there is one antelope in the foreground and other antelope in the background.'}\n", - "\n", - "**Pros and cons:**\n", - "The insights from these approaches are actionable - by identifying the data responsible for a prediction, it can help correct labels or annotation artifacts in that data.\n", - "Unfortunately, these methods scale poorly with the size of the model and training data, quickly becoming computationally expensive.\n", - "Furthermore, even knowing which datapoints had a high influence on a prediction, we don’t know what it was about that datapoint that caused the influence.\n", - "\n", - "\n", - "### Understanding the impact of a single example:\n", - "For a single input, what parts of the input were most important in generating the model's prediction? \n", - "These methods study the signal sent by various features to the model, and observe how the model reacts to changes in these features. \n", - "\n", - "**Example methods:** [Saliency Maps](https://arxiv.org/abs/1312.6034), [LIME](https://arxiv.org/abs/1602.04938)/[SHAP](https://arxiv.org/abs/1705.07874), Perturbations ([Input reduction](https://arxiv.org/abs/1804.07781), [Adversarial Perturbations](https://arxiv.org/abs/1712.06751))\n", - "\n", - "These methods can be further subdivided into two categories: gradient-based methods that rely on white-box model access to directly see the impact of changing a single input, and perturbation-based methods that manually perturb an input and re-query the model to see how the prediction changes. \n", - "\n", - "![Example saliency maps. The right 4 columns show the result of different saliency method techniques, where red dots indicate regions that are influential for predicting \"dog\" and blue dots indicate regions that are influential for predicting \"cat\". The image creators argue that their method, SmoothGrad, is most effective at mapping model behavior to images. (Image taken from [Smilkov et al.](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/1706.03825))](https://raw.githubusercontent.com/carpentries-incubator/fair-explainable-ml/main/images/e5-smoothgrad.png){alt='Two rows images (5 images per row). Leftmost column shows two different pictures, each containing a cat and a dog. Remaining columns show the saliency maps using different techniques (VanillaGrad, InteGrad, GuidedBackProp, and SmoothGrad). Each saliency map has red dots (indicated regions that are influential for predicting \"dog\") and blue dots (influential for predicting \"cat\"). All methods except GuidedBackProp have good overlap between the respective dots and where the animals appear in the image. SmoothGrad has the most precise mapping.'}\n", - "\n", - "**Pros and cons:**\n", - "These methods are fast to compute, and flexible in their use across models.\n", - "However, the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it.\n", - "On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on.\n", - "Relatedly, these methods can be unstable, and can even be [fooled by adversarial examples](https://proceedings.neurips.cc/paper_files/paper/2019/hash/7fea637fd6d02b8f0adf6f7dc36aed93-Abstract.html). \n", - "\n", - "\n", - "### Probing internal representations\n", - "\n", - "As the name suggests, this class of methods aims to probe the internals of a model, to discover what kind of information or knowledge is stored inside the model. \n", - "Probes are often administered to a specific component of the model, like a set of neurons or layers within a neural network. \n", - "\n", - "**Example methods:** [Probing classifiers](https://direct.mit.edu/coli/article/48/1/207/107571/Probing-Classifiers-Promises-Shortcomings-and), [Causal tracing](https://proceedings.neurips.cc/paper/2020/hash/92650b2e92217715fe312e6fa7b90d82-Abstract.html)\n", - "\n", - "![Example probe output. The image shows the result from probing three attention heads. We see that gender stereotypes are encoded into the model because the heads that are important for nurse and farmer change depending on the final pronoun. Specifically, Head 5-10 attends to the stereotypical gender assignment while Head 4-6 attends to the anti-stereotypical gender assignment. (Image taken from [Vig et al.](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/92650b2e92217715fe312e6fa7b90d82-Abstract.html))](https://raw.githubusercontent.com/carpentries-incubator/fair-explainable-ml/main/images/e5-probe.png){alt='The phrase \"The nurse examined the farmer for injuries because PRONOUN\" is shown twice, once with PRONOUN=she and once with PRONOUN=he. Each word is annotated with the importance of three different attention heads. The distribution of which heads are important with each pronoun differs for all words, but especially for nurse and farmer.'}\n", - "\n", - "**Pros and cons:**\n", - "Probes have shown that it is possible to find highly interpretable components in a complex model, e.g., MLP layers in transformers have been shown to store factual knowledge in a structured manner.\n", - "However, there is no systematic way of finding interpretable components, and many components may remain elusive to humans to understand.\n", - "Furthermore, the model components that have been shown to contain certain knowledge may not actually play a role in the model's prediction.\n", - "\n", - ":::: callout\n", - "\n", - "### Is that all?\n", - "Nope! We've discussed a few of the common explanation techniques, but many others exist. In particular, specialized model architectures often need their own explanation algorithms. \n", - "For instance, [Yuan et al.](https://ieeexplore.ieee.org/abstract/document/9875989?casa_token=BiFHRXv7_9gAAAAA:wPV-PXOpCLFg2g1qYgEQ7QF_LKZs32cOXEJBvwjK3z43sXeaGfvQ9e1QePW03MTLq4lrUsh4Jw) give an overview of different explanation techniques for graph neural networks (GNNs). \n", - "\n", - "::::::::::::\n", - "\n", - "\n", - ":::::::::::::::::::::::::::::::::::::: challenge\n", - "\n", - "### Classifying explanation techniques\n", - "\n", - "For each of the explanation techniques described above, discuss the following with a partner:\n", - "\n", - "* Does it require black-box or white-box model access? \n", - "* Are the explanations it provides global or local? \n", - "* Is the technique post-hoc or does it rely on inherent interpretability of the model?\n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - ":::::::::::::: solution\n", - "\n", - "### Solution\n", - "\n", - "| Approach | Post Hoc or Inherently Interpretable? | Local or Global? | White Box or Black Box? |\n", - "|--------------------------------------------------------------------------------------------------|---------------------------------------|------------------|-------------------------|\n", - "| [Diagnostic Testing](#diagnostic-testing) | Post Hoc | Global | Black Box |\n", - "| [Baking interpretability into models](#baking-interpretability-into-models) | Inherently Interpretable | Local | White Box |\n", - "| [Identifying Decision Rules of the Model](#identifying-decision-rules-of-the-model) | Post Hoc | Both | White Box | \n", - "| [Visualizing model weights or representations](#visualizing-model-weights-or-representations) | Post Hoc | Global | White Box |\n", - "| [Understanding the impact of training examples](#understanding-the-impact-of-training-examples) | Post Hoc | Local | White Box |\n", - "| [Understanding the impact of a single example](#understanding-the-impact-of-a-single-example) | Post Hoc | Local | Both |\n", - "| [Probing internal representations of a model](#probing-internal-representations) | Post Hoc | Global/Local | White Box |\n", - "\n", - "\n", - ":::::::::::::::::::::::::\n", - "\n", - "What explanation should you use when? There is no simple answer, as it depends upon your goals (i.e., why you need an explanation), who the audience is, the model architecture, and the availability of model internals (e.g., there is no white-box access to ChatGPT unless you work for Open AI!). The next exercise asks you to consider different scenarios and discuss what explanation techniques are appropriate.\n", - "\n", - "\n", - ":::::::::::::::::::::::::::::::::::::: challenge\n", - "\n", - "Think about the following scenarios and suggest which explainability method would be most appropriate to use, and what information could be gained from that method. Furthermore, think about the limitations of your findings.\n", - "\n", - "_Note:_ These are open-ended questions, and there is no correct answer. Feel free to break into discussion groups to discuss the scenarios.\n", - "\n", - "[//]: # ([These are open-ended questions. Participants are encouraged to discuss if current explainability methods are sufficient to provide guidance.])\n", - "[//]: # (Given a series of scenarios, suggest which explainability method would be most appropriate to use, and what information could be gained from that method. Furthermore, highlight the limitations of your findings. )\n", - "[//]: # ([Use responses from survey for scenarios])\n", - "[//]: # ([Vision + Text + Tabular examples])\n", - "\n", - "**Scenario 1**: Suppose that you are an ML engineer working at a tech company. A fast-food chain company consults with you about sentimental analysis based on feedback they collected on Yelp and their survey. You use an open sourced LLM such as Llama-2 and finetune it on the review text data. The fast-food company asks to provide explanations for the model: \n", - "Is there any outlier review? How does each review in the data affect the finetuned model?\n", - "Which part of the language in the review indicates that a customer likes or dislikes the food? Can you score the food quality according to the reviews?\n", - "Does the review show a trend over time? What item is gaining popularity or losing popularity?\n", - "Q: Can you suggest a few explainability methods that may be useful for answering these questions?\n", - "\n", - "[//]: # ([These are open-ended questions. Participants are encouraged to discuss if current explainability methods are sufficient to provide guidance.])\n", - "\n", - "**Scenario 2**: Suppose that you are a radiologist who analyzes medical images of patients with the help of machine learning models. You use black-box models (e.g., CNNs, Vision Transformers) to complement human expertise and get useful information before making high-stake decisions. \n", - "Which areas of a medical image most likely explains the output of a black-box? \n", - "Can we visualize and understand what features are captured by the intermediate components of the black-box models?\n", - "How do we know if there is a distribution shift? How can we tell if an image is an out-of-distribution example?\n", - "Q: Can you suggest a few explainability methods that may be useful for answering these questions?\n", - "\n", - "**Scenario 3**: Suppose that you work on genomics and you just collected samples of single-cell data into a table: each row records gene expression levels, and each column represents a single cell. You are interested in scientific hypotheses about evolution of cells. You believe that only a few genes are playing a role in your study. \n", - "What exploratory data analysis techniques would you use to examine the dataset?\n", - "How do you check whether there are potential outliers, irregularities in the dataset?\n", - "You believe that only a few genes are playing a role in your study. What can you do to find the set of most explanatory genes?\n", - "How do you know if there is clustering, and if there is a trajectory of changes in the cells? \n", - "Q: Can you explain the decisions you make for each method you use?\n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - "\n", - "## Summary\n", - "\n", - "There are many available explanation techniques and they differ along three dimensions: model access (white-box or black-box), explanation scope (global or local), and approach (inherently interpretable or post-hoc). There's often no objectively-right answer of which explanation technique to use in a given situation, as the different methods have different tradeoffs. \n", - "\n", - "\n", - "### References and Further Reading\n", - "\n", - "This lesson provides a gentle overview into the world of explainability methods. If you'd like to know more, here are some resources to get you started:\n", - "\n", - "- Tutorials on Explainability:\n", - " - [Wallace, E., Gardner, M., & Singh, S. (2020, November). Interpreting predictions of NLP models. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Tutorial Abstracts (pp. 20-23).](https://github.com/Eric-Wallace/interpretability-tutorial-emnlp2020/blob/master/tutorial_slides.pdf)\n", - " - [Lakkaraju, H., Adebayo, J., & Singh, S. (2020). Explaining machine learning predictions: State-of-the-art, challenges, and opportunities. NeurIPS Tutorial.](https://explainml-tutorial.github.io/aaai21)\n", - " - [Belinkov, Y., Gehrmann, S., & Pavlick, E. (2020, July). Interpretability and analysis in neural NLP. In Proceedings of the 58th annual meeting of the association for computational linguistics: tutorial abstracts (pp. 1-5).](https://sebastiangehrmann.github.io/assets/files/acl_2020_interpretability_tutorial.pdf)\n", - "- Research papers:\n", - " - [Holtzman, A., West, P., & Zettlemoyer, L. (2023). Generative Models as a Complex Systems Science: How can we make sense of large language model behavior?. arXiv preprint arXiv:2308.00189.](https://arxiv.org/abs/2308.00189)" - ] - } - ], - "metadata": {}, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/code/5b-deep-dive-into-methods.ipynb b/code/5b-deep-dive-into-methods.ipynb deleted file mode 100644 index 765d0e74..00000000 --- a/code/5b-deep-dive-into-methods.ipynb +++ /dev/null @@ -1,148 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "3f877931", - "metadata": {}, - "source": [ - "# Explainability methods: deep dive" - ] - }, - { - "cell_type": "markdown", - "id": "c192ad91", - "metadata": {}, - "source": [ - ":::::::::::::::::::::::::::::::::::::: questions \n", - "\n", - "\n", - "- How can we identify which parts of an input contribute most to a model’s prediction? \n", - "- What insights can saliency maps, GradCAM, and similar techniques provide about model behavior? \n", - "- What are the strengths and limitations of gradient-based explainability methods? \n", - "- How can probing classifiers help us understand what a model has learned? \n", - "- What are the limitations of probing classifiers, and how can they be addressed? \n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - "::::::::::::::::::::::::::::::::::::: objectives\n", - "\n", - "- Explain how saliency maps and GradCAM work and their applications in understanding model predictions. \n", - "- Introduce GradCAM as a method to visualize the important features used by a model. \n", - "- Understand the concept of probing classifiers and how they assess the representations learned by models. \n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::\n", - "\n", - "## A Deep Dive into Methods for Understanding Model Behaviour\n", - "\n", - "In the previous section, we scratched the surface of explainability methods, introducing you to the broad classes of methods designed to understand different aspects of a model's behavior.\n", - "\n", - "Now, we will dive deeper into two widely used methods, each one which answers one key question: \n", - "\n", - "## What part of my input causes this prediction?\n", - "\n", - "When a model makes a prediction, we often want to know which parts of the input were most important in generating that prediction.\n", - "This helps confirm if the model is making its predictions for the right reasons. \n", - "Sometimes, models use features totally unrelated to the task for their prediction - these are known as 'spurious correlations'.\n", - "For example, a model might predict that a picture contains a dog because it was taken in a park, and not because there is actually a dog in the picture.\n", - "\n", - "**[Saliency Maps](https://arxiv.org/abs/1312.6034)** are among the most simple and popular methods used towards this end. \n", - "We will be working with a more sophisticated version of this method, known as **[GradCAM](https://arxiv.org/abs/1610.02391)**.\n", - "\n", - "#### Method and Examples\n", - "\n", - "A saliency map is a kind of visualization - it is a heatmap across the input that shows which parts of the input are most important in generating the model's prediction.\n", - "They can be calculated using the gradients of a neural network, or by perturbing the input to any ML model and observing how the model reacts to these perturbations.\n", - "The key intuition is that if a small change in a part of the input causes a large change in the model's prediction, then that part of the input is important for the prediction.\n", - "Gradients are useful in this because they provide a signal towards how much the model's prediction would change if the input was changed slightly.\n", - "\n", - "For example, in an image classification task, a saliency map can be used to highlight the parts of the image that the model is focusing on to make its prediction.\n", - "In a text classification task, a saliency map can be used to highlight the words or phrases that are most important for the model's prediction.\n", - "\n", - "GradCAM is an extension of this idea, which uses the gradients of the final layer of a convolutional neural network to generate a heatmap that highlights the important regions of an image.\n", - "This heatmap can be overlaid on the original image to visualize which parts of the image are most important for the model's prediction.\n", - "\n", - "Other variants of this method include [Integrated Gradients](https://arxiv.org/abs/1703.01365), [SmoothGrad](https://arxiv.org/pdf/1806.03000), and others, which are designed to provide more robust and reliable explanations for model predictions.\n", - "However, GradCAM is a good starting point for understanding how saliency maps work, and is a popularly used approach.\n", - "\n", - "Alternative approaches, which may not directly generate heatmaps, include [LIME](https://arxiv.org/abs/1602.04938) and [SHAP](https://arxiv.org/abs/1705.07874), which are also popular and recommended for further reading. \n", - "\n", - "#### Limitations and Extensions\n", - "\n", - "Gradient based saliency methods like GradCam are fast to compute, requiring only a handful of backpropagation steps on the model to generate the heatmap.\n", - "The method is also model-agnostic, meaning it can be applied to any model that can be trained using gradient descent.\n", - "Additionally, the results obtained from these methods are intuitive and easy to understand, making them useful for explaining model predictions to non-experts.\n", - "\n", - "However, their use is limited to models that can be trained using gradient descent, and have white-box access. \n", - "It is also difficult to apply these methods to tasks beyond classification, making their application limited with many recent\n", - "generative models (think LLMs).\n", - "\n", - "Another limitation is that the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it.\n", - "On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on.\n", - "\n", - "\n", - "## What part of my model causes this prediction?\n", - "\n", - "When a model makes a correct prediction on a task it has been trained on (known as a 'downstream task'), \n", - "**[Probing classifiers](https://direct.mit.edu/coli/article/48/1/207/107571/Probing-Classifiers-Promises-Shortcomings-and)** can be used to identify if the model actually contains the relevant information or knowledge required \n", - "to make that prediction, or if it is just making a lucky guess.\n", - "Furthermore, probes can be used to identify the specific components of the model that contain this relevant information, \n", - "providing crucial insights for developing better models over time.\n", - "\n", - "#### Method and Examples\n", - "\n", - "A neural network takes its input as a series of vectors, or representations, and transforms them through a series of layers to produce an output.\n", - "The job of the main body of the neural network is to develop representations that are as useful for the downstream task as possible, \n", - "so that the final few layers of the network can make a good prediction.\n", - "\n", - "This essentially means that a good quality representation is one that _already_ contains all the information required to make a good prediction. \n", - "In other words, the features or representations from the model are easily separable by a simple classifier. And that classifier is what we call \n", - "a 'probe'. A probe is a simple model that uses the representations of the model as input, and tries to learn the downstream task from them.\n", - "The probe itself is designed to be too easy to learn the task on its own. This means, that the only way the probe get perform well on this task is if \n", - "the representations it is given are already good enough to make the prediction.\n", - "\n", - "These representations can be taken from any part of the model. Generally, using representations from the last layer of a neural network help identify if\n", - "the model even contains the information to make predictions for the downstream task. \n", - "However, this can be extended further: probing the representations from different layers of the model can help identify where in the model the\n", - "information is stored, and how it is transformed through the model.\n", - "\n", - "Probes have been frequently used in the domain of NLP, where they have been used to check if language models contain certain kinds of linguistic information. \n", - "These probes can be designed with varying levels of complexity. For example, simple probes have shown language models to contain information \n", - "about simple syntactical features like [Part of Speech tags](https://aclanthology.org/D15-1246.pdf), and more complex probes have shown models to contain entire [Parse trees](https://aclanthology.org/N19-1419.pdf) of sentences.\n", - "\n", - "#### Limitations and Extensions\n", - "\n", - "One large challenge in using probes is identifying the correct architectural design of the probe. Too simple, and \n", - "it may not be able to learn the downstream task at all. Too complex, and it may be able to learn the task even if the \n", - "model does not contain the information required to make the prediction.\n", - "\n", - "Another large limitation is that even if a probe is able to learn the downstream task, it does not mean that the model\n", - "is actually using the information contained in the representations to make the prediction. \n", - "So essentially, a probe can only tell us if a part of the model _can_ make the prediction, not if it _does_ make the prediction.\n", - "\n", - "A new approach known as **[Causal Tracing](https://proceedings.neurips.cc/paper/2020/hash/92650b2e92217715fe312e6fa7b90d82-Abstract.html)** \n", - "addresses this limitation. The objective of this approach is similar to probes: attempting to understand which part of a model contains \n", - "information relevant to a downstream task. The approach involves iterating through all parts of the model being examined (e.g. all layers\n", - "of a model), and disrupting the information flow through that part of the model. (This could be as easy as adding some kind of noise on top of the \n", - "weights of that model component). If the model performance on the downstream task suddenly drops on disrupting a specific model component, \n", - "we know for sure that that component not only contains the information required to make the prediction, but that the model is actually using that\n", - "information to make the prediction.\n", - "\n", - "\n", - ":::::::::::::::::::::::::::::::::::::: challenge\n", - "\n", - "Now, it's time to try implementing these methods yourself! Pick one of the following problems to work on:\n", - "\n", - "- [Train your own linear probe to check if BERT stores the required knowledge for sentiment analysis.](https://carpentries-incubator.github.io/fair-explainable-ml/5c-probes.html)\n", - "- [Use GradCAM on a trained model to check if the model is using the right features to make predictions.](https://carpentries-incubator.github.io/fair-explainable-ml/5d-gradcam.html)\n", - "\n", - "It's time to get your hands dirty now. Good luck, and have fun!\n", - "\n", - "\n", - "::::::::::::::::::::::::::::::::::::::::::::::::::" - ] - } - ], - "metadata": {}, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/code/5c-probes.ipynb b/code/5c-probes.ipynb index af92ac74..b6e15c66 100644 --- a/code/5c-probes.ipynb +++ b/code/5c-probes.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "d098f24e", + "id": "8f22377e", "metadata": {}, "source": [ "# Explainability methods: linear probe" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "b870b01b", + "id": "3f7a541d", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: questions \n", @@ -29,7 +29,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d5244974", + "id": "bdd101f9", "metadata": {}, "outputs": [], "source": [ @@ -51,7 +51,7 @@ }, { "cell_type": "markdown", - "id": "b9b6474b", + "id": "b01bfbad", "metadata": {}, "source": [ "Now, let's set the random seed to ensure reproducibility. Setting random seeds is like setting a starting point for your machine learning adventure. It ensures that every time you train your model, it starts from the same place, using the same random numbers, making your results consistent and comparable." @@ -60,7 +60,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8a61c064", + "id": "3bfcd804", "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ }, { "cell_type": "markdown", - "id": "069a82ba", + "id": "d80250f6", "metadata": {}, "source": [ "##### Loading the Dataset\n", @@ -81,7 +81,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe8ca9a8", + "id": "8eb08128", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +108,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bb2f2751", + "id": "fd768987", "metadata": {}, "outputs": [], "source": [ @@ -117,7 +117,7 @@ }, { "cell_type": "markdown", - "id": "fc2bd075", + "id": "9c63eb60", "metadata": {}, "source": [ "##### Loading the Model\n", @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "efde1f9d", + "id": "6fc272aa", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9825c2a2", + "id": "46aad0a5", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "7859a2c2", + "id": "7a5f2539", "metadata": {}, "source": [ "Let's see what the model's architecture looks like. How many layers does it have?" @@ -174,7 +174,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f4c2beb", + "id": "53e9764a", "metadata": {}, "outputs": [], "source": [ @@ -183,7 +183,7 @@ }, { "cell_type": "markdown", - "id": "02d3317f", + "id": "c895dae9", "metadata": {}, "source": [ "Let's see if your answer matches the actual number of layers in the model." @@ -192,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7474a7a2", + "id": "69e55967", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "38c22a12", + "id": "cdef0b95", "metadata": {}, "source": [ "##### Setting up the Probe\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8205c177", + "id": "56f142de", "metadata": {}, "outputs": [], "source": [ @@ -260,7 +260,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ec07bb2e", + "id": "cc2a7a6b", "metadata": {}, "outputs": [], "source": [ @@ -306,7 +306,7 @@ }, { "cell_type": "markdown", - "id": "e132be2c", + "id": "ade6219c", "metadata": {}, "source": [ "Now, it's finally time to define our probe! We set this up as a class, where the probe itself is an object of this class. \n", @@ -318,7 +318,7 @@ { "cell_type": "code", "execution_count": null, - "id": "78a2d7f4", + "id": "0ce14e70", "metadata": {}, "outputs": [], "source": [ @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d367d0f4", + "id": "c1a7d0dd", "metadata": {}, "outputs": [], "source": [ @@ -480,7 +480,7 @@ }, { "cell_type": "markdown", - "id": "5baabd5a", + "id": "b7b861b5", "metadata": {}, "source": [ "##### Analysing the model using Probes\n", @@ -492,7 +492,7 @@ { "cell_type": "code", "execution_count": null, - "id": "027d56ad", + "id": "e8b16cb4", "metadata": {}, "outputs": [], "source": [ @@ -529,7 +529,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ea3b0248", + "id": "65c7f01d", "metadata": {}, "outputs": [], "source": [ @@ -544,7 +544,7 @@ }, { "cell_type": "markdown", - "id": "47b8508b", + "id": "78c7faca", "metadata": {}, "source": [ "Which layer has the best accuracy? What does this tell us about the model?\n", @@ -557,7 +557,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85d8720c", + "id": "b1bc5534", "metadata": {}, "outputs": [], "source": [ diff --git a/code/5d-gradcam.ipynb b/code/5d-gradcam.ipynb index aaa55069..9411478d 100644 --- a/code/5d-gradcam.ipynb +++ b/code/5d-gradcam.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2100f187", + "id": "41914176", "metadata": {}, "source": [ "# Explainability methods: GradCAM" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "75471ec5", + "id": "f5810acb", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: questions \n", @@ -29,7 +29,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fd5f019", + "id": "e08230fd", "metadata": {}, "outputs": [], "source": [ @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0e19cdc1", + "id": "266ccc02", "metadata": {}, "outputs": [], "source": [ @@ -48,10 +48,10 @@ "import requests\n", "\n", "# Packages to view and process images\n", + "import matplotlib.pyplot as plt\n", "import cv2\n", "import numpy as np\n", "from PIL import Image\n", - "from google.colab.patches import cv2_imshow\n", "\n", "# Packages to load the model\n", "import torch\n", @@ -66,16 +66,16 @@ { "cell_type": "code", "execution_count": null, - "id": "66215f28", + "id": "5d98d540", "metadata": {}, "outputs": [], "source": [ - "device = 'gpu' if torch.cuda.is_available() else 'cpu'" + "device = 'cpu' # we're using the CPU only version of this workshop" ] }, { "cell_type": "markdown", - "id": "1cf9ba5c", + "id": "b5210e0e", "metadata": {}, "source": [ "##### Load Model\n", @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0535ef84", + "id": "6af6a4e0", "metadata": {}, "outputs": [], "source": [ @@ -96,7 +96,7 @@ }, { "cell_type": "markdown", - "id": "a295513e", + "id": "87c870bb", "metadata": {}, "source": [ "##### Load Test Image" @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f4580aed", + "id": "dc0fab50", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6f880ce", + "id": "c666cc65", "metadata": {}, "outputs": [], "source": [ @@ -136,7 +136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39604f0b", + "id": "ee270465", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "3ea98e40", + "id": "a2935599", "metadata": {}, "source": [ "### Grad-CAM Time!" @@ -154,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28988373", + "id": "bfe54323", "metadata": {}, "outputs": [], "source": [ @@ -166,7 +166,7 @@ }, { "cell_type": "markdown", - "id": "b203ca69", + "id": "10d918e2", "metadata": {}, "source": [ "Here we want to interpret what the model as a whole is doing (not what a specific layer is doing).\n", @@ -179,7 +179,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f6ff250", + "id": "b6875c4b", "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ }, { "cell_type": "markdown", - "id": "f38711b7", + "id": "077d0953", "metadata": {}, "source": [ "We also want to pick a label for the CAM - this is the class we want to visualize the activation for.\n", @@ -200,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b74179b1", + "id": "5b725fc8", "metadata": {}, "outputs": [], "source": [ @@ -212,7 +212,7 @@ }, { "cell_type": "markdown", - "id": "0dc504a7", + "id": "fe0c3d4a", "metadata": {}, "source": [ "Well, that's a lot! To simplify things, we have already picked out the indices of a few interesting classes.\n", @@ -231,7 +231,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a556986", + "id": "cd43e1e1", "metadata": {}, "outputs": [], "source": [ @@ -242,33 +242,34 @@ { "cell_type": "code", "execution_count": null, - "id": "ef9e892d", + "id": "62bc4816", "metadata": {}, "outputs": [], "source": [ "def viz_gradcam(model, target_layers, class_id):\n", "\n", - " if class_id is None:\n", - " targets = None\n", - " else:\n", - " targets = [ClassifierOutputTarget(class_id)]\n", + " if class_id is None:\n", + " targets = None\n", + " else:\n", + " targets = [ClassifierOutputTarget(class_id)]\n", "\n", - " cam_algorithm = GradCAM\n", - " with cam_algorithm(model=model, target_layers=target_layers) as cam:\n", - " grayscale_cam = cam(input_tensor=input_tensor,\n", - " targets=targets)\n", + " cam_algorithm = GradCAM\n", + " with cam_algorithm(model=model, target_layers=target_layers) as cam:\n", + " grayscale_cam = cam(input_tensor=input_tensor, targets=targets)\n", "\n", - " grayscale_cam = grayscale_cam[0, :]\n", + " grayscale_cam = grayscale_cam[0, :]\n", "\n", - " cam_image = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True)\n", - " cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)\n", + " cam_image = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True)\n", + " cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)\n", "\n", - " cv2_imshow(cam_image)" + " plt.imshow(cam_image)\n", + " plt.axis(\"off\")\n", + " plt.show()" ] }, { "cell_type": "markdown", - "id": "3fc9950d", + "id": "692bcb33", "metadata": {}, "source": [ "Finally, we can start visualizing! Let's begin by seeing what parts of the image the model looks at to make its most confident prediction." @@ -277,7 +278,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47afa122", + "id": "e576d877", "metadata": {}, "outputs": [], "source": [ @@ -286,7 +287,7 @@ }, { "cell_type": "markdown", - "id": "927c61d1", + "id": "fe1e51de", "metadata": {}, "source": [ "Interesting, it looks like the model totally ignores the cat and makes a prediction based on the dog.\n", @@ -298,7 +299,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2fe8f7ab", + "id": "684ffbdc", "metadata": {}, "outputs": [], "source": [ @@ -307,7 +308,7 @@ }, { "cell_type": "markdown", - "id": "8d85ede7", + "id": "647cd559", "metadata": {}, "source": [ "The model is indeed looking at the cat when asked to predict the class \"Tabby Cat\" (`class_id=281`)!\n", @@ -320,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5eaab5d", + "id": "38249114", "metadata": {}, "outputs": [], "source": [ @@ -329,7 +330,7 @@ }, { "cell_type": "markdown", - "id": "c427626b", + "id": "ac230875", "metadata": {}, "source": [ "It can! However, it seems to also think of the shelf behind the dog as a door.\n", @@ -340,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff38c4d7", + "id": "cda15914", "metadata": {}, "outputs": [], "source": [ @@ -349,7 +350,7 @@ }, { "cell_type": "markdown", - "id": "6db312f2", + "id": "20ab566f", "metadata": {}, "source": [ "Looks like our analysis has revealed a shortcoming of the model! It seems to percieve cats and street signs similarly.\n", @@ -360,7 +361,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65a0f812", + "id": "c36d61c3", "metadata": {}, "outputs": [], "source": [ @@ -369,7 +370,7 @@ }, { "cell_type": "markdown", - "id": "1a0381d8", + "id": "37d892d6", "metadata": {}, "source": [ "Explaining model predictions though visualization techniques like this can be very subjective and prone to error. However, this still provides some degree of insight a completely black box model would not provide.\n", diff --git a/code/7a-OOD-detection-overview.ipynb b/code/7a-OOD-detection-overview.ipynb index 8bf08031..d2dcab22 100644 --- a/code/7a-OOD-detection-overview.ipynb +++ b/code/7a-OOD-detection-overview.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "ca1b45be", + "id": "c49105b1", "metadata": {}, "source": [ "# OOD detection: overview" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "f24230d0", + "id": "cdaf8000", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::::: questions\n", @@ -38,19 +38,58 @@ "## How OOD data manifests in ML pipelines\n", "The difference between in-distribution (ID) and OOD data can arise from:\n", "\n", - "- **Semantic shift**: The OOD sample belongs to a class that was not present during training.\n", - "- **Covariate shift**: The OOD sample comes from a domain where the input feature distribution is drastically different from the training data.\n", + "- **Semantic shift**: The OOD sample belongs to a class that was not present during training (classification). With continuous prediction/regression, semantic shift occurs when the underlying relationship between X and Y changes. \n", + "- **Covariate shift**: The OOD sample comes from a domain where the input feature distribution is drastically different from the training data. The input feature distribution changes, but the underlying relationship between X and Y stays the same.\n", + "\n", + "Semantic shift often co-occurs with covariate shift.\n", + "\n", + ":::::::::::::::::::::::::::::::::::::: challenge\n", + "\n", + "### Distinguishing Semantic Shift vs. Covariate Shift\n", + "\n", + "You trained a model using the CIFAR-10 dataset to classify images into 10 classes (e.g., airplanes, dogs, trucks). Now, you deploy the model to classify images found on the internet. Consider the following scenarios and classify each as **Semantic Shift**, **Covariate Shift**, or **Both**. Provide reasoning for your choice.\n", + "\n", + "1. **Scenario A**: The internet dataset contains images of drones, which were not present in the CIFAR-10 dataset. The model struggles to classify them.\n", + " \n", + "2. **Scenario B**: The internet dataset has dog images, but these dogs are primarily captured in outdoor settings with unfamiliar backgrounds and lighting conditions compared to the training data.\n", + " \n", + "3. **Scenario C**: The internet dataset contains images of hybrid animals (e.g., \"wolf-dogs\") that do not belong to any CIFAR-10 class. The model predicts incorrectly.\n", + "\n", + "4. **Scenario D**: The internet dataset includes high-resolution images of airplanes, while the CIFAR-10 dataset contains only low-resolution airplane images. The model performs poorly on these new airplane images.\n", + "\n", + "5. **Scenario E**: A researcher retrains the CIFAR-10 model using an updated dataset where labels for \"trucks\" are now redefined to include pickup trucks, which were previously excluded. The new labels confuse the original model.\n", + "\n", + "::::::::::::::::::::::::::::::::::::::::::::::::::\n", + "\n", + ":::::::::::::::::::::::::::::::::::::: solution\n", + "\n", + "1. **Scenario A**: **Semantic Shift** \n", + " - Drones represent a new class not seen during training, so the model encounters a semantic shift.\n", + "\n", + "2. **Scenario B**: **Covariate Shift** \n", + " - The distribution of input features (e.g., lighting, background) changes, but the semantic relationship (e.g., dogs are still dogs) remains intact.\n", + "\n", + "3. **Scenario C**: **Both** \n", + " - Hybrid animals represent a semantic shift (new class), and unfamiliar feature distributions (e.g., traits of wolves and dogs combined) also introduce covariate shift.\n", + "\n", + "4. **Scenario D**: **Covariate Shift** \n", + " - The resolution of the images (input features) changes, but the semantic class of airplanes remains consistent.\n", + "\n", + "5. **Scenario E**: **Semantic Shift** \n", + " - The relationship between input features and class labels has changed, as the definition of the \"truck\" class has been altered.\n", + "\n", + "::::::::::::::::::::::::::::::::::::::::::::::::::\n", "\n", "## Why does OOD data matter?\n", "Models trained on a specific distribution might make incorrect predictions on OOD data, leading to unreliable outputs. In critical applications (e.g., healthcare, autonomous driving), encountering OOD data without proper handling can have severe consequences.\n", "\n", "### Ex1: Tesla crashes into jet\n", - "In April 2022, a [Tesla Model Y crashed into a $3.5 million private jet](https://www.newsweek.com/video-tesla-smart-summon-mode-ramming-3m-jet-viewed-34m-times-1700310 ) at an aviation trade show in Spokane, Washington, while operating on the \"Smart Summon\" feature. The feature allows Tesla vehicles to autonomously navigate parking lots to their owners, but in this case, it resulted in a significant mishap.\n", - "- The Tesla was summoned by its owner using the Tesla app, which requires holding down a button to keep the car moving. The car continued to move forward even after making contact with the jet, pushing the expensive aircraft and causing notable damage.\n", - "- The crash highlighted several issues with Tesla's Smart Summon feature, particularly its object detection capabilities. The system failed to recognize and appropriately react to the presence of the jet, a problem that has been observed in other scenarios where the car's sensors struggle with objects that are lifted off the ground or have unusual shapes.\n", + "In April 2022, a [Tesla Model Y crashed into a $3.5 million private jet](https://www.newsweek.com/video-tesla-smart-summon-mode-ramming-3m-jet-viewed-34m-times-1700310 ) at an aviation trade show in Spokane, Washington, while operating on the \"Smart Summon\" feature. The feature allows Tesla vehicles to \"autonomously\" navigate parking lots to their owners, but in this case, it resulted in a significant mishap. The car continued to move forward even after making contact with the jet, pushing the expensive aircraft and causing notable damage. \n", + "\n", + "The crash highlighted several issues with Tesla's Smart Summon feature, particularly its object detection capabilities. The system failed to recognize and appropriately react to the presence of the jet, a problem that has been observed in other scenarios where the car's sensors struggle with objects that are lifted off the ground or have unusual shapes.\n", "\n", "### Ex2: IBM Watson for Oncology\n", - "Around a decade ago, the excitement surrounding AI in healthcare often exceeded its actual capabilities. In 2016, IBM launched Watson for Oncology, an AI-powered platform for treatment recommendations, to much public enthusiasm. However, it soon became apparent that the system was both costly and unreliable, frequently generating flawed advice while operating as an opaque \"black box. IBM Watson for Oncology faced several issues due to OOD data. The system was primarily trained on data from Memorial Sloan Kettering Cancer Center (MSK), which did not generalize well to other healthcare settings. This led to the following problems:\n", + "Around a decade ago, the excitement surrounding AI in healthcare often exceeded its actual capabilities. In 2016, IBM launched Watson for Oncology, an AI-powered platform for treatment recommendations, to much public enthusiasm. However, it soon became apparent that the system was both costly and unreliable, frequently generating flawed advice while operating as an opaque \"black box\". IBM Watson for Oncology faced several issues due to OOD data. The system was primarily trained on data from Memorial Sloan Kettering Cancer Center (MSK), which did not generalize well to other healthcare settings. This led to:\n", "\n", "1. Unsafe recommendations: Watson for Oncology provided treatment recommendations that were not safe or aligned with standard care guidelines in many cases outside of MSK. This happened because the training data was not representative of the diverse medical practices and patient populations in different regions\n", "2. Bias in training data: The system's recommendations were biased towards the practices at MSK, failing to account for different treatment protocols and patient needs elsewhere. This bias is a classic example of an OOD issue, where the model encounters data (patients and treatments) during deployment that significantly differ from its training data\n", diff --git a/code/7b-OOD-detection-softmax.ipynb b/code/7b-OOD-detection-softmax.ipynb index ba75d31f..9b8dbed4 100644 --- a/code/7b-OOD-detection-softmax.ipynb +++ b/code/7b-OOD-detection-softmax.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e1848f26", + "id": "81889ab0", "metadata": {}, "source": [ "# OOD detection: softmax" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "bf3ffe32", + "id": "70bb1443", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::::: questions\n", @@ -37,7 +37,7 @@ "# Example 1: Softmax scores\n", "Softmax-based methods are among the most widely used techniques for out-of-distribution (OOD) detection, leveraging the probabilistic outputs of a model to differentiate between in-distribution (ID) and OOD data. These methods are inherently tied to models employing a softmax activation function in their final layer, such as logistic regression or neural networks with a classification output layer. Softmax produces a normalized probability distribution across the classes, which can then be thresholded to identify OOD instances.\n", "\n", - "In this lesson, we will train a logistic regression model to classify images from the Fashion MNIST dataset and explore how its softmax outputs can signal whether a given input belongs to the ID classes (e.g., T-shirts or pants) or is OOD (e.g., sandals). While softmax is most naturally applied in models with a logistic activation, alternative approaches, such as applying softmax-like operations post hoc to models with different architectures, are occasionally used. However, these alternatives are less common and may require additional considerations, such as recalibrating the output scores.\n", + "In this lesson, we will train a logistic regression model to classify images from the Fashion MNIST dataset and explore how its softmax outputs can signal whether a given input belongs to the ID classes (e.g., T-shirts or pants) or is OOD (e.g., sandals). While softmax is most naturally applied in models with a logistic activation, alternative approaches, such as applying softmax-like operations post hoc to models with different architectures, are occasionally used. However, these alternatives are less common and may require additional considerations.\n", "\n", "By focusing on logistic regression, we aim to illustrate the fundamental principles of softmax-based OOD detection in a simple and interpretable context before extending these ideas to more complex architectures." ] @@ -45,7 +45,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dec1102d", + "id": "126fbf1f", "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "markdown", - "id": "3104f274", + "id": "9c1b7be0", "metadata": {}, "source": [ "### Prepare the ID (train and test) and OOD data\n", @@ -69,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3eaf6cc", + "id": "00cea03c", "metadata": {}, "outputs": [], "source": [ @@ -133,19 +133,19 @@ { "cell_type": "code", "execution_count": null, - "id": "78182985", + "id": "9187394f", "metadata": {}, "outputs": [], "source": [ "train_data, test_data, ood_data, train_labels, test_labels, ood_labels = prep_ID_OOD_datasests([0,1], [5])\n", "fig = plot_data_sample(train_data, ood_data)\n", - "fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')\n", + "#fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "759dfc2f", + "id": "ada05288", "metadata": {}, "source": [ "![Preview of image dataset](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_image-data-preview.png)\n", @@ -161,7 +161,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b455c9d7", + "id": "99f9990b", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c8e3cedf", + "id": "c5caec51", "metadata": {}, "outputs": [], "source": [ @@ -201,13 +201,13 @@ "plt.xlabel('First Principal Component')\n", "plt.ylabel('Second Principal Component')\n", "plt.title('PCA of In-Distribution and OOD Data')\n", - "plt.savefig('../images/OOD-detection_PCA-image-dataset.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_PCA-image-dataset.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "125d6526", + "id": "66b7b614", "metadata": {}, "source": [ "![PCA visualization](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_PCA-image-dataset.png)\n", @@ -221,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c116dce", + "id": "d337b4a6", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +231,7 @@ }, { "cell_type": "markdown", - "id": "fa334626", + "id": "32f5ef0e", "metadata": {}, "source": [ "Before we worry about the impact of OOD data, let's first verify that we have a reasonably accurate model for the ID data." @@ -240,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "746bfd5e", + "id": "ae05eacb", "metadata": {}, "outputs": [], "source": [ @@ -253,7 +253,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fab6beac", + "id": "9c7b577f", "metadata": {}, "outputs": [], "source": [ @@ -263,13 +263,13 @@ "cm = confusion_matrix(test_labels, in_dist_preds, labels=[0, 1])\n", "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['T-shirt/top', 'Pants'])\n", "disp.plot(cmap=plt.cm.Blues)\n", - "plt.savefig('../images/OOD-detection_ID-confusion-matrix.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_ID-confusion-matrix.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "8ca6e884", + "id": "a8f48984", "metadata": {}, "source": [ "![ID confusion matrix](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_ID-confusion-matrix.png)\n", @@ -282,7 +282,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22490644", + "id": "a73561a0", "metadata": {}, "outputs": [], "source": [ @@ -305,7 +305,7 @@ }, { "cell_type": "markdown", - "id": "6c12b7ea", + "id": "44fc470f", "metadata": {}, "source": [ "Based on the difference in averages here, it looks like softmax may provide at least a somewhat useful signal in separating ID and OOD data. Let's take a closer look by plotting histograms of all probability scores across our classes of interest (ID-Tshirt, ID-Pants, and OOD)." @@ -314,7 +314,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14089422", + "id": "e682c757", "metadata": {}, "outputs": [], "source": [ @@ -342,14 +342,14 @@ "\n", "# Adjusting layout\n", "plt.tight_layout()\n", - "plt.savefig('../images/OOD-detection_histograms.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_histograms.png', dpi=300, bbox_inches='tight')\n", "# Displaying the plot\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "66e5537a", + "id": "2b37c28f", "metadata": {}, "source": [ "![Histograms of ID oand OOD data](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_histograms.png)\n", @@ -359,7 +359,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eab76070", + "id": "86662cd4", "metadata": {}, "outputs": [], "source": [ @@ -392,7 +392,7 @@ "plt.title('Probability Density Distributions for OOD and ID Data')\n", "plt.legend()\n", "\n", - "plt.savefig('../images/OOD-detection_PSDs.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_PSDs.png', dpi=300, bbox_inches='tight')\n", "\n", "# Displaying the plot\n", "plt.show()" @@ -400,7 +400,7 @@ }, { "cell_type": "markdown", - "id": "28306058", + "id": "5e77fc87", "metadata": {}, "source": [ "![Probability densities](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_PSDs.png)\n", @@ -415,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ba41e8c", + "id": "8e3b961e", "metadata": {}, "outputs": [], "source": [ @@ -429,7 +429,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dee14994", + "id": "d69ecd2f", "metadata": {}, "outputs": [], "source": [ @@ -457,7 +457,7 @@ "disp.plot(cmap=plt.cm.Blues)\n", "plt.title('Confusion Matrix for OOD and ID Classification')\n", "\n", - "plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix1.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix1.png', dpi=300, bbox_inches='tight')\n", "\n", "plt.show()\n", "\n", @@ -471,7 +471,7 @@ }, { "cell_type": "markdown", - "id": "5c89a9e1", + "id": "3faaa719", "metadata": {}, "source": [ "![Probability densities](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_ID-OOD-confusion-matrix1.png)\n", @@ -509,7 +509,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e93e3eb", + "id": "23e686f5", "metadata": {}, "outputs": [], "source": [ @@ -544,7 +544,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6b133222", + "id": "306c6984", "metadata": {}, "outputs": [], "source": [ @@ -558,7 +558,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bdf2f44e", + "id": "6c67cf54", "metadata": {}, "outputs": [], "source": [ @@ -600,17 +600,17 @@ { "cell_type": "code", "execution_count": null, - "id": "279a2a82", + "id": "4e9c6814", "metadata": {}, "outputs": [], "source": [ "fig, best_f1_threshold, best_precision_threshold, best_recall_threshold = plot_metrics_vs_thresholds(thresholds, f1_scores, precisions, recalls, 'Softmax')\n", - "fig.savefig('../images/OOD-detection_metrics_vs_softmax-thresholds.png', dpi=300, bbox_inches='tight')" + "#fig.savefig('../images/OOD-detection_metrics_vs_softmax-thresholds.png', dpi=300, bbox_inches='tight')" ] }, { "cell_type": "markdown", - "id": "bd1d54c4", + "id": "d7bd49a1", "metadata": {}, "source": [ "![OOD-detection_metrics_vs_softmax-thresholds](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_metrics_vs_softmax-thresholds.png)" @@ -619,7 +619,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f627fe32", + "id": "94217096", "metadata": {}, "outputs": [], "source": [ @@ -644,13 +644,13 @@ "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[\"Shirt\", \"Pants\", \"OOD\"])\n", "disp.plot(cmap=plt.cm.Blues)\n", "plt.title('Confusion Matrix for OOD and ID Classification')\n", - "plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix2.png', dpi=300, bbox_inches='tight')\n", + "#plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix2.png', dpi=300, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", - "id": "d2ae0fa7", + "id": "548c783c", "metadata": {}, "source": [ "![Optimized threshold confusion matrix](https://github.com/carpentries-incubator/fair-explainable-ml/raw/main/images/OOD-detection_ID-OOD-confusion-matrix2.png)\n", diff --git a/code/7c-OOD-detection-energy.ipynb b/code/7c-OOD-detection-energy.ipynb index 4413fa7b..6fa88ef7 100644 --- a/code/7c-OOD-detection-energy.ipynb +++ b/code/7c-OOD-detection-energy.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "22d4de61", + "id": "ccc984f3", "metadata": {}, "source": [ "# OOD detection: energy" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "590c9805", + "id": "805419b9", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::::: questions\n", @@ -79,7 +79,7 @@ { "cell_type": "code", "execution_count": null, - "id": "830e983e", + "id": "70a8695d", "metadata": {}, "outputs": [], "source": [ @@ -91,7 +91,7 @@ }, { "cell_type": "markdown", - "id": "54935992", + "id": "854b2422", "metadata": {}, "source": [ "## Visualizing OOD and ID data\n", @@ -107,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7413c644", + "id": "8549d05b", "metadata": {}, "outputs": [], "source": [ @@ -137,7 +137,7 @@ }, { "cell_type": "markdown", - "id": "eac9932a", + "id": "65cb9cb2", "metadata": {}, "source": [ "The warning message indicates that UMAP has overridden the n_jobs parameter to 1 due to the random_state being set. This behavior ensures reproducibility by using a single job. If you want to avoid the warning and still use parallelism, you can remove the random_state parameter. However, removing random_state will mean that the results might not be reproducible." @@ -146,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b05eaaa4", + "id": "29c3899a", "metadata": {}, "outputs": [], "source": [ @@ -173,7 +173,7 @@ }, { "cell_type": "markdown", - "id": "ad3efe2b", + "id": "77d12c25", "metadata": {}, "source": [ "## Train CNN" @@ -182,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3d54cc1", + "id": "658ad1b1", "metadata": {}, "outputs": [], "source": [ @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b659308d", + "id": "d1e047d0", "metadata": {}, "outputs": [], "source": [ @@ -294,7 +294,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebc30d3b", + "id": "97724cb2", "metadata": {}, "outputs": [], "source": [ @@ -370,7 +370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "167ce843", + "id": "efb701af", "metadata": {}, "outputs": [], "source": [ @@ -450,7 +450,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3cdf34d9", + "id": "b2b69c07", "metadata": {}, "outputs": [], "source": [ @@ -590,7 +590,7 @@ }, { "cell_type": "markdown", - "id": "fdd0239f", + "id": "1a6bbd26", "metadata": {}, "source": [ "## Limitations of our approach thus far\n", diff --git a/code/7d-OOD-detection-distance-based.ipynb b/code/7d-OOD-detection-distance-based.ipynb index b612f975..dbb49638 100644 --- a/code/7d-OOD-detection-distance-based.ipynb +++ b/code/7d-OOD-detection-distance-based.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "63dd7d14", + "id": "f6a11671", "metadata": {}, "source": [ "# OOD detection: distance-based" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "2a2bfae1", + "id": "0c55cccd", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::::: questions\n", @@ -98,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6be1ca41", + "id": "ab190f15", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42cea007", + "id": "8a748932", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51bf85b9", + "id": "e46af5a5", "metadata": {}, "outputs": [], "source": [ @@ -223,7 +223,7 @@ }, { "cell_type": "markdown", - "id": "1784513f", + "id": "5bae87d3", "metadata": {}, "source": [ "### 2) Extracting learned features\n", @@ -245,7 +245,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99fbc310", + "id": "3b7b3535", "metadata": {}, "outputs": [], "source": [ @@ -278,7 +278,7 @@ }, { "cell_type": "markdown", - "id": "02aa24b6", + "id": "e91e701b", "metadata": {}, "source": [ "### 3) Dimensionality Reduction and Visualization:\n", @@ -291,7 +291,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16cba146", + "id": "2c2efabc", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +318,7 @@ { "cell_type": "code", "execution_count": null, - "id": "271be2e7", + "id": "83b2523d", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "aa097308", + "id": "27857c8c", "metadata": {}, "source": [ "## Neural network trained with contrastive learning\n", @@ -422,7 +422,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2dbd940b", + "id": "bf198e25", "metadata": {}, "outputs": [], "source": [ @@ -499,7 +499,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7426d496", + "id": "1ac824a3", "metadata": {}, "outputs": [], "source": [ @@ -572,7 +572,7 @@ }, { "cell_type": "markdown", - "id": "88790561", + "id": "bd4d5b67", "metadata": {}, "source": [ "### 2) Extracting learned features\n", @@ -586,7 +586,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f0c2bd2", + "id": "e87a52de", "metadata": {}, "outputs": [], "source": [ @@ -634,7 +634,7 @@ }, { "cell_type": "markdown", - "id": "228f2448", + "id": "cb8c1824", "metadata": {}, "source": [ "### 3) Dimensionality Reduction and Visualization:\n", @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f2ee75c", + "id": "be79fc19", "metadata": {}, "outputs": [], "source": [ @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "f0ecef5d", + "id": "e828b2b5", "metadata": {}, "source": [ "# Limitations of Threshold-Based OOD Detection Methods\n", diff --git a/code/8-releasing-a-model.ipynb b/code/8-releasing-a-model.ipynb index 8621b177..df322db3 100644 --- a/code/8-releasing-a-model.ipynb +++ b/code/8-releasing-a-model.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "de17371c", + "id": "69445e2a", "metadata": {}, "source": [ "# Documenting and releasing a model" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "5c3e8c82", + "id": "e509a764", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: questions \n", @@ -91,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b74ff435", + "id": "9ed23b1d", "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "a19234d3", + "id": "f8b92c4b", "metadata": {}, "source": [ "Initialize model by calling the class with configuration settings." @@ -125,7 +125,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ea7e753e", + "id": "2048f687", "metadata": {}, "outputs": [], "source": [ @@ -136,7 +136,7 @@ }, { "cell_type": "markdown", - "id": "0eb433fb", + "id": "cbe8282a", "metadata": {}, "source": [ "We can then write a function to save out the model. We'll need both the model weights and the model's configuration (hyperparameter settings). We'll save the configurations as a json since a key/value format is convenient here." @@ -145,7 +145,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9074e8b", + "id": "2c2b5a2f", "metadata": {}, "outputs": [], "source": [ @@ -163,7 +163,7 @@ { "cell_type": "code", "execution_count": null, - "id": "129ad499", + "id": "10d144a5", "metadata": {}, "outputs": [], "source": [ @@ -173,7 +173,7 @@ }, { "cell_type": "markdown", - "id": "ac37cbca", + "id": "b2c282e2", "metadata": {}, "source": [ "To load the model back in, we can write another function" @@ -182,7 +182,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79090ad6", + "id": "63a2b442", "metadata": {}, "outputs": [], "source": [ @@ -201,7 +201,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a245db8", + "id": "06bb748f", "metadata": {}, "outputs": [], "source": [ @@ -214,7 +214,7 @@ }, { "cell_type": "markdown", - "id": "54536b19", + "id": "004a23fd", "metadata": {}, "source": [ "## Saving a model to Hugging Face\n", @@ -234,7 +234,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7746c2c5", + "id": "535546ad", "metadata": {}, "outputs": [], "source": [ @@ -243,7 +243,7 @@ }, { "cell_type": "markdown", - "id": "dab7d19d", + "id": "86cd1634", "metadata": {}, "source": [ "You might get a message saying you cannot authenticate through git-credential as no helper is defined on your machine. This warning message should not stop you from being able to complete this episode, but it may mean that the token won't be stored on your machine for future use. \n", @@ -256,7 +256,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5ff47773", + "id": "6bec3057", "metadata": {}, "outputs": [], "source": [ @@ -277,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ccfa199", + "id": "e200b0ed", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +290,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d217884", + "id": "1247a471", "metadata": {}, "outputs": [], "source": [ @@ -300,7 +300,7 @@ }, { "cell_type": "markdown", - "id": "e4f630fa", + "id": "65be6741", "metadata": {}, "source": [ "**Verifying**: To check your work, head back over to your Hugging Face account and click your profile icon in the top-right of the website. Click \"Profile\" from there to view all of your uploaded models. Alternatively, you can search for your username (or model name) from the [Model Hub](https://huggingface.co/models).\n", @@ -311,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b33fcc07", + "id": "f6f5583a", "metadata": {}, "outputs": [], "source": [ @@ -321,7 +321,7 @@ }, { "cell_type": "markdown", - "id": "c8eeef9a", + "id": "f8969f18", "metadata": {}, "source": [ "## Uploading transformer models to Hugging Face\n", @@ -336,7 +336,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4853ef79", + "id": "ed0af79a", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1d0c7a21", + "id": "31d916fd", "metadata": {}, "outputs": [], "source": [ @@ -382,7 +382,7 @@ }, { "cell_type": "markdown", - "id": "4b486b6a", + "id": "bf09e46a", "metadata": {}, "source": [ ":::::::::::::::::::::::::::::::::::::: challenge\n", diff --git a/code/conversion-scripts/batch_convert_md_to_ipynb.py b/code/conversion-scripts/batch_convert_md_to_ipynb.py index 2d4a1d31..a72e1da7 100644 --- a/code/conversion-scripts/batch_convert_md_to_ipynb.py +++ b/code/conversion-scripts/batch_convert_md_to_ipynb.py @@ -54,8 +54,10 @@ def batch_convert_md_to_ipynb(input_directory, output_directory, image_url, excl "1-preparing-to-train.md", "2-model-eval-and-fairness.md", "4-explainability-vs-interpretability.md", + "5a-explainable-AI-method-overview.md", "5b-deep-dive-into-methods.md", "6-confidence-intervals.md", + "7a-OOD-detection-overview.md", "7e-OOD-detection-algo-design.md" ]