Skip to content

Commit e5e60ad

Browse files
committed
Ensure we pass a compatible pruned schema to ParquetRowConverter
1 parent 5e5d886 commit e5e60ad

4 files changed

Lines changed: 104 additions & 24 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ class ParquetFileFormat
310310
hadoopConf.set(
311311
SQLConf.SESSION_LOCAL_TIMEZONE.key,
312312
sparkSession.sessionState.conf.sessionLocalTimeZone)
313+
hadoopConf.setBoolean(
314+
SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
315+
sparkSession.sessionState.conf.nestedSchemaPruningEnabled)
313316
hadoopConf.setBoolean(
314317
SQLConf.CASE_SENSITIVE.key,
315318
sparkSession.sessionState.conf.caseSensitiveAnalysis)
@@ -424,11 +427,12 @@ class ParquetFileFormat
424427
} else {
425428
logDebug(s"Falling back to parquet-mr")
426429
// ParquetRecordReader returns UnsafeRow
430+
val readSupport = new ParquetReadSupport(convertTz, usingVectorizedReader = false)
427431
val reader = if (pushed.isDefined && enableRecordFilter) {
428432
val parquetFilter = FilterCompat.get(pushed.get, null)
429-
new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter)
433+
new ParquetRecordReader[UnsafeRow](readSupport, parquetFilter)
430434
} else {
431-
new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz))
435+
new ParquetRecordReader[UnsafeRow](readSupport)
432436
}
433437
val iter = new RecordReaderIterator(reader)
434438
// SPARK-23457 Register a task completion lister before `initialization`.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,34 +49,82 @@ import org.apache.spark.sql.types._
4949
* Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]]
5050
* to [[prepareForRead()]], but use a private `var` for simplicity.
5151
*/
52-
private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone])
52+
private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone],
53+
usingVectorizedReader: Boolean)
5354
extends ReadSupport[UnsafeRow] with Logging {
5455
private var catalystRequestedSchema: StructType = _
5556

5657
def this() {
5758
// We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only
5859
// used in the vectorized reader, where we get the convertTz value directly, and the value here
5960
// is ignored.
60-
this(None)
61+
this(None, usingVectorizedReader = true)
6162
}
6263

6364
/**
6465
* Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record
6566
* readers. Responsible for figuring out Parquet requested schema used for column pruning.
6667
*/
6768
override def init(context: InitContext): ReadContext = {
69+
val conf = context.getConfiguration
6870
catalystRequestedSchema = {
69-
val conf = context.getConfiguration
7071
val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA)
7172
assert(schemaString != null, "Parquet requested schema not set.")
7273
StructType.fromString(schemaString)
7374
}
7475

75-
val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key,
76+
val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
77+
SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get)
78+
val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key,
7679
SQLConf.CASE_SENSITIVE.defaultValue.get)
77-
val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema(
78-
context.getFileSchema, catalystRequestedSchema, caseSensitive)
79-
80+
val parquetFileSchema = context.getFileSchema
81+
val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema,
82+
catalystRequestedSchema, caseSensitive)
83+
84+
// As part of schema clipping, we add fields in catalystRequestedSchema which are missing
85+
// from parquetFileSchema to parquetClippedSchema. However, nested schema pruning requires
86+
// we ignore unrequested field data when reading from a Parquet file. Therefore we pass two
87+
// schema to ParquetRecordMaterializer: the schema of the file data we want to read
88+
// (parquetRequestedSchema), and the schema of the rows we want to return
89+
// (catalystRequestedSchema). The reader is responsible for reconciling the differences between
90+
// the two.
91+
//
92+
// Aside from checking whether schema pruning is enabled (schemaPruningEnabled), there
93+
// is an additional complication to constructing parquetRequestedSchema. The manner in which
94+
// Spark's two Parquet readers reconcile the differences between parquetRequestedSchema and
95+
// catalystRequestedSchema differ. Spark's vectorized reader does not (currently) support
96+
// reading Parquet files with complex types in their schema. Further, it assumes that
97+
// parquetRequestedSchema includes all fields requested in catalystRequestedSchema. It includes
98+
// logic in its read path to skip fields in parquetRequestedSchema which are not present in the
99+
// file.
100+
//
101+
// Spark's parquet-mr based reader supports reading Parquet files of any kind of complex
102+
// schema, and it supports nested schema pruning as well. Unlike the vectorized reader, the
103+
// parquet-mr reader requires that parquetRequestedSchema include only those fields present in
104+
// the underlying parquetFileSchema. Therefore, in the case where we use the parquet-mr reader
105+
// we intersect the parquetClippedSchema with the parquetFileSchema to construct the
106+
// parquetRequestedSchema set in the ReadContext.
107+
val parquetRequestedSchema =
108+
if (schemaPruningEnabled && !usingVectorizedReader) {
109+
ParquetReadSupport.intersectParquetGroups(parquetClippedSchema, parquetFileSchema)
110+
.map(intersectionGroup =>
111+
new MessageType(intersectionGroup.getName, intersectionGroup.getFields))
112+
.getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE)
113+
} else {
114+
parquetClippedSchema
115+
}
116+
log.info {
117+
s"""Going to read the following fields from the Parquet file with the following schema:
118+
|Parquet file schema:
119+
|$parquetFileSchema
120+
|Parquet clipped schema:
121+
|$parquetClippedSchema
122+
|Parquet requested schema:
123+
|$parquetRequestedSchema
124+
|Catalyst requested schema:
125+
|${catalystRequestedSchema.treeString}
126+
""".stripMargin
127+
}
80128
new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava)
81129
}
82130

