diff --git a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala index 91335afe4e6..14e0d4e0970 100644 --- a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala +++ b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala @@ -16,21 +16,22 @@ package org.apache.spark.sql.tests.datagen +import com.fasterxml.jackson.core.{JsonFactoryBuilder, JsonParser, JsonToken} +import com.fasterxml.jackson.core.json.JsonReadFeature import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime} import java.util - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode import scala.util.Random -import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, XXH64} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{approx_count_distinct, avg, coalesce, col, count, lit, stddev, struct, transform, udf, when} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom @@ -79,22 +80,28 @@ class RowLocation(val rowNum: Long, val subRows: Array[Int] = null) { * hash. This makes the generated data correlated for all column/child columns. * @param tableNum a unique ID for the table this is a part of. * @param columnNum the location of the column in the data being generated + * @param substringNum the location of the substring column * @param correlatedKeyGroup the correlated key group this column is a part of, if any. */ -case class ColumnLocation(tableNum: Int, columnNum: Int, correlatedKeyGroup: Option[Long] = None) { - def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1) +case class ColumnLocation(tableNum: Int, + columnNum: Int, + substringNum: Int, + correlatedKeyGroup: Option[Long] = None) { + def forNextColumn(): ColumnLocation = ColumnLocation(tableNum, columnNum + 1, 0) + def forNextSubstring: ColumnLocation = ColumnLocation(tableNum, columnNum, substringNum + 1) /** * Create a new ColumnLocation that is specifically for a given key group */ def forCorrelatedKeyGroup(keyGroup: Long): ColumnLocation = - ColumnLocation(tableNum, columnNum, Some(keyGroup)) + ColumnLocation(tableNum, columnNum, substringNum, Some(keyGroup)) /** * Hash the location into a single long value. */ - lazy val hashLoc: Long = XXH64.hashLong(tableNum, correlatedKeyGroup.getOrElse(columnNum)) + lazy val hashLoc: Long = XXH64.hashLong(tableNum, + correlatedKeyGroup.getOrElse(XXH64.hashLong(columnNum, substringNum))) } /** @@ -115,6 +122,9 @@ case class ColumnConf(columnLoc: ColumnLocation, def forNextColumn(nullable: Boolean): ColumnConf = ColumnConf(columnLoc.forNextColumn(), nullable, numTableRows) + def forNextSubstring: ColumnConf = + ColumnConf(columnLoc.forNextSubstring, nullable = true, numTableRows) + /** * Create a new configuration based on this, but for a given correlated key group. */ @@ -303,6 +313,23 @@ case class VarLengthGeneratorFunction(minLength: Int, maxLength: Int) extends } } +case class StdDevLengthGen(mean: Double, + stdDev: Double, + mapping: LocationToSeedMapping = null) extends + LengthGeneratorFunction { + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): LengthGeneratorFunction = + StdDevLengthGen(mean, stdDev, mapping) + + override def apply(rowLoc: RowLocation): Int = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val g = r.nextGaussian() // g has a mean of 0 and a stddev of 1.0 + val adjusted = mean + (g * stdDev) + // If the range of seed is too small compared to the stddev and mean we will + // end up with an invalid distribution, but they asked for it. + math.max(0, math.round(adjusted).toInt) + } +} + /** * Generate nulls with a given probability. * @param prob 0.0 to 1.0 for how often nulls should appear in the output. @@ -562,11 +589,8 @@ case class DataGenExpr(child: Expression, } } -/** - * Base class for generating a column/sub-column. This holds configuration for the column, - * and handles what is needed to convert it into GeneratorFunction - */ -abstract class DataGen(var conf: ColumnConf, +abstract class CommonDataGen( + var conf: ColumnConf, defaultValueRange: Option[(Any, Any)], var seedMapping: LocationToSeedMapping = FlatDistribution(), var nullMapping: LocationToSeedMapping = FlatDistribution(), @@ -576,26 +600,25 @@ abstract class DataGen(var conf: ColumnConf, protected var valueRange: Option[(Any, Any)] = defaultValueRange /** - * Set a value range for this data gen. + * Set a value range */ - def setValueRange(min: Any, max: Any): DataGen = { + def setValueRange(min: Any, max: Any): CommonDataGen = { valueRange = Some((min, max)) this } /** - * Set a custom GeneratorFunction to use for this column. + * Set a custom GeneratorFunction */ - def setValueGen(f: GeneratorFunction): DataGen = { + def setValueGen(f: GeneratorFunction): CommonDataGen = { userProvidedValueGen = Some(f) this } /** - * Set a NullGeneratorFunction for this column. This will not be used - * if the column is not nullable. + * Set a NullGeneratorFunction */ - def setNullGen(f: NullGeneratorFunction): DataGen = { + def setNullGen(f: NullGeneratorFunction): CommonDataGen = { this.userProvidedNullGen = Some(f) this } @@ -604,12 +627,12 @@ abstract class DataGen(var conf: ColumnConf, * Set the probability of a null appearing in the output. The probability should be * 0.0 to 1.0. */ - def setNullProbability(probability: Double): DataGen = { + def setNullProbability(probability: Double): CommonDataGen = { this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) this } - def setNullProbabilityRecursively(probability: Double): DataGen = { + def setNullProbabilityRecursively(probability: Double): CommonDataGen = { this.userProvidedNullGen = Some(NullProbabilityGenerationFunction(probability)) children.foreach { case (_, dataGen) => @@ -621,7 +644,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a specific location to seed mapping for the value generation. */ - def setSeedMapping(seedMapping: LocationToSeedMapping): DataGen = { + def setSeedMapping(seedMapping: LocationToSeedMapping): CommonDataGen = { this.seedMapping = seedMapping this } @@ -629,7 +652,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a specific location to seed mapping for the null generation. */ - def setNullMapping(nullMapping: LocationToSeedMapping): DataGen = { + def setNullMapping(nullMapping: LocationToSeedMapping): CommonDataGen = { this.nullMapping = nullMapping this } @@ -638,7 +661,7 @@ abstract class DataGen(var conf: ColumnConf, * Set a specific LengthGeneratorFunction to use. This will only be used if * the datatype needs a length. */ - def setLengthGen(lengthGen: LengthGeneratorFunction): DataGen = { + def setLengthGen(lengthGen: LengthGeneratorFunction): CommonDataGen = { this.lengthGen = lengthGen this } @@ -646,25 +669,30 @@ abstract class DataGen(var conf: ColumnConf, /** * Set the length generation to be a fixed length. */ - def setLength(len: Int): DataGen = { + def setLength(len: Int): CommonDataGen = { this.lengthGen = FixedLengthGeneratorFunction(len) this } - def setLength(minLen: Int, maxLen: Int) = { + def setLength(minLen: Int, maxLen: Int): CommonDataGen = { this.lengthGen = VarLengthGeneratorFunction(minLen, maxLen) this } + def setGaussianLength(mean: Double, stdDev: Double): CommonDataGen = { + this.lengthGen = StdDevLengthGen(mean, stdDev) + this + } + /** * Add this column to a specific correlated key group. This should not be * called directly by users. */ def setCorrelatedKeyGroup(keyGroup: Long, - minSeed: Long, maxSeed: Long, - seedMapping: LocationToSeedMapping): DataGen = { + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): CommonDataGen = { conf = conf.forCorrelatedKeyGroup(keyGroup) - .forSeedRange(minSeed, maxSeed) + .forSeedRange(minSeed, maxSeed) this.seedMapping = seedMapping this } @@ -672,7 +700,7 @@ abstract class DataGen(var conf: ColumnConf, /** * Set a range of seed values that should be returned by the LocationToSeedMapping */ - def setSeedRange(min: Long, max: Long): DataGen = { + def setSeedRange(min: Long, max: Long): CommonDataGen = { conf = conf.forSeedRange(min, max) this } @@ -681,7 +709,7 @@ abstract class DataGen(var conf: ColumnConf, * Get the default value generator for this specific data gen. */ protected def getValGen: GeneratorFunction - def children: Seq[(String, DataGen)] + def children: Seq[(String, CommonDataGen)] /** * Get the final ready to use GeneratorFunction for the data generator. @@ -690,8 +718,8 @@ abstract class DataGen(var conf: ColumnConf, val sm = seedMapping.withColumnConf(conf) val lg = lengthGen.withLocationToSeedMapping(sm) var valGen = userProvidedValueGen.getOrElse(getValGen) - .withLocationToSeedMapping(sm) - .withLengthGeneratorFunction(lg) + .withLocationToSeedMapping(sm) + .withLengthGeneratorFunction(lg) valueRange.foreach { case (min, max) => valGen = valGen.withValueRange(min, max) @@ -700,35 +728,75 @@ abstract class DataGen(var conf: ColumnConf, val nullColConf = conf.forNulls val nm = nullMapping.withColumnConf(nullColConf) userProvidedNullGen.get - .withWrapped(valGen) - .withLocationToSeedMapping(nm) + .withWrapped(valGen) + .withLocationToSeedMapping(nm) } else { valGen } } - /** - * Get the data type for this column - */ - def dataType: DataType - /** * Is this column nullable or not. */ def nullable: Boolean = conf.nullable /** - * Get a child column for a given name, if it has one. + * Get a child for a given name, if it has one. */ - final def apply(name: String): DataGen = { + final def apply(name: String): CommonDataGen = { get(name).getOrElse{ throw new IllegalStateException(s"Could not find a child $name for $this") } } - def get(name: String): Option[DataGen] = None + def get(name: String): Option[CommonDataGen] = None +} + + +/** + * Base class for generating a column/sub-column. This holds configuration + * for the column, and handles what is needed to convert it into GeneratorFunction + */ +abstract class DataGen( + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + seedMapping: LocationToSeedMapping = FlatDistribution(), + nullMapping: LocationToSeedMapping = FlatDistribution(), + lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) extends + CommonDataGen(conf, defaultValueRange, seedMapping, nullMapping, lengthGen) { + + /** + * Get the data type for this column + */ + def dataType: DataType + + override def get(name: String): Option[DataGen] = None + + def getSubstringGen: Option[SubstringDataGen] = None + + def substringGen: SubstringDataGen = + getSubstringGen.getOrElse( + throw new IllegalArgumentException("substring data gen was not set")) + + def setSubstringGen(f : ColumnConf => SubstringDataGen): Unit = + setSubstringGen(Option(f(conf.forNextSubstring))) + + def setSubstringGen(subgen: Option[SubstringDataGen]): Unit = + throw new IllegalArgumentException("substring data gens can only be set for a STRING") } +/** + * Base class for generating a sub-string. This holds configuration + * for the substring, and handles what is needed to convert it into a GeneratorFunction + */ +abstract class SubstringDataGen( + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + seedMapping: LocationToSeedMapping = FlatDistribution(), + nullMapping: LocationToSeedMapping = FlatDistribution(), + lengthGen: LengthGeneratorFunction = FixedLengthGeneratorFunction(10)) extends + CommonDataGen(conf, defaultValueRange, seedMapping, nullMapping, lengthGen) {} + /** * A special GeneratorFunction that just returns the computed seed. This is helpful for * debugging distributions or if you want long values without any abstraction in between. @@ -1494,155 +1562,866 @@ class FloatGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) override def children: Seq[(String, DataGen)] = Seq.empty } -trait JSONType { - def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit -} +case class JsonPathElement(name: String, is_array: Boolean) +case class JsonLevel(path: Array[JsonPathElement], data_type: String, length: Int, value: String) {} + +object JsonColumnStats { + private def printHelp(): Unit = { + println("JSON Fingerprinting Tool:") + println("PARAMS: ") + println(" is a path to a Spark dataframe to read in") + println(" is a path in a Spark file system to write out fingerprint data to.") + println() + println("OPTIONS:") + println(" --json= where is the name of a top level String column") + println(" --anon= where is a SEED used to anonymize the JSON keys ") + println(" and column names.") + println(" --input_format= where is parquet or ORC. Defaults to parquet.") + println(" --overwrite to enable overwriting the fingerprint output.") + println(" --debug to enable some debug information to be printed out") + println(" --help to print out this help message") + println() + } + + def main(args: Array[String]): Unit = { + var inputPath = Option.empty[String] + var outputPath = Option.empty[String] + val jsonColumns = ArrayBuffer.empty[String] + var anonSeed = Option.empty[Long] + var debug = false + var argsDone = false + var format = "parquet" + var overwrite = false + + args.foreach { + case a if !argsDone && a.startsWith("--json=") => + jsonColumns += a.substring("--json=".length) + case a if !argsDone && a.startsWith("--anon=") => + anonSeed = Some(a.substring("--anon=".length).toLong) + case a if !argsDone && a.startsWith("--input_format=") => + format = a.substring("--input_format=".length).toLowerCase(java.util.Locale.US) + case "--overwrite" if !argsDone => + overwrite = true + case "--debug" if !argsDone => + debug = true + case "--help" if !argsDone => + printHelp() + System.exit(0) + case "--" if !argsDone => + argsDone = true + case a if !argsDone && a.startsWith("--") => // "--" was covered above already + println(s"ERROR $a is not a supported argument") + printHelp() + System.exit(-1) + case a if inputPath.isEmpty => + inputPath = Some(a) + case a if outputPath.isEmpty => + outputPath = Some(a) + case a => + println(s"ERROR only two arguments are supported. Found $a") + printHelp() + System.exit(-1) + } + if (outputPath.isEmpty) { + println("ERROR both an inputPath and an outputPath are required") + printHelp() + System.exit(-1) + } + + val spark = SparkSession.builder.getOrCreate() + spark.sparkContext.setLogLevel("WARN") + + val df = spark.read.format(format).load(inputPath.get) + jsonColumns.foreach { column => + val fp = fingerPrint(df, df(column), anonSeed) + val name = anonSeed.map(s => anonymizeString(column, s)).getOrElse(column) + val fullOutPath = s"${outputPath.get}/$name" + var writer = fp.write + if (overwrite) { + writer = writer.mode("overwrite") + } + if (debug) { + anonSeed.foreach { s => + println(s"Keys and columns will be anonymized with seed $s") + } + println(s"Writing $column fingerprint to $fullOutPath") + spark.time(writer.parquet(fullOutPath)) + println(s"Wrote ${spark.read.parquet(fullOutPath).count} rows") + spark.read.parquet(fullOutPath).show() + } else { + writer.parquet(fullOutPath) + } + } + } -object JSONType { - def selectType(depth: Int, - maxDepth: Int, - r: Random): JSONType = { - val toSelectFrom = if (depth < maxDepth) { - Seq(QuotedJSONString, JSONLong, JSONDouble, JSONArray, JSONObject) - } else { - Seq(QuotedJSONString, JSONLong, JSONDouble) - } - val index = r.nextInt(toSelectFrom.length) - toSelectFrom(index) - } -} - -object QuotedJSONString extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val strValue = r.nextString(r.nextInt(maxStringLength + 1)) - .replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\b", "\\b") - .replace("\f", "\\f") - sb.append('"') - sb.append(strValue) - sb.append('"') - } -} - -object JSONLong extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - sb.append(r.nextLong()) - } -} - -object JSONDouble extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - sb.append(r.nextDouble() * 4096.0) - } -} - -object JSONArray extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val childType = JSONType.selectType(depth, maxDepth, r) - val length = r.nextInt(maxArrayLength + 1) - sb.append("[") + case class JsonNodeStats(count: Long, meanLen: Double, stdDevLength: Double, dc: Long) + + class JsonNode() { + private val forDataType = + mutable.HashMap[String, (JsonNodeStats, mutable.HashMap[String, JsonNode])]() + + def getChild(name: String, isArray: Boolean): JsonNode = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + val typed = forDataType.getOrElse(dt, + throw new IllegalArgumentException(s"$dt is not a set data type yet.")) + typed._2.getOrElse(name, + throw new IllegalArgumentException(s"$name is not a child when the type is $dt")) + } + + def contains(name: String, isArray: Boolean): Boolean = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + forDataType.get(dt).exists { children => + children._2.contains(name) + } + } + + def addChild(name: String, isArray: Boolean): JsonNode = { + val dt = if (isArray) { "ARRAY" } else { "OBJECT" } + val found = forDataType.getOrElse(dt, + throw new IllegalArgumentException(s"$dt was not already added as a data type")) + if (found._2.contains(name)) { + throw new IllegalArgumentException(s"$dt already has a child named $name") + } + val node = new JsonNode() + found._2.put(name, node) + node + } + + def addChoice(dt: String, stats: JsonNodeStats): Unit = { + if (forDataType.contains(dt)) { + throw new IllegalArgumentException(s"$dt was already added as a data type") + } + forDataType.put(dt, (stats, new mutable.HashMap[String, JsonNode]())) + } + + override def toString: String = { + forDataType.toString() + } + + def totalCount: Long = { + forDataType.values.map{ case (stats, _) => stats.count}.sum + } + + private def makeNoChoiceGenRecursive(dt: String, + children: mutable.HashMap[String, JsonNode], + cc: ColumnConf): (SubstringDataGen, ColumnConf) = { + var c = cc + val ret = dt match { + case "LONG" => new JSONLongGen(c) + case "DOUBLE" => new JSONDoubleGen(c) + case "BOOLEAN" => new JSONBoolGen(c) + case "NULL" => new JSONNullGen(false, c) + case "VALUE_NULL" => new JSONNullGen(true, c) + case "ERROR" => new JSONErrorGen(c) + case "STRING" => new JSONStringGen(c) + case "ARRAY" => + val child = if (children.isEmpty) { + // A corner case, we will just make it a BOOL column and it will be ignored + val tmp = new JSONBoolGen(c) + c = c.forNextSubstring + tmp + } else { + val tmp = children.values.head.makeGenRecursive(c) + c = tmp._2 + tmp._1 + } + new JSONArrayGen(child, c) + case "OBJECT" => + val childGens = if (children.isEmpty) { + Seq.empty + } else { + children.toSeq.map { + case (k, node) => + val tmp = node.makeGenRecursive(c) + c = tmp._2 + (k, tmp._1) + } + } + new JSONObjectGen(childGens, c) + case other => + throw new IllegalArgumentException(s"$other is not a leaf node type") + } + (ret, c.forNextSubstring) + } + + private def makeGenRecursive(cc: ColumnConf): (SubstringDataGen, ColumnConf) = { + var c = cc + // We are going to recursively walk the tree for all of the values. + if (forDataType.size == 1) { + // We don't need a choice at all. This makes it simpler.. + val (dt, (_, children)) = forDataType.head + makeNoChoiceGenRecursive(dt, children, c) + } else { + val totalSum = forDataType.map(f => f._2._1.count).sum.toDouble + var runningSum = 0L + val allChoices = ArrayBuffer[(Double, String, SubstringDataGen)]() + forDataType.foreach { + case (dt, (stats, children)) => + val tmp = makeNoChoiceGenRecursive(dt, children, c) + c = tmp._2 + runningSum += stats.count + allChoices.append((runningSum/totalSum, dt, tmp._1)) + } + + val ret = new JSONChoiceGen(allChoices.toSeq, c) + (ret, c.forNextSubstring) + } + } + + def makeGen(cc: ColumnConf): SubstringDataGen = { + val (ret, _) = makeGenRecursive(cc) + ret + } + + def setStatsSingle(dg: CommonDataGen, + dt: String, + stats: JsonNodeStats, + nullPct: Double): Unit = { + + val includeLength = dt != "OBJECT" && dt != "BOOLEAN" && dt != "NULL" && dt != "VALUE_NULL" + val includeNullPct = nullPct > 0.0 + if (includeLength) { + dg.setGaussianLength(stats.meanLen, stats.stdDevLength) + } + if (includeNullPct) { + dg.setNullProbability(nullPct) + } + dg.setSeedRange(1, stats.dc) + } + + def setStats(dg: CommonDataGen, + parentCount: Option[Long]): Unit = { + // We are going to recursively walk the tree... + if (forDataType.size == 1) { + // We don't need a choice at all. This makes it simpler.. + val (dt, (stats, children)) = forDataType.head + val nullPct = parentCount.map { pc => + (pc - stats.count).toDouble/pc + }.getOrElse(0.0) + setStatsSingle(dg, dt, stats, nullPct) + val myCount = if (dt == "OBJECT") { + Some(totalCount) + } else { + None + } + children.foreach { + case (name, node) => + node.setStats(dg(name), myCount) + } + } else { + // We have choices to make between different types. + // The null percent cannot be calculated for each individual choice + // but is calculated on the group as a whole instead + parentCount.foreach { pc => + val tc = totalCount + val choiceNullPct = (pc - tc).toDouble / pc + if (choiceNullPct > 0.0) { + dg.setNullProbability(choiceNullPct) + } + } + forDataType.foreach { + case (dt, (stats, children)) => + // When there is a choice the name to access it is the data type + val choiceDg = dg(dt) + setStatsSingle(choiceDg, dt, stats, 0.0) + children.foreach { + case (name, node) => + val myCount = if (dt == "OBJECT") { + // Here we only want the count for the OBJECTs + Some(stats.count) + } else { + None + } + node.setStats(choiceDg(name), myCount) + } + } + } + } + } + + private lazy val jsonFactory = new JsonFactoryBuilder() + // The two options below enabled for Hive compatibility + .enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS) + .enable(JsonReadFeature.ALLOW_SINGLE_QUOTES) + .build() + + private def processNext(parser: JsonParser, + currentPath: ArrayBuffer[JsonPathElement], + output: ArrayBuffer[JsonLevel]): Unit = { + parser.currentToken() match { + case JsonToken.START_OBJECT => + parser.nextToken() + while (parser.currentToken() != JsonToken.END_OBJECT) { + processNext(parser, currentPath, output) + } + output.append(JsonLevel(currentPath.toArray, "OBJECT", 0, "")) + parser.nextToken() + case JsonToken.START_ARRAY => + currentPath.append(JsonPathElement("data", is_array = true)) + parser.nextToken() + var length = 0 + while (parser.currentToken() != JsonToken.END_ARRAY) { + length += 1 + processNext(parser, currentPath, output) + } + currentPath.remove(currentPath.length - 1) + output.append(JsonLevel(currentPath.toArray, "ARRAY", length, "")) + parser.nextToken() + case JsonToken.FIELD_NAME => + currentPath.append(JsonPathElement(parser.getCurrentName, is_array = false)) + parser.nextToken() + processNext(parser, currentPath, output) + currentPath.remove(currentPath.length - 1) + case JsonToken.VALUE_NUMBER_INT => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "LONG", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_NUMBER_FLOAT => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "DOUBLE", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_TRUE | JsonToken.VALUE_FALSE => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "BOOLEAN", length, parser.getValueAsString)) + parser.nextToken() + case JsonToken.VALUE_NULL | null => + output.append(JsonLevel(currentPath.toArray, "VALUE_NULL", 4, "NULL")) + parser.nextToken() + case JsonToken.VALUE_STRING => + val length = parser.getValueAsString.getBytes("UTF-8").length + output.append(JsonLevel(currentPath.toArray, "STRING", length, parser.getValueAsString)) + parser.nextToken() + case other => + throw new IllegalStateException(s"DON'T KNOW HOW TO DEAL WITH $other") + } + } + + def jsonStatsUdf(json: String): Array[JsonLevel] = { + val output = new ArrayBuffer[JsonLevel]() + try { + val currentPath = new ArrayBuffer[JsonPathElement]() + if (json == null) { + output.append(JsonLevel(Array.empty, "NULL", 0, "")) + } else { + val parser = jsonFactory.createParser(json) + try { + parser.nextToken() + processNext(parser, currentPath, output) + } finally { + parser.close() + } + } + } catch { + case _: com.fasterxml.jackson.core.JsonParseException => + output.clear() + output.append(JsonLevel(Array.empty, "ERROR", json.getBytes("UTF-8").length, json)) + } + output.toArray + } + + private lazy val extractPaths = udf(json => jsonStatsUdf(json)) + + def anonymizeString(str: String, seed: Long): String = { + val length = str.length + val data = new Array[Byte](length) + val hash = XXH64.hashLong(str.hashCode, seed) + val r = new Random() + r.setSeed(hash) (0 until length).foreach { i => - if (i > 0) { - sb.append(",") + val tmp = r.nextInt(16) + data(i) = (tmp + 'A').toByte + } + new String(data) + } + + private lazy val anonPath = udf((str, seed) => anonymizeString(str, seed)) + + def anonymizeFingerPrint(df: DataFrame, anonSeed: Long): DataFrame = { + df.withColumn("tmp", transform(col("path"), + o => { + val name = o("name") + val isArray = o("is_array") + val anon = anonPath(name, lit(anonSeed)) + val newName = when(isArray, name).otherwise(anon).alias("name") + struct(newName, isArray) + })) + .drop("path").withColumnRenamed("tmp", "path") + .orderBy("path", "dt") + .selectExpr("path", "dt","c","mean_len","stddev_len","distinct","version") + } + + def fingerPrint(df: DataFrame, column: Column, anonymize: Option[Long] = None): DataFrame = { + val ret = df.select(extractPaths(column).alias("paths")) + .selectExpr("explode_outer(paths) as p") + .selectExpr("p.path as path", "p.data_type as dt", "p.length as len", "p.value as value") + .groupBy(col("path"), col("dt")).agg( + count(lit(1)).alias("c"), + avg(col("len")).alias("mean_len"), + coalesce(stddev(col("len")), lit(0.0)).alias("stddev_len"), + approx_count_distinct(col("value")).alias("distinct")) + .orderBy("path", "dt").withColumn("version", lit("0.1")) + .selectExpr("path", "dt","c","mean_len","stddev_len","distinct","version") + + anonymize.map { anonSeed => + anonymizeFingerPrint(ret, anonSeed) + }.getOrElse(ret) + } + + def apply(aggForColumn: DataFrame, genColumn: ColumnGen): Unit = + apply(aggForColumn, genColumn.dataGen) + + private val expectedSchema = StructType.fromDDL( + "path ARRAY>," + + "dt STRING," + + "c BIGINT," + + "mean_len DOUBLE," + + "stddev_len DOUBLE," + + "distinct BIGINT," + + "version STRING") + + def apply(aggForColumn: DataFrame, gen: DataGen): Unit = { + val aggData = aggForColumn.orderBy("path", "dt").collect() + val rootNode: JsonNode = new JsonNode() + assert(aggData.length > 0) + val schema = aggData.head.schema + assert(schema.length == expectedSchema.length) + schema.fields.zip(expectedSchema.fields).foreach { + case(found, expected) => + assert(found.name == expected.name) + // TODO we can worry about the exact types later if we need to + } + assert(aggData.head.getString(6) == "0.1") + aggData.foreach { row => + val fullPath = row.getAs[mutable.WrappedArray[Row]](0) + val parsedPath = fullPath.map(r => (r.getString(0), r.getBoolean(1))).toList + val dt = row.getString(1) + val count = row.getLong(2) + val meanLen = row.getDouble(3) + val stdLen = row.getDouble(4) + val dc = row.getLong(5) + + val stats = JsonNodeStats(count, meanLen, stdLen, dc) + var currentNode = rootNode + // Find everything up to the last path element + if (parsedPath.length > 1) { + parsedPath.slice(0, parsedPath.length - 1).foreach { + case (name, isArray) => + currentNode = currentNode.getChild(name, isArray) + } + } + + if (parsedPath.nonEmpty) { + // For the last path element (that is not the root element) we might need to add it + // as a child + val (name, isArray) = parsedPath.last + if (!currentNode.contains(name, isArray)) { + currentNode.addChild(name, isArray) + } + currentNode = currentNode.getChild(name, isArray) } - childType.appendRandomValue(sb, i, maxStringLength, maxArrayLength, maxObjectLength, - depth + 1, maxDepth, r) + currentNode.addChoice(dt, stats) } - sb.append("]") + + gen.setSubstringGen(cc => rootNode.makeGen(cc)) + rootNode.setStats(gen.substringGen, None) } } -object JSONObject extends JSONType { - override def appendRandomValue(sb: StringBuilder, - index: Int, - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - depth: Int, - maxDepth: Int, - r: Random): Unit = { - val length = r.nextInt(maxObjectLength) + 1 - sb.append("{") - (0 until length).foreach { i => - if (i > 0) { - sb.append(",") + +case class JSONStringGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + // Value range is 32 (Space) to 126 (~) + buffer(at) = (r.nextInt(126 - 31) + 32).toByte + at += 1 + } + val strVal = new String(buffer, 0, len) + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\b", "\\b") + .replace("\f", "\\f") + '"' + strVal + '"' + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONStringGenFunc = + JSONStringGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONStringGenFunc = + JSONStringGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONStringGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONStringGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONLongGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = math.max(lengthGen(rowLoc), 1) // We need at least 1 long for a valid value + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + if (at == 0) { + // No leading 0's + buffer(at) = (r.nextInt(9) + '1').toByte + } else { + buffer(at) = (r.nextInt(10) + '0').toByte } - sb.append("\"key_") - sb.append(i) - sb.append("_") - sb.append(depth ) - sb.append("\":") - val childType = JSONType.selectType(depth, maxDepth, r) - childType.appendRandomValue(sb, i, maxStringLength, maxArrayLength, maxObjectLength, - depth + 1, maxDepth, r) + at += 1 } - sb.append("}") + new String(buffer, 0, len) } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONLongGenFunc = + JSONLongGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONLongGenFunc = + JSONLongGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") } -case class JSONGenFunc( - maxStringLength: Int, - maxArrayLength: Int, - maxObjectLength: Int, - maxDepth: Int, - lengthGen: LengthGeneratorFunction = null, - mapping: LocationToSeedMapping = null) extends GeneratorFunction { +class JSONLongGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONLongGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONDoubleGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { override def apply(rowLoc: RowLocation): Any = { + val len = math.max(lengthGen(rowLoc), 3) // We have to have at least 3 chars NUM.NUM val r = DataGen.getRandomFor(rowLoc, mapping) - val sb = new StringBuilder() - JSONObject.appendRandomValue(sb, 0, maxStringLength, maxArrayLength, maxObjectLength, - 0, maxDepth, r) - // For now I am going to have some hard coded keys - UTF8String.fromString(sb.toString()) + val beforeLen = if (len == 3) { 1 } else { r.nextInt(len - 3) + 1 } + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + if (at == 0) { + // No leading 0's + buffer(at) = (r.nextInt(9) + '1').toByte + } else if (at == beforeLen) { + buffer(at) = '.' + } else { + buffer(at) = (r.nextInt(10) + '0').toByte + } + at += 1 + } + UTF8String.fromBytes(buffer, 0, len) } - override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = - JSONGenFunc(maxStringLength, maxArrayLength, maxObjectLength, maxDepth, lengthGen, mapping) + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONDoubleGenFunc = + JSONDoubleGenFunc(lengthGen, mapping) - override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = - JSONGenFunc(maxStringLength, maxArrayLength, maxObjectLength, maxDepth, lengthGen, mapping) + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONDoubleGenFunc = + JSONDoubleGenFunc(lengthGen, mapping) override def withValueRange(min: Any, max: Any): GeneratorFunction = - throw new IllegalArgumentException("value ranges are not supported for strings") + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONDoubleGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONDoubleGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONBoolGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val ret = if (r.nextBoolean()) "true" else "false" + UTF8String.fromString(ret) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONBoolGenFunc = + JSONBoolGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONBoolGenFunc = + JSONBoolGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONBoolGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONBoolGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONNullGenFunc(nullAsString: Boolean, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = + if (nullAsString) { + UTF8String.fromString("null") + } else { + null + } + + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONNullGenFunc = + JSONNullGenFunc(nullAsString, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONNullGenFunc = + JSONNullGenFunc(nullAsString, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONNullGen(nullAsString: Boolean, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONNullGenFunc(nullAsString) + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONErrorGenFunc(lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val r = DataGen.getRandomFor(rowLoc, mapping) + val buffer = new Array[Byte](len) + var at = 0 + while (at < len) { + // Value range is 32 (Space) to 126 (~) + // But it is almost impossible to show up as valid JSON + buffer(at) = (r.nextInt(126 - 31) + 32).toByte + at += 1 + } + UTF8String.fromBytes(buffer, 0, len) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONErrorGenFunc = + JSONErrorGenFunc(lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONErrorGenFunc = + JSONErrorGenFunc(lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONErrorGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override protected def getValGen: GeneratorFunction = JSONErrorGenFunc() + + override def children: Seq[(String, SubstringDataGen)] = Seq.empty +} + +case class JSONArrayGenFunc(child: GeneratorFunction, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + val len = lengthGen(rowLoc) + val data = new Array[String](len) + val childRowLoc = rowLoc.withNewChild() + var i = 0 + while (i < len) { + childRowLoc.setLastChildIndex(i) + val v = child(childRowLoc) + if (v == null) { + // A null in an array must look like "null" + data(i) = "null" + } else { + data(i) = v.toString + } + i += 1 + } + val ret = data.mkString("[", ",", "]") + UTF8String.fromString(ret) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONArrayGenFunc = + JSONArrayGenFunc(child, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONArrayGenFunc = + JSONArrayGenFunc(child, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONArrayGen(child: SubstringDataGen, + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + child.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + this + } + + override protected def getValGen: GeneratorFunction = JSONArrayGenFunc(child.getGen) + + override def get(name: String): Option[SubstringDataGen] = { + if ("data".equalsIgnoreCase(name) || "child".equalsIgnoreCase(name)) { + Some(child) + } else { + None + } + } + + override def children: Seq[(String, SubstringDataGen)] = Seq(("data", child)) +} + +case class JSONObjectGenFunc(childGens: Array[(String, GeneratorFunction)], + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = { + // TODO randomize the order of the children??? + // TODO duplicate child values??? + // The row location does not change for a struct/object + val data = childGens.map { + case (k, gen) => + val key = k.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\b", "\\b") + .replace("\f", "\\f") + val v = gen.apply(rowLoc) + if (v == null) { + "" + } else { + '"' + key + "\":" + v + } + } + val ret = data.filterNot(_.isEmpty).mkString("{",",","}") + UTF8String.fromString(ret) + } + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONObjectGenFunc = + JSONObjectGenFunc(childGens, lengthGen, mapping) + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONObjectGenFunc = + JSONObjectGenFunc(childGens, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONObjectGen(val children: Seq[(String, SubstringDataGen)], + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + children.foreach { + case (_, gen) => + gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + } + this + } + + override def get(name: String): Option[SubstringDataGen] = + children.collectFirst { + case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen + } + + override protected def getValGen: GeneratorFunction = { + val childGens = children.map(c => (c._1, c._2.getGen)).toArray + JSONObjectGenFunc(childGens) + } +} + +case class JSONChoiceGenFunc(choices: List[(Double, GeneratorFunction)], + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + override def apply(rowLoc: RowLocation): Any = { + val r = DataGen.getRandomFor(rowLoc, mapping) + val l = r.nextDouble() + var index = 0 + while (choices(index)._1 < l) { + index += 1 + } + val childRowLoc = rowLoc.withNewChild() + choices(index)._2(childRowLoc) + } + + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): JSONChoiceGenFunc = + JSONChoiceGenFunc(choices, lengthGen, mapping) + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): JSONChoiceGenFunc = + JSONChoiceGenFunc(choices, lengthGen, mapping) + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for JSON") +} + +class JSONChoiceGen(val choices: Seq[(Double, String, SubstringDataGen)], + conf: ColumnConf, + defaultValueRange: Option[(Any, Any)] = None) + extends SubstringDataGen(conf, defaultValueRange) { + + override val children: Seq[(String, SubstringDataGen)] = + choices.map { case (_, name, gen) => (name, gen) } + + override def setCorrelatedKeyGroup(keyGroup: Long, + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): SubstringDataGen = { + super.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + children.foreach { + case (_, gen) => + gen.setCorrelatedKeyGroup(keyGroup, minSeed, maxSeed, seedMapping) + } + this + } + + override def get(name: String): Option[SubstringDataGen] = + children.collectFirst { + case (childName, dataGen) if childName.equalsIgnoreCase(name) => dataGen + } + + override protected def getValGen: GeneratorFunction = { + val childGens = choices.map(c => (c._1, c._3.getGen)).toList + JSONChoiceGenFunc(childGens) + } } case class ASCIIGenFunc( @@ -1672,14 +2451,46 @@ case class ASCIIGenFunc( throw new IllegalArgumentException("value ranges are not supported for strings") } -class StringGen(conf: ColumnConf, defaultValueRange: Option[(Any, Any)]) - extends DataGen(conf, defaultValueRange) { +/** + * This is here to wrap the substring gen function so that its length/settings + * are the ones used when generating a string, and not what was set for the string. + */ +case class SubstringGenFunc( + substringGen: GeneratorFunction, + lengthGen: LengthGeneratorFunction = null, + mapping: LocationToSeedMapping = null) extends GeneratorFunction { + + override def apply(rowLoc: RowLocation): Any = { + substringGen(rowLoc) + } + + // The length and location seed mapping are just ignored for this... + override def withLengthGeneratorFunction(lengthGen: LengthGeneratorFunction): GeneratorFunction = + this + + override def withLocationToSeedMapping(mapping: LocationToSeedMapping): GeneratorFunction = + this + + override def withValueRange(min: Any, max: Any): GeneratorFunction = + throw new IllegalArgumentException("value ranges are not supported for strings") +} + +class StringGen(conf: ColumnConf, + defaultValueRange: Option[(Any, Any)], + var substringDataGen: Option[SubstringDataGen] = None) + extends DataGen(conf, defaultValueRange) { override def dataType: DataType = StringType - override protected def getValGen: GeneratorFunction = ASCIIGenFunc() + override protected def getValGen: GeneratorFunction = + substringDataGen.map(s => SubstringGenFunc(s.getGen)).getOrElse(ASCIIGenFunc()) override def children: Seq[(String, DataGen)] = Seq.empty + + override def setSubstringGen(subgen: Option[SubstringDataGen]): Unit = + substringDataGen = subgen + + override def getSubstringGen: Option[SubstringDataGen] = substringDataGen } case class StructGenFunc(childGens: Array[GeneratorFunction]) extends GeneratorFunction { @@ -1854,7 +2665,6 @@ class MapGen(key: DataGen, override def children: Seq[(String, DataGen)] = Seq(("key", key), ("value", value)) } - object ColumnGen { private def genInternal(rowNumber: Column, dataType: DataType, @@ -1869,8 +2679,8 @@ object ColumnGen { */ class ColumnGen(val dataGen: DataGen) { def setCorrelatedKeyGroup(kg: Long, - minSeed: Long, maxSeed: Long, - seedMapping: LocationToSeedMapping): ColumnGen = { + minSeed: Long, maxSeed: Long, + seedMapping: LocationToSeedMapping): ColumnGen = { dataGen.setCorrelatedKeyGroup(kg, minSeed, maxSeed, seedMapping) this } @@ -1930,6 +2740,11 @@ class ColumnGen(val dataGen: DataGen) { this } + def setGaussianLength(mean: Double, stdDev: Double): ColumnGen = { + dataGen.setGaussianLength(mean, stdDev) + this + } + final def apply(name: String): DataGen = { get(name).getOrElse { throw new IllegalArgumentException(s"$name not a child of $this") @@ -1941,8 +2756,16 @@ class ColumnGen(val dataGen: DataGen) { def gen(rowNumber: Column): Column = { ColumnGen.genInternal(rowNumber, dataGen.dataType, dataGen.nullable, dataGen.getGen) } + + def getSubstring: Option[SubstringDataGen] = dataGen.getSubstringGen + + def substringGen: SubstringDataGen = dataGen.substringGen + + def setSubstringGen(f : ColumnConf => SubstringDataGen): Unit = + dataGen.setSubstringGen(f) } + sealed trait KeyGroupType /** @@ -2192,7 +3015,7 @@ object DBGen { numRows: Long, mapping: OrderedTypeMapping): Seq[(String, ColumnGen)] = { // a bit of a hack with the column num so that we update it before each time... - var conf = ColumnConf(ColumnLocation(tableId, -1), true, numRows) + var conf = ColumnConf(ColumnLocation(tableId, -1, 0), true, numRows) st.toArray.map { sf => if (!mapping.canMap(sf.dataType, mapping)) { throw new IllegalArgumentException(s"$sf is not supported at this time")