diff --git a/dimod/sampleset.py b/dimod/sampleset.py index bfcd2cc0f..f79d62f61 100644 --- a/dimod/sampleset.py +++ b/dimod/sampleset.py @@ -237,7 +237,16 @@ def concatenate(samplesets, defaults=None): record = recfunctions.stack_arrays(records, defaults=defaults, asrecarray=True, usemask=False) - return SampleSet(record, variables, {}, vartype) + # Merge info, preserving conflicts as lists + info = {} + for k in set().union(*[s.info for s in samplesets]): + info[k] = [] + for s in samplesets: + if k in s.info and s.info[k] not in info[k]: + info[k] += [s.info[k]] + info[k] = info[k] if len(info[k])>1 else info[k].pop() + + return SampleSet(record, variables, info, vartype) def _iter_records(samplesets, vartype, variables): diff --git a/tests/test_sampleset.py b/tests/test_sampleset.py index 74672acad..8b460a134 100644 --- a/tests/test_sampleset.py +++ b/tests/test_sampleset.py @@ -900,6 +900,7 @@ def test_simple(self): out = dimod.SampleSet.from_samples([[-1, +1], [+1, -1], [+1, +1], [-1, -1]], dimod.SPIN, energy=[-1, -1, 1, 1]) self.assertEqual(comb, out) + self.assertEqual(comb.info, {}) np.testing.assert_array_equal(comb.record.sample, out.record.sample) def test_variables_order(self): @@ -941,6 +942,17 @@ def test_variables_order_and_vartype(self): self.assertEqual(comb, out) np.testing.assert_array_equal(comb.record.sample, out.record.sample) + def test_info(self): + ss0 = dimod.SampleSet.from_samples(([-1, +1], 'ab'), dimod.SPIN, info={}, energy=-1) + ss1 = dimod.SampleSet.from_samples(([-1, +1], 'ba'), dimod.SPIN, info={1:'a',2:['b','c']}, energy=-1) + ss2 = dimod.SampleSet.from_samples(([+1, +1], 'ab'), dimod.SPIN, info={3:'e',2:'d',4:[]}, energy=+1) + + comb = dimod.concatenate((ss0, ss1, ss2)) + + out_info = {1:'a',2:[['b','c'],'d'],3:'e',4:[]} + + self.assertEqual(comb.info, out_info) + def test_empty(self): with self.assertRaises(ValueError): dimod.concatenate([])