Skip to content

Commit

Permalink
move f.lit into spark session
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Oct 11, 2023
1 parent 3c43343 commit 815f85c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
9 changes: 5 additions & 4 deletions integration_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,11 @@ The marks you care about are all in marks.py
For the most part you can ignore this file. It provides the underlying Spark session to operations that need it, but most tests should interact with
it through `asserts.py`.

All data generation should occur within a Spark session. Typically this is done by passing a
lambda to functions in `asserts.py` such as `assert_gpu_and_cpu_are_equal_collect`. However,
for scalar generation like `gen_scalars`, you may need to put it in a `with_cpu_session`. It is
because negative scale decimals can have problems if called from outside of `with_spark_session`.
All data generation and Spark function calls should occur within a Spark session. Typically
this is done by passing a lambda to functions in `asserts.py` such as
`assert_gpu_and_cpu_are_equal_collect`. However, for scalar generation like `gen_scalars`, you
may need to put it in a `with_cpu_session`. It is because negative scale decimals can have
problems when calling `f.lit` from outside of `with_spark_session`.

## Guidelines for Testing

Expand Down
8 changes: 4 additions & 4 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def test_concat_double_list_with_lit(dg):

@pytest.mark.parametrize('data_gen', non_nested_array_gens, ids=idfn)
def test_concat_list_with_lit(data_gen):
lit_col1 = f.lit(with_cpu_session(lambda spark: gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col2 = f.lit(with_cpu_session(lambda spark: gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col1 = with_cpu_session(lambda spark: f.lit(gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col2 = with_cpu_session(lambda spark: f.lit(gen_scalar(data_gen))).cast(data_gen.data_type)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: binary_op_df(spark, data_gen).select(
Expand Down Expand Up @@ -106,8 +106,8 @@ def test_map_concat(data_gen):

@pytest.mark.parametrize('data_gen', map_gens_sample + decimal_64_map_gens + decimal_128_map_gens, ids=idfn)
def test_map_concat_with_lit(data_gen):
lit_col1 = f.lit(with_cpu_session(lambda spark: gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col2 = f.lit(with_cpu_session(lambda spark: gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col1 = with_cpu_session(lambda spark: f.lit(gen_scalar(data_gen))).cast(data_gen.data_type)
lit_col2 = with_cpu_session(lambda spark: f.lit(gen_scalar(data_gen))).cast(data_gen.data_type)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: binary_op_df(spark, data_gen).select(
f.map_concat(f.col('a'), f.col('b'), lit_col1),
Expand Down

0 comments on commit 815f85c

Please sign in to comment.