From 60941066b5e52999fa4bc243d84d1903cfcc1b2d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 16 Jun 2024 04:57:55 +0800 Subject: [PATCH] [backport] Allow unaligned pointer if the array is empty. (#10418) (#10424) Co-authored-by: Philip Hyunsu Cho --- src/data/array_interface.h | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/data/array_interface.h b/src/data/array_interface.h index fafe0b6acc8e..f96ecd0c86a8 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -6,28 +6,25 @@ #ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_ -#include -#include // for size_t -#include -#include // for numeric_limits -#include -#include +#include // for all_of, transform, fill +#include // for size_t +#include // for int32_t, int64_t, ... +#include // for numeric_limits +#include // for map +#include // for string #include // for alignment_of, remove_pointer_t, invoke_result_t -#include -#include +#include // for vector -#include "../common/bitfield.h" // for RBitField8 -#include "../common/common.h" +#include "../common/bitfield.h" // for RBitField8 #include "../common/error_msg.h" // for NoF128 -#include "xgboost/base.h" -#include "xgboost/data.h" -#include "xgboost/json.h" -#include "xgboost/linalg.h" -#include "xgboost/logging.h" -#include "xgboost/span.h" +#include "xgboost/json.h" // for Json +#include "xgboost/linalg.h" // for CalcStride, TensorView +#include "xgboost/logging.h" // for CHECK +#include "xgboost/span.h" // for Span +#include "xgboost/string_view.h" // for StringView #if defined(XGBOOST_USE_CUDA) -#include "cuda_fp16.h" +#include "cuda_fp16.h" // for __half #endif namespace xgboost { @@ -410,7 +407,7 @@ class ArrayInterface { auto typestr = get(array.at("typestr")); this->AssignType(StringView{typestr}); ArrayInterfaceHandler::ExtractShape(array, shape); - size_t itemsize = typestr[2] - '0'; + std::size_t itemsize = typestr[2] - '0'; is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides); n = linalg::detail::CalcSize(shape); @@ -419,7 +416,9 @@ class ArrayInterface { auto alignment = this->ElementAlignment(); auto ptr = reinterpret_cast(this->data); - CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment."; + if (!std::all_of(this->shape, this->shape + D, [](auto v) { return v == 0; })) { + CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment."; + } if (allow_mask) { common::Span s_mask;