diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java index 867ae535ca80..03674b404dee 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java @@ -105,6 +105,7 @@ import io.trino.operator.scalar.ArraySortComparatorFunction; import io.trino.operator.scalar.ArraySortFunction; import io.trino.operator.scalar.ArrayToArrayCast; +import io.trino.operator.scalar.ArrayTrimFunction; import io.trino.operator.scalar.ArrayUnionFunction; import io.trino.operator.scalar.ArraysOverlapFunction; import io.trino.operator.scalar.BitwiseFunctions; @@ -548,6 +549,7 @@ public FunctionRegistry( .scalar(ArrayUnionFunction.class) .scalar(ArrayExceptFunction.class) .scalar(ArraySliceFunction.class) + .scalar(ArrayTrimFunction.class) .scalar(ArrayCombinationsFunction.class) .scalar(ArrayNgramsFunction.class) .scalar(ArrayAllMatchFunction.class) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java new file mode 100644 index 000000000000..9b253f1ba8be --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTrimFunction.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.spi.block.Block; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.Type; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.util.Failures.checkCondition; +import static java.lang.Math.toIntExact; + +@ScalarFunction("trim_array") +@Description("Remove elements from the end of array") +public final class ArrayTrimFunction +{ + private ArrayTrimFunction() {} + + @TypeParameter("E") + @SqlType("array(E)") + public static Block trim( + @TypeParameter("E") Type type, + @SqlType("array(E)") Block array, + @SqlType(StandardTypes.BIGINT) long size) + { + checkCondition(size >= 0, INVALID_FUNCTION_ARGUMENT, "size must not be negative: %s", size); + checkCondition(size <= array.getPositionCount(), INVALID_FUNCTION_ARGUMENT, "size must not exceed array cardinality %s: %s", array.getPositionCount(), size); + + return array.getRegion(0, toIntExact(array.getPositionCount() - size)); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayTrimFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayTrimFunction.java new file mode 100644 index 000000000000..06beab4a306f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayTrimFunction.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.type.ArrayType; +import org.testng.annotations.Test; + +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public class TestArrayTrimFunction + extends AbstractTestFunctions +{ + @Test + public void testTrimArray() + { + assertFunction("trim_array(ARRAY[1, 2, 3, 4], 0)", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3, 4)); + assertFunction("trim_array(ARRAY[1, 2, 3, 4], 1)", new ArrayType(INTEGER), ImmutableList.of(1, 2, 3)); + assertFunction("trim_array(ARRAY[1, 2, 3, 4], 2)", new ArrayType(INTEGER), ImmutableList.of(1, 2)); + assertFunction("trim_array(ARRAY[1, 2, 3, 4], 3)", new ArrayType(INTEGER), ImmutableList.of(1)); + assertFunction("trim_array(ARRAY[1, 2, 3, 4], 4)", new ArrayType(INTEGER), ImmutableList.of()); + + assertFunction("trim_array(ARRAY['a', 'b', 'c', 'd'], 1)", new ArrayType(createVarcharType(1)), ImmutableList.of("a", "b", "c")); + assertFunction("trim_array(ARRAY['a', 'b', null, 'd'], 1)", new ArrayType(createVarcharType(1)), asList("a", "b", null)); + assertFunction("trim_array(ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5, 6]], 1)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2, 3))); + + assertInvalidFunction("trim_array(ARRAY[1, 2, 3, 4], 5)", "size must not exceed array cardinality 4: 5"); + assertInvalidFunction("trim_array(ARRAY[1, 2, 3, 4], -1)", "size must not be negative: -1"); + } +} diff --git a/docs/src/main/sphinx/functions/array.rst b/docs/src/main/sphinx/functions/array.rst index b4db6007fe32..8e78ed74590f 100644 --- a/docs/src/main/sphinx/functions/array.rst +++ b/docs/src/main/sphinx/functions/array.rst @@ -319,6 +319,16 @@ Array functions Subsets array ``x`` starting from index ``start`` (or starting from the end if ``start`` is negative) with a length of ``length``. +.. function:: trim_array(x, n) -> array + + Remove ``n`` elements from the end of array:: + + SELECT trim_array(ARRAY[1, 2, 3, 4], 1); + -- [1, 2, 3] + + SELECT trim_array(ARRAY[1, 2, 3, 4], 2); + -- [1, 2] + .. function:: transform(array(T), function(T,U)) -> array(U) Returns an array that is the result of applying ``function`` to each element of ``array``:: diff --git a/docs/src/main/sphinx/functions/list-by-topic.rst b/docs/src/main/sphinx/functions/list-by-topic.rst index 71bfb5d7566b..639d1d4e0721 100644 --- a/docs/src/main/sphinx/functions/list-by-topic.rst +++ b/docs/src/main/sphinx/functions/list-by-topic.rst @@ -84,6 +84,7 @@ For more details, see :doc:`array` * :func:`shuffle` * :func:`slice` * :func:`transform` +* :func:`trim_array` * :func:`zip` * :func:`zip_with` diff --git a/docs/src/main/sphinx/functions/list.rst b/docs/src/main/sphinx/functions/list.rst index 491cccb674a0..d7847ce0fda2 100644 --- a/docs/src/main/sphinx/functions/list.rst +++ b/docs/src/main/sphinx/functions/list.rst @@ -490,6 +490,7 @@ T - :func:`transform_values` - :func:`translate` - :func:`trim` +- :func:`trim_array` - :func:`truncate` - :ref:`try ` - :func:`try_cast`