Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix confusion matrix type #20584

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions keras/src/metrics/iou_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from keras.src import backend
from keras.src import initializers
from keras.src import ops
Expand Down Expand Up @@ -55,8 +57,8 @@ def __init__(
sparse_y_pred=True,
axis=-1,
):
# defaulting to float32 to avoid issues with confusion matrix
super().__init__(name=name, dtype=dtype or "float32")
# defaulting to int to avoid issues with confusion matrix
super().__init__(name=name, dtype=dtype or "int")
# Metric should be maximized during optimization.
self._direction = "up"
self.num_classes = num_classes
Expand All @@ -69,6 +71,7 @@ def __init__(
name="total_confusion_matrix",
shape=(num_classes, num_classes),
initializer=initializers.Zeros(),
dtype=self.dtype,
)

def update_state(self, y_true, y_pred, sample_weight=None):
Expand Down Expand Up @@ -102,7 +105,17 @@ def update_state(self, y_true, y_pred, sample_weight=None):

if sample_weight is None:
sample_weight = 1

else:
if (
hasattr(sample_weight, "dtype")
and "float" in str(sample_weight.dtype)
and "int" in str(self.dtype)
):
warnings.warn(
"You are passing weight as `float`, but dtype is `int`. "
"This may result in an incorrect weight due to type casting"
" Consider using integer weights."
)
sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)

if len(sample_weight.shape) > 1:
Expand Down Expand Up @@ -131,7 +144,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
y_pred,
self.num_classes,
weights=sample_weight,
dtype="float32",
dtype=self.dtype,
)

return self.total_cm.assign(self.total_cm + current_cm)
Expand Down Expand Up @@ -272,10 +285,11 @@ def result(self):
denominator = ops.take_along_axis(
denominator, target_class_ids, axis=-1
)
denominator = ops.cast(denominator, dtype="float32")

# If the denominator is 0, we need to ignore the class.
num_valid_entries = ops.sum(
ops.cast(ops.greater(denominator, 1e-9), dtype=self.dtype)
ops.cast(ops.greater(denominator, 1e-9), dtype="float32")
)

iou = ops.divide(true_positives, denominator + backend.epsilon())
Expand Down Expand Up @@ -406,7 +420,8 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Update op.
"""
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
# convert y_pred on float 32 and cast just after to dtype
y_pred = ops.convert_to_tensor(y_pred, dtype="float32")
y_pred = ops.cast(y_pred >= self.threshold, self.dtype)
return super().update_state(y_true, y_pred, sample_weight)

Expand Down
145 changes: 133 additions & 12 deletions keras/src/metrics/iou_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import models
from keras.src import testing
from keras.src.metrics import iou_metrics as metrics
from keras.src.ops import convert_to_tensor


class IoUTest(testing.TestCase):
Expand All @@ -25,9 +26,7 @@ def test_unweighted(self):
y_pred = [0, 1, 0, 1]
y_true = [0, 0, 1, 1]

obj = metrics.IoU(
num_classes=2, target_class_ids=[0, 1], dtype="float32"
)
obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])

result = obj(y_true, y_pred)

Expand Down Expand Up @@ -64,7 +63,9 @@ def test_multi_dim_input(self):
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])

obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])
obj = metrics.IoU(
num_classes=2, target_class_ids=[0, 1], dtype="float32"
)

result = obj(y_true, y_pred, sample_weight=sample_weight)

Expand Down Expand Up @@ -136,7 +137,9 @@ def test_different_thresholds_weighted(self):
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=0.3, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand All @@ -150,7 +153,9 @@ def test_different_thresholds_weighted(self):
expected_result = (
0.5 / (0.5 + 0.7 - 0.5) + 0.3 / (0.5 + 0.3 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=0.5, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -191,7 +196,9 @@ def test_multi_dim_input(self):
expected_result = (
0.2 / (0.6 + 0.3 - 0.2) + 0.3 / (0.4 + 0.7 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=threshold, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -281,7 +288,7 @@ def test_weighted(self):
y_true = np.array([0, 0, 1, 1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])

m_obj = metrics.MeanIoU(num_classes=2)
m_obj = metrics.MeanIoU(num_classes=2, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand All @@ -300,7 +307,7 @@ def test_weighted_ignore_class_1(self):
y_true = np.array([0, 0, 1, -1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])

m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1)
m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand All @@ -319,7 +326,7 @@ def test_multi_dim_input(self):
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])

m_obj = metrics.MeanIoU(num_classes=2)
m_obj = metrics.MeanIoU(num_classes=2, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand Down Expand Up @@ -351,6 +358,112 @@ def test_zero_and_non_zero_entries(self):
expected_result = (0 + 1 / (1 + 1 - 1)) / 1
self.assertAllClose(result, expected_result, atol=1e-3)

@staticmethod
def _confusion_matrix(y_true, y_pred, num_classes):
"""
Creates a confusion matrix as a numpy array using vectorized operations.

