Skip to content

Commit

Permalink
Refactor JSON functions to use parameter array
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi committed May 27, 2022
1 parent 811fb3a commit dce2a6e
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
import java.util.Objects;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static com.google.common.base.Preconditions.checkArgument;

public class IrNamedJsonVariable
extends IrPathNode
{
private final String name;
private final int index;

@JsonCreator
public IrNamedJsonVariable(@JsonProperty("name") String name, @JsonProperty("type") Optional<Type> type)
public IrNamedJsonVariable(@JsonProperty("index") int index, @JsonProperty("type") Optional<Type> type)
{
super(type);
this.name = requireNonNull(name, "name is null");
checkArgument(index >= 0, "parameter index is negative");
this.index = index;
}

@Override
Expand All @@ -41,9 +42,9 @@ protected <R, C> R accept(IrJsonPathVisitor<R, C> visitor, C context)
}

@JsonProperty
public String getName()
public int getIndex()
{
return name;
return index;
}

@Override
Expand All @@ -56,12 +57,12 @@ public boolean equals(Object obj)
return false;
}
IrNamedJsonVariable other = (IrNamedJsonVariable) obj;
return Objects.equals(this.name, other.name);
return this.index == other.index;
}

@Override
public int hashCode()
{
return Objects.hash(name);
return Objects.hash(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
import java.util.Objects;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static com.google.common.base.Preconditions.checkArgument;

public class IrNamedValueVariable
extends IrPathNode
{
private final String name;
private final int index;

@JsonCreator
public IrNamedValueVariable(@JsonProperty("name") String name, @JsonProperty("type") Optional<Type> type)
public IrNamedValueVariable(@JsonProperty("index") int index, @JsonProperty("type") Optional<Type> type)
{
super(type);
this.name = requireNonNull(name, "name is null");
checkArgument(index >= 0, "parameter index is negative");
this.index = index;
}

@Override
Expand All @@ -41,9 +42,9 @@ protected <R, C> R accept(IrJsonPathVisitor<R, C> visitor, C context)
}

@JsonProperty
public String getName()
public int getIndex()
{
return name;
return index;
}

@Override
Expand All @@ -56,12 +57,12 @@ public boolean equals(Object obj)
return false;
}
IrNamedValueVariable other = (IrNamedValueVariable) obj;
return Objects.equals(this.name, other.name);
return this.index == other.index;
}

