Skip to content

Commit

Permalink
#678 Add a method to convert integrals to decimals in schema accordin…
Browse files Browse the repository at this point in the history
…g to the metadata.
  • Loading branch information
yruslan committed May 30, 2024
1 parent 8997831 commit 6f4d166
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 16 deletions.
11 changes: 10 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ lazy val sparkCobol = (project in file("spark-cobol"))
log.info(s"Building with Spark ${sparkVersion(scalaVersion.value)}, Scala ${scalaVersion.value}")
sparkVersion(scalaVersion.value)
},
(Compile / compile) := ((Compile / compile) dependsOn printSparkVersion).value,
Compile / compile := ((Compile / compile) dependsOn printSparkVersion).value,
Compile / unmanagedSourceDirectories += {
val sourceDir = (Compile / sourceDirectory).value
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, n)) if n == 11 => sourceDir / "scala_2.11"
case Some((2, n)) if n == 12 => sourceDir / "scala_2.12"
case Some((2, n)) if n == 13 => sourceDir / "scala_2.13"
case _ => throw new RuntimeException("Unsupported Scala version")
}
},
libraryDependencies ++= SparkCobolDependencies(scalaVersion.value) :+ getScalaDependency(scalaVersion.value),
dependencyOverrides ++= SparkCobolDependenciesOverride,
Test / fork := true, // Spark tests fail randomly otherwise
Expand Down
36 changes: 28 additions & 8 deletions spark-cobol/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,33 @@
</dependency>
</dependencies>

<build>
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
</build>
<build>
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/scala_${scala.compat.version}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package za.co.absa.cobrix.spark.cobol.utils
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.SparkContext
import org.apache.spark.sql.functions.{concat_ws, expr, max}
import org.apache.spark.sql.functions.{array, col, expr, max, struct}
import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import za.co.absa.cobrix.cobol.internal.Logging
Expand Down Expand Up @@ -178,6 +179,48 @@ object SparkUtils extends Logging {
df.select(fields.toSeq: _*)
}

def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = {
def mapField(column: Column, field: StructField): Column = {
field.dataType match {
case st: StructType =>
val columns = st.fields.map(f => mapField(column.getField(field.name), f))
struct(columns: _*).as(field.name)
case ar: ArrayType =>
mapArray(ar, column, field.name).as(field.name)
case _ =>
f(field, column).as(field.name)
}
}

def mapArray(arr: ArrayType, column: Column, columnName: String): Column = {
arr.elementType match {
case st: StructType =>
transform(column, c => {
val columns = st.fields.map(f => mapField(c.getField(f.name), f))
struct(columns: _*)
})
case ar: ArrayType =>
array(mapArray(ar, column, columnName))
case p =>
array(f(StructField(columnName, p), column))
}
}

val columns = df.schema.fields.map(f => mapField(col(f.name), f))
df.select(columns: _*)
}

def covertIntegralToDecimal(df: DataFrame): DataFrame = {
mapPrimitives(df) { (field, c) =>
val metadata = field.metadata
if (metadata.contains("precision") && (field.dataType == LongType || field.dataType == IntegerType || field.dataType == ShortType)) {
val precision = metadata.getLong("precision").toInt
c.cast(DecimalType(precision, 0)).as(field.name)
} else {
c
}
}
}

