Skip to content

Commit

Permalink
#740 Make sure unsigned binary fields can fit data types.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Jan 22, 2025
1 parent 6d48657 commit 888d212
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ object BinaryNumberDecoders {
if (v<0) null else v
}

def decodeBinaryUnsignedIntBigEndianAsLong(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 4) {
return null
}
val v: Long = ((bytes(0) & 255L) << 24L) | ((bytes(1) & 255L) << 16L) | ((bytes(2) & 255L) << 8L) | (bytes(3) & 255L)
if (v<0) null else v
}

def decodeBinaryUnsignedIntLittleEndian(bytes: Array[Byte]): Integer = {
if (bytes.length < 4) {
return null
Expand All @@ -90,6 +98,14 @@ object BinaryNumberDecoders {
if (v<0) null else v
}

def decodeBinaryUnsignedIntLittleEndianAsLong(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 4) {
return null
}
val v: Long = ((bytes(3) & 255L) << 24L) | ((bytes(2) & 255L) << 16L) | ((bytes(1) & 255L) << 8L) | (bytes(0) & 255L)
if (v<0) null else v
}

def decodeBinarySignedLongBigEndian(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 8) {
return null
Expand All @@ -112,6 +128,13 @@ object BinaryNumberDecoders {
if (v < 0L) null else v
}

def decodeBinaryUnsignedLongBigEndianAsDecimal(bytes: Array[Byte]): BigDecimal = {
if (bytes.length < 8) {
return null
}
BigDecimal(BigInt(1, bytes).toString())
}

def decodeBinaryUnsignedLongLittleEndian(bytes: Array[Byte]): java.lang.Long = {
if (bytes.length < 8) {
return null
Expand All @@ -120,6 +143,13 @@ object BinaryNumberDecoders {
if (v < 0L) null else v
}

def decodeBinaryUnsignedLongLittleEndianAsDecimal(bytes: Array[Byte]): BigDecimal = {
if (bytes.length < 8) {
return null
}
BigDecimal(BigInt(1, bytes.reverse).toString())
}

def decodeBinaryAribtraryPrecision(bytes: Array[Byte], isBigEndian: Boolean, isSigned: Boolean): BigDecimal = {
if (bytes.length == 0) {
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package za.co.absa.cobrix.cobol.parser.decoders
import java.nio.charset.{Charset, StandardCharsets}
import za.co.absa.cobrix.cobol.parser.ast.datatype._
import za.co.absa.cobrix.cobol.parser.common.Constants
import za.co.absa.cobrix.cobol.parser.common.Constants.{maxIntegerPrecision, maxLongPrecision}
import za.co.absa.cobrix.cobol.parser.decoders.FloatingPointFormat.FloatingPointFormat
import za.co.absa.cobrix.cobol.parser.encoding._
import za.co.absa.cobrix.cobol.parser.encoding.codepage.{CodePage, CodePageCommon}
Expand Down Expand Up @@ -255,26 +256,32 @@ object DecoderSelector {
val isSigned = signPosition.nonEmpty

val numOfBytes = BinaryUtils.getBytesCount(compact, precision, isSigned, isExplicitDecimalPt = false, isSignSeparate = false)
val isMaxUnsignedPrecision = precision == maxIntegerPrecision || precision == maxLongPrecision

val decoder = if (strictIntegralPrecision) {
(a: Array[Byte]) => BinaryNumberDecoders.decodeBinaryAribtraryPrecision(a, isBigEndian, isSigned)
} else {
(isSigned, isBigEndian, numOfBytes) match {
case (true, true, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, true, 2) => BinaryNumberDecoders.decodeBinarySignedShortBigEndian _
case (true, true, 4) => BinaryNumberDecoders.decodeBinarySignedIntBigEndian _
case (true, true, 8) => BinaryNumberDecoders.decodeBinarySignedLongBigEndian _
case (true, false, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, false, 2) => BinaryNumberDecoders.decodeBinarySignedShortLittleEndian _
case (true, false, 4) => BinaryNumberDecoders.decodeBinarySignedIntLittleEndian _
case (true, false, 8) => BinaryNumberDecoders.decodeBinarySignedLongLittleEndian _
case (false, true, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, true, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortBigEndian _
case (false, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndian _
case (false, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndian _
case (false, false, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, false, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortLittleEndian _
case (false, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndian _
case (false, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndian _
(isSigned, isBigEndian, isMaxUnsignedPrecision, numOfBytes) match {
case (true, true, _, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, true, _, 2) => BinaryNumberDecoders.decodeBinarySignedShortBigEndian _
case (true, true, _, 4) => BinaryNumberDecoders.decodeBinarySignedIntBigEndian _
case (true, true, _, 8) => BinaryNumberDecoders.decodeBinarySignedLongBigEndian _
case (true, false, _, 1) => BinaryNumberDecoders.decodeSignedByte _
case (true, false, _, 2) => BinaryNumberDecoders.decodeBinarySignedShortLittleEndian _
case (true, false, _, 4) => BinaryNumberDecoders.decodeBinarySignedIntLittleEndian _
case (true, false, _, 8) => BinaryNumberDecoders.decodeBinarySignedLongLittleEndian _
case (false, true, _, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, true, _, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortBigEndian _
case (false, true, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndian _
case (false, true, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntBigEndianAsLong _
case (false, true, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndian _
case (false, true, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongBigEndianAsDecimal _
case (false, false, _, 1) => BinaryNumberDecoders.decodeUnsignedByte _
case (false, false, _, 2) => BinaryNumberDecoders.decodeBinaryUnsignedShortLittleEndian _
case (false, false, false, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndian _
case (false, false, true, 4) => BinaryNumberDecoders.decodeBinaryUnsignedIntLittleEndianAsLong _
case (false, false, false, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndian _
case (false, false, true, 8) => BinaryNumberDecoders.decodeBinaryUnsignedLongLittleEndianAsDecimal _
case _ =>
(a: Array[Byte]) => BinaryNumberDecoders.decodeBinaryAribtraryPrecision(a, isBigEndian, isSigned)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,14 @@ class BinaryDecoderSpec extends AnyFunSuite {
val decoderUnsignedShort = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 3, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderSignedInt = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP4())), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntBe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntBeAsLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 9, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntLe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 8, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedIntLeAsLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 9, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderSignedLong = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP4())), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongBe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongBeAsBig = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 18, compact = Some(COMP5()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongLe = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 15, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)
val decoderUnsignedLongLeAsBig = DecoderSelector.getIntegralDecoder(integralType.copy(precision = 18, compact = Some(COMP9()), signPosition = None), strictSignOverpunch = false, improvedNullDetection = false, strictIntegralPrecision = false)

val num1 = decoderSignedByte(Array(0x10).map(_.toByte))
assert(num1.isInstanceOf[Integer])
Expand Down Expand Up @@ -501,10 +505,18 @@ class BinaryDecoderSpec extends AnyFunSuite {
assert(num9.isInstanceOf[Integer])
assert(num9.asInstanceOf[Integer] == 9437184)

val num9a = decoderUnsignedIntBeAsLong(Array(0x00, 0x90, 0x00, 0x00).map(_.toByte))
assert(num9a.isInstanceOf[java.lang.Long])
assert(num9a.asInstanceOf[java.lang.Long] == 9437184L)

val num10 = decoderUnsignedIntLe(Array(0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num10.isInstanceOf[Integer])
assert(num10.asInstanceOf[Integer] == 9437184)

val num10a = decoderUnsignedIntLeAsLong(Array(0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num10a.isInstanceOf[java.lang.Long])
assert(num10a.asInstanceOf[java.lang.Long] == 9437184L)

val num11 = decoderSignedLong(Array(0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00).map(_.toByte))
assert(num11.isInstanceOf[Long])
assert(num11.asInstanceOf[Long] == 72057594037927936L)
Expand All @@ -517,9 +529,17 @@ class BinaryDecoderSpec extends AnyFunSuite {
assert(num13.isInstanceOf[Long])
assert(num13.asInstanceOf[Long] == 40532396646334464L)

val num13a = decoderUnsignedLongBeAsBig(Array(0x00, 0x90, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00).map(_.toByte))
assert(num13a.isInstanceOf[BigDecimal])
assert(num13a.asInstanceOf[BigDecimal] == BigDecimal("40532396646334464"))

val num14 = decoderUnsignedLongLe(Array(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num14.isInstanceOf[Long])
assert(num14.asInstanceOf[Long] == 40532396646334464L)

val num14a = decoderUnsignedLongLeAsBig(Array(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x90, 0x00).map(_.toByte))
assert(num14a.isInstanceOf[BigDecimal])
assert(num14a.asInstanceOf[BigDecimal] == BigDecimal("40532396646334464"))
}

test("Test Binary strict integral precision numbers") {
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17a_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17b_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
2 changes: 1 addition & 1 deletion data/test17_expected/test17c_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
2 changes: 1 addition & 1 deletion data/test18 special_char_expected/test18a_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
}
}, {
"name" : "TAXPAYER",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
} ]
Expand Down
4 changes: 2 additions & 2 deletions data/test24_expected/test24_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
}
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down Expand Up @@ -760,7 +760,7 @@
}
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
4 changes: 2 additions & 2 deletions data/test24_expected/test24b_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@
}
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand Down Expand Up @@ -760,7 +760,7 @@
}
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
4 changes: 2 additions & 2 deletions data/test6_expected/test6_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"metadata" : { }
}, {
"name" : "NUM_BIN_INT07",
"type" : "integer",
"type" : "long",
"nullable" : true,
"metadata" : { }
}, {
Expand All @@ -319,7 +319,7 @@
"metadata" : { }
}, {
"name" : "NUM_BIN_INT11",
"type" : "long",
"type" : "decimal(20,0)",
"nullable" : true,
"metadata" : { }
}, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.types._
import za.co.absa.cobrix.cobol.internal.Logging
import za.co.absa.cobrix.cobol.parser.Copybook
import za.co.absa.cobrix.cobol.parser.ast._
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP1, COMP2, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.ast.datatype.{AlphaNumeric, COMP1, COMP2, COMP4, COMP5, COMP9, Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.common.Constants
import za.co.absa.cobrix.cobol.parser.encoding.RAW
import za.co.absa.cobrix.cobol.parser.policies.MetadataPolicy
Expand Down Expand Up @@ -66,23 +66,10 @@ class CobolSchema(copybook: Copybook,
@throws(classOf[IllegalStateException])
private[this] lazy val sparkSchema = createSparkSchema()

@throws(classOf[IllegalStateException])
private[this] lazy val sparkFlatSchema = {
val arraySchema = copybook.ast.children.toArray
val records = arraySchema.flatMap(record => {
parseGroupFlat(record.asInstanceOf[Group], s"${record.name}_")
})
StructType(records)
}

def getSparkSchema: StructType = {
sparkSchema
}

def getSparkFlatSchema: StructType = {
sparkFlatSchema
}

@throws(classOf[IllegalStateException])
private def createSparkSchema(): StructType = {
val records = for (record <- copybook.getRootRecords) yield {
Expand Down Expand Up @@ -200,12 +187,16 @@ class CobolSchema(copybook: Copybook,
case dt: Integral if strictIntegralPrecision =>
DecimalType(precision = dt.precision, scale = 0)
case dt: Integral =>
val isBinary = dt.compact.exists(c => c == COMP4() || c == COMP5() || c == COMP9())
if (dt.precision > Constants.maxLongPrecision) {
DecimalType(precision = dt.precision, scale = 0)
} else if (dt.precision == Constants.maxLongPrecision && isBinary && dt.signPosition.isEmpty) { // promoting unsigned int to long to be able to fit any value
DecimalType(precision = dt.precision + 2, scale = 0)
} else if (dt.precision > Constants.maxIntegerPrecision) {
LongType
}
else {
} else if (dt.precision == Constants.maxIntegerPrecision && isBinary && dt.signPosition.isEmpty) { // promoting unsigned long to decimal(20) to be able to fit any value
LongType
} else {
IntegerType
}
case _ => throw new IllegalStateException("Unknown AST object")
Expand Down Expand Up @@ -290,53 +281,6 @@ class CobolSchema(copybook: Copybook,
})
childSegments
}

@throws(classOf[IllegalStateException])
private def parseGroupFlat(group: Group, structPath: String = ""): ArrayBuffer[StructField] = {
val fields = new ArrayBuffer[StructField]()
for (field <- group.children if !field.isFiller) {
field match {
case group: Group =>
if (group.isArray) {
for (i <- Range(1, group.arrayMaxSize + 1)) {
val path = s"$structPath${group.name}_${i}_"
fields ++= parseGroupFlat(group, path)
}
} else {
val path = s"$structPath${group.name}_"
fields ++= parseGroupFlat(group, path)
}
case s: Primitive =>
val dataType: DataType = s.dataType match {
case d: Decimal =>
DecimalType(d.getEffectivePrecision, d.getEffectiveScale)
case a: AlphaNumeric =>
a.enc match {
case Some(RAW) => BinaryType
case _ => StringType
}
case dt: Integral =>
if (dt.precision > Constants.maxIntegerPrecision) {
LongType
}
else {
IntegerType
}
case _ => throw new IllegalStateException("Unknown AST object")
}
val path = s"$structPath" //${group.name}_"
if (s.isArray) {
for (i <- Range(1, s.arrayMaxSize + 1)) {
fields += StructField(s"$path{s.name}_$i", ArrayType(dataType), nullable = true)
}
} else {
fields += StructField(s"$path${s.name}", dataType, nullable = true)
}
}
}

fields
}
}

object CobolSchema {
Expand Down
Loading

0 comments on commit 888d212

Please sign in to comment.