Skip to content

Commit

Permalink
Add quantile_at_value function
Browse files Browse the repository at this point in the history
Co-authored-by: Peizhen Guo <[email protected]>
  • Loading branch information
2 people authored and Praveen2112 committed Mar 29, 2023
1 parent 500a04a commit 6158839
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.stats.QuantileDigest;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.Description;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlNullable;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.doubleToSortableLong;
import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.floatToSortableInt;
import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.sortableIntToFloat;
import static io.trino.operator.aggregation.FloatingPointBitsConverterUtil.sortableLongToDouble;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
Expand All @@ -30,6 +35,7 @@
import static io.trino.spi.type.RealType.REAL;
import static io.trino.util.Failures.checkCondition;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Float.intBitsToFloat;

public final class QuantileDigestFunctions
{
Expand Down Expand Up @@ -62,6 +68,38 @@ public static long valueAtQuantileBigint(@SqlType("qdigest(bigint)") Slice input
return new QuantileDigest(input).getQuantile(quantile);
}

@ScalarFunction("quantile_at_value")
@Description("Given an input x between min/max values of qdigest, find which quantile is represented by that value")
@SqlType(StandardTypes.DOUBLE)
@SqlNullable
public static Double quantileAtValueDouble(@SqlType("qdigest(double)") Slice input, @SqlType(StandardTypes.DOUBLE) double value)
{
return quantileAtValueBigint(input, doubleToSortableLong(value));
}

@ScalarFunction("quantile_at_value")
@Description("Given an input x between min/max values of qdigest, find which quantile is represented by that value")
@SqlType(StandardTypes.DOUBLE)
@SqlNullable
public static Double quantileAtValueReal(@SqlType("qdigest(real)") Slice input, @SqlType(StandardTypes.REAL) long value)
{
return quantileAtValueBigint(input, floatToSortableInt(intBitsToFloat((int) value)));
}

@ScalarFunction("quantile_at_value")
@Description("Given an input x between min/max values of qdigest, find which quantile is represented by that value")
@SqlType(StandardTypes.DOUBLE)
@SqlNullable
public static Double quantileAtValueBigint(@SqlType("qdigest(bigint)") Slice input, @SqlType(StandardTypes.BIGINT) long value)
{
QuantileDigest digest = new QuantileDigest(input);
if (digest.getCount() == 0 || value > digest.getMax() || value < digest.getMin()) {
return null;
}
double bucketCount = getOnlyElement(digest.getHistogram(ImmutableList.of(value))).getCount();
return bucketCount / digest.getCount();
}

@ScalarFunction("values_at_quantiles")
@Description("For each input q between [0, 1], find the value whose rank in the sorted sequence of the n values represented by the qdigest is qn.")
@SqlType("array(double)")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;

import io.airlift.stats.QuantileDigest;
import io.trino.spi.type.SqlVarbinary;
import io.trino.sql.query.QueryAssertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.stream.IntStream;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
public class TestQuantileDigestFunctions
{
private QueryAssertions assertions;

@BeforeAll
public void init()
{
assertions = new QueryAssertions();
}

@AfterAll
public void teardown()
{
assertions.close();
assertions = null;
}

@Test
public void testQuantileAtValueBigint()
{
QuantileDigest qdigest = new QuantileDigest(1);
addAll(qdigest, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(bigint)), 20)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(bigint)), 6)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(0.6);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(bigint)), -1)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
}

@Test
public void testQuantileAtValueDouble()
{
QuantileDigest qdigest = new QuantileDigest(1);
IntStream.of(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
.mapToLong(FloatingPointBitsConverterUtil::doubleToSortableLong)
.forEach(qdigest::add);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(double)), 5.6)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(0.6);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(double)), -1.23)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(double)), 12.3)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(double)), nan())")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
}

@Test
public void testQuantileAtValueBigintWithEmptyDigest()
{
QuantileDigest qdigest = new QuantileDigest(1);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(bigint)), 5)")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isNull();
}

@Test
public void testQuantileRoundTrip()
{
QuantileDigest qdigest = new QuantileDigest(1);
addAll(qdigest, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);

assertThat(assertions
.expression("value_at_quantile(CAST(a AS qdigest(bigint)), quantile_at_value(CAST(a AS qdigest(bigint)), 6))")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(6L);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(bigint)), value_at_quantile(CAST(a AS qdigest(bigint)), .6))")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(.6);

qdigest = new QuantileDigest(1);
IntStream.range(0, 10)
.mapToLong(FloatingPointBitsConverterUtil::doubleToSortableLong)
.forEach(qdigest::add);

assertThat(assertions
.expression("value_at_quantile(CAST(a AS qdigest(double)),quantile_at_value(CAST(a AS qdigest(double)), 5.6))")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(6.);
assertThat(assertions
.expression("quantile_at_value(CAST(a AS qdigest(double)),value_at_quantile(CAST(a AS qdigest(double)), .6))")
.binding("a", "X'%s'".formatted(toHexString(qdigest))))
.isEqualTo(.6);
}

private static void addAll(QuantileDigest digest, long... values)
{
for (long value : values) {
digest.add(value);
}
}

private static String toHexString(QuantileDigest qdigest)
{
return new SqlVarbinary(qdigest.serialize().getBytes()).toString().replaceAll("\\s+", " ");
}
}
6 changes: 6 additions & 0 deletions docs/src/main/sphinx/functions/qdigest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ Functions
Returns the approximate percentile value from the quantile digest given
the number ``quantile`` between 0 and 1.

.. function:: quantile_at_value(qdigest(T), T) -> quantile

Returns the approximate ``quantile`` number between 0 and 1 from the
quantile digest given an input value. Null is returned if the quantile digest
is empty or the input value is outside of the range of the quantile digest.

.. function:: values_at_quantiles(qdigest(T), quantiles) -> array(T)

Returns the approximate percentile values as an array given the input
Expand Down

0 comments on commit 6158839

Please sign in to comment.