From 5e762c3591a41b7d9f29223f8276a4f8115dc1e4 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Mon, 22 Oct 2018 14:37:19 -0700 Subject: [PATCH] Replace assert with require (#159) --- .../op/filters/FeatureDistribution.scala | 2 +- .../salesforce/op/filters/RawFeatureFilter.scala | 16 ++++++++-------- .../op/stages/impl/feature/OpStringIndexer.scala | 2 +- .../impl/feature/SmartTextMapVectorizer.scala | 2 +- .../impl/feature/SmartTextVectorizer.scala | 4 ++-- .../op/stages/impl/feature/Transmogrifier.scala | 2 +- .../impl/insights/RecordInsightsCorr.scala | 4 ++-- .../impl/preparators/SanityCheckerMetadata.scala | 2 +- .../impl/feature/OpStringIndexerTest.scala | 2 +- .../impl/feature/SmartTextVectorizerTest.scala | 6 +++--- .../com/salesforce/op/features/FeatureLike.scala | 2 +- .../op/stages/OpPipelineStageParams.scala | 2 +- .../base/sequence/BinarySequenceEstimator.scala | 2 +- .../sequence/BinarySequenceTransformer.scala | 2 +- .../stages/base/sequence/SequenceEstimator.scala | 2 +- .../base/sequence/SequenceTransformer.scala | 2 +- .../op/utils/spark/OpVectorColumnMetadata.scala | 8 ++++---- .../op/utils/spark/OPVectorMetadataTest.scala | 2 +- .../salesforce/op/readers/JoinedDataReader.scala | 2 +- .../scala/com/salesforce/op/readers/Reader.scala | 2 +- .../op/readers/JoinedReadersTest.scala | 6 +++--- 21 files changed, 37 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/filters/FeatureDistribution.scala b/core/src/main/scala/com/salesforce/op/filters/FeatureDistribution.scala index 159f8f9208..50e788d08e 100644 --- a/core/src/main/scala/com/salesforce/op/filters/FeatureDistribution.scala +++ b/core/src/main/scala/com/salesforce/op/filters/FeatureDistribution.scala @@ -69,7 +69,7 @@ case class FeatureDistribution * @param fd distribution to compare to */ def checkMatch(fd: FeatureDistribution): Unit = - assert(name == fd.name && key == fd.key, "Name and key must match to compare or combine FeatureDistribution") + require(name == fd.name && key == fd.key, "Name and key must match to compare or combine FeatureDistribution") /** * Get fill rate of feature diff --git a/core/src/main/scala/com/salesforce/op/filters/RawFeatureFilter.scala b/core/src/main/scala/com/salesforce/op/filters/RawFeatureFilter.scala index 7f5d48d6b1..a34681352c 100644 --- a/core/src/main/scala/com/salesforce/op/filters/RawFeatureFilter.scala +++ b/core/src/main/scala/com/salesforce/op/filters/RawFeatureFilter.scala @@ -103,18 +103,18 @@ class RawFeatureFilter[T] val timePeriod: Option[TimePeriod] = None ) extends Serializable { - assert(bins > 1 && bins <= FeatureDistribution.MaxBins, s"Invalid bin size $bins," + + require(bins > 1 && bins <= FeatureDistribution.MaxBins, s"Invalid bin size $bins," + s" bins must be between 1 and ${FeatureDistribution.MaxBins}") - assert(minFill >= 0.0 && minFill <= 1.0, s"Invalid minFill size $minFill, minFill must be between 0 and 1") - assert(maxFillDifference >= 0.0 && maxFillDifference <= 1.0, s"Invalid maxFillDifference size $maxFillDifference," + + require(minFill >= 0.0 && minFill <= 1.0, s"Invalid minFill size $minFill, minFill must be between 0 and 1") + require(maxFillDifference >= 0.0 && maxFillDifference <= 1.0, s"Invalid maxFillDifference size $maxFillDifference," + s" maxFillDifference must be between 0 and 1") - assert(maxFillRatioDiff >= 0.0, s"Invalid maxFillRatioDiff size $maxFillRatioDiff," + + require(maxFillRatioDiff >= 0.0, s"Invalid maxFillRatioDiff size $maxFillRatioDiff," + s" maxFillRatioDiff must be greater than 0.0") - assert(maxJSDivergence >= 0.0 && maxJSDivergence <= 1.0, s"Invalid maxJSDivergence size $maxJSDivergence," + + require(maxJSDivergence >= 0.0 && maxJSDivergence <= 1.0, s"Invalid maxJSDivergence size $maxJSDivergence," + s" maxJSDivergence must be between 0 and 1") ClosureUtils.checkSerializable(textBinsFormula) match { - case Failure(e) => throw new AssertionError("The argument textBinsFormula must be serializable", e) + case Failure(e) => throw new IllegalArgumentException("The argument textBinsFormula must be serializable", e) case ok => ok } @@ -229,7 +229,7 @@ class RawFeatureFilter[T] val scoringUnfilled = if (scoringDistribs.nonEmpty) { - assert(scoringDistribs.length == featureSize, "scoring and training features must match") + require(scoringDistribs.length == featureSize, "scoring and training features must match") val su = scoringDistribs.map(_.fillRate() < minFill) logExcluded(su, s"Features excluded because scoring fill rate did not meet min required ($minFill)") su @@ -284,7 +284,7 @@ class RawFeatureFilter[T] val trainData = trainingReader.generateDataFrame(rawFeatures, parameters).persist() log.info("Loaded training data") - assert(trainData.count() > 0, "RawFeatureFilter cannot work with empty training data") + require(trainData.count() > 0, "RawFeatureFilter cannot work with empty training data") val trainingSummary = computeFeatureStats(trainData, rawFeatures) log.info("Computed summary stats for training features") if (log.isDebugEnabled) { diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala index f03b54b1a9..a0359200a3 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpStringIndexer.scala @@ -63,7 +63,7 @@ class OpStringIndexer[T <: Text] * @return this stage */ def setHandleInvalid(value: StringIndexerHandleInvalid): this.type = { - assert(Seq(Inv.Skip, Inv.Error, Inv.Keep).contains(value), + require(Seq(Inv.Skip, Inv.Error, Inv.Keep).contains(value), "OpStringIndexer only supports Skip, Error, and Keep for handle invalid") getSparkMlStage().get.setHandleInvalid(value.entryName.toLowerCase) this diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala index f9b79cbd59..b2e4615bea 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextMapVectorizer.scala @@ -149,7 +149,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]] } def fitFn(dataset: Dataset[Seq[T#Value]]): SequenceModel[T, OPVector] = { - assert(!dataset.isEmpty, "Input dataset cannot be empty") + require(!dataset.isEmpty, "Input dataset cannot be empty") val maxCard = $(maxCardinality) val shouldCleanKeys = $(cleanKeys) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala index f09e863cfc..a5f0157d50 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala @@ -77,7 +77,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])( ) def fitFn(dataset: Dataset[Seq[T#Value]]): SequenceModel[T, OPVector] = { - assert(!dataset.isEmpty, "Input dataset cannot be empty") + require(!dataset.isEmpty, "Input dataset cannot be empty") val maxCard = $(maxCardinality) val shouldCleanText = $(cleanText) @@ -123,7 +123,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])( } private def makeVectorMetadata(smartTextParams: SmartTextVectorizerModelArgs): OpVectorMetadata = { - assert(inN.length == smartTextParams.isCategorical.length) + require(inN.length == smartTextParams.isCategorical.length) val (categoricalFeatures, textFeatures) = SmartTextVectorizer.partition[TransientFeature](inN, smartTextParams.isCategorical) diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/Transmogrifier.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/Transmogrifier.scala index 9bedcca1b2..347fce17bc 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/Transmogrifier.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/Transmogrifier.scala @@ -439,7 +439,7 @@ case object VectorizerUtils { * @return one-hot vector with 1.0 in position value */ def oneHot(pos: Int, size: Int): Array[Double] = { - assert(pos < size && pos >= 0, s"One-hot index lies outside the bounds of the vector: pos = $pos, size = $size") + require(pos < size && pos >= 0, s"One-hot index lies outside the bounds of the vector: pos = $pos, size = $size") val arr = new Array[Double](size) arr(pos) = 1.0 arr diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsCorr.scala b/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsCorr.scala index d8135d850f..9e281a996e 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsCorr.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/insights/RecordInsightsCorr.scala @@ -95,7 +95,7 @@ class RecordInsightsCorr(uid: String = UID[RecordInsightsCorr]) extends override def fitFn(dataset: Dataset[(OPVector#Value, OPVector#Value)]): RecordInsightsCorrModel = { val vectorMetadata = Try(OpVectorMetadata(getInputSchema()(in2.name))) - assert(vectorMetadata.isSuccess, s"first input feature must be a feature vector with OpVectorMetadata," + + require(vectorMetadata.isSuccess, s"first input feature must be a feature vector with OpVectorMetadata," + s"got error parsing metadata: ${vectorMetadata.failed.get}") val first = dataset.first() @@ -139,7 +139,7 @@ private[op] final class RecordInsightsCorrModel private lazy val featureInfo = OpVectorMetadata(getInputSchema()(in2.name)).getColumnHistory() override def transformFn: (OPVector, OPVector) => TextMap = (_, features) => { - assert(featureInfo.size == features.value.size, "feature metadata size does not match feature size") + require(featureInfo.size == features.value.size, "feature metadata size does not match feature size") val normalizedFeatures = norm(features) val importance = scoreCorr.map{ _.zip(normalizedFeatures).map{ case (a, b) => if (a.isNaN) 0.0 else a * b} } diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerMetadata.scala b/core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerMetadata.scala index da0ee2a1fd..adab507931 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerMetadata.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/preparators/SanityCheckerMetadata.scala @@ -271,7 +271,7 @@ case class Correlations nanCorrs: Seq[String], corrType: CorrelationType ) extends MetadataLike { - assert(featuresIn.length == values.length, "Feature names and correlation values arrays must have the same length") + require(featuresIn.length == values.length, "Feature names and correlation values arrays must have the same length") def this(corrs: Seq[(String, Double)], nans: Seq[String], corrType: CorrelationType) = this( featuresIn = corrs.map(_._1), diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala index 86df7d7d94..50c61fc8f9 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpStringIndexerTest.scala @@ -57,7 +57,7 @@ class OpStringIndexerTest extends FlatSpec with TestSparkContext{ it should "throw an error if you try to set noFilter as the indexer" in { val indexer = new OpStringIndexer[Text]() - intercept[AssertionError](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter)) + intercept[IllegalArgumentException](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter)) } it should "correctly index a text column" in { diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala index 672c17e25d..2bd0e421f1 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizerTest.scala @@ -170,17 +170,17 @@ class SmartTextVectorizerTest regular shouldBe shortcut } - it should "fail with an assertion error" in { + it should "fail with an error" in { val emptyDF = inputData.filter(inputData("text1") === "").toDF() val smartVectorized = new SmartTextVectorizer() .setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(false) .setInput(f1, f2).getOutput() - val thrown = intercept[AssertionError] { + val thrown = intercept[IllegalArgumentException] { new OpWorkflow().setResultFeatures(smartVectorized).transform(emptyDF) } - assert(thrown.getMessage.contains("assertion failed")) + assert(thrown.getMessage.contains("requirement failed")) } it should "generate metadata correctly" in { diff --git a/features/src/main/scala/com/salesforce/op/features/FeatureLike.scala b/features/src/main/scala/com/salesforce/op/features/FeatureLike.scala index e3fcfde9af..a0d8712385 100644 --- a/features/src/main/scala/com/salesforce/op/features/FeatureLike.scala +++ b/features/src/main/scala/com/salesforce/op/features/FeatureLike.scala @@ -353,7 +353,7 @@ trait FeatureLike[O <: FeatureType] { if (acc.contains(f.uid)) acc else acc + (f.uid -> f) ) - assert(checkFeatureOriginStageMatch(featuresByUid.values), "Some of your features had parent features that did" + + require(checkFeatureOriginStageMatch(featuresByUid.values), "Some of your features had parent features that did" + " not match the inputs to their origin stage. All stages must be a new instance when used to transform features") def logDebug(msg: String) = log.debug(s"[${this.uid}]: $msg") diff --git a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageParams.scala b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageParams.scala index a28e4dfa94..1579c1ba0a 100644 --- a/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageParams.scala +++ b/features/src/main/scala/com/salesforce/op/stages/OpPipelineStageParams.scala @@ -67,7 +67,7 @@ trait InputParams extends Params { * @return this stage */ final protected def setInputFeatures[S <: OPFeature](features: Array[S]): this.type = { - assert( + require( checkInputLength(features), "Number of input features must match the number expected by this type of pipeline stage" ) diff --git a/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceEstimator.scala b/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceEstimator.scala index eb1fc2260c..2e78be04bf 100644 --- a/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceEstimator.scala +++ b/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceEstimator.scala @@ -100,7 +100,7 @@ abstract class BinarySequenceEstimator[I1 <: FeatureType, I2 <: FeatureType, O < * @return a fitted model that will perform the transformation specified by the function defined in constructor fit */ override def fit(dataset: Dataset[_]): BinarySequenceModel[I1, I2, O] = { - assert(getTransientFeatures.size > 1, "Inputs cannot be empty") + require(getTransientFeatures.size > 1, "Inputs cannot be empty") setInputSchema(dataset.schema).transformSchema(dataset.schema) val seqColumns = inN.map(feature => col(feature.name)) diff --git a/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformer.scala b/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformer.scala index e31a7e5fd9..3c8ec96861 100644 --- a/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformer.scala +++ b/features/src/main/scala/com/salesforce/op/stages/base/sequence/BinarySequenceTransformer.scala @@ -76,7 +76,7 @@ trait OpTransformer2N[I1 <: FeatureType, I2 <: FeatureType, O <: FeatureType] * @return a new dataset containing a column for the transformed feature */ override def transform(dataset: Dataset[_]): DataFrame = { - assert(getTransientFeatures.size > 1, "Inputs cannot be empty") + require(getTransientFeatures.size > 1, "Inputs cannot be empty") val newSchema = setInputSchema(dataset.schema).transformSchema(dataset.schema) val functionUDF = FeatureSparkTypes.udf2N[I1, I2, O](transformFn) val meta = newSchema(getOutputFeatureName).metadata diff --git a/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceEstimator.scala b/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceEstimator.scala index 8937196646..f1ba75a41c 100644 --- a/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceEstimator.scala +++ b/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceEstimator.scala @@ -89,7 +89,7 @@ abstract class SequenceEstimator[I <: FeatureType, O <: FeatureType] * @return a fitted model that will perform the transformation specified by the function defined in constructor fit */ override def fit(dataset: Dataset[_]): SequenceModel[I, O] = { - assert(inN.nonEmpty, "Inputs cannot be empty") + require(inN.nonEmpty, "Inputs cannot be empty") setInputSchema(dataset.schema).transformSchema(dataset.schema) val columns = inN.map(feature => col(feature.name)) diff --git a/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceTransformer.scala b/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceTransformer.scala index 990018830d..0b2b00b2a1 100644 --- a/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceTransformer.scala +++ b/features/src/main/scala/com/salesforce/op/stages/base/sequence/SequenceTransformer.scala @@ -73,7 +73,7 @@ trait OpTransformerN[I <: FeatureType, O <: FeatureType] * @return a new dataset containing a column for the transformed feature */ override def transform(dataset: Dataset[_]): DataFrame = { - assert(inN.nonEmpty, "Inputs cannot be empty") + require(inN.nonEmpty, "Inputs cannot be empty") val newSchema = setInputSchema(dataset.schema).transformSchema(dataset.schema) val functionUDF = FeatureSparkTypes.udfN[I, O](transformFn) val meta = newSchema(getOutputFeatureName).metadata diff --git a/features/src/main/scala/com/salesforce/op/utils/spark/OpVectorColumnMetadata.scala b/features/src/main/scala/com/salesforce/op/utils/spark/OpVectorColumnMetadata.scala index bcb1c0e09b..679f411063 100644 --- a/features/src/main/scala/com/salesforce/op/utils/spark/OpVectorColumnMetadata.scala +++ b/features/src/main/scala/com/salesforce/op/utils/spark/OpVectorColumnMetadata.scala @@ -74,12 +74,12 @@ case class OpVectorColumnMetadata // TODO make separate case classes extending t index: Int = 0 ) extends JsonLike { - assert(parentFeatureName.nonEmpty, "must provide parent feature name") - assert(parentFeatureType.nonEmpty, "must provide parent type name") - assert(parentFeatureName.length == parentFeatureType.length, + require(parentFeatureName.nonEmpty, "must provide parent feature name") + require(parentFeatureType.nonEmpty, "must provide parent type name") + require(parentFeatureName.length == parentFeatureType.length, s"must provide both type and name for every parent feature," + s" names: $parentFeatureName and types: $parentFeatureType do not have the same length") - assert(indicatorValue.isEmpty || descriptorValue.isEmpty, "cannot have both indicatorValue and descriptorValue") + require(indicatorValue.isEmpty || descriptorValue.isEmpty, "cannot have both indicatorValue and descriptorValue") /** * Convert this column into Spark metadata. diff --git a/features/src/test/scala/com/salesforce/op/utils/spark/OPVectorMetadataTest.scala b/features/src/test/scala/com/salesforce/op/utils/spark/OPVectorMetadataTest.scala index 2dcad0b38c..438119ee57 100644 --- a/features/src/test/scala/com/salesforce/op/utils/spark/OPVectorMetadataTest.scala +++ b/features/src/test/scala/com/salesforce/op/utils/spark/OPVectorMetadataTest.scala @@ -100,7 +100,7 @@ class OPVectorMetadataTest extends PropSpec with TestCommon with PropertyChecks property("column metadata cannot be created with empty parents or feature types") { forAll(vecColTupleGen) { (vct: OpVectorColumnTuple) => if (!checkTuples(vct)) { - assertThrows[AssertionError] { OpVectorColumnMetadata(vct._1, vct._2, vct._3, vct._4, vct._5) } + assertThrows[IllegalArgumentException] { OpVectorColumnMetadata(vct._1, vct._2, vct._3, vct._4, vct._5) } } } } diff --git a/readers/src/main/scala/com/salesforce/op/readers/JoinedDataReader.scala b/readers/src/main/scala/com/salesforce/op/readers/JoinedDataReader.scala index 9166f922bf..ab4379f5d5 100644 --- a/readers/src/main/scala/com/salesforce/op/readers/JoinedDataReader.scala +++ b/readers/src/main/scala/com/salesforce/op/readers/JoinedDataReader.scala @@ -128,7 +128,7 @@ private[op] abstract class JoinedReader[T, U] final def subReaders: Seq[DataReader[_]] = { val allReaders = Seq(leftReader.subReaders, rightReader.subReaders).flatten - assert(allReaders.size == allReaders.distinct.size, "Cannot have duplicate readers in joins") + require(allReaders.size == allReaders.distinct.size, "Cannot have duplicate readers in joins") allReaders } diff --git a/readers/src/main/scala/com/salesforce/op/readers/Reader.scala b/readers/src/main/scala/com/salesforce/op/readers/Reader.scala index 110dc27701..25fe4ef803 100644 --- a/readers/src/main/scala/com/salesforce/op/readers/Reader.scala +++ b/readers/src/main/scala/com/salesforce/op/readers/Reader.scala @@ -150,7 +150,7 @@ trait Reader[T] extends ReaderType[T] { ): JoinedDataReader[T, U] = { val joinedReader = new JoinedDataReader[T, U](leftReader = this, rightReader = other, joinKeys = joinKeys, joinType = joinType) - assert(joinedReader.leftReader.subReaders + require(joinedReader.leftReader.subReaders .forall(r => r.fullTypeName != joinedReader.rightReader.fullTypeName), "All joins must be for readers of different objects - self joins are not supported" ) diff --git a/readers/src/test/scala/com/salesforce/op/readers/JoinedReadersTest.scala b/readers/src/test/scala/com/salesforce/op/readers/JoinedReadersTest.scala index f9b0085679..95b732062c 100644 --- a/readers/src/test/scala/com/salesforce/op/readers/JoinedReadersTest.scala +++ b/readers/src/test/scala/com/salesforce/op/readers/JoinedReadersTest.scala @@ -84,19 +84,19 @@ class JoinedReadersTest extends FlatSpec with PassengerSparkFixtureTest { } it should "throw an error if you try to perform a self join" in { - a[AssertionError] should be thrownBy { + a[IllegalArgumentException] should be thrownBy { dataReader.innerJoin(dataReader) } } it should "throw an error if you try to use the same reader twice" in { - a[AssertionError] should be thrownBy { + a[IllegalArgumentException] should be thrownBy { dataReader.innerJoin(sparkReader).innerJoin(dataReader) } } it should "throw an error if you try to read the same data type twice with different readers" in { - a[AssertionError] should be thrownBy { + a[IllegalArgumentException] should be thrownBy { passengerReader.innerJoin(sparkReader).outerJoin(dataReader) } }