Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-34708][SQL] Code-gen for left semi/anti broadcast nested loop …
…join (build right side) ### What changes were proposed in this pull request? This PR is to add code-gen support for left semi / left anti BroadcastNestedLoopJoin (build side is right side). The execution code path for build left side cannot fit into whole stage code-gen framework, so only add the code-gen for build right side here. Reference: the iterator (non-code-gen) code path is `BroadcastNestedLoopJoinExec.leftExistenceJoin()` with `BuildRight`. ### Why are the changes needed? Improve query CPU performance. Tested with a simple query: ``` val N = 20 << 20 val M = 1 << 4 val dim = broadcast(spark.range(M).selectExpr("id as k2")) codegenBenchmark("left semi broadcast nested loop join", N) { park.range(N).selectExpr(s"id as k1").join( dim, col("k1") + 1 <= col("k2"), "left_semi") } ``` Seeing 5x run time improvement: ``` Running benchmark: left semi broadcast nested loop join Running case: left semi broadcast nested loop join codegen off Stopped after 2 iterations, 6958 ms Running case: left semi broadcast nested loop join codegen on Stopped after 5 iterations, 3383 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz left semi broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------- left semi broadcast nested loop join codegen off 3434 3479 65 6.1 163.7 1.0X left semi broadcast nested loop join codegen on 672 677 5 31.2 32.1 5.1X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Changed existing unit test in `ExistenceJoinSuite.scala` to cover all code paths: * left semi/anti + empty right side + empty condition * left semi/anti + non-empty right side + empty condition * left semi/anti + right side + non-empty condition Added unit test in `WholeStageCodegenSuite.scala` to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well. Example query: ``` val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) df1.join(df2, $"k1" + 1 <= $"k2", "left_semi").explain("codegen") ``` Example generated code (`bnlj_doConsume_0` method): This is for left semi join. The generated code for left anti join is mostly to be same as here, except L55 to be `if (bnlj_findMatchedRow_0 == false) {`. ``` == Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) == *(2) Project [id#0L AS k1#2L] +- *(2) BroadcastNestedLoopJoin BuildRight, LeftSemi, ((id#0L + 1) <= k2#6L) :- *(2) Range (0, 4, step=1, splits=2) +- BroadcastExchange IdentityBroadcastMode, [id=#23] +- *(1) Project [id#4L AS k2#6L] +- *(1) Range (0, 3, step=1, splits=2) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean range_initRange_0; /* 010 */ private long range_nextIndex_0; /* 011 */ private TaskContext range_taskContext_0; /* 012 */ private InputMetrics range_inputMetrics_0; /* 013 */ private long range_batchEnd_0; /* 014 */ private long range_numElementsTodo_0; /* 015 */ private InternalRow[] bnlj_buildRowArray_0; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ /* 026 */ range_taskContext_0 = TaskContext.get(); /* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics(); /* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value(); /* 031 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 033 */ /* 034 */ } /* 035 */ /* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException { /* 037 */ boolean bnlj_findMatchedRow_0 = false; /* 038 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) { /* 039 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0]; /* 040 */ /* 041 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0); /* 042 */ /* 043 */ long bnlj_value_3 = -1L; /* 044 */ /* 045 */ bnlj_value_3 = bnlj_expr_0_0 + 1L; /* 046 */ /* 047 */ boolean bnlj_value_2 = false; /* 048 */ bnlj_value_2 = bnlj_value_3 <= bnlj_value_1; /* 049 */ if (!(false || !bnlj_value_2)) /* 050 */ { /* 051 */ bnlj_findMatchedRow_0 = true; /* 052 */ break; /* 053 */ } /* 054 */ } /* 055 */ if (bnlj_findMatchedRow_0 == true) { /* 056 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1); /* 057 */ /* 058 */ // common sub-expressions /* 059 */ /* 060 */ range_mutableStateArray_0[3].reset(); /* 061 */ /* 062 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0); /* 063 */ append((range_mutableStateArray_0[3].getRow()).copy()); /* 064 */ /* 065 */ } /* 066 */ /* 067 */ } /* 068 */ /* 069 */ private void initRange(int idx) { /* 070 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 071 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 072 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L); /* 073 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 074 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 075 */ long partitionEnd; /* 076 */ /* 077 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 078 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 079 */ range_nextIndex_0 = Long.MAX_VALUE; /* 080 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 081 */ range_nextIndex_0 = Long.MIN_VALUE; /* 082 */ } else { /* 083 */ range_nextIndex_0 = st.longValue(); /* 084 */ } /* 085 */ range_batchEnd_0 = range_nextIndex_0; /* 086 */ /* 087 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 088 */ .multiply(step).add(start); /* 089 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 090 */ partitionEnd = Long.MAX_VALUE; /* 091 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 092 */ partitionEnd = Long.MIN_VALUE; /* 093 */ } else { /* 094 */ partitionEnd = end.longValue(); /* 095 */ } /* 096 */ /* 097 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 098 */ java.math.BigInteger.valueOf(range_nextIndex_0)); /* 099 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue(); /* 100 */ if (range_numElementsTodo_0 < 0) { /* 101 */ range_numElementsTodo_0 = 0; /* 102 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 103 */ range_numElementsTodo_0++; /* 104 */ } /* 105 */ } /* 106 */ /* 107 */ protected void processNext() throws java.io.IOException { /* 108 */ // initialize Range /* 109 */ if (!range_initRange_0) { /* 110 */ range_initRange_0 = true; /* 111 */ initRange(partitionIndex); /* 112 */ } /* 113 */ /* 114 */ while (true) { /* 115 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 116 */ long range_nextBatchTodo_0; /* 117 */ if (range_numElementsTodo_0 > 1000L) { /* 118 */ range_nextBatchTodo_0 = 1000L; /* 119 */ range_numElementsTodo_0 -= 1000L; /* 120 */ } else { /* 121 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 122 */ range_numElementsTodo_0 = 0; /* 123 */ if (range_nextBatchTodo_0 == 0) break; /* 124 */ } /* 125 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 126 */ } /* 127 */ /* 128 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 129 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 130 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 131 */ /* 132 */ bnlj_doConsume_0(range_value_0); /* 133 */ /* 134 */ if (shouldStop()) { /* 135 */ range_nextIndex_0 = range_value_0 + 1L; /* 136 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1); /* 137 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1); /* 138 */ return; /* 139 */ } /* 140 */ /* 141 */ } /* 142 */ range_nextIndex_0 = range_batchEnd_0; /* 143 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 144 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 145 */ range_taskContext_0.killTaskIfInterrupted(); /* 146 */ } /* 147 */ } /* 148 */ /* 149 */ } ``` Closes #31874 from c21/code-semi-anti. Authored-by: Cheng Su <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
- Loading branch information