Skip to content

Commit

Permalink
Support Python UDFs written in terms of rows (#9343)
Browse files Browse the repository at this point in the history
DEPENDS on #9217

Introduces a row-like abstraction to the numba UDF pipeline  which enables functions of the following form:

```python
def f(row):
    return row['a'] + row['b']
```

To be applied to dataframes with the corresponding column labels using

```python
df.apply(f, axis=1)
```


Removes the `nulludf` decorator and as such is a breaking change. However since it was just introduced anyways as somewhat of a stopgap, the impact is hopefully low. Users will still be able to write functions the old way, but will require the `numba.cuda.jit(device=True)` decorator for the function to work when wrapped in a lambda

```python
@cuda.jit(device=True)
def f(x, y):
    return x +y

df.apply(lambda row: f(row['a'], row['b'])
```

Makes it so that pandas and cudf can consume the exact same UDF in the same way.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Graham Markall (https://github.com/gmarkall)
  - Ashwin Srinath (https://github.com/shwina)

URL: #9343
  • Loading branch information
brandon-b-miller authored Oct 12, 2021
1 parent 5e46c7e commit 7fa2738
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 205 deletions.
245 changes: 142 additions & 103 deletions docs/cudf/source/user_guide/guide-to-udfs.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4856,7 +4856,7 @@ def apply(
if args or kwargs:
raise ValueError("args and kwargs are not yet supported.")

return cudf.Series(func(self))
return self._apply(func)

@applyutils.doc_apply()
def apply_rows(
Expand Down
19 changes: 14 additions & 5 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import cupy
import numpy as np
import pandas as pd
from numba import cuda
from pandas._config import get_option

import cudf
Expand Down Expand Up @@ -3372,7 +3373,10 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs):
Notes
-----
UDFs are cached in memory to avoid recompilation. The first
call to the UDF will incur compilation overhead.
call to the UDF will incur compilation overhead. `func` may
call nested functions that are decorated with the decorator
`numba.cuda.jit(device=True)`, otherwise numba will raise a
typing error.
Examples
--------
Expand Down Expand Up @@ -3425,16 +3429,21 @@ def apply(self, func, convert_dtype=True, args=(), **kwargs):
1 <NA>
2 4.5
dtype: float64
"""
if args or kwargs:
raise ValueError(
"UDFs using *args or **kwargs are not yet supported."
)

return super()._apply(func)
# these functions are generally written as functions of scalar
# values rather than rows. Rather than writing an entirely separate
# numba kernel that is not built around a row object, its simpler
# to just turn this into the equivalent single column dataframe case
name = self.name or "__temp_srname"
df = cudf.DataFrame({name: self})
f_ = cuda.jit(device=True)(func)

return df.apply(lambda row: f_(row[name]))

def applymap(self, udf, out_dtype=None):
"""Apply an elementwise function to transform the values in the Column.
Expand Down
1 change: 1 addition & 0 deletions python/cudf/cudf/core/udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def cast_masked_to_masked(context, builder, fromty, toty, val):


# Masked constructor for use in a kernel for testing
@lower_builtin(api.Masked, types.Boolean, types.boolean)
@lower_builtin(api.Masked, types.Number, types.boolean)
def masked_constructor(context, builder, sig, args):
ty = sig.return_type
Expand Down
187 changes: 139 additions & 48 deletions python/cudf/cudf/core/udf/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math

import cachetools
import numpy as np
from numba import cuda
from numba.np import numpy_support
from numba.types import Tuple, boolean, int64, void
from numba.types import Record, Tuple, boolean, int64, void
from nvtx import annotate

from cudf.core.udf.api import Masked, pack_return
Expand All @@ -14,21 +16,67 @@
precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32)


def get_frame_row_type(fr):
"""
Get the numba `Record` type corresponding to a frame.
Models each column and its mask as a MaskedType and
models the row as a dictionary like data structure
containing these MaskedTypes.
Large parts of this function are copied with comments
from the Numba internals and slightly modified to
account for validity bools to be present in the final
struct.
"""

# Create the numpy structured type corresponding to the frame.
dtype = np.dtype([(name, col.dtype) for name, col in fr._data.items()])

fields = []
offset = 0

sizes = [val[0].itemsize for val in dtype.fields.values()]
for i, (name, info) in enumerate(dtype.fields.items()):
# *info* consists of the element dtype, its offset from the beginning
# of the record, and an optional "title" containing metadata.
# We ignore the offset in info because its value assumes no masking;
# instead, we compute the correct offset based on the masked type.
elemdtype = info[0]
title = info[2] if len(info) == 3 else None
ty = numpy_support.from_dtype(elemdtype)
infos = {
"type": MaskedType(ty),
"offset": offset,
"title": title,
}
fields.append((name, infos))

# increment offset by itemsize plus one byte for validity
offset += elemdtype.itemsize + 1

# Align the next member of the struct to be a multiple of the
# memory access size, per PTX ISA 7.4/5.4.5
if i < len(sizes) - 1:
next_itemsize = sizes[i + 1]
offset = int(math.ceil(offset / next_itemsize) * next_itemsize)

# Numba requires that structures are aligned for the CUDA target
_is_aligned_struct = True
return Record(fields, offset, _is_aligned_struct)


@annotate("NUMBA JIT", color="green", domain="cudf_python")
def get_udf_return_type(func, dtypes):
def get_udf_return_type(func, df):
"""
Get the return type of a masked UDF for a given set of argument dtypes. It
is assumed that a `MaskedType(dtype)` is passed to the function for each
input dtype.
"""
to_compiler_sig = tuple(
MaskedType(arg)
for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes)
)
row_type = get_frame_row_type(df)

# Get the return type. The PTX is also returned by compile_udf, but is not
# needed here.
ptx, output_type = cudautils.compile_udf(func, to_compiler_sig)

ptx, output_type = cudautils.compile_udf(func, (row_type,))
if not isinstance(output_type, MaskedType):
numba_output_type = numpy_support.from_dtype(np.dtype(output_type))
else:
Expand All @@ -37,33 +85,6 @@ def get_udf_return_type(func, dtypes):
return numba_output_type


def nulludf(func):
"""
Mimic pandas API:
def f(x, y):
return x + y
df.apply(lambda row: f(row['x'], row['y']))
in this scheme, `row` is actually the whole dataframe
`DataFrame` sends `self` in as `row` and subsequently
we end up calling `f` on the resulting columns since
the dataframe is dict-like
"""

def wrapper(*args):
from cudf import DataFrame

# This probably creates copies but is fine for now
to_udf_table = DataFrame(
{idx: arg for idx, arg in zip(range(len(args)), args)}
)
# Frame._apply
return to_udf_table._apply(func)

return wrapper


def masked_array_type_from_col(col):
"""
Return a type representing a tuple of arrays,
Expand Down Expand Up @@ -109,8 +130,19 @@ def _kernel(retval, {input_columns}, {input_offsets}, size):
i = cuda.grid(1)
ret_data_arr, ret_mask_arr = retval
if i < size:
# Create a structured array with the desired fields
rows = cuda.local.array(1, dtype=row_type)
# one element of that array
row = rows[0]
{masked_input_initializers}
ret = {user_udf_call}
{row_initializers}
# pass the assembled row into the udf
ret = f_(row)
# pack up the return values and set them
ret_masked = pack_return(ret)
ret_data_arr[i] = ret_masked.value
ret_mask_arr[i] = ret_masked.valid
Expand All @@ -126,19 +158,52 @@ def _kernel(retval, {input_columns}, {input_offsets}, size):
masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx}))
"""

row_initializer_template = """\
row["{name}"] = masked_{idx}
"""

def _define_function(df, scalar_return=False):
# Create argument list for kernel
input_columns = ", ".join([f"input_col_{i}" for i in range(len(df._data))])
input_offsets = ", ".join([f"offset_{i}" for i in range(len(df._data))])

# Create argument list to pass to device function
args = ", ".join([f"masked_{i}" for i in range(len(df._data))])
user_udf_call = f"f_({args})"
def _define_function(fr, row_type, scalar_return=False):
"""
The kernel we want to JIT compile looks something like the following,
which is an example for two columns that both have nulls present
def _kernel(retval, input_col_0, input_col_1, offset_0, offset_1, size):
i = cuda.grid(1)
ret_data_arr, ret_mask_arr = retval
if i < size:
rows = cuda.local.array(1, dtype=row_type)
row = rows[0]
d_0, m_0 = input_col_0
masked_0 = Masked(d_0[i], mask_get(m_0, i + offset_0))
d_1, m_1 = input_col_1
masked_1 = Masked(d_1[i], mask_get(m_1, i + offset_1))
row["a"] = masked_0
row["b"] = masked_1
ret = f_(row)
ret_masked = pack_return(ret)
ret_data_arr[i] = ret_masked.value
ret_mask_arr[i] = ret_masked.valid
However we do not always have two columns and columns do not always have
an associated mask. Ideally, we would just write one kernel and make use
of `*args` - and then one function would work for any number of columns,
currently numba does not support `*args` and treats functions it JITs as
if `*args` is a singular argument. Thus we are forced to write the right
funtions dynamically at runtime and define them using `exec`.
"""
# Create argument list for kernel
input_columns = ", ".join([f"input_col_{i}" for i in range(len(fr._data))])
input_offsets = ", ".join([f"offset_{i}" for i in range(len(fr._data))])

# Generate the initializers for each device function argument
initializers = []
for i, col in enumerate(df._data.values()):
row_initializers = []
for i, (colname, col) in enumerate(fr._data.items()):
idx = str(i)
if col.mask is not None:
template = masked_input_initializer_template
Expand All @@ -149,14 +214,21 @@ def _define_function(df, scalar_return=False):

initializers.append(initializer)

row_initializer = row_initializer_template.format(
idx=idx, name=colname
)
row_initializers.append(row_initializer)

masked_input_initializers = "\n".join(initializers)
row_initializers = "\n".join(row_initializers)

# Incorporate all of the above into the kernel code template
d = {
"input_columns": input_columns,
"input_offsets": input_offsets,
"masked_input_initializers": masked_input_initializers,
"user_udf_call": user_udf_call,
"row_initializers": row_initializers,
"numba_rectype": row_type, # from global
}

return kernel_template.format(**d)
Expand All @@ -173,19 +245,32 @@ def compile_or_get(df, f):
If the UDF has already been compiled for this requested dtypes,
a cached version will be returned instead of running compilation.
CUDA kernels are void and do not return values. Thus, we need to
preallocate a column of the correct dtype and pass it in as one of
the kernel arguments. This creates a chicken-and-egg problem where
we need the column type to compile the kernel, but normally we would
be getting that type FROM compiling the kernel (and letting numba
determine it as a return value). As a workaround, we compile the UDF
itself outside the final kernel to invoke a full typing pass, which
unfortunately is difficult to do without running full compilation.
we then obtain the return type from that separate compilation and
use it to allocate an output column of the right dtype.
"""

# check to see if we already compiled this function
frame_dtypes = tuple(col.dtype for col in df._data.values())
cache_key = (
*cudautils.make_cache_key(f, frame_dtypes),
*(col.mask is None for col in df._data.values()),
*df._data.keys(),
)
if precompiled.get(cache_key) is not None:
kernel, scalar_return_type = precompiled[cache_key]
return kernel, scalar_return_type

numba_return_type = get_udf_return_type(f, frame_dtypes)
# precompile the user udf to get the right return type.
# could be a MaskedType or a scalar type.
numba_return_type = get_udf_return_type(f, df)

_is_scalar_return = not isinstance(numba_return_type, MaskedType)
scalar_return_type = (
Expand All @@ -194,9 +279,14 @@ def compile_or_get(df, f):
else numba_return_type.value_type
)

# this is the signature for the final full kernel compilation
sig = construct_signature(df, scalar_return_type)
f_ = cuda.jit(device=True)(f)

# this row type is used within the kernel to pack up the column and
# mask data into the dict like data structure the user udf expects
row_type = get_frame_row_type(df)

f_ = cuda.jit(device=True)(f)
# Dict of 'local' variables into which `_kernel` is defined
local_exec_context = {}
global_exec_context = {
Expand All @@ -205,9 +295,10 @@ def compile_or_get(df, f):
"Masked": Masked,
"mask_get": mask_get,
"pack_return": pack_return,
"row_type": row_type,
}
exec(
_define_function(df, scalar_return=_is_scalar_return),
_define_function(df, row_type, scalar_return=_is_scalar_return),
global_exec_context,
local_exec_context,
)
Expand Down
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/udf/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ def typeof_masked(val, c):
@cuda_decl_registry.register
class MaskedConstructor(ConcreteTemplate):
key = api.Masked

cases = [
nb_signature(MaskedType(t), t, types.boolean)
for t in (types.integer_domain | types.real_domain)
for t in (types.integer_domain | types.real_domain | {types.boolean})
]


Expand Down
Loading

0 comments on commit 7fa2738

Please sign in to comment.