diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java index bfa5711261de..6487d3edaa96 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionRegistry.java @@ -161,6 +161,7 @@ import io.trino.operator.scalar.VersionFunction; import io.trino.operator.scalar.WilsonInterval; import io.trino.operator.scalar.WordStemFunction; +import io.trino.operator.scalar.XPathFunctions; import io.trino.operator.scalar.time.LocalTimeFunction; import io.trino.operator.scalar.time.TimeFunctions; import io.trino.operator.scalar.time.TimeOperators; @@ -481,6 +482,7 @@ public FunctionRegistry( .scalar(SplitToMultimapFunction.class) .scalars(VarbinaryFunctions.class) .scalars(UrlFunctions.class) + .scalars(XPathFunctions.class) .scalars(MathFunctions.class) .scalar(MathFunctions.Abs.class) .scalar(MathFunctions.Sign.class) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/XPathFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/XPathFunctions.java new file mode 100644 index 000000000000..efcac65cee11 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/XPathFunctions.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.airlift.slice.Slice; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.Description; +import io.trino.spi.function.LiteralParameters; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlNullable; +import io.trino.spi.function.SqlType; +import io.trino.spi.type.StandardTypes; +import io.trino.util.UDFXPathUtil; +import org.w3c.dom.NodeList; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.type.VarcharType.VARCHAR; + +public final class XPathFunctions +{ + private XPathFunctions() {} + + @SqlNullable + @Description("Returns the text contents of the first xml node that matches the xpath expression") + @ScalarFunction("xpath_string") + @LiteralParameters({"x", "y"}) + @SqlType("varchar(x)") + public static Slice xpathString(@SqlType("varchar(x)") Slice xml, @SqlType("varchar(y)") Slice path) + { + UDFXPathUtil xpath = new UDFXPathUtil(); + String s = xpath.evalString(xml.toStringUtf8(), path.toStringUtf8()); + if (s == null) { + return null; + } + return utf8Slice(s); + } + + @SqlNullable + @Description("Returns a string array of values within xml nodes that match the xpath expression") + @ScalarFunction("xpath") + @LiteralParameters({"x", "y"}) + @SqlType("array(varchar)") + public static Block xpath(@SqlType("varchar(y)") Slice xml, @SqlType("varchar(x)") Slice path) + { + List initList = eval(xml.toStringUtf8(), path.toStringUtf8()); + + BlockBuilder builder = VARCHAR.createBlockBuilder(null, initList.size()); + for (String value : initList) { + if (value == null) { + builder.appendNull(); + } + else { + VARCHAR.writeSlice(builder, utf8Slice(value)); + } + } + return builder.build(); + } + + @SqlNullable + @Description("Returns true if the XPath expression evaluates to true, or if a matching node is found") + @ScalarFunction("xpath_boolean") + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.BOOLEAN) + public static Boolean xpath_boolean(@SqlType("varchar(y)") Slice xml, @SqlType("varchar(x)") Slice path) + { + UDFXPathUtil xpath = new UDFXPathUtil(); + return xpath.evalBoolean(xml.toStringUtf8(), path.toStringUtf8()); + } + + @SqlNullable + @Description("Return a double value, or the value 0.0 if no match is found, or NaN for non-numeric match") + @ScalarFunction("xpath_double") + @LiteralParameters({"x", "y"}) + @SqlType(StandardTypes.DOUBLE) + public static Double xpath_double(@SqlType("varchar(y)") Slice xml, @SqlType("varchar(x)") Slice path) + { + UDFXPathUtil xpath = new UDFXPathUtil(); + return xpath.evalNumber(xml.toStringUtf8(), path.toStringUtf8()); + } + + private static List eval(String xml, String path) + { + UDFXPathUtil xpath = new UDFXPathUtil(); + NodeList nodeList = xpath.evalNodeList(xml, path); + if (nodeList == null) { + return Collections.emptyList(); + } + + return Stream.iterate(0, i -> i < nodeList.getLength(), i -> i + 1) + .map(nodeList::item) + .filter(field -> field.getNodeValue() != null) + .map(field -> field.getNodeValue()) + .collect(Collectors.toList()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/util/UDFXPathUtil.java b/core/trino-main/src/main/java/io/trino/util/UDFXPathUtil.java new file mode 100644 index 000000000000..ad819d5b5f76 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/util/UDFXPathUtil.java @@ -0,0 +1,243 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.util; + +import io.trino.spi.TrinoException; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.InputSource; +import org.xml.sax.SAXParseException; + +import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; +import javax.xml.xpath.XPath; +import javax.xml.xpath.XPathConstants; +import javax.xml.xpath.XPathExpression; +import javax.xml.xpath.XPathExpressionException; +import javax.xml.xpath.XPathFactory; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; + +/** + * Utility class for all XPath UDFs. Each UDF instance should keep an instance + * of this class. + */ +public class UDFXPathUtil +{ + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder; + private XPath xpath = XPathFactory.newInstance().newXPath(); + + private XPathExpression expression; + private String oldPath; + + public Object eval(String xml, String path, QName qname) + { + if (xml == null || path == null || qname == null) { + return null; + } + + if (xml.length() == 0 || path.length() == 0) { + return null; + } + + if (!path.equals(oldPath)) { + try { + expression = xpath.compile(path); + } + catch (XPathExpressionException e) { + expression = null; + } + oldPath = path; + } + + if (expression == null) { + return null; + } + + if (builder == null) { + try { + initializeDocumentBuilderFactory(); + builder = documentBuilderFactory.newDocumentBuilder(); + } + catch (ParserConfigurationException e) { + throw new RuntimeException("Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + ReusableStringReader reader = new ReusableStringReader(); + InputSource inputSource = new InputSource(reader); + reader.set(xml); + + try { + return expression.evaluate(builder.parse(inputSource), qname); + } + catch (XPathExpressionException ex) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Invalid expression '" + oldPath + "'", ex); + } + catch (SAXParseException ex) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Error parsing xml data '" + xml + "'", ex); + } + catch (Exception ex) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Error loading expression '" + oldPath + "'", ex); + } + } + + private void initializeDocumentBuilderFactory() throws ParserConfigurationException + { + documentBuilderFactory.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + documentBuilderFactory.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + + public Boolean evalBoolean(String xml, String path) + { + return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); + } + + public String evalString(String xml, String path) + { + return (String) eval(xml, path, XPathConstants.STRING); + } + + public Double evalNumber(String xml, String path) + { + return (Double) eval(xml, path, XPathConstants.NUMBER); + } + + public Node evalNode(String xml, String path) + { + return (Node) eval(xml, path, XPathConstants.NODE); + } + + public NodeList evalNodeList(String xml, String path) + { + return (NodeList) eval(xml, path, XPathConstants.NODESET); + } + + /** + * Reusable, threadsafe version of {@link StringReader}. + */ + public static class ReusableStringReader + extends Reader + { + private String streamData; + private int length = -1; + private int next; + private int mark; + + public void set(String stringData) + { + this.streamData = stringData; + this.length = stringData.length(); + this.mark = 0; + this.next = 0; + } + + /** Check to make sure that the stream has not been closed */ + private void ensureOpen() throws IOException + { + if (streamData == null) { + throw new IOException("Stream closed"); + } + } + + @Override + public int read() throws IOException + { + ensureOpen(); + if (next >= length) { + return -1; + } + return streamData.charAt(next++); + } + + @Override + public int read(char[] cbuf, int off, int len) throws IOException + { + ensureOpen(); + if ((off < 0) || (off > cbuf.length) || (len < 0) + || ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + else if (len == 0) { + return 0; + } + if (next >= length) { + return -1; + } + int n = Math.min(length - next, len); + streamData.getChars(next, next + n, cbuf, off); + next += n; + return n; + } + + @Override + public long skip(long ns) throws IOException + { + ensureOpen(); + if (next >= length) { + return 0; + } + // Bound skip by beginning and end of the source + long n = Math.min(length - next, ns); + n = Math.max(-next, n); + next += n; + return n; + } + + @Override + public boolean ready() throws IOException + { + ensureOpen(); + return true; + } + + @Override + public boolean markSupported() + { + return true; + } + + @Override + public void mark(int readAheadLimit) throws IOException + { + if (readAheadLimit < 0) { + throw new IllegalArgumentException("Read-ahead limit < 0"); + } + ensureOpen(); + mark = next; + } + + @Override + public void reset() throws IOException + { + ensureOpen(); + next = mark; + } + + @Override + public void close() + { + streamData = null; + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestXPathFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestXPathFunction.java new file mode 100644 index 000000000000..72052402401e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestXPathFunction.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.VarcharType; +import org.testng.annotations.Test; + +import java.util.Arrays; + +import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.VarcharType.VARCHAR; + +public class TestXPathFunction + extends AbstractTestFunctions +{ + @Test + public void testXpathString() + { + assertFunction("xpath_string('bbcc','a/b')", VarcharType.createVarcharType(25), "bb"); + assertFunction("xpath_string('bbcc','a')", VarcharType.createVarcharType(25), "bbcc"); + assertFunction("xpath_string('bbcc','a/d')", VarcharType.createVarcharType(25), ""); + assertFunction("xpath_string('b1b2','//b')", VarcharType.createVarcharType(25), "b1"); + assertFunction("xpath_string('b1b2','a/b[2]')", VarcharType.createVarcharType(25), "b2"); + assertFunction("xpath_string('b1b2','a/b[@id=\"b_2\"]')", VarcharType.createVarcharType(34), "b2"); + assertInvalidFunction("xpath_string('\"\"','detail/reason')", INVALID_FUNCTION_ARGUMENT, "Error parsing xml data '\"\"'"); + } + + @Test + public void testXpath() + { + assertFunction("xpath('b1b2','//@id')", + new ArrayType(VARCHAR), + Arrays.asList("foo", "bar")); + assertFunction("xpath('','/descendant::c/ancestor::b/@id')", + new ArrayType(VARCHAR), + Arrays.asList("1", "2")); + assertFunction("xpath('b1b2','a/*')", + new ArrayType(VARCHAR), + Arrays.asList()); + assertFunction("xpath('b1b2','a/*/text()')", + new ArrayType(VARCHAR), + Arrays.asList("b1", "b2")); + assertFunction("xpath('b1b2b3c1c2','a/*[@class=\"bb\"]/text()')", + new ArrayType(VARCHAR), + Arrays.asList("b1", "c1")); + } + + @Test + public void testXpathBoolean() + { + assertFunction("xpath_boolean('b','a/b')", + BOOLEAN, + true); + assertFunction("xpath_boolean('b','a/c')", + BOOLEAN, + false); + assertFunction("xpath_boolean('b','a/b = \"b\"')", + BOOLEAN, + true); + assertFunction("xpath_boolean('10','a/b < 10')", + BOOLEAN, + false); + } + + @Test + public void testXpathDouble() + { + assertFunction("xpath_double('b','a = 10')", + DoubleType.DOUBLE, + 0.0); + assertFunction("xpath_double('this is not a number','a')", + DoubleType.DOUBLE, + Double.NaN); + assertFunction("xpath_double('200000000040000000000','a/b * a/c')", + DoubleType.DOUBLE, + 8.0E19); + } +}