@Override
public int hashCode()
{
return Objects.hash(name);
return Objects.hash(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Map;

import static io.trino.json.JsonInputErrorNode.JSON_ERROR;
import static io.trino.operator.scalar.json.JsonQueryFunction.getParametersMap;
import static io.trino.operator.scalar.json.ParameterUtil.getParametersArray;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
Expand Down Expand Up @@ -113,8 +112,8 @@ public static Boolean jsonExists(
if (inputExpression.equals(JSON_ERROR)) {
return handleError(errorBehavior, INPUT_ARGUMENT_ERROR); // ERROR ON ERROR was already handled by the input function
}
Map<String, Object> parameters = getParametersMap(parametersRowType, parametersRow); // TODO refactor
for (Object parameter : parameters.values()) {
Object[] parameters = getParametersArray(parametersRowType, parametersRow);
for (Object parameter : parameters) {
if (parameter.equals(JSON_ERROR)) {
return handleError(errorBehavior, PATH_PARAMETER_ERROR); // ERROR ON ERROR was already handled by the input function
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.NullNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.json.ir.IrJsonPath;
import io.trino.json.ir.TypedValue;
Expand All @@ -32,35 +30,29 @@
import io.trino.operator.scalar.ChoicesScalarFunctionImplementation;
import io.trino.operator.scalar.ScalarFunctionImplementation;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.planner.JsonPathEvaluator;
import io.trino.sql.planner.JsonPathEvaluator.PathEvaluationError;
import io.trino.sql.tree.JsonQuery.ArrayWrapperBehavior;
import io.trino.sql.tree.JsonQuery.EmptyOrErrorBehavior;
import io.trino.type.Json2016Type;
import io.trino.type.JsonPath2016Type;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.json.JsonEmptySequenceNode.EMPTY_SEQUENCE;
import static io.trino.json.JsonInputErrorNode.JSON_ERROR;
import static io.trino.json.ir.SqlJsonLiteralConverter.getJsonNode;
import static io.trino.operator.scalar.json.ParameterUtil.getParametersArray;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static io.trino.spi.type.StandardTypes.JSON_2016;
import static io.trino.spi.type.StandardTypes.TINYINT;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.analyzer.ExpressionAnalyzer.JSON_NO_PARAMETERS_ROW_TYPE;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -140,8 +132,8 @@ public static JsonNode jsonQuery(
if (inputExpression.equals(JSON_ERROR)) {
return handleSpecialCase(errorBehavior, INPUT_ARGUMENT_ERROR); // ERROR ON ERROR was already handled by the input function
}
Map<String, Object> parameters = getParametersMap(parametersRowType, parametersRow); // TODO refactor
for (Object parameter : parameters.values()) {
Object[] parameters = getParametersArray(parametersRowType, parametersRow);
for (Object parameter : parameters) {
if (parameter.equals(JSON_ERROR)) {
return handleSpecialCase(errorBehavior, PATH_PARAMETER_ERROR); // ERROR ON ERROR was already handled by the input function
}
Expand Down Expand Up @@ -216,39 +208,4 @@ private static JsonNode handleSpecialCase(long behavior, TrinoException error)
}
throw new IllegalStateException("unexpected behavior");
}

public static Map<String, Object> getParametersMap(Type parametersRowType, Object parametersRow)
{
if (JSON_NO_PARAMETERS_ROW_TYPE.equals(parametersRowType)) {
return ImmutableMap.of();
}

RowType rowType = (RowType) parametersRowType;
Block row = (Block) parametersRow;
List<Block> parameterBlocks = row.getChildren();

ImmutableMap.Builder<String, Object> map = ImmutableMap.builder();
for (int i = 0; i < rowType.getFields().size(); i++) {
RowType.Field field = rowType.getFields().get(i);
String name = field.getName().orElseThrow(() -> new IllegalStateException("missing parameter name"));
Type type = field.getType();
Object value = readNativeValue(type, parameterBlocks.get(i), 0);
if (type.equals(Json2016Type.JSON_2016)) {
if (value == null) {
map.put(name, EMPTY_SEQUENCE); // null as JSON value shall produce an empty sequence
}
else {
map.put(name, value);
}
}
else if (value == null) {
map.put(name, NullNode.getInstance()); // null as a non-JSON value shall produce a JSON null
}
else {
map.put(name, TypedValue.fromValueAsObject(type, value));
}
}

return map.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.json.JsonInputErrorNode.JSON_ERROR;
import static io.trino.json.ir.SqlJsonLiteralConverter.getTypedValue;
import static io.trino.operator.scalar.json.JsonQueryFunction.getParametersMap;
import static io.trino.operator.scalar.json.ParameterUtil.getParametersArray;
import static io.trino.spi.StandardErrorCode.JSON_VALUE_RESULT_ERROR;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
Expand Down Expand Up @@ -239,8 +238,8 @@ public static Object jsonValue(
if (inputExpression.equals(JSON_ERROR)) {
return handleSpecialCase(errorBehavior, errorDefault, INPUT_ARGUMENT_ERROR); // ERROR ON ERROR was already handled by the input function
}
Map<String, Object> parameters = getParametersMap(parametersRowType, parametersRow); // TODO refactor
for (Object parameter : parameters.values()) {
Object[] parameters = getParametersArray(parametersRowType, parametersRow);
for (Object parameter : parameters) {
if (parameter.equals(JSON_ERROR)) {
return handleSpecialCase(errorBehavior, errorDefault, PATH_PARAMETER_ERROR); // ERROR ON ERROR was already handled by the input function
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.scalar.json;

import com.fasterxml.jackson.databind.node.NullNode;
import io.trino.json.ir.TypedValue;
import io.trino.spi.block.Block;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.type.Json2016Type;

import java.util.List;

import static io.trino.json.JsonEmptySequenceNode.EMPTY_SEQUENCE;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.analyzer.ExpressionAnalyzer.JSON_NO_PARAMETERS_ROW_TYPE;

public final class ParameterUtil
{
private ParameterUtil() {}

/**
* Converts the parameters passed to json path into appropriate values,
* respecting the proper SQL semantics for nulls in the context of
* a path parameter, and collects them in an array.
* <p>
* All non-null values are passed as-is. Conversions apply in the following cases:
* - null value with FORMAT option is converted into an empty JSON sequence
* - null value without FORMAT option is converted into a JSON null.
*
* @param parametersRowType type of the Block containing parameters
* @param parametersRow a Block containing parameters
* @return an array containing the converted values
*/
public static Object[] getParametersArray(Type parametersRowType, Object parametersRow)
{
if (JSON_NO_PARAMETERS_ROW_TYPE.equals(parametersRowType)) {
return new Object[] {};
}

RowType rowType = (RowType) parametersRowType;
Block row = (Block) parametersRow;
List<Block> parameterBlocks = row.getChildren();

Object[] array = new Object[rowType.getFields().size()];
for (int i = 0; i < rowType.getFields().size(); i++) {
Type type = rowType.getFields().get(i).getType();
Object value = readNativeValue(type, parameterBlocks.get(i), 0);
if (type.equals(Json2016Type.JSON_2016)) {
if (value == null) {
array[i] = EMPTY_SEQUENCE; // null as JSON value shall produce an empty sequence
}
else {
array[i] = value;
}
}
else if (value == null) {
array[i] = NullNode.getInstance(); // null as a non-JSON value shall produce a JSON null
}
else {
array[i] = TypedValue.fromValueAsObject(type, value);
}
}

return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import io.trino.type.VarcharOperators;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

Expand Down Expand Up @@ -142,7 +141,7 @@
public class JsonPathEvaluator
{
private final JsonNode input;
private final Map<String, Object> parameters; // TODO refactor to Object[]
private final Object[] parameters;
private final Metadata metadata;
private final Session session;
private final ConnectorSession connectorSession;
Expand All @@ -151,10 +150,10 @@ public class JsonPathEvaluator
private final JsonPredicateEvaluator predicateEvaluator;
private int objectId;

public JsonPathEvaluator(JsonNode input, Map<String, Object> parameters, FunctionManager functionManager, Metadata metadata, TypeManager typeManager, ConnectorSession connectorSession)
public JsonPathEvaluator(JsonNode input, Object[] parameters, FunctionManager functionManager, Metadata metadata, TypeManager typeManager, ConnectorSession connectorSession)
{
this.input = requireNonNull(input, "input is null");
this.parameters = ImmutableMap.copyOf(requireNonNull(parameters, "parameters is null"));
this.parameters = requireNonNull(parameters, "parameters is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.connectorSession = requireNonNull(connectorSession, "connectorSession is null");
this.session = ((FullConnectorSession) connectorSession).getSession();
Expand Down Expand Up @@ -966,8 +965,8 @@ protected List<Object> visitIrMemberAccessor(IrMemberAccessor node, Context cont
@Override
protected List<Object> visitIrNamedJsonVariable(IrNamedJsonVariable node, Context context)
{
Object value = parameters.get(node.getName());
checkState(value != null, "missing value for parameter " + node.getName());
Object value = parameters[node.getIndex()];
checkState(value != null, "missing value for parameter");
checkState(value instanceof JsonNode, "expected JSON, got SQL value");

if (value.equals(EMPTY_SEQUENCE)) {
Expand All @@ -979,8 +978,8 @@ protected List<Object> visitIrNamedJsonVariable(IrNamedJsonVariable node, Contex
@Override
protected List<Object> visitIrNamedValueVariable(IrNamedValueVariable node, Context context)
{
Object value = parameters.get(node.getName());
checkState(value != null, "missing value for parameter " + node.getName());
Object value = parameters[node.getIndex()];
checkState(value != null, "missing value for parameter");
checkState(value instanceof TypedValue || value instanceof NullNode, "expected SQL value or JSON null, got non-null JSON");

return ImmutableList.of(value);
Expand Down
Loading

0 comments on commit dce2a6e

Please sign in to comment.