Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flaky tests #300

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions hybrid/composers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def next(self, states, **runopts):
synthesis_en = thesis_en

# input sanity check
# TODO: convert to hard input validation
assert len(thesis) == len(antithesis)
assert state_thesis.problem == state_antithesis.problem
if len(thesis) != len(antithesis):
raise ValueError("thesis-antithesis length mismatch")
if state_thesis.problem != state_antithesis.problem:
raise ValueError("thesis and antithesis refer to different problems")

diff = {v for v in thesis if thesis[v] != antithesis[v]}

Expand All @@ -116,7 +117,8 @@ def next(self, states, **runopts):
synthesis_samples = SampleSet.from_samples_bqm(synthesis, bqm)

# calculation sanity check
assert synthesis_samples.first.energy == synthesis_en
if synthesis_samples.first.energy != synthesis_en:
logger.error("Synthesis error: lowest energy sample is not on synthesis path.")

return state_thesis.updated(samples=synthesis_samples)

Expand Down
5 changes: 3 additions & 2 deletions hybrid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,10 @@ def __init__(self, sampler, fields, **sample_kwargs):
if not isinstance(sampler, dimod.Sampler):
raise TypeError("'sampler' should be 'dimod.Sampler'")
try:
assert len(tuple(fields)) == 2
if len(tuple(fields)) != 2:
raise ValueError
except:
raise ValueError("'fields' should be two-tuple with input/output state fields")
raise ValueError("'fields' should be a two-tuple with input/output state fields")

self.sampler = sampler
self.input, self.output = fields
Expand Down
39 changes: 21 additions & 18 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,14 @@ def test_validation(self):
class TestHybridRunnable(unittest.TestCase):
bqm = dimod.BinaryQuadraticModel({}, {'ab': 1, 'bc': 1, 'ca': -1}, 0, dimod.SPIN)
init_state = State.from_sample(min_sample(bqm), bqm)
ground_energy = dimod.ExactSolver().sample(bqm).first.energy

def test_generic(self):
runnable = HybridRunnable(TabuSampler(), fields=('problem', 'samples'))
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_validation(self):
with self.assertRaises(TypeError):
Expand All @@ -462,54 +463,56 @@ def test_problem_sampler_runnable(self):
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_subproblem_sampler_runnable(self):
runnable = HybridSubproblemRunnable(TabuSampler())
state = self.init_state.updated(subproblem=self.bqm)
response = runnable.run(state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().subsamples.record[0].energy, -3.0)
self.assertEqual(response.result().subsamples.record[0].energy, self.ground_energy)

def test_runnable_composition(self):
runnable = IdentityDecomposer() | HybridSubproblemRunnable(TabuSampler()) | IdentityComposer()
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_racing_workflow_with_oracle_subsolver(self):
class ExactSolver(dimod.ExactSolver):
"""Exact solver that returns only the ground state."""
def sample(self, bqm):
return super().sample(bqm).truncate(1)

workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches(
hybrid.InterruptableTabuSampler(),
hybrid.EnergyImpactDecomposer(size=1)
| HybridSubproblemRunnable(dimod.ExactSolver())
| HybridSubproblemRunnable(ExactSolver())
| hybrid.SplatComposer()
) | hybrid.ArgMin(), convergence=3)
state = State.from_sample(min_sample(self.bqm), self.bqm)
response = workflow.run(state)
response = workflow.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_sampling_parameters_filtering(self):
class Sampler(dimod.ExactSolver):
"""Exact solver that fails if a sampling parameter is provided."""
parameters = {}
def sample(self, bqm):
return super().sample(bqm)
return super().sample(bqm).truncate(1)

workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches(
hybrid.InterruptableTabuSampler(),
hybrid.EnergyImpactDecomposer(size=1)
| HybridSubproblemRunnable(Sampler())
| hybrid.SplatComposer()
) | hybrid.ArgMin(), convergence=3)
state = State.from_sample(min_sample(self.bqm), self.bqm)
response = workflow.run(state)
workflow = (
hybrid.IdentityDecomposer()
| HybridSubproblemRunnable(Sampler(), unknown_sampler_argument=1)
| hybrid.IdentityComposer()
)
response = workflow.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)


class TestLogging(unittest.TestCase):
Expand Down