Skip to content

Commit

Permalink
Add support for row subscript
Browse files Browse the repository at this point in the history
Allows accessing row fields with subscript expression
ROW (1, 'a', true)[2] --> 'a'
  • Loading branch information
kasiafi authored and martint committed Jun 4, 2019
1 parent de75aef commit 85355ff
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
import static io.prestosql.util.DateTimeUtils.parseTimestampLiteral;
import static io.prestosql.util.DateTimeUtils.timeHasTimeZone;
import static io.prestosql.util.DateTimeUtils.timestampHasTimeZone;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
Expand Down Expand Up @@ -639,6 +640,28 @@ private Type getVarcharType(Expression value, StackableAstVisitorContext<Context
@Override
protected Type visitSubscriptExpression(SubscriptExpression node, StackableAstVisitorContext<Context> context)
{
Type baseType = process(node.getBase(), context);
// Subscript on Row hasn't got a dedicated operator. Its Type is resolved by hand.
if (baseType instanceof RowType) {
if (!(node.getIndex() instanceof LongLiteral)) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node.getIndex(), "Subscript expression on ROW requires a constant index");
}
Type indexType = process(node.getIndex(), context);
if (!indexType.equals(INTEGER)) {
throw new SemanticException(TYPE_MISMATCH, node.getIndex(), "Subscript expression on ROW requires integer index, found %s", indexType);
}
int indexValue = toIntExact(((LongLiteral) node.getIndex()).getValue());
if (indexValue <= 0) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node.getIndex(), "Invalid subscript index: %s. ROW indices start at 1", indexValue);
}
List<Type> rowTypes = baseType.getTypeParameters();
if (indexValue > rowTypes.size()) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node.getIndex(), "Subscript index out of bounds: %s, max value is %s", indexValue, rowTypes.size());
}
return setExpressionType(node, rowTypes.get(indexValue - 1));
}

// Subscript on Array or Map uses an operator to resolve Type.
return getOperator(context, node, SUBSCRIPT, node.getBase(), node.getIndex());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.DereferenceExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.Identifier;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SubscriptExpression;

import java.util.Map;
import java.util.Optional;

import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
* Replaces subscript expression on Row
* with cast and dereference:
* <pre>
* ROW (1, 'a', 2) [2]
* </pre>
* is transformed into:
* <pre>
* (CAST (ROW (1, 'a', 2) AS ROW (field_0 bigint, field_1 varchar(1), field_2 bigint))).field_1
* </pre>
*/
public class DesugarRowSubscriptRewriter
{
private DesugarRowSubscriptRewriter() {}

public static Expression rewrite(Expression expression, Session session, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator)
{
requireNonNull(typeAnalyzer, "typeAnalyzer is null");
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), expression);
return rewrite(expression, expressionTypes);
}

public static Expression rewrite(Expression expression, Map<NodeRef<Expression>, Type> expressionTypes)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(expressionTypes), expression);
}

private static class Visitor
extends ExpressionRewriter<Void>
{
private final Map<NodeRef<Expression>, Type> expressionTypes;

public Visitor(Map<NodeRef<Expression>, Type> expressionTypes)
{
this.expressionTypes = ImmutableMap.copyOf(requireNonNull(expressionTypes, "expressionTypes is null"));
}

@Override
public Expression rewriteSubscriptExpression(SubscriptExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression base = node.getBase();
Expression index = node.getIndex();

Expression result = node;

Type type = getType(base);
if (type instanceof RowType) {
RowType rowType = (RowType) type;
int position = toIntExact(((LongLiteral) index).getValue() - 1);

Optional<String> fieldName = rowType.getFields().get(position).getName();

// Do not cast if Row fields are named
if (fieldName.isPresent()) {
result = new DereferenceExpression(base, new Identifier(fieldName.get()));
}
else {
// Cast to Row with named fields
ImmutableList.Builder<RowType.Field> namedFields = new ImmutableList.Builder<>();
for (int i = 0; i < rowType.getFields().size(); i++) {
namedFields.add(new RowType.Field(Optional.of("f" + i), rowType.getTypeParameters().get(i)));
}
RowType namedRowType = RowType.from(namedFields.build());
Cast cast = new Cast(base, namedRowType.getTypeSignature().toString());
result = new DereferenceExpression(cast, new Identifier("f" + position));
}
}

return treeRewriter.defaultRewrite(result, context);
}

private Type getType(Expression expression)
{
return expressionTypes.get(NodeRef.of(expression));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.block.RowBlockBuilder;
import io.prestosql.spi.block.SingleRowBlock;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.ArrayType;
Expand Down Expand Up @@ -120,6 +121,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.prestosql.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL;
import static io.prestosql.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.prestosql.spi.type.TypeSignature.parseTypeSignature;
import static io.prestosql.spi.type.TypeUtils.readNativeValue;
Expand All @@ -134,6 +136,7 @@
import static io.prestosql.type.LikeFunctions.isLikePattern;
import static io.prestosql.type.LikeFunctions.unescapeLiteralLikePattern;
import static io.prestosql.util.Failures.checkCondition;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -1168,6 +1171,18 @@ protected Object visitSubscriptExpression(SubscriptExpression node, Object conte
return new SubscriptExpression(toExpression(base, type(node.getBase())), toExpression(index, type(node.getIndex())));
}

// Subscript on Row hasn't got a dedicated operator. It is interpreted by hand.
if (base instanceof SingleRowBlock) {
SingleRowBlock row = (SingleRowBlock) base;
int position = toIntExact((long) index - 1);
if (position < 0 || position >= row.getPositionCount()) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "ROW index out of bounds: " + (position + 1));
}
Type returnType = type(node.getBase()).getTypeParameters().get(position);
return readNativeValue(returnType, row, position);
}