Parameters:
- y_true: array-like, true class labels.
- y_pred: array-like, predicted class labels.
- num_classes: int, number of classes.

Returns:
- conf_matrix: np.ndarray, confusion matrix of shape (num_classes,
num_classes).
"""
# Map pairs of (y_true, y_pred) to indices in the confusion matrix
indices = y_true * num_classes + y_pred
# Count occurrences of each index
conf_matrix = np.bincount(indices, minlength=num_classes * num_classes)
# Reshape the flat array into a 2D confusion matrix
conf_matrix = conf_matrix.reshape((num_classes, num_classes))
return conf_matrix

@staticmethod
def _get_big_chunk(dtype):
np.random.seed(14)
all_y_true = np.random.choice([0, 1, 2], size=(10, 530, 530))
# Generate random probabilities for each channel
random_probs = np.random.rand(10, 530, 530, 3)
# Normalize to ensure the last dimension sums to 1
all_y_pred = random_probs / random_probs.sum(axis=-1, keepdims=True)
# Convert predictions to class indices
all_y_pred_arg = np.argmax(all_y_pred, axis=-1)
mean_iou_metric = metrics.MeanIoU(num_classes=3, dtype=dtype)
conf_matrix_start_point = np.array(
[
[18729664, 18728760, 18731196],
[18727297, 18726105, 18728071],
[18727917, 18717835, 18723155],
]
)
mean_iou_metric.total_cm = mean_iou_metric.add_variable(
name="total_confusion_matrix",
shape=(3, 3),
initializer=convert_to_tensor(conf_matrix_start_point),
dtype=dtype or "int",
)
mean_iou_metric.update_state(all_y_true, all_y_pred_arg)
tmp_true = np.reshape(all_y_true, -1)
tmp_pred = np.reshape(all_y_pred_arg, -1)
return (
all_y_true,
all_y_pred_arg,
mean_iou_metric,
tmp_true,
tmp_pred,
conf_matrix_start_point,
)

def test_big_chunk(self):
# Init. process with dtype=None which will default to int
(
all_y_true,
all_y_pred_arg,
mean_iou_metric_all,
tmp_true,
tmp_pred,
conf_matrix_start_point,
) = self._get_big_chunk(dtype=None)
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
# Validate confusion matrices and results
conf_matrix_manual = (
self._confusion_matrix(tmp_true, tmp_pred, 3)
+ conf_matrix_start_point
)
self.assertTrue(
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
msg="Confusion matrices do not match!",
)
# Now same but with float32 dtype, in here the confusion matrix
# should not match. Likely this can be removed
(
all_y_true,
all_y_pred_arg,
mean_iou_metric_all,
tmp_true,
tmp_pred,
conf_matrix_start_point,
) = self._get_big_chunk(dtype="float32")
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
# Validate confusion matrices and results
conf_matrix_manual = (
self._confusion_matrix(tmp_true, tmp_pred, 3)
+ conf_matrix_start_point
)
self.assertFalse(
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
msg="Confusion matrices match, but they should not!",
)

def test_user_warning_float_weight(self):
y_pred = [0, 1, 1, 1]
y_true = [0, 1, 1, 0]
m_obj = metrics.MeanIoU(num_classes=3)
with pytest.warns(Warning, match=r"weight.*float.*int.*casting"):
m_obj(y_true, y_pred, sample_weight=np.array([0.2, 0.3, 0.4, 0.1]))


class OneHotIoUTest(testing.TestCase):
def test_unweighted(self):
Expand Down Expand Up @@ -385,7 +498,9 @@ def test_weighted(self):
# true_positives = [0, 0, 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
obj = metrics.OneHotIoU(
num_classes=3, target_class_ids=[0, 2], dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -439,6 +554,12 @@ def test_weighted(self):
expected_result = (
0.1 / (0.4 + 0.6 - 0.1) + 0 + 0.1 / (0.6 + 0.1 - 0.1)
) / 3
obj = metrics.OneHotMeanIoU(num_classes=3)
obj = metrics.OneHotMeanIoU(num_classes=3, dtype="float32")
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

# Check same result with int weights
sample_weight_int = [1, 2, 3, 3, 1]
obj_int = metrics.OneHotMeanIoU(num_classes=3)
result_int = obj_int(y_true, y_pred, sample_weight=sample_weight_int)
self.assertAllClose(result_int, expected_result, atol=1e-3)
Loading