Skip to content

Commit

Permalink
Update 7b-OOD-detection-softmax.md
Browse files Browse the repository at this point in the history
  • Loading branch information
qualiaMachine authored Dec 19, 2024
1 parent d3d9bf0 commit f8a3bfc
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions episodes/7b-OOD-detection-softmax.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ exercises: 0

::::::::::::::::::::::::::::::::::::::::::::::::::

## Leveraging softmax model outputs
Softmax-based methods are among the most widely used techniques for out-of-distribution (OOD) detection, leveraging the probabilistic outputs of a model to differentiate between in-distribution (ID) and OOD data. These methods are inherently tied to models employing a softmax activation function in their final layer, such as logistic regression or neural networks with a classification output layer.

The softmax function normalizes the logits (i.e., sum of neuron input without passing through activation function) in the final layer, squeezing the output into a range between 0 and 1. This is useful for interpreting the model’s predictions as probabilities. Softmax probabilities are computed as:
Expand All @@ -31,9 +32,9 @@ $$
P(y = k \mid x) = \frac{\exp(f_k(x))}{ \sum_{j} \exp(f_j(x))}
$$


In this lesson, we will train a logistic regression model to classify images from the Fashion MNIST dataset and explore how its softmax outputs can signal whether a given input belongs to the ID classes (e.g., T-shirts or pants) or is OOD (e.g., sandals). While softmax is most naturally applied in models with a logistic activation, alternative approaches, such as applying softmax-like operations post hoc to models with different architectures, are occasionally used. However, these alternatives are less common and may require additional considerations. By focusing on logistic regression, we aim to illustrate the fundamental principles of softmax-based OOD detection in a simple and interpretable context before extending these ideas to more complex architectures.
### Prepare the ID (train and test) and OOD data

## Prepare the ID (train and test) and OOD data
In order to determine a threshold that can separate ID data from OOD data (or ensure new test data as ID), we need to sample data from both distributions. OOD data used should be representative of potential new classes (i.e., semanitic shift) that may be seen by your model, or distribution/covariate shifts observed in your application area.

* ID = T-shirts/Blouses, Pants
Expand Down Expand Up @@ -120,7 +121,7 @@ def plot_data_sample(train_data, ood_data):
```
Load and prepare the ID data (train+test containing shirts and pants) and OOD data (sandals)

**Why not just add the OOD class to training dataset?**
## Why not just add the OOD class to training dataset?
OOD data is, by definition, not part of the training distribution. It could encompass anything outside the known classes, which means you'd need to collect a representative dataset for "everything else" to train the OOD class. This is practically impossible because OOD data is often diverse and unbounded (e.g., new species, novel medical conditions, adversarial examples).

The key idea behind threshold-based methods is we want to vet our model against a small sample of potential risk-cases using known OOD data to determine an empirical threshold that *hopefully* extends to other OOD cases that may arise in real-world scenarios.
Expand All @@ -136,7 +137,7 @@ Plot sample
fig = plot_data_sample(train_data, ood_data)
plt.show()
```
## Visualizing OOD and ID data
## Visualizing OOD and ID data with PCA

### PCA
PCA visualization can provide insights into how well a model is separating ID and OOD data. If the OOD data overlaps significantly with ID data in the PCA space, it might indicate that the model could struggle to correctly identify OOD samples.
Expand Down Expand Up @@ -178,6 +179,7 @@ From this plot, we see that sandals are more likely to be confused as T-shirts t

* **Over-reliance on linear relationships**: Part of this has to do with the fact that we're only looking at linear relationships and treating each pixel as its own input feature, which is usually never a great idea when working with image data. In our next example, we'll switch to the more modern approach of CNNs.
* **Semantic gap != feature gap**: Another factor of note is that images that have a wide semantic gap may not necessarily translate to a wide gap in terms of the data's visual features (e.g., ankle boots and bags might both be small, have leather, and have zippers). Part of an effective OOD detection scheme involves thinking carefully about what sorts of data contanimations may be observed by the model, and assessing how similar these contaminations may be to your desired class labels.

## Train and evaluate model on ID data
```python
model = LogisticRegression(max_iter=10, solver='lbfgs', multi_class='multinomial').fit(train_data_flat, train_labels) # 'lbfgs' is an efficient solver that works well for small to medium-sized datasets.
Expand Down Expand Up @@ -290,7 +292,7 @@ Unfortunately, we observe a significant amount of overlap between OOD data and h

For pants, the problem is much less severe. It looks like a low threshold (on this T-shirt probability scale) can separate nearly all OOD samples from being pants.

### Setting a threshold
## Setting a threshold
Let's put our observations to the test and produce a confusion matrix that includes ID-pants, ID-Tshirts, and OOD class labels. We'll start with a high threshold of 0.9 to see how that performs.
```python
def softmax_thresh_classifications(probs, threshold):
Expand Down Expand Up @@ -371,7 +373,7 @@ What threhsold is required to ensure that no OOD samples are incorrectly conside

With a very conservative threshold, we can make sure very few OOD samples are incorrectly classified as ID. However, the flip side is that conservative thresholds tend to incorrectly classify many ID samples as being OOD. In this case, we incorrectly assume almost 20% of shirts are OOD samples.

## Iterative Threshold Determination
## Iterative threshold determination

In practice, selecting an appropriate threshold is an iterative process that balances the trade-off between correctly identifying in-distribution (ID) data and accurately flagging out-of-distribution (OOD) data. Here's how you can iteratively determine the threshold:

Expand Down Expand Up @@ -558,8 +560,9 @@ disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification')
plt.show()
```

#### Discuss
How might you use these tools to ensure that a model trained on health data from hospital A will reliably predict new test data from hospital B?
How might you use these tools to ensure that a model trained on health data from hospital A will reliably predict new test data from hospital B?

:::::::::::::::::::::::::::::::::::::::: keypoints

Expand All @@ -569,4 +572,4 @@ How might you use these tools to ensure that a model trained on health data from
- While simple and widely used, softmax-based methods have limitations, including sensitivity to threshold choices and reduced reliability in high-dimensional settings.
- Understanding softmax-based OOD detection lays the groundwork for exploring more advanced techniques like energy-based detection.

::::::::::::::::::::::::::::::::::::::::::::::::::
::::::::::::::::::::::::::::::::::::::::::::::::::

0 comments on commit f8a3bfc

Please sign in to comment.