From a2d1315fed1dd0094117bc638cf558ef1cee9d50 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 27 Aug 2018 10:37:06 -0700 Subject: [PATCH 1/5] updated string indexer enum to match spark --- .../salesforce/op/dsl/RichTextFeature.scala | 4 +- .../stages/impl/feature/OpStringIndexer.scala | 2 +- .../impl/feature/OpStringIndexerTest.scala | 52 +++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala index 00581f0473..925867e526 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala @@ -256,10 +256,10 @@ trait RichTextFeature { */ def indexed( unseenName: String = OpStringIndexerNoFilter.UnseenNameDefault, - handleInvalid: StringIndexerHandleInvalid = StringIndexerHandleInvalid.NoFilter + handleInvalid: StringIndexerHandleInvalid = StringIndexerHandleInvalid.Keep ): FeatureLike[RealNN] = { handleInvalid match { - case StringIndexerHandleInvalid.NoFilter => f.transformWith( + case StringIndexerHandleInvalid.Keep => f.transformWith( new OpStringIndexerNoFilter[T]().setUnseenName(unseenName) ) case _ => f.transformWith(new OpStringIndexer[T]().setHandleInvalid(handleInvalid)) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index 617a20d923..b65eb81fc8 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -73,5 +73,5 @@ object StringIndexerHandleInvalid extends Enum[StringIndexerHandleInvalid] { val values = findValues case object Skip extends StringIndexerHandleInvalid case object Error extends StringIndexerHandleInvalid - case object NoFilter extends StringIndexerHandleInvalid + case object Keep extends StringIndexerHandleInvalid } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala new file mode 100644 index 0000000000..f500606dab --- /dev/null +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * * Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.salesforce.op.stages.impl.feature + +import com.salesforce.op.features.types.Text +import com.salesforce.op.test.TestSparkContext +import org.apache.spark.ml.feature.StringIndexer +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class OpStringIndexerTest extends FlatSpec with TestSparkContext{ + + Spec[OpStringIndexer[_]] should "correctly set the wrapped spark stage params" in { + val indexer = new OpStringIndexer[Text]() + indexer.setHandleInvalid(StringIndexerHandleInvalid.Skip) + indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Skip.entryName.toLowerCase + indexer.setHandleInvalid(StringIndexerHandleInvalid.Error) + indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Error.entryName.toLowerCase + indexer.setHandleInvalid(StringIndexerHandleInvalid.Keep) + indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Skip.entryName.toLowerCase + } + +} From 547a645aa3a4d72042da19a39ff6c85e05128df2 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 27 Aug 2018 10:49:12 -0700 Subject: [PATCH 2/5] minor fix --- .../main/scala/com/salesforce/op/dsl/RichTextFeature.scala | 2 +- .../op/stages/impl/feature/OpStringIndexer.scala | 4 ++++ .../op/stages/impl/feature/OpStringIndexerTest.scala | 7 ++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala index 925867e526..cbea4ac85f 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala @@ -259,7 +259,7 @@ trait RichTextFeature { handleInvalid: StringIndexerHandleInvalid = StringIndexerHandleInvalid.Keep ): FeatureLike[RealNN] = { handleInvalid match { - case StringIndexerHandleInvalid.Keep => f.transformWith( + case StringIndexerHandleInvalid.NoFilter => f.transformWith( new OpStringIndexerNoFilter[T]().setUnseenName(unseenName) ) case _ => f.transformWith(new OpStringIndexer[T]().setHandleInvalid(handleInvalid)) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index b65eb81fc8..30aa60b9b3 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -32,6 +32,7 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.types._ +import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid.{Keep, Skip} import com.salesforce.op.stages.sparkwrappers.specific.OpEstimatorWrapper import enumeratum._ import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} @@ -62,6 +63,8 @@ class OpStringIndexer[T <: Text] * @return this stage */ def setHandleInvalid(value: StringIndexerHandleInvalid): this.type = { + assert(Seq(StringIndexerHandleInvalid.Skip, StringIndexerHandleInvalid.Error, StringIndexerHandleInvalid.Keep) + .contains(value), "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") getSparkMlStage().get.setHandleInvalid(value.entryName.toLowerCase) this } @@ -74,4 +77,5 @@ object StringIndexerHandleInvalid extends Enum[StringIndexerHandleInvalid] { case object Skip extends StringIndexerHandleInvalid case object Error extends StringIndexerHandleInvalid case object Keep extends StringIndexerHandleInvalid + case object NoFilter extends StringIndexerHandleInvalid } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala index f500606dab..a6ed100f8e 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala @@ -46,7 +46,12 @@ class OpStringIndexerTest extends FlatSpec with TestSparkContext{ indexer.setHandleInvalid(StringIndexerHandleInvalid.Error) indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Error.entryName.toLowerCase indexer.setHandleInvalid(StringIndexerHandleInvalid.Keep) - indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Skip.entryName.toLowerCase + indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Keep.entryName.toLowerCase + } + + it should "throw an error if you try to set noFilter as the indexer" in { + val indexer = new OpStringIndexer[Text]() + intercept[AssertionError](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter)) } } From 7f343fef8980adcc64d428d4aecd75fa517e02a1 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 27 Aug 2018 10:53:39 -0700 Subject: [PATCH 3/5] put back default --- core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala index cbea4ac85f..00581f0473 100644 --- a/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala +++ b/core/src/main/scala/com/salesforce/op/dsl/RichTextFeature.scala @@ -256,7 +256,7 @@ trait RichTextFeature { */ def indexed( unseenName: String = OpStringIndexerNoFilter.UnseenNameDefault, - handleInvalid: StringIndexerHandleInvalid = StringIndexerHandleInvalid.Keep + handleInvalid: StringIndexerHandleInvalid = StringIndexerHandleInvalid.NoFilter ): FeatureLike[RealNN] = { handleInvalid match { case StringIndexerHandleInvalid.NoFilter => f.transformWith( From 14d9a03fa2b6c23dd9ca2fea1fef381cc29d83f7 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 27 Aug 2018 11:13:18 -0700 Subject: [PATCH 4/5] moved tests --- .../stages/impl/feature/OpStringIndexer.scala | 5 ++-- .../feature/OpStringIndexerNoFilterTest.scala | 20 +------------- .../impl/feature/OpStringIndexerTest.scala | 27 +++++++++++++++++-- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index 30aa60b9b3..1bed512cf9 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -32,7 +32,7 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.types._ -import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid.{Keep, Skip} +import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid._ import com.salesforce.op.stages.sparkwrappers.specific.OpEstimatorWrapper import enumeratum._ import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} @@ -63,8 +63,7 @@ class OpStringIndexer[T <: Text] * @return this stage */ def setHandleInvalid(value: StringIndexerHandleInvalid): this.type = { - assert(Seq(StringIndexerHandleInvalid.Skip, StringIndexerHandleInvalid.Error, StringIndexerHandleInvalid.Keep) - .contains(value), "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") + assert(Seq(Skip, Error, Keep).contains(value), "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") getSparkMlStage().get.setHandleInvalid(value.entryName.toLowerCase) this } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala index 5ee027a1c6..b7590e02ce 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerNoFilterTest.scala @@ -38,8 +38,8 @@ import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import com.salesforce.op.utils.spark.RichDataset._ import org.apache.spark.ml.feature.StringIndexerModel import org.junit.runner.RunWith +import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner -import org.scalatest.{Assertions, FlatSpec, Matchers} @RunWith(classOf[JUnitRunner]) @@ -90,22 +90,4 @@ class OpStringIndexerNoFilterTest extends FlatSpec with TestSparkContext { indices shouldBe expectedNew } - - Spec[OpStringIndexer[_]] should "correctly index a text column" in { - val stringIndexer = new OpStringIndexer[Text]().setInput(txtF) - val indices = stringIndexer.fit(ds).transform(ds).collect(stringIndexer.getOutput()) - - indices shouldBe expected - } - - it should "correctly deinxed a numeric column" in { - val indexedStage = new OpStringIndexer[Text]().setInput(txtF) - val indexed = indexedStage.getOutput() - val indices = indexedStage.fit(ds).transform(ds) - val deindexedStage = new OpIndexToString().setInput(indexed) - val deindexed = deindexedStage.getOutput() - val deindexedData = deindexedStage.transform(indices).collect(deindexed) - deindexedData shouldBe txtData - } - } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala index a6ed100f8e..86df7d7d94 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala @@ -29,16 +29,22 @@ */ package com.salesforce.op.stages.impl.feature +import com.salesforce.op.features.types._ import com.salesforce.op.features.types.Text -import com.salesforce.op.test.TestSparkContext -import org.apache.spark.ml.feature.StringIndexer +import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext} import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner +import com.salesforce.op.utils.spark.RichDataset._ @RunWith(classOf[JUnitRunner]) class OpStringIndexerTest extends FlatSpec with TestSparkContext{ + val txtData = Seq("a", "b", "c", "a", "a", "c").map(_.toText) + val (ds, txtF) = TestFeatureBuilder(txtData) + val expected = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN) + + Spec[OpStringIndexer[_]] should "correctly set the wrapped spark stage params" in { val indexer = new OpStringIndexer[Text]() indexer.setHandleInvalid(StringIndexerHandleInvalid.Skip) @@ -54,4 +60,21 @@ class OpStringIndexerTest extends FlatSpec with TestSparkContext{ intercept[AssertionError](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter)) } + it should "correctly index a text column" in { + val stringIndexer = new OpStringIndexer[Text]().setInput(txtF) + val indices = stringIndexer.fit(ds).transform(ds).collect(stringIndexer.getOutput()) + + indices shouldBe expected + } + + it should "correctly deinxed a numeric column" in { + val indexedStage = new OpStringIndexer[Text]().setInput(txtF) + val indexed = indexedStage.getOutput() + val indices = indexedStage.fit(ds).transform(ds) + val deindexedStage = new OpIndexToString().setInput(indexed) + val deindexed = deindexedStage.getOutput() + val deindexedData = deindexedStage.transform(indices).collect(deindexed) + deindexedData shouldBe txtData + } + } From 9ece16be82a1ceacc57edafd29a77053c7d02d50 Mon Sep 17 00:00:00 2001 From: leahmcguire Date: Mon, 27 Aug 2018 11:15:58 -0700 Subject: [PATCH 5/5] fixed import --- .../salesforce/op/stages/impl/feature/OpStringIndexer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index 1bed512cf9..f03b54b1a9 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -32,7 +32,7 @@ package com.salesforce.op.stages.impl.feature import com.salesforce.op.UID import com.salesforce.op.features.types._ -import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid._ +import com.salesforce.op.stages.impl.feature.{StringIndexerHandleInvalid => Inv} import com.salesforce.op.stages.sparkwrappers.specific.OpEstimatorWrapper import enumeratum._ import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} @@ -63,7 +63,8 @@ class OpStringIndexer[T <: Text] * @return this stage */ def setHandleInvalid(value: StringIndexerHandleInvalid): this.type = { - assert(Seq(Skip, Error, Keep).contains(value), "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") + assert(Seq(Inv.Skip, Inv.Error, Inv.Keep).contains(value), + "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") getSparkMlStage().get.setHandleInvalid(value.entryName.toLowerCase) this }