// Subscript on Array or Map is interpreted using operator.
return invokeOperator(OperatorType.SUBSCRIPT, types(node.getBase(), node.getIndex()), ImmutableList.of(base, index));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.prestosql.sql.planner.iterative.rule.DesugarCurrentPath;
import io.prestosql.sql.planner.iterative.rule.DesugarCurrentUser;
import io.prestosql.sql.planner.iterative.rule.DesugarLambdaExpression;
import io.prestosql.sql.planner.iterative.rule.DesugarRowSubscript;
import io.prestosql.sql.planner.iterative.rule.DesugarTryExpression;
import io.prestosql.sql.planner.iterative.rule.DetermineJoinDistributionType;
import io.prestosql.sql.planner.iterative.rule.DetermineSemiJoinDistributionType;
Expand Down Expand Up @@ -286,6 +287,7 @@ public PlanOptimizers(
.addAll(new DesugarCurrentUser().rules())
.addAll(new DesugarCurrentPath().rules())
.addAll(new DesugarTryExpression().rules())
.addAll(new DesugarRowSubscript(typeAnalyzer).rules())
.build()),
new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import io.prestosql.sql.planner.DesugarRowSubscriptRewriter;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.iterative.Rule;

import java.util.Set;

import static java.util.Objects.requireNonNull;

public class DesugarRowSubscript
extends ExpressionRewriteRuleSet
{
public DesugarRowSubscript(TypeAnalyzer typeAnalyzer)
{
super(createRewrite(typeAnalyzer));
}

@Override
public Set<Rule<?>> rules()
{
return ImmutableSet.of(
projectExpressionRewrite(),
aggregationExpressionRewrite(),
filterExpressionRewrite(),
joinExpressionRewrite(),
valuesExpressionRewrite());
}

private static ExpressionRewriter createRewrite(TypeAnalyzer typeAnalyzer)
{
requireNonNull(typeAnalyzer, "typeAnalyzer is null");

return (expression, context) -> DesugarRowSubscriptRewriter.rewrite(expression, context.getSession(), typeAnalyzer, context.getSymbolAllocator());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,13 @@ public void testRowConstructor()
optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]");
}

@Test
public void testRowSubscript()
{
assertOptimizedEquals("ROW (1, 'a', true)[3]", "true");
assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'");
}

@Test(expectedExceptions = PrestoException.class)
public void testArraySubscriptConstantNegativeIndex()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ public void testExpressions()
// row type
assertFails(TYPE_MISMATCH, "SELECT t.x.f1 FROM (VALUES 1) t(x)");
assertFails(TYPE_MISMATCH, "SELECT x.f1 FROM (VALUES 1) t(x)");

// subscript on Row
assertFails(INVALID_PARAMETER_USAGE, "line 1:20: Subscript expression on ROW requires a constant index", "SELECT ROW(1, 'a')[x]");
assertFails(TYPE_MISMATCH, "line 1:20: Subscript expression on ROW requires integer index, found bigint", "SELECT ROW(1, 'a')[9999999999]");
assertFails(INVALID_PARAMETER_USAGE, "line 1:20: Invalid subscript index: -1. ROW indices start at 1", "SELECT ROW(1, 'a')[-1]");
assertFails(INVALID_PARAMETER_USAGE, "line 1:20: Invalid subscript index: 0. ROW indices start at 1", "SELECT ROW(1, 'a')[0]");
assertFails(INVALID_PARAMETER_USAGE, "line 1:20: Subscript index out of bounds: 5, max value is 2", "SELECT ROW(1, 'a')[5]");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ public void testArraySubscript()
}
}

@Test
public void testRowSubscript()
{
assertExpression("ROW (1, 'a', true)[1]", new SubscriptExpression(
new Row(ImmutableList.of(new LongLiteral("1"), new StringLiteral("a"), new BooleanLiteral("true"))),
new LongLiteral("1")));
}

@Test
public void testDouble()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,27 @@ public void testMapSubscript()
assertQuery("SELECT map(array[(1,2)], array['a'])[(1,2)]", "SELECT 'a'");
}

@Test
public void testRowSubscript()
{
// Subscript on Row with unnamed fields
assertQuery("SELECT ROW (1, 'a', true)[2]", "SELECT 'a'");
assertQuery("SELECT r[2] FROM (VALUES (ROW (ROW (1, 'a', true)))) AS v(r)", "SELECT 'a'");
assertQuery("SELECT r[1], r[2] FROM (SELECT ROW (name, regionkey) FROM nation ORDER BY name LIMIT 1) t(r)", "VALUES ('ALGERIA', 0)");

// Subscript on Row with named fields
assertQuery("SELECT (CAST (ROW (1, 'a', 2 ) AS ROW (field1 bigint, field2 varchar(1), field3 bigint)))[2]", "SELECT 'a'");

// Subscript on nested Row
assertQuery("SELECT ROW (1, 'a', ROW (false, 2, 'b'))[3][3]", "SELECT 'b'");

// Row subscript in filter condition
assertQuery("SELECT orderstatus FROM orders WHERE ROW (orderkey, custkey)[1] = 100", "SELECT 'O'");

// Row subscript in join condition
assertQuery("SELECT n.name, r.name FROM nation n JOIN region r ON ROW (n.name, n.regionkey)[2] = ROW (r.name, r.regionkey)[2] ORDER BY n.name LIMIT 1", "VALUES ('ALGERIA', 'AFRICA')");
}

@Test
public void testVarbinary()
{
Expand Down

0 comments on commit 85355ff

Please sign in to comment.