Skip to content

Commit

Permalink
ORC encrypted write should fallback to CPU [databricks] (#5604)
Browse files Browse the repository at this point in the history
* ORC encrypted write should fallback to CPU

Signed-off-by: Raza Jafri <[email protected]>

* skip the test instead of xfail

Signed-off-by: Raza Jafri <[email protected]>

* Skip ORC encryption test for DB10.4

Signed-off-by: Raza Jafri <[email protected]>

* addressed review comments

Signed-off-by: Raza Jafri <[email protected]>

Co-authored-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri and razajafri authored Jun 2, 2022
1 parent 91c1230 commit db7b611
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
28 changes: 28 additions & 0 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

from asserts import assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_write
from spark_session import is_databricks104_or_later
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
Expand Down Expand Up @@ -166,3 +167,30 @@ def create_empty_df(spark, path):
lambda spark, path: spark.read.orc(path),
data_path,
conf={'spark.rapids.sql.format.orc.write.enabled': True})

@allow_non_gpu('DataWritingCommandExec')
@pytest.mark.parametrize("path", ["", "kms://http@localhost:9600/kms"])
@pytest.mark.parametrize("provider", ["", "hadoop"])
@pytest.mark.parametrize("encrypt", ["", "pii:a"])
@pytest.mark.parametrize("mask", ["", "sha256:a"])
@pytest.mark.skipif(is_databricks104_or_later(), reason="The test will fail on Databricks10.4 because `HadoopShimsPre2_3$NullKeyProvider` is loaded")
def test_orc_write_encryption_fallback(spark_tmp_path, spark_tmp_table_factory, path, provider, encrypt, mask):
def write_func(spark, write_path):
writer = unary_op_df(spark, gen).coalesce(1).write
if path != "":
writer.option("hadoop.security.key.provider.path", path)
if provider != "":
writer.option("orc.key.provider", provider)
if encrypt != "":
writer.option("orc.encrypt", encrypt)
if mask != "":
writer.option("orc.mask", mask)
writer.format("orc").mode('overwrite').option("path", write_path).saveAsTable(spark_tmp_table_factory.get())
if path == "" and provider == "" and encrypt == "" and mask == "":
pytest.skip("Skip this test when none of the encryption confs are set")
gen = IntegerGen()
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_fallback_write(write_func,
lambda spark, path: spark.read.orc(path),
data_path,
'DataWritingCommandExec')
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ object GpuOrcFileFormat extends Logging {
s"${RapidsConf.ENABLE_ORC_WRITE} to true")
}

val keyProviderPath = options.getOrElse("hadoop.security.key.provider.path", "")
val keyProvider = options.getOrElse("orc.key.provider", "")
val encrypt = options.getOrElse("orc.encrypt", "")
val mask = options.getOrElse("orc.mask", "")

if (!keyProvider.isEmpty || !keyProviderPath.isEmpty || !encrypt.isEmpty || !mask.isEmpty) {
meta.willNotWorkOnGpu("Encryption is not yet supported on GPU. If encrypted ORC " +
"writes are not required unset the \"hadoop.security.key.provider.path\" and " +
"\"orc.key.provider\" and \"orc.encrypt\" and \"orc.mask\"")
}

FileFormatChecks.tag(meta, schema, OrcFormatType, WriteFileOp)

val sqlConf = spark.sessionState.conf
Expand Down

0 comments on commit db7b611

Please sign in to comment.