Skip to content

Commit

Permalink
#394 Add a workaround implemetation
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Jul 8, 2021
1 parent 998a116 commit 0561251
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ case class CobolParameters(
copybookPath: Option[String],
multiCopybookPath: Seq[String],
copybookContent: Option[String],
sourcePath: Option[String],
sourcePaths: Seq[String],
isText: Boolean,
isEbcdic: Boolean,
ebcdicCodePage: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ object CobolParametersParser {
val PARAM_MULTI_COPYBOOK_PATH = "copybooks"
val PARAM_COPYBOOK_CONTENTS = "copybook_contents"
val PARAM_SOURCE_PATH = "path"
val PARAM_SOURCE_PATHS = "paths"
val PARAM_ENCODING = "encoding"
val PARAM_PEDANTIC = "pedantic"
val PARAM_RECORD_LENGTH_FIELD = "record_length_field"
Expand Down Expand Up @@ -208,11 +209,13 @@ object CobolParametersParser {
}
}

val paths = getParameter(PARAM_SOURCE_PATHS, params).map(_.split(',')).getOrElse(Array(getParameter(PARAM_SOURCE_PATH, params).get))

val cobolParameters = CobolParameters(
getParameter(PARAM_COPYBOOK_PATH, params),
params.getOrElse(PARAM_MULTI_COPYBOOK_PATH, "").split(','),
getParameter(PARAM_COPYBOOK_CONTENTS, params),
getParameter(PARAM_SOURCE_PATH, params),
paths,
params.getOrElse(PARAM_IS_TEXT, "false").toBoolean,
isEbcdic,
ebcdicCodePageName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ import scala.util.control.NonFatal


class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit =
private def writeObject(out: ObjectOutputStream): Unit =
try {
out.defaultWriteObject()
value.write(out)
out.defaultWriteObject()
value.write(out)
} catch {
case NonFatal(e) =>
throw new IOException(e)
}

private def readObject(in: ObjectInputStream): Unit =
private def readObject(in: ObjectInputStream): Unit =
try {
value = new Configuration(false)
value.readFields(in)
value = new Configuration(false)
value.readFields(in)
} catch {
case NonFatal(e) =>
throw new IOException(e)
Expand All @@ -63,18 +63,18 @@ class SerializableConfiguration(@transient var value: Configuration) extends Ser
*
* Its constructor is expected to change after the hierarchy of [[za.co.absa.cobrix.spark.cobol.reader.Reader]] is put in place.
*/
class CobolRelation(sourceDir: String,
class CobolRelation(sourceDirs: Seq[String],
cobolReader: Reader,
localityParams: LocalityParameters,
debugIgnoreFileSize: Boolean
)(@transient val sqlContext: SQLContext)
extends BaseRelation
with Serializable
with TableScan {
with Serializable
with TableScan {

private val logger = LoggerFactory.getLogger(this.getClass)

private val filesList = getListFilesWithOrder(sourceDir)
private val filesList = getListFilesWithOrder(sourceDirs)

private lazy val indexes: RDD[SparseIndexEntry] = IndexBuilder.buildIndex(filesList, cobolReader, sqlContext)(localityParams)

Expand All @@ -83,12 +83,11 @@ class CobolRelation(sourceDir: String,
}

override def buildScan(): RDD[Row] = {

cobolReader match {
case blockReader: FixedLenTextReader =>
CobolScanners.buildScanForTextFiles(blockReader, sourceDir, parseRecords, sqlContext)
CobolScanners.buildScanForTextFiles(blockReader, sourceDirs, parseRecords, sqlContext)
case blockReader: FixedLenReader =>
CobolScanners.buildScanForFixedLength(blockReader, sourceDir, parseRecords, debugIgnoreFileSize, sqlContext)
CobolScanners.buildScanForFixedLength(blockReader, sourceDirs, parseRecords, debugIgnoreFileSize, sqlContext)
case streamReader: VarLenReader if streamReader.isIndexGenerationNeeded =>
CobolScanners.buildScanForVarLenIndex(streamReader, indexes, filesList, sqlContext)
case streamReader: VarLenReader =>
Expand All @@ -104,13 +103,15 @@ class CobolRelation(sourceDir: String,
*
* The List contains [[za.co.absa.cobrix.spark.cobol.source.types.FileWithOrder]] instances.
*/
private def getListFilesWithOrder(sourceDir: String): Array[FileWithOrder] = {
private def getListFilesWithOrder(sourceDirs: Seq[String]): Array[FileWithOrder] = {
val allFiles = sourceDirs.flatMap(sourceDir => {
FileUtils
.getFiles(sourceDir, sqlContext.sparkContext.hadoopConfiguration, isRecursiveRetrieval)
}).toArray

FileUtils
.getFiles(sourceDir, sqlContext.sparkContext.hadoopConfiguration, isRecursiveRetrieval)
allFiles
.zipWithIndex
.map(file => FileWithOrder(file._1, file._2))
.toArray
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DefaultSource
val cobolParameters = CobolParametersParser.parse(new Parameters(parameters))
CobolParametersValidator.checkSanity(cobolParameters)

new CobolRelation(parameters(PARAM_SOURCE_PATH),
new CobolRelation(cobolParameters.sourcePaths,
buildEitherReader(sqlContext.sparkSession, cobolParameters),
LocalityParameters.extract(cobolParameters),
cobolParameters.debugIgnoreFileSize)(sqlContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import za.co.absa.cobrix.spark.cobol.utils.FileNameUtils
object CobolParametersValidator {

def checkSanity(params: CobolParameters) = {

if (params.sourcePath.isEmpty) {
if (params.sourcePaths.isEmpty) {
throw new IllegalArgumentException("Data source path must be specified.")
}

Expand All @@ -53,8 +52,14 @@ object CobolParametersValidator {
val copyBookPathFileName = parameters.get(PARAM_COPYBOOK_PATH)
val copyBookMultiPathFileNames = parameters.get(PARAM_MULTI_COPYBOOK_PATH)

parameters.getOrElse(PARAM_SOURCE_PATH, throw new IllegalStateException(s"Cannot define path to source files: missing " +
s"parameter: '$PARAM_SOURCE_PATH'"))
if (!parameters.isDefinedAt(PARAM_SOURCE_PATH) && !parameters.isDefinedAt(PARAM_SOURCE_PATHS)) {
throw new IllegalStateException(s"Cannot define path to data files: missing " +
s"parameter: '$PARAM_SOURCE_PATH' or '$PARAM_SOURCE_PATHS'. It is automatically set when you invoke .load()")
}

if (parameters.isDefinedAt(PARAM_SOURCE_PATH) && parameters.isDefinedAt(PARAM_SOURCE_PATHS)) {
throw new IllegalStateException(s"Only one of '$PARAM_SOURCE_PATH' or '$PARAM_SOURCE_PATHS' should be defined.")
}

def validatePath(fileName: String): Unit = {
val (isLocalFS, copyBookFileName) = FileNameUtils.getCopyBookFileName(fileName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[source] object CobolScanners {
})
}

private[source] def buildScanForFixedLength(reader: FixedLenReader, sourceDir: String,
private[source] def buildScanForFixedLength(reader: FixedLenReader, sourceDirs: Seq[String],
recordParser: (FixedLenReader, RDD[Array[Byte]]) => RDD[Row],
debugIgnoreFileSize: Boolean,
sqlContext: SQLContext): RDD[Row] = {
Expand All @@ -85,21 +85,24 @@ private[source] object CobolScanners {

val recordSize = reader.getRecordSize

if (!debugIgnoreFileSize && areThereNonDivisibleFiles(sourceDir, sqlContext.sparkContext.hadoopConfiguration, recordSize)) {
throw new IllegalArgumentException(s"There are some files in $sourceDir that are NOT DIVISIBLE by the RECORD SIZE calculated from the copybook ($recordSize bytes per record). Check the logs for the names of the files.")
}
sourceDirs.foreach(sourceDir => {
if (!debugIgnoreFileSize && areThereNonDivisibleFiles(sourceDir, sqlContext.sparkContext.hadoopConfiguration, recordSize)) {
throw new IllegalArgumentException(s"There are some files in $sourceDir that are NOT DIVISIBLE by the RECORD SIZE calculated from the copybook ($recordSize bytes per record). Check the logs for the names of the files.")
}
})

val records = sqlContext.sparkContext.binaryRecords(sourceDir, recordSize, sqlContext.sparkContext.hadoopConfiguration)
val records = sourceDirs.map(sourceDir => sqlContext.sparkContext.binaryRecords(sourceDir, recordSize, sqlContext.sparkContext.hadoopConfiguration))
.reduce((a ,b) => a.union(b))
recordParser(reader, records)
}

private[source] def buildScanForTextFiles(reader: FixedLenReader, sourceDir: String,
private[source] def buildScanForTextFiles(reader: FixedLenReader, sourceDirs: Seq[String],
recordParser: (FixedLenReader, RDD[Array[Byte]]) => RDD[Row],
sqlContext: SQLContext): RDD[Row] = {
sqlContext.read.text()

val rddText = sqlContext.sparkContext
.textFile(sourceDir)
val rddText = sourceDirs.map(sourceDir => sqlContext.sparkContext.textFile(sourceDir))
.reduce((a,b) => a.union(b))

val records = rddText
.filter(str => str.length > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CobolRelationSpec extends SparkCobolTestBase with Serializable {

it should "return an RDD[Row] if data are correct" in {
val testReader: FixedLenReader = new DummyFixedLenReader(sparkSchema, cobolSchema, testData)(() => Unit)
val relation = new CobolRelation(copybookFile.getParentFile.getAbsolutePath,
val relation = new CobolRelation(Seq(copybookFile.getParentFile.getAbsolutePath),
testReader,
localityParams = localityParams,
debugIgnoreFileSize = false)(sqlContext)
Expand All @@ -85,7 +85,7 @@ class CobolRelationSpec extends SparkCobolTestBase with Serializable {
it should "manage exceptions from Reader" in {
val exceptionMessage = "exception expected message"
val testReader: FixedLenReader = new DummyFixedLenReader(sparkSchema, cobolSchema, testData)(() => throw new Exception(exceptionMessage))
val relation = new CobolRelation(copybookFile.getParentFile.getAbsolutePath,
val relation = new CobolRelation(Seq(copybookFile.getParentFile.getAbsolutePath),
testReader,
localityParams = localityParams,
debugIgnoreFileSize = false)(sqlContext)
Expand All @@ -100,7 +100,7 @@ class CobolRelationSpec extends SparkCobolTestBase with Serializable {
val absentField = "absentField"
val modifiedSparkSchema = sparkSchema.add(StructField(absentField, StringType, false))
val testReader: FixedLenReader = new DummyFixedLenReader(modifiedSparkSchema, cobolSchema, testData)(() => Unit)
val relation = new CobolRelation(copybookFile.getParentFile.getAbsolutePath,
val relation = new CobolRelation(Seq(copybookFile.getParentFile.getAbsolutePath),
testReader,
localityParams = localityParams,
debugIgnoreFileSize = false)(sqlContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package za.co.absa.cobrix.spark.cobol.source.integration

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.scalatest.WordSpec
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture
Expand All @@ -29,27 +30,26 @@ class Test28MultipartLoadSpec extends WordSpec with SparkTestBase with BinaryFil
private val copybook =
""" 01 R.
03 A PIC X(2).
03 B PIC X(1).
"""
private val data1 = "AABBBCCDDDEEFFFZYY"
private val data2 = "BAABBBCCDDDEEFFFZY"
private val data1 = "010203040506070809"
private val data2 = "101112131415161718"

"Multipart path spec" should {
"load avv available copybooks" in {
val expected = """"""
val expected = """[{"A":"01"},{"A":"02"},{"A":"03"},{"A":"04"},{"A":"05"},{"A":"06"},{"A":"07"},{"A":"08"},{"A":"09"},{"A":"10"},{"A":"11"},{"A":"12"},{"A":"13"},{"A":"14"},{"A":"15"},{"A":"16"},{"A":"17"},{"A":"18"}]"""

withTempBinFile("rec_len1", ".dat", data1.getBytes) { tmpFileName1 =>
withTempBinFile("rec_len2", ".dat", data2.getBytes) { tmpFileName2 =>
val df = getDataFrame(Seq(tmpFileName1, tmpFileName2))

val actual = df
.orderBy(col("A"))
.toJSON
.collect()
.mkString("[", ",", "]")

intercept[IllegalStateException] {
val df = getDataFrame(Seq(tmpFileName1, tmpFileName1))

val actual = df.toJSON.collect().mkString("[", ",", "]")
}

//assert(df.count() == 12)
//assert(actual == expected)
assert(df.count() == 18)
assert(actual == expected)
}
}
}
Expand All @@ -61,10 +61,10 @@ class Test28MultipartLoadSpec extends WordSpec with SparkTestBase with BinaryFil
.format("cobol")
.option("copybook_contents", copybook)
.option("encoding", "ascii")
.option("record_length", "2")
.option("schema_retention_policy", "collapse_root")
.option("paths", inputPaths.mkString(","))
.options(extraOptions)
.load(inputPaths: _*)
.load()
}


Expand Down

0 comments on commit 0561251

Please sign in to comment.