Skip to content

Commit

Permalink
progress toward getting tests to play well with error-producing behav…
Browse files Browse the repository at this point in the history
…ior when trace ID has multiple instances of the same locus ID; #392
  • Loading branch information
vreuter committed Dec 10, 2024
1 parent aebdbb0 commit b4dffae
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions src/test/scala/TestComputeLocusPairwiseDistances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import at.ac.oeaw.imba.gerlich.looptrace.csv.ColumnNames.TraceGroupColumnName
import at.ac.oeaw.imba.gerlich.looptrace.instances.all.given
import at.ac.oeaw.imba.gerlich.looptrace.space.*
import at.ac.oeaw.imba.gerlich.looptrace.syntax.all.*
import at.ac.oeaw.imba.gerlich.looptrace.ComputeLocusPairwiseDistances.Input.getGroupingKey

/**
* Tests for the simple pairwise distances computation program, for locus-specific spots
Expand Down Expand Up @@ -245,7 +246,7 @@ class TestComputeLocusPairwiseDistances extends AnyFunSuite, ScalaCheckPropertyC
}
case rs => throw new Exception(s"Got ${rs.length} records when taking pairs!")
}
whenever(possiblePrefixes.nonEmpty):
whenever(possiblePrefixes.nonEmpty && recordsSatisfyUniqueLocusPerTraceConstraint(records)):
val obsMsg = intercept[Exception]{ inputRecordsToOutputRecords(NonnegativeInt.indexed(records)) }.getMessage
possiblePrefixes.exists(obsMsg.startsWith) shouldBe true
}
Expand Down Expand Up @@ -289,11 +290,12 @@ class TestComputeLocusPairwiseDistances extends AnyFunSuite, ScalaCheckPropertyC
/* To encourage collisions, narrow the choices for grouping components. */
given arbPosAndTrace: Arbitrary[(PositionName, TraceId)] = genPosTracePairOneOrOther.toArbitrary
given arbRegion: Arbitrary[RegionId] = Gen.oneOf(40, 41).map(RegionId.unsafe).toArbitrary
forAll (Gen.choose(10, 100).flatMap(Gen.listOfN(_, arbitrary[Input.GoodRecord]))) { (records: List[Input.GoodRecord]) =>
val indexedRecords = NonnegativeInt.indexed(records)
val getKey = indexedRecords.map(_.swap).toMap.apply.andThen(Input.getGroupingKey)
val observed = inputRecordsToOutputRecords(indexedRecords)
observed.filter{ r => getKey(r.inputIndex1) === getKey(r.inputIndex2) } shouldEqual observed
forAll (Gen.choose(10, 100).flatMap(Gen.resize(_, arbitrary[List[Input.GoodRecord]]))) {
(records: List[Input.GoodRecord]) =>
val indexedRecords = NonnegativeInt.indexed(records)
val getKey = indexedRecords.map(_.swap).toMap.apply.andThen(Input.getGroupingKey)
val observed = inputRecordsToOutputRecords(indexedRecords)
observed.filter{ r => getKey(r.inputIndex1) === getKey(r.inputIndex2) } shouldEqual observed
}
}

Expand All @@ -303,7 +305,7 @@ class TestComputeLocusPairwiseDistances extends AnyFunSuite, ScalaCheckPropertyC
given arbRegion: Arbitrary[RegionId] = Gen.oneOf(40, 41).map(RegionId.unsafe).toArbitrary
given arbTime: Arbitrary[ImagingTimepoint] = Gen.const(ImagingTimepoint(NonnegativeInt(10))).toArbitrary

forAll (Gen.choose(10, 100).flatMap(Gen.listOfN(_, arbitrary[Input.GoodRecord]))) {
forAll (Gen.choose(10, 100).flatMap(Gen.resize(_, arbitrary[List[Input.GoodRecord]]))) {
(records: List[Input.GoodRecord]) => inputRecordsToOutputRecords(NonnegativeInt.indexed(records)).toList shouldEqual List()
}
}
Expand All @@ -322,6 +324,17 @@ class TestComputeLocusPairwiseDistances extends AnyFunSuite, ScalaCheckPropertyC
Input.GoodRecord(pos, TraceGroupOptional.empty, trace, reg, loc, pt)
}

given arbitraryForGoodInputRecords(using Arbitrary[Input.GoodRecord]): Arbitrary[List[Input.GoodRecord]] =
arbitrary[List[Input.GoodRecord]]
.suchThat(recordsSatisfyUniqueLocusPerTraceConstraint)
.toArbitrary

private def recordsSatisfyUniqueLocusPerTraceConstraint(records: List[Input.GoodRecord]): Boolean =
records.groupBy(getGroupingKey).view.values.forall(_.combinations(2).forall{
case r1 :: r2 :: Nil => !(r1.trace === r2.trace && r1.locus === r2.locus)
case rs => throw new Exception(s"Got ${rs.length} records when taking combinations of 2!")
})

private def genPosTracePairOneOrOther: Gen[(PositionName, TraceId)] =
val posNames = List("P0001.zarr", "P0002.zarr") map PositionName.unsafe
val traceIds = List(2, 3) map TraceId.unsafe
Expand Down

0 comments on commit b4dffae

Please sign in to comment.