Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include chain break points in returned embedding context #447

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dwave/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from dwave.embedding.diagnostic import diagnose_embedding, is_valid_embedding, verify_embedding

from dwave.embedding.chain_breaks import broken_chains
from dwave.embedding.chain_breaks import broken_chains, break_points
from dwave.embedding.chain_breaks import discard, majority_vote, weighted_random, MinimizeEnergy

from dwave.embedding.transforms import embed_bqm, embed_ising, embed_qubo, unembed_sampleset, EmbeddedStructure
Expand Down
66 changes: 66 additions & 0 deletions dwave/embedding/chain_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,75 @@
'majority_vote',
'weighted_random',
'MinimizeEnergy',
'break_points',
]


def break_points(samples_like, embedding):
"""Identify breakpoints in each chain.

Args:
samples_like (dimod.typing.SamplesLike):
A collection of raw samples for the embedded problem.
Each sample's variables' values should be 0/1 or -1/+1.

embedding (dwave.embedding.transforms.EmbeddedStructure):
Mapping from source graph to target graph as a dict of form {s: [t, ...], ...},
where s is a source-model variable and t is a target-model variable.

Returns:
list: A list of `dict`. The size of the list is equal to number of input samples:

dict: A dictionary whose keys are variable labels of the logical BQM
(the problem you care about), and values are lists of 2-tuples `(u, v)`
representing edges in the target graph (the QPU graph). The existence
of an edge (u, v) indicates `u` and `v` disagree in its value, i.e.,
a break point in the chain.

The index of the list corresponds to the index of the sample in `samples_like`.

Examples:

>>> from dwave.embedding.transforms import EmbeddedStructure

>>> embedding = EmbeddedStructure([(0,1), (1,2)], {0: [0, 1, 2]})
>>> samples = np.array([[-1, +1, -1], [-1, -1, -1]], dtype=np.int8)
>>> dwave.embedding.break_points(samples, embedding)
[{0: [(0, 1), (1, 2)]}, {}]

>>> embedding = EmbeddedStructure([(0,1), (1,2), (0,2)], {0: [0, 1, 2]})
>>> samples = np.array([[-1, +1, -1], [-1, +1, +1]], dtype=np.int8)
>>> dwave.embedding.break_points(samples, embedding)
[{0: [(0, 1), (1, 2)]}, {0: [(0, 1), (0, 2)]}]

>>> samples = [{"a": -1, "b": +1, "c": -1}, {"a": -1, "b": -1, "c": -1},]
>>> target_edges = [("a", "b"), ("b", "c")]
>>> chains = {"x": ["a", "b", "c"]}
>>> embedding = EmbeddedStructure(target_edges, chains)
>>> dwave.embedding.break_points(samples, embedding)
[{"x": [("a", "b"), ("b", "c")]}, {}]
"""
result = []
samples, labels = dimod.as_samples(samples_like)
label_to_i = {label: idx for idx, label in enumerate(labels)}
for sample in samples:
bps = {}
for node in embedding.keys():
try:
chain_edges = embedding.chain_edges(node)
except AttributeError:
raise TypeError("'embedding' must be a dwave.embedding.EmbeddedStructure") from None

broken_edges = [(u, v) for u, v in chain_edges
if sample[label_to_i[u]] != sample[label_to_i[v]]]
if len(broken_edges) > 0:
bps[node] = broken_edges

result.append(bps)

return result


