Skip to content

Commit

Permalink
Centralize function for generating a list vector of scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
vyasr committed Nov 29, 2023
1 parent 54a13e6 commit 4bed647
Showing 1 changed file with 19 additions and 25 deletions.
44 changes: 19 additions & 25 deletions python/cudf/cudf/_lib/pylibcudf/copying.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ from .table cimport Table
ctypedef const scalar constscalar


cdef vector[reference_wrapper[const scalar]] _make_scalar_vector(list source):
"""Make a vector of reference_wrapper[const scalar] from a list of scalars."""
if not isinstance(source, list) or not isinstance(source[0], Scalar):
raise ValueError("source must be a list[Scalar]")

cdef vector[reference_wrapper[const scalar]] c_scalars
c_scalars.reserve(len(source))
cdef Scalar slr
for slr in source:
c_scalars.push_back(
reference_wrapper[constscalar](dereference(slr.c_obj))
)
return c_scalars


# TODO: Is it OK to reference the corresponding libcudf algorithm in the
# documentation? Otherwise there's a lot of room for duplication.
cpdef Table gather(
Expand Down Expand Up @@ -98,24 +113,12 @@ cpdef Table table_scatter(Table source, Column scatter_map, Table target_table):
return Table.from_libcudf(move(c_result))


cdef _check_is_list_of_scalars(list source):
if not isinstance(source, list) or not isinstance(source[0], Scalar):
raise ValueError("source must be a list[Scalar]")


# TODO: Could generalize list to sequence
cpdef Table scalar_scatter(list source, Column scatter_map, Table target_table):
cdef unique_ptr[table] c_result
cdef vector[reference_wrapper[const scalar]] source_scalars
cdef Scalar slr

_check_is_list_of_scalars(source)

for slr in source:
source_scalars.push_back(
reference_wrapper[constscalar](dereference(slr.c_obj))
)
source_scalars = _make_scalar_vector(source)

cdef unique_ptr[table] c_result
with nogil:
c_result = move(
cpp_copying.scatter(
Expand Down Expand Up @@ -230,22 +233,13 @@ cpdef Table table_boolean_mask_scatter(Table input, Table target, Column boolean

# TODO: Could generalize list to sequence
cpdef Table scalar_boolean_mask_scatter(list input, Table target, Column boolean_mask):
_check_is_list_of_scalars(input)

cdef vector[reference_wrapper[const scalar]] c_scalars
c_scalars.reserve(len(input))

cdef Scalar slr
for slr in input:
c_scalars.push_back(
reference_wrapper[constscalar](dereference(slr.c_obj))
)
source_scalars = _make_scalar_vector(input)

cdef unique_ptr[table] result
with nogil:
result = move(
cpp_copying.boolean_mask_scatter(
c_scalars,
source_scalars,
target.view(),
boolean_mask.view(),
)
Expand Down

0 comments on commit 4bed647

Please sign in to comment.