diff --git a/hybrid/composers.py b/hybrid/composers.py index 2b92f49..4d8e55d 100644 --- a/hybrid/composers.py +++ b/hybrid/composers.py @@ -94,9 +94,10 @@ def next(self, states, **runopts): synthesis_en = thesis_en # input sanity check - # TODO: convert to hard input validation - assert len(thesis) == len(antithesis) - assert state_thesis.problem == state_antithesis.problem + if len(thesis) != len(antithesis): + raise ValueError("thesis-antithesis length mismatch") + if state_thesis.problem != state_antithesis.problem: + raise ValueError("thesis and antithesis refer to different problems") diff = {v for v in thesis if thesis[v] != antithesis[v]} @@ -116,7 +117,8 @@ def next(self, states, **runopts): synthesis_samples = SampleSet.from_samples_bqm(synthesis, bqm) # calculation sanity check - assert synthesis_samples.first.energy == synthesis_en + if synthesis_samples.first.energy != synthesis_en: + logger.error("Synthesis error: lowest energy sample is not on synthesis path.") return state_thesis.updated(samples=synthesis_samples) diff --git a/hybrid/core.py b/hybrid/core.py index ad1a1ba..fb0ee2b 100644 --- a/hybrid/core.py +++ b/hybrid/core.py @@ -663,9 +663,10 @@ def __init__(self, sampler, fields, **sample_kwargs): if not isinstance(sampler, dimod.Sampler): raise TypeError("'sampler' should be 'dimod.Sampler'") try: - assert len(tuple(fields)) == 2 + if len(tuple(fields)) != 2: + raise ValueError except: - raise ValueError("'fields' should be two-tuple with input/output state fields") + raise ValueError("'fields' should be a two-tuple with input/output state fields") self.sampler = sampler self.input, self.output = fields diff --git a/tests/test_core.py b/tests/test_core.py index 9ae38ee..1ba2bb6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -435,13 +435,14 @@ def test_validation(self): class TestHybridRunnable(unittest.TestCase): bqm = dimod.BinaryQuadraticModel({}, {'ab': 1, 'bc': 1, 'ca': -1}, 0, dimod.SPIN) init_state = State.from_sample(min_sample(bqm), bqm) + ground_energy = dimod.ExactSolver().sample(bqm).first.energy def test_generic(self): runnable = HybridRunnable(TabuSampler(), fields=('problem', 'samples')) response = runnable.run(self.init_state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().samples.record[0].energy, -3.0) + self.assertEqual(response.result().samples.record[0].energy, self.ground_energy) def test_validation(self): with self.assertRaises(TypeError): @@ -462,7 +463,7 @@ def test_problem_sampler_runnable(self): response = runnable.run(self.init_state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().samples.record[0].energy, -3.0) + self.assertEqual(response.result().samples.record[0].energy, self.ground_energy) def test_subproblem_sampler_runnable(self): runnable = HybridSubproblemRunnable(TabuSampler()) @@ -470,46 +471,48 @@ def test_subproblem_sampler_runnable(self): response = runnable.run(state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().subsamples.record[0].energy, -3.0) + self.assertEqual(response.result().subsamples.record[0].energy, self.ground_energy) def test_runnable_composition(self): runnable = IdentityDecomposer() | HybridSubproblemRunnable(TabuSampler()) | IdentityComposer() response = runnable.run(self.init_state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().samples.record[0].energy, -3.0) + self.assertEqual(response.result().samples.record[0].energy, self.ground_energy) def test_racing_workflow_with_oracle_subsolver(self): + class ExactSolver(dimod.ExactSolver): + """Exact solver that returns only the ground state.""" + def sample(self, bqm): + return super().sample(bqm).truncate(1) + workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches( hybrid.InterruptableTabuSampler(), hybrid.EnergyImpactDecomposer(size=1) - | HybridSubproblemRunnable(dimod.ExactSolver()) + | HybridSubproblemRunnable(ExactSolver()) | hybrid.SplatComposer() ) | hybrid.ArgMin(), convergence=3) - state = State.from_sample(min_sample(self.bqm), self.bqm) - response = workflow.run(state) + response = workflow.run(self.init_state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().samples.record[0].energy, -3.0) + self.assertEqual(response.result().samples.record[0].energy, self.ground_energy) def test_sampling_parameters_filtering(self): class Sampler(dimod.ExactSolver): """Exact solver that fails if a sampling parameter is provided.""" parameters = {} def sample(self, bqm): - return super().sample(bqm) + return super().sample(bqm).truncate(1) - workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches( - hybrid.InterruptableTabuSampler(), - hybrid.EnergyImpactDecomposer(size=1) - | HybridSubproblemRunnable(Sampler()) - | hybrid.SplatComposer() - ) | hybrid.ArgMin(), convergence=3) - state = State.from_sample(min_sample(self.bqm), self.bqm) - response = workflow.run(state) + workflow = ( + hybrid.IdentityDecomposer() + | HybridSubproblemRunnable(Sampler(), unknown_sampler_argument=1) + | hybrid.IdentityComposer() + ) + response = workflow.run(self.init_state) self.assertIsInstance(response, concurrent.futures.Future) - self.assertEqual(response.result().samples.record[0].energy, -3.0) + self.assertEqual(response.result().samples.record[0].energy, self.ground_energy) class TestLogging(unittest.TestCase):