def broken_chains(samples, chains):
"""Find the broken chains.

Expand Down
4 changes: 3 additions & 1 deletion dwave/system/composites/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from dwave.embedding import (target_to_source, unembed_sampleset, embed_bqm,
chain_to_quadratic, EmbeddedStructure)
from dwave.embedding.chain_breaks import break_points
from dwave.system.warnings import WarningHandler, WarningAction

__all__ = ('EmbeddingComposite',
Expand Down Expand Up @@ -289,7 +290,8 @@ def async_unembed(response):
if return_embedding:
sampleset.info['embedding_context'].update(
embedding_parameters=embedding_parameters,
chain_strength=embedding.chain_strength)
chain_strength=embedding.chain_strength,
break_points=break_points(response, embedding))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a non-trivial performance hit. IMO we either should not do this by default or we need to write a more performant implementation of break_points.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, ideally this would be a lazy proxy.

But FWIW, return_embedding does default to False.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, unless for instance the inspector is imported.

My inclination is to not include this in the embedding composite for now, but document how to use the TrackingComposite to get the information in the docstring of break_points().

Copy link
Contributor Author

@kevinchern kevinchern May 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR according to feedback except for this comment. Should I move the statistic to TrackingComposite instead?

I think the lazy proxy approach would require storing response. For a more performant implementation, I can wrap it in numba. Any suggestions for writing a more performant implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best thing is just to remove the statistic from the embedding composite altogether. I would then add an example to the break_points docstring showing how to calculate it, using the TrackingComposite to retrieve the relevant information.


if chain_break_fraction and len(sampleset):
warninghandler.issue("All samples have broken chains",
Expand Down
102 changes: 102 additions & 0 deletions tests/test_embedding_chain_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,108 @@
import numpy as np

import dwave.embedding
from dwave.embedding.transforms import EmbeddedStructure


class TestBreakPoints(unittest.TestCase):
def test_break_points_no_samples(self):
# No samples
target_edges = [(0, 1), (1, 2)]
chains = {0: [0, 1, 2]}
embedding = EmbeddedStructure(target_edges, chains)

samples = np.array([], dtype=np.int8)

break_points = dwave.embedding.break_points(samples, embedding)
answer = []

np.testing.assert_array_equal(answer, break_points)

def test_break_points_no_breaks(self):
# No breaks :D
target_edges = [(0, 1), (1, 2)]
chains = {0: [0, 1, 2]}
embedding = EmbeddedStructure(target_edges, chains)

samples = np.array([[+1, +1, +1],
[+1, +1, +1]], dtype=np.int8)

break_points = dwave.embedding.break_points(samples, embedding)
answer = [{}, {}]

np.testing.assert_array_equal(answer, break_points)

def test_break_points_chain(self):
# Target chain of length 3, one embedded variable
target_edges = [(0, 1), (1, 2)]
chains = {0: [0, 1, 2]}
embedding = EmbeddedStructure(target_edges, chains)

samples = np.array([[-1, +1, -1],
[-1, -1, -1]], dtype=np.int8)

break_points = dwave.embedding.break_points(samples, embedding)
answer = [{0: [(0, 1), (1, 2)]},
{}]

np.testing.assert_array_equal(answer, break_points)

def test_break_points_chain_string_labels(self):
# Target chain of length 3, one embedded variable, but labels are strings
target_edges = [("a", "b"), ("b", "c")]
chains = {"x": ["a", "b", "c"]}
embedding = EmbeddedStructure(target_edges, chains)

# samples = np.array([[-1, +1, -1],
# [-1, -1, -1]], dtype=np.int8)
samples = [{"a": -1, "b": +1, "c": -1},
{"a": -1, "b": -1, "c": -1},]

break_points = dwave.embedding.break_points(samples, embedding)
answer = [{"x": [("a", "b"), ("b", "c")]},
{}]

np.testing.assert_array_equal(answer, break_points)

def test_break_points_loop(self):
# Target triangle, one embedded variable
target_edges = [(0, 1), (1, 2), (0, 2)]
chains = {0: [0, 1, 2]}
embedding = EmbeddedStructure(target_edges, chains)

samples = np.array([[-1, +1, -1],
[-1, +1, +1]], dtype=np.int8)

break_points = dwave.embedding.break_points(samples, embedding)
answer = [{0: [(0, 1), (1, 2)]},
{0: [(0, 1), (0, 2)]}]
np.testing.assert_array_equal(answer, break_points)

def test_break_points_chain_2(self):
# Target triangle, two embedded variables
target_edges = [(0, 1), (1, 2)]
chains = {0: [0, 1], 1: [2]}

embedding = EmbeddedStructure(target_edges, chains)
samples = np.array([[-1, +1, -1],
[-1, -1, +1]], dtype=np.int8)

break_points = dwave.embedding.break_points(samples, embedding)
answer = [{0: [(0, 1)]},
{}]
np.testing.assert_array_equal(answer, break_points)

def test_break_points_loop_2(self):
# Target square, two embedded variables
target_edges = [(0, 1), (1, 2), (2, 3), (0, 3)]
chains = {0: [0, 1], 1: [2, 3]}
embedding = EmbeddedStructure(target_edges, chains)
samples = np.array([[-1, -1, +1, -1],
[-1, +1, +1, -1]], dtype=np.int8)
break_points = dwave.embedding.break_points(samples, embedding)
answer = [{1: [(2, 3)]}, {0: [(0, 1)],
1: [(2, 3)]}]
np.testing.assert_array_equal(answer, break_points)


class TestBrokenChains(unittest.TestCase):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_embedding_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def test_return_embedding(self):
embedding = sampleset.info['embedding_context']['embedding']
self.assertEqual(set(embedding), {'a', 'c'})

self.assertIn('break_points', sampleset.info['embedding_context'])
break_points = sampleset.info['embedding_context']['break_points']
self.assertEqual(break_points, [])

self.assertIn('chain_break_method', sampleset.info['embedding_context'])
self.assertEqual(sampleset.info['embedding_context']['chain_break_method'], 'majority_vote') # the default

Expand Down Expand Up @@ -277,6 +281,10 @@ def test_return_embedding_as_class_variable(self):
embedding = sampleset.info['embedding_context']['embedding']
self.assertEqual(set(embedding), {'a', 'c'})

self.assertIn('break_points', sampleset.info['embedding_context'])
break_points = sampleset.info['embedding_context']['break_points']
self.assertEqual(break_points, [])

self.assertIn('chain_break_method', sampleset.info['embedding_context'])
self.assertEqual(sampleset.info['embedding_context']['chain_break_method'], 'majority_vote') # the default

Expand Down