/**
* Given an instance of DataFrame returns a dataframe where all primitive fields are converted to String
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.spark.cobol.utils.impl

import org.apache.spark.sql.Column

object HofsWrapper {
/**
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
* from functional programming.
*
* The method is not available in Scala 2.11 and Spark < 3.0
*/
def transform(
array: Column,
f: Column => Column): Column = {
throw new IllegalArgumentException("Array transformation is not available for Scala 2.11 and Spark < 3.0.")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.spark.cobol.utils.impl

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{transform => sparkTransform}

object HofsWrapper {
/**
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
* from functional programming.
*
* (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
*
* @param array A column of arrays
* @param f A function transforming individual elements of the array
* @return A column of arrays with transformed elements
*/
def transform(
array: Column,
f: Column => Column): Column = {
sparkTransform(array, f)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.spark.cobol.utils.impl

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{transform => sparkTransform}

object HofsWrapper {
/**
* Applies the function `f` to every element in the `array`. The method is an equivalent to the `map` function
* from functional programming.
*
* (The idea comes from https://github.com/AbsaOSS/spark-hats/blob/v0.3.0/src/main/scala_2.12/za/co/absa/spark/hats/HofsWrapper.scala)
*
* @param array A column of arrays
* @param f A function transforming individual elements of the array
* @return A column of arrays with transformed elements
*/
def transform(
array: Column,
f: Column => Column): Column = {
sparkTransform(array, f)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.cobrix.spark.cobol.source.fixtures

import org.scalatest.{Assertion, Suite}

trait TextComparisonFixture {
this: Suite =>

protected def compareText(actual: String, expected: String): Assertion = {
if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
fail(renderTextDifference(actual, expected))
} else {
succeed
}
}

protected def compareTextVertical(actual: String, expected: String): Unit = {
if (actual.replaceAll("[\r\n]", "") != expected.replaceAll("[\r\n]", "")) {
fail(s"ACTUAL:\n$actual\nEXPECTED: \n$expected")
}
}

protected def renderTextDifference(textActual: String, textExpected: String): String = {
val t1 = textActual.replaceAll("\\r\\n", "\\n").split('\n')
val t2 = textExpected.replaceAll("\\r\\n", "\\n").split('\n')

val maxLen = Math.max(getMaxStrLen(t1), getMaxStrLen(t2))
val header = s" ${rightPad("ACTUAL:", maxLen)} ${rightPad("EXPECTED:", maxLen)}\n"

val stringBuilder = new StringBuilder
stringBuilder.append(header)

val linesCount = Math.max(t1.length, t2.length)
var i = 0
while (i < linesCount) {
val a = if (i < t1.length) t1(i) else ""
val b = if (i < t2.length) t2(i) else ""

val marker1 = if (a != b) ">" else " "
val marker2 = if (a != b) "<" else " "

val comparisonText = s"$marker1${rightPad(a, maxLen)} ${rightPad(b, maxLen)}$marker2\n"
stringBuilder.append(comparisonText)

i += 1
}

val footer = s"\nACTUAL:\n$textActual"
stringBuilder.append(footer)
stringBuilder.toString()
}

def getMaxStrLen(text: Seq[String]): Int = {
if (text.isEmpty) {
0
} else {
text.maxBy(_.length).length
}
}

def rightPad(s: String, length: Int): String = {
if (s.length < length) {
s + " " * (length - s.length)
} else if (s.length > length) {
s.take(length)
} else {
s
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

package za.co.absa.cobrix.spark.cobol.utils

import org.apache.spark.sql.types.{ArrayType, LongType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
import org.slf4j.LoggerFactory
import za.co.absa.cobrix.spark.cobol.source.fixtures.BinaryFileFixture
import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase
import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture}
import za.co.absa.cobrix.spark.cobol.utils.TestUtils._

import java.nio.charset.StandardCharsets
import scala.collection.immutable
import scala.util.Properties

class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture {
class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixture with TextComparisonFixture {

import spark.implicits._

Expand Down Expand Up @@ -377,7 +377,7 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
assert(dfFlattened.count() == 0)
}

test("Schema with multiple OCCURS should properly determine array sized") {
test("Schema with multiple OCCURS should properly determine array sizes") {
val copyBook: String =
""" 01 RECORD.
| 02 COUNT PIC 9(1).
Expand Down Expand Up @@ -429,6 +429,46 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt
}
}

test("Integral to decimal conversion for complex schema") {
val expectedSchema =
"""|root
| |-- COUNT: decimal(1,0) (nullable = true)
| |-- GROUP: array (nullable = true)
| | |-- element: struct (containsNull = false)
| | | |-- INNER_COUNT: decimal(1,0) (nullable = true)
| | | |-- INNER_GROUP: array (nullable = true)
| | | | |-- element: struct (containsNull = false)
| | | | | |-- FIELD: decimal(1,0) (nullable = true)
|""".stripMargin

val copyBook: String =
""" 01 RECORD.
| 02 COUNT PIC 9(1).
| 02 GROUP OCCURS 2 TIMES.
| 03 INNER-COUNT PIC S9(1).
| 03 INNER-GROUP OCCURS 3 TIMES.
| 04 FIELD PIC 9.
|""".stripMargin

withTempTextFile("fletten", "test", StandardCharsets.UTF_8, "") { filePath =>
val df = spark.read
.format("cobol")
.option("copybook_contents", copyBook)
.option("pedantic", "true")
.option("record_format", "D")
.option("metadata", "extended")
.load(filePath)

if (!Properties.versionString.startsWith("2.")) {
// This method only works with Scala 2.12+ and Spark 3.0+
val actualDf = SparkUtils.covertIntegralToDecimal(df)
val actualSchema = actualDf.schema.treeString

compareText(actualSchema, expectedSchema)
}
}
}

private def assertSchema(actualSchema: String, expectedSchema: String): Unit = {
if (actualSchema != expectedSchema) {
logger.error(s"EXPECTED:\n$expectedSchema")
Expand Down

0 comments on commit 6f4d166

Please sign in to comment.