Skip to content

Commit

Permalink
[backport] Allow unaligned pointer if the array is empty. (#10418) (#…
Browse files Browse the repository at this point in the history
…10424)

Co-authored-by: Philip Hyunsu Cho <[email protected]>
  • Loading branch information
trivialfis and hcho3 authored Jun 15, 2024
1 parent f789e50 commit 6094106
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,25 @@
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_

#include <algorithm>
#include <cstddef> // for size_t
#include <cstdint>
#include <limits> // for numeric_limits
#include <map>
#include <string>
#include <algorithm> // for all_of, transform, fill
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int64_t, ...
#include <limits> // for numeric_limits
#include <map> // for map
#include <string> // for string
#include <type_traits> // for alignment_of, remove_pointer_t, invoke_result_t
#include <utility>
#include <vector>
#include <vector> // for vector

#include "../common/bitfield.h" // for RBitField8
#include "../common/common.h"
#include "../common/bitfield.h" // for RBitField8
#include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "xgboost/logging.h"
#include "xgboost/span.h"
#include "xgboost/json.h" // for Json
#include "xgboost/linalg.h" // for CalcStride, TensorView
#include "xgboost/logging.h" // for CHECK
#include "xgboost/span.h" // for Span
#include "xgboost/string_view.h" // for StringView

#if defined(XGBOOST_USE_CUDA)
#include "cuda_fp16.h"
#include "cuda_fp16.h" // for __half
#endif

namespace xgboost {
Expand Down Expand Up @@ -410,7 +407,7 @@ class ArrayInterface {
auto typestr = get<String const>(array.at("typestr"));
this->AssignType(StringView{typestr});
ArrayInterfaceHandler::ExtractShape(array, shape);
size_t itemsize = typestr[2] - '0';
std::size_t itemsize = typestr[2] - '0';
is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides);
n = linalg::detail::CalcSize(shape);

Expand All @@ -419,7 +416,9 @@ class ArrayInterface {

auto alignment = this->ElementAlignment();
auto ptr = reinterpret_cast<uintptr_t>(this->data);
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
if (!std::all_of(this->shape, this->shape + D, [](auto v) { return v == 0; })) {
CHECK_EQ(ptr % alignment, 0) << "Input pointer misalignment.";
}

if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
Expand Down

0 comments on commit 6094106

Please sign in to comment.