@@ -93,13 +141,14 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone])
93141
log.debug(s"Preparing for read Parquet file with message type: $fileSchema")
94142
val parquetRequestedSchema = readContext.getRequestedSchema
95143

96-
logInfo {
97-
s"""Going to read the following fields from the Parquet file:
98-
|
99-
|Parquet form:
144+
log.info {
145+
s"""Going to read the following fields from the Parquet file with the following schema:
146+
|Parquet file schema:
147+
|$fileSchema
148+
|Parquet read schema:
100149
|$parquetRequestedSchema
101-
|Catalyst form:
102-
|$catalystRequestedSchema
150+
|Catalyst read schema:
151+
|${catalystRequestedSchema.treeString}
103152
""".stripMargin
104153
}
105154

@@ -322,6 +371,27 @@ private[parquet] object ParquetReadSupport {
322371
}
323372
}
324373

374+
/**
375+
* Computes the structural intersection between two Parquet group types.
376+
*/
377+
private def intersectParquetGroups(
378+
groupType1: GroupType, groupType2: GroupType): Option[GroupType] = {
379+
val fields =
380+
groupType1.getFields.asScala
381+
.filter(field => groupType2.containsField(field.getName))
382+
.flatMap {
383+
case field1: GroupType =>
384+
intersectParquetGroups(field1, groupType2.getType(field1.getName).asGroupType)
385+
case field1 => Some(field1)
386+
}
387+
388+
if (fields.nonEmpty) {
389+
Some(groupType1.withNewFields(fields.asJava))
390+
} else {
391+
None
392+
}
393+
}
394+
325395
def expandUDT(schema: StructType): StructType = {
326396
def expand(dataType: DataType): DataType = {
327397
dataType match {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ private[parquet] class ParquetRowConverter(
130130
extends ParquetGroupConverter(updater) with Logging {
131131

132132
assert(
133-
parquetType.getFieldCount == catalystType.length,
134-
s"""Field counts of the Parquet schema and the Catalyst schema don't match:
133+
parquetType.getFieldCount <= catalystType.length,
134+
s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema:
135135
|
136136
|Parquet schema:
137137
|$parquetType
@@ -182,18 +182,20 @@ private[parquet] class ParquetRowConverter(
182182

183183
// Converters for each field.
184184
private val fieldConverters: Array[Converter with HasParentContainerUpdater] = {
185-
parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map {
186-
case ((parquetFieldType, catalystField), ordinal) =>
187-
// Converted field value should be set to the `ordinal`-th cell of `currentRow`
188-
newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal))
185+
parquetType.getFields.asScala.map {
186+
case parquetField =>
187+
val fieldIndex = catalystType.fieldIndex(parquetField.getName)
188+
val catalystField = catalystType(fieldIndex)
189+
// Converted field value should be set to the `fieldIndex`-th cell of `currentRow`
190+
newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex))
189191
}.toArray
190192
}
191193

192194
override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex)
193195

194196
override def end(): Unit = {
195197
var i = 0
196-
while (i < currentRow.numFields) {
198+
while (i < fieldConverters.length) {
197199
fieldConverters(i).updater.end()
198200
i += 1
199201
}
@@ -202,11 +204,15 @@ private[parquet] class ParquetRowConverter(
202204

203205
override def start(): Unit = {
204206
var i = 0
205-
while (i < currentRow.numFields) {
207+
while (i < fieldConverters.length) {
206208
fieldConverters(i).updater.start()
207209
currentRow.setNullAt(i)
208210
i += 1
209211
}
212+
while (i < currentRow.numFields) {
213+
currentRow.setNullAt(i)
214+
i += 1
215+
}
210216
}
211217

212218
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class ParquetSchemaPruningSuite
130130
Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil)
131131
}
132132

133-
ignore("partial schema intersection - select missing subfield") {
133+
testSchemaPruning("partial schema intersection - select missing subfield") {
134134
val query = sql("select name.middle, address from contacts where p=2")
135135
checkScan(query, "struct<name:struct<middle:string>,address:string>")
136136
checkAnswer(query.orderBy("id"),

0 commit comments

Comments
 (0)