From b9c056ee09c76ef45b0e9b6ba36d2c35764f321b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 7 Feb 2025 11:35:44 -0800 Subject: [PATCH] Use tensor_shape_to_c_string for error in check_mask_indices Rolling out for #7902 ghstack-source-id: 7f375bd1157c10efe7d833b45f8b474b3c3c1111 ghstack-comment-id: 2643854240 Pull Request resolved: https://github.com/pytorch/executorch/pull/8314 --- .../portable/cpu/util/advanced_index_util.cpp | 20 ++++++++++++++++--- kernels/portable/cpu/util/targets.bzl | 1 + 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/kernels/portable/cpu/util/advanced_index_util.cpp b/kernels/portable/cpu/util/advanced_index_util.cpp index e2eabec4bc..6af310dbf6 100644 --- a/kernels/portable/cpu/util/advanced_index_util.cpp +++ b/kernels/portable/cpu/util/advanced_index_util.cpp @@ -7,6 +7,7 @@ */ #include +#include #include namespace torch { @@ -49,9 +50,22 @@ bool check_mask_indices(const Tensor& in, TensorOptList indices) { ET_LOG_MSG_AND_RETURN_IF_FALSE( index.dim() > 0, "Zero-dimensional mask index not allowed"); for (auto j = 0; j < index.dim(); j++) { - ET_LOG_MSG_AND_RETURN_IF_FALSE( - index.size(j) == in.size(in_i + j), - "The shape of mask index must match the sizes of the corresponding input dimensions."); + if (index.size(j) != in.size(in_i + j)) { +#ifdef ET_LOG_ENABLED + auto mask_shape = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + index.sizes().data(), index.sizes().size())); + auto input_shape = executorch::runtime::tensor_shape_to_c_string( + executorch::runtime::Span( + in.sizes().data() + in_i, index.sizes().size())); +#endif // ET_LOG_ENABLED + ET_LOG( + Error, + "The shape of mask index %s must match the sizes of the corresponding input dimensions %s.", + mask_shape.data(), + input_shape.data()); + return false; + } } in_i += index.dim(); } else { diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 0115feb625..2c25d17156 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -117,6 +117,7 @@ def define_common_targets(): compiler_flags = ["-Wno-missing-prototypes"], deps = [ ":broadcast_util", + "//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string", "//executorch/runtime/kernel:kernel_includes", ], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],