Skip to content

Commit

Permalink
Fix VirtualColumn related issues in window expressions (apache#15119)
Browse files Browse the repository at this point in the history
for some exotic queries like:

  SELECT
  	'_'||dim1,
    MIN(cast(0 as double)) OVER (),
    MIN(cast((cnt||cnt) as bigint)) OVER ()
  FROM foo
the compilation have resulted in NPE -s mostly because VirtualColumn -s were not handled properly
  • Loading branch information
kgyrtkirk authored Oct 23, 2023
1 parent c8e4584 commit b95035f
Show file tree
Hide file tree
Showing 23 changed files with 449 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.druid.frame.write.columnar;

import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.write.UnsupportedColumnTypeException;
import org.apache.druid.java.util.common.ISE;
Expand Down Expand Up @@ -167,9 +166,7 @@ private static ComplexFrameColumnWriter makeComplexWriter(

private static boolean hasNullsForNumericWriter(final ColumnCapabilities capabilities)
{
if (NullHandling.replaceWithDefault()) {
return false;
} else if (capabilities == null) {
if (capabilities == null) {
return true;
} else if (capabilities.getType().isNumeric()) {
return capabilities.hasNulls().isMaybeTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
@JsonSubTypes.Type(name = "naivePartition", value = NaivePartitioningOperatorFactory.class),
@JsonSubTypes.Type(name = "naiveSort", value = NaiveSortOperatorFactory.class),
@JsonSubTypes.Type(name = "window", value = WindowOperatorFactory.class),
@JsonSubTypes.Type(name = "scan", value = ScanOperatorFactory.class),
})
public interface OperatorFactory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public Query<RowsAndColumns> withQuerySegmentSpec(QuerySegmentSpec spec)
{
return new WindowOperatorQuery(
getDataSource(),
getQuerySegmentSpec(),
spec,
getContext(),
rowSignature,
operators,
Expand All @@ -217,6 +217,18 @@ public Query<RowsAndColumns> withDataSource(DataSource dataSource)
);
}

public Query<RowsAndColumns> withOperators(List<OperatorFactory> operators)
{
return new WindowOperatorQuery(
getDataSource(),
getQuerySegmentSpec(),
getContext(),
rowSignature,
operators,
leafOperators
);
}

@Override
public boolean equals(Object o)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,17 @@ public QueryRunner<RowsAndColumns> mergeResults(QueryRunner<RowsAndColumns> runn
(queryPlus, responseContext) -> {
final WindowOperatorQuery query = (WindowOperatorQuery) queryPlus.getQuery();
final List<OperatorFactory> opFactories = query.getOperators();
if (opFactories.isEmpty()) {
return runner.run(queryPlus, responseContext);
}

Supplier<Operator> opSupplier = () -> {
Operator retVal = new SequenceOperator(runner.run(queryPlus, responseContext));
Operator retVal = new SequenceOperator(
runner.run(
queryPlus.withQuery(query.withOperators(new ArrayList<OperatorFactory>())),
responseContext
)
);
for (OperatorFactory operatorFactory : opFactories) {
retVal = operatorFactory.wrap(retVal);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.druid.query.rowsandcols;

import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory;
Expand Down Expand Up @@ -101,6 +102,11 @@ public Collection<String> getColumnNames()
return viewableColumns == null ? base.getColumnNames() : viewableColumns;
}

public RowsAndColumns getBase()
{
return base;
}

@Override
public int numRows()
{
Expand All @@ -115,7 +121,6 @@ public Column findColumn(String name)
if (viewableColumns != null && !viewableColumns.contains(name)) {
return null;
}

maybeMaterialize();
return base.findColumn(name);
}
Expand Down Expand Up @@ -158,7 +163,7 @@ public WireTransferable toWireTransferable()

private void maybeMaterialize()
{
if (!(interval == null && filter == null && limit == -1 && ordering == null)) {
if (needsMaterialization()) {
final Pair<byte[], RowSignature> thePair = materialize();
if (thePair == null) {
reset(new EmptyRowsAndColumns());
Expand All @@ -168,6 +173,11 @@ private void maybeMaterialize()
}
}

private boolean needsMaterialization()
{
return interval != null || filter != null || limit != -1 || ordering != null || virtualColumns != null;
}

private Pair<byte[], RowSignature> materialize()
{
if (ordering != null) {
Expand All @@ -180,7 +190,6 @@ private Pair<byte[], RowSignature> materialize()
} else {
return materializeStorageAdapter(as);
}

}

private void reset(RowsAndColumns rac)
Expand All @@ -200,13 +209,26 @@ private Pair<byte[], RowSignature> materializeStorageAdapter(StorageAdapter as)
final Sequence<Cursor> cursors = as.makeCursors(
filter,
interval == null ? Intervals.ETERNITY : interval,
virtualColumns,
virtualColumns == null ? VirtualColumns.EMPTY : virtualColumns,
Granularities.ALL,
false,
null
);

Collection<String> cols = viewableColumns == null ? base.getColumnNames() : viewableColumns;

final Collection<String> cols;
if (viewableColumns != null) {
cols = viewableColumns;
} else {
if (virtualColumns == null) {
cols = base.getColumnNames();
} else {
cols = ImmutableList.<String>builder()
.addAll(base.getColumnNames())
.addAll(virtualColumns.getColumnNames())
.build();
}
}
AtomicReference<RowSignature> siggy = new AtomicReference<>(null);

FrameWriter writer = cursors.accumulate(null, (accumulated, in) -> {
Expand All @@ -222,9 +244,18 @@ private Pair<byte[], RowSignature> materializeStorageAdapter(StorageAdapter as)
final RowSignature.Builder sigBob = RowSignature.builder();

for (String col : cols) {
final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(col);
ColumnCapabilities capabilities;
capabilities = columnSelectorFactory.getColumnCapabilities(col);
if (capabilities != null) {
sigBob.add(col, capabilities.toColumnType());
continue;
}
if (virtualColumns != null) {
capabilities = virtualColumns.getColumnCapabilities(columnSelectorFactory, col);
if (capabilities != null) {
sigBob.add(col, capabilities.toColumnType());
continue;
}
}
}
final RowSignature signature = sigBob.build();
Expand Down Expand Up @@ -350,12 +381,12 @@ private Pair<byte[], RowSignature> naiveMaterialize(RowsAndColumns rac)
final RowSignature.Builder sigBob = RowSignature.builder();
final ArenaMemoryAllocatorFactory memFactory = new ArenaMemoryAllocatorFactory(200 << 20);


for (String column : columnsToGenerate) {
final Column racColumn = rac.findColumn(column);
if (racColumn == null) {
continue;
}

sigBob.add(column, racColumn.toAccessor().getType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@

import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.read.columnar.FrameColumnReaders;
import org.apache.druid.frame.segment.FrameStorageAdapter;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.query.rowsandcols.RowsAndColumns;
import org.apache.druid.query.rowsandcols.column.Column;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;

Expand Down Expand Up @@ -77,10 +81,14 @@ public Column findColumn(String name)

}

@SuppressWarnings("unchecked")
@Nullable
@Override
public <T> T as(Class<T> clazz)
{
if (StorageAdapter.class.equals(clazz)) {
return (T) new FrameStorageAdapter(frame, FrameReader.create(signature), Intervals.ETERNITY);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public ColumnValueSelector makeColumnValueSelector(@Nonnull String columnName)
{
return withColumnAccessor(columnName, columnAccessor -> {
if (columnAccessor == null) {
return DimensionSelector.constant(null);
return DimensionSelector.nilSelector();
} else {
final ColumnType type = columnAccessor.getType();
switch (type.getType()) {
Expand All @@ -160,16 +160,22 @@ public ColumnValueSelector makeColumnValueSelector(@Nonnull String columnName)
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
return withColumnAccessor(column, columnAccessor ->
new ColumnCapabilitiesImpl()
return withColumnAccessor(column, columnAccessor -> {
if (columnAccessor == null) {
return null;
} else {
return new ColumnCapabilitiesImpl()
.setType(columnAccessor.getType())
.setHasMultipleValues(false)
.setDictionaryEncoded(false)
.setHasBitmapIndexes(false));
.setHasBitmapIndexes(false);
}
});
}

private <T> T withColumnAccessor(String column, Function<ColumnAccessor, T> fn)
{
@Nullable
ColumnAccessor retVal = accessorCache.get(column);
if (retVal == null) {
Column racColumn = rac.findColumn(column);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@

import javax.annotation.Nullable;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Class allowing lookup and usage of virtual columns.
Expand Down Expand Up @@ -112,6 +114,11 @@ public static VirtualColumns create(List<VirtualColumn> virtualColumns)
return new VirtualColumns(ImmutableList.copyOf(virtualColumns), withDotSupport, withoutDotSupport);
}

public static VirtualColumns create(VirtualColumn... virtualColumns)
{
return create(Arrays.asList(virtualColumns));
}

public static VirtualColumns nullToEmpty(@Nullable VirtualColumns virtualColumns)
{
return virtualColumns == null ? EMPTY : virtualColumns;
Expand Down Expand Up @@ -519,4 +526,14 @@ public boolean equals(Object obj)
((VirtualColumns) obj).virtualColumns.isEmpty();
}
}

public boolean isEmpty()
{
return virtualColumns.isEmpty();
}

public List<String> getColumnNames()
{
return virtualColumns.stream().map(v -> v.getOutputName()).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ public static DruidExceptionMatcher invalidSqlInput()
return invalidInput().expectContext("sourceType", "sql");
}

public static DruidExceptionMatcher defensive()
{
return new DruidExceptionMatcher(
DruidException.Persona.DEVELOPER,
DruidException.Category.DEFENSIVE,
"general"
);
}

private final AllOf<DruidException> delegate;
private final ArrayList<Matcher<? super DruidException>> matcherList;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@

package org.apache.druid.query.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.spec.LegacySegmentSpec;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.segment.column.RowSignature;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;

/**
Expand Down Expand Up @@ -107,6 +112,22 @@ public void withDataSource()
Assert.assertSame(newDs, query.withDataSource(newDs).getDataSource());
}

@Test
public void withQuerySpec()
{
QuerySegmentSpec spec = new MultipleIntervalSegmentSpec(Collections.emptyList());
Assert.assertSame(spec, ((WindowOperatorQuery) query.withQuerySegmentSpec(spec)).getQuerySegmentSpec());
}

@Test
public void withOperators()
{
List<OperatorFactory> operators = ImmutableList.<OperatorFactory>builder()
.add(new NaivePartitioningOperatorFactory(Collections.singletonList("some")))
.build();
Assert.assertSame(operators, ((WindowOperatorQuery) query.withOperators(operators)).getOperators());
}

@Test
public void testEquals()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.common.collect.Lists;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.rowsandcols.concrete.FrameRowsAndColumns;
import org.apache.druid.query.rowsandcols.concrete.FrameRowsAndColumnsTest;
import org.junit.Assert;
import org.junit.Test;

Expand Down Expand Up @@ -63,7 +65,8 @@ private static ArrayList<Object[]> getMakers()
new Object[]{MapOfColumnsRowsAndColumns.class, Function.identity()},
new Object[]{ArrayListRowsAndColumns.class, ArrayListRowsAndColumnsTest.MAKER},
new Object[]{ConcatRowsAndColumns.class, ConcatRowsAndColumnsTest.MAKER},
new Object[]{RearrangedRowsAndColumns.class, RearrangedRowsAndColumnsTest.MAKER}
new Object[]{RearrangedRowsAndColumns.class, RearrangedRowsAndColumnsTest.MAKER},
new Object[]{FrameRowsAndColumns.class, FrameRowsAndColumnsTest.MAKER}
);
}

Expand Down
Loading

0 comments on commit b95035f

Please sign in to comment.