From 77588fd309432ab434b8f8dafc00bc8eacfbcd88 Mon Sep 17 00:00:00 2001 From: vamsi-amazon Date: Fri, 4 Nov 2022 12:13:25 -0700 Subject: [PATCH] back quote fix Signed-off-by: vamsi-amazon --- .../sql/ppl/PrometheusCatalogCommandsIT.java | 12 ++--- .../response/PrometheusResponse.java | 20 +++++++- .../PrometheusDefaultImplementor.java | 8 +--- .../model/PrometheusResponseFieldNames.java | 3 ++ .../querybuilder/AggregationQueryBuilder.java | 8 +++- .../storage/PrometheusMetricScanTest.java | 47 +++++++++++++++++++ .../storage/PrometheusMetricTableTest.java | 27 +++++++++++ 7 files changed, 109 insertions(+), 16 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusCatalogCommandsIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusCatalogCommandsIT.java index 9e197bbb27..10c1e911ab 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusCatalogCommandsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PrometheusCatalogCommandsIT.java @@ -46,12 +46,12 @@ public void testSourceMetricCommand() { @SneakyThrows public void testMetricAvgAggregationCommand() { JSONObject response = - executeQuery("source=my_prometheus.prometheus_http_requests_total | stats avg(@value) by span(@timestamp, 15s), handler, job"); + executeQuery("source=`my_prometheus`.`prometheus_http_requests_total` | stats avg(@value) as `agg` by span(@timestamp, 15s), `handler`, `job`"); verifySchema(response, - schema("avg(@value)", "double"), + schema("agg", "double"), schema("span(@timestamp,15s)", "timestamp"), - schema("handler", "string"), - schema("job", "string")); + schema("`handler`", "string"), + schema("`job`", "string")); Assertions.assertTrue(response.getInt("size") > 0); Assertions.assertEquals(4, response.getJSONArray("datarows").getJSONArray(0).length()); JSONArray firstRow = response.getJSONArray("datarows").getJSONArray(0); @@ -65,11 +65,11 @@ public void testMetricAvgAggregationCommand() { @SneakyThrows public void testMetricAvgAggregationCommandWithAlias() { JSONObject response = - executeQuery("source=my_prometheus.prometheus_http_requests_total | stats avg(@value) as agg by span(@timestamp, 15s), handler, job"); + executeQuery("source=my_prometheus.prometheus_http_requests_total | stats avg(@value) as agg by span(@timestamp, 15s), `handler`, job"); verifySchema(response, schema("agg", "double"), schema("span(@timestamp,15s)", "timestamp"), - schema("handler", "string"), + schema("`handler`", "string"), schema("job", "string")); Assertions.assertTrue(response.getInt("size") > 0); Assertions.assertEquals(4, response.getJSONArray("datarows").getJSONArray(0).length()); diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/response/PrometheusResponse.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/response/PrometheusResponse.java index e26e006403..ef7f19ba2f 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/response/PrometheusResponse.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/response/PrometheusResponse.java @@ -5,7 +5,6 @@ package org.opensearch.sql.prometheus.response; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.prometheus.data.constants.PrometheusFieldConstants.LABELS; @@ -16,6 +15,7 @@ import java.util.LinkedHashMap; import java.util.List; import lombok.NonNull; +import org.apache.commons.lang3.StringUtils; import org.json.JSONArray; import org.json.JSONObject; import org.opensearch.sql.data.model.ExprDoubleValue; @@ -26,6 +26,8 @@ import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.prometheus.storage.model.PrometheusResponseFieldNames; public class PrometheusResponse implements Iterable { @@ -100,7 +102,7 @@ public Iterator iterator() { private void insertLabels(LinkedHashMap linkedHashMap, JSONObject metric) { for (String key : metric.keySet()) { - linkedHashMap.put(key, new ExprStringValue(metric.getString(key))); + linkedHashMap.put(getKey(key), new ExprStringValue(metric.getString(key))); } } @@ -113,4 +115,18 @@ private ExprValue getValue(JSONArray jsonArray, Integer index, ExprType exprType return new ExprDoubleValue(jsonArray.getDouble(index)); } + private String getKey(String key) { + if (this.prometheusResponseFieldNames.getGroupByList() == null) { + return key; + } else { + return this.prometheusResponseFieldNames.getGroupByList().stream() + .filter(expression -> expression.getDelegated() instanceof ReferenceExpression) + .filter(expression + -> ((ReferenceExpression) expression.getDelegated()).getAttr().equals(key)) + .findFirst() + .map(NamedExpression::getName) + .orElse(key); + } + } + } diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/implementor/PrometheusDefaultImplementor.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/implementor/PrometheusDefaultImplementor.java index 071cd7ba8c..8cae250e5e 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/implementor/PrometheusDefaultImplementor.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/implementor/PrometheusDefaultImplementor.java @@ -7,10 +7,6 @@ package org.opensearch.sql.prometheus.storage.implementor; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.prometheus.data.constants.PrometheusFieldConstants.LABELS; - -import java.util.ArrayList; import java.util.List; import java.util.Optional; import lombok.RequiredArgsConstructor; @@ -18,14 +14,11 @@ import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.span.SpanExpression; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; -import org.opensearch.sql.planner.physical.ProjectOperator; import org.opensearch.sql.prometheus.planner.logical.PrometheusLogicalMetricAgg; import org.opensearch.sql.prometheus.planner.logical.PrometheusLogicalMetricScan; import org.opensearch.sql.prometheus.storage.PrometheusMetricScan; @@ -130,6 +123,7 @@ private void setPrometheusResponseFieldNames(PrometheusLogicalMetricAgg node, prometheusResponseFieldNames.setValueFieldName(node.getAggregatorList().get(0).getName()); prometheusResponseFieldNames.setValueType(node.getAggregatorList().get(0).type()); prometheusResponseFieldNames.setTimestampFieldName(spanExpression.get().getNameOrAlias()); + prometheusResponseFieldNames.setGroupByList(node.getGroupByList()); context.setPrometheusResponseFieldNames(prometheusResponseFieldNames); } diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/model/PrometheusResponseFieldNames.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/model/PrometheusResponseFieldNames.java index 4276848aa2..d3a6ef184f 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/model/PrometheusResponseFieldNames.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/model/PrometheusResponseFieldNames.java @@ -11,9 +11,11 @@ import static org.opensearch.sql.prometheus.data.constants.PrometheusFieldConstants.TIMESTAMP; import static org.opensearch.sql.prometheus.data.constants.PrometheusFieldConstants.VALUE; +import java.util.List; import lombok.Getter; import lombok.Setter; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.NamedExpression; @Getter @@ -23,5 +25,6 @@ public class PrometheusResponseFieldNames { private String valueFieldName = VALUE; private ExprType valueType = DOUBLE; private String timestampFieldName = TIMESTAMP; + private List groupByList; } diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/querybuilder/AggregationQueryBuilder.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/querybuilder/AggregationQueryBuilder.java index 1aff9eca88..76c8c6872e 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/querybuilder/AggregationQueryBuilder.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/querybuilder/AggregationQueryBuilder.java @@ -7,11 +7,14 @@ package org.opensearch.sql.prometheus.storage.querybuilder; +import java.sql.Ref; import java.util.List; import java.util.Set; import java.util.stream.Collectors; import lombok.NoArgsConstructor; +import org.apache.commons.lang3.StringUtils; import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.span.SpanExpression; @@ -63,7 +66,10 @@ public static String build(List namedAggregatorList, if (groupByList.size() > 0) { aggregateQuery.append("by("); aggregateQuery.append( - groupByList.stream().map(NamedExpression::getName).collect(Collectors.joining(", "))); + groupByList.stream() + .filter(expression -> expression.getDelegated() instanceof ReferenceExpression) + .map(expression -> ((ReferenceExpression) expression.getDelegated()).getAttr()) + .collect(Collectors.joining(", "))); aggregateQuery.append(")"); } } diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricScanTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricScanTest.java index ac99a996af..984103df56 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricScanTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricScanTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.prometheus.constants.TestConstants.ENDTIME; import static org.opensearch.sql.prometheus.constants.TestConstants.QUERY; import static org.opensearch.sql.prometheus.constants.TestConstants.STARTTIME; @@ -22,6 +23,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.LinkedHashMap; import lombok.SneakyThrows; import org.json.JSONObject; @@ -36,6 +38,7 @@ import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.expression.DSL; import org.opensearch.sql.prometheus.client.PrometheusClient; import org.opensearch.sql.prometheus.storage.model.PrometheusResponseFieldNames; @@ -163,6 +166,49 @@ void testQueryResponseIteratorWithGivenPrometheusResponseWithLongInAggType() { Assertions.assertFalse(prometheusMetricScan.hasNext()); } + @Test + @SneakyThrows + void testQueryResponseIteratorWithGivenPrometheusResponseWithBackQuotedFieldNames() { + PrometheusResponseFieldNames prometheusResponseFieldNames + = new PrometheusResponseFieldNames(); + prometheusResponseFieldNames.setValueFieldName("testAgg"); + prometheusResponseFieldNames.setValueType(LONG); + prometheusResponseFieldNames.setTimestampFieldName(TIMESTAMP); + prometheusResponseFieldNames.setGroupByList( + Collections.singletonList(DSL.named("`instance`", DSL.ref("`instance`", STRING)))); + PrometheusMetricScan prometheusMetricScan = new PrometheusMetricScan(prometheusClient); + prometheusMetricScan.setPrometheusResponseFieldNames(prometheusResponseFieldNames); + prometheusMetricScan.getRequest().setPromQl(QUERY); + prometheusMetricScan.getRequest().setStartTime(STARTTIME); + prometheusMetricScan.getRequest().setEndTime(ENDTIME); + prometheusMetricScan.getRequest().setStep(STEP); + + when(prometheusClient.queryRange(any(), any(), any(), any())) + .thenReturn(new JSONObject(getJson("query_range_result.json"))); + prometheusMetricScan.open(); + Assertions.assertTrue(prometheusMetricScan.hasNext()); + ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() {{ + put(TIMESTAMP, new ExprTimestampValue(Instant.ofEpochMilli(1435781430781L))); + put("testAgg", new ExprLongValue(1)); + put("`instance`", new ExprStringValue("localhost:9090")); + put("__name__", new ExprStringValue("up")); + put("job", new ExprStringValue("prometheus")); + } + }); + assertEquals(firstRow, prometheusMetricScan.next()); + Assertions.assertTrue(prometheusMetricScan.hasNext()); + ExprTupleValue secondRow = new ExprTupleValue(new LinkedHashMap<>() {{ + put(TIMESTAMP, new ExprTimestampValue(Instant.ofEpochMilli(1435781430781L))); + put("testAgg", new ExprLongValue(0)); + put("`instance`", new ExprStringValue("localhost:9091")); + put("__name__", new ExprStringValue("up")); + put("job", new ExprStringValue("node")); + } + }); + assertEquals(secondRow, prometheusMetricScan.next()); + Assertions.assertFalse(prometheusMetricScan.hasNext()); + } + @Test @SneakyThrows void testQueryResponseIteratorForQueryRangeFunction() { @@ -247,6 +293,7 @@ void testEmptyQueryWithException() { runtimeException.getMessage()); } + @Test @SneakyThrows void testExplain() { diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java index ff5ae5dcf5..cff8a11610 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java @@ -739,4 +739,31 @@ void testOptimize() { assertEquals(inputPlan, optimizedPlan); } + @Test + void testImplementPrometheusQueryWithBackQuotedFieldNamesInStatsQuery() { + + PrometheusMetricTable prometheusMetricTable = + new PrometheusMetricTable(client, "prometheus_http_total_requests"); + + + // IndexScanAgg with Filter + PhysicalPlan plan = prometheusMetricTable.implement( + indexScanAgg("prometheus_http_total_requests", + dsl.and(dsl.equal(DSL.ref("code", STRING), DSL.literal(stringValue("200"))), + dsl.equal(DSL.ref("handler", STRING), DSL.literal(stringValue("/ready/")))), + ImmutableList + .of(named("AVG(@value)", + dsl.avg(DSL.ref("@value", INTEGER)))), + ImmutableList.of(named("`job`", DSL.ref("`job`", STRING)), + named("span", DSL.span(DSL.ref("@timestamp", ExprCoreType.TIMESTAMP), + DSL.literal(40), "s"))))); + assertTrue(plan instanceof PrometheusMetricScan); + PrometheusQueryRequest prometheusQueryRequest = ((PrometheusMetricScan) plan).getRequest(); + assertEquals( + "avg by(job) (avg_over_time" + + "(prometheus_http_total_requests{code=\"200\" , handler=\"/ready/\"}[40s]))", + prometheusQueryRequest.getPromQl()); + + } + }