Skip to content

Commit

Permalink
Move CheckAxes to utils.h
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao committed Jan 5, 2023
1 parent 8db68c1 commit 02b8d12
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 68 deletions.
43 changes: 42 additions & 1 deletion dali/kernels/common/utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,47 @@ OutShape GetStrides(const Shape& shape) {
return strides;
}


/**
* @brief Checks that axes only appear once and that they are within range.
*
* @param axes list of axis indices
* @param ndim dimensionality of the tensor(list) to which axes refer
*/
template <typename Axes>
inline void CheckAxes(Axes&& axes, int ndim) {
assert(ndim >= 0 && ndim <= 64);
uint64_t mask = 0;
for (auto a : axes) {
if (a < -ndim || a >= ndim)
throw std::out_of_range(make_string("Axis index out of range: ", a, " not in range [", -ndim,
"..", ndim - 1, "]"));
if (a < 0)
a += ndim;
uint64_t amask = 1_u64 << a;
if (mask & amask)
throw std::invalid_argument(make_string("Duplicate axis index ", a));
mask |= amask;
}
}


/**
* @brief Adjusts negative axis indices to the positive range.
* Negative indices are counted from the back.
*
* @param axes list of axis indices
* @param ndim dimensionality of the tensor(list) to which axes refer
*/
template <typename Axes>
void AdjustAxes(Axes& axes, int ndim) {
for (auto& a : axes) {
assert(a >= -ndim && a < ndim);
if (a < 0)
a += ndim;
}
}

} // namespace kernels
} // namespace dali

Expand Down
6 changes: 3 additions & 3 deletions dali/kernels/reduce/reduce_cpu.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,8 +244,8 @@ struct ReduceBaseCPU {
}

axes.copy_assign(_axes.begin(), _axes.end());
reduce_impl::CheckAxes(make_cspan(axes), ndim());
reduce_impl::AdjustAxes(make_span(axes), ndim());
CheckAxes(make_cspan(axes), ndim());
AdjustAxes(make_span(axes), ndim());
axis_mask = 0;
for (int axis : axes) {
axis_mask |= 1_u64 << axis;
Expand Down
26 changes: 1 addition & 25 deletions dali/kernels/reduce/reduce_gpu_impl_test.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -200,30 +200,6 @@ TEST(ReduceImplGPU, Simplify_NoOp) {
EXPECT_EQ(groups[2], std::make_pair(2, 1));
}

TEST(ReduceImpl, TestCheckAxes) {
EXPECT_NO_THROW(CheckAxes({}, 0));
int axes_0[] = { 0 };
int axes_01[] = { 0, 1 };
int axes_2[] = { 2 };
int axes_010[] = { 0, 1, 0 };
int axes_neg1[] = { -1 };
int axes_neg2[] = { -2 };
int axes_neg12[] = { -2, -1 };
int axes_neg121[] = {-2, -1, -2 };
int axes_pos0_neg1[] = {0, -1};
EXPECT_NO_THROW(CheckAxes(make_span(axes_0), 1));
EXPECT_NO_THROW(CheckAxes(make_span(axes_01), 2));
EXPECT_NO_THROW(CheckAxes(make_span(axes_neg1), 1));
EXPECT_NO_THROW(CheckAxes(make_span(axes_neg2), 2));
EXPECT_NO_THROW(CheckAxes(make_span(axes_neg12), 2));
EXPECT_NO_THROW(CheckAxes(make_span(axes_pos0_neg1), 2));
EXPECT_THROW(CheckAxes(make_span(axes_neg2), 1), std::out_of_range);
EXPECT_THROW(CheckAxes(make_span(axes_2), 2), std::out_of_range);
EXPECT_THROW(CheckAxes(make_span(axes_010), 2), std::invalid_argument);
EXPECT_THROW(CheckAxes(make_span(axes_neg121), 2), std::invalid_argument);
EXPECT_THROW(CheckAxes(make_span(axes_pos0_neg1), 1), std::invalid_argument);
}

TEST(ReduceImpl, TestCheckBatchReduce) {
TensorListShape<> tls = {{
{ 3, 3, 2, 4 },
Expand Down
40 changes: 1 addition & 39 deletions dali/kernels/reduce/reduce_setup_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -138,44 +138,6 @@ inline void CheckBatchReduce(const TensorListShape<> &tls, span<const int> axes)
}


/**
* @brief Checks that axes only appear once and that they are within range.
*
* @param axes list of axis indices
* @param ndim dimensionality of the tensor(list) to which axes refer
*/
inline void CheckAxes(span<const int> axes, int ndim) {
assert(ndim >= 0 && ndim <= 64);
uint64_t mask = 0;
for (auto a : axes) {
if (a < -ndim || a >= ndim)
throw std::out_of_range(make_string("Axis index out of range: ", a, " not in range [", -ndim,
"..", ndim - 1, "]"));
if (a < 0)
a += ndim;
uint64_t amask = 1_u64 << a;
if (mask & amask)
throw std::invalid_argument(make_string("Duplicate axis index ", a));
mask |= amask;
}
}


/**
* @brief Adjusts negative axis indices to the positive range.
* Negative indices are counted from the back.
*
* @param axes list of axis indices
* @param ndim dimensionality of the tensor(list) to which axes refer
*/
inline void AdjustAxes(span<int> axes, int ndim) {
for (auto& a : axes) {
assert(a >= -ndim && a < ndim);
if (a < 0)
a += ndim;
}
}

/**
* @brief Calculates the shape of the result of reduction under given parameters
*
Expand Down

0 comments on commit 02b8d12

Please sign in to comment.