From b7e07e340f93440f6337f653a4a35497401b4819 Mon Sep 17 00:00:00 2001 From: Kevin Moore Date: Tue, 21 Jan 2020 14:30:33 -0800 Subject: [PATCH] Add support for ignoring text that looks like IDs in SmartTextMapVectorizer (#455) --- .../OPCollectionHashingVectorizer.scala | 25 ++-- .../impl/feature/SmartTextMapVectorizer.scala | 131 ++++++++++++++---- .../impl/feature/SmartTextVectorizer.scala | 4 +- .../feature/SmartTextMapVectorizerTest.scala | 113 ++++++++++++++- .../salesforce/op/features/types/Lists.scala | 9 ++ .../features/types/FeatureTypeValueTest.scala | 10 +- 6 files changed, 249 insertions(+), 43 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OPCollectionHashingVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OPCollectionHashingVectorizer.scala index c646db6d06..c3c4588c1d 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OPCollectionHashingVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OPCollectionHashingVectorizer.scala @@ -314,42 +314,49 @@ private[op] trait MapHashingFun extends HashingFun { protected def makeVectorColumnMetadata ( - features: Array[TransientFeature], + hashFeatures: Array[TransientFeature], + ignoreFeatures: Array[TransientFeature], params: HashingFunctionParams, - allKeys: Seq[Seq[String]], + hashKeys: Seq[Seq[String]], + ignoreKeys: Seq[Seq[String]], shouldTrackNulls: Boolean, shouldTrackLen: Boolean ): Array[OpVectorColumnMetadata] = { val numHashes = params.numFeatures - val numFeatures = allKeys.map(_.length).sum + val numFeatures = hashKeys.map(_.length).sum val hashColumns = if (isSharedHashSpace(params, Some(numFeatures))) { (0 until numHashes).map { i => OpVectorColumnMetadata( - parentFeatureName = features.map(_.name), - parentFeatureType = features.map(_.typeName), + parentFeatureName = hashFeatures.map(_.name), + parentFeatureType = hashFeatures.map(_.typeName), grouping = None, indicatorValue = None ) }.toArray } else { for { - (keys, f) <- allKeys.toArray.zip(features) + // Need to filter out empty key sequences since the hashFeatures only contain a map feature if one of their + // keys is to be hashed, but hashKeys contains a sequence per map (whether it's empty or not) + (keys, f) <- hashKeys.filter(_.nonEmpty).zip(hashFeatures) key <- keys i <- 0 until numHashes } yield f.toColumnMetaData().copy(grouping = Option(key)) - } + }.toArray + // All columns get null tracking or text length tracking, whether their contents are hashed or ignored + val allTextKeys = hashKeys.zip(ignoreKeys).map{ case(h, i) => h ++ i } + val allTextFeatures = hashFeatures ++ ignoreFeatures val nullColumns = if (shouldTrackNulls) { for { - (keys, f) <- allKeys.toArray.zip(features) + (keys, f) <- allTextKeys.toArray.zip(allTextFeatures) key <- keys } yield f.toColumnMetaData(isNull = true).copy(grouping = Option(key)) } else Array.empty[OpVectorColumnMetadata] val lenColumns = if (shouldTrackLen) { for { - (keys, f) <- allKeys.toArray.zip(features) + (keys, f) <- allTextKeys.toArray.zip(allTextFeatures) key <- keys } yield f.toColumnMetaData(descriptorValue = OpVectorColumnMetadata.TextLenString).copy(grouping = Option(key)) } else Array.empty[OpVectorColumnMetadata] diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala index f149f5abba..03f4b47d99 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala @@ -39,7 +39,7 @@ import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.twitter.algebird.Monoid._ import com.twitter.algebird.Operators._ -import com.twitter.algebird.Monoid +import com.twitter.algebird.{Monoid, Semigroup} import com.twitter.algebird.macros.caseclass import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.{Dataset, Encoder} @@ -63,7 +63,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]] with PivotParams with CleanTextFun with SaveOthersParams with TrackNullsParam with MinSupportParam with TextTokenizerParams with TrackTextLenParam with HashingVectorizerParams with MapHashingFun with OneHotFun with MapStringPivotHelper - with MapVectorizerFuns[String, OPMap[String]] with MaxCardinalityParams { + with MapVectorizerFuns[String, OPMap[String]] with MaxCardinalityParams with MinLengthStdDevParams { private implicit val textMapStatsSeqEnc: Encoder[Array[TextMapStats]] = ExpressionEncoder[Array[TextMapStats]]() @@ -73,7 +73,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]] ): TextMapStats = { val keyValueCounts = textMap.map{ case (k, v) => cleanTextFn(k, shouldCleanKeys) -> - TextStats(Map(cleanTextFn(v, shouldCleanValues) -> 1), Map(cleanTextFn(v, shouldCleanValues).length -> 1)) + TextStats(Map(cleanTextFn(v, shouldCleanValues) -> 1L), Map(cleanTextFn(v, shouldCleanValues).length -> 1L)) } TextMapStats(keyValueCounts) } @@ -104,25 +104,46 @@ class SmartTextMapVectorizer[T <: OPMap[String]] ) } else Array.empty[OpVectorColumnMetadata] - val textColumns = if (args.textFeatureInfo.flatten.nonEmpty) { + val allTextFeatureInfo = args.hashFeatureInfo.zip(args.ignoreFeatureInfo).map{ case (h, i) => h ++ i } + val allTextColumns = if (allTextFeatureInfo.flatten.nonEmpty) { val (mapFeatures, mapFeatureInfo) = - inN.toSeq.zip(args.textFeatureInfo).filter{ case (tf, featureInfoSeq) => featureInfoSeq.nonEmpty }.unzip + inN.toSeq.zip(allTextFeatureInfo).filter{ case (tf, featureInfoSeq) => featureInfoSeq.nonEmpty }.unzip val allKeys = mapFeatureInfo.map(_.map(_.key)) + + // Careful when zipping sequences like hashKeys (length = number of maps, always) and + // hashFeatures (length <= number of maps, depending on which ones contain keys to hash) + val hashKeys = args.hashFeatureInfo.map( + _.filter(_.vectorizationMethod == TextVectorizationMethod.Hash).map(_.key) + ) + val ignoreKeys = args.ignoreFeatureInfo.map( + _.filter(_.vectorizationMethod == TextVectorizationMethod.Ignore).map(_.key) + ) + + val hashFeatures = inN.toSeq.zip(args.hashFeatureInfo).filter { + case (tf, featureInfoSeq) => featureInfoSeq.nonEmpty + }.map(_._1) + val ignoreFeatures = inN.toSeq.zip(args.ignoreFeatureInfo).filter{ + case (tf, featureInfoSeq) => featureInfoSeq.nonEmpty + }.map(_._1) + makeVectorColumnMetadata( - features = mapFeatures.toArray, + hashFeatures = hashFeatures.toArray, + ignoreFeatures = ignoreFeatures.toArray, params = makeHashingParams(), - allKeys = allKeys, + hashKeys = hashKeys, + ignoreKeys = ignoreKeys, shouldTrackNulls = args.shouldTrackNulls, shouldTrackLen = $(trackTextLen) ) } else Array.empty[OpVectorColumnMetadata] - val columns = categoricalColumns ++ textColumns + val columns = categoricalColumns ++ allTextColumns OpVectorMetadata(getOutputFeatureName, columns, Transmogrifier.inputFeaturesToHistory(inN, stageName)) } def makeSmartTextMapVectorizerModelArgs(aggregatedStats: Array[TextMapStats]): SmartTextMapVectorizerModelArgs = { val maxCard = $(maxCardinality) + val minLenStdDev = $(minLengthStdDev) val minSup = $(minSupport) val shouldCleanKeys = $(cleanKeys) val shouldCleanValues = $(cleanText) @@ -130,14 +151,18 @@ class SmartTextMapVectorizer[T <: OPMap[String]] val allFeatureInfo = aggregatedStats.toSeq.map { textMapStats => textMapStats.keyValueCounts.toSeq.map { case (k, textStats) => - val isCat = textStats.valueCounts.size <= maxCard - val topVals = if (isCat) { + val vecMethod: TextVectorizationMethod = textStats match { + case _ if textStats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot + case _ if textStats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore + case _ => TextVectorizationMethod.Hash + } + val topVals = if (vecMethod == TextVectorizationMethod.Pivot) { textStats.valueCounts .filter { case (_, count) => count >= minSup } .toSeq.sortBy(v => -v._2 -> v._1) .take($(topK)).map(_._1).toArray } else Array.empty[String] - SmartTextFeatureInfo(key = k, isCategorical = isCat, topValues = topVals) + SmartTextFeatureInfo(key = k, vectorizationMethod = vecMethod, topValues = topVals) } } @@ -197,11 +222,15 @@ private[op] object TextMapStats { /** * Info about each feature within a text map * - * @param key name of a feature - * @param isCategorical indicate whether a feature is categorical or not - * @param topValues most common values of a feature (only for categoricals) + * @param key name of a feature + * @param vectorizationMethod method to use for text vectorization (either pivot, hashing, or ignoring) + * @param topValues most common values of a feature (only for categoricals) */ -case class SmartTextFeatureInfo(key: String, isCategorical: Boolean, topValues: Array[String]) extends JsonLike +case class SmartTextFeatureInfo( + key: String, + vectorizationMethod: TextVectorizationMethod, + topValues: Array[String] +) extends JsonLike /** @@ -221,11 +250,22 @@ case class SmartTextMapVectorizerModelArgs shouldTrackNulls: Boolean, hashingParams: HashingFunctionParams ) extends JsonLike { - val (categoricalFeatureInfo, textFeatureInfo) = allFeatureInfo.map{ featureInfoSeq => - featureInfoSeq.partition{_ .isCategorical } - }.unzip - val categoricalKeys = categoricalFeatureInfo.map(featureInfoSeq => featureInfoSeq.map(_.key)) - val textKeys = textFeatureInfo.map(featureInfoSeq => featureInfoSeq.map(_.key)) + // Partition allFeatureInfo into separate SmartTextFeatureInfo sequences corresponding to each vectorization type + val (categoricalFeatureInfo, hashFeatureInfo, ignoreFeatureInfo) = allFeatureInfo.map{ featureInfoSeq => + val groups = featureInfoSeq.groupBy(_.vectorizationMethod) + val catGroup = groups.getOrElse(TextVectorizationMethod.Pivot, Seq.empty) + val hashGroup = groups.getOrElse(TextVectorizationMethod.Hash, Seq.empty) + val ignoreGroup = groups.getOrElse(TextVectorizationMethod.Ignore, Seq.empty) + (catGroup, hashGroup, ignoreGroup) + }.unzip3 + + // Seq[Seq[String]] corresponding to the keys in each map that are treated with each vectorization type + val categoricalKeys = categoricalFeatureInfo.map(_.map(_.key)) + val hashKeys = hashFeatureInfo.map(_.map(_.key)) + val ignoreKeys = ignoreFeatureInfo.map(_.map(_.key)) + + // Combined keys for hashed and ignored features (everything that's not pivoted) + val textKeys = hashKeys.zip(ignoreKeys).map{ case (hk, ik) => hk ++ ik } } @@ -240,6 +280,25 @@ final class SmartTextMapVectorizerModel[T <: OPMap[String]] private[op] with MapHashingFun with TextMapPivotVectorizerModelFun[OPMap[String]] { + /** + * Storage for results of row partitioning + * + * @param categoricalMaps Sequence of maps that have at least one key that should be treated as a categorical + * @param categoricalKeys Sequence containing keys for each map that correspond to categorical features + * @param hashMaps Sequence of maps that have at least one key that should be hashed + * @param hashKeys Sequence containing keys for each map that correspond to hashed features + * @param ignoreMaps Sequence of maps that have at least one key that should be ignored + * @param ignoreKeys Sequence containing keys for each map that correspond to ignored features + */ + case class PartitionResult( + categoricalMaps: Seq[OPMap[String]], + categoricalKeys: Seq[Seq[String]], + hashMaps: Seq[OPMap[String]], + hashKeys: Seq[Seq[String]], + ignoreMaps: Seq[OPMap[String]], + ignoreKeys: Seq[Seq[String]] + ) + private val categoricalPivotFn = pivotFn( topValues = args.categoricalFeatureInfo.filter(_.nonEmpty).map(_.map(info => info.key -> info.topValues)), shouldCleanKeys = args.shouldCleanKeys, @@ -247,33 +306,47 @@ final class SmartTextMapVectorizerModel[T <: OPMap[String]] private[op] shouldTrackNulls = args.shouldTrackNulls ) - private def partitionRow(row: Seq[OPMap[String]]): - (Seq[OPMap[String]], Seq[Seq[String]], Seq[OPMap[String]], Seq[Seq[String]]) = { + private def partitionRow(row: Seq[OPMap[String]]): PartitionResult = { val (rowCategorical, keysCategorical) = row.view.zip(args.categoricalKeys).collect { case (elements, keys) if keys.nonEmpty => val filtered = elements.value.filter { case (k, v) => keys.contains(k) } (TextMap(filtered), keys) }.unzip - val (rowText, keysText) = - row.view.zip(args.textKeys).collect { case (elements, keys) if keys.nonEmpty => + val (rowHashedText, keysHashedText) = + row.view.zip(args.hashKeys).collect { case (elements, keys) if keys.nonEmpty => val filtered = elements.value.filter { case (k, v) => keys.contains(k) } (TextMap(filtered), keys) }.unzip - (rowCategorical.toList, keysCategorical.toList, rowText.toList, keysText.toList) + val (rowIgnoredText, keysIgnoredText) = + row.view.zip(args.ignoreKeys).collect { case (elements, keys) if keys.nonEmpty => + val filtered = elements.value.filter { case (k, v) => keys.contains(k) } + (TextMap(filtered), keys) + }.unzip + + PartitionResult(rowCategorical.toList, keysCategorical.toList, rowHashedText.toList, keysHashedText.toList, + rowIgnoredText.toList, keysIgnoredText.toList) } def transformFn: Seq[T] => OPVector = row => { - val (rowCategorical, keysCategorical, rowText, keysText) = partitionRow(row) + implicit val textListMonoid: Monoid[TextList] = TextList.monoid + + val PartitionResult(rowCategorical, keysCategorical, rowHash, keysHash, rowIgnore, keysIgnore) = partitionRow(row) + val keysText = keysHash + keysIgnore // Go algebird! val categoricalVector = categoricalPivotFn(rowCategorical) - val rowTextTokenized = rowText.map(_.value.map { case (k, v) => k -> tokenize(v.toText).tokens }) - val textVector = hash(rowTextTokenized, keysText, args.hashingParams) + + val rowHashTokenized = rowHash.map(_.value.map { case (k, v) => k -> tokenize(v.toText).tokens }) + val rowIgnoreTokenized = rowIgnore.map(_.value.map { case (k, v) => k -> tokenize(v.toText).tokens }) + val rowTextTokenized = rowHashTokenized + rowIgnoreTokenized // Go go algebird! + val hashVector = hash(rowHashTokenized, keysHash, args.hashingParams) + + // All columns get null tracking or text length tracking, whether their contents are hashed or ignored val textNullIndicatorsVector = if (args.shouldTrackNulls) getNullIndicatorsVector(keysText, rowTextTokenized) else OPVector.empty val textLenVector = if ($(trackTextLen)) getLenVector(keysText, rowTextTokenized) else OPVector.empty - categoricalVector.combine(textVector, textLenVector, textNullIndicatorsVector) + categoricalVector.combine(hashVector, textLenVector, textNullIndicatorsVector) } private def getNullIndicatorsVector(keysSeq: Seq[Seq[String]], inputs: Seq[Map[String, TextList]]): OPVector = { diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala index e95939d249..d75e42bf36 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala @@ -91,7 +91,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])( val (vectorizationMethods, topValues) = aggregatedStats.map { stats => val vecMethod: TextVectorizationMethod = stats match { case _ if stats.valueCounts.size <= maxCard => TextVectorizationMethod.Pivot - case _ if stats.lengthStdDev <= minLenStdDev => TextVectorizationMethod.Ignore + case _ if stats.lengthStdDev < minLenStdDev => TextVectorizationMethod.Ignore case _ => TextVectorizationMethod.Hash } val topValues = stats.valueCounts @@ -225,8 +225,6 @@ private[op] object TextStats { * Arguments for [[SmartTextVectorizerModel]] * * @param vectorizationMethods method to use for text vectorization (either pivot, hashing, or ignoring) - * @param isCategorical is feature a categorical or not - * @param isIgnorable is a text feature that we think is ignorable? high cardinality + low length variance * @param topValues top values to each feature * @param shouldCleanText should clean text value * @param shouldTrackNulls should track nulls diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala index 5584ea1565..a5af0c824c 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizerTest.scala @@ -32,13 +32,14 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op._ import com.salesforce.op.stages.base.sequence.SequenceModel -import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder, TestSparkContext} +import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.linalg.Vectors import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner import com.salesforce.op.features.types._ +import com.salesforce.op.testkit.RandomText @RunWith(classOf[JUnitRunner]) class SmartTextMapVectorizerTest @@ -73,6 +74,44 @@ class SmartTextMapVectorizerTest ) ) + /* + Generate some more complicated input data to check things a little closer. There are four text fields with + different token distributions: + + country: Uniformly distributed from a larger list of ~few hundred countries, should be hashed + categoricalText: Uniformly distributed from a small list of choices, should be pivoted (also has fixed lengths, + so serves as a test that the categorical check happens before the token length variance check) + textId: Uniformly distributed high cardinality Ids with fixed lengths, should be ignored + text: Uniformly distributed unicode strings with lengths ranging from 0-100, should be hashed + */ + val countryData: Seq[Text] = RandomText.countries.withProbabilityOfEmpty(0.2).limit(1000) + val categoricalTextData: Seq[Text] = RandomText.textFromDomain(domain = List("A", "B", "C", "D", "E", "F")) + .withProbabilityOfEmpty(0.2).limit(1000) + // Generate List containing elements like 040231, 040232, ... + val textIdData: Seq[Text] = RandomText.textFromDomain( + domain = (1 to 1000).map(x => "%06d".format(40230 + x)).toList + ).withProbabilityOfEmpty(0.2).limit(1000) + val textData: Seq[Text] = RandomText.strings(minLen = 0, maxLen = 100).withProbabilityOfEmpty(0.2).limit(1000) + val generatedData: Seq[(Text, Text, Text, Text)] = + countryData.zip(categoricalTextData).zip(textIdData).zip(textData).map { + case (((co, ca), id), te) => (co, ca, id, te) + } + + def mapifyText(textSeq: Seq[Text]): TextMap = { + textSeq.zipWithIndex.flatMap { + case (t, i) => t.value.map(tv => "f" + i.toString -> tv) + }.toMap.toTextMap + } + val mapData = generatedData.map { case (t1, t2, t3, t4) => mapifyText(Seq(t1, t2, t3, t4)) } + val (rawDF, rawTextMap) = TestFeatureBuilder("textMap", mapData) + + // Do the same thing with the data spread across many maps to test that they get combined correctly as well + val mapDataSeparate = generatedData.map { + case (t1, t2, t3, t4) => (mapifyText(Seq(t1, t2)), mapifyText(Seq(t3)), mapifyText(Seq(t4))) + } + val (rawDFSeparateMaps, rawTextMap1, rawTextMap2, rawTextMap3) = + TestFeatureBuilder("textMap1", "textMap2", "textMap3", mapDataSeparate) + /** * Estimator instance to be tested */ @@ -398,4 +437,76 @@ class SmartTextMapVectorizerTest result.foreach { case (vec1, vec2) => vec1 shouldBe vec2 } } + it should "detect and ignore fields that looks like machine-generated IDs by having a low token length variance " + + "when data is in a single TextMap" in { + val topKCategorial = 3 + val hashSize = 5 + + val smartVectorized = new SmartTextMapVectorizer() + .setMaxCardinality(10).setNumFeatures(hashSize).setMinSupport(10).setTopK(topKCategorial) + .setMinLengthStdDev(1.0) + .setAutoDetectLanguage(false).setMinTokenLength(1).setToLowercase(false) + .setTrackNulls(true).setTrackTextLen(true) + .setInput(rawTextMap).getOutput() + + val transformed = new OpWorkflow().setResultFeatures(smartVectorized).transform(rawDF) + val result = transformed.collect(smartVectorized) + + /* + Feature vector should have 16 components, corresponding to two hashed text fields, one categorical field, and + one ignored text field. + + Hashed text: (5 hash buckets + 1 length + 1 null indicator) = 7 elements + Categorical: (3 topK + 1 other + 1 null indicator) = 5 elements + Ignored text: (1 length + 1 null indicator) = 2 elements + */ + val featureVectorSize = 2 * (hashSize + 2) + (topKCategorial + 2) + 2 + val firstRes = result.head + firstRes.v.size shouldBe featureVectorSize + + val meta = OpVectorMetadata(transformed.schema(smartVectorized.name)) + meta.columns.length shouldBe featureVectorSize + meta.columns.slice(0, 5).forall(_.grouping.contains("categorical")) + meta.columns.slice(5, 10).forall(_.grouping.contains("country")) + meta.columns.slice(10, 15).forall(_.grouping.contains("text")) + meta.columns.slice(15, 18).forall(_.descriptorValue.contains(OpVectorColumnMetadata.TextLenString)) + meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString)) + } + + it should "detect and ignore fields that looks like machine-generated IDs by having a low token length variance " + + "when data is in many TextMaps" in { + val topKCategorial = 3 + val hashSize = 5 + + val smartVectorized = new SmartTextMapVectorizer() + .setMaxCardinality(10).setNumFeatures(hashSize).setMinSupport(10).setTopK(topKCategorial) + .setMinLengthStdDev(1.0) + .setAutoDetectLanguage(false).setMinTokenLength(1).setToLowercase(false) + .setTrackNulls(true).setTrackTextLen(true) + .setInput(rawTextMap1, rawTextMap2, rawTextMap3).getOutput() + + val transformed = new OpWorkflow().setResultFeatures(smartVectorized).transform(rawDFSeparateMaps) + val result = transformed.collect(smartVectorized) + + /* + Feature vector should have 16 components, corresponding to two hashed text fields, one categorical field, and + one ignored text field. + + Hashed text: (5 hash buckets + 1 length + 1 null indicator) = 7 elements + Categorical: (3 topK + 1 other + 1 null indicator) = 5 elements + Ignored text: (1 length + 1 null indicator) = 2 elements + */ + val featureVectorSize = 2 * (hashSize + 2) + (topKCategorial + 2) + 2 + val firstRes = result.head + firstRes.v.size shouldBe featureVectorSize + + val meta = OpVectorMetadata(transformed.schema(smartVectorized.name)) + meta.columns.length shouldBe featureVectorSize + meta.columns.slice(0, 5).forall(_.grouping.contains("categorical")) + meta.columns.slice(5, 10).forall(_.grouping.contains("country")) + meta.columns.slice(10, 15).forall(_.grouping.contains("text")) + meta.columns.slice(15, 18).forall(_.descriptorValue.contains(OpVectorColumnMetadata.TextLenString)) + meta.columns.slice(18, 21).forall(_.indicatorValue.contains(OpVectorColumnMetadata.NullString)) + } + } diff --git a/features/src/main/scala/com/salesforce/op/features/types/Lists.scala b/features/src/main/scala/com/salesforce/op/features/types/Lists.scala index eebf1f7a47..67147227c4 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/Lists.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/Lists.scala @@ -30,6 +30,8 @@ package com.salesforce.op.features.types +import com.twitter.algebird.{Monoid, SeqMonoid} + /** * A list of text values * @@ -41,6 +43,13 @@ class TextList(val value: Seq[String]) extends OPList[String] { object TextList { def apply(value: Seq[String]): TextList = new TextList(value) def empty: TextList = FeatureTypeDefaults.TextList + + def monoid: Monoid[TextList] = new Monoid[TextList] { + override def zero = TextList.empty + override def plus(left: TextList, right: TextList): TextList = { + TextList(left.value ++ right.value) + } + } } /** diff --git a/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeValueTest.scala b/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeValueTest.scala index ea1c597f2d..a1b6398a38 100644 --- a/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeValueTest.scala +++ b/features/src/test/scala/com/salesforce/op/features/types/FeatureTypeValueTest.scala @@ -32,6 +32,8 @@ package com.salesforce.op.features.types import com.salesforce.op.test.TestCommon import com.salesforce.op.utils.reflection.ReflectionUtils +import com.twitter.algebird.Monoid +import com.twitter.algebird.Operators._ import org.apache.lucene.geo.GeoUtils import org.apache.spark.ml.linalg.DenseVector import org.junit.runner.RunWith @@ -125,8 +127,14 @@ class FeatureTypeValueTest extends PropSpec with PropertyChecks with TestCommon } property("OPList types should correctly wrap their corresponding types") { + implicit val textListMonoid: Monoid[TextList] = TextList.monoid + forAll(geoGen) { x => checkVals(Geolocation(x), x) } - forAll(textListGen) { x => checkVals(TextList(x), x) } + forAll(textListGen) { x => + checkVals(TextList(x), x) + val tl = TextList(x) + tl + tl shouldBe TextList(x ++ x) // Test the monoid too + } forAll(longListGen) { x => checkVals(DateList(x), x) checkVals(DateTimeList(x), x)