Skip to content

Commit

Permalink
Merge pull request #126 from bryevdv/bryanv/scatter
Browse files Browse the repository at this point in the history
Add scatter for Series
  • Loading branch information
bryevdv authored Mar 19, 2021
2 parents 646e862 + ee6df5c commit 60c7aa1
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 47 deletions.
18 changes: 17 additions & 1 deletion modules/cudf/src/node_cudf/table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,24 @@ class Table : public Napi::ObjectWrap<Table> {
cudf::size_type threshold,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const;

// table/copying.cpp
ObjectUnwrap<Table> gather(
Column const& gather_map,
cudf::out_of_bounds_policy bounds_policy = cudf::out_of_bounds_policy::DONT_CHECK,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const;

ObjectUnwrap<Table> scatter(
std::vector<std::reference_wrapper<const cudf::scalar>> const& source,
Column const& indices,
bool check_bounds = false,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const;

ObjectUnwrap<Table> scatter(
Table const& source,
Column const& indices,
bool check_bounds = false,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const;

private:
static Napi::FunctionReference constructor;

Expand All @@ -187,10 +200,13 @@ class Table : public Napi::ObjectWrap<Table> {
Napi::Value num_columns(Napi::CallbackInfo const& info);
Napi::Value num_rows(Napi::CallbackInfo const& info);
Napi::Value select(Napi::CallbackInfo const& info);
Napi::Value gather(Napi::CallbackInfo const& info);
Napi::Value get_column(Napi::CallbackInfo const& info);
Napi::Value drop_nulls(Napi::CallbackInfo const& info);
Napi::Value drop_nans(Napi::CallbackInfo const& info);
// table/copying.cpp
Napi::Value gather(Napi::CallbackInfo const& info);
Napi::Value scatter_scalar(Napi::CallbackInfo const& info);
Napi::Value scatter_table(Napi::CallbackInfo const& info);

static Napi::Value read_csv(Napi::CallbackInfo const& info);
Napi::Value write_csv(Napi::CallbackInfo const& info);
Expand Down
59 changes: 56 additions & 3 deletions modules/cudf/src/series.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {VectorType} from 'apache-arrow/interfaces';
import {Column, ColumnProps} from './column';
import {fromArrow} from './column/from_arrow';
import {DataFrame} from './data_frame';
import {Scalar} from './scalar';
import {Table} from './table';
import {
Bool8,
Expand Down Expand Up @@ -139,7 +140,7 @@ export class AbstractSeries<T extends DataType = any> {
}

/** @ignore */
public readonly _col: Column<T>;
public _col: Column<T>;

protected constructor(input: SeriesProps<T>|Column<T>|arrow.Vector<T>) {
this._col = asColumn<T>(input);
Expand Down Expand Up @@ -194,6 +195,52 @@ export class AbstractSeries<T extends DataType = any> {
return this.__construct(this._col.gather(selection._col));
}

/**
* Scatters single value into this Series according to provided indices.
*
* @param value A column of values to be scattered in to this Series
* @param indices A column of integral indices that indicate the rows in the this Series to be
* replaced by `value`.
* @param check_bounds Optionally perform bounds checking on the indices and throw an error if any
* of its values are out of bounds (default: false).
* @param memoryResource An optional MemoryResource used to allocate the result's device memory.
*/
scatter(value: T['scalarType'],
indices: Series<Int32>|number[],
check_bounds?: boolean,
memoryResource?: MemoryResource): void;
/**
* Scatters a column of values into this Series according to provided indices.
*
* @param value A value to be scattered in to this Series
* @param indices A column of integral indices that indicate the rows in the this Series to be
* replaced by `value`.
* @param check_bounds Optionally perform bounds checking on the indices and throw an error if any
* of its values are out of bounds (default: false).
* @param memoryResource An optional MemoryResource used to allocate the result's device memory.
*/
scatter(values: Series<T>,
indices: Series<Int32>|number[],
check_bounds?: boolean,
memoryResource?: MemoryResource): void;

scatter(source: Series<T>|T['scalarType'],
indices: Series<Int32>|number[],
check_bounds = false,
memoryResource?: MemoryResource): void {
const dst = new Table({columns: [this._col]});
const inds = indices instanceof Series ? indices : new Series({type: new Int32, data: indices});
if (source instanceof Series) {
const src = new Table({columns: [source._col]});
const out = dst.scatterTable(src, inds._col, check_bounds, memoryResource);
this._col = out.getColumnByIndex(0);
} else {
const src = [new Scalar({type: this.type, value: source})];
const out = dst.scatterScalar(src, inds._col, check_bounds, memoryResource);
this._col = out.getColumnByIndex(0);
}
}

/**
* Return a sub-selection of this Series using the specified boolean mask.
*
Expand All @@ -205,11 +252,17 @@ export class AbstractSeries<T extends DataType = any> {
/**
* Return a value at the specified index to host memory
*
* @param index
* @param index the index in this Series to return a value for
*/
getValue(index: number) { return this._col.getValue(index); }

// setValue(index: number, value?: this[0] | null);
/**
* Set a value at the specified index
*
* @param index the index in this Series to set a value for
* @param value the value to set at `index`
*/
setValue(index: number, value: T['scalarType']): void { this.scatter(value, [index]); }

/**
* Copy the underlying device memory to host, and return an Iterator of the values.
Expand Down
14 changes: 2 additions & 12 deletions modules/cudf/src/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Napi::Object Table::Init(Napi::Env env, Napi::Object exports) {
{
InstanceAccessor<&Table::num_columns>("numColumns"),
InstanceAccessor<&Table::num_rows>("numRows"),
InstanceMethod<&Table::scatter_scalar>("scatterScalar"),
InstanceMethod<&Table::scatter_table>("scatterTable"),
InstanceMethod<&Table::gather>("gather"),
InstanceMethod<&Table::get_column>("getColumnByIndex"),
InstanceMethod<&Table::to_arrow>("toArrow"),
Expand Down Expand Up @@ -137,18 +139,6 @@ Napi::Value Table::num_columns(Napi::CallbackInfo const& info) {

Napi::Value Table::num_rows(Napi::CallbackInfo const& info) { return CPPToNapi(info)(num_rows()); }

Napi::Value Table::gather(Napi::CallbackInfo const& info) {
CallbackArgs args{info};
if (!Column::is_instance(args[0])) {
throw Napi::Error::New(info.Env(), "gather selection argument expects a Column");
}
auto& selection = *Column::Unwrap(args[0]);
if (selection.type().id() == cudf::type_id::BOOL8) {
return this->apply_boolean_mask(selection)->Value();
}
return this->gather(selection)->Value();
}

Napi::Value Table::get_column(Napi::CallbackInfo const& info) {
cudf::size_type i = CallbackArgs{info}[0];
if (i >= num_columns_) { throw Napi::Error::New(info.Env(), "Column index out of bounds"); }
Expand Down
44 changes: 35 additions & 9 deletions modules/cudf/src/table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import {MemoryResource} from '@rapidsai/rmm';
import CUDF from './addon';
import {Column} from './column';
import {Scalar} from './scalar';
import {CSVTypeMap, ReadCSVOptions, WriteCSVOptions} from './types/csv';
import {
Bool8,
DataType,
IndexType,
Int32,
} from './types/dtypes';
import {
NullOrder,
} from './types/enums';
import {Bool8, DataType, IndexType, Int32} from './types/dtypes';
import {NullOrder} from './types/enums';
import {TypeMap} from './types/mappings';

export type ToArrowMetadata = [string | number, ToArrowMetadata[]?];

Expand Down Expand Up @@ -71,6 +67,36 @@ export interface Table {
*/
gather(selection: Column<IndexType|Bool8>): Table;

/**
* Scatters row of values into this Table according to provided indices.
*
* @param source A column of values to be scattered in to this Series
* @param indices A column of integral indices that indicate the rows in the this Series to be
* replaced by `value`.
* @param check_bounds Optionally perform bounds checking on the indices and throw an error if any
* of its values are out of bounds (default: false).
* @param memoryResource An optional MemoryResource used to allocate the result's device memory.
*/
scatterScalar<T extends TypeMap = any>(source: (Scalar<T[keyof T]>)[],
indices: Column<Int32>,
check_bounds?: boolean,
memoryResource?: MemoryResource): Table;

/**
* Scatters a Table of values into this Table according to provided indices.
*
* @param value A value to be scattered in to this Series
* @param indices A column of integral indices that indicate the rows in the this Series to be
* replaced by `value`.
* @param check_bounds Optionally perform bounds checking on the indices and throw an error if any
* of its values are out of bounds (default: false).
* @param memoryResource An optional MemoryResource used to allocate the result's device memory.
*/
scatterTable(source: Table,
indices: Column<Int32>,
check_bounds?: boolean,
memoryResource?: MemoryResource): Table;

/**
* Get the Column at a specified index
*
Expand Down
89 changes: 89 additions & 0 deletions modules/cudf/src/table/copying.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <node_cudf/column.hpp>
#include <node_cudf/table.hpp>
#include <node_cudf/utilities/napi_to_cpp.hpp>

#include <cudf/copying.hpp>
#include <cudf/table/table_view.hpp>
Expand All @@ -28,4 +29,92 @@ ObjectUnwrap<Table> Table::gather(Column const& gather_map,
return Table::New(cudf::gather(cudf::table_view{{*this}}, gather_map, bounds_policy, mr));
}

ObjectUnwrap<Table> Table::scatter(
std::vector<std::reference_wrapper<const cudf::scalar>> const& source,
Column const& indices,
bool check_bounds,
rmm::mr::device_memory_resource* mr) const {
try {
return Table::New(cudf::scatter(source, indices.view(), this->view(), check_bounds, mr));
} catch (cudf::logic_error const& err) { NAPI_THROW(Napi::Error::New(Env(), err.what())); }
}

ObjectUnwrap<Table> Table::scatter(Table const& source,
Column const& indices,
bool check_bounds,
rmm::mr::device_memory_resource* mr) const {
try {
return Table::New(cudf::scatter(source.view(), indices.view(), this->view(), check_bounds, mr));
} catch (cudf::logic_error const& err) { NAPI_THROW(Napi::Error::New(Env(), err.what())); }
}

Napi::Value Table::gather(Napi::CallbackInfo const& info) {
CallbackArgs args{info};
if (!Column::is_instance(args[0])) {
throw Napi::Error::New(info.Env(), "gather selection argument expects a Column");
}
auto& selection = *Column::Unwrap(args[0]);
if (selection.type().id() == cudf::type_id::BOOL8) {
return this->apply_boolean_mask(selection)->Value();
}
return this->gather(selection)->Value();
}

Napi::Value Table::scatter_scalar(Napi::CallbackInfo const& info) {
CallbackArgs args{info};

if (args.Length() != 3 and args.Length() != 4) {
NAPI_THROW(Napi::Error::New(info.Env(),
"scatter_scalar expects a vector of scalars, a Column, and "
"optionally a bool and memory resource"));
}

if (!args[0].IsArray()) {
throw Napi::Error::New(info.Env(), "scatter_scalar source argument expects an array");
}
auto const source_array = args[0].As<Napi::Array>();
std::vector<std::reference_wrapper<const cudf::scalar>> source{};

for (uint32_t i = 0; i < source_array.Length(); ++i) {
if (!Scalar::is_instance(source_array.Get(i))) {
throw Napi::Error::New(info.Env(),
"scatter_scalar source argument expects an array of scalars");
}
auto& scalar = *Scalar::Unwrap(source_array.Get(i).ToObject());
source.push_back(scalar);
}

if (!Column::is_instance(args[1])) {
throw Napi::Error::New(info.Env(), "scatter_scalar indices argument expects a Column");
}
auto& indices = *Column::Unwrap(args[1]);
bool check_bounds = args[2];
rmm::mr::device_memory_resource* mr = args[3];
return scatter(source, indices, check_bounds, mr)->Value();
}

Napi::Value Table::scatter_table(Napi::CallbackInfo const& info) {
CallbackArgs args{info};

if (args.Length() != 3 and args.Length() != 4) {
NAPI_THROW(Napi::Error::New(
info.Env(),
"scatter_table expects a Table, a Column, and optionally a bool and memory resource"));
}

if (!Table::is_instance(args[0])) {
throw Napi::Error::New(info.Env(), "scatter_table source argument expects a Table");
}
auto& source = *Table::Unwrap(args[0]);

if (!Column::is_instance(args[1])) {
throw Napi::Error::New(info.Env(), "scatter_table indices argument expects a Column");
}
auto& indices = *Column::Unwrap(args[1]);

bool check_bounds = args[2];
rmm::mr::device_memory_resource* mr = args[3];
return scatter(source, indices, check_bounds, mr)->Value();
}

} // namespace nv
Loading

0 comments on commit 60c7aa1

Please sign in to comment.