Skip to content

Commit

Permalink
Replace assert with require (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm authored Oct 22, 2018
1 parent 8264265 commit 5e762c3
Show file tree
Hide file tree
Showing 21 changed files with 37 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class FeatureDistribution
* @param fd distribution to compare to
*/
def checkMatch(fd: FeatureDistribution): Unit =
assert(name == fd.name && key == fd.key, "Name and key must match to compare or combine FeatureDistribution")
require(name == fd.name && key == fd.key, "Name and key must match to compare or combine FeatureDistribution")

/**
* Get fill rate of feature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ class RawFeatureFilter[T]
val timePeriod: Option[TimePeriod] = None
) extends Serializable {

assert(bins > 1 && bins <= FeatureDistribution.MaxBins, s"Invalid bin size $bins," +
require(bins > 1 && bins <= FeatureDistribution.MaxBins, s"Invalid bin size $bins," +
s" bins must be between 1 and ${FeatureDistribution.MaxBins}")
assert(minFill >= 0.0 && minFill <= 1.0, s"Invalid minFill size $minFill, minFill must be between 0 and 1")
assert(maxFillDifference >= 0.0 && maxFillDifference <= 1.0, s"Invalid maxFillDifference size $maxFillDifference," +
require(minFill >= 0.0 && minFill <= 1.0, s"Invalid minFill size $minFill, minFill must be between 0 and 1")
require(maxFillDifference >= 0.0 && maxFillDifference <= 1.0, s"Invalid maxFillDifference size $maxFillDifference," +
s" maxFillDifference must be between 0 and 1")
assert(maxFillRatioDiff >= 0.0, s"Invalid maxFillRatioDiff size $maxFillRatioDiff," +
require(maxFillRatioDiff >= 0.0, s"Invalid maxFillRatioDiff size $maxFillRatioDiff," +
s" maxFillRatioDiff must be greater than 0.0")
assert(maxJSDivergence >= 0.0 && maxJSDivergence <= 1.0, s"Invalid maxJSDivergence size $maxJSDivergence," +
require(maxJSDivergence >= 0.0 && maxJSDivergence <= 1.0, s"Invalid maxJSDivergence size $maxJSDivergence," +
s" maxJSDivergence must be between 0 and 1")

ClosureUtils.checkSerializable(textBinsFormula) match {
case Failure(e) => throw new AssertionError("The argument textBinsFormula must be serializable", e)
case Failure(e) => throw new IllegalArgumentException("The argument textBinsFormula must be serializable", e)
case ok => ok
}

Expand Down Expand Up @@ -229,7 +229,7 @@ class RawFeatureFilter[T]

val scoringUnfilled =
if (scoringDistribs.nonEmpty) {
assert(scoringDistribs.length == featureSize, "scoring and training features must match")
require(scoringDistribs.length == featureSize, "scoring and training features must match")
val su = scoringDistribs.map(_.fillRate() < minFill)
logExcluded(su, s"Features excluded because scoring fill rate did not meet min required ($minFill)")
su
Expand Down Expand Up @@ -284,7 +284,7 @@ class RawFeatureFilter[T]

val trainData = trainingReader.generateDataFrame(rawFeatures, parameters).persist()
log.info("Loaded training data")
assert(trainData.count() > 0, "RawFeatureFilter cannot work with empty training data")
require(trainData.count() > 0, "RawFeatureFilter cannot work with empty training data")
val trainingSummary = computeFeatureStats(trainData, rawFeatures)
log.info("Computed summary stats for training features")
if (log.isDebugEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class OpStringIndexer[T <: Text]
* @return this stage
*/
def setHandleInvalid(value: StringIndexerHandleInvalid): this.type = {
assert(Seq(Inv.Skip, Inv.Error, Inv.Keep).contains(value),
require(Seq(Inv.Skip, Inv.Error, Inv.Keep).contains(value),
"OpStringIndexer only supports Skip, Error, and Keep for handle invalid")
getSparkMlStage().get.setHandleInvalid(value.entryName.toLowerCase)
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class SmartTextMapVectorizer[T <: OPMap[String]]
}

def fitFn(dataset: Dataset[Seq[T#Value]]): SequenceModel[T, OPVector] = {
assert(!dataset.isEmpty, "Input dataset cannot be empty")
require(!dataset.isEmpty, "Input dataset cannot be empty")

val maxCard = $(maxCardinality)
val shouldCleanKeys = $(cleanKeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
)

def fitFn(dataset: Dataset[Seq[T#Value]]): SequenceModel[T, OPVector] = {
assert(!dataset.isEmpty, "Input dataset cannot be empty")
require(!dataset.isEmpty, "Input dataset cannot be empty")

val maxCard = $(maxCardinality)
val shouldCleanText = $(cleanText)
Expand Down Expand Up @@ -123,7 +123,7 @@ class SmartTextVectorizer[T <: Text](uid: String = UID[SmartTextVectorizer[T]])(
}

private def makeVectorMetadata(smartTextParams: SmartTextVectorizerModelArgs): OpVectorMetadata = {
assert(inN.length == smartTextParams.isCategorical.length)
require(inN.length == smartTextParams.isCategorical.length)

val (categoricalFeatures, textFeatures) =
SmartTextVectorizer.partition[TransientFeature](inN, smartTextParams.isCategorical)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ case object VectorizerUtils {
* @return one-hot vector with 1.0 in position value
*/
def oneHot(pos: Int, size: Int): Array[Double] = {
assert(pos < size && pos >= 0, s"One-hot index lies outside the bounds of the vector: pos = $pos, size = $size")
require(pos < size && pos >= 0, s"One-hot index lies outside the bounds of the vector: pos = $pos, size = $size")
val arr = new Array[Double](size)
arr(pos) = 1.0
arr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class RecordInsightsCorr(uid: String = UID[RecordInsightsCorr]) extends
override def fitFn(dataset: Dataset[(OPVector#Value, OPVector#Value)]): RecordInsightsCorrModel = {

val vectorMetadata = Try(OpVectorMetadata(getInputSchema()(in2.name)))
assert(vectorMetadata.isSuccess, s"first input feature must be a feature vector with OpVectorMetadata," +
require(vectorMetadata.isSuccess, s"first input feature must be a feature vector with OpVectorMetadata," +
s"got error parsing metadata: ${vectorMetadata.failed.get}")

val first = dataset.first()
Expand Down Expand Up @@ -139,7 +139,7 @@ private[op] final class RecordInsightsCorrModel
private lazy val featureInfo = OpVectorMetadata(getInputSchema()(in2.name)).getColumnHistory()

override def transformFn: (OPVector, OPVector) => TextMap = (_, features) => {
assert(featureInfo.size == features.value.size, "feature metadata size does not match feature size")
require(featureInfo.size == features.value.size, "feature metadata size does not match feature size")

val normalizedFeatures = norm(features)
val importance = scoreCorr.map{ _.zip(normalizedFeatures).map{ case (a, b) => if (a.isNaN) 0.0 else a * b} }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ case class Correlations
nanCorrs: Seq[String],
corrType: CorrelationType
) extends MetadataLike {
assert(featuresIn.length == values.length, "Feature names and correlation values arrays must have the same length")
require(featuresIn.length == values.length, "Feature names and correlation values arrays must have the same length")

def this(corrs: Seq[(String, Double)], nans: Seq[String], corrType: CorrelationType) = this(
featuresIn = corrs.map(_._1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OpStringIndexerTest extends FlatSpec with TestSparkContext{

it should "throw an error if you try to set noFilter as the indexer" in {
val indexer = new OpStringIndexer[Text]()
intercept[AssertionError](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter))
intercept[IllegalArgumentException](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter))
}

it should "correctly index a text column" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,17 @@ class SmartTextVectorizerTest
regular shouldBe shortcut
}

it should "fail with an assertion error" in {
it should "fail with an error" in {
val emptyDF = inputData.filter(inputData("text1") === "").toDF()

val smartVectorized = new SmartTextVectorizer()
.setMaxCardinality(2).setNumFeatures(4).setMinSupport(1).setTopK(2).setPrependFeatureName(false)
.setInput(f1, f2).getOutput()

val thrown = intercept[AssertionError] {
val thrown = intercept[IllegalArgumentException] {
new OpWorkflow().setResultFeatures(smartVectorized).transform(emptyDF)
}
assert(thrown.getMessage.contains("assertion failed"))
assert(thrown.getMessage.contains("requirement failed"))
}

it should "generate metadata correctly" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ trait FeatureLike[O <: FeatureType] {
if (acc.contains(f.uid)) acc else acc + (f.uid -> f)
)

assert(checkFeatureOriginStageMatch(featuresByUid.values), "Some of your features had parent features that did" +
require(checkFeatureOriginStageMatch(featuresByUid.values), "Some of your features had parent features that did" +
" not match the inputs to their origin stage. All stages must be a new instance when used to transform features")

def logDebug(msg: String) = log.debug(s"[${this.uid}]: $msg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ trait InputParams extends Params {
* @return this stage
*/
final protected def setInputFeatures[S <: OPFeature](features: Array[S]): this.type = {
assert(
require(
checkInputLength(features),
"Number of input features must match the number expected by this type of pipeline stage"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ abstract class BinarySequenceEstimator[I1 <: FeatureType, I2 <: FeatureType, O <
* @return a fitted model that will perform the transformation specified by the function defined in constructor fit
*/
override def fit(dataset: Dataset[_]): BinarySequenceModel[I1, I2, O] = {
assert(getTransientFeatures.size > 1, "Inputs cannot be empty")
require(getTransientFeatures.size > 1, "Inputs cannot be empty")
setInputSchema(dataset.schema).transformSchema(dataset.schema)

val seqColumns = inN.map(feature => col(feature.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ trait OpTransformer2N[I1 <: FeatureType, I2 <: FeatureType, O <: FeatureType]
* @return a new dataset containing a column for the transformed feature
*/
override def transform(dataset: Dataset[_]): DataFrame = {
assert(getTransientFeatures.size > 1, "Inputs cannot be empty")
require(getTransientFeatures.size > 1, "Inputs cannot be empty")
val newSchema = setInputSchema(dataset.schema).transformSchema(dataset.schema)
val functionUDF = FeatureSparkTypes.udf2N[I1, I2, O](transformFn)
val meta = newSchema(getOutputFeatureName).metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ abstract class SequenceEstimator[I <: FeatureType, O <: FeatureType]
* @return a fitted model that will perform the transformation specified by the function defined in constructor fit
*/
override def fit(dataset: Dataset[_]): SequenceModel[I, O] = {
assert(inN.nonEmpty, "Inputs cannot be empty")
require(inN.nonEmpty, "Inputs cannot be empty")
setInputSchema(dataset.schema).transformSchema(dataset.schema)

val columns = inN.map(feature => col(feature.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ trait OpTransformerN[I <: FeatureType, O <: FeatureType]
* @return a new dataset containing a column for the transformed feature
*/
override def transform(dataset: Dataset[_]): DataFrame = {
assert(inN.nonEmpty, "Inputs cannot be empty")
require(inN.nonEmpty, "Inputs cannot be empty")
val newSchema = setInputSchema(dataset.schema).transformSchema(dataset.schema)
val functionUDF = FeatureSparkTypes.udfN[I, O](transformFn)
val meta = newSchema(getOutputFeatureName).metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ case class OpVectorColumnMetadata // TODO make separate case classes extending t
index: Int = 0
) extends JsonLike {

assert(parentFeatureName.nonEmpty, "must provide parent feature name")
assert(parentFeatureType.nonEmpty, "must provide parent type name")
assert(parentFeatureName.length == parentFeatureType.length,
require(parentFeatureName.nonEmpty, "must provide parent feature name")
require(parentFeatureType.nonEmpty, "must provide parent type name")
require(parentFeatureName.length == parentFeatureType.length,
s"must provide both type and name for every parent feature," +
s" names: $parentFeatureName and types: $parentFeatureType do not have the same length")
assert(indicatorValue.isEmpty || descriptorValue.isEmpty, "cannot have both indicatorValue and descriptorValue")
require(indicatorValue.isEmpty || descriptorValue.isEmpty, "cannot have both indicatorValue and descriptorValue")

/**
* Convert this column into Spark metadata.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class OPVectorMetadataTest extends PropSpec with TestCommon with PropertyChecks
property("column metadata cannot be created with empty parents or feature types") {
forAll(vecColTupleGen) { (vct: OpVectorColumnTuple) =>
if (!checkTuples(vct)) {
assertThrows[AssertionError] { OpVectorColumnMetadata(vct._1, vct._2, vct._3, vct._4, vct._5) }
assertThrows[IllegalArgumentException] { OpVectorColumnMetadata(vct._1, vct._2, vct._3, vct._4, vct._5) }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ private[op] abstract class JoinedReader[T, U]

final def subReaders: Seq[DataReader[_]] = {
val allReaders = Seq(leftReader.subReaders, rightReader.subReaders).flatten
assert(allReaders.size == allReaders.distinct.size, "Cannot have duplicate readers in joins")
require(allReaders.size == allReaders.distinct.size, "Cannot have duplicate readers in joins")
allReaders
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ trait Reader[T] extends ReaderType[T] {
): JoinedDataReader[T, U] = {
val joinedReader =
new JoinedDataReader[T, U](leftReader = this, rightReader = other, joinKeys = joinKeys, joinType = joinType)
assert(joinedReader.leftReader.subReaders
require(joinedReader.leftReader.subReaders
.forall(r => r.fullTypeName != joinedReader.rightReader.fullTypeName),
"All joins must be for readers of different objects - self joins are not supported"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,19 @@ class JoinedReadersTest extends FlatSpec with PassengerSparkFixtureTest {
}

it should "throw an error if you try to perform a self join" in {
a[AssertionError] should be thrownBy {
a[IllegalArgumentException] should be thrownBy {
dataReader.innerJoin(dataReader)
}
}

it should "throw an error if you try to use the same reader twice" in {
a[AssertionError] should be thrownBy {
a[IllegalArgumentException] should be thrownBy {
dataReader.innerJoin(sparkReader).innerJoin(dataReader)
}
}

it should "throw an error if you try to read the same data type twice with different readers" in {
a[AssertionError] should be thrownBy {
a[IllegalArgumentException] should be thrownBy {
passengerReader.innerJoin(sparkReader).outerJoin(dataReader)
}
}
Expand Down

0 comments on commit 5e762c3

Please sign in to comment.