Skip to content

Commit

Permalink
[REVIEW] Adapt to changes in cudf.core.buffer.Buffer (rapidsai#5154)
Browse files Browse the repository at this point in the history
This PR adapts to the breaking changes being introduced in: rapidsai/cudf#12587

Authors:
   - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
   - Lawrence Mitchell (https://github.com/wence-)
   - AJ Schmidt (https://github.com/ajschmidt8)
   - Dante Gama Dessavre (https://github.com/dantegd)
  • Loading branch information
galipremsagar authored Jan 26, 2023
1 parent 71edf93 commit 3d0fa32
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def subtract_valid(input_array, valid_bool_array, sub_val):
input_array[pos] = input_array[pos] - sub_val


@cudf.core.buffer.acquire_spill_lock()
def get_stem_series(word_str_ser, suffix_len, can_replace_mask):
"""
word_str_ser: input string column
Expand All @@ -95,8 +96,8 @@ def get_stem_series(word_str_ser, suffix_len, can_replace_mask):
start_series = cudf.Series(cp.zeros(len(word_str_ser), dtype=cp.int32))
end_ser = word_str_ser.str.len()

end_ar = end_ser._column.data_array_view
can_replace_mask_ar = can_replace_mask._column.data_array_view
end_ar = end_ser._column.data_array_view(mode="read")
can_replace_mask_ar = can_replace_mask._column.data_array_view(mode="read")

subtract_valid[NBLCK, NTHRD](end_ar, can_replace_mask_ar, suffix_len)
return word_str_ser.str.slice_from(
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/tests/test_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ def check_numpy_order(ary, order):
def check_ptr(a, b, input_type):
if input_type == 'cudf':
for (_, col_a), (_, col_b) in zip(a._data.items(), b._data.items()):
assert col_a.base_data.ptr == col_b.base_data.ptr
with cudf.core.buffer.acquire_spill_lock():
assert (
col_a.base_data.get_ptr(mode="read") ==
col_b.base_data.get_ptr(mode="read")
)
else:
def get_ptr(x):
try:
Expand Down

0 comments on commit 3d0fa32

Please sign in to comment.