Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.

Commit f8d200c

Browse files
authored
Fix parsing of partial result when corrupted record field is present (#518)
1 parent b96dfad commit f8d200c

3 files changed

Lines changed: 87 additions & 8 deletions

File tree

src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,13 @@ private[xml] object StaxXmlParser extends Serializable {
115115
case PermissiveMode =>
116116
logger.debug("Malformed line cause:", cause)
117117
// The logic below is borrowed from Apache Spark's FailureSafeParser.
118-
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
119-
val actualSchema = StructType(schema.filterNot(_.name == options.columnNameOfCorruptRecord))
120118
val resultRow = new Array[Any](schema.length)
121-
var i = 0
122-
while (i < actualSchema.length) {
123-
val from = actualSchema(i)
124-
resultRow(schema.fieldIndex(from.name)) = partialResult.map(_.get(i)).orNull
125-
i += 1
119+
schema.filterNot(_.name == options.columnNameOfCorruptRecord).foreach { from =>
120+
val sourceIndex = schema.fieldIndex(from.name)
121+
resultRow(sourceIndex) = partialResult.map(_.get(sourceIndex)).orNull
126122
}
127-
corruptFieldIndex.foreach(index => resultRow(index) = record)
123+
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
124+
corruptFieldIndex.foreach(resultRow(_) = record)
128125
Some(Row.fromSeq(resultRow))
129126
}
130127
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<row id='0' xml:space='preserve'>
2+
<c2>1234</c2>
3+
<c3>Mark</c3>
4+
<c4>Mark</c4>
5+
<c5>Mark</c5>
6+
<c6>DOLLAR</c6>
7+
<c7>RT</c7>
8+
<c8>USD</c8>
9+
<c9>1</c9>
10+
<c11>3000</c11>
11+
<c20 m='8'></c20>
12+
<c20 m='9'></c20>
13+
<c46>20210207</c46>
14+
<c76>NO</c76>
15+
<c78>20210207</c78>
16+
<c85>14503</c85>
17+
<c93>USD</c93>
18+
<c95>USD</c95>
19+
<c99>LEGACY</c99>
20+
<c99 m='2'>IBAN</c99>
21+
<c100>sm342</c100>
22+
<c100 m='2'></c100>
23+
<c108>NO</c108>
24+
<c192>M</c192>
25+
<c193>46_STREET1</c193>
26+
<c194>0811241751</c194>
27+
<c195>46_STREET1</c195>
28+
<c196>SA0010001</c196>
29+
<c197>1</c197>
30+
</row>

src/test/scala/com/databricks/spark/xml/XmlSuite.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll {
9191
private val whitespaceError = resDir + "whitespace_error.xml"
9292
private val mapAttribute = resDir + "map-attribute.xml"
9393
private val structWithOptChild = resDir + "struct_with_optional_child.xml"
94+
private val manualSchemaCorruptRecord = resDir + "manual_schema_corrupt_record.xml"
9495

9596
private val booksTag = "book"
9697
private val booksRootTag = "books"
@@ -1316,6 +1317,57 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll {
13161317
assert(df.selectExpr("SIZE(Bar)").collect().head.getInt(0) === 2)
13171318
}
13181319

1320+
test("Manual schema with corrupt record field works on permissive mode failure") {
1321+
// See issue #517
1322+
val schema = StructType(List(
1323+
StructField("_id", StringType),
1324+
StructField("_space", StringType),
1325+
StructField("c2", DoubleType),
1326+
StructField("c3", StringType),
1327+
StructField("c4", StringType),
1328+
StructField("c5", StringType),
1329+
StructField("c6", StringType),
1330+
StructField("c7", StringType),
1331+
StructField("c8", StringType),
1332+
StructField("c9", DoubleType),
1333+
StructField("c11", DoubleType),
1334+
StructField("c20", ArrayType(StructType(List(
1335+
StructField("_VALUE", StringType),
1336+
StructField("_m", IntegerType)))
1337+
)),
1338+
StructField("c46", StringType),
1339+
StructField("c76", StringType),
1340+
StructField("c78", StringType),
1341+
StructField("c85", DoubleType),
1342+
StructField("c93", StringType),
1343+
StructField("c95", StringType),
1344+
StructField("c99", ArrayType(StructType(List(
1345+
StructField("_VALUE", StringType),
1346+
StructField("_m", IntegerType)))
1347+
)),
1348+
StructField("c100", ArrayType(StructType(List(
1349+
StructField("_VALUE", StringType),
1350+
StructField("_m", IntegerType)))
1351+
)),
1352+
StructField("c108", StringType),
1353+
StructField("c192", DoubleType),
1354+
StructField("c193", StringType),
1355+
StructField("c194", StringType),
1356+
StructField("c195", StringType),
1357+
StructField("c196", StringType),
1358+
StructField("c197", DoubleType),
1359+
StructField("_corrupt_record", StringType)))
1360+
1361+
val df = spark.read
1362+
.option("inferSchema", false)
1363+
.option("rowTag", "row")
1364+
.schema(schema)
1365+
.xml(manualSchemaCorruptRecord)
1366+
1367+
// Assert it works at all
1368+
assert(df.collect().head.getAs[String]("_corrupt_record") !== null)
1369+
}
1370+
13191371
private def getLines(path: Path): Seq[String] = {
13201372
val source = Source.fromFile(path.toFile)
13211373
try {

0 commit comments

Comments
 (0)