Skip to content

Commit fbe9db6

Browse files
committed
[SPARK-23628][SQL] Fix calculateParamLength
1 parent 404f7e2 commit fbe9db6

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,14 +1253,15 @@ class CodegenContext {
12531253
*/
12541254
def calculateParamLength(params: Seq[Expression]): Int = {
12551255
def paramLengthForExpr(input: Expression): Int = {
1256-
// For a nullable expression, we need to pass in an extra boolean parameter.
1257-
(if (input.nullable) 1 else 0) + javaType(input.dataType) match {
1256+
val javaParamLength = javaType(input.dataType) match {
12581257
case JAVA_LONG | JAVA_DOUBLE => 2
12591258
case _ => 1
12601259
}
1260+
// For a nullable expression, we need to pass in an extra boolean parameter.
1261+
(if (input.nullable) 1 else 0) + javaParamLength
12611262
}
12621263
// Initial value is 1 for `this`.
1263-
1 + params.map(paramLengthForExpr(_)).sum
1264+
1 + params.map(paramLengthForExpr).sum
12641265
}
12651266

12661267
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
436436
ctx.addImmutableStateIfNotExists("String", mutableState2)
437437
assert(ctx.inlinedMutableStates.length == 2)
438438
}
439+
440+
test("SPARK-23628: calculateParamLength should compute properly the param length") {
441+
val ctx = new CodegenContext
442+
assert(ctx.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
443+
assert(ctx.calculateParamLength(Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
444+
}
439445
}

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2525
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2626
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
2727
import org.apache.spark.sql.expressions.scalalang.typed
28-
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
28+
import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
3131
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
249249
}
250250

251251
test("Skip splitting consume function when parameter number exceeds JVM limit") {
252-
import testImplicits._
253-
254-
Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
252+
// since every field is nullable we have 2 params for each input column (one for the value
253+
// and one for the isNull variable)
254+
Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) =>
255255
withTempPath { dir =>
256256
val path = dir.getCanonicalPath
257-
spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
257+
spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*)
258258
.write.mode(SaveMode.Overwrite).parquet(path)
259259

260260
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
@@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
263263
val df = spark.read.parquet(path).selectExpr(projection: _*)
264264

265265
val plan = df.queryExecution.executedPlan
266-
val wholeStageCodeGenExec = plan.find(p => p match {
267-
case wp: WholeStageCodegenExec => true
266+
val wholeStageCodeGenExec = plan.find {
267+
case _: WholeStageCodegenExec => true
268268
case _ => false
269-
})
269+
}
270270
assert(wholeStageCodeGenExec.isDefined)
271271
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
272272
assert(code.body.contains("project_doConsume") == hasSplit)

0 commit comments

Comments
 (0)