From 1c2291c04a6cd738e32d11052e8d8f0b9363324a Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Thu, 10 Jun 2021 16:22:20 -0700 Subject: [PATCH] Allow load method to work on registered model types --- dimod/binary/binary_quadratic_model.py | 6 ++- dimod/discrete/discrete_quadratic_model.py | 6 ++- dimod/serialization/fileview.py | 46 +++++++++++++++++++--- tests/test_serialization_fileview.py | 19 +++++++++ 4 files changed, 70 insertions(+), 7 deletions(-) diff --git a/dimod/binary/binary_quadratic_model.py b/dimod/binary/binary_quadratic_model.py index f84ac43f6..f26459fd0 100644 --- a/dimod/binary/binary_quadratic_model.py +++ b/dimod/binary/binary_quadratic_model.py @@ -37,7 +37,7 @@ from dimod.binary.pybqm import pyBQM from dimod.binary.vartypeview import VartypeView from dimod.decorators import forwarding_method -from dimod.serialization.fileview import SpooledTemporaryFile, _BytesIO, VariablesSection +from dimod.serialization.fileview import SpooledTemporaryFile, _BytesIO, VariablesSection, load from dimod.sym import Eq, Ge, Le from dimod.typing import Bias, Variable from dimod.variables import Variables, iter_deserialize_variables @@ -1909,3 +1909,7 @@ def as_bqm(*args, cls: None = None, copy: bool = False, return bqm return BinaryQuadraticModel(*args, dtype=dtype) + + +# register fileview loader +load.register(BQM_MAGIC_PREFIX, BinaryQuadraticModel.from_file) diff --git a/dimod/discrete/discrete_quadratic_model.py b/dimod/discrete/discrete_quadratic_model.py index 26dd909a2..d629bfc01 100644 --- a/dimod/discrete/discrete_quadratic_model.py +++ b/dimod/discrete/discrete_quadratic_model.py @@ -25,7 +25,7 @@ from dimod.discrete.cydiscrete_quadratic_model import cyDiscreteQuadraticModel from dimod.sampleset import as_samples -from dimod.serialization.fileview import VariablesSection, _BytesIO, SpooledTemporaryFile +from dimod.serialization.fileview import VariablesSection, _BytesIO, SpooledTemporaryFile, load from dimod.variables import Variables from typing import List, Tuple, Union, Generator @@ -803,3 +803,7 @@ def to_numpy_vectors(self): DQM = DiscreteQuadraticModel # alias + + +# register fileview loader +load.register(DQM_MAGIC_PREFIX, DiscreteQuadraticModel.from_file) diff --git a/dimod/serialization/fileview.py b/dimod/serialization/fileview.py index 674d7a698..272f453ae 100644 --- a/dimod/serialization/fileview.py +++ b/dimod/serialization/fileview.py @@ -19,6 +19,8 @@ import tempfile import warnings +from typing import ByteString, Callable, Mapping + import numpy as np from dimod.variables import iter_deserialize_variables, iter_serialize_variables @@ -221,8 +223,16 @@ def seekable(self): return True +_loaders: Mapping[bytes, Callable] = dict() + + +def register(prefix: bytes, loader: Callable): + """Register a new loader.""" + _loaders[prefix] = loader + + def load(fp, cls=None): - """Load a binary quadratic model from a file. + """Load a model from a file. Args: fp (bytes-like/file-like): @@ -233,9 +243,35 @@ def load(fp, cls=None): Deprecated keyword argument. Is ignored. Returns: - The loaded bqm. + The loaded model. """ - # todo: handle DQM - from dimod.binary.binary_quadratic_model import BinaryQuadraticModel - return BinaryQuadraticModel.from_file(fp) + if cls is not None: + warnings.warn("'cls' keyword argument is deprecated and ignored", + DeprecationWarning, stacklevel=2) + + if isinstance(fp, ByteString): + file_like: BinaryIO = _BytesIO(fp) # type: ignore[assignment] + else: + file_like = fp + + if not file_like.seekable: + raise ValueError("expected file-like to be seekable") + + pos = file_like.tell() + + lengths = sorted(set(map(len, _loaders))) + for num_bytes in lengths: + prefix = file_like.read(num_bytes) + file_like.seek(pos) + + try: + return _loaders[prefix](file_like) + except KeyError: + pass + + raise ValueError("cannot load the given file-like") + + +# for slightly more explicit naming +load.register = register diff --git a/tests/test_serialization_fileview.py b/tests/test_serialization_fileview.py index a9051e481..3e4408fbd 100644 --- a/tests/test_serialization_fileview.py +++ b/tests/test_serialization_fileview.py @@ -173,3 +173,22 @@ def test_unhashable_variables(self, name, BQM, version): new = load(fv) self.assertEqual(new, bqm) + + +class TestLoad(unittest.TestCase): + def test_bqm(self): + bqm = BinaryQuadraticModel({'a': -1}, {'ab': 1}, 7, 'SPIN') + self.assertEqual(bqm, load(bqm.to_file())) + + def test_dqm(self): + dqm = dimod.DiscreteQuadraticModel() + dqm.add_variable(5, 'a') + dqm.add_variable(6, 'b') + dqm.set_quadratic_case('a', 0, 'b', 5, 1.5) + + new = load(dqm.to_file()) + + self.assertEqual(dqm.num_variables(), new.num_variables()) + self.assertEqual(dqm.num_cases(), new.num_cases()) + self.assertEqual(dqm.get_quadratic_case('a', 0, 'b', 5), + new.get_quadratic_case('a', 0, 'b', 5))