diff --git a/utils/src/main/scala/com/salesforce/op/test/TestCommon.scala b/utils/src/main/scala/com/salesforce/op/test/TestCommon.scala index 80df6791ea..6699b25671 100644 --- a/utils/src/main/scala/com/salesforce/op/test/TestCommon.scala +++ b/utils/src/main/scala/com/salesforce/op/test/TestCommon.scala @@ -45,6 +45,11 @@ import scala.reflect.{ClassTag, _} */ trait TestCommon extends Matchers with Assertions { + /** + * Returns the resource directory path + */ + protected def resourceDir: String = "src/test/resources" + /** * Set logging level for */ @@ -91,7 +96,7 @@ trait TestCommon extends Matchers with Assertions { * @param name resource name * @return resource file */ - def resourceFile(parent: String = "src/test/resources", name: String): File = { + def resourceFile(parent: String = resourceDir, name: String): File = { val file = new File(parent, name) if (!file.canRead) throw new IllegalStateException(s"File $file unreadable") file @@ -106,7 +111,7 @@ trait TestCommon extends Matchers with Assertions { * @return resource file */ @deprecated("Use loadResource", "3.2.3") - def resourceString(parent: String = "src/test/resources", noSpaces: Boolean = true, name: String): String = { + def resourceString(parent: String = resourceDir, noSpaces: Boolean = true, name: String): String = { val file = resourceFile(parent = parent, name = name) val contents = Source.fromFile(file, "UTF-8").mkString if (noSpaces) contents.replaceAll("\\s", "") else contents diff --git a/utils/src/main/scala/com/salesforce/op/utils/avro/RichGenericRecord.scala b/utils/src/main/scala/com/salesforce/op/utils/avro/RichGenericRecord.scala index bd810cc1f7..e3ada3688b 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/avro/RichGenericRecord.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/avro/RichGenericRecord.scala @@ -31,10 +31,6 @@ package com.salesforce.op.utils.avro import org.apache.avro.generic.GenericRecord -import scala.collection.JavaConverters._ - - -import scala.util.Try object RichGenericRecord { @@ -55,7 +51,7 @@ object RichGenericRecord { def getValue[T](fieldName: String): Option[T] = { // Check that field exists in schema require(Option(r.getSchema.getField(fieldName)).isDefined, - s"${fieldName} is not found in Avro schema!") + s"$fieldName is not found in Avro schema!") val field = Option(r.get(fieldName)) (field map { diff --git a/utils/src/test/resources/PassengerDataContentTypeMisMatch.csv b/utils/src/test/resources/PassengerDataContentTypeMisMatch.csv new file mode 100644 index 0000000000..397c6e09ef --- /dev/null +++ b/utils/src/test/resources/PassengerDataContentTypeMisMatch.csv @@ -0,0 +1,5 @@ +1,false,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S +2,fail,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",female,38,1,0,PC 17599,71.2833,C85,C +3,true,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S +4,true,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35,1,0,113803,53.1,C123,S +5,false,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S diff --git a/utils/src/test/resources/PassengerDataModifiedDataTypes.csv b/utils/src/test/resources/PassengerDataModifiedDataTypes.csv new file mode 100644 index 0000000000..279782e66a --- /dev/null +++ b/utils/src/test/resources/PassengerDataModifiedDataTypes.csv @@ -0,0 +1,5 @@ +1,false,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S +2,true,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",female,38,1,0,PC 17599,71.2833,C85,C +3,true,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S +4,true,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35,1,0,113803,53.1,C123,S +5,false,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S diff --git a/utils/src/test/resources/PassengerSchemaInvalidField.avsc b/utils/src/test/resources/PassengerSchemaInvalidField.avsc new file mode 100644 index 0000000000..5b5046dbc3 --- /dev/null +++ b/utils/src/test/resources/PassengerSchemaInvalidField.avsc @@ -0,0 +1,46 @@ +{ + "type" : "record", + "name" : "Passenger", + "namespace" : "com.salesforce.app.schema", + "fields" : [ { + "name" : "PassengerId", + "type" : [ "int", "null" ] + }, { + "name" : "Survived", + "type" : "boolean", + "default": false + }, { + "name" : "Pclass", + "type" : [ "int", "null" ] + }, { + "name" : "Name", + "type" : [ "string", "null" ] + }, { + "name" : "Sex", + "type" : [ "string", "null" ] + }, { + "name" : "Age", + "type" : [ "double", "null" ] + }, { + "name" : "SibSp", + "type" : [ "int", "null" ] + }, { + "name" : "Parch", + "type" : [ "long", "null" ] + }, { + "name" : "Ticket", + "type" : [ "string", "null" ] + }, { + "name" : "Fare", + "type" : [ "float", "null" ] + }, { + "name" : "Cabin", + "type" : [ "string", "null" ] + }, { + "name" : "Embarked", + "type" : [ "string", "null" ] + }, { + "name" : "FailTest", + "type" : [ "string", "null" ] + } ] +} diff --git a/utils/src/test/resources/PassengerSchemaModifiedDataTypes.avsc b/utils/src/test/resources/PassengerSchemaModifiedDataTypes.avsc new file mode 100644 index 0000000000..b318bf182e --- /dev/null +++ b/utils/src/test/resources/PassengerSchemaModifiedDataTypes.avsc @@ -0,0 +1,43 @@ +{ + "type" : "record", + "name" : "Passenger", + "namespace" : "com.salesforce.app.schema", + "fields" : [ { + "name" : "PassengerId", + "type" : [ "int", "null" ] + }, { + "name" : "Survived", + "type" : "boolean", + "default": false + }, { + "name" : "Pclass", + "type" : [ "int", "null" ] + }, { + "name" : "Name", + "type" : [ "string", "null" ] + }, { + "name" : "Sex", + "type" : [ "string", "null" ] + }, { + "name" : "Age", + "type" : [ "double", "null" ] + }, { + "name" : "SibSp", + "type" : [ "int", "null" ] + }, { + "name" : "Parch", + "type" : [ "long", "null" ] + }, { + "name" : "Ticket", + "type" : [ "string", "null" ] + }, { + "name" : "Fare", + "type" : [ "float", "null" ] + }, { + "name" : "Cabin", + "type" : [ "string", "null" ] + }, { + "name" : "Embarked", + "type" : [ "string", "null" ] + } ] +} diff --git a/utils/src/test/resources/PassengerSchemaNestedTypeCSV.avsc b/utils/src/test/resources/PassengerSchemaNestedTypeCSV.avsc new file mode 100644 index 0000000000..90655a7e93 --- /dev/null +++ b/utils/src/test/resources/PassengerSchemaNestedTypeCSV.avsc @@ -0,0 +1,47 @@ +{ + "type" : "record", + "name" : "Passenger", + "namespace" : "com.salesforce.app.schema", + "fields" : [ { + "name" : "PassengerId", + "type" : [ "int", "null" ] + }, { + "name" : "Survived", + "type" : "boolean", + "default": false + }, { + "name" : "Pclass", + "type" : [ "int", "null" ] + }, { + "name" : "Name", + "type" : [ "string", "null" ] + }, { + "name" : "Sex", + "type" : { + "name": "Sex", + "type": "enum", + "symbols": [ "male", "female" ] + } + }, { + "name" : "Age", + "type" : [ "double", "null" ] + }, { + "name" : "SibSp", + "type" : [ "int", "null" ] + }, { + "name" : "Parch", + "type" : [ "int", "null" ] + }, { + "name" : "Ticket", + "type" : [ "string", "null" ] + }, { + "name" : "Fare", + "type" : [ "double", "null" ] + }, { + "name" : "Cabin", + "type" : [ "string", "null" ] + }, { + "name" : "Embarked", + "type" : [ "string", "null" ] + } ] +} diff --git a/utils/src/test/scala/com/salesforce/op/utils/avro/RichGenericRecordTest.scala b/utils/src/test/scala/com/salesforce/op/utils/avro/RichGenericRecordTest.scala index 17432bcd66..a2ed6ea496 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/avro/RichGenericRecordTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/avro/RichGenericRecordTest.scala @@ -34,8 +34,8 @@ import com.salesforce.op.test.{TestCommon, TestSparkContext} import com.salesforce.op.utils.io.avro.AvroInOut import org.apache.avro.generic.GenericRecord import org.junit.runner.RunWith -import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner +import org.scalatest.{FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) @@ -51,33 +51,42 @@ class RichGenericRecordTest extends FlatSpec val firstRow = passengerData.first Spec[RichGenericRecord] should "get value of Int" in { - val id = firstRow.getValue[Int]("passengerId") id shouldBe Some(1) } + it should "get value of Double" in { val survived = firstRow.getValue[Double]("survived") survived shouldBe Some(0.0) } + it should "get value of Long" in { val height = firstRow.getValue[Long]("height") height shouldBe Some(168L) } + it should "get value of String" in { val gender = firstRow.getValue[String]("gender") gender shouldBe Some("Female") } + it should "get value of Char" in { val gender = firstRow.getValue[Char]("gender") gender shouldBe Some("Female") } + it should "get value of Float" in { val age = firstRow.getValue[Float]("age") age shouldBe Some(32.0) } + it should "get value of Short" in { val weight = firstRow.getValue[Short]("weight") weight shouldBe Some(67) + } + it should "throw error for invalid field" in { + val error = intercept[IllegalArgumentException](firstRow.getValue[Short]("invalidField")) + error.getMessage shouldBe "requirement failed: invalidField is not found in Avro schema!" } } diff --git a/utils/src/test/scala/com/salesforce/op/utils/io/csv/CSVToAvroTest.scala b/utils/src/test/scala/com/salesforce/op/utils/io/csv/CSVToAvroTest.scala new file mode 100644 index 0000000000..b858d9f656 --- /dev/null +++ b/utils/src/test/scala/com/salesforce/op/utils/io/csv/CSVToAvroTest.scala @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.utils.io.csv + +import com.salesforce.op.test.{Passenger, TestSparkContext} +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class CSVToAvroTest extends FlatSpec with TestSparkContext { + val avroSchema: String = loadFile(s"$resourceDir/PassengerSchemaModifiedDataTypes.avsc") + val csvReader: CSVInOut = new CSVInOut(CSVOptions(header = true)) + lazy val csvRDD: RDD[Seq[String]] = csvReader.readRDD(s"$resourceDir/PassengerDataModifiedDataTypes.csv") + lazy val csvFileRecordCount: Long = csvRDD.count + + Spec(CSVToAvro.getClass) should "convert RDD[Seq[String]] to RDD[GenericRecord]" in { + val res = CSVToAvro.toAvro(csvRDD, avroSchema) + res shouldBe a[RDD[_]] + res.count shouldBe csvFileRecordCount + } + + it should "convert RDD[Seq[String]] to RDD[T]" in { + val res = CSVToAvro.toAvroTyped[Passenger](csvRDD, avroSchema) + res shouldBe a[RDD[_]] + res.count shouldBe csvFileRecordCount + } + + it should "throw an error for nested schema" in { + val invalidAvroSchema = loadFile(s"$resourceDir/PassengerSchemaNestedTypeCSV.avsc") + val exceptionMsg = "CSV should be a flat file and not have nested records (unsupported column(Sex schemaType=ENUM)" + val error = intercept[SparkException](CSVToAvro.toAvro(csvRDD, invalidAvroSchema).count()) + error.getCause.getMessage shouldBe exceptionMsg + } + + it should "throw an error for mis-matching schema fields" in { + val invalidAvroSchema = loadFile(s"$resourceDir/PassengerSchemaInvalidField.avsc") + val error = intercept[SparkException](CSVToAvro.toAvro(csvRDD, invalidAvroSchema).count()) + error.getCause.getMessage shouldBe "Mismatch number of fields in csv record and avro schema" + } + + it should "throw an error for bad data" in { + val invalidDataRDD = csvReader.readRDD(s"$resourceDir/PassengerDataContentTypeMisMatch.csv") + val error = intercept[SparkException](CSVToAvro.toAvro(invalidDataRDD, avroSchema).count()) + error.getCause.getMessage shouldBe "Boolean column not actually a boolean. Invalid value: 'fail'" + } +}