Skip to content

Commit

Permalink
feat!: provide key columns as scalars (vs. vectors) to `RollingFormul…
Browse files Browse the repository at this point in the history
…a` (#6375)

### Example:

NOTE: `Sym` is a key column and is constant for each bucket. It is
presented to the UDF as a string (not a vector). `intCol` / `longCol`
are vectors containing the window data.

```
t_out = t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks,
        "out_val=sum(intCol) - max(longCol) + (Sym == null ? 0 : Sym.length())"), "Sym");
```
  • Loading branch information
lbooker42 authored Nov 19, 2024
1 parent 60a2948 commit 5d44879
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//
package io.deephaven.engine.table.impl.updateby;

import io.deephaven.UncheckedDeephavenException;
import io.deephaven.api.updateby.UpdateByControl;
import io.deephaven.base.verify.Assert;
import io.deephaven.engine.exceptions.CancellationException;
Expand Down Expand Up @@ -83,9 +82,12 @@ class BucketedPartitionedUpdateByManager extends UpdateBy {
final PartitionedTable partitioned = source.partitionedAggBy(List.of(), true, null, byColumnNames);
final PartitionedTable transformed = partitioned.transform(t -> {
final long firstSourceRowKey = t.getRowSet().firstRowKey();
final Object[] bucketKeyValues = Arrays.stream(byColumnNames)
.map(colName -> t.getColumnSource(colName).get(firstSourceRowKey))
.toArray();
final String bucketDescription = BucketedPartitionedUpdateByManager.this + "-bucket-" +
Arrays.stream(byColumnNames)
.map(bcn -> Objects.toString(t.getColumnSource(bcn).get(firstSourceRowKey)))
Arrays.stream(bucketKeyValues)
.map(Objects::toString)
.collect(Collectors.joining(", ", "[", "]"));
UpdateByBucketHelper bucket = new UpdateByBucketHelper(
bucketDescription,
Expand All @@ -94,7 +96,8 @@ class BucketedPartitionedUpdateByManager extends UpdateBy {
resultSources,
timestampColumnName,
control,
this::onBucketFailure);
this::onBucketFailure,
bucketKeyValues);
// add this to the bucket list
synchronized (buckets) {
buckets.offer(bucket);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class UpdateByBucketHelper extends IntrusiveDoublyLinkedNode.Impl<UpdateByBucket
private final ColumnSource<?> timestampColumnSource;
private final ModifiedColumnSet timestampColumnSet;

/** Store boxed key values for this bucket */
private final Object[] bucketKeyValues;

/** Indicates this bucket needs to be processed (at least one window and operator are dirty) */
private boolean isDirty;
/** This rowset will store row keys where the timestamp is not null (will mirror the SSA contents) */
Expand All @@ -65,22 +68,25 @@ class UpdateByBucketHelper extends IntrusiveDoublyLinkedNode.Impl<UpdateByBucket
* @param resultSources the result sources
* @param timestampColumnName the timestamp column used for time-based operations
* @param control the control object.
* @param failureNotifier a consumer to notify of any failures
* @param bucketKeyValues the key values for this bucket (empty for zero-key)
*/

protected UpdateByBucketHelper(
@NotNull final String description,
@NotNull final QueryTable source,
@NotNull final UpdateByWindow[] windows,
@NotNull final Map<String, ? extends ColumnSource<?>> resultSources,
@Nullable final String timestampColumnName,
@NotNull final UpdateByControl control,
@NotNull final BiConsumer<Throwable, TableListener.Entry> failureNotifier) {
@NotNull final BiConsumer<Throwable, TableListener.Entry> failureNotifier,
@NotNull final Object[] bucketKeyValues) {
this.description = description;
this.source = source;
// some columns will have multiple inputs, such as time-based and Weighted computations
this.windows = windows;
this.control = control;
this.failureNotifier = failureNotifier;
this.bucketKeyValues = bucketKeyValues;

result = new QueryTable(source.getRowSet(), resultSources);

Expand Down Expand Up @@ -331,7 +337,8 @@ public void prepareForUpdate(final TableUpdate upstream, final boolean initialSt
timestampValidRowSet,
timestampsModified,
control.chunkCapacityOrDefault(),
initialStep);
initialStep,
bucketKeyValues);

// compute the affected/influenced operators and rowsets within this window
windows[winIdx].computeAffectedRowsAndOperators(windowContexts[winIdx], upstream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,46 @@ protected UpdateByOperator(
*/
public abstract void initializeSources(@NotNull Table source, @Nullable RowRedirection rowRedirection);

/**
* Initialize the bucket context for a cumulative operator and pass in the bucket key values. Most operators will
* not need the key values, but those that do can override this method.
*/
public void initializeCumulativeWithKeyValues(
@NotNull final Context context,
final long firstUnmodifiedKey,
final long firstUnmodifiedTimestamp,
@NotNull final RowSet bucketRowSet,
@NotNull Object[] bucketKeyValues) {
initializeCumulative(context, firstUnmodifiedKey, firstUnmodifiedTimestamp, bucketRowSet);
}

/**
* Initialize the bucket context for a cumulative operator
*/
public void initializeCumulative(@NotNull final Context context, final long firstUnmodifiedKey,
public void initializeCumulative(
@NotNull final Context context,
final long firstUnmodifiedKey,
final long firstUnmodifiedTimestamp,
@NotNull final RowSet bucketRowSet) {
context.reset();
}

/**
* Initialize the bucket context for a windowed operator and pass in the bucket key values. Most operators will not
* need the key values, but those that do can override this method.
*/
public void initializeRollingWithKeyValues(
@NotNull final Context context,
@NotNull final RowSet bucketRowSet,
@NotNull Object[] bucketKeyValues) {
initializeRolling(context, bucketRowSet);
}

/**
* Initialize the bucket context for a windowed operator
*/
public void initializeRolling(@NotNull final Context context,
public void initializeRolling(
@NotNull final Context context,
@NotNull final RowSet bucketRowSet) {
context.reset();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class UpdateByOperatorFactory {
private final MatchPair[] groupByColumns;
@NotNull
private final UpdateByControl control;
private Map<String, ColumnDefinition<?>> vectorColumnNameMap;
private Map<String, ColumnDefinition<?>> vectorColumnDefinitions;

public UpdateByOperatorFactory(
@NotNull final TableDefinition tableDef,
Expand Down Expand Up @@ -1437,7 +1437,6 @@ private UpdateByOperator makeRollingFormulaOperator(@NotNull final MatchPair pai
private UpdateByOperator makeRollingFormulaMultiColumnOperator(
@NotNull final TableDefinition tableDef,
@NotNull final RollingFormulaSpec rs) {

final long prevWindowScaleUnits = rs.revWindowScale().getTimeScaleUnits();
final long fwdWindowScaleUnits = rs.fwdWindowScale().getTimeScaleUnits();

Expand All @@ -1446,42 +1445,58 @@ private UpdateByOperator makeRollingFormulaMultiColumnOperator(
// Create the colum
final SelectColumn selectColumn = SelectColumn.of(Selectable.parse(rs.formula()));

// Get or create a column definition map where the definitions are vectors of the original column types.
if (vectorColumnNameMap == null) {
vectorColumnNameMap = new HashMap<>();
columnDefinitionMap.forEach((key, value) -> {
final ColumnDefinition<?> columnDef = ColumnDefinition.fromGenericType(
key,
VectorFactory.forElementType(value.getDataType()).vectorType(),
value.getDataType());
vectorColumnNameMap.put(key, columnDef);
});
// Get or create a column definition map composed of vectors of the original column types (or scalars when
// part of the group_by columns).
final Set<String> groupByColumnSet =
Arrays.stream(groupByColumns).map(MatchPair::rightColumn).collect(Collectors.toSet());
if (vectorColumnDefinitions == null) {
vectorColumnDefinitions = tableDef.getColumnStream().collect(Collectors.toMap(
ColumnDefinition::getName,
(final ColumnDefinition<?> cd) -> groupByColumnSet.contains(cd.getName())
? cd
: ColumnDefinition.fromGenericType(
cd.getName(),
VectorFactory.forElementType(cd.getDataType()).vectorType(),
cd.getDataType())));
}

// Get the input column names and data types from the formula.
final String[] inputColumnNames =
selectColumn.initDef(vectorColumnNameMap, compilationProcessor).toArray(String[]::new);
// Get the input column names from the formula and provide them to the rolling formula operator
final String[] allInputColumns =
selectColumn.initDef(vectorColumnDefinitions, compilationProcessor).toArray(String[]::new);
if (!selectColumn.getColumnArrays().isEmpty()) {
throw new IllegalArgumentException("RollingFormulaMultiColumnOperator does not support column arrays ("
+ selectColumn.getColumnArrays() + ")");
}
if (selectColumn.hasVirtualRowVariables()) {
throw new IllegalArgumentException("RollingFormula does not support virtual row variables");
}
final Class<?>[] inputColumnTypes = new Class[inputColumnNames.length];
final Class<?>[] inputVectorTypes = new Class[inputColumnNames.length];

for (int i = 0; i < inputColumnNames.length; i++) {
final ColumnDefinition<?> columnDef = columnDefinitionMap.get(inputColumnNames[i]);
inputColumnTypes[i] = columnDef.getDataType();
inputVectorTypes[i] = vectorColumnNameMap.get(inputColumnNames[i]).getDataType();
final Map<Boolean, List<String>> partitioned = Arrays.stream(allInputColumns)
.collect(Collectors.partitioningBy(groupByColumnSet::contains));
final String[] inputKeyColumns = partitioned.get(true).toArray(String[]::new);
final String[] inputNonKeyColumns = partitioned.get(false).toArray(String[]::new);

final Class<?>[] inputKeyColumnTypes = new Class[inputKeyColumns.length];
final Class<?>[] inputKeyComponentTypes = new Class[inputKeyColumns.length];
for (int i = 0; i < inputKeyColumns.length; i++) {
final ColumnDefinition<?> columnDef = columnDefinitionMap.get(inputKeyColumns[i]);
inputKeyColumnTypes[i] = columnDef.getDataType();
inputKeyComponentTypes[i] = columnDef.getComponentType();
}

final Class<?>[] inputNonKeyColumnTypes = new Class[inputNonKeyColumns.length];
final Class<?>[] inputNonKeyVectorTypes = new Class[inputNonKeyColumns.length];
for (int i = 0; i < inputNonKeyColumns.length; i++) {
final ColumnDefinition<?> columnDef = columnDefinitionMap.get(inputNonKeyColumns[i]);
inputNonKeyColumnTypes[i] = columnDef.getDataType();
inputNonKeyVectorTypes[i] = vectorColumnDefinitions.get(inputNonKeyColumns[i]).getDataType();
}

final String[] affectingColumns;
if (rs.revWindowScale().timestampCol() == null) {
affectingColumns = inputColumnNames;
affectingColumns = inputNonKeyColumns;
} else {
affectingColumns = ArrayUtils.add(inputColumnNames, rs.revWindowScale().timestampCol());
affectingColumns = ArrayUtils.add(inputNonKeyColumns, rs.revWindowScale().timestampCol());
}

// Create a new column pair with the same name for the left and right columns
Expand All @@ -1494,9 +1509,12 @@ private UpdateByOperator makeRollingFormulaMultiColumnOperator(
prevWindowScaleUnits,
fwdWindowScaleUnits,
selectColumn,
inputColumnNames,
inputColumnTypes,
inputVectorTypes);
inputKeyColumns,
inputKeyColumnTypes,
inputKeyComponentTypes,
inputNonKeyColumns,
inputNonKeyColumnTypes,
inputNonKeyVectorTypes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ static class UpdateByWindowBucketContext implements SafeCloseable {
protected final boolean timestampsModified;
/** Whether this is the creation phase of this window */
protected final boolean initialStep;

/** Store the key values for this bucket */
protected final Object[] bucketKeyValues;

/** An array of ColumnSources for each underlying operator */
protected ColumnSource<?>[] inputSources;
Expand All @@ -71,12 +72,14 @@ static class UpdateByWindowBucketContext implements SafeCloseable {
final TrackingRowSet timestampValidRowSet,
final boolean timestampsModified,
final int chunkSize,
final boolean initialStep) {
final boolean initialStep,
@NotNull final Object[] bucketKeyValues) {
this.sourceRowSet = sourceRowSet;
this.timestampColumnSource = timestampColumnSource;
this.timestampSsa = timestampSsa;
this.timestampValidRowSet = timestampValidRowSet;
this.timestampsModified = timestampsModified;
this.bucketKeyValues = bucketKeyValues;

workingChunkSize = chunkSize;
this.initialStep = initialStep;
Expand All @@ -91,13 +94,15 @@ public void close() {
}
}

abstract UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet,
abstract UpdateByWindowBucketContext makeWindowContext(
final TrackingRowSet sourceRowSet,
final ColumnSource<?> timestampColumnSource,
final LongSegmentedSortedArray timestampSsa,
final TrackingRowSet timestampValidRowSet,
final boolean timestampsModified,
final int chunkSize,
final boolean isInitializeStep);
final boolean isInitializeStep,
final Object[] bucketKeyValues);

UpdateByWindow(final UpdateByOperator[] operators,
final int[][] operatorInputSourceSlots,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet,
final TrackingRowSet timestampValidRowSet,
final boolean timestampsModified,
final int chunkSize,
final boolean isInitializeStep) {
return new UpdateByWindowBucketContext(sourceRowSet, timestampColumnSource, timestampSsa, timestampValidRowSet,
timestampsModified, chunkSize, isInitializeStep);
final boolean isInitializeStep,
final Object[] bucketKeyValues) {
return new UpdateByWindowBucketContext(
sourceRowSet,
timestampColumnSource,
timestampSsa,
timestampValidRowSet,
timestampsModified,
chunkSize,
isInitializeStep,
bucketKeyValues);
}

@Override
Expand Down Expand Up @@ -192,7 +200,8 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context,
continue;
}
UpdateByOperator cumOp = operators[opIdx];
cumOp.initializeCumulative(winOpContexts[ii], rowKey, timestamp, context.sourceRowSet);
cumOp.initializeCumulativeWithKeyValues(winOpContexts[ii], rowKey, timestamp, context.sourceRowSet,
context.bucketKeyValues);
}

while (affectedIt.hasMore()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ static class UpdateByWindowRollingBucketContext extends UpdateByWindowBucketCont
final TrackingRowSet timestampValidRowSet,
final boolean timestampsModified,
final int chunkSize,
final boolean initialStep) {
final boolean initialStep,
final Object[] bucketKeyValues) {
super(sourceRowSet,
timestampColumnSource,
timestampSsa,
timestampValidRowSet,
timestampsModified,
chunkSize,
initialStep);
initialStep,
bucketKeyValues);
}

@Override
Expand All @@ -60,7 +62,7 @@ public void close() {
}

UpdateByWindowRollingBase(@NotNull final UpdateByOperator[] operators,
@NotNull final int[][] operatorSourceSlots,
final int[][] operatorSourceSlots,
final long prevUnits,
final long fwdUnits,
@Nullable final String timestampColumnName) {
Expand Down Expand Up @@ -152,7 +154,7 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context,
continue;
}
UpdateByOperator rollingOp = operators[opIdx];
rollingOp.initializeRolling(winOpContexts[ii], bucketRowSet);
rollingOp.initializeRollingWithKeyValues(winOpContexts[ii], bucketRowSet, context.bucketKeyValues);
}

int affectedChunkOffset = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ static class UpdateByWindowTicksBucketContext extends UpdateByWindowRollingBucke
private RowSet affectedRowPositions;
private RowSet influencerPositions;

UpdateByWindowTicksBucketContext(final TrackingRowSet sourceRowSet,
final int chunkSize, final boolean initialStep) {
super(sourceRowSet, null, null, null, false, chunkSize, initialStep);
UpdateByWindowTicksBucketContext(
final TrackingRowSet sourceRowSet,
final int chunkSize,
final boolean initialStep,
final Object[] bucketKeyValues) {
super(sourceRowSet, null, null, null, false, chunkSize, initialStep, bucketKeyValues);
}

@Override
Expand Down Expand Up @@ -77,14 +80,16 @@ void finalizeWindowBucket(UpdateByWindowBucketContext context) {
}

@Override
UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet,
UpdateByWindowBucketContext makeWindowContext(
final TrackingRowSet sourceRowSet,
final ColumnSource<?> timestampColumnSource,
final LongSegmentedSortedArray timestampSsa,
final TrackingRowSet timestampValidRowSet,
final boolean timestampsModified,
final int chunkSize,
final boolean isInitializeStep) {
return new UpdateByWindowTicksBucketContext(sourceRowSet, chunkSize, isInitializeStep);
final boolean isInitializeStep,
final Object[] bucketKeyValues) {
return new UpdateByWindowTicksBucketContext(sourceRowSet, chunkSize, isInitializeStep, bucketKeyValues);
}

private static WritableRowSet computeAffectedRowsTicks(final RowSet sourceSet, final RowSet invertedSubSet,
Expand Down
Loading

0 comments on commit 5d44879

Please sign in to comment.