From d05224eb99284d821d304711577d630f19a12a79 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 27 Jun 2024 16:49:39 +0200 Subject: [PATCH] Text fun --- tests/tests_utils/test_utils.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/tests_utils/test_utils.py b/tests/tests_utils/test_utils.py index 9421dafdb..f7cdee559 100644 --- a/tests/tests_utils/test_utils.py +++ b/tests/tests_utils/test_utils.py @@ -1,5 +1,5 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors + """Tests for the direct.utils module.""" import pathlib @@ -9,7 +9,7 @@ import pytest import torch -from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds +from direct.utils import is_power_of_two, normalize_image, remove_keys, set_all_seeds, reshape_array_to_shape from direct.utils.asserts import assert_complex from direct.utils.bbox import crop_to_largest from direct.utils.dataset import get_filenames_for_datasets_from_config @@ -126,3 +126,20 @@ def test_normalize_image(shape, eps): img = np.random.randn(*shape) normalized_img = normalize_image(img, eps) assert normalized_img.min() >= 0.0 and normalized_img.max() <= 1.0 + + +@pytest.mark.parametrize( + "array, requested_shape, expected_shape", + [ + (np.random.rand(4, 5), (4, 5, 1), (4, 5, 1)), + (np.random.rand(4, 5), (1, 4, 5, 1), (1, 4, 5, 1)), + (np.random.rand(2, 4, 5), (2, 4, 5, 1), (2, 4, 5, 1)), + (np.random.rand(3, 3), (1, 3, 1, 3, 1), (1, 3, 1, 3, 1)), + (np.random.rand(2, 3), (2, 1, 3), (2, 1, 3)), + (np.random.rand(4), (1, 1, 4, 1), (1, 1, 4, 1)), + (np.random.rand(6), (1, 6, 1), (1, 6, 1)), + ] +) +def test_reshape_array_to_shape(array, requested_shape, expected_shape): + result = reshape_array_to_shape(array, requested_shape) + assert result.shape == expected_shape \ No newline at end of file