Skip to content

Commit

Permalink
[SPARK-16896][SQL] Handle duplicated field names in header consistent…
Browse files Browse the repository at this point in the history
…ly with null or empty strings in CSV

## What changes were proposed in this pull request?

Currently, CSV datasource allows to load duplicated empty string fields or fields having `nullValue` in the header. It'd be great if this can deal with normal fields as well.

This PR proposes handling the duplicates consistently with the existing behaviour with considering case-sensitivity (`spark.sql.caseSensitive`) as below:

data below:

```
fieldA,fieldB,,FIELDA,fielda,,
1,2,3,4,5,6,7
```

is parsed as below:

```scala
spark.read.format("csv").option("header", "true").load("test.csv").show()
```

- when `spark.sql.caseSensitive` is `false` (by default).

  ```
  +-------+------+---+-------+-------+---+---+
  |fieldA0|fieldB|_c2|FIELDA3|fieldA4|_c5|_c6|
  +-------+------+---+-------+-------+---+---+
  |      1|     2|  3|      4|      5|  6|  7|
  +-------+------+---+-------+-------+---+---+
  ```

- when `spark.sql.caseSensitive` is `true`.

  ```
  +-------+------+---+-------+-------+---+---+
  |fieldA0|fieldB|_c2| FIELDA|fieldA4|_c5|_c6|
  +-------+------+---+-------+-------+---+---+
  |      1|     2|  3|      4|      5|  6|  7|
  +-------+------+---+-------+-------+---+---+
  ```

**In more details**,

There is a good reference about this problem, `read.csv()` in R. So, I initially wanted to propose the similar behaviour.

In case of R,  the CSV data below:

```
fieldA,fieldB,,fieldA,fieldA,,
1,2,3,4,5,6,7
```

is parsed as below:

```r
test <- read.csv(file="test.csv",header=TRUE,sep=",")
> test
  fieldA fieldB X fieldA.1 fieldA.2 X.1 X.2
1      1      2 3        4        5   6   7
```

However, Spark CSV datasource already is handling duplicated empty strings and `nullValue` as field names. So the data below:

```
,,,fieldA,,fieldB,
1,2,3,4,5,6,7
```

is parsed as below:

```scala
spark.read.format("csv").option("header", "true").load("test.csv").show()
```
```
+---+---+---+------+---+------+---+
|_c0|_c1|_c2|fieldA|_c4|fieldB|_c6|
+---+---+---+------+---+------+---+
|  1|  2|  3|     4|  5|     6|  7|
+---+---+---+------+---+------+---+
```

R starts the number for each duplicate but Spark adds the number for its position for all fields for `nullValue` and empty strings.

In terms of case-sensitivity, it seems R is case-sensitive as below: (it seems it is not configurable).

```
a,a,a,A,A
1,2,3,4,5
```

is parsed as below:

```r
test <- read.csv(file="test.csv",header=TRUE,sep=",")
> test
  a a.1 a.2 A A.1
1 1   2   3 4   5
```

## How was this patch tested?

Unit test in `CSVSuite`.

Author: hyukjinkwon <[email protected]>

Closes #14745 from HyukjinKwon/SPARK-16896.
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Oct 11, 2016
1 parent d5ec4a3 commit 90217f9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,60 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val rdd = baseRdd(sparkSession, csvOptions, paths)
val firstLine = findFirstLine(csvOptions, rdd)
val firstRow = new CsvReader(csvOptions).parseLine(firstLine)

val header = if (csvOptions.headerFlag) {
firstRow.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value
}
} else {
firstRow.zipWithIndex.map { case (value, index) => s"_c$index" }
}
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)

val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
val schema = if (csvOptions.inferSchemaFlag) {
CSVInferSchema.infer(parsedRdd, header, csvOptions)
} else {
// By default fields are assumed to be StringType
val schemaFields = header.map { fieldName =>
StructField(fieldName.toString, StringType, nullable = true)
StructField(fieldName, StringType, nullable = true)
}
StructType(schemaFields)
}
Some(schema)
}

/**
* Generates a header from the given row which is null-safe and duplicate-safe.
*/
private def makeSafeHeader(
row: Array[String],
options: CSVOptions,
caseSensitive: Boolean): Array[String] = {
if (options.headerFlag) {
val duplicates = {
val headerNames = row.filter(_ != null)
.map(name => if (caseSensitive) name else name.toLowerCase)
headerNames.diff(headerNames.distinct).distinct
}

row.zipWithIndex.map { case (value, index) =>
if (value == null || value.isEmpty || value == options.nullValue) {
// When there are empty strings or the values set in `nullValue`, put the
// index as the suffix.
s"_c$index"
} else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
// When there are case-insensitive duplicates, put the index as the suffix.
s"$value$index"
} else if (duplicates.contains(value)) {
// When there are duplicates, put the index as the suffix.
s"$value$index"
} else {
value
}
}
} else {
row.zipWithIndex.map { case (_, index) =>
// Uses default column names, "_c#" where # is its position of fields
// when header option is disabled.
s"_c$index"
}
}
}

override def prepareWrite(
sparkSession: SparkSession,
job: Job,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.hadoop.io.compress.GzipCodec

import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -856,4 +857,36 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat)
}
}

test("load duplicated field names consistently with null or empty strings - case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
withTempPath { path =>
Seq("a,a,c,A,b,B").toDF().write.text(path.getAbsolutePath)
val actualSchema = spark.read
.format("csv")
.option("header", true)
.load(path.getAbsolutePath)
.schema
val fields = Seq("a0", "a1", "c", "A", "b", "B").map(StructField(_, StringType, true))
val expectedSchema = StructType(fields)
assert(actualSchema == expectedSchema)
}
}
}

test("load duplicated field names consistently with null or empty strings - case insensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
Seq("a,A,c,A,b,B").toDF().write.text(path.getAbsolutePath)
val actualSchema = spark.read
.format("csv")
.option("header", true)
.load(path.getAbsolutePath)
.schema
val fields = Seq("a0", "A1", "c", "A3", "b4", "B5").map(StructField(_, StringType, true))
val expectedSchema = StructType(fields)
assert(actualSchema == expectedSchema)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.spark.sql.execution.datasources.csv

import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import org.apache.spark.SparkFunSuite
Expand Down

0 comments on commit 90217f9

Please sign in to comment.