From b7726827489f79b44cde6db40e7870be3dd7654e Mon Sep 17 00:00:00 2001 From: Shashwat Arghode Date: Sun, 4 Oct 2020 17:45:17 -0700 Subject: [PATCH] Apache DataSketches plugin for Trino --- .../trino-server/src/main/provisio/presto.xml | 6 + docs/src/main/sphinx/functions.rst | 1 + .../main/sphinx/functions/datasketches.rst | 39 +++ .../main/sphinx/functions/list-by-topic.rst | 8 + docs/src/main/sphinx/functions/list.rst | 2 + plugin/trino-datasketches/pom.xml | 101 ++++++++ .../datasketches/state/SketchState.java | 40 ++++ .../state/SketchStateFactory.java | 223 ++++++++++++++++++ .../state/SketchStateSerializer.java | 70 ++++++ .../plugin/datasketches/theta/Estimate.java | 48 ++++ .../theta/SketchFunctionsPlugin.java | 33 +++ .../plugin/datasketches/theta/Union.java | 54 +++++ .../datasketches/theta/UnionWithParams.java | 69 ++++++ .../datasketches/theta/TestMergeEstimate.java | 160 +++++++++++++ pom.xml | 1 + 15 files changed, 855 insertions(+) create mode 100644 docs/src/main/sphinx/functions/datasketches.rst create mode 100644 plugin/trino-datasketches/pom.xml create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchState.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateFactory.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateSerializer.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Estimate.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/SketchFunctionsPlugin.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Union.java create mode 100644 plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/UnionWithParams.java create mode 100644 plugin/trino-datasketches/src/test/java/io/trino/plugin/datasketches/theta/TestMergeEstimate.java diff --git a/core/trino-server/src/main/provisio/presto.xml b/core/trino-server/src/main/provisio/presto.xml index c6ecddc6a3cc..5823e0b4f6a0 100644 --- a/core/trino-server/src/main/provisio/presto.xml +++ b/core/trino-server/src/main/provisio/presto.xml @@ -62,6 +62,12 @@ + + + + + + diff --git a/docs/src/main/sphinx/functions.rst b/docs/src/main/sphinx/functions.rst index 4dbcf5d31a77..deadef1bf956 100644 --- a/docs/src/main/sphinx/functions.rst +++ b/docs/src/main/sphinx/functions.rst @@ -25,6 +25,7 @@ and the :doc:`SQL statement and syntax reference`. Conditional Conversion Date and time + Datasketches Decimal Geospatial HyperLogLog diff --git a/docs/src/main/sphinx/functions/datasketches.rst b/docs/src/main/sphinx/functions/datasketches.rst new file mode 100644 index 000000000000..496abdc67340 --- /dev/null +++ b/docs/src/main/sphinx/functions/datasketches.rst @@ -0,0 +1,39 @@ +====================== +DataSketches Functions +====================== +DataSketches is a high-performance library of stochastic streaming +algorithms commonly called ”sketches” in the data sciences. Sketches are +small, stateful programs that process massive data as a stream and can +provide approximate answers, with mathematical guarantees, to +computationally difficult queries orders-of-magnitude faster than +traditional, exact methods. +The DataSketches functions allows querying the fast and memory-efficient `Apache +DataSkecthes `_ +from Trino. Support for `Theta Sketch Framework `_ +is added, specifically :func:`theta_sketch_union` and :func:`theta_sketch_estimate` functions. +These functions are used in the ``count distinct`` queries using sketches. +Datasketches can be created using Hive or Pig using respective sketch APIs. + +DataSketches functions +---------------------- + +.. function:: theta_sketch_union(sketches) -> sketch + + Returns a single sketch which is a merged collection of sketches. + +.. function:: theta_sketch_estimate(sketch) -> double + + Returns the estimated value of the sketch. + +Example in Trino for using DataSketches +--------------------------------------- +Query:: + + sql + SELECT + o_orderdate as date, + theta_sketch_estimate(theta_sketch_union(o_custkey_sketch)) AS unique_user_count + SUM(o_totalprice) AS user_spent, + FROM tpch.sf100000.orders WHERE o_orderdate >= dateadd(day, -90, current_date) + GROUP BY o_orderdate; + diff --git a/docs/src/main/sphinx/functions/list-by-topic.rst b/docs/src/main/sphinx/functions/list-by-topic.rst index dc48b1054fde..5a1f2a7aa3b9 100644 --- a/docs/src/main/sphinx/functions/list-by-topic.rst +++ b/docs/src/main/sphinx/functions/list-by-topic.rst @@ -178,6 +178,14 @@ For more details, see :doc:`conversion` * :func:`try_cast` * :func:`typeof` +DataSketches +------------ + +For more details, see :doc:`datasketches` + +* :func:`theta_sketch_estimate` +* :func:`theta_sketch_union` + Date and time ------------- diff --git a/docs/src/main/sphinx/functions/list.rst b/docs/src/main/sphinx/functions/list.rst index 109f95a9da5c..ac2ea6944f5a 100644 --- a/docs/src/main/sphinx/functions/list.rst +++ b/docs/src/main/sphinx/functions/list.rst @@ -462,6 +462,8 @@ T - :func:`tan` - :func:`tanh` - :func:`tdigest_agg` +- :func:`theta_sketch_estimate` +- :func:`theta_sketch_union` - :func:`timestamp_objectid` - :func:`timezone_hour` - :func:`timezone_minute` diff --git a/plugin/trino-datasketches/pom.xml b/plugin/trino-datasketches/pom.xml new file mode 100644 index 000000000000..1681800550d4 --- /dev/null +++ b/plugin/trino-datasketches/pom.xml @@ -0,0 +1,101 @@ + + + 4.0.0 + + + io.trino + trino-root + 362-SNAPSHOT + ../../pom.xml + + + trino-datasketches + trino-datasketches + trino-plugin + http://datasketches.apache.org/ + + + ${project.parent.basedir} + + + + + io.trino + trino-array + + + + com.google.guava + guava + + + + org.apache.datasketches + datasketches-java + 2.0.0 + + + + org.apache.datasketches + datasketches-memory + 1.3.0 + + + + + io.trino + trino-spi + provided + + + + io.airlift + slice + provided + + + + org.openjdk.jol + jol-core + provided + + + + + io.trino + trino-hive + test + + + + io.trino + trino-hive + test-jar + test + + + + io.trino + trino-main + test + + + + io.trino + trino-testing + test + + + + io.trino.hadoop + hadoop-apache + test + + + + org.testng + testng + test + + + diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchState.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchState.java new file mode 100644 index 000000000000..7233c62357d3 --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchState.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.state; + +import io.airlift.slice.Slice; +import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateMetadata; + +/** + * State object to keep track of sketch aggregations. + */ +@AccumulatorStateMetadata(stateSerializerClass = SketchStateSerializer.class, stateFactoryClass = SketchStateFactory.class) +public interface SketchState + extends AccumulatorState +{ + Slice getSketch(); + + int getNominalEntries(); + + long getSeed(); + + void setSketch(Slice value); + + void setNominalEntries(int value); + + void setSeed(long value); + + void merge(SketchState state); +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateFactory.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateFactory.java new file mode 100644 index 000000000000..d4210606fb55 --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateFactory.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.state; + +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.array.ObjectBigArray; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.GroupedAccumulatorState; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.theta.SetOperation; +import org.apache.datasketches.theta.Sketch; +import org.apache.datasketches.theta.Union; +import org.openjdk.jol.info.ClassLayout; + +public class SketchStateFactory + implements AccumulatorStateFactory +{ + @Override + public SketchState createSingleState() + { + return new SingleSketchState(); + } + + @Override + public Class getSingleStateClass() + { + return SingleSketchState.class; + } + + @Override + public SketchState createGroupedState() + { + return new GroupedSketchState(); + } + + @Override + public Class getGroupedStateClass() + { + return GroupedSketchState.class; + } + + public static class GroupedSketchState + implements GroupedAccumulatorState, SketchState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedSketchState.class).instanceSize(); + private int nominalEntries; + private long size; + private long seed; + private long groupId; + private ObjectBigArray unions; + + public GroupedSketchState() + { + unions = new ObjectBigArray(); + } + + @Override + public void ensureCapacity(long size) + { + unions.ensureCapacity(size); + } + + @Override + public long getEstimatedSize() + { + return INSTANCE_SIZE + size + unions.sizeOf(); + } + + @Override + public Slice getSketch() + { + Union union = getUnion(); + if (union != null) { + return Slices.wrappedBuffer(union.getResult().toByteArray()); + } + return null; + } + + @Override + public int getNominalEntries() + { + return nominalEntries; + } + + @Override + public long getSeed() + { + return seed; + } + + public Union getUnion() + { + return unions.get(groupId); + } + + public void setMemoryUsage(int value) + { + size = value; + } + + @Override + public void setGroupId(long groupId) + { + this.groupId = groupId; + } + + @Override + public void setNominalEntries(int value) + { + nominalEntries = value; + } + + @Override + public void setSeed(long value) + { + seed = value; + } + + @Override + public void setSketch(Slice value) + { + addSketchToUnion(value, nominalEntries); + } + + private void addSketchToUnion(Slice value, int nominalEntries) + { + Union groupedUnion = getUnion(); + if (groupedUnion == null) { + groupedUnion = SetOperation.builder().setNominalEntries(nominalEntries).buildUnion(); + groupedUnion.update(Memory.wrap(value.getBytes())); + unions.set(groupId, groupedUnion); + setMemoryUsage(value.length()); + return; + } + groupedUnion.update(Memory.wrap(value.getBytes())); + setMemoryUsage(Math.max(value.length(), (int) size)); + } + + @Override + public void merge(SketchState otherState) + { + addSketchToUnion(otherState.getSketch(), otherState.getNominalEntries()); + } + } + + public static class SingleSketchState + implements SketchState + { + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SingleSketchState.class).instanceSize(); + private Slice sketch; + private int nominalEntries; + private long seed; + + @Override + public Slice getSketch() + { + return sketch; + } + + @Override + public int getNominalEntries() + { + return nominalEntries; + } + + @Override + public long getSeed() + { + return seed; + } + + @Override + public void setSketch(Slice value) + { + sketch = value; + } + + @Override + public void setNominalEntries(int value) + { + nominalEntries = value; + } + + @Override + public void setSeed(long value) + { + seed = value; + } + + @Override + public long getEstimatedSize() + { + long estimatedSize = INSTANCE_SIZE; + if (sketch != null) { + estimatedSize += sketch.getRetainedSize(); + } + return estimatedSize; + } + + @Override + public void merge(SketchState otherState) + { + int normEntries = Math.max(this.getNominalEntries(), otherState.getNominalEntries()); + Union union = SetOperation.builder().setSeed(this.getSeed()).setNominalEntries(normEntries).buildUnion(); + union.update(Memory.wrap(this.getSketch().getBytes())); + union.update(Memory.wrap(otherState.getSketch().getBytes())); + Sketch unionResult = union.getResult(); + this.setSketch(Slices.wrappedBuffer(unionResult.toByteArray())); + } + } +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateSerializer.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateSerializer.java new file mode 100644 index 000000000000..a82b4b7627db --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/state/SketchStateSerializer.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.state; + +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.type.Type; + +import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.airlift.slice.SizeOf.SIZE_OF_LONG; +import static io.trino.spi.type.VarbinaryType.VARBINARY; + +public class SketchStateSerializer + implements AccumulatorStateSerializer +{ + @Override + public Type getSerializedType() + { + return VARBINARY; + } + + @Override + public void serialize(SketchState state, BlockBuilder out) + { + if (state.getSketch() == null) { + out.appendNull(); + } + else { + Slice slice = state.getSketch(); + + SliceOutput sliceOutput = new DynamicSliceOutput(SIZE_OF_LONG + SIZE_OF_INT + SIZE_OF_INT + slice.length()); + sliceOutput.appendInt(state.getNominalEntries()); + sliceOutput.appendLong(state.getSeed()); + + sliceOutput.appendInt(slice.length()); + sliceOutput.appendBytes(slice); + + VARBINARY.writeSlice(out, sliceOutput.slice()); + } + } + + @Override + public void deserialize(Block block, int index, SketchState state) + { + Slice slice = VARBINARY.getSlice(block, index); + SliceInput input = slice.getInput(); + + state.setNominalEntries(input.readInt()); + state.setSeed(input.readLong()); + + int sketchLength = input.readInt(); + state.setSketch(input.readSlice(sketchLength)); + } +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Estimate.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Estimate.java new file mode 100644 index 000000000000..bc7ac7b38431 --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Estimate.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.theta; + +import io.airlift.slice.Slice; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.theta.Sketches; + +import static org.apache.datasketches.Util.DEFAULT_UPDATE_SEED; + +public class Estimate +{ + private Estimate() {} + + @ScalarFunction("theta_sketch_estimate") + @Description("Converts sketch bytearrays to double estimate") + @SqlType(StandardTypes.DOUBLE) + public static double estimate(@SqlType(StandardTypes.VARBINARY) Slice inputValue) + { + return estimate(inputValue, DEFAULT_UPDATE_SEED); + } + + @ScalarFunction("theta_sketch_estimate") + @Description("Converts sketch bytearrays to double estimate") + @SqlType(StandardTypes.DOUBLE) + public static double estimate(@SqlType(StandardTypes.VARBINARY) Slice inputValue, @SqlType(StandardTypes.BIGINT) long seed) + { + if (inputValue.getBytes() == null || inputValue.getBytes().length == 0) { + return 0; + } + return Sketches.wrapSketch(Memory.wrap(inputValue.getBytes()), seed).getEstimate(); + } +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/SketchFunctionsPlugin.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/SketchFunctionsPlugin.java new file mode 100644 index 000000000000..daed92eddd33 --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/SketchFunctionsPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.theta; + +import com.google.common.collect.ImmutableSet; +import io.trino.spi.Plugin; + +import java.util.Set; + +public class SketchFunctionsPlugin + implements Plugin +{ + @Override + public Set> getFunctions() + { + return ImmutableSet.>builder() + .add(Estimate.class) + .add(Union.class) + .add(UnionWithParams.class) + .build(); + } +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Union.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Union.java new file mode 100644 index 000000000000..67f250bddc2c --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/Union.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.theta; + +import io.airlift.slice.Slice; +import io.trino.plugin.datasketches.state.SketchState; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; + +import static org.apache.datasketches.Util.DEFAULT_NOMINAL_ENTRIES; +import static org.apache.datasketches.Util.DEFAULT_UPDATE_SEED; + +@AggregationFunction("theta_sketch_union") +public final class Union +{ + private Union() {} + + @InputFunction + public static void input(@AggregationState SketchState state, @SqlType(StandardTypes.VARBINARY) Slice inputValue) + { + state.setNominalEntries(DEFAULT_NOMINAL_ENTRIES); + state.setSeed(DEFAULT_UPDATE_SEED); + state.setSketch(inputValue); + } + + @CombineFunction + public static void combine(@AggregationState SketchState state, SketchState otherState) + { + UnionWithParams.combine(state, otherState); + } + + @OutputFunction(StandardTypes.VARBINARY) + public static void output(@AggregationState SketchState state, BlockBuilder out) + { + UnionWithParams.output(state, out); + } +} diff --git a/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/UnionWithParams.java b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/UnionWithParams.java new file mode 100644 index 000000000000..ad013200ad0b --- /dev/null +++ b/plugin/trino-datasketches/src/main/java/io/trino/plugin/datasketches/theta/UnionWithParams.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.theta; + +import io.airlift.slice.Slice; +import io.trino.plugin.datasketches.state.SketchState; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; +import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.InputFunction; +import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; + +@AggregationFunction("theta_sketch_union") +public final class UnionWithParams +{ + private UnionWithParams() {} + + @InputFunction + public static void input(@AggregationState SketchState state, @SqlType(StandardTypes.VARBINARY) Slice inputValue, @SqlType(StandardTypes.INTEGER) Integer normEntries, @SqlType(StandardTypes.BIGINT) Long seed) + { + state.setNominalEntries(normEntries); + state.setSeed(seed); + state.setSketch(inputValue); + } + + @CombineFunction + public static void combine(@AggregationState SketchState state, SketchState otherState) + { + if (otherState == null || otherState.getSketch() == null) { + return; + } + + if (state == null || state.getSketch() == null) { + state.setSeed(otherState.getSeed()); + state.setNominalEntries(otherState.getNominalEntries()); + state.setSketch(otherState.getSketch()); + return; + } + + state.merge(otherState); + } + + @OutputFunction(StandardTypes.VARBINARY) + public static void output(@AggregationState SketchState state, BlockBuilder out) + { + Slice sketch = state.getSketch(); + if (sketch == null) { + out.appendNull(); + return; + } + + out.writeBytes(sketch, 0, sketch.length()); + out.closeEntry(); + } +} diff --git a/plugin/trino-datasketches/src/test/java/io/trino/plugin/datasketches/theta/TestMergeEstimate.java b/plugin/trino-datasketches/src/test/java/io/trino/plugin/datasketches/theta/TestMergeEstimate.java new file mode 100644 index 000000000000..2ce2f05703c3 --- /dev/null +++ b/plugin/trino-datasketches/src/test/java/io/trino/plugin/datasketches/theta/TestMergeEstimate.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.datasketches.theta; + +import io.airlift.slice.Slices; +import io.trino.Session; +import io.trino.plugin.datasketches.state.SketchState; +import io.trino.plugin.datasketches.state.SketchStateFactory; +import io.trino.plugin.hive.TestingHivePlugin; +import io.trino.plugin.hive.authentication.HiveIdentity; +import io.trino.plugin.hive.metastore.Database; +import io.trino.plugin.hive.metastore.HiveMetastore; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.security.PrincipalType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import org.apache.datasketches.theta.SetOperation; +import org.apache.datasketches.theta.Sketch; +import org.apache.datasketches.theta.UpdateSketch; +import org.apache.hadoop.io.BytesWritable; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Arrays; + +import static io.trino.plugin.hive.metastore.file.FileHiveMetastore.createTestingFileHiveMetastore; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.testing.MaterializedResult.resultBuilder; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.apache.datasketches.Util.DEFAULT_NOMINAL_ENTRIES; +import static org.apache.datasketches.Util.DEFAULT_UPDATE_SEED; +import static org.testng.Assert.assertEquals; + +public class TestMergeEstimate + extends AbstractTestQueryFramework +{ + private SketchState state1; + private SketchState state2; + + private static final long TEST_SEED = 95869L; + private static final int TEST_ENTRIES = 2048; + + @BeforeClass + public void setup() + { + AccumulatorStateFactory factory = new SketchStateFactory(); + state1 = factory.createSingleState(); + state2 = factory.createSingleState(); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("hive") + .setSchema("default") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); + queryRunner.installPlugin(new SketchFunctionsPlugin()); + + File baseDir = queryRunner.getCoordinator().getBaseDataDir().resolve("hive_data").toFile(); + + HiveMetastore metastore = createTestingFileHiveMetastore(baseDir); + + metastore.createDatabase( + new HiveIdentity(SESSION), + Database.builder() + .setDatabaseName("default") + .setOwnerName("public") + .setOwnerType(PrincipalType.ROLE) + .build()); + queryRunner.installPlugin(new TestingHivePlugin(metastore)); + queryRunner.createCatalog("hive", "hive"); + return queryRunner; + } + + @Test + public void testSimpleMerge() + { + // 1000 unique keys + UpdateSketch sketch1 = UpdateSketch.builder().setSeed(TEST_SEED).build(); + for (int key = 0; key < 1000; key++) { + sketch1.update(key); + } + + // 1000 unique keys + // the first 500 unique keys overlap with sketch1 + UpdateSketch sketch2 = UpdateSketch.builder().setSeed(TEST_SEED).build(); + for (int key = 500; key < 1500; key++) { + sketch2.update(key); + } + + UnionWithParams.input(state1, Slices.wrappedBuffer(sketch1.compact().toByteArray()), TEST_ENTRIES, TEST_SEED); + UnionWithParams.input(state2, Slices.wrappedBuffer(sketch2.compact().toByteArray()), TEST_ENTRIES, TEST_SEED); + Union.combine(state1, state2); + double estimate = Estimate.estimate(state1.getSketch(), TEST_SEED); + + org.apache.datasketches.theta.Union union = SetOperation.builder().setSeed(TEST_SEED).buildUnion(); + union.update(sketch1); + union.update(sketch2); + Sketch unionResult = union.getResult(); + + assertEquals(unionResult.getEstimate(), estimate, 0); + } + + /** + * Test case similar to theta sketch hive documented test at https://datasketches.apache.org/docs/Theta/ThetaHiveUDFs.html + */ + @Test + public void testMergeAndEstimate() + { + int[] idACatagory = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + int[] idBCatagory = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + assertUpdate("CREATE TABLE sketch_intermediate (category VARCHAR, sketch VARBINARY)"); + assertUpdate("INSERT INTO sketch_intermediate VALUES ('a', X'" + getBytesWritableSketchUnion(idACatagory) + "')", 1); + assertUpdate("INSERT INTO sketch_intermediate VALUES ('b', X'" + getBytesWritableSketchUnion(idBCatagory) + "')", 1); + + MaterializedResult actualEstimateResult = computeActual("SELECT category, theta_sketch_estimate(sketch) FROM sketch_intermediate ORDER BY category ASC"); + assertEquals(actualEstimateResult.getRowCount(), 2); + MaterializedResult expectedEstimateResult = resultBuilder(getSession(), VARCHAR, DOUBLE) + .row("a", 10d) + .row("b", 10d) + .build(); + assertEquals(actualEstimateResult.getMaterializedRows(), expectedEstimateResult.getMaterializedRows()); + + double mergeEstimateResult = (double) computeScalar("SELECT theta_sketch_estimate(theta_sketch_union(sketch)) FROM sketch_intermediate"); + assertEquals(mergeEstimateResult, 15d); + } + + private BytesWritable getBytesWritableSketchUnion(int[] data) + { + org.apache.datasketches.theta.Union union = + SetOperation.builder().setSeed(DEFAULT_UPDATE_SEED).setNominalEntries(DEFAULT_NOMINAL_ENTRIES).buildUnion(); + Arrays.stream(data).forEach(e -> union.update(e)); + + byte[] resultSketch = union.getResult().toByteArray(); + BytesWritable resultBytes = new BytesWritable(); + resultBytes.set(resultSketch, 0, resultSketch.length); + return resultBytes; + } +} diff --git a/pom.xml b/pom.xml index 3cb91e0dd3fc..4e79781158d5 100644 --- a/pom.xml +++ b/pom.xml @@ -112,6 +112,7 @@ plugin/trino-blackhole plugin/trino-cassandra plugin/trino-clickhouse + plugin/trino-datasketches plugin/trino-druid plugin/trino-elasticsearch plugin/trino-example-http