From 22a9f350eb807cd83b1ee0b35b90aa64b6c28809 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 17 Mar 2022 09:18:27 -0700 Subject: [PATCH 1/3] Simplify column binary operations (#10421) This PR contains a large set of changes aimed at simplifying the logic used to normalize binary operands and perform binary operations at the column level. Many previously disparate code paths have been consolidated, while others have simply been moved into the narrowest possible scope so that they no longer affect all operations. In particular, all dtype-specific logic has been moved from `Frame._colwise_binop` into the appropriate column dtypes. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Michael Wang (https://github.com/isVoid) - Charles Blackmon-Luca (https://github.com/charlesbluca) URL: https://github.com/rapidsai/cudf/pull/10421 --- python/cudf/cudf/api/types.py | 2 +- python/cudf/cudf/core/column/categorical.py | 42 ++++--- python/cudf/cudf/core/column/column.py | 5 + python/cudf/cudf/core/column/datetime.py | 21 ++-- python/cudf/cudf/core/column/decimal.py | 84 +++++++------ python/cudf/cudf/core/column/lists.py | 5 +- python/cudf/cudf/core/column/numerical.py | 116 ++++++++++-------- python/cudf/cudf/core/column/string.py | 64 +++++++--- python/cudf/cudf/core/column/timedelta.py | 73 +++-------- python/cudf/cudf/core/dataframe.py | 3 +- python/cudf/cudf/core/frame.py | 101 ++------------- python/cudf/cudf/core/series.py | 2 +- python/cudf/cudf/tests/test_categorical.py | 11 +- python/cudf/cudf/tests/test_timedelta.py | 2 +- .../dask_cudf/tests/test_accessor.py | 7 +- 15 files changed, 230 insertions(+), 308 deletions(-) diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index 6d5387591cb..fad2a973681 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -248,7 +248,7 @@ def _union_categoricals( is_datetime64tz_dtype = pd_types.is_datetime64tz_dtype is_extension_type = pd_types.is_extension_type is_extension_array_dtype = pd_types.is_extension_array_dtype -is_float_dtype = pd_types.is_float_dtype +is_float_dtype = _wrap_pandas_is_dtype_api(pd_types.is_float_dtype) is_int64_dtype = pd_types.is_int64_dtype is_integer_dtype = _wrap_pandas_is_dtype_api(pd_types.is_integer_dtype) is_object_dtype = pd_types.is_object_dtype diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index 27cd6b631bc..caab2294484 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -878,38 +878,48 @@ def slice( def binary_operator( self, op: str, rhs, reflect: bool = False ) -> ColumnBase: - if not (self.ordered and rhs.ordered) and op not in ( - "eq", - "ne", - "NULL_EQUALS", - ): - if op in ("lt", "gt", "le", "ge"): - raise TypeError( - "Unordered Categoricals can only compare equality or not" - ) + if op not in {"eq", "ne", "lt", "le", "gt", "ge", "NULL_EQUALS"}: raise TypeError( - f"Series of dtype `{self.dtype}` cannot perform the " - f"operation: {op}" + "Series of dtype `category` cannot perform the operation: " + f"{op}" + ) + rhs = self._wrap_binop_normalization(rhs) + # TODO: This is currently just here to make mypy happy, but eventually + # we'll need to properly establish the APIs for these methods. + if not isinstance(rhs, CategoricalColumn): + raise ValueError + # Note: at this stage we are guaranteed that the dtypes are equal. + if not self.ordered and op not in {"eq", "ne", "NULL_EQUALS"}: + raise TypeError( + "The only binary operations supported by unordered " + "categorical columns are equality and inequality." ) - if self.dtype != rhs.dtype: - raise TypeError("Categoricals can only compare with the same type") return self.as_numerical.binary_operator(op, rhs.as_numerical) def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn: - + if isinstance(other, column.ColumnBase): + if not isinstance(other, CategoricalColumn): + raise ValueError( + "Binary operations with categorical columns require both " + "columns to be categorical." + ) + if other.dtype != self.dtype: + raise TypeError( + "Categoricals can only compare with the same type" + ) + return other if isinstance(other, np.ndarray) and other.ndim == 0: other = other.item() ary = cudf.utils.utils.scalar_broadcast_to( self._encode(other), size=len(self), dtype=self.codes.dtype ) - col = column.build_categorical_column( + return column.build_categorical_column( categories=self.dtype.categories._values, codes=column.as_column(ary), mask=self.base_mask, ordered=self.dtype.ordered, ) - return col def sort_by_values( self, ascending: bool = True, na_position="last" diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 1c1845373e1..92075eef1b4 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -528,6 +528,11 @@ def __setitem__(self, key: Any, value: Any): if out: self._mimic_inplace(out, inplace=True) + def _wrap_binop_normalization(self, other): + if other is cudf.NA: + return cudf.Scalar(other, dtype=self.dtype) + return self.normalize_binop_value(other) + def _scatter_by_slice( self, key: Slice, value: Union[cudf.core.scalar.Scalar, ColumnBase] ) -> Optional[ColumnBase]: diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index 4ed296ceb52..63c934b2b6c 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -225,12 +225,11 @@ def round(self, freq: str) -> ColumnBase: return libcudf.datetime.round_datetime(self, freq) def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: - if isinstance(other, cudf.Scalar): + if isinstance(other, (cudf.Scalar, ColumnBase, cudf.DateOffset)): return other if isinstance(other, np.ndarray) and other.ndim == 0: other = other.item() - if isinstance(other, dt.datetime): other = np.datetime64(other) elif isinstance(other, dt.timedelta): @@ -239,8 +238,7 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: other = other.to_datetime64() elif isinstance(other, pd.Timedelta): other = other.to_timedelta64() - elif isinstance(other, cudf.DateOffset): - return other + if isinstance(other, np.datetime64): if np.isnat(other): return cudf.Scalar(None, dtype=self.dtype) @@ -250,7 +248,7 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: elif isinstance(other, np.timedelta64): other_time_unit = cudf.utils.dtypes.get_time_unit(other) - if other_time_unit not in ("s", "ms", "ns", "us"): + if other_time_unit not in {"s", "ms", "ns", "us"}: other = other.astype("timedelta64[s]") if np.isnat(other): @@ -259,8 +257,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike: return cudf.Scalar(other) elif other is None: return cudf.Scalar(other, dtype=self.dtype) - else: - raise TypeError(f"cannot normalize {type(other)}") + + raise TypeError(f"cannot normalize {type(other)}") @property def as_numerical(self) -> "cudf.core.column.NumericalColumn": @@ -390,11 +388,13 @@ def binary_operator( rhs: Union[ColumnBase, "cudf.Scalar"], reflect: bool = False, ) -> ColumnBase: + rhs = self._wrap_binop_normalization(rhs) if isinstance(rhs, cudf.DateOffset): return rhs._datetime_binop(self, op, reflect=reflect) + lhs: Union[ScalarLike, ColumnBase] = self - if op in ("eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"): - out_dtype = cudf.dtype(np.bool_) # type: Dtype + if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}: + out_dtype: Dtype = cudf.dtype(np.bool_) elif op == "add" and pd.api.types.is_timedelta64_dtype(rhs.dtype): out_dtype = cudf.core.column.timedelta._timedelta_add_result_dtype( rhs, lhs @@ -418,8 +418,7 @@ def binary_operator( f" the operation {op}" ) - if reflect: - lhs, rhs = rhs, lhs + lhs, rhs = (self, rhs) if not reflect else (rhs, self) return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype) def fillna( diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index a31eaa52641..e011afbd0ff 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -63,46 +63,23 @@ def as_string_column( def binary_operator(self, op, other, reflect=False): if reflect: self, other = other, self - - if not isinstance( - other, - ( - DecimalBaseColumn, - cudf.core.column.NumericalColumn, - cudf.Scalar, - ), - ): - raise TypeError( - f"Operator {op} not supported between" - f"{str(type(self))} and {str(type(other))}" - ) - elif isinstance( - other, cudf.core.column.NumericalColumn - ) and not is_integer_dtype(other.dtype): - raise TypeError( - f"Only decimal and integer column is supported for {op}." - ) - if isinstance(other, cudf.core.column.NumericalColumn): - other = other.as_decimal_column( - self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0) - ) - if not isinstance(self.dtype, other.dtype.__class__): - if ( - self.dtype.precision == other.dtype.precision - and self.dtype.scale == other.dtype.scale - ): - other = other.astype(self.dtype) + # Decimals in libcudf don't support truediv, see + # https://github.com/rapidsai/cudf/pull/7435 for explanation. + op = op.replace("true", "") + other = self._wrap_binop_normalization(other) # Binary Arithmetics between decimal columns. `Scale` and `precision` # are computed outside of libcudf try: - if op in ("add", "sub", "mul", "div"): + if op in {"add", "sub", "mul", "div"}: output_type = _get_decimal_type(self.dtype, other.dtype, op) result = libcudf.binaryop.binaryop( self, other, op, output_type ) + # TODO: Why is this necessary? Why isn't the result's + # precision already set correctly based on output_type? result.dtype.precision = output_type.precision - elif op in ("eq", "ne", "lt", "gt", "le", "ge"): + elif op in {"eq", "ne", "lt", "gt", "le", "ge"}: result = libcudf.binaryop.binaryop(self, other, op, bool) except RuntimeError as e: if "Unsupported operator for these types" in str(e): @@ -140,14 +117,41 @@ def fillna( return result._with_type_metadata(self.dtype) def normalize_binop_value(self, other): - if is_scalar(other) and isinstance(other, (int, np.int, Decimal)): - return cudf.Scalar(Decimal(other)) - elif isinstance(other, cudf.Scalar) and isinstance( - other.dtype, cudf.core.dtypes.DecimalDtype + if isinstance(other, ColumnBase): + if isinstance(other, cudf.core.column.NumericalColumn): + if not is_integer_dtype(other.dtype): + raise TypeError( + "Decimal columns only support binary operations with " + "integer numerical columns." + ) + other = other.as_decimal_column( + self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0) + ) + elif not isinstance(other, DecimalBaseColumn): + raise TypeError( + f"Binary operations are not supported between" + f"{str(type(self))} and {str(type(other))}" + ) + elif not isinstance(self.dtype, other.dtype.__class__): + # This branch occurs if we have a DecimalBaseColumn of a + # different size (e.g. 64 instead of 32). + if ( + self.dtype.precision == other.dtype.precision + and self.dtype.scale == other.dtype.scale + ): + other = other.astype(self.dtype) + + return other + if isinstance(other, cudf.Scalar) and isinstance( + # TODO: Should it be possible to cast scalars of other numerical + # types to decimal? + other.dtype, + cudf.core.dtypes.DecimalDtype, ): return other - else: - raise TypeError(f"cannot normalize {type(other)}") + elif is_scalar(other) and isinstance(other, (int, Decimal)): + return cudf.Scalar(Decimal(other)) + raise TypeError(f"cannot normalize {type(other)}") def _decimal_quantile( self, q: Union[float, Sequence[float]], interpolation: str, exact: bool @@ -256,12 +260,6 @@ def _with_type_metadata( class Decimal64Column(DecimalBaseColumn): dtype: Decimal64Dtype - def __truediv__(self, other): - # TODO: This override is not sufficient. While it will change the - # behavior of x / y for two decimal columns, it will not affect - # col1.binary_operator(col2), which is how Series/Index will call this. - return self.binary_operator("div", other) - def __setitem__(self, key, value): if isinstance(value, np.integer): value = int(value) diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index a541967076d..53ab79542e2 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -133,7 +133,7 @@ def binary_operator( Name: val, dtype: list """ - + other = self._wrap_binop_normalization(other) if isinstance(other.dtype, ListDtype): if binop == "add": return concatenate_rows( @@ -254,6 +254,9 @@ def __cuda_array_interface__(self): "Lists are not yet supported via `__cuda_array_interface__`" ) + def normalize_binop_value(self, other): + return other + def _with_type_metadata( self: "cudf.core.column.ListColumn", dtype: Dtype ) -> "cudf.core.column.ListColumn": diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 4e8acbf2634..015524b841e 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -22,7 +22,12 @@ from cudf import _lib as libcudf from cudf._lib.stream_compaction import drop_nulls from cudf._typing import BinaryOperand, ColumnLike, Dtype, DtypeObj, ScalarLike -from cudf.api.types import is_integer_dtype, is_number +from cudf.api.types import ( + is_bool_dtype, + is_float_dtype, + is_integer_dtype, + is_number, +) from cudf.core.buffer import Buffer from cudf.core.column import ( ColumnBase, @@ -31,12 +36,7 @@ column, string, ) -from cudf.core.dtypes import ( - CategoricalDtype, - Decimal32Dtype, - Decimal64Dtype, - Decimal128Dtype, -) +from cudf.core.dtypes import CategoricalDtype from cudf.utils import cudautils, utils from cudf.utils.dtypes import ( NUMERIC_TYPES, @@ -153,53 +153,45 @@ def unary_operator(self, unaryop: Union[str, Callable]) -> ColumnBase: def binary_operator( self, binop: str, rhs: BinaryOperand, reflect: bool = False, ) -> ColumnBase: - int_dtypes = [ - cudf.dtype("int8"), - cudf.dtype("int16"), - cudf.dtype("int32"), - cudf.dtype("int64"), - cudf.dtype("uint8"), - cudf.dtype("uint16"), - cudf.dtype("uint32"), - cudf.dtype("uint64"), - ] - if rhs is None: - out_dtype = self.dtype - else: - if not ( - isinstance( - rhs, - ( - NumericalColumn, - cudf.Scalar, - cudf.core.column.DecimalBaseColumn, - ), - ) - or np.isscalar(rhs) - ): - msg = "{!r} operator not supported between {} and {}" - raise TypeError(msg.format(binop, type(self), type(rhs))) - if isinstance(rhs, cudf.core.column.Decimal128Column): - lhs: Union[ScalarLike, ColumnBase] = self.as_decimal_column( - Decimal128Dtype(Decimal128Dtype.MAX_PRECISION, 0) - ) - return lhs.binary_operator(binop, rhs) - elif isinstance(rhs, cudf.core.column.Decimal64Column): - lhs = self.as_decimal_column( - Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0) - ) - return lhs.binary_operator(binop, rhs) - elif isinstance(rhs, cudf.core.column.Decimal32Column): - lhs = self.as_decimal_column( - Decimal32Dtype(Decimal32Dtype.MAX_PRECISION, 0) + int_float_dtype_mapping = { + np.int8: np.float32, + np.int16: np.float32, + np.int32: np.float32, + np.int64: np.float64, + np.uint8: np.float32, + np.uint16: np.float32, + np.uint32: np.float64, + np.uint64: np.float64, + np.bool_: np.float32, + } + + if binop in {"truediv", "rtruediv"}: + # Division with integer types results in a suitable float. + if (truediv_type := int_float_dtype_mapping.get(self.dtype.type)) : + return self.astype(truediv_type).binary_operator( + binop, rhs, reflect ) - return lhs.binary_operator(binop, rhs) + + rhs = self._wrap_binop_normalization(rhs) + out_dtype = self.dtype + if rhs is not None: + if isinstance(rhs, cudf.core.column.DecimalBaseColumn): + dtyp = rhs.dtype.__class__(rhs.dtype.MAX_PRECISION, 0) + return self.as_decimal_column(dtyp).binary_operator(binop, rhs) + out_dtype = np.result_type(self.dtype, rhs.dtype) - if binop in ["mod", "floordiv"]: + if binop in {"mod", "floordiv"}: tmp = self if reflect else rhs - if (tmp.dtype in int_dtypes) and ( - (np.isscalar(tmp) and (0 == tmp)) - or ((isinstance(tmp, NumericalColumn)) and (0.0 in tmp)) + # Guard against division by zero for integers. + if ( + (tmp.dtype.type in int_float_dtype_mapping) + and (tmp.dtype.type != np.bool_) + and ( + (np.isscalar(tmp) and (0 == tmp)) + or ( + (isinstance(tmp, NumericalColumn)) and (0.0 in tmp) + ) + ) ): out_dtype = cudf.dtype("float64") @@ -215,6 +207,17 @@ def binary_operator( "NULL_EQUALS", }: out_dtype = "bool" + + if binop in {"and", "or", "xor"}: + if is_float_dtype(self.dtype) or is_float_dtype(rhs): + raise TypeError( + f"Operation 'bitwise {binop}' not supported between " + f"{self.dtype.type.__name__} and " + f"{rhs.dtype.type.__name__}" + ) + if is_bool_dtype(self.dtype) or is_bool_dtype(rhs): + out_dtype = "bool" + lhs, rhs = (self, rhs) if not reflect else (rhs, self) return libcudf.binaryop.binaryop(lhs, rhs, binop, out_dtype) @@ -228,6 +231,15 @@ def nans_to_nulls(self: NumericalColumn) -> NumericalColumn: def normalize_binop_value( self, other: ScalarLike ) -> Union[ColumnBase, ScalarLike]: + if isinstance(other, ColumnBase): + if not isinstance( + other, (NumericalColumn, cudf.core.column.DecimalBaseColumn,), + ): + raise TypeError( + f"Binary operations are not supported between " + f"{type(self)}and {type(other)}" + ) + return other if other is None: return other if isinstance(other, cudf.Scalar): @@ -410,7 +422,7 @@ def find_and_replace( if replacement_col.null_count == len(replacement_col): replacement_col = replacement_col.astype(self.dtype) - if type(to_replace_col) != type(replacement_col): + if not isinstance(to_replace_col, type(replacement_col)): raise TypeError( f"to_replace and value should be of same types," f"got to_replace dtype: {to_replace_col.dtype} and " diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 8f017376c6a..82be924dfbc 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5426,32 +5426,57 @@ def find_first_value( def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: return self._find_first_and_last(value)[1] - def normalize_binop_value(self, other) -> "column.ColumnBase": - # fastpath: gpu scalar - if isinstance(other, cudf.Scalar) and other.dtype == "object": - return column.as_column(other, length=len(self)) - if isinstance(other, column.ColumnBase): - return other.astype(self.dtype) - elif isinstance(other, str) or other is None: - col = utils.scalar_broadcast_to( + def normalize_binop_value( + self, other + ) -> Union[column.ColumnBase, cudf.Scalar]: + if ( + isinstance(other, (column.ColumnBase, cudf.Scalar)) + and other.dtype == "object" + ): + return other + if isinstance(other, str) or other is None: + return utils.scalar_broadcast_to( other, size=len(self), dtype="object" ) - return col - elif isinstance(other, np.ndarray) and other.ndim == 0: - col = utils.scalar_broadcast_to( + if isinstance(other, np.ndarray) and other.ndim == 0: + return utils.scalar_broadcast_to( other.item(), size=len(self), dtype="object" ) - return col - else: - raise TypeError(f"cannot broadcast {type(other)}") + raise TypeError(f"cannot broadcast {type(other)}") def binary_operator( self, op: str, rhs, reflect: bool = False ) -> "column.ColumnBase": - lhs = self - if reflect: - lhs, rhs = rhs, lhs + # Handle object columns that are empty or all nulls when performing + # binary operations + # See https://github.com/pandas-dev/pandas/issues/46332 + if self.null_count == len(self): + if op in { + "add", + "sub", + "mul", + "mod", + "pow", + "truediv", + "floordiv", + "radd", + "rsub", + "rmul", + "rmod", + "rpow", + "rtruediv", + "rfloordiv", + }: + return self + elif op in {"eq", "lt", "le", "gt", "ge"}: + return self.notnull() + elif op == "ne": + return self.isnull() + + rhs = self._wrap_binop_normalization(rhs) + if isinstance(rhs, (StringColumn, str, cudf.Scalar)): + lhs, rhs = (rhs, self) if reflect else (self, rhs) if op == "add": return cast( "column.ColumnBase", @@ -5461,13 +5486,12 @@ def binary_operator( na_rep=cudf.Scalar(None, "str"), ), ) - elif op in ("eq", "ne", "gt", "lt", "ge", "le", "NULL_EQUALS"): + elif op in {"eq", "ne", "gt", "lt", "ge", "le", "NULL_EQUALS"}: return libcudf.binaryop.binaryop( lhs=lhs, rhs=rhs, op=op, dtype="bool" ) - raise TypeError( - f"{op} operator not supported between {type(self)} and {type(rhs)}" + f"{op} not supported between {type(self)} and {type(rhs)}" ) @copy_docstring(column.ColumnBase.view) diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index 7a5f777e88e..d6aedf7f4f4 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -126,35 +126,6 @@ def to_pandas( return pd_series - def _binary_op_floordiv( - self, rhs: BinaryOperand - ) -> Tuple["column.ColumnBase", BinaryOperand, DtypeObj]: - lhs = self # type: column.ColumnBase - if pd.api.types.is_timedelta64_dtype(rhs.dtype): - common_dtype = determine_out_dtype(self.dtype, rhs.dtype) - lhs = lhs.astype(common_dtype).astype("float64") - if isinstance(rhs, cudf.Scalar): - if rhs.is_valid(): - rhs = cudf.Scalar( - np.timedelta64(rhs.value) - .astype(common_dtype) - .astype("float64") - ) - else: - rhs = cudf.Scalar(None, "float64") - else: - rhs = rhs.astype(common_dtype).astype("float64") - out_dtype = cudf.dtype("int64") - elif rhs.dtype.kind in ("f", "i", "u"): - out_dtype = self.dtype - else: - raise TypeError( - f"Floor Division of {self.dtype} with {rhs.dtype} " - f"cannot be performed." - ) - - return lhs, rhs, out_dtype - def _binary_op_mul(self, rhs: BinaryOperand) -> DtypeObj: if rhs.dtype.kind in ("f", "i", "u"): out_dtype = self.dtype @@ -177,27 +148,16 @@ def _binary_op_mod(self, rhs: BinaryOperand) -> DtypeObj: ) return out_dtype - def _binary_op_eq_ne(self, rhs: BinaryOperand) -> DtypeObj: - if pd.api.types.is_timedelta64_dtype(rhs.dtype): - out_dtype = np.bool_ - else: - raise TypeError( - f"Equality of {self.dtype} with {rhs.dtype} " - f"cannot be performed." - ) - return out_dtype - - def _binary_op_lt_gt_le_ge(self, rhs: BinaryOperand) -> DtypeObj: + def _binary_op_lt_gt_le_ge_eq_ne(self, rhs: BinaryOperand) -> DtypeObj: if pd.api.types.is_timedelta64_dtype(rhs.dtype): return np.bool_ - else: - raise TypeError( - f"Invalid comparison between dtype={self.dtype}" - f" and {rhs.dtype}" - ) + raise TypeError( + f"Invalid comparison between dtype={self.dtype}" + f" and {rhs.dtype}" + ) - def _binary_op_truediv( - self, rhs: BinaryOperand + def _binary_op_div( + self, rhs: BinaryOperand, op: str ) -> Tuple["column.ColumnBase", BinaryOperand, DtypeObj]: lhs = self # type: column.ColumnBase if pd.api.types.is_timedelta64_dtype(rhs.dtype): @@ -211,7 +171,7 @@ def _binary_op_truediv( else: rhs = rhs.astype(common_dtype).astype("float64") - out_dtype = cudf.dtype("float64") + out_dtype = cudf.dtype("float64" if op == "truediv" else "int64") elif rhs.dtype.kind in ("f", "i", "u"): out_dtype = self.dtype else: @@ -225,20 +185,17 @@ def _binary_op_truediv( def binary_operator( self, op: str, rhs: BinaryOperand, reflect: bool = False ) -> "column.ColumnBase": + rhs = self._wrap_binop_normalization(rhs) lhs, rhs = self, rhs - if op in ("eq", "ne"): - out_dtype = self._binary_op_eq_ne(rhs) - elif op in ("lt", "gt", "le", "ge", "NULL_EQUALS"): - out_dtype = self._binary_op_lt_gt_le_ge(rhs) + if op in {"eq", "ne", "lt", "gt", "le", "ge", "NULL_EQUALS"}: + out_dtype = self._binary_op_lt_gt_le_ge_eq_ne(rhs) elif op == "mul": out_dtype = self._binary_op_mul(rhs) elif op == "mod": out_dtype = self._binary_op_mod(rhs) - elif op == "truediv": - lhs, rhs, out_dtype = self._binary_op_truediv(rhs) # type: ignore - elif op == "floordiv": - lhs, rhs, out_dtype = self._binary_op_floordiv(rhs) # type: ignore + elif op in {"truediv", "floordiv"}: + lhs, rhs, out_dtype = self._binary_op_div(rhs, op) # type: ignore op = "truediv" elif op == "add": out_dtype = _timedelta_add_result_dtype(lhs, rhs) @@ -256,12 +213,10 @@ def binary_operator( return libcudf.binaryop.binaryop(lhs, rhs, op, out_dtype) def normalize_binop_value(self, other) -> BinaryOperand: - if isinstance(other, cudf.Scalar): + if isinstance(other, (ColumnBase, cudf.Scalar)): return other - if isinstance(other, np.ndarray) and other.ndim == 0: other = other.item() - if isinstance(other, dt.timedelta): other = np.timedelta64(other) elif isinstance(other, pd.Timestamp): diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 57d591dd3e7..39ae9c774e5 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1843,7 +1843,8 @@ def _make_operands_and_index_for_binop( # implementation assumes that binary operations between a column and # NULL are always commutative, even for binops (like subtraction) that # are normally anticommutative. - # TODO: We probably should support pandas DataFrame/Series objects. + # TODO: The above should no longer be necessary once we switch to + # properly invoking the operator since we can then rely on reflection. if isinstance(rhs, Sequence): # TODO: Consider validating sequence length (pandas does). operands = { diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 07cc3ea71cd..84b3bc03fbf 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -33,7 +33,6 @@ is_dict_like, is_dtype_equal, is_scalar, - issubdtype, ) from cudf.core.column import ( ColumnBase, @@ -3428,65 +3427,9 @@ def _colwise_binop( col, (left_column, right_column, reflect, fill_value), ) in operands.items(): - - # Handle object columns that are empty or - # all nulls when performing binary operations - if ( - left_column.dtype == "object" - and left_column.null_count == len(left_column) - and fill_value is None - ): - if fn in ( - "add", - "sub", - "mul", - "mod", - "pow", - "truediv", - "floordiv", - ): - output[col] = left_column - elif fn in ("eq", "lt", "le", "gt", "ge"): - output[col] = left_column.notnull() - elif fn == "ne": - output[col] = left_column.isnull() - continue - - if right_column is cudf.NA: - right_column = cudf.Scalar( - right_column, dtype=left_column.dtype - ) - elif not isinstance(right_column, ColumnBase): - right_column = left_column.normalize_binop_value(right_column) - - fn_apply = fn - if fn == "truediv": - # Decimals in libcudf don't support truediv, see - # https://github.com/rapidsai/cudf/pull/7435 for explanation. - if is_decimal_dtype(left_column.dtype): - fn_apply = "div" - - # Division with integer types results in a suitable float. - truediv_type = { - np.int8: np.float32, - np.int16: np.float32, - np.int32: np.float32, - np.int64: np.float64, - np.uint8: np.float32, - np.uint16: np.float32, - np.uint32: np.float64, - np.uint64: np.float64, - np.bool_: np.float32, - }.get(left_column.dtype.type) - if truediv_type is not None: - left_column = left_column.astype(truediv_type) - output_mask = None if fill_value is not None: - if is_scalar(right_column): - if left_column.nullable: - left_column = left_column.fillna(fill_value) - else: + if isinstance(right_column, ColumnBase): # If both columns are nullable, pandas semantics dictate # that nulls that are present in both left_column and # right_column are not filled. @@ -3500,42 +3443,15 @@ def _colwise_binop( left_column = left_column.fillna(fill_value) elif right_column.nullable: right_column = right_column.fillna(fill_value) + else: + if left_column.nullable: + left_column = left_column.fillna(fill_value) - # For bitwise operations we must verify whether the input column - # types are valid, and if so, whether we need to coerce the output - # columns to booleans. - coerce_to_bool = False - if fn_apply in {"and", "or", "xor"}: - err_msg = ( - f"Operation 'bitwise {fn_apply}' not supported between " - f"{left_column.dtype.type.__name__} and {{}}" - ) - if right_column is None: - raise TypeError(err_msg.format(type(None))) - - try: - left_is_bool = issubdtype(left_column.dtype, np.bool_) - right_is_bool = issubdtype(right_column.dtype, np.bool_) - except TypeError: - raise TypeError(err_msg.format(type(right_column))) - - coerce_to_bool = left_is_bool or right_is_bool - - if not ( - (left_is_bool or issubdtype(left_column.dtype, np.integer)) - and ( - right_is_bool - or issubdtype(right_column.dtype, np.integer) - ) - ): - raise TypeError( - err_msg.format(right_column.dtype.type.__name__) - ) + # TODO: Disable logical and binary operators between columns that + # are not numerical using the new binops mixin. outcol = ( - left_column.binary_operator( - fn_apply, right_column, reflect=reflect - ) + left_column.binary_operator(fn, right_column, reflect=reflect) if right_column is not None else column_empty( left_column.size, left_column.dtype, masked=True @@ -3545,9 +3461,6 @@ def _colwise_binop( if output_mask is not None: outcol = outcol.set_mask(output_mask) - if coerce_to_bool: - outcol = outcol.astype(np.bool_) - output[col] = outcol return output diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index b3b73b8961c..ef5850ecc17 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -2655,7 +2655,7 @@ def value_counts( 3.0 0.500000 2.0 0.333333 1.0 0.166667 - dtype: float64 + dtype: float32 To include ``NA`` value counts, pass ``dropna=False``: diff --git a/python/cudf/cudf/tests/test_categorical.py b/python/cudf/cudf/tests/test_categorical.py index 19a5cd4a49d..5bceaac45c7 100644 --- a/python/cudf/cudf/tests/test_categorical.py +++ b/python/cudf/cudf/tests/test_categorical.py @@ -133,6 +133,7 @@ def test_categorical_compare_unordered(): rfunc=operator.lt, lfunc_args_and_kwargs=([pdsr, pdsr],), rfunc_args_and_kwargs=([sr, sr],), + compare_error_message=False, ) @@ -178,9 +179,7 @@ def test_categorical_binary_add(): rfunc=operator.add, lfunc_args_and_kwargs=([pdsr, pdsr],), rfunc_args_and_kwargs=([sr, sr],), - expected_error_message=( - "Series of dtype `category` cannot perform the operation: add" - ), + compare_error_message=False, ) @@ -258,9 +257,7 @@ def test_cat_series_binop_error(): rfunc=operator.add, lfunc_args_and_kwargs=([pdf["a"], pdf["b"]],), rfunc_args_and_kwargs=([df["a"], df["b"]],), - expected_error_message=( - "Series of dtype `category` cannot perform the operation: add" - ), + compare_error_message=False, ) # lhs is numerical @@ -269,7 +266,7 @@ def test_cat_series_binop_error(): rfunc=operator.add, lfunc_args_and_kwargs=([pdf["b"], pdf["a"]],), rfunc_args_and_kwargs=([df["b"], df["a"]],), - expected_error_message="'add' operator not supported", + compare_error_message=False, ) diff --git a/python/cudf/cudf/tests/test_timedelta.py b/python/cudf/cudf/tests/test_timedelta.py index 2dc7bdaeba4..e371cd16180 100644 --- a/python/cudf/cudf/tests/test_timedelta.py +++ b/python/cudf/cudf/tests/test_timedelta.py @@ -1252,7 +1252,7 @@ def test_timedelta_invalid_ops(): lfunc_args_and_kwargs=([psr, dt_psr],), rfunc_args_and_kwargs=([sr, dt_sr],), expected_error_message=re.escape( - f"Floor Division of {sr.dtype} with {dt_sr.dtype} " + f"Division of {sr.dtype} with {dt_sr.dtype} " f"cannot be performed." ), ) diff --git a/python/dask_cudf/dask_cudf/tests/test_accessor.py b/python/dask_cudf/dask_cudf/tests/test_accessor.py index c7342818610..db4b655fcbd 100644 --- a/python/dask_cudf/dask_cudf/tests/test_accessor.py +++ b/python/dask_cudf/dask_cudf/tests/test_accessor.py @@ -1,3 +1,5 @@ +# Copyright (c) 2019-2022, NVIDIA CORPORATION. + import numpy as np import pandas as pd import pytest @@ -190,7 +192,10 @@ def test_categorical_compare_unordered(data): with pytest.raises( (TypeError, ValueError), - match="Unordered Categoricals can only compare equality or not", + match=( + "The only binary operations supported by unordered categorical " + "columns are equality and inequality." + ), ): dsr < dsr From 9a60671038641b917fbd0f7b3400eb877cb7f9e7 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 17 Mar 2022 19:50:17 -0400 Subject: [PATCH 2/3] Fix cudf::shift to handle offset greater than column size (#10414) Closes #10314 Fixes logic to handle `abs(offset) > input.size()` when passed to `cudf::shift`. As mentioned in #10314 this was causing an unexpected exception: ``` C++ exception with description "parallel_for failed: cudaErrorInvalidConfiguration: invalid configuration argument" thrown in the test body. ``` The behavior now fills the entire output column with the input scalar value. If the scalar is null, then the column is filled with null entries. The logic added here did not require changing or adding any new kernel functions. Additional gtests were added to `shift_tests.cpp` as well. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - MithunR (https://github.com/mythrocks) - Yunsong Wang (https://github.com/PointKernel) URL: https://github.com/rapidsai/cudf/pull/10414 --- cpp/include/cudf/copying.hpp | 4 +- cpp/src/copying/shift.cu | 5 +- cpp/src/strings/copying/shift.cu | 5 +- cpp/tests/copying/shift_tests.cpp | 78 ++++++++++++++------------ python/cudf/cudf/core/column/column.py | 8 --- 5 files changed, 52 insertions(+), 48 deletions(-) diff --git a/cpp/include/cudf/copying.hpp b/cpp/include/cudf/copying.hpp index 850a11426af..2e559afef4f 100644 --- a/cpp/include/cudf/copying.hpp +++ b/cpp/include/cudf/copying.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -371,7 +371,7 @@ std::unique_ptr copy_range( * @param fill_value Fill value for indeterminable outputs. * @param mr Device memory resource used to allocate the returned result's device memory * - * @throw cudf::logic_error if @p input dtype is not fixed-with. + * @throw cudf::logic_error if @p input dtype is neither fixed-width nor string type * @throw cudf::logic_error if @p fill_value dtype does not match @p input dtype. */ std::unique_ptr shift( diff --git a/cpp/src/copying/shift.cu b/cpp/src/copying/shift.cu index dacc1d07447..38fb16f66f4 100644 --- a/cpp/src/copying/shift.cu +++ b/cpp/src/copying/shift.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,7 +70,7 @@ struct shift_functor { std::unique_ptr> operator()(Args&&...) { - CUDF_FAIL("shift does not support non-fixed-width types."); + CUDF_FAIL("shift only supports fixed-width or string types."); } template @@ -125,6 +125,7 @@ struct shift_functor { // avoid assigning elements we know to be invalid. if (not scalar_is_valid) { + if (std::abs(offset) > size) { return output; } if (offset > 0) { index_begin = thrust::make_counting_iterator(offset); data = data + offset; diff --git a/cpp/src/strings/copying/shift.cu b/cpp/src/strings/copying/shift.cu index 024c8d2924d..bdcf01bd336 100644 --- a/cpp/src/strings/copying/shift.cu +++ b/cpp/src/strings/copying/shift.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,6 +94,9 @@ std::unique_ptr shift(strings_column_view const& input, { auto d_fill_str = static_cast(fill_value).value(stream); + // adjust offset when greater than the size of the input + if (std::abs(offset) > input.size()) { offset = input.size(); } + // output offsets column is the same size as the input auto const input_offsets = cudf::detail::slice( diff --git a/cpp/tests/copying/shift_tests.cpp b/cpp/tests/copying/shift_tests.cpp index 256f9129cbf..47615a584a4 100644 --- a/cpp/tests/copying/shift_tests.cpp +++ b/cpp/tests/copying/shift_tests.cpp @@ -20,16 +20,14 @@ #include #include -#include #include #include +#include #include -#include #include #include -#include using cudf::test::fixed_width_column_wrapper; using TestTypes = cudf::test::Types; @@ -72,28 +70,12 @@ constexpr auto lowest() } template -struct ShiftTest : public cudf::test::BaseFixture { +struct ShiftTestsTyped : public cudf::test::BaseFixture { }; -TYPED_TEST_SUITE(ShiftTest, cudf::test::FixedWidthTypes); +TYPED_TEST_SUITE(ShiftTestsTyped, cudf::test::FixedWidthTypes); -TYPED_TEST(ShiftTest, OneColumnEmpty) -{ - using T = TypeParam; - - std::vector vals{}; - std::vector mask{}; - - auto input = fixed_width_column_wrapper{}; - auto expected = fixed_width_column_wrapper(vals.begin(), vals.end(), mask.begin()); - - auto fill = make_scalar(); - auto actual = cudf::shift(input, 5, *fill); - - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); -} - -TYPED_TEST(ShiftTest, TwoColumnsEmpty) +TYPED_TEST(ShiftTestsTyped, ColumnEmpty) { using T = TypeParam; @@ -109,7 +91,7 @@ TYPED_TEST(ShiftTest, TwoColumnsEmpty) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); } -TYPED_TEST(ShiftTest, OneColumn) +TYPED_TEST(ShiftTestsTyped, NonNullColumn) { using T = TypeParam; @@ -134,7 +116,7 @@ TYPED_TEST(ShiftTest, OneColumn) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); } -TYPED_TEST(ShiftTest, OneColumnNegativeShift) +TYPED_TEST(ShiftTestsTyped, NegativeShift) { using T = TypeParam; @@ -159,7 +141,7 @@ TYPED_TEST(ShiftTest, OneColumnNegativeShift) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); } -TYPED_TEST(ShiftTest, OneColumnNullFill) +TYPED_TEST(ShiftTestsTyped, NullScalar) { using T = TypeParam; @@ -186,7 +168,7 @@ TYPED_TEST(ShiftTest, OneColumnNullFill) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); } -TYPED_TEST(ShiftTest, TwoColumnsNullableInput) +TYPED_TEST(ShiftTestsTyped, NullableColumn) { using T = TypeParam; @@ -199,25 +181,21 @@ TYPED_TEST(ShiftTest, TwoColumnsNullableInput) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *actual); } -TYPED_TEST(ShiftTest, MismatchFillValueDtypes) +TYPED_TEST(ShiftTestsTyped, MismatchFillValueDtypes) { using T = TypeParam; - if (std::is_same_v) { return; } - auto input = fixed_width_column_wrapper{}; - auto fill = make_scalar(); + auto fill = cudf::string_scalar(""); - std::unique_ptr output; - - EXPECT_THROW(output = cudf::shift(input, 5, *fill), cudf::logic_error); + EXPECT_THROW(cudf::shift(input, 5, fill), cudf::logic_error); } -struct ShiftTestNonFixedWidth : public cudf::test::BaseFixture { +struct ShiftTests : public cudf::test::BaseFixture { }; -TEST_F(ShiftTestNonFixedWidth, StringsShiftTest) +TEST_F(ShiftTests, StringsShiftTest) { auto input = cudf::test::strings_column_wrapper({"", "bb", "ccc", "ddddddé", ""}, {0, 1, 1, 1, 0}); @@ -243,3 +221,33 @@ TEST_F(ShiftTestNonFixedWidth, StringsShiftTest) auto sliced_left = cudf::test::strings_column_wrapper({"ccc", "ddddddé", "xx"}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(sliced_left, *results); } + +TEST_F(ShiftTests, OffsetGreaterThanSize) +{ + auto const input_str = + cudf::test::strings_column_wrapper({"", "bb", "ccc", "ddé", ""}, {0, 1, 1, 1, 0}); + auto results = cudf::shift(input_str, 6, cudf::string_scalar("xx")); + auto expected_str = cudf::test::strings_column_wrapper({"xx", "xx", "xx", "xx", "xx"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_str, *results); + results = cudf::shift(input_str, -6, cudf::string_scalar("xx")); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_str, *results); + + results = cudf::shift(input_str, 6, cudf::string_scalar("", false)); + expected_str = cudf::test::strings_column_wrapper({"", "", "", "", ""}, {0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_str, *results); + results = cudf::shift(input_str, -6, cudf::string_scalar("", false)); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_str, *results); + + auto const input = fixed_width_column_wrapper({0, 2, 3, 4, 0}, {0, 1, 1, 1, 0}); + results = cudf::shift(input, 6, cudf::numeric_scalar(9)); + auto expected = fixed_width_column_wrapper({9, 9, 9, 9, 9}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, *results); + results = cudf::shift(input, -6, cudf::numeric_scalar(9)); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, *results); + + results = cudf::shift(input, 6, cudf::numeric_scalar(0, false)); + expected = fixed_width_column_wrapper({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, *results); + results = cudf::shift(input, -6, cudf::numeric_scalar(0, false)); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, *results); +} diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 92075eef1b4..01a450ce1d0 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -349,14 +349,6 @@ def _fill( return self def shift(self, offset: int, fill_value: ScalarLike) -> ColumnBase: - # libcudf currently doesn't handle case when offset > len(df) - # ticket to fix the bug in link below: - # https://github.com/rapidsai/cudf/issues/10314 - if abs(offset) > len(self): - if fill_value is None: - return column_empty_like(self, masked=True) - else: - return full(len(self), fill_value, dtype=self.dtype) return libcudf.copying.shift(self, offset, fill_value) @property From 621d26faa6a1b0831e071b61fd0a9b7f1493c195 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 17 Mar 2022 19:51:49 -0400 Subject: [PATCH 3/3] Add nvtext::byte_pair_encoding API (#10270) Reference #9657 Add the `nvtext::byte_pair_encoding` API. This is not the BPE tokenizer but just the encoding function. The tokenizer will be a larger effort that will probably span multiple PRs. Providing the encoder here to be evaluated independently. Theoretically, this API could be used like the following to achieve a _similar_ BPE tokenizer behavior perhaps: ``` input = strings to tokenize mps = nvtext::load_merge_pairs_file("merges.txt"); bpe = nvtext::byte_pair_encoding( input, mps ); vocab = nvtext::load_vocabulary_file( "hashed_vocab.txt" ); result = nvtext::subword_tokenize( bpe, vocab, max_length, stride, lower_case, truncate, max_rows ); ``` Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) - https://github.com/nvdbaranec URL: https://github.com/rapidsai/cudf/pull/10270 --- cpp/CMakeLists.txt | 2 + cpp/include/cudf/strings/detail/combine.hpp | 16 +- cpp/include/nvtext/bpe_tokenize.hpp | 122 +++++ cpp/src/text/subword/bpe_tokenizer.cu | 556 ++++++++++++++++++++ cpp/src/text/subword/bpe_tokenizer.cuh | 59 +++ cpp/src/text/subword/load_merges_file.cu | 187 +++++++ cpp/tests/CMakeLists.txt | 1 + cpp/tests/text/bpe_tests.cpp | 111 ++++ 8 files changed, 1053 insertions(+), 1 deletion(-) create mode 100644 cpp/include/nvtext/bpe_tokenize.hpp create mode 100644 cpp/src/text/subword/bpe_tokenizer.cu create mode 100644 cpp/src/text/subword/bpe_tokenizer.cuh create mode 100644 cpp/src/text/subword/load_merges_file.cu create mode 100644 cpp/tests/text/bpe_tests.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8b8198782ba..5e523cd4e65 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -480,8 +480,10 @@ add_library( src/text/normalize.cu src/text/replace.cu src/text/stemmer.cu + src/text/subword/bpe_tokenizer.cu src/text/subword/data_normalizer.cu src/text/subword/load_hash_file.cu + src/text/subword/load_merges_file.cu src/text/subword/subword_tokenize.cu src/text/subword/wordpiece_tokenizer.cu src/text/tokenize.cu diff --git a/cpp/include/cudf/strings/detail/combine.hpp b/cpp/include/cudf/strings/detail/combine.hpp index d6bdf398886..50f9a70e21c 100644 --- a/cpp/include/cudf/strings/detail/combine.hpp +++ b/cpp/include/cudf/strings/detail/combine.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,20 @@ std::unique_ptr join_strings( rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @copydoc join_list_elements(table_view const&,string_scalar const&,string_scalar + * const&,separator_on_nulls,output_if_empty_list,rmm::mr::device_memory_resource*) + * + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr join_list_elements(lists_column_view const& lists_strings_column, + string_scalar const& separator, + string_scalar const& narep, + separator_on_nulls separate_nulls, + output_if_empty_list empty_list_policy, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/include/nvtext/bpe_tokenize.hpp b/cpp/include/nvtext/bpe_tokenize.hpp new file mode 100644 index 00000000000..23fcd3acd03 --- /dev/null +++ b/cpp/include/nvtext/bpe_tokenize.hpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace nvtext { + +/** + * @addtogroup nvtext_tokenize + * @{ + * @file + */ + +/** + * @brief The table of merge pairs for the BPE encoder. + * + * To create an instance, call @ref nvtext::load_merge_pairs_file + */ +struct bpe_merge_pairs { + struct bpe_merge_pairs_impl; + std::unique_ptr impl{}; + + bpe_merge_pairs(std::unique_ptr&& input, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + + bpe_merge_pairs(cudf::strings_column_view const& input, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + + ~bpe_merge_pairs(); + + cudf::size_type get_size(); + std::size_t get_map_size(); +}; + +/** + * @brief Create a nvtext::bpe_merge_pairs from an input file. + * + * The file should contain a pair of strings per line separated by + * a single space. + * + * Example: + * @code{.txt} + * e n + * i t + * i s + * e s + * en t + * c e + * es t + * en ce + * T h + * Th is + * t est + * s ent + * ... + * @endcode + * + * The pairs are expected to be ordered in the file by their rank + * relative to each other. A pair earlier in the file has priority over + * any pairs below it. + * + * @param filename_merges Local file path of pairs encoded in UTF-8. + * @param mr Memory resource to allocate any returned objects. + */ +std::unique_ptr load_merge_pairs_file( + std::string const& filename_merges, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Byte pair encode the input strings. + * + * This will split each string on whitespace, perform the encoding, + * and then build the output column using the given `separator`. + * + * The encoding algorithm rebuilds each string by matching substrings + * in the `merge_pairs` table and iteratively removing the minimum ranked pair + * until no pairs are left. Then, a space is inserted between the remaining + * pairs before the result is joined to make the output string. + * + * @code{.pseudo} + * mps = load_merges_file("merges.txt") // see doxygen for example contents + * input = ["test sentence", "thisis test"] + * result = byte_pair_encoding(input, mps) + * result is now ["test sent ence", "this is test"] + * @endcode + * + * @throw cudf::logic_error if `merge_pairs` is empty + * @throw cudf::logic_error if `separator` is invalid + * + * @param input Strings to encode. + * @param merge_pairs Created by a call to @ref nvtext::load_merge_pairs_file. + * @param separator String used to build the output after encoding. + * Default is a space. + * @param mr Memory resource to allocate any returned objects. + */ +std::unique_ptr byte_pair_encoding( + cudf::strings_column_view const& input, + bpe_merge_pairs const& merges_pairs, + cudf::string_scalar const& separator = cudf::string_scalar(" "), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** @} */ // end of group +} // namespace nvtext diff --git a/cpp/src/text/subword/bpe_tokenizer.cu b/cpp/src/text/subword/bpe_tokenizer.cu new file mode 100644 index 00000000000..c9a1d685f2e --- /dev/null +++ b/cpp/src/text/subword/bpe_tokenizer.cu @@ -0,0 +1,556 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace nvtext { +namespace detail { + +namespace { + +template +constexpr bool is_whitespace(CharType ch) +{ + return ch <= ' '; +} + +/** + * @brief Resolve a substring up to the first whitespace character. + * + * This will return a substring of the input starting with the first byte + * up to the first whitespace character found or the end of the string. + * Any whitespace is expected only at the end of the string. + * + * @param d_str Input string to resolve. + * @return Substring of the input excluding any trailing whitespace. + */ +__device__ cudf::string_view get_first_token(cudf::string_view const& d_str) +{ + auto const begin = d_str.data(); + auto const end = thrust::find_if( + thrust::seq, begin, begin + d_str.size_bytes(), [](auto ch) { return is_whitespace(ch); }); + auto const size = static_cast(thrust::distance(begin, end)); + return cudf::string_view(begin, size); +} + +/** + * @brief Main byte pair encoding algorithm function for each string. + * + * @see The byte_pair_encoding_fn::operator() function below for details. + */ +struct byte_pair_encoding_fn { + cudf::column_device_view const d_merges; + cudf::column_device_view const d_strings; + merge_pairs_map_type::device_view const d_map; + cudf::size_type* d_sizes; // output size of encoded string + string_hasher_type const hasher; + cudf::size_type* d_byte_indices; + + /** + * @brief Parse the merge pair into components. + * + * The two substrings are separated by a single space. + * + * @param idx Index of merge pair to dissect. + * @return The left and right halves of the merge pair. + */ + __device__ thrust::pair dissect_merge_pair( + cudf::size_type idx) + { + auto const d_pair = d_merges.element(idx); + auto const lhs = d_pair.data(); + auto const end_str = d_pair.data() + d_pair.size_bytes(); + auto const rhs = thrust::find(thrust::seq, lhs, end_str, ' '); // space always expected + // check for malformed pair entry to prevent segfault + if (rhs == end_str) { return thrust::make_pair(cudf::string_view{}, cudf::string_view{}); } + auto const lhs_size = static_cast(thrust::distance(lhs, rhs)); + auto const rhs_size = static_cast(thrust::distance(rhs + 1, end_str)); + return thrust::make_pair(cudf::string_view(lhs, lhs_size), + cudf::string_view(rhs + 1, rhs_size)); + } + + /** + * @brief Get the next substring of the given string. + * + * This will find the next sequence of characters identified by the + * given byte indices iterator values. The beginning of the sequence + * starts at `begin` and the end of the sequence is the first non-zero + * index found between (begin,end) exclusive. + * + * @tparam Iterator The byte indices iterator type + * @param begin Start of indices to check + * @param end End of indices to check + * @param d_str String to substring + * @return The substring found. + */ + template + __device__ cudf::string_view next_substr(Iterator begin, + Iterator end, + cudf::string_view const& d_str) + { + auto const next = thrust::find_if(thrust::seq, begin + 1, end, [](auto v) { return v != 0; }); + auto const size = static_cast(thrust::distance(begin, next)); + return cudf::string_view(d_str.data() + *begin, size); + } + + /** + * @brief Compute the hash over the input strings. + * + * The input strings are combined with a space to produce hash for matching + * a merge pair within the `d_map`. + * + * @param lhs First string. + * @param rhs Second string. + * @return The hash value to match with `d_map`. + */ + __device__ hash_value_type compute_hash(cudf::string_view const& lhs, + cudf::string_view const& rhs) + { + __shared__ char shmem[48 * 1024]; // max for Pascal + auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1; + auto const thread_memory_size = static_cast(sizeof(shmem) / blockDim.x); + + // Edge case check. + // Empirically found only two merge pair strings that were greater than 70 bytes + // and they both looked like ignorable errors. Double check this analysis with Vibhu. + if (thread_memory_size < total_size) { return 0; } + + // build the target string in shared memory + char* ptr = &shmem[threadIdx.x * thread_memory_size]; + + // build a temp string like: temp = lhs + ' ' + rhs + memcpy(ptr, lhs.data(), lhs.size_bytes()); + memcpy(ptr + lhs.size_bytes(), " ", 1); + memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes()); + + auto const d_hash_str = cudf::string_view(ptr, total_size); + return hasher(d_hash_str); // return the hash for the temp string + } + + /** + * @brief Byte encode each string. + * + * Each string is iteratively scanned for the minimum rank of adjacent substring pairs + * as found within the `d_map` table. Once the minimum pair is located, that pair + * is removed -- virtually by zero-ing the index value between any matching adjacent pairs. + * + * The iteration ends once there are no more adjacent pairs or there are no more + * matches found in `d_map`. At the end, the indices for each string reflect the + * encoding pattern and can be used to build the output. + * + * This function also computes the size of the encoded output of each string + * by simply counting the number of non-zero indices values remaining. This saves + * an extra kernel launch normally required to compute the offsets of the output column. + * + * @param idx The index of the string in `d_strings` to encode + */ + __device__ void operator()(cudf::size_type idx) + { + if (d_strings.is_null(idx)) { return; } + auto const d_str = get_first_token(d_strings.element(idx)); + if (d_str.empty()) { return; } + + auto const offset = d_strings.child(cudf::strings_column_view::offsets_column_index) + .element(idx); + auto const d_indices = d_byte_indices + offset; + + // initialize the byte indices for this string; + // set the index value to 0 for any intermediate UTF-8 bytes + thrust::transform(thrust::seq, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_str.size_bytes()), + d_indices, + [data = d_str.data()](auto idx) { + auto const byte = static_cast(data[idx]); + return cudf::strings::detail::is_begin_utf8_char(byte) ? idx : 0; + }); + + auto const begin = d_indices; + auto const end = d_indices + d_str.size_bytes(); + + // keep processing the string until there are no more adjacent pairs found in d_map + cudf::size_type min_rank = 0; + while (min_rank < cuda::std::numeric_limits::max()) { + // initialize working variables + min_rank = cuda::std::numeric_limits::max(); + + auto lhs = next_substr(begin, end, d_str); + auto itr = begin + lhs.size_bytes(); + + auto min_itr = itr; // these are set along with + auto min_size = lhs.size_bytes(); // the min_rank variable + + // check each adjacent pair against the d_map + while (itr < end) { + auto const rhs = next_substr(itr, end, d_str); + if (rhs.empty()) break; // no more adjacent pairs + + auto const hash = compute_hash(lhs, rhs); + auto const map_itr = d_map.find(hash); + if (map_itr != d_map.end()) { + // found a match; record the rank (and other min_ vars) + auto const rank = static_cast(map_itr->second); + if (rank < min_rank) { + min_rank = rank; + min_itr = itr; + min_size = rhs.size_bytes(); + } + } + // next substring + lhs = rhs; + itr += rhs.size_bytes(); + } + + // if any pair matched, remove every occurrence from the string + if (min_rank < cuda::std::numeric_limits::max()) { + // remove the first pair we found + itr = min_itr; + *itr = 0; + + // continue scanning for other occurrences in the remainder of the string + itr += min_size; + if (itr < end) { + auto const d_pair = dissect_merge_pair(min_rank); + + lhs = next_substr(itr, end, d_str); + itr += lhs.size_bytes(); + while (itr < end) { + auto rhs = next_substr(itr, end, d_str); + if (d_pair.first == lhs && d_pair.second == rhs) { + *itr = 0; // removes the pair from this string + itr += rhs.size_bytes(); + if (itr >= end) { break; } // done checking for pairs + // skip to the next adjacent pair + rhs = next_substr(itr, end, d_str); + } + // next substring + lhs = rhs; + itr += rhs.size_bytes(); + } + } + } + } + + // compute and store the output size for this string's encoding + auto const encoded_size = d_str.size_bytes() + // number of original bytes + + thrust::count_if( // number of non-zero byte indices + thrust::seq, + d_indices, + d_indices + d_str.size_bytes(), + [](auto v) { return v != 0; }); + d_sizes[idx] = static_cast(encoded_size); + } +}; + +/** + * @brief Build the output string encoding. + * + * This copies each string to the output inserting a space at each non-zero byte index. + * + * @code{.txt} + * d_strings = ["helloworld", "testthis"] + * d_byte_indices = [ 0000050000 00004000] + * result is ["hello world", "test this"] + * @endcode + */ +struct build_encoding_fn { + cudf::column_device_view const d_strings; + cudf::size_type const* d_byte_indices; + cudf::offset_type const* d_offsets; + char* d_chars{}; + + __device__ void operator()(cudf::size_type idx) + { + if (d_strings.is_null(idx)) { return; } + auto const d_str = get_first_token(d_strings.element(idx)); + if (d_str.empty()) { return; } + + auto const offset = d_strings.child(cudf::strings_column_view::offsets_column_index) + .element(idx); + auto const d_indices = d_byte_indices + offset; + auto d_output = d_chars ? d_chars + d_offsets[idx] : nullptr; + + // copy chars while indices[i]==0, + // insert space each time indices[i]!=0 + auto const begin = d_indices; + auto const end = d_indices + d_str.size_bytes(); + auto d_input = d_str.data(); + *d_output++ = *d_input++; + auto itr = begin + 1; + while (itr < end) { + if (*itr++) *d_output++ = ' '; + *d_output++ = *d_input++; + } + // https://github.com/rapidsai/cudf/pull/10270/files#r826319405 + } +}; + +/** + * @brief Perform byte pair encoding on each string in the input column. + * + * The result is a strings column of the same size where each string has been encoded. + * + * The encoding is performed iteratively. Each pass determines the string's lowest + * ranked merge pair as determined by the strings in `merges_table`. This pair + * is removed (virtually) from each string before starting the next iteration. + * + * Once all pairs have exhausted for all strings, the output is constructed from + * the results by adding spaces between each remaining pair in each string. + * + * @param input Strings to encode. + * @param merge_pairs Merge pairs data and map used for encoding. + * @param stream CUDA stream used for device memory operations and kernel launches + */ +std::unique_ptr byte_pair_encoding( + cudf::strings_column_view const& input, + bpe_merge_pairs::bpe_merge_pairs_impl const& merge_pairs, + rmm::cuda_stream_view stream) +{ + CUDF_EXPECTS(!merge_pairs.get_merge_pairs().is_empty(), "Merge pairs table must not be empty"); + + // build working vector to hold index values per byte + rmm::device_uvector d_byte_indices(input.chars().size(), stream); + + auto const d_merges = cudf::column_device_view::create(merge_pairs.get_merge_pairs(), stream); + auto const d_strings = cudf::column_device_view::create(input.parent(), stream); + + auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, + static_cast(input.size() + 1), + cudf::mask_state::UNALLOCATED, + stream, + rmm::mr::get_current_device_resource()); + auto d_offsets = offsets->mutable_view().data(); + + byte_pair_encoding_fn fn{*d_merges, + *d_strings, + merge_pairs.get_merge_pairs_map(), + d_offsets, + string_hasher_type{}, + d_byte_indices.data()}; + thrust::for_each_n( + rmm::exec_policy(stream), thrust::make_counting_iterator(0), input.size(), fn); + + // build the output: add spaces between the remaining pairs in each string + thrust::exclusive_scan( + rmm::exec_policy(stream), d_offsets, d_offsets + input.size() + 1, d_offsets); + + auto const bytes = + cudf::detail::get_value(offsets->view(), input.size(), stream); + auto chars = cudf::strings::detail::create_chars_child_column( + bytes, stream, rmm::mr::get_current_device_resource()); + auto d_chars = chars->mutable_view().data(); + + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + input.size(), + build_encoding_fn{*d_strings, d_byte_indices.data(), d_offsets, d_chars}); + + return make_strings_column( + input.size(), std::move(offsets), std::move(chars), 0, rmm::device_buffer{}); +} + +/** + * @brief Detect space to not-space transitions inside each string. + * + * This handles sliced input and null strings as well. + * It is parallelized over bytes and returns true only for valid left edges + * -- non-space preceded by a space. + */ +struct edge_of_space_fn { + cudf::column_device_view const d_strings; + __device__ bool operator()(cudf::offset_type offset) + { + auto const d_chars = + d_strings.child(cudf::strings_column_view::chars_column_index).data(); + if (is_whitespace(d_chars[offset]) || !is_whitespace(d_chars[offset - 1])) { return false; } + + auto const offsets = d_strings.child(cudf::strings_column_view::offsets_column_index); + auto const d_offsets = offsets.data() + d_strings.offset(); + // ignore offsets outside sliced range + if (offset < d_offsets[0] || offset >= d_offsets[d_strings.size()]) { return false; } + + auto itr = + thrust::lower_bound(thrust::seq, d_offsets, d_offsets + d_strings.size() + 1, offset); + // ignore offsets at existing string boundaries + if (*itr == offset) { return false; } + + // count only edges for valid strings + auto const index = static_cast(thrust::distance(d_offsets, itr)) - 1; + return d_strings.is_valid(index); + } +}; + +/** + * @brief Create new offsets by identifying substrings by whitespace. + * + * This is similar to cudf::strings::split_record but does not fully split + * and only returns new offsets. The behavior is more like a view-only slice + * of the chars child with the result still including trailing delimiters. + * + * The encoding algorithm ignores the trailing whitespace of each string. + * + * @param input Strings to tokenize. + * @param stream CUDA stream used for device memory operations and kernel launches + * @return New offsets including those at the edge of each space. + */ +std::unique_ptr space_offsets(cudf::strings_column_view const& input, + cudf::column_device_view const& d_strings, + rmm::cuda_stream_view stream) +{ + // count space offsets + auto const begin = thrust::make_counting_iterator(1); + auto const end = thrust::make_counting_iterator(input.chars().size()); + edge_of_space_fn edge_of_space{d_strings}; + auto const space_count = thrust::count_if(rmm::exec_policy(stream), begin, end, edge_of_space); + + // copy space offsets + rmm::device_uvector space_offsets(space_count, stream); + thrust::copy_if(rmm::exec_policy(stream), begin, end, space_offsets.data(), edge_of_space); + + // create output offsets + auto result = + cudf::make_numeric_column(cudf::data_type{cudf::type_to_id()}, + static_cast(space_count + input.size() + 1), + cudf::mask_state::UNALLOCATED, + stream, + rmm::mr::get_current_device_resource()); + + // combine current offsets with space offsets + thrust::merge(rmm::exec_policy(stream), + input.offsets_begin(), + input.offsets_end(), + space_offsets.begin(), + space_offsets.end(), + result->mutable_view().begin()); + + return result; +} + +/** + * @brief Build new offsets that can be used to build a list column for calling join. + * + * This essentially returns the number of tokens for each string. + */ +struct list_offsets_fn { + cudf::column_device_view const d_strings; + __device__ cudf::size_type operator()(cudf::size_type idx) + { + if (d_strings.is_null(idx)) return 0; + auto const d_str = d_strings.element(idx); + if (d_str.empty()) return 1; // empty is a single valid result + + auto const begin = thrust::make_counting_iterator(1); + auto const end = thrust::make_counting_iterator(d_str.size_bytes()); + + // this counts the number of non-adjacent delimiters + auto const result = + thrust::count_if(thrust::seq, begin, end, [data = d_str.data()](auto chidx) { + return !is_whitespace(data[chidx]) && is_whitespace(data[chidx - 1]); + }); + return static_cast(result) + 1; + } +}; + +} // namespace + +std::unique_ptr byte_pair_encoding(cudf::strings_column_view const& input, + bpe_merge_pairs const& merge_pairs, + cudf::string_scalar const& separator, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + if (input.is_empty() || input.chars_size() == 0) + return cudf::make_empty_column(cudf::type_id::STRING); + + auto const d_strings = cudf::column_device_view::create(input.parent(), stream); + auto const offsets = space_offsets(input, *d_strings, stream); + + // build a view using the new offsets and the current input chars column + auto const input_view = cudf::column_view(cudf::data_type{cudf::type_id::STRING}, + offsets->size() - 1, + nullptr, // no parent data + nullptr, // null-mask + 0, // null-count + 0, // offset + {offsets->view(), input.chars()}); + + // run BPE on this view + auto const bpe_column = + byte_pair_encoding(cudf::strings_column_view(input_view), *(merge_pairs.impl), stream); + + // recombine the result: + // compute the offsets needed to build a list view + auto const list_offsets = [d_strings = *d_strings, stream] { + auto offsets_itr = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), list_offsets_fn{d_strings}); + return cudf::strings::detail::make_offsets_child_column( + offsets_itr, offsets_itr + d_strings.size(), stream, rmm::mr::get_current_device_resource()); + }(); + + // build a list column_view using the BPE output and the list_offsets + auto const list_join = cudf::column_view(cudf::data_type{cudf::type_id::LIST}, + input.size(), + nullptr, // no parent data in list column + input.null_mask(), + input.null_count(), + 0, + {list_offsets->view(), bpe_column->view()}); + + // build the output strings column + auto result = + cudf::strings::detail::join_list_elements(cudf::lists_column_view(list_join), + separator, + cudf::string_scalar(""), + cudf::strings::separator_on_nulls::NO, + cudf::strings::output_if_empty_list::EMPTY_STRING, + stream, + mr); + return result; +} + +} // namespace detail + +std::unique_ptr byte_pair_encoding(cudf::strings_column_view const& input, + bpe_merge_pairs const& merges_table, + cudf::string_scalar const& separator, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::byte_pair_encoding(input, merges_table, separator, rmm::cuda_stream_default, mr); +} + +} // namespace nvtext diff --git a/cpp/src/text/subword/bpe_tokenizer.cuh b/cpp/src/text/subword/bpe_tokenizer.cuh new file mode 100644 index 00000000000..31cc29a8d8a --- /dev/null +++ b/cpp/src/text/subword/bpe_tokenizer.cuh @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include + +#include + +#include +#include +#include + +#include + +namespace nvtext { +namespace detail { + +using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor>; + +using merge_pairs_map_type = cuco::static_map; + +using string_hasher_type = MurmurHash3_32; + +} // namespace detail + +struct bpe_merge_pairs::bpe_merge_pairs_impl { + std::unique_ptr const merge_pairs; + std::unique_ptr merge_pairs_map; + + bpe_merge_pairs_impl(std::unique_ptr&& merge_pairs, + std::unique_ptr&& merge_pairs_map); + + auto get_merge_pairs() const { return merge_pairs->view(); } + auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); } +}; + +} // namespace nvtext diff --git a/cpp/src/text/subword/load_merges_file.cu b/cpp/src/text/subword/load_merges_file.cu new file mode 100644 index 00000000000..bdcbe45df64 --- /dev/null +++ b/cpp/src/text/subword/load_merges_file.cu @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace nvtext { +namespace detail { + +namespace { + +struct make_pair_function { + /** + * @brief Hash the merge pair entry + */ + __device__ cuco::pair_type operator()(cudf::size_type idx) + { + auto const result = _hasher(d_strings.element(idx)); + return cuco::make_pair(result, idx); + } + + string_hasher_type const _hasher; + cudf::column_device_view const d_strings; +}; + +/** + * @brief Loads a text file of merge-pairs into a strings column. + * + * The line position in the file indicates the pair's rank. + * + * @code{.pseudo} + * Format of the file: + * #version .. + * a1 a2 + * b1 b2 + * c1 c2 + * ... + * @endcode + * + * @param filename_merges Path to text file containing merge-pairs + * @return object containing table elements for the BPE function + */ +std::unique_ptr load_file_to_column(std::string const& filename_merges, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + std::ifstream merges_file(filename_merges); + CUDF_EXPECTS(merges_file.good(), "Could not open " + filename_merges); + + std::vector chars{}; + std::vector offsets(1, 0); + + std::string line; + std::getline(merges_file, line); + std::string version = "#version"; + if (line.substr(0, version.size()).compare(version) == 0) { std::getline(merges_file, line); } + + // This is a text file delimited only by CR/LF. + // TODO: Look into using the CSV reader to load the strings column instead. + while (!line.empty()) { + chars.insert(chars.end(), std::cbegin(line), std::cend(line)); + offsets.push_back(offsets.back() + line.length()); + std::getline(merges_file, line); + } + + CUDF_EXPECTS(!chars.empty(), "No data found in " + filename_merges); + + auto d_chars = cudf::detail::make_device_uvector_async(chars, stream, mr); + auto d_offsets = cudf::detail::make_device_uvector_async(offsets, stream, mr); + return cudf::make_strings_column(d_chars, d_offsets); +} + +std::unique_ptr initialize_merge_pairs_map( + cudf::strings_column_view const& input, rmm::cuda_stream_view stream) +{ + // Ensure capacity is at least (size/0.7) as documented here: + // https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182 + auto merge_pairs_map = std::make_unique( + static_cast(input.size() * 2), // capacity is 2x; + std::numeric_limits::max(), // empty key; + -1, // empty value is not used + hash_table_allocator_type{default_allocator{}, stream}, + stream.value()); + + auto d_strings = cudf::column_device_view::create(input.parent(), stream); + make_pair_function pair_func{string_hasher_type{}, *d_strings}; + auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func); + + merge_pairs_map->insert(iter, + iter + input.size(), + cuco::detail::MurmurHash3_32{}, + thrust::equal_to{}, + stream.value()); + + return merge_pairs_map; +} + +std::unique_ptr create_bpe_merge_pairs_impl( + std::unique_ptr&& input, rmm::cuda_stream_view stream) +{ + auto merge_pairs = initialize_merge_pairs_map(cudf::strings_column_view(input->view()), stream); + return std::make_unique(std::move(input), + std::move(merge_pairs)); +} + +std::unique_ptr create_bpe_merge_pairs_impl( + cudf::strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return create_bpe_merge_pairs_impl(std::make_unique(input.parent(), stream, mr), + stream); +} + +} // namespace + +std::unique_ptr load_merge_pairs_file(std::string const& filename_merges, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto input_column = load_file_to_column(filename_merges, stream, mr); + return std::make_unique(std::move(input_column), stream, mr); +} + +} // namespace detail + +std::unique_ptr load_merge_pairs_file(std::string const& filename_merges, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::load_merge_pairs_file(filename_merges, rmm::cuda_stream_default, mr); +} + +bpe_merge_pairs::bpe_merge_pairs_impl::bpe_merge_pairs_impl( + std::unique_ptr&& merge_pairs, + std::unique_ptr&& merge_pairs_map) + : merge_pairs(std::move(merge_pairs)), merge_pairs_map(std::move(merge_pairs_map)) +{ +} + +bpe_merge_pairs::bpe_merge_pairs(std::unique_ptr&& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource*) + : impl(detail::create_bpe_merge_pairs_impl(std::move(input), stream)) +{ +} + +bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + : impl(detail::create_bpe_merge_pairs_impl(input, stream, mr)) +{ +} + +bpe_merge_pairs::~bpe_merge_pairs() = default; + +cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); } +std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); } + +} // namespace nvtext diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 24013da62b9..9120d3b3836 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -416,6 +416,7 @@ ConfigureTest(STRUCTS_TEST structs/structs_column_tests.cpp structs/utilities_te # * nvtext test ----------------------------------------------------------------------------------- ConfigureTest( TEXT_TEST + text/bpe_tests.cpp text/edit_distance_tests.cpp text/ngrams_tests.cpp text/ngrams_tokenize_tests.cpp diff --git a/cpp/tests/text/bpe_tests.cpp b/cpp/tests/text/bpe_tests.cpp new file mode 100644 index 00000000000..07f3a41f0e2 --- /dev/null +++ b/cpp/tests/text/bpe_tests.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include +#include + +struct TextBPETokenize : public cudf::test::BaseFixture { +}; + +TEST_F(TextBPETokenize, BytePairEncoding) +{ + // partial table based on values from https://huggingface.co/gpt2/raw/main/merges.txt + auto mpt = cudf::test::strings_column_wrapper({ + "e n", // 12 + "i t", // 14 + "i s", // 15 + "e s", // 18 + "en t", // 42 + "c e", // 88 + "es t", // 139 + "en ce", // 338 + "T h", // 561 + "Th is", // 956 + "t est", // 9032 + "s ent", // 33830 + }); + + nvtext::bpe_merge_pairs merge_pairs{cudf::strings_column_view(mpt)}; + + auto validity = cudf::test::iterators::null_at(4); + cudf::test::strings_column_wrapper input({" This\tis it\n", + "This is test-sentence-1", + "This is test sentence-2", + "This-is test sentence 3", + "", + ""}, + validity); + auto sv = cudf::strings_column_view(input); + + auto results = nvtext::byte_pair_encoding(sv, merge_pairs); + + auto expected = cudf::test::strings_column_wrapper({" This is it", + "This is test - sent ence - 1", + "This is test sent ence - 2", + "This - is test sent ence 3", + "", + ""}, + validity); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); + + auto sliced = cudf::slice(input, {1, 4}).front(); + auto sliced_expected = cudf::slice(expected, {1, 4}).front(); + + results = nvtext::byte_pair_encoding(cudf::strings_column_view(sliced), merge_pairs); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), sliced_expected); +} + +TEST_F(TextBPETokenize, BytePairEncodingSeparator) +{ + auto mpt = cudf::test::strings_column_wrapper( + {"e n", "i t", "e s", "en t", "c e", "es t", "en ce", "t est", "s ent"}); + nvtext::bpe_merge_pairs merge_pairs{cudf::strings_column_view(mpt)}; + + cudf::test::strings_column_wrapper input( + {"test-sentence-1", "test sentence-2", "test sentence 3", " test sentence 4 "}); + auto sv = cudf::strings_column_view(input); + + auto results = nvtext::byte_pair_encoding(sv, merge_pairs, std::string(" Ġ")); + + auto expected = cudf::test::strings_column_wrapper( + {"test - sent ence - 1", "test Ġsent ence - 2", "test Ġsent ence Ġ3", " Ġtest Ġsent ence Ġ4"}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected); +} + +TEST_F(TextBPETokenize, BPE_Empty) +{ + auto mpt = cudf::test::strings_column_wrapper({"i s", "i t"}); + nvtext::bpe_merge_pairs merge_pairs{mpt.release()}; + auto empty = cudf::make_empty_column(cudf::type_id::STRING); + auto results = nvtext::byte_pair_encoding(cudf::strings_column_view(empty->view()), merge_pairs); + EXPECT_EQ(0, results->size()); +} + +TEST_F(TextBPETokenize, BPE_Error) +{ + auto empty = cudf::make_empty_column(cudf::type_id::STRING); + nvtext::bpe_merge_pairs merge_pairs{std::move(empty)}; + cudf::test::strings_column_wrapper input({"isit"}); + EXPECT_THROW(nvtext::byte_pair_encoding(cudf::strings_column_view(input), merge_pairs), + cudf::logic_error); +}