From b48335ab78ae2b3056d7063aeae7916c6dbd53e5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Dec 2024 16:43:16 +0800 Subject: [PATCH] Support some escape characters in search list when rewriting regexp_replace to string replace Signed-off-by: Haoyang Li --- integration_tests/src/main/python/regexp_test.py | 7 +++++-- .../main/scala/com/nvidia/spark/rapids/GpuOverrides.scala | 5 +++-- .../com/nvidia/spark/rapids/StringFunctionSuite.scala | 5 +++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index c2062605ca1..bfc8de85632 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -1012,7 +1012,9 @@ def test_regexp_replace_simple(regexp_enabled): 'REGEXP_REPLACE(a, "ab", "PROD")', 'REGEXP_REPLACE(a, "ae", "PROD")', 'REGEXP_REPLACE(a, "bc", "PROD")', - 'REGEXP_REPLACE(a, "fa", "PROD")' + 'REGEXP_REPLACE(a, "fa", "PROD")', + 'REGEXP_REPLACE(a, "a\n", "PROD")', + 'REGEXP_REPLACE(a, "\n", "PROD")' ), conf=conf ) @@ -1032,7 +1034,8 @@ def test_regexp_replace_multi_optimization(regexp_enabled): 'REGEXP_REPLACE(a, "aa|bb|cc|dd", "PROD")', 'REGEXP_REPLACE(a, "(aa|bb)|(cc|dd)", "PROD")', 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee", "PROD")', - 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")' + 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")', + 'REGEXP_REPLACE(a, "a\n|b\a|c\t", "PROD")' ), conf=conf ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 45905f0b9e0..2d940ca9467 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -593,8 +593,9 @@ object GpuOverrides extends Logging { } def isSupportedStringReplacePattern(strLit: String): Boolean = { - // check for regex special characters, except for \u0000 which we can support - !regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern)) + // check for regex special characters, except for \u0000, \n, \r, and \t which we can support + val supported = Seq("\u0000", "\n", "\r", "\t") + !regexList.filterNot(supported.contains(_)).exists(pattern => strLit.contains(pattern)) } def isSupportedStringReplacePattern(exp: Expression): Boolean = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala index 3c3933946c5..25c8c10b26d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -207,7 +207,8 @@ class RegExpUtilsSuite extends AnyFunSuite { "aa|bb|cc|dd" -> Seq("aa", "bb", "cc", "dd"), "(aa|bb)|(cc|dd)" -> Seq("aa", "bb", "cc", "dd"), "aa|bb|cc|dd|ee" -> Seq("aa", "bb", "cc", "dd", "ee"), - "aa|bb|cc|dd|ee|ff" -> Seq("aa", "bb", "cc", "dd", "ee", "ff") + "aa|bb|cc|dd|ee|ff" -> Seq("aa", "bb", "cc", "dd", "ee", "ff"), + "a\n|b\t|c\r" -> Seq("a\n", "b\t", "c\r") ) regexChoices.foreach { case (pattern, choices) =>