From 5026b3d604372109c3e3ff8f8e0528efb369f831 Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 1 Dec 2022 09:04:52 -0800 Subject: [PATCH] Code restructure: break 'hyper.py' into smaller files. PiperOrigin-RevId: 492211440 --- pyglove/core/hyper.py | 2957 ----------------- pyglove/core/hyper/__init__.py | 118 + pyglove/core/hyper/base.py | 198 ++ pyglove/core/hyper/categorical.py | 696 ++++ pyglove/core/hyper/categorical_test.py | 404 +++ pyglove/core/hyper/custom.py | 159 + pyglove/core/hyper/custom_test.py | 111 + pyglove/core/hyper/derived.py | 154 + pyglove/core/hyper/derived_test.py | 137 + pyglove/core/hyper/dynamic_evaluation.py | 588 ++++ pyglove/core/hyper/dynamic_evaluation_test.py | 523 +++ pyglove/core/hyper/evolvable.py | 278 ++ pyglove/core/hyper/evolvable_test.py | 235 ++ pyglove/core/hyper/iter.py | 193 ++ pyglove/core/hyper/iter_test.py | 135 + pyglove/core/hyper/numerical.py | 219 ++ pyglove/core/hyper/numerical_test.py | 134 + pyglove/core/hyper/object_template.py | 577 ++++ pyglove/core/hyper/object_template_test.py | 269 ++ pyglove/core/hyper_test.py | 1829 ---------- 20 files changed, 5128 insertions(+), 4786 deletions(-) delete mode 100644 pyglove/core/hyper.py create mode 100644 pyglove/core/hyper/__init__.py create mode 100644 pyglove/core/hyper/base.py create mode 100644 pyglove/core/hyper/categorical.py create mode 100644 pyglove/core/hyper/categorical_test.py create mode 100644 pyglove/core/hyper/custom.py create mode 100644 pyglove/core/hyper/custom_test.py create mode 100644 pyglove/core/hyper/derived.py create mode 100644 pyglove/core/hyper/derived_test.py create mode 100644 pyglove/core/hyper/dynamic_evaluation.py create mode 100644 pyglove/core/hyper/dynamic_evaluation_test.py create mode 100644 pyglove/core/hyper/evolvable.py create mode 100644 pyglove/core/hyper/evolvable_test.py create mode 100644 pyglove/core/hyper/iter.py create mode 100644 pyglove/core/hyper/iter_test.py create mode 100644 pyglove/core/hyper/numerical.py create mode 100644 pyglove/core/hyper/numerical_test.py create mode 100644 pyglove/core/hyper/object_template.py create mode 100644 pyglove/core/hyper/object_template_test.py delete mode 100644 pyglove/core/hyper_test.py diff --git a/pyglove/core/hyper.py b/pyglove/core/hyper.py deleted file mode 100644 index 4883ef7..0000000 --- a/pyglove/core/hyper.py +++ /dev/null @@ -1,2957 +0,0 @@ -# Copyright 2019 The PyGlove Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Hyper objects: representing template-based object space. - -In PyGlove, an object space is represented by a hyper object, which is an -symbolic object that is placeheld by hyper primitives -(:class:`pyglove.hyper.HyperPrimitive`). Through hyper objects, object templates -(:class:`pyglove.hyper.ObjectTemplate`) can be obtained to generate objects -based on program genomes (:class:`pyglove.DNA`). - - .. graphviz:: - :align: center - - digraph hypers { - node [shape="box"]; - edge [arrowtail="empty" arrowhead="none" dir="back" style="dashed"]; - hyper [label="HyperValue" href="hyper_value.html"]; - template [label="ObjectTemplate" href="object_template.html"]; - primitive [label="HyperPrimitive" href="hyper_primitive.html"]; - choices [label="Choices" href="choices.html"]; - oneof [label="OneOf" href="oneof_class.html"]; - manyof [label="ManyOf" href="manyof_class.html"]; - float [label="Float" href="float.html"]; - custom [label="CustomHyper" href="custom_hyper.html"]; - hyper -> template; - hyper -> primitive; - primitive -> choices; - choices -> oneof; - choices -> manyof; - primitive -> float; - primitive -> custom - } - -Hyper values map 1:1 to genotypes as the following: - -+-------------------------------------+----------------------------------------+ -| Hyper class | Genotype class | -+=====================================+========================================+ -|:class:`pyglove.hyper.HyperValue` |:class:`pyglove.DNASpec` | -+-------------------------------------+----------------------------------------+ -|:class:`pyglove.hyper.ObjectTemplate`|:class:`pyglove.geno.Space` | -+-------------------------------------+----------------------------------------+ -|:class:`pyglove.hyper.HyperPrimitive`|:class:`pyglove.geno.DecisionPoint` | -+-------------------------------------+----------------------------------------+ -|:class:`pyglove.hyper.Choices` |:class:`pyglove.geno.Choices` | -+-------------------------------------+----------------------------------------+ -|:class:`pyglove.hyper.Float` |:class:`pyglove.geno.Float` | -+-------------------------------------+----------------------------------------+ -|:class:`pyglove.hyper.CustomHyper` :class:`pyglove.geno.CustomDecisionPoint` | -+------------------------------------------------------------------------------+ -""" - -import abc -import contextlib -import copy -import dataclasses -import enum -import numbers -import random -import threading -import types -import typing -from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union - -from pyglove.core import geno -from pyglove.core import object_utils -from pyglove.core import symbolic -from pyglove.core import typing as schema - - -# Disable implicit str concat in Tuple as it's used for multi-line docstr for -# symbolic members. -# pylint: disable=implicit-str-concat - - -class HyperValue(symbolic.NonDeterministic): # pytype: disable=ignored-metaclass - """Base class for a hyper value. - - Hyper value represents a space of objects, which is essential for - programmatically generating objects. It can encode a concrete object into a - DNA, or decode a DNA into a concrete object. - - DNA is a nestable numeric interface we use to generate object (see `geno.py`). - Each position in the DNA represents either the index of a choice, or a value - itself is numeric. There could be multiple choices standing side-by-side, - representing knobs on different parts of an object, or choices being chained, - forming conditional choice spaces, which can be described by a tree structure. - - Hyper values form a tree as the following: - - .. graphviz:: - - digraph relationship { - template [label="ObjectTemplate" href="object_template.html"]; - primitive [label="HyperPrimitive" href="hyper_primitive.html"]; - choices [label="OneOf/ManyOf" href="choices.html"]; - float [label="Float" href="float_class.html"]; - custom [label="CustomHyper" href="custom_hyper.html"]; - template -> primitive [label="elements (1:*)"]; - primitive -> choices [dir="back" arrowtail="empty" style="dashed"]; - primitive -> float [dir="back" arrowtail="empty" style="dashed"]; - primitive -> custom [dir="back" arrowtail="empty" style="dashed"]; - choices -> template [label="candidates (1:*)"]; - } - """ - - __metaclass__ = abc.ABCMeta - - def __init__(self): - # DNA and decoded value are states for __call__. - # Though `decode` and `encode` methods are stateless. - self._dna = None - self._decoded_value = None - - def set_dna(self, dna: geno.DNA) -> None: - """Use this DNA to generate value. - - NOTE(daiyip): self._dna is only used in __call__. - Thus 'set_dna' can be called multiple times to generate different values. - - Args: - dna: DNA to use to decode the value. - """ - self._dna = dna - # Invalidate decoded value when DNA is refreshed. - self._decoded_value = None - - @property - def dna(self) -> geno.DNA: - """Returns the DNA that is being used by this hyper value.""" - return self._dna - - def __call__(self) -> Any: - """Generate value from DNA provided by set_dna.""" - if self._decoded_value is None: - if self._dna is None: - raise ValueError( - '\'set_dna\' should be called to set a DNA before \'__call__\'.') - self._decoded_value = self.decode(self._dna) - return self._decoded_value - - def decode(self, dna: geno.DNA) -> Any: - """Decode a value from a DNA.""" - self.set_dna(dna) - return self._decode() - - @abc.abstractmethod - def _decode(self) -> Any: - """Decode using self.dna.""" - - @abc.abstractmethod - def encode(self, value: Any) -> geno.DNA: - """Encode a value into a DNA. - - Args: - value: A value that conforms to the hyper value definition. - - Returns: - DNA for the value. - """ - - @abc.abstractmethod - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: - """Get DNA spec of DNA that is decodable/encodable by this hyper value.""" - - -@symbolic.members([ - ('name', schema.Str().noneable(), - 'Name of the hyper primitive. Useful in define-by-run mode to identify a' - 'decision point in the search space - that is - different instances with ' - 'the same name will refer to the same decision point in the search space ' - 'under define-by-run mode. ' - 'Please refer to `pg.hyper.trace` for details.'), - ('hints', schema.Any(default=None), 'Generator hints') -]) -class HyperPrimitive(symbolic.Object, HyperValue): - """Base class for hyper primitives. - - A hyper primitive is a pure symbolic object which represents an object - generation rule. It correspond to a decision point - (:class:`pyglove.geno.DecisionPoint`) in the algorithm's view. - - Child classes: - - * :class:`pyglove.hyper.Choices` - - * :class:`pyglove.hyper.OneOf` - * :class:`pyglove.hyper.ManyOf` - * :class:`pyglove.hyper.Float` - * :class:`pyglove.hyper.CustomHyper` - """ - - def __new__(cls, *args, **kwargs) -> Any: - """Overrides __new__ for supporting dynamic evaluation mode. - - Args: - *args: Positional arguments passed to init the custom hyper. - **kwargs: Keyword arguments passed to init the custom hyper. - - Returns: - A dynamic evaluated value according to current `dynamic_evaluate` context. - """ - dynamic_evaluate_fn = getattr( - _thread_local_state, - _TLS_KEY_DYNAMIC_EVALUATE_FN, - _global_dynamic_evaluate_fn) - - if dynamic_evaluate_fn is None: - return super().__new__(cls) - else: - hyper_value = object.__new__(cls) - cls.__init__(hyper_value, *args, **kwargs) - return dynamic_evaluate_fn(hyper_value) # pylint: disable=not-callable - - def _sym_clone(self, deep: bool, memo=None) -> 'HyperPrimitive': - """Overrides _sym_clone to force no dynamic evaluation.""" - kwargs = dict() - for k, v in self._sym_attributes.items(): - if deep or isinstance(v, symbolic.Symbolic): - v = symbolic.clone(v, deep, memo) - kwargs[k] = v - - # NOTE(daiyip): instead of calling self.__class__(...), - # we manually create a new instance without invoking dynamic - # evaluation. - new_value = object.__new__(self.__class__) - new_value.__init__( # pylint: disable=unexpected-keyword-arg - allow_partial=self._allow_partial, sealed=self._sealed, **kwargs) - return new_value - - -@symbolic.members([ - ('num_choices', schema.Int(min_value=0).noneable(), - 'Number of choices to make. If set to None, any number of choices is ' - 'acceptable.'), - ('candidates', schema.List(schema.Any()), - 'Candidate values, which may contain nested hyper values.' - 'Candidate can customize its display value (literal) by implementing the ' - '`pg.Formattable` interface.'), - ('choices_distinct', schema.Bool(True), 'Whether choices are distinct.'), - ('choices_sorted', schema.Bool(False), 'Whether choices are sorted.'), - ('where', schema.Callable([schema.Object(HyperPrimitive)], - returns=schema.Bool()).noneable(), - 'Callable object to filter nested hyper values. If None, all nested hyper ' - 'value will be included in the encoding/decoding process. Otherwise only ' - 'the hyper values on which `where` returns True will be included. `where` ' - 'can be useful to partition a search space into separate optimization ' - 'processes. Please see `ObjectTemplate` docstr for details.') -]) -class Choices(HyperPrimitive): - """Categorical choices from a list of candidates. - - Example:: - - # A single categorical choice: - v = pg.oneof([1, 2, 3]) - - # A multiple categorical choice as a list: - vs = pg.manyof(2, [1, 2, 3]) - - # A hierarchical categorical choice: - v2 = pg.oneof([ - 'foo', - 'bar', - pg.manyof(2, [1, 2, 3]) - ]) - - See also: - - * :class:`pyglove.hyper.OneOf` - * :class:`pyglove.hyper.ManyOf` - * :func:`pyglove.oneof` - * :func:`pyglove.manyof` - * :func:`pyglove.permutate` - """ - - def _on_bound(self): - """On members are bound.""" - super()._on_bound() - if self.num_choices > len(self.candidates) and self.choices_distinct: - raise ValueError( - f'{len(self.candidates)} candidates cannot produce ' - f'{self.num_choices} distinct choices.') - self._candidate_templates = [ - ObjectTemplate(c, where=self.where) for c in self.candidates - ] - # ValueSpec for candidate. - self._value_spec = None - - def _update_children_paths( - self, old_path: object_utils.KeyPath, new_path: object_utils.KeyPath): - """Customized logic to update children paths.""" - super()._update_children_paths(old_path, new_path) - for t in self._candidate_templates: - t.root_path = self.sym_path - - @property - def candidate_templates(self): - """Returns candidate templates.""" - return self._candidate_templates - - @property - def is_leaf(self) -> bool: - """Returns whether this is a leaf node.""" - for t in self._candidate_templates: - if not t.is_constant: - return False - return True - - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.Choices: - """Returns corresponding DNASpec.""" - return geno.Choices( - num_choices=self.num_choices, - candidates=[ct.dna_spec() for ct in self._candidate_templates], - literal_values=[self._literal_value(c) - for i, c in enumerate(self.candidates)], - distinct=self.choices_distinct, - sorted=self.choices_sorted, - hints=self.hints, - name=self.name, - location=location or object_utils.KeyPath()) - - def _literal_value( - self, candidate: Any, max_len: int = 120) -> Union[int, float, str]: - """Returns literal value for candidate.""" - if isinstance(candidate, numbers.Number): - return candidate - - literal = object_utils.format(candidate, compact=True, - hide_default_values=True, - hide_missing_values=True, - strip_object_id=True) - if len(literal) > max_len: - literal = literal[:max_len - 3] + '...' - return literal - - def _decode(self) -> List[Any]: - """Decode a DNA into a list of object.""" - dna = self._dna - if self.num_choices == 1: - # Single choice. - if not isinstance(dna.value, int): - raise ValueError( - object_utils.message_on_path( - f'Did you forget to specify values for conditional choices?\n' - f'Expect integer for {self.__class__.__name__}. ' - f'Encountered: {dna!r}.', self.sym_path)) - if dna.value >= len(self.candidates): - raise ValueError( - object_utils.message_on_path( - f'Choice out of range. Value: {dna.value!r}, ' - f'Candidates: {len(self.candidates)}.', self.sym_path)) - choices = [self._candidate_templates[dna.value].decode( - geno.DNA(None, dna.children))] - else: - # Multi choices. - if len(dna.children) != self.num_choices: - raise ValueError( - object_utils.message_on_path( - f'Number of DNA child values does not match the number of ' - f'choices. Child values: {dna.children!r}, ' - f'Choices: {self.num_choices}.', self.sym_path)) - if self.choices_distinct or self.choices_sorted: - sub_dna_values = [s.value for s in dna] - if (self.choices_distinct - and len(set(sub_dna_values)) != len(dna.children)): - raise ValueError( - object_utils.message_on_path( - f'DNA child values should be distinct. ' - f'Encountered: {sub_dna_values}.', self.sym_path)) - if self.choices_sorted and sorted(sub_dna_values) != sub_dna_values: - raise ValueError( - object_utils.message_on_path( - f'DNA child values should be sorted. ' - f'Encountered: {sub_dna_values}.', self.sym_path)) - choices = [] - for i, sub_dna in enumerate(dna): - if not isinstance(sub_dna.value, int): - raise ValueError( - object_utils.message_on_path( - f'Choice value should be int. ' - f'Encountered: {sub_dna.value}.', - object_utils.KeyPath(i, self.sym_path))) - if sub_dna.value >= len(self.candidates): - raise ValueError( - object_utils.message_on_path( - f'Choice out of range. Value: {sub_dna.value}, ' - f'Candidates: {len(self.candidates)}.', - object_utils.KeyPath(i, self.sym_path))) - choices.append(self._candidate_templates[sub_dna.value].decode( - geno.DNA(None, sub_dna.children))) - return choices - - def encode(self, value: List[Any]) -> geno.DNA: - """Encode a list of values into DNA. - - Example:: - - # DNA of an object containing a single OneOf. - # {'a': 1} => DNA(0) - { - 'a': one_of([1, 2]) - } - - - # DNA of an object containing multiple OneOfs. - # {'b': 1, 'c': bar} => DNA([0, 1]) - { - 'b': pg.oneof([1, 2]), - 'c': pg.oneof(['foo', 'bar']) - } - - # DNA of an object containing conditional space. - # {'a': {'b': 1} => DNA(0, 0, 0)]) - # {'a': {'b': [4, 7]} => DNA(1, [(0, 1), 2]) - # {'a': {'b': 'bar'} => DNA(2) - { - 'a': { - 'b': pg.oneof([ - pg.oneof([ - pg.oneof([1, 2]), - pg.oneof(3, 4)]), - pg.manyof(2, [ - pg.oneof([4, 5]), - 6, - 7 - ]), - ]), - 'bar', - ]) - } - } - - Args: - value: A list of value that can match choice candidates. - - Returns: - Encoded DNA. - - Raises: - ValueError if value cannot be encoded. - """ - if not isinstance(value, list): - raise ValueError( - object_utils.message_on_path( - f'Cannot encode value: value should be a list type. ' - f'Encountered: {value!r}.', self.sym_path)) - choices = [] - if self.num_choices is not None and len(value) != self.num_choices: - raise ValueError( - object_utils.message_on_path( - f'Length of input list is different from the number of choices ' - f'({self.num_choices}). Encountered: {value}.', self.sym_path)) - for v in value: - choice_id = None - child_dna = None - for i, b in enumerate(self._candidate_templates): - succeeded, child_dna = b.try_encode(v) - if succeeded: - choice_id = i - break - if child_dna is None: - raise ValueError( - object_utils.message_on_path( - f'Cannot encode value: no candidates matches with ' - f'the value. Value: {v!r}, Candidates: {self.candidates}.', - self.sym_path)) - choices.append(geno.DNA(choice_id, [child_dna])) - return geno.DNA(None, choices) - - -@symbolic.members( - [], - init_arg_list=[ - 'num_choices', 'candidates', 'choices_distinct', - 'choices_sorted', 'hints' - ], - # TODO(daiyip): Change to 'ManyOf' once existing code migrates to ManyOf. - serialization_key='hyper.ManyOf', - additional_keys=['pyglove.generators.genetic.ChoiceList'] -) -class ManyOf(Choices): - """N Choose K. - - Example:: - - # Chooses 2 distinct candidates. - v = pg.manyof(2, [1, 2, 3]) - - # Chooses 2 non-distinct candidates. - v = pg.manyof(2, [1, 2, 3], distinct=False) - - # Chooses 2 distinct candidates sorted by their indices. - v = pg.manyof(2, [1, 2, 3], sorted=True) - - # Permutates the candidates. - v = pg.permutate([1, 2, 3]) - - # A hierarchical categorical choice: - v2 = pg.manyof(2, [ - 'foo', - 'bar', - pg.oneof([1, 2, 3]) - ]) - - See also: - - * :func:`pyglove.manyof` - * :func:`pyglove.permutate` - * :class:`pyglove.hyper.Choices` - * :class:`pyglove.hyper.OneOf` - * :class:`pyglove.hyper.Float` - * :class:`pyglove.hyper.CustomHyper` - """ - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'Choices']: - """Validate candidates during value_spec binding time.""" - # Check if value_spec directly accepts `self`. - if value_spec.value_type and isinstance(self, value_spec.value_type): - return (False, self) - - if self._value_spec: - src_spec = self._value_spec - dest_spec = value_spec - if not dest_spec.is_compatible(src_spec): - raise TypeError( - object_utils.message_on_path( - f'Cannot bind an incompatible value spec {dest_spec} ' - f'to {self.__class__.__name__} with bound spec {src_spec}.', - path)) - return (False, self) - - list_spec = typing.cast(schema.List, - schema.ensure_value_spec( - value_spec, schema.List(schema.Any()), path)) - if list_spec: - for i, c in enumerate(self.candidates): - list_spec.element.value.apply( - c, - self._allow_partial, - root_path=path + f'candidates[{i}]') - self._value_spec = list_spec - return (False, self) - - -@symbolic.members( - [ - ('num_choices', 1) - ], - init_arg_list=['candidates', 'hints', 'where'], - serialization_key='hyper.OneOf', - additional_keys=['pyglove.generators.genetic.ChoiceValue'] -) -class OneOf(Choices): - """N Choose 1. - - Example:: - - # A single categorical choice: - v = pg.oneof([1, 2, 3]) - - # A hierarchical categorical choice: - v2 = pg.oneof([ - 'foo', - 'bar', - pg.oneof([1, 2, 3]) - ]) - - See also: - - * :func:`pyglove.oneof` - * :class:`pyglove.hyper.Choices` - * :class:`pyglove.hyper.ManyOf` - * :class:`pyglove.hyper.Float` - * :class:`pyglove.hyper.CustomHyper` - """ - - def _on_bound(self): - """Event triggered when members are bound.""" - super()._on_bound() - assert self.num_choices == 1 - - def _decode(self) -> Any: - """Decode a DNA into an object.""" - return super()._decode()[0] - - def encode(self, value: Any) -> geno.DNA: - """Encode a value into a DNA.""" - # NOTE(daiyip): Single choice DNA will automatically be pulled - # up from children to current node. Thus we simply returns - # encoded DNA from parent node. - return super().encode([value]) - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'OneOf']: - """Validate candidates during value_spec binding time.""" - # Check if value_spec directly accepts `self`. - if value_spec.value_type and isinstance(self, value_spec.value_type): - return (False, self) - - if self._value_spec: - if not value_spec.is_compatible(self._value_spec): - raise TypeError( - object_utils.message_on_path( - f'Cannot bind an incompatible value spec {value_spec} ' - f'to {self.__class__.__name__} with bound ' - f'spec {self._value_spec}.', path)) - return (False, self) - - for i, c in enumerate(self.candidates): - value_spec.apply( - c, - self._allow_partial, - root_path=path + f'candidates[{i}]') - self._value_spec = value_spec - return (False, self) - - -@symbolic.members( - [ - ('min_value', schema.Float(), 'Minimum acceptable value.'), - ('max_value', schema.Float(), 'Maximum acceptable value.'), - geno.float_scale_spec('scale'), - ], - init_arg_list=['min_value', 'max_value', 'scale', 'name', 'hints'], - serialization_key='hyper.Float', - additional_keys=['pyglove.generators.genetic.Float'] -) -class Float(HyperPrimitive): - """A continuous value within a range. - - Example:: - - # A float value between between 0.0 and 1.0. - v = pg.floatv(0.0, 1.0) - - See also: - - * :func:`pyglove.floatv` - * :class:`pyglove.hyper.Choices` - * :class:`pyglove.hyper.OneOf` - * :class:`pyglove.hyper.ManyOf` - * :class:`pyglove.hyper.CustomHyper` - """ - - def _on_bound(self): - """Constructor.""" - super()._on_bound() - if self.min_value > self.max_value: - raise ValueError( - f'\'min_value\' ({self.min_value}) is greater than \'max_value\' ' - f'({self.max_value}).') - if self.scale in ['log', 'rlog'] and self.min_value <= 0: - raise ValueError( - f'\'min_value\' must be positive when `scale` is {self.scale!r}. ' - f'encountered: {self.min_value}.') - - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.Float: - """Returns corresponding DNASpec.""" - return geno.Float( - min_value=self.min_value, - max_value=self.max_value, - scale=self.scale, - hints=self.hints, - name=self.name, - location=location or object_utils.KeyPath()) - - def _decode(self) -> float: - """Decode a DNA into a float value.""" - dna = self._dna - if not isinstance(dna.value, float): - raise ValueError( - object_utils.message_on_path( - f'Expect float value. Encountered: {dna.value}.', self.sym_path)) - if dna.value < self.min_value: - raise ValueError( - object_utils.message_on_path( - f'DNA value should be no less than {self.min_value}. ' - f'Encountered {dna.value}.', self.sym_path)) - - if dna.value > self.max_value: - raise ValueError( - object_utils.message_on_path( - f'DNA value should be no greater than {self.max_value}. ' - f'Encountered {dna.value}.', self.sym_path)) - return dna.value - - def encode(self, value: float) -> geno.DNA: - """Encode a float value into a DNA.""" - if not isinstance(value, float): - raise ValueError( - object_utils.message_on_path( - f'Value should be float to be encoded for {self!r}. ' - f'Encountered {value}.', self.sym_path)) - if value < self.min_value: - raise ValueError( - object_utils.message_on_path( - f'Value should be no less than {self.min_value}. ' - f'Encountered {value}.', self.sym_path)) - if value > self.max_value: - raise ValueError( - object_utils.message_on_path( - f'Value should be no greater than {self.max_value}. ' - f'Encountered {value}.', self.sym_path)) - return geno.DNA(value) - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool = False, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'Float']: - """Validate candidates during value_spec binding time.""" - del allow_partial - del child_transform - # Check if value_spec directly accepts `self`. - if value_spec.value_type and isinstance(self, value_spec.value_type): - return (False, self) - - float_spec = typing.cast( - schema.Float, schema.ensure_value_spec( - value_spec, schema.Float(), path)) - if float_spec: - if (float_spec.min_value is not None - and self.min_value < float_spec.min_value): - raise ValueError( - object_utils.message_on_path( - f'Float.min_value ({self.min_value}) should be no less than ' - f'the min value ({float_spec.min_value}) of value spec: ' - f'{float_spec}.', path)) - if (float_spec.max_value is not None - and self.max_value > float_spec.max_value): - raise ValueError( - object_utils.message_on_path( - f'Float.max_value ({self.max_value}) should be no greater than ' - f'the max value ({float_spec.max_value}) of value spec: ' - f'{float_spec}.', path)) - return (False, self) - - def is_leaf(self) -> bool: - """Returns whether this is a leaf node.""" - return True - - -class CustomHyper(HyperPrimitive): - """User-defined hyper primitive. - - User-defined hyper primitive is useful when users want to have full control - on the semantics and genome encoding of the search space. For example, the - decision points are of variable length, which is not yet supported by - built-in hyper primitives. - - To use user-defined hyper primitive is simple, the user should: - - - 1) Subclass `CustomHyper` and implements the - :meth:`pyglove.hyper.CustomHyper.custom_decode` method. - It's optional to implement the - :meth:`pyglove.hyper.CustomHyper.custom_encode` method, which is only - necessary when the user want to encoder a material object into a DNA. - 2) Introduce a DNAGenerator that can generate DNA for the - :class:`pyglove.geno.CustomDecisionPoint`. - - For example, the following code tries to find an optimal sub-sequence of an - integer sequence by their sums:: - - import random - - class IntSequence(pg.hyper.CustomHyper): - - def custom_decode(self, dna): - return [int(v) for v in dna.value.split(',') if v != ''] - - class SubSequence(pg.evolution.Mutator): - - def mutate(self, dna): - genome = dna.value - items = genome.split(',') - start = random.randint(0, len(items)) - end = random.randint(start, len(items)) - new_genome = ','.join(items[start:end]) - return pg.DNA(new_genome, spec=dna.spec) - - @pg.geno.dna_generator - def initial_population(): - yield pg.DNA('12,-34,56,-2,100,98', spec=dna_spec) - - algo = pg.evolution.Evolution( - (pg.evolution.selectors.Random(10) - >> pg.evolution.selectors.Top(1) - >> SubSequence()), - population_init=initial_population(), - population_update=pg.evolution.selectors.Last(20)) - - best_reward, best_example = None, None - for int_seq, feedback in pg.sample(IntSequence(), algo, num_examples=100): - reward = sum(int_seq) - if best_reward is None or best_reward < reward: - best_reward, best_example = reward, int_seq - feedback(reward) - - print(best_reward, best_example) - - Please note that user-defined hyper value can be used together with PyGlove's - built-in hyper primitives, for example:: - - pg.oneof([IntSequence(), None]) - - Therefore it's also a mechanism to extend PyGlove's search space definitions. - """ - - def _decode(self): - if not isinstance(self.dna.value, str): - raise ValueError( - f'{self.__class__} expects string type DNA. ' - f'Encountered {self.dna!r}.') - return self.custom_decode(self.dna) - - @abc.abstractmethod - def custom_decode(self, dna: geno.DNA) -> Any: - """Decode a DNA whose value is a string of user-defined genome.""" - - def encode(self, value: Any) -> geno.DNA: - """Encode value into DNA with user-defined genome.""" - return self.custom_encode(value) - - def custom_encode(self, value: Any) -> geno.DNA: - """Encode value to user defined genome.""" - raise NotImplementedError( - f'\'custom_encode\' is not supported by {self.__class__.__name__!r}.') - - def dna_spec( - self, location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: - """Always returns CustomDecisionPoint for CustomHyper.""" - return geno.CustomDecisionPoint( - hyper_type=self.__class__.__name__, - next_dna_fn=self.next_dna, - random_dna_fn=self.random_dna, - hints=self.hints, name=self.name, location=location) - - def first_dna(self) -> geno.DNA: - """Returns the first DNA of current sub-space. - - Returns: - A string-valued DNA. - """ - if self.next_dna.__code__ is CustomHyper.next_dna.__code__: - raise NotImplementedError( - f'{self.__class__!r} must implement method `next_dna` to be used in ' - f'dynamic evaluation mode.') - return self.next_dna(None) - - def next_dna(self, dna: Optional[geno.DNA] = None) -> Optional[geno.DNA]: - """Subclasses should override this method to support pg.Sweeping.""" - raise NotImplementedError( - f'`next_dna` is not implemented in f{self.__class__!r}') - - def random_dna( - self, - random_generator: Union[types.ModuleType, random.Random, None] = None, - previous_dna: Optional[geno.DNA] = None) -> geno.DNA: - """Subclasses should override this method to support pg.random_dna.""" - raise NotImplementedError( - f'`random_dna` is not implemented in {self.__class__!r}') - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'CustomHyper']: - """Validate candidates during value_spec binding time.""" - del path, value_spec, allow_partial, child_transform - # Allow custom hyper to be assigned to any type. - return (False, self) - - -class MutationType(str, enum.Enum): - """Mutation type.""" - REPLACE = 0 - INSERT = 1 - DELETE = 2 - - -@dataclasses.dataclass -class MutationPoint: - """Internal class that encapsulates the information for a mutation point. - - Attributes: - mutation_type: The type of the mutation. - location: The location where the mutation will take place. - old_value: The value of the mutation point before mutation. - parent: The parent node of the mutation point. - """ - mutation_type: 'MutationType' - location: object_utils.KeyPath - old_value: Any - parent: symbolic.Symbolic - - -class Evolvable(CustomHyper): - """Hyper primitive for evolving an arbitrary symbolic object.""" - - def _on_bound(self): - super()._on_bound() - self._weights = self.weights or (lambda mt, k, v, p: 1.0) - - def custom_decode(self, dna: geno.DNA) -> Any: - assert isinstance(dna.value, str) - # TODO(daiyip): consider compression. - return symbolic.from_json_str(dna.value) - - def custom_encode(self, value: Any) -> geno.DNA: - return geno.DNA(symbolic.to_json_str(value)) - - def mutation_points_and_weights( - self, - value: symbolic.Symbolic) -> Tuple[List[MutationPoint], List[float]]: - """Returns mutation points with weights for a symbolic tree.""" - mutation_points: List[MutationPoint] = [] - mutation_weights: List[float] = [] - - def _choose_mutation_point(k: object_utils.KeyPath, - v: Any, - p: Optional[symbolic.Symbolic]): - """Visiting function for a symbolic node.""" - def _add_point(mt: MutationType, k=k, v=v, p=p): - assert p is not None - mutation_points.append(MutationPoint(mt, k, v, p)) - mutation_weights.append(self._weights(mt, k, v, p)) - - if p is not None: - # Stopping mutating current branch if metadata said so. - f = p.sym_attr_field(k.key) - if f and f.metadata and 'no_mutation' in f.metadata: - return symbolic.TraverseAction.CONTINUE - _add_point(MutationType.REPLACE) - - # Special handle list traversal to add insertion and deletion. - if isinstance(v, symbolic.List): - if v.value_spec: - spec = v.value_spec - reached_max_size = spec.max_size and len(v) == spec.max_size - reached_min_size = spec.min_size and len(v) == spec.min_size - else: - reached_max_size = False - reached_min_size = False - - for i, cv in enumerate(v): - ck = object_utils.KeyPath(i, parent=k) - if not reached_max_size: - _add_point(MutationType.INSERT, - k=ck, v=object_utils.MISSING_VALUE, p=v) - - if not reached_min_size: - _add_point(MutationType.DELETE, k=ck, v=cv, p=v) - - # Replace type and value will be added in traverse. - symbolic.traverse(cv, _choose_mutation_point, root_path=ck, parent=v) - if not reached_max_size and i == len(v) - 1: - _add_point(MutationType.INSERT, - k=object_utils.KeyPath(i + 1, parent=k), - v=object_utils.MISSING_VALUE, - p=v) - return symbolic.TraverseAction.CONTINUE - return symbolic.TraverseAction.ENTER - - # First-order traverse the symbolic tree to compute - # the mutation points and weights. - symbolic.traverse(value, _choose_mutation_point) - return mutation_points, mutation_weights - - def first_dna(self) -> geno.DNA: - """Returns the first DNA of current sub-space.""" - return self.custom_encode(self.initial_value) - - def random_dna( - self, - random_generator: Union[types.ModuleType, random.Random, None] = None, - previous_dna: Optional[geno.DNA] = None) -> geno.DNA: - """Generates a random DNA.""" - random_generator = random_generator or random - if previous_dna is None: - return self.first_dna() - return self.custom_encode( - self.mutate(self.custom_decode(previous_dna), random_generator)) - - def mutate( - self, - value: symbolic.Symbolic, - random_generator: Union[types.ModuleType, random.Random, None] = None - ) -> symbolic.Symbolic: - """Returns the next value for a symbolic value.""" - r = random_generator or random - points, weights = self.mutation_points_and_weights(value) - [point] = r.choices(points, weights, k=1) - - # Mutating value. - if point.mutation_type == MutationType.REPLACE: - assert point.location, point - value.rebind({ - str(point.location): self.node_transform( - point.location, point.old_value, point.parent)}) - elif point.mutation_type == MutationType.INSERT: - assert isinstance(point.parent, symbolic.List), point - assert point.old_value == object_utils.MISSING_VALUE, point - assert isinstance(point.location.key, int), point - with symbolic.allow_writable_accessors(): - point.parent.insert( - point.location.key, - self.node_transform(point.location, point.old_value, point.parent)) - else: - assert point.mutation_type == MutationType.DELETE, point - assert isinstance(point.parent, symbolic.List), point - assert isinstance(point.location.key, int), point - with symbolic.allow_writable_accessors(): - del point.parent[point.location.key] - return value - - -# We defer members declaration for Evolvable since the weights will reference -# the definition of MutationType. -symbolic.members([ - ('initial_value', schema.Object(symbolic.Symbolic), - 'Symbolic value to involve.'), - ('node_transform', schema.Callable( - [], - returns=schema.Any()), - ''), - ('weights', schema.Callable( - [ - schema.Object(MutationType), - schema.Object(object_utils.KeyPath), - schema.Any().noneable(), - schema.Object(symbolic.Symbolic) - ], returns=schema.Float(min_value=0.0)).noneable(), - ('An optional callable object that returns the unnormalized (e.g. ' - 'the sum of all probabilities do not have to sum to 1.0) mutation ' - 'probabilities for all the nodes in the symbolic tree, based on ' - '(mutation type, location, old value, parent node). If None, all the ' - 'locations and mutation types will be sampled uniformly.')), -])(Evolvable) - - -@symbolic.members([ - ('reference_paths', schema.List(schema.Object(object_utils.KeyPath)), - 'Paths of referenced values, which are relative paths searched from ' - 'current node to root.') -]) -class DerivedValue(symbolic.Object, schema.CustomTyping): - """Base class of value that references to other values in object tree.""" - - @abc.abstractmethod - def derive(self, *args: Any) -> Any: - """Derive the value from referenced values.""" - - def resolve(self, - reference_path_or_paths: Optional[Union[Text, List[Text]]] = None - ) -> Union[ - Tuple[symbolic.Symbolic, object_utils.KeyPath], - List[Tuple[symbolic.Symbolic, object_utils.KeyPath]]]: - """Resolve reference paths based on the location of this node. - - Args: - reference_path_or_paths: (Optional) a string or KeyPath as a reference - path or a list of strings or KeyPath objects as a list of - reference paths. - If this argument is not provided, prebound reference paths of this - object will be used. - - Returns: - A tuple (or list of tuple) of (resolved parent, resolved full path) - """ - single_input = False - if reference_path_or_paths is None: - reference_paths = self.reference_paths - elif isinstance(reference_path_or_paths, str): - reference_paths = [object_utils.KeyPath.parse(reference_path_or_paths)] - single_input = True - elif isinstance(reference_path_or_paths, object_utils.KeyPath): - reference_paths = [reference_path_or_paths] - single_input = True - elif isinstance(reference_path_or_paths, list): - paths = [] - for path in reference_path_or_paths: - if isinstance(path, str): - path = object_utils.KeyPath.parse(path) - elif not isinstance(path, object_utils.KeyPath): - raise ValueError('Argument \'reference_path_or_paths\' must be None, ' - 'a string, KeyPath object, a list of strings, or a ' - 'list of KeyPath objects.') - paths.append(path) - reference_paths = paths - else: - raise ValueError('Argument \'reference_path_or_paths\' must be None, ' - 'a string, KeyPath object, a list of strings, or a ' - 'list of KeyPath objects.') - - resolved_paths = [] - for reference_path in reference_paths: - parent = self.sym_parent - while parent is not None and not reference_path.exists(parent): - parent = getattr(parent, 'sym_parent', None) - if parent is None: - raise ValueError( - f'Cannot resolve \'{reference_path}\': parent not found.') - resolved_paths.append((parent, parent.sym_path + reference_path)) - return resolved_paths if not single_input else resolved_paths[0] - - def __call__(self): - """Generate value by deriving values from reference paths.""" - referenced_values = [] - for reference_path, (parent, _) in zip( - self.reference_paths, self.resolve()): - referenced_value = reference_path.query(parent) - - # Make sure referenced value does not have referenced value. - # NOTE(daiyip): We can support dependencies between derived values - # in future if needed. - if not object_utils.traverse( - referenced_value, self._contains_not_derived_value): - raise ValueError( - f'Derived value (path={referenced_value.sym_path}) should not ' - f'reference derived values. ' - f'Encountered: {referenced_value}, ' - f'Referenced at path {self.sym_path}.') - referenced_values.append(referenced_value) - return self.derive(*referenced_values) - - def _contains_not_derived_value( - self, path: object_utils.KeyPath, value: Any) -> bool: - """Returns whether a value contains derived value.""" - if isinstance(value, DerivedValue): - return False - elif isinstance(value, symbolic.Object): - for k, v in value.sym_items(): - if not object_utils.traverse( - v, self._contains_not_derived_value, - root_path=object_utils.KeyPath(k, path)): - return False - return True - - -class ValueReference(DerivedValue): - """Class that represents a value referencing another value.""" - - def _on_bound(self): - """Custom init.""" - super()._on_bound() - if len(self.reference_paths) != 1: - raise ValueError( - f'Argument \'reference_paths\' should have exact 1 ' - f'item. Encountered: {self.reference_paths}') - - def derive(self, referenced_value: Any) -> Any: - """Derive value by return a copy of the referenced value.""" - return copy.copy(referenced_value) - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'DerivedValue']: - """Implement schema.CustomTyping interface.""" - # TODO(daiyip): perform possible static analysis on referenced paths. - del path, value_spec, allow_partial, child_transform - return (False, self) - - -def reference(reference_path: Text) -> ValueReference: - """Create a referenced value from a referenced path.""" - return ValueReference(reference_paths=[reference_path]) - - -class ObjectTemplate(HyperValue, object_utils.Formattable): - """Object template that encodes and decodes symbolic values. - - An object template can be created from a hyper value, which is a symbolic - object with some parts placeheld by hyper primitives. For example:: - - x = A(a=0, - b=pg.oneof(['foo', 'bar']), - c=pg.manyof(2, [1, 2, 3, 4, 5, 6]), - d=pg.floatv(0.1, 0.5), - e=pg.oneof([ - { - 'f': pg.oneof([True, False]), - } - { - 'g': pg.manyof(2, [B(), C(), D()], distinct=False), - 'h': pg.manyof(2, [0, 1, 2], sorted=True), - } - ]) - }) - t = pg.template(x) - - In this example, the root template have 4 children hyper primitives associated - with keys 'b', 'c', 'd' and 'e', while the hyper primitive 'e' have 3 children - associated with keys 'f', 'g' and 'h', creating a conditional search space. - - Thus the DNA shape is determined by the definition of template, described - by geno.DNASpec. In this case, the DNA spec of this template looks like:: - - pg.geno.space([ - pg.geno.oneof([ # Spec for 'b'. - pg.geno.constant(), # A constant template for 'foo'. - pg.geno.constant(), # A constant template for 'bar'. - ]), - pg.geno.manyof([ # Spec for 'c'. - pg.geno.constant(), # A constant template for 1. - pg.geno.constant(), # A constant template for 2. - pg.geno.constant(), # A constant template for 3. - pg.geno.constant(), # A constant template for 4. - pg.geno.constant(), # A constant template for 5. - pg.geno.constant(), # A constant template for 6. - ]), - pg.geno.floatv(0.1, 0.5), # Spec for 'd'. - pg.geno.oneof([ # Spec for 'e'. - pg.geno.space([ - pg.geno.oneof([ # Spec for 'f'. - pg.geno.constant(), # A constant template for True. - pg.geno.constant(), # A constant template for False. - ]) - ]), - pg.geno.space([ - pg.geno.manyof(2, [ # Spec for 'g'. - pg.geno.constant(), # A constant template for B(). - pg.geno.constant(), # A constant template for C(). - pg.geno.constant(), # A constant template for D(). - ], distinct=False) # choices of the same value can - # be selected multiple times. - pg.geno.manyof(2, [ # Spec for 'h'. - pg.geno.constant(), # A constant template for 0. - pg.geno.constant(), # A constant template for 1. - pg.geno.constant(), # A constant template for 2. - ], sorted=True) # acceptable choices needs to be sorted, - # which enables using choices as set (of - # possibly repeated values). - ]) - ]) - - It may generate DNA as the following: - DNA([0, [0, 2], 0.1, (0, 0)]) - - A template can also work only on a subset of hyper primitives from the input - value through the `where` function. This is useful to partition a search space - into parts for separate optimization. - - For example:: - - t = pg.hyper.ObjectTemplate( - A(a=pg.oneof([1, 2]), b=pg.oneof([3, 4])), - where=lambda e: e.root_path == 'a') - assert t.dna_spec() == pg.geno.space([ - pg.geno.oneof(location='a', candidates=[ - pg.geno.constant(), # For a=1 - pg.geno.constant(), # For a=2 - ], literal_values=['(0/2) 1', '(1/2) 2']) - ]) - assert t.decode(pg.DNA(0)) == A(a=1, b=pg.oneof([3, 4])) - """ - - def __init__(self, - value: Any, - compute_derived: bool = False, - where: Optional[Callable[[HyperPrimitive], bool]] = None): - """Constructor. - - Args: - value: Value (maybe) annotated with generators to use as template. - compute_derived: Whether to compute derived value at this level. - We only want to compute derived value at root level since reference path - may go out of scope of a non-root ObjectTemplate. - where: Function to filter hyper primitives. If None, all hyper primitives - from `value` will be included in the encoding/decoding process. - Otherwise only the hyper primitives on which 'where' returns True will - be included. `where` can be useful to partition a search space into - separate optimization processes. - Please see 'ObjectTemplate' docstr for details. - """ - super().__init__() - self._value = value - self._root_path = object_utils.KeyPath() - self._compute_derived = compute_derived - self._where = where - self._parse_generators() - - @property - def root_path(self) -> object_utils.KeyPath: - """Returns root path.""" - return self._root_path - - @root_path.setter - def root_path(self, path: object_utils.KeyPath): - """Set root path.""" - self._root_path = path - - def _parse_generators(self) -> None: - """Parse generators from its templated value.""" - hyper_primitives = [] - def _extract_immediate_child_hyper_primitives( - path: object_utils.KeyPath, value: Any) -> bool: - """Extract top-level hyper primitives.""" - if (isinstance(value, HyperValue) - and (not self._where or self._where(value))): - # Apply where clause to child choices. - if isinstance(value, Choices) and self._where: - value = value.clone().rebind(where=self._where) - hyper_primitives.append((path, value)) - elif isinstance(value, symbolic.Object): - for k, v in value.sym_items(): - object_utils.traverse( - v, _extract_immediate_child_hyper_primitives, - root_path=object_utils.KeyPath(k, path)) - return True - - object_utils.traverse( - self._value, _extract_immediate_child_hyper_primitives) - self._hyper_primitives = hyper_primitives - - @property - def value(self) -> Any: - """Returns templated value.""" - return self._value - - @property - def hyper_primitives(self) -> List[Tuple[Text, HyperValue]]: - """Returns hyper primitives in tuple (relative path, hyper primitive).""" - return self._hyper_primitives - - @property - def is_constant(self) -> bool: - """Returns whether current template is constant value.""" - return not self._hyper_primitives - - def dna_spec( - self, location: Optional[object_utils.KeyPath] = None) -> geno.Space: - """Return DNA spec (geno.Space) from this template.""" - return geno.Space( - elements=[ - primitive.dna_spec(primitive_location) - for primitive_location, primitive in self._hyper_primitives - ], - location=location or object_utils.KeyPath()) - - def _decode(self) -> Any: - """Decode DNA into a value.""" - dna = self._dna - if not self._hyper_primitives and (dna.value is not None or dna.children): - raise ValueError( - object_utils.message_on_path( - f'Encountered extra DNA value to decode: {dna!r}', - self._root_path)) - - # Compute hyper primitive values first. - rebind_dict = {} - if len(self._hyper_primitives) == 1: - primitive_location, primitive = self._hyper_primitives[0] - rebind_dict[primitive_location.path] = primitive.decode(dna) - else: - if len(dna.children) != len(self._hyper_primitives): - raise ValueError( - object_utils.message_on_path( - f'The length of child values ({len(dna.children)}) is ' - f'different from the number of hyper primitives ' - f'({len(self._hyper_primitives)}) in ObjectTemplate. ' - f'DNA={dna!r}, ObjectTemplate={self!r}.', self._root_path)) - for i, (primitive_location, primitive) in enumerate( - self._hyper_primitives): - rebind_dict[primitive_location.path] = ( - primitive.decode(dna.children[i])) - - if rebind_dict: - if len(rebind_dict) == 1 and '' in rebind_dict: - # NOTE(daiyip): Special handle the case when the root value needs to be - # replaced. For example: `template(oneof([0, 1])).decode(geno.DNA(0))` - # should return 0 instead of rebinding the root `OneOf` object. - value = rebind_dict[''] - else: - # NOTE(daiyip): Instead of deep copying the whole object (with hyper - # primitives), we can cherry-pick only non-hyper parts. Unless we saw - # performance issues it's not worthy to optimize this. - value = symbolic.clone(self._value, deep=True) - value.rebind(rebind_dict) - copied = True - else: - assert self.is_constant - value = self._value - copied = False - - # Compute derived values if needed. - if self._compute_derived: - # TODO(daiyip): Currently derived value parsing is done at decode time, - # which can be optimized by moving to template creation time. - derived_values = [] - def _extract_derived_values( - path: object_utils.KeyPath, value: Any) -> bool: - """Extract top-level primitives.""" - if isinstance(value, DerivedValue): - derived_values.append((path, value)) - elif isinstance(value, symbolic.Object): - for k, v in value.sym_items(): - object_utils.traverse( - v, _extract_derived_values, - root_path=object_utils.KeyPath(k, path)) - return True - object_utils.traverse(value, _extract_derived_values) - - if derived_values: - if not copied: - value = symbolic.clone(value, deep=True) - rebind_dict = {} - for path, derived_value in derived_values: - rebind_dict[path.path] = derived_value() - assert rebind_dict - value.rebind(rebind_dict) - return value - - def encode(self, value: Any) -> geno.DNA: - """Encode a value into a DNA. - - Example:: - - # DNA of a constant template: - template = pg.hyper.ObjectTemplate({'a': 0}) - assert template.encode({'a': 0}) == pg.DNA(None) - # Raises: Unmatched value between template and input. - template.encode({'a': 1}) - - # DNA of a template containing only one pg.oneof. - template = pg.hyper.ObjectTemplate({'a': pg.oneof([1, 2])}) - assert template.encode({'a': 1}) == pg.DNA(0) - - # DNA of a template containing only one pg.oneof. - template = pg.hyper.ObjectTemplate({'a': pg.floatv(0.1, 1.0)}) - assert template.encode({'a': 0.5}) == pg.DNA(0.5) - - Args: - value: Value to encode. - - Returns: - Encoded DNA. - - Raises: - ValueError if value cannot be encoded by this template. - """ - children = [] - def _encode(path: object_utils.KeyPath, - template_value: Any, - input_value: Any) -> Any: - """Encode input value according to template value.""" - if (schema.MISSING_VALUE == input_value - and schema.MISSING_VALUE != template_value): - raise ValueError( - f'Value is missing from input. Path=\'{path}\'.') - if (isinstance(template_value, HyperValue) - and (not self._where or self._where(template_value))): - children.append(template_value.encode(input_value)) - elif isinstance(template_value, DerivedValue): - if self._compute_derived: - referenced_values = [ - reference_path.query(value) - for _, reference_path in template_value.resolve() - ] - derived_value = template_value.derive(*referenced_values) - if derived_value != input_value: - raise ValueError( - f'Unmatched derived value between template and input. ' - f'(Path=\'{path}\', Template={template_value!r}, ' - f'ComputedValue={derived_value!r}, Input={input_value!r})') - # For template that doesn't compute derived value, it get passed over - # to parent template who may be able to handle. - elif isinstance(template_value, symbolic.Object): - if type(input_value) is not type(template_value): - raise ValueError( - f'Unmatched Object type between template and input: ' - f'(Path=\'{path}\', Template={template_value!r}, ' - f'Input={input_value!r})') - template_keys = set(template_value.sym_keys()) - value_keys = set(input_value.sym_keys()) - if template_keys != value_keys: - raise ValueError( - f'Unmatched Object keys between template value and input ' - f'value. (Path=\'{path}\', ' - f'TemplateOnlyKeys={template_keys - value_keys}, ' - f'InputOnlyKeys={value_keys - template_keys})') - for key in template_value.sym_keys(): - object_utils.merge_tree( - template_value.sym_getattr(key), - input_value.sym_getattr(key), - _encode, root_path=object_utils.KeyPath(key, path)) - elif isinstance(template_value, symbolic.Dict): - # Do nothing since merge will iterate all elements in dict and list. - if not isinstance(input_value, dict): - raise ValueError( - f'Unmatched dict between template value and input ' - f'value. (Path=\'{path}\', Template={template_value!r}, ' - f'Input={input_value!r})') - elif isinstance(template_value, symbolic.List): - if (not isinstance(input_value, list) - or len(input_value) != len(template_value)): - raise ValueError( - f'Unmatched list between template value and input ' - f'value. (Path=\'{path}\', Template={template_value!r}, ' - f'Input={input_value!r})') - for i, template_item in enumerate(template_value): - object_utils.merge_tree( - template_item, input_value[i], _encode, - root_path=object_utils.KeyPath(i, path)) - else: - if template_value != input_value: - raise ValueError( - f'Unmatched value between template and input. ' - f'(Path=\'{path}\', ' - f'Template={object_utils.quote_if_str(template_value)}, ' - f'Input={object_utils.quote_if_str(input_value)})') - return template_value - object_utils.merge_tree( - self._value, value, _encode, root_path=self._root_path) - return geno.DNA(None, children) - - def try_encode(self, value: Any) -> Tuple[bool, geno.DNA]: - """Try to encode a value without raise Exception.""" - try: - dna = self.encode(value) - return (True, dna) - except ValueError: - return (False, None) # pytype: disable=bad-return-type - except KeyError: - return (False, None) # pytype: disable=bad-return-type - - def __eq__(self, other): - """Operator ==.""" - if not isinstance(other, self.__class__): - return False - return self.value == other.value - - def __ne__(self, other): - """Operator !=.""" - return not self.__eq__(other) - - def format(self, - compact: bool = False, - verbose: bool = True, - root_indent: int = 0, - **kwargs) -> Text: - """Format this object.""" - details = object_utils.format( - self._value, compact, verbose, root_indent, **kwargs) - return f'{self.__class__.__name__}(value={details})' - - def custom_apply( - self, - path: object_utils.KeyPath, - value_spec: schema.ValueSpec, - allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, schema.Field, Any], Any]] = None - ) -> Tuple[bool, 'ObjectTemplate']: - """Validate candidates during value_spec binding time.""" - # Check if value_spec directly accepts `self`. - if not value_spec.value_type or not isinstance(self, value_spec.value_type): - value_spec.apply( - self._value, - allow_partial, - root_path=self.root_path) - return (False, self) - - -# TODO(daiyip): For backward compatibility, remove after legacy dependencies -# are updated. -ChoiceList = ManyOf -ChoiceValue = OneOf -Template = ObjectTemplate - - -# -# Helper methods for creating hyper values. -# - - -def oneof(candidates: Iterable[Any], - *, - name: Optional[Text] = None, - hints: Optional[Any] = None) -> Any: - """N choose 1. - - Example:: - - @pg.members([ - ('x', pg.typing.Int()) - ]) - class A(pg.Object): - pass - - # A single categorical choice: - v = pg.oneof([1, 2, 3]) - - # A complex type as candidate. - v1 = pg.oneof(['a', {'x': 1}, A(1)]) - - # A hierarchical categorical choice: - v2 = pg.oneof([ - 'foo', - 'bar', - A(pg.oneof([1, 2, 3])) - ]) - - See also: - - * :class:`pyglove.hyper.OneOf` - * :func:`pyglove.manyof` - * :func:`pyglove.floatv` - * :func:`pyglove.permutate` - * :func:`pyglove.evolve` - - .. note:: - - Under symbolic mode (by default), `pg.oneof` returns a ``pg.hyper.OneOf`` - object. Under dynamic evaluation mode, which is called under the context of - :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or - :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to - a concrete candidate value. - - To use conditional search space in dynamic evaluation mode, the candidate - should be wrapped with a `lambda` function, which is not necessary under - symbolic mode. For example:: - - pg.oneof([lambda: pg.oneof([0, 1], name='sub'), 2], name='root') - - Args: - candidates: Candidates to select from. Items of candidate can be any type, - therefore it can have nested hyper primitives, which forms a hierarchical - search space. - name: A name that can be used to identify a decision point in the search - space. This is needed when the code to instantiate the same hyper - primitive may be called multiple times under a - `pg.DynamicEvaluationContext.collect` context or under a - `pg.DynamicEvaluationContext.apply` context. - hints: An optional value which acts as a hint for the controller. - - Returns: - In symbolic mode, this function returns a `ChoiceValue`. - In dynamic evaluation mode, this function returns one of the items in - `candidates`. - If evaluated under a `pg.DynamicEvaluationContext.apply` scope, - this function will return the selected candidate. - If evaluated under a `pg.DynamicEvaluationContext.collect` - scope, it will return the first candidate. - """ - return OneOf(candidates=list(candidates), name=name, hints=hints) - - -one_of = oneof - - -def manyof(k: int, - candidates: Iterable[Any], - distinct: bool = True, - sorted: bool = False, # pylint: disable=redefined-builtin - *, - name: Optional[Text] = None, - hints: Optional[Any] = None, - **kwargs) -> Any: - """N choose K. - - Example:: - - @pg.members([ - ('x', pg.typing.Int()) - ]) - class A(pg.Object): - pass - - # Chooses 2 distinct candidates. - v = pg.manyof(2, [1, 2, 3]) - - # Chooses 2 non-distinct candidates. - v = pg.manyof(2, [1, 2, 3], distinct=False) - - # Chooses 2 distinct candidates sorted by their indices. - v = pg.manyof(2, [1, 2, 3], sorted=True) - - # A complex type as candidate. - v1 = pg.manyof(2, ['a', {'x': 1}, A(1)]) - - # A hierarchical categorical choice: - v2 = pg.manyof(2, [ - 'foo', - 'bar', - A(pg.oneof([1, 2, 3])) - ]) - - .. note:: - - Under symbolic mode (by default), `pg.manyof` returns a ``pg.hyper.ManyOf`` - object. Under dynamic evaluation mode, which is called under the context of - :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or - :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to - a concrete candidate value. - - To use conditional search space in dynamic evaluate mode, the candidate - should be wrapped with a `lambda` function, which is not necessary under - symbolic mode. For example:: - - pg.manyof(2, [ - lambda: pg.oneof([0, 1], name='sub_a'), - lambda: pg.floatv(0.0, 1.0, name='sub_b'), - lambda: pg.manyof(2, ['a', 'b', 'c'], name='sub_c') - ], name='root') - - See also: - - * :class:`pyglove.hyper.ManyOf` - * :func:`pyglove.manyof` - * :func:`pyglove.floatv` - * :func:`pyglove.permutate` - * :func:`pyglove.evolve` - - Args: - k: number of choices to make. Should be no larger than the length of - `candidates` unless `choice_distinct` is set to False, - candidates: Candidates to select from. Items of candidate can be any type, - therefore it can have nested hyper primitives, which forms a hierarchical - search space. - distinct: If True, each choice needs to be unique. - sorted: If True, choices are sorted by their indices in the - candidates. - name: A name that can be used to identify a decision point in the search - space. This is needed when the code to instantiate the same hyper - primitive may be called multiple times under a - `pg.DynamicEvaluationContext.collect` context or a - `pg.DynamicEvaluationContext.apply` context. - hints: An optional value which acts as a hint for the controller. - **kwargs: Keyword arguments for backward compatibility. - `choices_distinct`: Old name for `distinct`. - `choices_sorted`: Old name for `sorted`. - - Returns: - In symbolic mode, this function returns a `Choices`. - In dynamic evaluate mode, this function returns a list of items in - `candidates`. - If evaluated under a `pg.DynamicEvaluationContext.apply` scope, - this function will return a list of selected candidates. - If evaluated under a `pg.DynamicEvaluationContext.collect` - scope, it will return a list of the first valid combination from the - `candidates`. For example:: - - # Evaluates to [0, 1, 2]. - manyof(3, range(5)) - - # Evaluates to [0, 0, 0]. - manyof(3, range(5), distinct=False) - """ - choices_distinct = kwargs.pop('choices_distinct', distinct) - choices_sorted = kwargs.pop('choices_sorted', sorted) - return ManyOf( - num_choices=k, - candidates=list(candidates), - choices_distinct=choices_distinct, - choices_sorted=choices_sorted, - name=name, - hints=hints) - - -sublist_of = manyof - - -def permutate(candidates: Iterable[Any], - name: Optional[Text] = None, - hints: Optional[Any] = None) -> Any: - """Permuatation of candidates. - - Example:: - - @pg.members([ - ('x', pg.typing.Int()) - ]) - class A(pg.Object): - pass - - # Permutates the candidates. - v = pg.permutate([1, 2, 3]) - - # A complex type as candidate. - v1 = pg.permutate(['a', {'x': 1}, A(1)]) - - # A hierarchical categorical choice: - v2 = pg.permutate([ - 'foo', - 'bar', - A(pg.oneof([1, 2, 3])) - ]) - - .. note:: - - Under symbolic mode (by default), `pg.manyof` returns a ``pg.hyper.ManyOf`` - object. Under dynamic evaluate mode, which is called under the context of - :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or - :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to - a concrete candidate value. - - To use conditional search space in dynamic evaluate mode, the candidate - should be wrapped with a `lambda` function, which is not necessary under - symbolic mode. For example:: - - pg.permutate([ - lambda: pg.oneof([0, 1], name='sub_a'), - lambda: pg.floatv(0.0, 1.0, name='sub_b'), - lambda: pg.manyof(2, ['a', 'b', 'c'], name='sub_c') - ], name='root') - - See also: - - * :class:`pyglove.hyper.ManyOf` - * :func:`pyglove.oneof` - * :func:`pyglove.manyof` - * :func:`pyglove.floatv` - * :func:`pyglove.evolve` - - Args: - candidates: Candidates to select from. Items of candidate can be any type, - therefore it can have nested hyper primitives, which forms a hierarchical - search space. - name: A name that can be used to identify a decision point in the search - space. This is needed when the code to instantiate the same hyper - primitive may be called multiple times under a - `pg.DynamicEvaluationContext.collect` context or a - `pg.DynamicEvaluationContext.apply` context. - hints: An optional value which acts as a hint for the controller. - - Returns: - In symbolic mode, this function returns a `Choices`. - In dynamic evaluate mode, this function returns a permutation from - `candidates`. - If evaluated under an `pg.DynamicEvaluationContext.apply` scope, - this function will return a permutation of candidates based on controller - decisions. - If evaluated under a `pg.DynamicEvaluationContext.collect` - scope, it will return the first valid permutation. - For example:: - - # Evaluates to [0, 1, 2, 3, 4]. - permutate(range(5), name='numbers') - """ - candidates = list(candidates) - return manyof( - len(candidates), candidates, - choices_distinct=True, choices_sorted=False, name=name, hints=hints) - - -def floatv(min_value: float, - max_value: float, - scale: Optional[Text] = None, - *, - name: Optional[Text] = None, - hints: Optional[Any] = None) -> Any: - """A continuous value within a range. - - Example:: - - # A continuous value within [0.0, 1.0] - v = pg.floatv(0.0, 1.0) - - See also: - - * :class:`pyglove.hyper.Float` - * :func:`pyglove.oneof` - * :func:`pyglove.manyof` - * :func:`pyglove.permutate` - * :func:`pyglove.evolve` - - .. note:: - - Under symbolic mode (by default), `pg.floatv` returns a ``pg.hyper.Float`` - object. Under dynamic evaluate mode, which is called under the context of - :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or - :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to - a concrete candidate value. - - Args: - min_value: Minimum acceptable value (inclusive). - max_value: Maximum acceptable value (inclusive). - scale: An optional string as the scale of the range. Supported values - are None, 'linear', 'log', and 'rlog'. - If None, the feasible space is unscaled. - If `linear`, the feasible space is mapped to [0, 1] linearly. - If `log`, the feasible space is mapped to [0, 1] logarithmically with - formula `x -> log(x / min) / log(max / min)`. - If `rlog`, the feasible space is mapped to [0, 1] "reverse" - logarithmically, resulting in values close to `max_value` spread - out more than the points near the `min_value`, with formula: - x -> 1.0 - log((max + min - x) / min) / log (max / min). - `min_value` must be positive if `scale` is not None. - Also, it depends on the search algorithm to decide whether this - information is used or not. - name: A name that can be used to identify a decision point in the search - space. This is needed when the code to instantiate the same hyper - primitive may be called multiple times under a - `pg.DynamicEvaluationContext.collect` context or a - `pg.DynamicEvaluationContext.apply` context. - hints: An optional value which acts as a hint for the controller. - - Returns: - In symbolic mode, this function returns a `Float`. - In dynamic evaluate mode, this function returns a float value that is no - less than the `min_value` and no greater than the `max_value`. - If evaluated under an `pg.DynamicEvaluationContext.apply` scope, - this function will return a chosen float value from the controller - decisions. - If evaluated under a `pg.DynamicEvaluationContext.collect` - scope, it will return `min_value`. - """ - return Float( - min_value=min_value, max_value=max_value, - scale=scale, name=name, hints=hints) - - -# For backward compatibility -float_value = floatv - - -def evolve( - initial_value: symbolic.Symbolic, - node_transform: Callable[ - [ - object_utils.KeyPath, # Location. - Any, # Old value. - # pg.MISSING_VALUE for insertion. - symbolic.Symbolic, # Parent node. - ], - Any # Replacement. - ], - *, - weights: Optional[Callable[ - [ - MutationType, # Mutation type. - object_utils.KeyPath, # Location. - Any, # Value. - symbolic.Symbolic, # Parent. - ], - float # Mutation weight. - ]] = None, # pylint: disable=bad-whitespace - name: Optional[Text] = None, - hints: Optional[Any] = None) -> Evolvable: - """An evolvable symbolic value. - - Example:: - - @pg.symbolize - @dataclasses.dataclass - class Foo: - x: int - y: int - - @pg.symbolize - @dataclasses.dataclass - class Bar: - a: int - b: int - - # Defines possible transitions. - def node_transform(location, value, parent): - if isinstance(value, Foo) - return Bar(value.x, value.y) - if location.key == 'x': - return random.choice([1, 2, 3]) - if location.key == 'y': - return random.choice([3, 4, 5]) - - v = pg.evolve(Foo(1, 3), node_transform) - - See also: - - * :class:`pyglove.hyper.Evolvable` - * :func:`pyglove.oneof` - * :func:`pyglove.manyof` - * :func:`pyglove.permutate` - * :func:`pyglove.floatv` - - Args: - initial_value: The initial value to evolve. - node_transform: A callable object that takes information of the value to - operate (e.g. location, old value, parent node) and returns a new value as - a replacement for the node. Such information allows users to not only - access the mutation node, but the entire symbolic tree if needed, allowing - complex mutation rules to be written with ease - for example - check - adjacent nodes while modifying a list element. This function is designed - to take care of both node replacements and node insertions. When insertion - happens, the old value for the location will be `pg.MISSING_VALUE`. See - `pg.composing.SeenObjectReplacer` as an example. - weights: An optional callable object that returns the unnormalized (e.g. - the sum of all probabilities don't have to sum to 1.0) mutation - probabilities for all the nodes in the symbolic tree, based on (mutation - type, location, old value, parent node), If None, all the locations and - mutation types will be sampled uniformly. - name: An optional name of the decision point. - hints: An optional hints for the decision point. - - Returns: - A `pg.hyper.Evolvable` object. - """ - return Evolvable( - initial_value=initial_value, node_transform=node_transform, - weights=weights, name=name, hints=hints) - - -def template( - value: Any, - where: Optional[Callable[[HyperPrimitive], bool]] = None) -> ObjectTemplate: - """Creates an object template from the input. - - Example:: - - d = pg.Dict(x=pg.oneof(['a', 'b', 'c'], y=pg.manyof(2, range(4)))) - t = pg.template(d) - - assert t.dna_spec() == pg.geno.space([ - pg.geno.oneof([ - pg.geno.constant(), - pg.geno.constant(), - pg.geno.constant(), - ], location='x'), - pg.geno.manyof([ - pg.geno.constant(), - pg.geno.constant(), - pg.geno.constant(), - pg.geno.constant(), - ], location='y') - ]) - - assert t.encode(pg.Dict(x='a', y=0)) == pg.DNA([0, 0]) - assert t.decode(pg.DNA([0, 0])) == pg.Dict(x='a', y=0) - - t = pg.template(d, where=lambda x: isinstance(x, pg.hyper.ManyOf)) - assert t.dna_spec() == pg.geno.space([ - pg.geno.manyof([ - pg.geno.constant(), - pg.geno.constant(), - pg.geno.constant(), - pg.geno.constant(), - ], location='y') - ]) - assert t.encode(pg.Dict(x=pg.oneof(['a', 'b', 'c']), y=0)) == pg.DNA(0) - assert t.decode(pg.DNA(0)) == pg.Dict(x=pg.oneof(['a', 'b', 'c']), y=0) - - Args: - value: A value based on which the template is created. - where: Function to filter hyper values. If None, all hyper primitives from - `value` will be included in the encoding/decoding process. Otherwise - only the hyper values on which 'where' returns True will be included. - `where` can be useful to partition a search space into separate - optimization processes. Please see 'ObjectTemplate' docstr for details. - - Returns: - A template object. - """ - return ObjectTemplate(value, compute_derived=True, where=where) - - -# -# Helper methods for operating on hyper values. -# - - -def dna_spec( - value: Any, - where: Optional[Callable[[HyperPrimitive], bool]] = None) -> geno.DNASpec: - """Returns the DNASpec from a (maybe) hyper value. - - Example:: - - hyper = pg.Dict(x=pg.oneof([1, 2, 3]), y=pg.oneof(['a', 'b'])) - spec = pg.dna_spec(hyper) - - assert spec.space_size == 6 - assert len(spec.decision_points) == 2 - print(spec.decision_points) - - # Select a partial space with `where` argument. - spec = pg.dna_spec(hyper, where=lambda x: len(x.candidates) == 2) - - assert spec.space_size == 2 - assert len(spec.decision_points) == 1 - - See also: - - * :class:`pyglove.DNASpec` - * :class:`pyglove.DNA` - - Args: - value: A (maybe) hyper value. - where: Function to filter hyper primitives. If None, all hyper primitives - from `value` will be included in the encoding/decoding process. Otherwise - only the hyper primitives on which 'where' returns True will be included. - `where` can be very useful to partition a search space into separate - optimization processes. Please see 'Template' docstr for details. - - Returns: - A DNASpec object, which represents the search space from algorithm's view. - """ - return template(value, where).dna_spec() - - -# NOTE(daiyip): For backward compatibility, we use `search_space` as an alias -# for `dna_spec`. Once downstream users are updated to call `dna_spec`, we will -# remove this method. -search_space = dna_spec - - -def materialize( - value: Any, - parameters: Union[geno.DNA, Dict[Text, Any]], - use_literal_values: bool = True, - where: Optional[Callable[[HyperPrimitive], bool]] = None) -> Any: - """Materialize a (maybe) hyper value using a DNA or parameter dict. - - Example:: - - hyper_dict = pg.Dict(x=pg.oneof(['a', 'b']), y=pg.floatv(0.0, 1.0)) - - # Materialize using DNA. - assert pg.materialize( - hyper_dict, pg.DNA([0, 0.5])) == pg.Dict(x='a', y=0.5) - - # Materialize usign key value pairs. - # See `pg.DNA.from_dict` for more details. - assert pg.materialize( - hyper_dict, {'x': 0, 'y': 0.5}) == pg.Dict(x='a', y=0.5) - - # Partially materialize. - v = pg.materialize( - hyper_dict, pg.DNA(0), where=lambda x: isinstance(x, pg.hyper.OneOf)) - assert v == pg.Dict(x='a', y=pg.floatv(0.0, 1.0)) - - Args: - value: A (maybe) hyper value - parameters: A DNA object or a dict of string (key path) to a - string (in format of '/' for - `geno.Choices`, or '' for `geno.Float`), or their literal - values when `use_literal_values` is set to True. - use_literal_values: Applicable when `parameters` is a dict. If True, the - values in the dict will be from `geno.Choices.literal_values` for - `geno.Choices`. - where: Function to filter hyper primitives. If None, all hyper primitives - from `value` will be included in the encoding/decoding process. Otherwise - only the hyper primitives on which 'where' returns True will be included. - `where` can be useful to partition a search space into separate - optimization processes. Please see 'Template' docstr for details. - - Returns: - A materialized value. - - Raises: - TypeError: if parameters is not a DNA or dict. - ValueError: if parameters cannot be decoded. - """ - t = template(value, where) - if isinstance(parameters, dict): - dna = geno.DNA.from_parameters( - parameters=parameters, - dna_spec=t.dna_spec(), - use_literal_values=use_literal_values) - else: - dna = parameters - - if not isinstance(dna, geno.DNA): - raise TypeError( - f'\'parameters\' must be a DNA or a dict of string to DNA values. ' - f'Encountered: {dna!r}.') - return t.decode(dna) - - -def iterate(hyper_value: Any, - num_examples: Optional[int] = None, - algorithm: Optional[geno.DNAGenerator] = None, - where: Optional[Callable[[HyperPrimitive], bool]] = None, - force_feedback: bool = False): - """Iterate a hyper value based on an algorithm. - - Example:: - - hyper_dict = pg.Dict(x=pg.oneof([1, 2, 3]), y=pg.oneof(['a', 'b'])) - - # Get all examples from the hyper_dict. - assert list(pg.iter(hyper_dict)) == [ - pg.Dict(x=1, y='a'), - pg.Dict(x=1, y='b'), - pg.Dict(x=2, y='a'), - pg.Dict(x=2, y='b'), - pg.Dict(x=3, y='a'), - pg.Dict(x=3, y='b'), - ] - - # Get the first two examples. - assert list(pg.iter(hyper_dict, 2)) == [ - pg.Dict(x=1, y='a'), - pg.Dict(x=1, y='b'), - ] - - # Random sample examples, which is equivalent to `pg.random_sample`. - list(pg.iter(hyper_dict, 2, pg.geno.Random())) - - # Iterate examples with feedback loop. - for d, feedback in pg.iter( - hyper_dict, 10, - pg.evolution.regularized_evolution(pg.evolution.mutators.Uniform())): - feedback(d.x) - - # Only materialize selected parts. - assert list( - pg.iter(hyper_dict, where=lambda x: len(x.candidates) == 2)) == [ - pg.Dict(x=pg.oneof([1, 2, 3]), y='a'), - pg.Dict(x=pg.oneof([1, 2, 3]), y='b'), - ] - - ``pg.iter`` distinguishes from `pg.sample` in that it's designed - for simple in-process iteration, which is handy for quickly generating - examples from algorithms without maintaining trail states. On the contrary, - `pg.sample` is designed for distributed sampling, with parallel workers and - failover handling. - - Args: - hyper_value: A hyper value that represents a space of instances. - num_examples: An optional integer as the max number of examples to - propose. If None, propose will return an iterator of infinite examples. - algorithm: An optional DNA generator. If None, Sweeping will be used, which - iterates examples in order. - where: Function to filter hyper primitives. If None, all hyper primitives - from `value` will be included in the encoding/decoding process. Otherwise - only the hyper primitives on which 'where' returns True will be included. - `where` can be useful to partition a search space into separate - optimization processes. Please see 'Template' docstr for details. - force_feedback: If True, always return the Feedback object together - with the example, this is useful when the user want to pass different - DNAGenerators to `pg.iter` and want to handle them uniformly. - - Yields: - A tuple of (example, feedback_fn) if the algorithm needs a feedback or - `force_feedback` is True, otherwise the example. - - Raises: - ValueError: when `hyper_value` is a constant value. - """ - if isinstance(hyper_value, DynamicEvaluationContext): - dynamic_evaluation_context = hyper_value - spec = hyper_value.dna_spec - t = None - else: - t = template(hyper_value, where) - if t.is_constant: - raise ValueError( - f'\'hyper_value\' is a constant value: {hyper_value!r}.') - dynamic_evaluation_context = None - spec = t.dna_spec() - - if algorithm is None: - algorithm = geno.Sweeping() - - # NOTE(daiyip): algorithm can continue if it's already set up with the same - # DNASpec, or we will setup the algorithm with the DNASpec from the template. - if algorithm.dna_spec is None: - algorithm.setup(spec) - elif symbolic.ne(spec, algorithm.dna_spec): - raise ValueError( - f'{algorithm!r} has been set up with a different DNASpec. ' - f'Existing: {algorithm.dna_spec!r}, New: {spec!r}.') - - count = 0 - while num_examples is None or count < num_examples: - try: - count += 1 - dna = algorithm.propose() - if t is not None: - example = t.decode(dna) - else: - assert dynamic_evaluation_context is not None - example = lambda: dynamic_evaluation_context.apply(dna) - if force_feedback or algorithm.needs_feedback: - yield example, Feedback(algorithm, dna) - else: - yield example - except StopIteration: - return - - -class Feedback: - """Feedback object.""" - - def __init__(self, algorithm: geno.DNAGenerator, dna: geno.DNA): - """Creates a feedback object.""" - self._algorithm = algorithm - self._dna = dna - - def __call__(self, reward: Union[float, Tuple[float, ...]]): - """Call to feedback reward.""" - self._algorithm.feedback(self._dna, reward) - - @property - def dna(self) -> geno.DNA: - """Returns DNA.""" - return self._dna - - -def random_sample( - value: Any, - num_examples: Optional[int] = None, - where: Optional[Callable[[HyperPrimitive], bool]] = None, - seed: Optional[int] = None): - """Returns an iterator of random sampled examples. - - Example:: - - hyper_dict = pg.Dict(x=pg.oneof(range(3)), y=pg.floatv(0.0, 1.0)) - - # Generate one random example from the hyper_dict. - d = next(pg.random_sample(hyper_dict)) - - # Generate 5 random examples with random seed. - ds = list(pg.random_sample(hyper_dict, 5, seed=1)) - - # Generate 3 random examples of `x` with `y` intact. - ds = list(pg.random_sample(hyper_dict, 3, - where=lambda x: isinstance(x, pg.hyper.OneOf))) - - - Args: - value: A (maybe) hyper value. - num_examples: An optional integer as number of examples to propose. If None, - propose will return an iterator that iterates forever. - where: Function to filter hyper primitives. If None, all hyper primitives in - `value` will be included in the encoding/decoding process. Otherwise only - the hyper primitives on which 'where' returns True will be included. - `where` can be useful to partition a search space into separate - optimization processes. Please see 'Template' docstr for details. - seed: An optional integer as random seed. - - Returns: - Iterator of random examples. - """ - return iterate( - value, num_examples, geno.Random(seed), where=where) - -# -# Methods for dynamically evaluting hyper values. -# - - -_thread_local_state = threading.local() -_TLS_KEY_DYNAMIC_EVALUATE_FN = 'dynamic_evaluate_fn' -_global_dynamic_evaluate_fn = None - - -@contextlib.contextmanager -def dynamic_evaluate(evaluate_fn: Optional[Callable[[HyperValue], Any]], - yield_value: Optional[Any] = None, - exit_fn: Optional[Callable[[], None]] = None, - per_thread: bool = True): - """Eagerly evaluate hyper primitives within current scope. - - Example:: - - global_indices = [0] - def evaluate_fn(x: pg.hyper.HyperPrimitive): - if isinstance(x, pg.hyper.OneOf): - return x.candidates[global_indices[0]] - raise NotImplementedError() - - with pg.hyper.dynamic_evaluate(evaluate_fn): - assert 0 = pg.oneof([0, 1, 2]) - - Please see :meth:`pyglove.DynamicEvaluationContext.apply` as an example - for using this method. - - Args: - evaluate_fn: A callable object that evaluates a hyper value such as - oneof, manyof, floatv, and etc. into a concrete value. - yield_value: Value to yield return. - exit_fn: A callable object to be called when exiting the context scope. - per_thread: If True, the context manager will be applied to current thread - only. Otherwise, it will be applied on current process. - - Yields: - `yield_value` from the argument. - """ - global _global_dynamic_evaluate_fn - if evaluate_fn is not None and not callable(evaluate_fn): - raise ValueError( - f'\'evaluate_fn\' must be either None or a callable object. ' - f'Encountered: {evaluate_fn!r}.') - if exit_fn is not None and not callable(exit_fn): - raise ValueError( - f'\'exit_fn\' must be a callable object. Encountered: {exit_fn!r}.') - if per_thread: - old_evaluate_fn = getattr( - _thread_local_state, _TLS_KEY_DYNAMIC_EVALUATE_FN, None) - else: - old_evaluate_fn = _global_dynamic_evaluate_fn - - has_errors = False - try: - if per_thread: - setattr(_thread_local_state, _TLS_KEY_DYNAMIC_EVALUATE_FN, evaluate_fn) - else: - _global_dynamic_evaluate_fn = evaluate_fn - yield yield_value - except Exception: - has_errors = True - raise - finally: - if per_thread: - setattr( - _thread_local_state, _TLS_KEY_DYNAMIC_EVALUATE_FN, old_evaluate_fn) - else: - _global_dynamic_evaluate_fn = old_evaluate_fn - if not has_errors and exit_fn is not None: - exit_fn() - - -class DynamicEvaluationContext: - """Context for dynamic evaluation of hyper primitives. - - Example:: - - import pyglove as pg - - # Define a function that implicitly declares a search space. - def foo(): - return pg.oneof(range(-10, 10)) ** 2 + pg.oneof(range(-10, 10)) ** 2 - - # Define the search space by running the `foo` once. - search_space = pg.hyper.DynamicEvaluationContext() - with search_space.collect(): - _ = foo() - - # Create a search algorithm. - search_algorithm = pg.evolution.regularized_evolution( - pg.evolution.mutators.Uniform(), population_size=32, tournament_size=16) - - # Define the feedback loop. - best_foo, best_reward = None, None - for example, feedback in pg.sample( - search_space, search_algorithm, num_examples=100): - # Call to `example` returns a context manager - # under which the `program` is connected with - # current search algorithm decisions. - with example(): - reward = foo() - feedback(reward) - if best_reward is None or best_reward < reward: - best_foo, best_reward = example, reward - """ - - class _AnnoymousHyperNameAccumulator: - """Name accumulator for annoymous hyper primitives.""" - - def __init__(self): - self.index = 0 - - def next_name(self): - name = f'decision_{self.index}' - self.index += 1 - return name - - def __init__(self, - where: Optional[Callable[[HyperPrimitive], bool]] = None, - require_hyper_name: bool = False, - per_thread: bool = True, - dna_spec: Optional[geno.DNASpec] = None) -> None: # pylint: disable=redefined-outer-name - """Create a dynamic evaluation context. - - Args: - where: A callable object that decide whether a hyper primitive should be - included when being instantiated under `collect`. - If None, all hyper primitives under `collect` will be - included. - require_hyper_name: If True, all hyper primitives (e.g. pg.oneof) must - come with a `name`. This option helps to eliminate errors when a - function that contains hyper primitive definition may be called multiple - times. Since hyper primitives sharing the same name will be registered - to the same decision point, repeated call to the hyper primitive - definition will not matter. - per_thread: If True, the context manager will be applied to current thread - only. Otherwise, it will be applied on current process. - dna_spec: External provided search space. If None, the dynamic evaluation - context can be used to create new search space via `colelct` context - manager. Otherwise, current context will use the provided DNASpec to - apply decisions. - """ - self._where = where - self._require_hyper_name: bool = require_hyper_name - self._name_to_hyper: Dict[Text, HyperPrimitive] = dict() - self._annoymous_hyper_name_accumulator = ( - DynamicEvaluationContext._AnnoymousHyperNameAccumulator()) - self._hyper_dict = symbolic.Dict() if dna_spec is None else None - self._dna_spec: Optional[geno.DNASpec] = dna_spec - self._per_thread = per_thread - self._decision_getter = None - - @property - def per_thread(self) -> bool: - """Returns True if current context collects/applies decisions per thread.""" - return self._per_thread - - @property - def dna_spec(self) -> geno.DNASpec: - """Returns the DNASpec of the search space defined so far.""" - if self._dna_spec is None: - assert self._hyper_dict is not None - self._dna_spec = dna_spec(self._hyper_dict) - return self._dna_spec - - def _decision_name(self, hyper_primitive: HyperPrimitive) -> Text: - """Get the name for a decision point.""" - name = hyper_primitive.name - if name is None: - if self._require_hyper_name: - raise ValueError( - f'\'name\' must be specified for hyper ' - f'primitive {hyper_primitive!r}.') - name = self._annoymous_hyper_name_accumulator.next_name() - return name - - @property - def is_external(self) -> bool: - """Returns True if the search space is defined by an external DNASpec.""" - return self._hyper_dict is None - - @property - def hyper_dict(self) -> Optional[symbolic.Dict]: - """Returns collected hyper primitives as a dict. - - None if current context is controlled by an external DNASpec. - """ - return self._hyper_dict - - @contextlib.contextmanager - def collect(self): - """A context manager for collecting hyper primitives within this context. - - Example:: - - context = DynamicEvaluationContext() - with context.collect(): - x = pg.oneof([1, 2, 3]) + pg.oneof([4, 5, 6]) - - # Will print 1 + 4 = 5. Meanwhile 2 hyper primitives will be registered - # in the search space represented by the context. - print(x) - - Yields: - The hyper dict representing the search space. - """ - if self.is_external: - raise ValueError( - f'`collect` cannot be called on a dynamic evaluation context that is ' - f'using an external DNASpec: {self._dna_spec}.') - - # Ensure per-thread dynamic evaluation context will not be used - # together with process-level dynamic evaluation context. - _dynamic_evaluation_stack.ensure_thread_safety(self) - - self._hyper_dict = {} - with dynamic_evaluate(self.add_decision_point, per_thread=self._per_thread): - try: - # Push current context to dynamic evaluatoin stack so nested context - # can defer unresolved hyper primitive to current context. - _dynamic_evaluation_stack.push(self) - yield self._hyper_dict - - finally: - # Invalidate DNASpec. - self._dna_spec = None - - # Pop current context from dynamic evaluatoin stack. - _dynamic_evaluation_stack.pop(self) - - def add_decision_point(self, hyper_primitive: HyperPrimitive): - """Registers a parameter with current context and return its first value.""" - def _add_child_decision_point(c): - if isinstance(c, types.LambdaType): - s = schema.get_signature(c) - if not s.args and not s.has_wildcard_args: - sub_context = DynamicEvaluationContext( - where=self._where, per_thread=self._per_thread) - sub_context._annoymous_hyper_name_accumulator = ( # pylint: disable=protected-access - self._annoymous_hyper_name_accumulator) - with sub_context.collect() as hyper_dict: - v = c() - return (v, hyper_dict) - return (c, c) - - if self._where and not self._where(hyper_primitive): - # Delegate the resolution of hyper primitives that do not pass - # the `where` predicate to its parent context. - parent_context = _dynamic_evaluation_stack.get_parent(self) - if parent_context is not None: - return parent_context.add_decision_point(hyper_primitive) - return hyper_primitive - - if isinstance(hyper_primitive, Template): - return hyper_primitive.value - - assert isinstance(hyper_primitive, HyperPrimitive), hyper_primitive - name = self._decision_name(hyper_primitive) - if isinstance(hyper_primitive, Choices): - candidate_values, candidates = zip( - *[_add_child_decision_point(c) for c in hyper_primitive.candidates]) - if hyper_primitive.choices_distinct: - assert hyper_primitive.num_choices <= len(hyper_primitive.candidates) - v = [candidate_values[i] for i in range(hyper_primitive.num_choices)] - else: - v = [candidate_values[0]] * hyper_primitive.num_choices - hyper_primitive = hyper_primitive.clone(deep=True, override={ - 'candidates': list(candidates) - }) - first_value = v[0] if isinstance(hyper_primitive, ChoiceValue) else v - elif isinstance(hyper_primitive, Float): - first_value = hyper_primitive.min_value - else: - assert isinstance(hyper_primitive, CustomHyper), hyper_primitive - first_value = hyper_primitive.decode(hyper_primitive.first_dna()) - - if (name in self._name_to_hyper - and hyper_primitive != self._name_to_hyper[name]): - raise ValueError( - f'Found different hyper primitives under the same name {name!r}: ' - f'Instance1={self._name_to_hyper[name]!r}, ' - f'Instance2={hyper_primitive!r}.') - self._hyper_dict[name] = hyper_primitive - self._name_to_hyper[name] = hyper_primitive - return first_value - - def _decision_getter_and_evaluation_finalizer( - self, decisions: Union[geno.DNA, List[Union[int, float, str]]]): - """Returns decision getter based on input decisions.""" - # NOTE(daiyip): when hyper primitives are required to carry names, we do - # decision lookup from the DNA dict. This allows the decision points - # to appear in any order other than strictly following the order of their - # appearences during the search space inspection. - if self._require_hyper_name: - if isinstance(decisions, list): - dna = geno.DNA.from_numbers(decisions, self.dna_spec) - else: - dna = decisions - dna.use_spec(self.dna_spec) - decision_dict = dna.to_dict( - key_type='name_or_id', multi_choice_key='parent') - - used_decision_names = set() - def get_decision_from_dict( - hyper_primitive, sub_index: Optional[int] = None - ) -> Union[int, float, str]: - name = hyper_primitive.name - assert name is not None, hyper_primitive - if name not in decision_dict: - raise ValueError( - f'Hyper primitive {hyper_primitive!r} is not defined during ' - f'search space inspection (pg.hyper.DynamicEvaluationContext.' - f'collect()). Please make sure `collect` and `apply` are applied ' - f'to the same function.') - - # We use assertion here since DNA is validated with `self.dna_spec`. - # User errors should be caught by `dna.use_spec`. - decision = decision_dict[name] - used_decision_names.add(name) - if (not isinstance(hyper_primitive, Choices) - or hyper_primitive.num_choices == 1): - return decision - assert isinstance(decision, list), (hyper_primitive, decision) - assert len(decision) == hyper_primitive.num_choices, ( - hyper_primitive, decision) - return decision[sub_index] - - def err_on_unused_decisions(): - if len(used_decision_names) != len(decision_dict): - remaining = {k: v for k, v in decision_dict.items() - if k not in used_decision_names} - raise ValueError( - f'Found extra decision values that are not used. {remaining!r}') - return get_decision_from_dict, err_on_unused_decisions - else: - if isinstance(decisions, geno.DNA): - decision_list = decisions.to_numbers() - else: - decision_list = decisions - value_context = dict(pos=0, value_cache={}) - - def get_decision_by_position( - hyper_primitive, sub_index: Optional[int] = None - ) -> Union[int, float, str]: - if sub_index is None or hyper_primitive.name is None: - name = hyper_primitive.name - else: - name = f'{hyper_primitive.name}:{sub_index}' - if name is None or name not in value_context['value_cache']: - if value_context['pos'] >= len(decision_list): - raise ValueError( - f'No decision is provided for {hyper_primitive!r}.') - decision = decision_list[value_context['pos']] - value_context['pos'] += 1 - if name is not None: - value_context['value_cache'][name] = decision - else: - decision = value_context['value_cache'][name] - - if (isinstance(hyper_primitive, Float) - and not isinstance(decision, float)): - raise ValueError( - f'Expect float-type decision for {hyper_primitive!r}, ' - f'encoutered {decision!r}.') - if (isinstance(hyper_primitive, CustomHyper) - and not isinstance(decision, str)): - raise ValueError( - f'Expect string-type decision for {hyper_primitive!r}, ' - f'encountered {decision!r}.') - if (isinstance(hyper_primitive, Choices) - and not (isinstance(decision, int) - and decision < len(hyper_primitive.candidates))): - raise ValueError( - f'Expect int-type decision in range ' - f'[0, {len(hyper_primitive.candidates)}) for choice {sub_index} ' - f'of {hyper_primitive!r}, encountered {decision!r}.') - return decision - - def err_on_unused_decisions(): - if value_context['pos'] != len(decision_list): - remaining = decision_list[value_context['pos']:] - raise ValueError( - f'Found extra decision values that are not used: {remaining!r}') - return get_decision_by_position, err_on_unused_decisions - - @contextlib.contextmanager - def apply( - self, decisions: Union[geno.DNA, List[Union[int, float, str]]]): - """Context manager for applying decisions. - - Example:: - - def fun(): - return pg.oneof([1, 2, 3]) + pg.oneof([4, 5, 6]) - - context = DynamicEvaluationContext() - with context.collect(): - fun() - - with context.apply([0, 1]): - # Will print 6 (1 + 5). - print(fun()) - - Args: - decisions: A DNA or a list of numbers or strings as decisions for currrent - search space. - - Yields: - None - """ - if not isinstance(decisions, (geno.DNA, list)): - raise ValueError('`decisions` should be a DNA or a list of numbers.') - - # Ensure per-thread dynamic evaluation context will not be used - # together with process-level dynamic evaluation context. - _dynamic_evaluation_stack.ensure_thread_safety(self) - - get_current_decision, evaluation_finalizer = ( - self._decision_getter_and_evaluation_finalizer(decisions)) - - has_errors = False - with dynamic_evaluate(self.evaluate, per_thread=self._per_thread): - try: - # Set decision getter for current decision. - self._decision_getter = get_current_decision - - # Push current context to dynamic evaluation stack so nested context - # can delegate evaluate to current context. - _dynamic_evaluation_stack.push(self) - - yield - except Exception: - has_errors = True - raise - finally: - # Pop current context from dynamic evaluatoin stack. - _dynamic_evaluation_stack.pop(self) - - # Reset decisions. - self._decision_getter = None - - # Call evaluation finalizer to make sure all decisions are used. - if not has_errors: - evaluation_finalizer() - - def evaluate(self, hyper_primitive: HyperPrimitive): - """Evaluates a hyper primitive based on current decisions.""" - if self._decision_getter is None: - raise ValueError( - '`evaluate` needs to be called under the `apply` context.') - - get_current_decision = self._decision_getter - def _apply_child(c): - if isinstance(c, types.LambdaType): - s = schema.get_signature(c) - if not s.args and not s.has_wildcard_args: - return c() - return c - - if self._where and not self._where(hyper_primitive): - # Delegate the resolution of hyper primitives that do not pass - # the `where` predicate to its parent context. - parent_context = _dynamic_evaluation_stack.get_parent(self) - if parent_context is not None: - return parent_context.evaluate(hyper_primitive) - return hyper_primitive - - if isinstance(hyper_primitive, Float): - return get_current_decision(hyper_primitive) - - if isinstance(hyper_primitive, CustomHyper): - return hyper_primitive.decode( - geno.DNA(get_current_decision(hyper_primitive))) - - assert isinstance(hyper_primitive, Choices), hyper_primitive - value = symbolic.List() - for i in range(hyper_primitive.num_choices): - # NOTE(daiyip): during registering the hyper primitives when - # constructing the search space, we will need to evaluate every - # candidate in order to pick up sub search spaces correctly, which is - # not necessary for `pg.DynamicEvaluationContext.apply`. - value.append(_apply_child( - hyper_primitive.candidates[get_current_decision(hyper_primitive, i)])) - if isinstance(hyper_primitive, ChoiceValue): - assert len(value) == 1 - value = value[0] - return value - - -# We maintain a stack of dynamic evaluation context for support search space -# combination -class _DynamicEvaluationStack: - """Dynamic evaluation stack used for dealing with nested evaluation.""" - - _TLS_KEY = 'dynamic_evaluation_stack' - - def __init__(self): - self._global_stack = [] - - def ensure_thread_safety(self, context: DynamicEvaluationContext): - if ((context.per_thread and self._global_stack) - or (not context.per_thread and self._local_stack)): - raise ValueError( - 'Nested dynamic evaluation contexts must be either all per-thread ' - 'or all process-wise. Please check the `per_thread` argument of ' - 'the `pg.hyper.DynamicEvaluationContext` objects being used.') - - @property - def _local_stack(self): - """Returns thread-local stack.""" - stack = getattr(_thread_local_state, self._TLS_KEY, None) - if stack is None: - stack = [] - setattr(_thread_local_state, self._TLS_KEY, stack) - return stack - - def push(self, context: DynamicEvaluationContext): - """Pushes the context to the stack.""" - stack = self._local_stack if context.per_thread else self._global_stack - stack.append(context) - - def pop(self, context: DynamicEvaluationContext): - """Pops the context from the stack.""" - stack = self._local_stack if context.per_thread else self._global_stack - assert stack - stack_top = stack.pop(-1) - assert stack_top is context, (stack_top, context) - - def get_parent( - self, - context: DynamicEvaluationContext) -> Optional[DynamicEvaluationContext]: - """Returns the parent context of the input context.""" - stack = self._local_stack if context.per_thread else self._global_stack - parent = None - for i in reversed(range(1, len(stack))): - if context is stack[i]: - parent = stack[i - 1] - break - return parent - - -# System-wise dynamic evaluation stack. -_dynamic_evaluation_stack = _DynamicEvaluationStack() - - -def trace( - fun: Callable[[], Any], - *, - where: Optional[Callable[[HyperPrimitive], bool]] = None, - require_hyper_name: bool = False, - per_thread: bool = True) -> DynamicEvaluationContext: - """Trace the hyper primitives called within a function by executing it. - - See examples in :class:`pyglove.hyper.DynamicEvaluationContext`. - - Args: - fun: Function in which the search space is defined. - where: A callable object that decide whether a hyper primitive should be - included when being instantiated under `collect`. - If None, all hyper primitives under `collect` will be included. - require_hyper_name: If True, all hyper primitives defined in this scope - will need to carry their names, which is usually a good idea when the - function that instantiates the hyper primtives need to be called multiple - times. - per_thread: If True, the context manager will be applied to current thread - only. Otherwise, it will be applied on current process. - - Returns: - An DynamicEvaluationContext that can be passed to `pg.sample`. - """ - context = DynamicEvaluationContext( - where=where, require_hyper_name=require_hyper_name, per_thread=per_thread) - with context.collect(): - fun() - return context - diff --git a/pyglove/core/hyper/__init__.py b/pyglove/core/hyper/__init__.py new file mode 100644 index 0000000..8ddd41a --- /dev/null +++ b/pyglove/core/hyper/__init__.py @@ -0,0 +1,118 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyper objects: representing template-based object space. + +In PyGlove, an object space is represented by a hyper object, which is an +symbolic object that is placeheld by hyper primitives +(:class:`pyglove.hyper.HyperPrimitive`). Through hyper objects, object templates +(:class:`pyglove.hyper.ObjectTemplate`) can be obtained to generate objects +based on program genomes (:class:`pyglove.DNA`). + + .. graphviz:: + :align: center + + digraph hypers { + node [shape="box"]; + edge [arrowtail="empty" arrowhead="none" dir="back" style="dashed"]; + hyper [label="HyperValue" href="hyper_value.html"]; + template [label="ObjectTemplate" href="object_template.html"]; + primitive [label="HyperPrimitive" href="hyper_primitive.html"]; + choices [label="Choices" href="choices.html"]; + oneof [label="OneOf" href="oneof_class.html"]; + manyof [label="ManyOf" href="manyof_class.html"]; + float [label="Float" href="float.html"]; + custom [label="CustomHyper" href="custom_hyper.html"]; + hyper -> template; + hyper -> primitive; + primitive -> choices; + choices -> oneof; + choices -> manyof; + primitive -> float; + primitive -> custom + } + +Hyper values map 1:1 to genotypes as the following: + ++-------------------------------------+----------------------------------------+ +| Hyper class | Genotype class | ++=====================================+========================================+ +|:class:`pyglove.hyper.HyperValue` |:class:`pyglove.DNASpec` | ++-------------------------------------+----------------------------------------+ +|:class:`pyglove.hyper.ObjectTemplate`|:class:`pyglove.geno.Space` | ++-------------------------------------+----------------------------------------+ +|:class:`pyglove.hyper.HyperPrimitive`|:class:`pyglove.geno.DecisionPoint` | ++-------------------------------------+----------------------------------------+ +|:class:`pyglove.hyper.Choices` |:class:`pyglove.geno.Choices` | ++-------------------------------------+----------------------------------------+ +|:class:`pyglove.hyper.Float` |:class:`pyglove.geno.Float` | ++-------------------------------------+----------------------------------------+ +|:class:`pyglove.hyper.CustomHyper` :class:`pyglove.geno.CustomDecisionPoint` | ++------------------------------------------------------------------------------+ +""" + +# pylint: disable=g-bad-import-order + +# The hyper value interface and hyper primitives. +from pyglove.core.hyper.base import HyperValue +from pyglove.core.hyper.base import HyperPrimitive + +from pyglove.core.hyper.categorical import Choices +from pyglove.core.hyper.categorical import OneOf +from pyglove.core.hyper.categorical import ManyOf +from pyglove.core.hyper.numerical import Float +from pyglove.core.hyper.custom import CustomHyper + +from pyglove.core.hyper.evolvable import Evolvable +from pyglove.core.hyper.evolvable import MutationType +from pyglove.core.hyper.evolvable import MutationPoint + +# Helper functions for creating hyper values. +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.categorical import manyof +from pyglove.core.hyper.categorical import permutate +from pyglove.core.hyper.numerical import floatv +from pyglove.core.hyper.evolvable import evolve + +# Object template and helper functions. +from pyglove.core.hyper.object_template import ObjectTemplate +from pyglove.core.hyper.object_template import template +from pyglove.core.hyper.object_template import materialize +from pyglove.core.hyper.object_template import dna_spec + +from pyglove.core.hyper.derived import DerivedValue +from pyglove.core.hyper.derived import ValueReference +from pyglove.core.hyper.derived import reference + +# Classes and functions for dynamic evaluation. +from pyglove.core.hyper.dynamic_evaluation import dynamic_evaluate +from pyglove.core.hyper.dynamic_evaluation import DynamicEvaluationContext +from pyglove.core.hyper.dynamic_evaluation import trace + + +# Helper functions for iterating examples from the search space. +from pyglove.core.hyper.iter import iterate +from pyglove.core.hyper.iter import random_sample + + +# Alias for backward compatibility: +ChoiceList = ManyOf +ChoiceValue = OneOf +Template = ObjectTemplate +one_of = oneof +sublist_of = manyof +float_value = floatv +search_space = dna_spec + + +# pylint: enable=g-bad-import-order diff --git a/pyglove/core/hyper/base.py b/pyglove/core/hyper/base.py new file mode 100644 index 0000000..3373e45 --- /dev/null +++ b/pyglove/core/hyper/base.py @@ -0,0 +1,198 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base of symbolic hyper values for representing client-side search spaces.""" + + +import abc +from typing import Any, Callable, Optional + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.object_utils import thread_local + + +class HyperValue(symbolic.NonDeterministic): # pytype: disable=ignored-metaclass + """Base class for a hyper value. + + Hyper value represents a space of objects, which is essential for + programmatically generating objects. It can encode a concrete object into a + DNA, or decode a DNA into a concrete object. + + DNA is a nestable numeric interface we use to generate object (see `geno.py`). + Each position in the DNA represents either the index of a choice, or a value + itself is numeric. There could be multiple choices standing side-by-side, + representing knobs on different parts of an object, or choices being chained, + forming conditional choice spaces, which can be described by a tree structure. + + Hyper values form a tree as the following: + + .. graphviz:: + + digraph relationship { + template [label="ObjectTemplate" href="object_template.html"]; + primitive [label="HyperPrimitive" href="hyper_primitive.html"]; + choices [label="OneOf/ManyOf" href="choices.html"]; + float [label="Float" href="float_class.html"]; + custom [label="CustomHyper" href="custom_hyper.html"]; + template -> primitive [label="elements (1:*)"]; + primitive -> choices [dir="back" arrowtail="empty" style="dashed"]; + primitive -> float [dir="back" arrowtail="empty" style="dashed"]; + primitive -> custom [dir="back" arrowtail="empty" style="dashed"]; + choices -> template [label="candidates (1:*)"]; + } + """ + + __metaclass__ = abc.ABCMeta + + def __init__(self): + # DNA and decoded value are states for __call__. + # Though `decode` and `encode` methods are stateless. + self._dna = None + self._decoded_value = None + + def set_dna(self, dna: geno.DNA) -> None: + """Use this DNA to generate value. + + NOTE(daiyip): self._dna is only used in __call__. + Thus 'set_dna' can be called multiple times to generate different values. + + Args: + dna: DNA to use to decode the value. + """ + self._dna = dna + # Invalidate decoded value when DNA is refreshed. + self._decoded_value = None + + @property + def dna(self) -> geno.DNA: + """Returns the DNA that is being used by this hyper value.""" + return self._dna + + def __call__(self) -> Any: + """Generate value from DNA provided by set_dna.""" + if self._decoded_value is None: + if self._dna is None: + raise ValueError( + '\'set_dna\' should be called to set a DNA before \'__call__\'.') + self._decoded_value = self.decode(self._dna) + return self._decoded_value + + def decode(self, dna: geno.DNA) -> Any: + """Decode a value from a DNA.""" + self.set_dna(dna) + return self._decode() + + @abc.abstractmethod + def _decode(self) -> Any: + """Decode using self.dna.""" + + @abc.abstractmethod + def encode(self, value: Any) -> geno.DNA: + """Encode a value into a DNA. + + Args: + value: A value that conforms to the hyper value definition. + + Returns: + DNA for the value. + """ + + @abc.abstractmethod + def dna_spec(self, + location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: + """Get DNA spec of DNA that is decodable/encodable by this hyper value.""" + + +@symbolic.members([ + ('name', pg_typing.Str().noneable(), + ('Name of the hyper primitive. Useful in define-by-run mode to identify a' + 'decision point in the search space - that is - different instances with ' + 'the same name will refer to the same decision point in the search space ' + 'under define-by-run mode. ' + 'Please refer to `pg.hyper.trace` for details.')), + ('hints', pg_typing.Any(default=None), 'Generator hints') +]) +class HyperPrimitive(symbolic.Object, HyperValue): + """Base class for hyper primitives. + + A hyper primitive is a pure symbolic object which represents an object + generation rule. It correspond to a decision point + (:class:`pyglove.geno.DecisionPoint`) in the algorithm's view. + + Child classes: + + * :class:`pyglove.hyper.Choices` + + * :class:`pyglove.hyper.OneOf` + * :class:`pyglove.hyper.ManyOf` + * :class:`pyglove.hyper.Float` + * :class:`pyglove.hyper.CustomHyper` + """ + + def __new__(cls, *args, **kwargs) -> Any: + """Overrides __new__ for supporting dynamic evaluation mode. + + Args: + *args: Positional arguments passed to init the custom hyper. + **kwargs: Keyword arguments passed to init the custom hyper. + + Returns: + A dynamic evaluated value according to current `dynamic_evaluate` context. + """ + dynamic_evaluate_fn = get_dynamic_evaluate_fn() + if dynamic_evaluate_fn is None: + return super().__new__(cls) # pylint: disable=no-value-for-parameter + else: + hyper_value = object.__new__(cls) + cls.__init__(hyper_value, *args, **kwargs) + return dynamic_evaluate_fn(hyper_value) # pylint: disable=not-callable + + def _sym_clone(self, deep: bool, memo=None) -> 'HyperPrimitive': + """Overrides _sym_clone to force no dynamic evaluation.""" + kwargs = dict() + for k, v in self._sym_attributes.items(): + if deep or isinstance(v, symbolic.Symbolic): + v = symbolic.clone(v, deep, memo) + kwargs[k] = v + + # NOTE(daiyip): instead of calling self.__class__(...), + # we manually create a new instance without invoking dynamic + # evaluation. + new_value = object.__new__(self.__class__) + new_value.__init__( # pylint: disable=unexpected-keyword-arg + allow_partial=self._allow_partial, sealed=self._sealed, **kwargs) + return new_value + + +_TLS_KEY_DYNAMIC_EVALUATE_FN = 'dynamic_evaluate_fn' +_global_dynamic_evaluate_fn = None + + +def set_dynamic_evaluate_fn( + fn: Optional[Callable[[HyperValue], Any]], per_thread: bool) -> None: + """Set current dynamic evaluate function.""" + global _global_dynamic_evaluate_fn + if per_thread: + assert _global_dynamic_evaluate_fn is None, _global_dynamic_evaluate_fn + thread_local.set_value(_TLS_KEY_DYNAMIC_EVALUATE_FN, fn) + else: + _global_dynamic_evaluate_fn = fn + + +def get_dynamic_evaluate_fn() -> Optional[Callable[[HyperValue], Any]]: + """Gets current dynamic evaluate function.""" + return thread_local.get_value( + _TLS_KEY_DYNAMIC_EVALUATE_FN, _global_dynamic_evaluate_fn) diff --git a/pyglove/core/hyper/categorical.py b/pyglove/core/hyper/categorical.py new file mode 100644 index 0000000..25d3e3c --- /dev/null +++ b/pyglove/core/hyper/categorical.py @@ -0,0 +1,696 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Categorical hyper primitives.""" + +import numbers +import typing +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import base +from pyglove.core.hyper import object_template + + +@symbolic.members([ + ('num_choices', pg_typing.Int(min_value=0).noneable(), + ('Number of choices to make. If set to None, any number of choices is ' + 'acceptable.')), + ('candidates', pg_typing.List(pg_typing.Any()), + ('Candidate values, which may contain nested hyper values.' + 'Candidate can customize its display value (literal) by implementing the ' + '`pg.Formattable` interface.')), + ('choices_distinct', pg_typing.Bool(True), 'Whether choices are distinct.'), + ('choices_sorted', pg_typing.Bool(False), 'Whether choices are sorted.'), + ('where', pg_typing.Callable([pg_typing.Object(base.HyperPrimitive)], + returns=pg_typing.Bool()).noneable(), + ('Callable object to filter nested hyper values. If None, all nested ' + 'hyper value will be included in the encoding/decoding process. ' + 'Otherwise only the hyper values on which `where` returns True will be ' + 'included. `where` can be useful to partition a search space into ' + 'separate optimization processes. ' + 'Please see `ObjectTemplate` docstr for details.')) +]) +class Choices(base.HyperPrimitive): + """Categorical choices from a list of candidates. + + Example:: + + # A single categorical choice: + v = pg.oneof([1, 2, 3]) + + # A multiple categorical choice as a list: + vs = pg.manyof(2, [1, 2, 3]) + + # A hierarchical categorical choice: + v2 = pg.oneof([ + 'foo', + 'bar', + pg.manyof(2, [1, 2, 3]) + ]) + + See also: + + * :class:`pyglove.hyper.OneOf` + * :class:`pyglove.hyper.ManyOf` + * :func:`pyglove.oneof` + * :func:`pyglove.manyof` + * :func:`pyglove.permutate` + """ + + def _on_bound(self): + """On members are bound.""" + super()._on_bound() + if self.num_choices > len(self.candidates) and self.choices_distinct: + raise ValueError( + f'{len(self.candidates)} candidates cannot produce ' + f'{self.num_choices} distinct choices.') + self._candidate_templates = [ + object_template.ObjectTemplate(c, where=self.where) + for c in self.candidates] + # ValueSpec for candidate. + self._value_spec = None + + def _update_children_paths( + self, old_path: object_utils.KeyPath, new_path: object_utils.KeyPath): + """Customized logic to update children paths.""" + super()._update_children_paths(old_path, new_path) + for t in self._candidate_templates: + t.root_path = self.sym_path + + @property + def candidate_templates(self): + """Returns candidate templates.""" + return self._candidate_templates + + @property + def is_leaf(self) -> bool: + """Returns whether this is a leaf node.""" + for t in self._candidate_templates: + if not t.is_constant: + return False + return True + + def dna_spec(self, + location: Optional[object_utils.KeyPath] = None) -> geno.Choices: + """Returns corresponding DNASpec.""" + return geno.Choices( + num_choices=self.num_choices, + candidates=[ct.dna_spec() for ct in self._candidate_templates], + literal_values=[self._literal_value(c) + for i, c in enumerate(self.candidates)], + distinct=self.choices_distinct, + sorted=self.choices_sorted, + hints=self.hints, + name=self.name, + location=location or object_utils.KeyPath()) + + def _literal_value( + self, candidate: Any, max_len: int = 120) -> Union[int, float, str]: + """Returns literal value for candidate.""" + if isinstance(candidate, numbers.Number): + return candidate + + literal = object_utils.format(candidate, compact=True, + hide_default_values=True, + hide_missing_values=True, + strip_object_id=True) + if len(literal) > max_len: + literal = literal[:max_len - 3] + '...' + return literal + + def _decode(self) -> List[Any]: + """Decode a DNA into a list of object.""" + dna = self._dna + if self.num_choices == 1: + # Single choice. + if not isinstance(dna.value, int): + raise ValueError( + object_utils.message_on_path( + f'Did you forget to specify values for conditional choices?\n' + f'Expect integer for {self.__class__.__name__}. ' + f'Encountered: {dna!r}.', self.sym_path)) + if dna.value >= len(self.candidates): + raise ValueError( + object_utils.message_on_path( + f'Choice out of range. Value: {dna.value!r}, ' + f'Candidates: {len(self.candidates)}.', self.sym_path)) + choices = [self._candidate_templates[dna.value].decode( + geno.DNA(None, dna.children))] + else: + # Multi choices. + if len(dna.children) != self.num_choices: + raise ValueError( + object_utils.message_on_path( + f'Number of DNA child values does not match the number of ' + f'choices. Child values: {dna.children!r}, ' + f'Choices: {self.num_choices}.', self.sym_path)) + if self.choices_distinct or self.choices_sorted: + sub_dna_values = [s.value for s in dna] + if (self.choices_distinct + and len(set(sub_dna_values)) != len(dna.children)): + raise ValueError( + object_utils.message_on_path( + f'DNA child values should be distinct. ' + f'Encountered: {sub_dna_values}.', self.sym_path)) + if self.choices_sorted and sorted(sub_dna_values) != sub_dna_values: + raise ValueError( + object_utils.message_on_path( + f'DNA child values should be sorted. ' + f'Encountered: {sub_dna_values}.', self.sym_path)) + choices = [] + for i, sub_dna in enumerate(dna): + if not isinstance(sub_dna.value, int): + raise ValueError( + object_utils.message_on_path( + f'Choice value should be int. ' + f'Encountered: {sub_dna.value}.', + object_utils.KeyPath(i, self.sym_path))) + if sub_dna.value >= len(self.candidates): + raise ValueError( + object_utils.message_on_path( + f'Choice out of range. Value: {sub_dna.value}, ' + f'Candidates: {len(self.candidates)}.', + object_utils.KeyPath(i, self.sym_path))) + choices.append(self._candidate_templates[sub_dna.value].decode( + geno.DNA(None, sub_dna.children))) + return choices + + def encode(self, value: List[Any]) -> geno.DNA: + """Encode a list of values into DNA. + + Example:: + + # DNA of an object containing a single OneOf. + # {'a': 1} => DNA(0) + { + 'a': one_of([1, 2]) + } + + + # DNA of an object containing multiple OneOfs. + # {'b': 1, 'c': bar} => DNA([0, 1]) + { + 'b': pg.oneof([1, 2]), + 'c': pg.oneof(['foo', 'bar']) + } + + # DNA of an object containing conditional space. + # {'a': {'b': 1} => DNA(0, 0, 0)]) + # {'a': {'b': [4, 7]} => DNA(1, [(0, 1), 2]) + # {'a': {'b': 'bar'} => DNA(2) + { + 'a': { + 'b': pg.oneof([ + pg.oneof([ + pg.oneof([1, 2]), + pg.oneof(3, 4)]), + pg.manyof(2, [ + pg.oneof([4, 5]), + 6, + 7 + ]), + ]), + 'bar', + ]) + } + } + + Args: + value: A list of value that can match choice candidates. + + Returns: + Encoded DNA. + + Raises: + ValueError if value cannot be encoded. + """ + if not isinstance(value, list): + raise ValueError( + object_utils.message_on_path( + f'Cannot encode value: value should be a list type. ' + f'Encountered: {value!r}.', self.sym_path)) + choices = [] + if self.num_choices is not None and len(value) != self.num_choices: + raise ValueError( + object_utils.message_on_path( + f'Length of input list is different from the number of choices ' + f'({self.num_choices}). Encountered: {value}.', self.sym_path)) + for v in value: + choice_id = None + child_dna = None + for i, b in enumerate(self._candidate_templates): + succeeded, child_dna = b.try_encode(v) + if succeeded: + choice_id = i + break + if child_dna is None: + raise ValueError( + object_utils.message_on_path( + f'Cannot encode value: no candidates matches with ' + f'the value. Value: {v!r}, Candidates: {self.candidates}.', + self.sym_path)) + choices.append(geno.DNA(choice_id, [child_dna])) + return geno.DNA(None, choices) + + +@symbolic.members( + [], + init_arg_list=[ + 'num_choices', 'candidates', 'choices_distinct', + 'choices_sorted', 'hints' + ], + # TODO(daiyip): Change to 'ManyOf' once existing code migrates to ManyOf. + serialization_key='hyper.ManyOf', + additional_keys=['pyglove.generators.genetic.ChoiceList'] +) +class ManyOf(Choices): + """N Choose K. + + Example:: + + # Chooses 2 distinct candidates. + v = pg.manyof(2, [1, 2, 3]) + + # Chooses 2 non-distinct candidates. + v = pg.manyof(2, [1, 2, 3], distinct=False) + + # Chooses 2 distinct candidates sorted by their indices. + v = pg.manyof(2, [1, 2, 3], sorted=True) + + # Permutates the candidates. + v = pg.permutate([1, 2, 3]) + + # A hierarchical categorical choice: + v2 = pg.manyof(2, [ + 'foo', + 'bar', + pg.oneof([1, 2, 3]) + ]) + + See also: + + * :func:`pyglove.manyof` + * :func:`pyglove.permutate` + * :class:`pyglove.hyper.Choices` + * :class:`pyglove.hyper.OneOf` + * :class:`pyglove.hyper.Float` + * :class:`pyglove.hyper.CustomHyper` + """ + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'Choices']: + """Validate candidates during value_spec binding time.""" + # Check if value_spec directly accepts `self`. + if value_spec.value_type and isinstance(self, value_spec.value_type): + return (False, self) + + if self._value_spec: + src_spec = self._value_spec + dest_spec = value_spec + if not dest_spec.is_compatible(src_spec): + raise TypeError( + object_utils.message_on_path( + f'Cannot bind an incompatible value spec {dest_spec} ' + f'to {self.__class__.__name__} with bound spec {src_spec}.', + path)) + return (False, self) + + list_spec = typing.cast( + pg_typing.List, + pg_typing.ensure_value_spec( + value_spec, pg_typing.List(pg_typing.Any()), path)) + if list_spec: + for i, c in enumerate(self.candidates): + list_spec.element.value.apply( + c, + self._allow_partial, + root_path=path + f'candidates[{i}]') + self._value_spec = list_spec + return (False, self) + + +@symbolic.members( + [ + ('num_choices', 1) + ], + init_arg_list=['candidates', 'hints', 'where'], + serialization_key='hyper.OneOf', + additional_keys=['pyglove.generators.genetic.ChoiceValue'] +) +class OneOf(Choices): + """N Choose 1. + + Example:: + + # A single categorical choice: + v = pg.oneof([1, 2, 3]) + + # A hierarchical categorical choice: + v2 = pg.oneof([ + 'foo', + 'bar', + pg.oneof([1, 2, 3]) + ]) + + See also: + + * :func:`pyglove.oneof` + * :class:`pyglove.hyper.Choices` + * :class:`pyglove.hyper.ManyOf` + * :class:`pyglove.hyper.Float` + * :class:`pyglove.hyper.CustomHyper` + """ + + def _on_bound(self): + """Event triggered when members are bound.""" + super()._on_bound() + assert self.num_choices == 1 + + def _decode(self) -> Any: + """Decode a DNA into an object.""" + return super()._decode()[0] + + def encode(self, value: Any) -> geno.DNA: + """Encode a value into a DNA.""" + # NOTE(daiyip): Single choice DNA will automatically be pulled + # up from children to current node. Thus we simply returns + # encoded DNA from parent node. + return super().encode([value]) + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'OneOf']: + """Validate candidates during value_spec binding time.""" + # Check if value_spec directly accepts `self`. + if value_spec.value_type and isinstance(self, value_spec.value_type): + return (False, self) + + if self._value_spec: + if not value_spec.is_compatible(self._value_spec): + raise TypeError( + object_utils.message_on_path( + f'Cannot bind an incompatible value spec {value_spec} ' + f'to {self.__class__.__name__} with bound ' + f'spec {self._value_spec}.', path)) + return (False, self) + + for i, c in enumerate(self.candidates): + value_spec.apply( + c, + self._allow_partial, + root_path=path + f'candidates[{i}]') + self._value_spec = value_spec + return (False, self) + +# +# Helper methods for creating hyper values. +# + + +def oneof(candidates: Iterable[Any], + *, + name: Optional[str] = None, + hints: Optional[Any] = None) -> Any: + """N choose 1. + + Example:: + + @pg.members([ + ('x', pg.typing.Int()) + ]) + class A(pg.Object): + pass + + # A single categorical choice: + v = pg.oneof([1, 2, 3]) + + # A complex type as candidate. + v1 = pg.oneof(['a', {'x': 1}, A(1)]) + + # A hierarchical categorical choice: + v2 = pg.oneof([ + 'foo', + 'bar', + A(pg.oneof([1, 2, 3])) + ]) + + See also: + + * :class:`pyglove.hyper.OneOf` + * :func:`pyglove.manyof` + * :func:`pyglove.floatv` + * :func:`pyglove.permutate` + * :func:`pyglove.evolve` + + .. note:: + + Under symbolic mode (by default), `pg.oneof` returns a ``pg.hyper.OneOf`` + object. Under dynamic evaluation mode, which is called under the context of + :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or + :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to + a concrete candidate value. + + To use conditional search space in dynamic evaluation mode, the candidate + should be wrapped with a `lambda` function, which is not necessary under + symbolic mode. For example:: + + pg.oneof([lambda: pg.oneof([0, 1], name='sub'), 2], name='root') + + Args: + candidates: Candidates to select from. Items of candidate can be any type, + therefore it can have nested hyper primitives, which forms a hierarchical + search space. + name: A name that can be used to identify a decision point in the search + space. This is needed when the code to instantiate the same hyper + primitive may be called multiple times under a + `pg.DynamicEvaluationContext.collect` context or under a + `pg.DynamicEvaluationContext.apply` context. + hints: An optional value which acts as a hint for the controller. + + Returns: + In symbolic mode, this function returns a `ChoiceValue`. + In dynamic evaluation mode, this function returns one of the items in + `candidates`. + If evaluated under a `pg.DynamicEvaluationContext.apply` scope, + this function will return the selected candidate. + If evaluated under a `pg.DynamicEvaluationContext.collect` + scope, it will return the first candidate. + """ + return OneOf(candidates=list(candidates), name=name, hints=hints) + + +def manyof(k: int, + candidates: Iterable[Any], + distinct: bool = True, + sorted: bool = False, # pylint: disable=redefined-builtin + *, + name: Optional[str] = None, + hints: Optional[Any] = None, + **kwargs) -> Any: + """N choose K. + + Example:: + + @pg.members([ + ('x', pg.typing.Int()) + ]) + class A(pg.Object): + pass + + # Chooses 2 distinct candidates. + v = pg.manyof(2, [1, 2, 3]) + + # Chooses 2 non-distinct candidates. + v = pg.manyof(2, [1, 2, 3], distinct=False) + + # Chooses 2 distinct candidates sorted by their indices. + v = pg.manyof(2, [1, 2, 3], sorted=True) + + # A complex type as candidate. + v1 = pg.manyof(2, ['a', {'x': 1}, A(1)]) + + # A hierarchical categorical choice: + v2 = pg.manyof(2, [ + 'foo', + 'bar', + A(pg.oneof([1, 2, 3])) + ]) + + .. note:: + + Under symbolic mode (by default), `pg.manyof` returns a ``pg.hyper.ManyOf`` + object. Under dynamic evaluation mode, which is called under the context of + :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or + :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to + a concrete candidate value. + + To use conditional search space in dynamic evaluate mode, the candidate + should be wrapped with a `lambda` function, which is not necessary under + symbolic mode. For example:: + + pg.manyof(2, [ + lambda: pg.oneof([0, 1], name='sub_a'), + lambda: pg.floatv(0.0, 1.0, name='sub_b'), + lambda: pg.manyof(2, ['a', 'b', 'c'], name='sub_c') + ], name='root') + + See also: + + * :class:`pyglove.hyper.ManyOf` + * :func:`pyglove.manyof` + * :func:`pyglove.floatv` + * :func:`pyglove.permutate` + * :func:`pyglove.evolve` + + Args: + k: number of choices to make. Should be no larger than the length of + `candidates` unless `choice_distinct` is set to False, + candidates: Candidates to select from. Items of candidate can be any type, + therefore it can have nested hyper primitives, which forms a hierarchical + search space. + distinct: If True, each choice needs to be unique. + sorted: If True, choices are sorted by their indices in the + candidates. + name: A name that can be used to identify a decision point in the search + space. This is needed when the code to instantiate the same hyper + primitive may be called multiple times under a + `pg.DynamicEvaluationContext.collect` context or a + `pg.DynamicEvaluationContext.apply` context. + hints: An optional value which acts as a hint for the controller. + **kwargs: Keyword arguments for backward compatibility. + `choices_distinct`: Old name for `distinct`. + `choices_sorted`: Old name for `sorted`. + + Returns: + In symbolic mode, this function returns a `Choices`. + In dynamic evaluate mode, this function returns a list of items in + `candidates`. + If evaluated under a `pg.DynamicEvaluationContext.apply` scope, + this function will return a list of selected candidates. + If evaluated under a `pg.DynamicEvaluationContext.collect` + scope, it will return a list of the first valid combination from the + `candidates`. For example:: + + # Evaluates to [0, 1, 2]. + manyof(3, range(5)) + + # Evaluates to [0, 0, 0]. + manyof(3, range(5), distinct=False) + """ + choices_distinct = kwargs.pop('choices_distinct', distinct) + choices_sorted = kwargs.pop('choices_sorted', sorted) + return ManyOf( + num_choices=k, + candidates=list(candidates), + choices_distinct=choices_distinct, + choices_sorted=choices_sorted, + name=name, + hints=hints) + + +def permutate(candidates: Iterable[Any], + name: Optional[str] = None, + hints: Optional[Any] = None) -> Any: + """Permuatation of candidates. + + Example:: + + @pg.members([ + ('x', pg.typing.Int()) + ]) + class A(pg.Object): + pass + + # Permutates the candidates. + v = pg.permutate([1, 2, 3]) + + # A complex type as candidate. + v1 = pg.permutate(['a', {'x': 1}, A(1)]) + + # A hierarchical categorical choice: + v2 = pg.permutate([ + 'foo', + 'bar', + A(pg.oneof([1, 2, 3])) + ]) + + .. note:: + + Under symbolic mode (by default), `pg.manyof` returns a ``pg.hyper.ManyOf`` + object. Under dynamic evaluate mode, which is called under the context of + :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or + :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to + a concrete candidate value. + + To use conditional search space in dynamic evaluate mode, the candidate + should be wrapped with a `lambda` function, which is not necessary under + symbolic mode. For example:: + + pg.permutate([ + lambda: pg.oneof([0, 1], name='sub_a'), + lambda: pg.floatv(0.0, 1.0, name='sub_b'), + lambda: pg.manyof(2, ['a', 'b', 'c'], name='sub_c') + ], name='root') + + See also: + + * :class:`pyglove.hyper.ManyOf` + * :func:`pyglove.oneof` + * :func:`pyglove.manyof` + * :func:`pyglove.floatv` + * :func:`pyglove.evolve` + + Args: + candidates: Candidates to select from. Items of candidate can be any type, + therefore it can have nested hyper primitives, which forms a hierarchical + search space. + name: A name that can be used to identify a decision point in the search + space. This is needed when the code to instantiate the same hyper + primitive may be called multiple times under a + `pg.DynamicEvaluationContext.collect` context or a + `pg.DynamicEvaluationContext.apply` context. + hints: An optional value which acts as a hint for the controller. + + Returns: + In symbolic mode, this function returns a `Choices`. + In dynamic evaluate mode, this function returns a permutation from + `candidates`. + If evaluated under an `pg.DynamicEvaluationContext.apply` scope, + this function will return a permutation of candidates based on controller + decisions. + If evaluated under a `pg.DynamicEvaluationContext.collect` + scope, it will return the first valid permutation. + For example:: + + # Evaluates to [0, 1, 2, 3, 4]. + permutate(range(5), name='numbers') + """ + candidates = list(candidates) + return manyof( + len(candidates), candidates, + choices_distinct=True, choices_sorted=False, name=name, hints=hints) diff --git a/pyglove/core/hyper/categorical_test.py b/pyglove/core/hyper/categorical_test.py new file mode 100644 index 0000000..0d3d95b --- /dev/null +++ b/pyglove/core/hyper/categorical_test.py @@ -0,0 +1,404 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.Choices.""" + +import unittest +from pyglove.core import geno +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper.categorical import manyof +from pyglove.core.hyper.categorical import ManyOf +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.categorical import OneOf +from pyglove.core.hyper.numerical import floatv + + +class OneOfTest(unittest.TestCase): + """Tests for pg.oneof.""" + + def test_dna_spec(self): + + class C: + pass + + self.assertTrue(symbolic.eq( + oneof(candidates=[ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + 'c': C() + }, + [floatv(min_value=1.0, max_value=2.0), 1.0], + ]).dna_spec('a.b'), + geno.Choices( + num_choices=1, + candidates=[ + geno.constant(), + geno.Space(elements=[ + geno.Float(min_value=0.0, max_value=1.0, location='a'), + geno.Choices( + num_choices=1, + candidates=[ + geno.constant(), + geno.constant(), + geno.constant() + ], + literal_values=[1, 2, 3], + location='b'), + ]), + geno.Space(elements=[ + geno.Float(min_value=1.0, max_value=2.0, location='[0]') + ]) + ], + literal_values=[ + '\'foo\'', + ('{a=Float(min_value=0.0, max_value=1.0), ' + 'b=OneOf(candidates=[0: 1, 1: 2, 2: 3]), ' + 'c=C(...)}'), + '[0: Float(min_value=1.0, max_value=2.0), 1: 1.0]', + ], + location='a.b'))) + + def test_decode(self): + choice_value = oneof(candidates=[ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + }, + [floatv(min_value=1.0, max_value=2.0), 1.0], + ]) + + self.assertEqual(choice_value.decode(geno.DNA.parse(0)), 'foo') + + self.assertEqual( + choice_value.decode(geno.DNA.parse((1, [0.5, 0]))), { + 'a': 0.5, + 'b': 1 + }) + + self.assertEqual(choice_value.decode(geno.DNA.parse((2, 1.5))), [1.5, 1.0]) + + with self.assertRaisesRegex(ValueError, 'Choice out of range'): + choice_value.decode(geno.DNA.parse(5)) + + with self.assertRaisesRegex( + ValueError, 'Encountered extra DNA value to decode'): + choice_value.decode(geno.DNA.parse((0, 1))) + + with self.assertRaisesRegex( + ValueError, + 'The length of child values .* is different from the number ' + 'of hyper primitives'): + choice_value.decode(geno.DNA.parse((1, 0))) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + choice_value.decode(geno.DNA.parse((1, [1, 0]))) + + with self.assertRaisesRegex( + ValueError, + 'The length of child values .* is different from the number ' + 'of hyper primitives'): + choice_value.decode(geno.DNA.parse((1, [0.5, 1, 2]))) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + choice_value.decode(geno.DNA.parse(2)) + + with self.assertRaisesRegex( + ValueError, 'DNA value should be no greater than'): + choice_value.decode(geno.DNA.parse((2, 5.0))) + + def test_encode(self): + choice_value = oneof(candidates=[ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + }, + [floatv(min_value=1.0, max_value=2.0), 1.0], + ]) + self.assertEqual(choice_value.encode('foo'), geno.DNA(0)) + self.assertEqual( + choice_value.encode({ + 'a': 0.5, + 'b': 1 + }), geno.DNA.parse((1, [0.5, 0]))) + self.assertEqual(choice_value.encode([1.5, 1.0]), geno.DNA.parse((2, 1.5))) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_value.encode(['bar']) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + print(choice_value.encode({'a': 0.5})) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_value.encode({'a': 1.8, 'b': 1}) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_value.encode([1.0]) + + def test_assignment_compatibility(self): + sd = symbolic.Dict.partial( + value_spec=pg_typing.Dict([ + ('a', pg_typing.Str()), + ('b', pg_typing.Int()), + ('c', pg_typing.Union([pg_typing.Str(), pg_typing.Int()])), + ('d', pg_typing.Any()) + ])) + choice_value = oneof(candidates=[1, 'foo']) + sd.c = choice_value + sd.d = choice_value + + with self.assertRaisesRegex( + TypeError, 'Cannot bind an incompatible value spec'): + sd.a = choice_value + + with self.assertRaisesRegex( + TypeError, 'Cannot bind an incompatible value spec'): + sd.b = choice_value + + def test_custom_apply(self): + o = oneof([1, 2]) + self.assertIs(pg_typing.Object(OneOf).apply(o), o) + self.assertIs(pg_typing.Int().apply(o), o) + with self.assertRaisesRegex( + TypeError, r'Cannot bind an incompatible value spec Float\(\)'): + pg_typing.Float().apply(o) + + +class ManyOfTest(unittest.TestCase): + """Test for pg.manyof.""" + + def test_bad_init(self): + with self.assertRaisesRegex( + ValueError, '.* candidates cannot produce .* distinct choices'): + manyof(3, [1, 2], distinct=True) + + def test_dna_spec(self): + # Test simple choice list without nested encoders. + self.assertTrue(symbolic.eq( + manyof( + 2, ['foo', 1, 2, 'bar'], sorted=True, distinct=True).dna_spec(), + geno.manyof(2, [ + geno.constant(), + geno.constant(), + geno.constant(), + geno.constant() + ], literal_values=[ + '\'foo\'', 1, 2, '\'bar\'' + ], sorted=True, distinct=True))) + + # Test complex choice list with nested encoders. + self.assertTrue(symbolic.eq( + oneof([ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + }, + [floatv(min_value=1.0, max_value=2.0, scale='linear'), 1.0], + ]).dna_spec('a.b'), + geno.oneof([ + geno.constant(), + geno.space([ + geno.floatv(min_value=0.0, max_value=1.0, location='a'), + geno.oneof([ + geno.constant(), + geno.constant(), + geno.constant() + ], literal_values=[1, 2, 3], location='b') + ]), + geno.floatv(1.0, 2.0, scale='linear', location='[0]') + ], literal_values=[ + '\'foo\'', + ('{a=Float(min_value=0.0, max_value=1.0), ' + 'b=OneOf(candidates=[0: 1, 1: 2, 2: 3])}'), + '[0: Float(min_value=1.0, max_value=2.0, scale=\'linear\'), 1: 1.0]', + ], location='a.b'))) + + def test_decode(self): + choice_list = manyof(2, [ + 'foo', 1, 2, 'bar' + ], choices_sorted=True, choices_distinct=True) + self.assertTrue(choice_list.is_leaf) + self.assertEqual(choice_list.decode(geno.DNA.parse([0, 1])), ['foo', 1]) + + with self.assertRaisesRegex( + ValueError, + 'Number of DNA child values does not match the number of choices'): + choice_list.decode(geno.DNA.parse([1, 0, 0])) + + with self.assertRaisesRegex(ValueError, 'Choice value should be int'): + choice_list.decode(geno.DNA.parse([0, 0.1])) + + with self.assertRaisesRegex(ValueError, 'Choice out of range'): + choice_list.decode(geno.DNA.parse([0, 5])) + + with self.assertRaisesRegex( + ValueError, 'DNA child values should be sorted'): + choice_list.decode(geno.DNA.parse([1, 0])) + + with self.assertRaisesRegex( + ValueError, 'DNA child values should be distinct'): + choice_list.decode(geno.DNA.parse([0, 0])) + + choice_list = manyof(1, [ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + }, + [floatv(min_value=1.0, max_value=2.0), 1.0], + ]) + self.assertFalse(choice_list.is_leaf) + self.assertEqual(choice_list.decode(geno.DNA.parse(0)), ['foo']) + + self.assertEqual( + choice_list.decode(geno.DNA.parse((1, [0.5, 0]))), [{ + 'a': 0.5, + 'b': 1 + }]) + + self.assertEqual(choice_list.decode(geno.DNA.parse((2, 1.5))), [[1.5, 1.0]]) + + with self.assertRaisesRegex(ValueError, 'Choice out of range'): + choice_list.decode(geno.DNA.parse(5)) + + with self.assertRaisesRegex( + ValueError, 'Encountered extra DNA value to decode'): + choice_list.decode(geno.DNA.parse((0, 1))) + + with self.assertRaisesRegex( + ValueError, + 'The length of child values .* is different from the number ' + 'of hyper primitives'): + choice_list.decode(geno.DNA.parse((1, 0))) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + choice_list.decode(geno.DNA.parse((1, [1, 0]))) + + with self.assertRaisesRegex( + ValueError, + 'The length of child values .* is different from the number ' + 'of hyper primitives'): + choice_list.decode(geno.DNA.parse((1, [0.5, 1, 2]))) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + choice_list.decode(geno.DNA.parse(2)) + + with self.assertRaisesRegex( + ValueError, 'DNA value should be no greater than'): + choice_list.decode(geno.DNA.parse((2, 5.0))) + + def test_encode(self): + choice_list = manyof(1, [ + 'foo', + { + 'a': floatv(min_value=0.0, max_value=1.0), + 'b': oneof(candidates=[1, 2, 3]), + }, + [floatv(min_value=1.0, max_value=2.0), 1.0], + ]) + self.assertEqual(choice_list.encode(['foo']), geno.DNA(0)) + self.assertEqual( + choice_list.encode([{ + 'a': 0.5, + 'b': 1 + }]), geno.DNA.parse((1, [0.5, 0]))) + self.assertEqual(choice_list.encode([[1.5, 1.0]]), geno.DNA.parse((2, 1.5))) + + with self.assertRaisesRegex( + ValueError, 'Cannot encode value: value should be a list type'): + choice_list.encode('bar') + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_list.encode(['bar']) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + print(choice_list.encode([{'a': 0.5}])) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_list.encode([{'a': 1.8, 'b': 1}]) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + choice_list.encode([[1.0]]) + + choice_list = manyof(2, ['a', 'b', 'c']) + self.assertEqual(choice_list.encode(['a', 'c']), geno.DNA.parse([0, 2])) + with self.assertRaisesRegex( + ValueError, + 'Length of input list is different from the number of choices'): + choice_list.encode(['a']) + + def test_assignment_compatibility(self): + """Test drop-in type compatibility.""" + sd = symbolic.Dict.partial( + value_spec=pg_typing.Dict([ + ('a', pg_typing.Int()), + ('b', pg_typing.List(pg_typing.Int())), + ('c', pg_typing.List(pg_typing.Union( + [pg_typing.Str(), pg_typing.Int()]))), + ('d', pg_typing.Any()) + ])) + choice_list = manyof(2, [1, 'foo']) + sd.c = choice_list + sd.d = choice_list + + with self.assertRaisesRegex( + TypeError, 'Cannot bind an incompatible value spec Int\\(\\)'): + sd.a = choice_list + + with self.assertRaisesRegex( + TypeError, + 'Cannot bind an incompatible value spec List\\(Int\\(\\)\\)'): + sd.b = choice_list + + def test_custom_apply(self): + l = manyof(2, [1, 2, 3]) + self.assertIs(pg_typing.Object(ManyOf).apply(l), l) + self.assertIs(pg_typing.List(pg_typing.Int()).apply(l), l) + with self.assertRaisesRegex( + TypeError, r'Cannot bind an incompatible value spec List\(Float\(\)\)'): + pg_typing.List(pg_typing.Float()).apply(l) + + class A: + pass + + class B: + pass + + t = oneof([B()]) + self.assertEqual( + pg_typing.Union([pg_typing.Object(A), pg_typing.Object(B)]).apply(t), t) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/custom.py b/pyglove/core/hyper/custom.py new file mode 100644 index 0000000..acf1543 --- /dev/null +++ b/pyglove/core/hyper/custom.py @@ -0,0 +1,159 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom hyper primitives.""" + +import abc +import random +import types +from typing import Any, Callable, Optional, Tuple, Union + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import base + + +class CustomHyper(base.HyperPrimitive): + """User-defined hyper primitive. + + User-defined hyper primitive is useful when users want to have full control + on the semantics and genome encoding of the search space. For example, the + decision points are of variable length, which is not yet supported by + built-in hyper primitives. + + To use user-defined hyper primitive is simple, the user should: + + + 1) Subclass `CustomHyper` and implements the + :meth:`pyglove.hyper.CustomHyper.custom_decode` method. + It's optional to implement the + :meth:`pyglove.hyper.CustomHyper.custom_encode` method, which is only + necessary when the user want to encoder a material object into a DNA. + 2) Introduce a DNAGenerator that can generate DNA for the + :class:`pyglove.geno.CustomDecisionPoint`. + + For example, the following code tries to find an optimal sub-sequence of an + integer sequence by their sums:: + + import random + + class IntSequence(pg.hyper.CustomHyper): + + def custom_decode(self, dna): + return [int(v) for v in dna.value.split(',') if v != ''] + + class SubSequence(pg.evolution.Mutator): + + def mutate(self, dna): + genome = dna.value + items = genome.split(',') + start = random.randint(0, len(items)) + end = random.randint(start, len(items)) + new_genome = ','.join(items[start:end]) + return pg.DNA(new_genome, spec=dna.spec) + + @pg.geno.dna_generator + def initial_population(): + yield pg.DNA('12,-34,56,-2,100,98', spec=dna_spec) + + algo = pg.evolution.Evolution( + (pg.evolution.selectors.Random(10) + >> pg.evolution.selectors.Top(1) + >> SubSequence()), + population_init=initial_population(), + population_update=pg.evolution.selectors.Last(20)) + + best_reward, best_example = None, None + for int_seq, feedback in pg.sample(IntSequence(), algo, num_examples=100): + reward = sum(int_seq) + if best_reward is None or best_reward < reward: + best_reward, best_example = reward, int_seq + feedback(reward) + + print(best_reward, best_example) + + Please note that user-defined hyper value can be used together with PyGlove's + built-in hyper primitives, for example:: + + pg.oneof([IntSequence(), None]) + + Therefore it's also a mechanism to extend PyGlove's search space definitions. + """ + + def _decode(self): + if not isinstance(self.dna.value, str): + raise ValueError( + f'{self.__class__} expects string type DNA. ' + f'Encountered {self.dna!r}.') + return self.custom_decode(self.dna) + + @abc.abstractmethod + def custom_decode(self, dna: geno.DNA) -> Any: + """Decode a DNA whose value is a string of user-defined genome.""" + + def encode(self, value: Any) -> geno.DNA: + """Encode value into DNA with user-defined genome.""" + return self.custom_encode(value) + + def custom_encode(self, value: Any) -> geno.DNA: + """Encode value to user defined genome.""" + raise NotImplementedError( + f'\'custom_encode\' is not supported by {self.__class__.__name__!r}.') + + def dna_spec( + self, location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: + """Always returns CustomDecisionPoint for CustomHyper.""" + return geno.CustomDecisionPoint( + hyper_type=self.__class__.__name__, + next_dna_fn=self.next_dna, + random_dna_fn=self.random_dna, + hints=self.hints, name=self.name, location=location) + + def first_dna(self) -> geno.DNA: + """Returns the first DNA of current sub-space. + + Returns: + A string-valued DNA. + """ + if self.next_dna.__code__ is CustomHyper.next_dna.__code__: + raise NotImplementedError( + f'{self.__class__!r} must implement method `next_dna` to be used in ' + f'dynamic evaluation mode.') + return self.next_dna(None) + + def next_dna(self, dna: Optional[geno.DNA] = None) -> Optional[geno.DNA]: + """Subclasses should override this method to support pg.Sweeping.""" + raise NotImplementedError( + f'`next_dna` is not implemented in f{self.__class__!r}') + + def random_dna( + self, + random_generator: Union[types.ModuleType, random.Random, None] = None, + previous_dna: Optional[geno.DNA] = None) -> geno.DNA: + """Subclasses should override this method to support pg.random_dna.""" + raise NotImplementedError( + f'`random_dna` is not implemented in {self.__class__!r}') + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'CustomHyper']: + """Validate candidates during value_spec binding time.""" + del path, value_spec, allow_partial, child_transform + # Allow custom hyper to be assigned to any type. + return (False, self) diff --git a/pyglove/core/hyper/custom_test.py b/pyglove/core/hyper/custom_test.py new file mode 100644 index 0000000..fd32fbb --- /dev/null +++ b/pyglove/core/hyper/custom_test.py @@ -0,0 +1,111 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.CustomHyper.""" + +import random +import unittest + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic + +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.custom import CustomHyper +from pyglove.core.hyper.iter import iterate +from pyglove.core.hyper.object_template import materialize + + +class IntSequence(CustomHyper): + + def custom_decode(self, dna): + return [int(v) for v in dna.value.split(',')] + + +class IntSequenceWithEncode(IntSequence): + + def custom_encode(self, value): + return geno.DNA(','.join([str(v) for v in value])) + + def next_dna(self, dna): + if dna is None: + return geno.DNA(','.join([str(i) for i in range(5)])) + v = self.custom_decode(dna) + v.append(len(v)) + return self._create_dna(v) + + def random_dna(self, random_generator, previous_dna): + del previous_dna + k = random_generator.randint(0, 10) + v = random_generator.choices(list(range(10)), k=k) + return self._create_dna(v) + + def _create_dna(self, numbers): + return geno.DNA(','.join([str(n) for n in numbers])) + + +class CustomHyperTest(unittest.TestCase): + """Test for CustomHyper.""" + + def test_dna_spec(self): + self.assertTrue(symbolic.eq( + IntSequence(hints='x').dna_spec('a'), + geno.CustomDecisionPoint( + hyper_type='IntSequence', + location=object_utils.KeyPath('a'), + hints='x'))) + + def test_decode(self): + self.assertEqual(IntSequence().decode(geno.DNA('0,1,2')), [0, 1, 2]) + self.assertEqual(IntSequence().decode(geno.DNA('0')), [0]) + with self.assertRaisesRegex(ValueError, '.* expects string type DNA'): + IntSequence().decode(geno.DNA(1)) + + def test_encode(self): + self.assertEqual( + IntSequenceWithEncode().encode([0, 1, 2]), geno.DNA('0,1,2')) + + with self.assertRaisesRegex( + NotImplementedError, '\'custom_encode\' is not supported by'): + _ = IntSequence().encode([0, 1, 2]) + + def test_random_dna(self): + self.assertEqual( + geno.random_dna( + IntSequenceWithEncode().dna_spec('a'), random.Random(1)), + geno.DNA('5,8')) + + with self.assertRaisesRegex( + NotImplementedError, '`random_dna` is not implemented in .*'): + geno.random_dna(IntSequence().dna_spec('a')) + + def test_iter(self): + self.assertEqual(IntSequenceWithEncode().first_dna(), geno.DNA('0,1,2,3,4')) + self.assertEqual( + list(iterate(IntSequenceWithEncode(), 3)), + [[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4, 5], + [0, 1, 2, 3, 4, 5, 6]]) + + with self.assertRaisesRegex( + NotImplementedError, '`next_dna` is not implemented in .*'): + next(iterate(IntSequence())) + + def test_interop_with_other_primitives(self): + v = oneof([IntSequence(), 1, 2]) + self.assertEqual(materialize(v, geno.DNA(1)), 1) + self.assertEqual(materialize(v, geno.DNA((0, '3,4'))), [3, 4]) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/derived.py b/pyglove/core/hyper/derived.py new file mode 100644 index 0000000..277dc32 --- /dev/null +++ b/pyglove/core/hyper/derived.py @@ -0,0 +1,154 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Derived value from other hyper primitives.""" + +import abc +import copy +from typing import Any, Callable, List, Optional, Tuple, Union + +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing + + +@symbolic.members([ + ('reference_paths', pg_typing.List(pg_typing.Object(object_utils.KeyPath)), + ('Paths of referenced values, which are relative paths searched from ' + 'current node to root.')) +]) +class DerivedValue(symbolic.Object, pg_typing.CustomTyping): + """Base class of value that references to other values in object tree.""" + + @abc.abstractmethod + def derive(self, *args: Any) -> Any: + """Derive the value from referenced values.""" + + def resolve( + self, reference_path_or_paths: Optional[Union[str, List[str]]] = None + ) -> Union[Tuple[symbolic.Symbolic, object_utils.KeyPath], + List[Tuple[symbolic.Symbolic, object_utils.KeyPath]]]: + """Resolve reference paths based on the location of this node. + + Args: + reference_path_or_paths: (Optional) a string or KeyPath as a reference + path or a list of strings or KeyPath objects as a list of + reference paths. + If this argument is not provided, prebound reference paths of this + object will be used. + + Returns: + A tuple (or list of tuple) of (resolved parent, resolved full path) + """ + single_input = False + if reference_path_or_paths is None: + reference_paths = self.reference_paths + elif isinstance(reference_path_or_paths, str): + reference_paths = [object_utils.KeyPath.parse(reference_path_or_paths)] + single_input = True + elif isinstance(reference_path_or_paths, object_utils.KeyPath): + reference_paths = [reference_path_or_paths] + single_input = True + elif isinstance(reference_path_or_paths, list): + paths = [] + for path in reference_path_or_paths: + if isinstance(path, str): + path = object_utils.KeyPath.parse(path) + elif not isinstance(path, object_utils.KeyPath): + raise ValueError('Argument \'reference_path_or_paths\' must be None, ' + 'a string, KeyPath object, a list of strings, or a ' + 'list of KeyPath objects.') + paths.append(path) + reference_paths = paths + else: + raise ValueError('Argument \'reference_path_or_paths\' must be None, ' + 'a string, KeyPath object, a list of strings, or a ' + 'list of KeyPath objects.') + + resolved_paths = [] + for reference_path in reference_paths: + parent = self.sym_parent + while parent is not None and not reference_path.exists(parent): + parent = getattr(parent, 'sym_parent', None) + if parent is None: + raise ValueError( + f'Cannot resolve \'{reference_path}\': parent not found.') + resolved_paths.append((parent, parent.sym_path + reference_path)) + return resolved_paths if not single_input else resolved_paths[0] + + def __call__(self): + """Generate value by deriving values from reference paths.""" + referenced_values = [] + for reference_path, (parent, _) in zip( + self.reference_paths, self.resolve()): + referenced_value = reference_path.query(parent) + + # Make sure referenced value does not have referenced value. + # NOTE(daiyip): We can support dependencies between derived values + # in future if needed. + if not object_utils.traverse( + referenced_value, self._contains_not_derived_value): + raise ValueError( + f'Derived value (path={referenced_value.sym_path}) should not ' + f'reference derived values. ' + f'Encountered: {referenced_value}, ' + f'Referenced at path {self.sym_path}.') + referenced_values.append(referenced_value) + return self.derive(*referenced_values) + + def _contains_not_derived_value( + self, path: object_utils.KeyPath, value: Any) -> bool: + """Returns whether a value contains derived value.""" + if isinstance(value, DerivedValue): + return False + elif isinstance(value, symbolic.Object): + for k, v in value.sym_items(): + if not object_utils.traverse( + v, self._contains_not_derived_value, + root_path=object_utils.KeyPath(k, path)): + return False + return True + + +class ValueReference(DerivedValue): + """Class that represents a value referencing another value.""" + + def _on_bound(self): + """Custom init.""" + super()._on_bound() + if len(self.reference_paths) != 1: + raise ValueError( + f'Argument \'reference_paths\' should have exact 1 ' + f'item. Encountered: {self.reference_paths}') + + def derive(self, referenced_value: Any) -> Any: + """Derive value by return a copy of the referenced value.""" + return copy.copy(referenced_value) + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'DerivedValue']: + """Implement pg_typing.CustomTyping interface.""" + # TODO(daiyip): perform possible static analysis on referenced paths. + del path, value_spec, allow_partial, child_transform + return (False, self) + + +def reference(reference_path: str) -> ValueReference: + """Create a referenced value from a referenced path.""" + return ValueReference(reference_paths=[reference_path]) diff --git a/pyglove/core/hyper/derived_test.py b/pyglove/core/hyper/derived_test.py new file mode 100644 index 0000000..eedd233 --- /dev/null +++ b/pyglove/core/hyper/derived_test.py @@ -0,0 +1,137 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.ValueReference.""" + +import unittest + +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper.derived import ValueReference + + +class ValueReferenceTest(unittest.TestCase): + """Tests for pg.hyper.ValueReference.""" + + def test_resolve(self): + sd = symbolic.Dict({'c': [ + { + 'x': [{ + 'z': 0 + }], + }, + { + 'x': [{ + 'z': 1 + }] + }, + ]}) + sd.a = ValueReference(reference_paths=['c[0].x[0].z']) + self.assertEqual(sd.a.resolve(), [(sd, 'c[0].x[0].z')]) + + # References refer to the same relative path under different parent. + ref = ValueReference(reference_paths=['x[0].z']) + sd.c[0].y = ref + sd.c[1].y = ref + self.assertEqual(sd.c[0].y.resolve(), [(sd.c[0], 'c[0].x[0].z')]) + self.assertEqual(sd.c[1].y.resolve(), [(sd.c[1], 'c[1].x[0].z')]) + # Resolve references from this point. + self.assertEqual(sd.c[0].y.resolve(object_utils.KeyPath(0)), (sd.c, 'c[0]')) + self.assertEqual(sd.c[0].y.resolve('[0]'), (sd.c, 'c[0]')) + self.assertEqual( + sd.c[0].y.resolve(['[0]', '[1]']), [(sd.c, 'c[0]'), (sd.c, 'c[1]')]) + + # Bad inputs. + with self.assertRaisesRegex( + ValueError, + 'Argument \'reference_path_or_paths\' must be None, a string, KeyPath ' + 'object, a list of strings, or a list of KeyPath objects.'): + sd.c[0].y.resolve([1]) + + with self.assertRaisesRegex( + ValueError, + 'Argument \'reference_path_or_paths\' must be None, a string, KeyPath ' + 'object, a list of strings, or a list of KeyPath objects.'): + sd.c[0].y.resolve(1) + + with self.assertRaisesRegex( + ValueError, 'Cannot resolve .*: parent not found.'): + ValueReference(reference_paths=['x[0].z']).resolve() + + def test_call(self): + + @symbolic.members([('a', pg_typing.Int(), 'Field a.')]) + class A(symbolic.Object): + pass + + sd = symbolic.Dict({'c': [ + { + 'x': [{ + 'z': 0 + }], + }, + { + 'x': [{ + 'z': A(a=1) + }] + }, + ]}) + sd.a = ValueReference(reference_paths=['c[0].x[0].z']) + self.assertEqual(sd.a(), 0) + + # References refer to the same relative path under different parent. + ref = ValueReference(reference_paths=['x[0]']) + sd.c[0].y = ref + sd.c[1].y = ref + self.assertEqual(sd.c[0].y(), {'z': 0}) + self.assertEqual(sd.c[1].y(), {'z': A(a=1)}) + + # References to another reference is not supported. + sd.c[1].z = ValueReference(reference_paths=['y']) + with self.assertRaisesRegex( + ValueError, + 'Derived value .* should not reference derived values'): + sd.c[1].z() + + sd.c[1].z = ValueReference(reference_paths=['c']) + with self.assertRaisesRegex( + ValueError, + 'Derived value .* should not reference derived values'): + sd.c[1].z() + + def test_assignment_compatibility(self): + sd = symbolic.Dict.partial( + x=0, + value_spec=pg_typing.Dict([ + ('x', pg_typing.Int()), + ('y', pg_typing.Int()), + ('z', pg_typing.Str()) + ])) + + sd.y = ValueReference(['x']) + # TODO(daiyip): Enable this test once static analysis is done + # on derived values. + # with self.assertRaisesRegexp( + # TypeError, ''): + # sd.z = ValueReference(['x']) + + def test_bad_init(self): + with self.assertRaisesRegex( + ValueError, + 'Argument \'reference_paths\' should have exact 1 item'): + ValueReference([]) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/dynamic_evaluation.py b/pyglove/core/hyper/dynamic_evaluation.py new file mode 100644 index 0000000..a1e5d5a --- /dev/null +++ b/pyglove/core/hyper/dynamic_evaluation.py @@ -0,0 +1,588 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dynamic evaluation for hyper primitives.""" + +import contextlib +import types +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +from pyglove.core import geno +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import base +from pyglove.core.hyper import categorical +from pyglove.core.hyper import custom +from pyglove.core.hyper import numerical +from pyglove.core.hyper import object_template +from pyglove.core.object_utils import thread_local + + +@contextlib.contextmanager +def dynamic_evaluate(evaluate_fn: Optional[Callable[[base.HyperValue], Any]], + yield_value: Any = None, + exit_fn: Optional[Callable[[], None]] = None, + per_thread: bool = True) -> Iterator[Any]: + """Eagerly evaluate hyper primitives within current scope. + + Example:: + + global_indices = [0] + def evaluate_fn(x: pg.hyper.HyperPrimitive): + if isinstance(x, pg.hyper.OneOf): + return x.candidates[global_indices[0]] + raise NotImplementedError() + + with pg.hyper.dynamic_evaluate(evaluate_fn): + assert 0 = pg.oneof([0, 1, 2]) + + Please see :meth:`pyglove.DynamicEvaluationContext.apply` as an example + for using this method. + + Args: + evaluate_fn: A callable object that evaluates a hyper value such as + oneof, manyof, floatv, and etc. into a concrete value. + yield_value: Value to yield return. + exit_fn: A callable object to be called when exiting the context scope. + per_thread: If True, the context manager will be applied to current thread + only. Otherwise, it will be applied on current process. + + Yields: + `yield_value` from the argument. + """ + if evaluate_fn is not None and not callable(evaluate_fn): + raise ValueError( + f'\'evaluate_fn\' must be either None or a callable object. ' + f'Encountered: {evaluate_fn!r}.') + if exit_fn is not None and not callable(exit_fn): + raise ValueError( + f'\'exit_fn\' must be a callable object. Encountered: {exit_fn!r}.') + old_evaluate_fn = base.get_dynamic_evaluate_fn() + has_errors = False + try: + base.set_dynamic_evaluate_fn(evaluate_fn, per_thread) + yield yield_value + except Exception: + has_errors = True + raise + finally: + base.set_dynamic_evaluate_fn(old_evaluate_fn, per_thread) + if not has_errors and exit_fn is not None: + exit_fn() + + +class DynamicEvaluationContext: + """Context for dynamic evaluation of hyper primitives. + + Example:: + + import pyglove as pg + + # Define a function that implicitly declares a search space. + def foo(): + return pg.oneof(range(-10, 10)) ** 2 + pg.oneof(range(-10, 10)) ** 2 + + # Define the search space by running the `foo` once. + search_space = pg.hyper.DynamicEvaluationContext() + with search_space.collect(): + _ = foo() + + # Create a search algorithm. + search_algorithm = pg.evolution.regularized_evolution( + pg.evolution.mutators.Uniform(), population_size=32, tournament_size=16) + + # Define the feedback loop. + best_foo, best_reward = None, None + for example, feedback in pg.sample( + search_space, search_algorithm, num_examples=100): + # Call to `example` returns a context manager + # under which the `program` is connected with + # current search algorithm decisions. + with example(): + reward = foo() + feedback(reward) + if best_reward is None or best_reward < reward: + best_foo, best_reward = example, reward + """ + + class _AnnoymousHyperNameAccumulator: + """Name accumulator for annoymous hyper primitives.""" + + def __init__(self): + self.index = 0 + + def next_name(self): + name = f'decision_{self.index}' + self.index += 1 + return name + + def __init__(self, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None, + require_hyper_name: bool = False, + per_thread: bool = True, + dna_spec: Optional[geno.DNASpec] = None) -> None: # pylint: disable=redefined-outer-name + """Create a dynamic evaluation context. + + Args: + where: A callable object that decide whether a hyper primitive should be + included when being instantiated under `collect`. + If None, all hyper primitives under `collect` will be + included. + require_hyper_name: If True, all hyper primitives (e.g. pg.oneof) must + come with a `name`. This option helps to eliminate errors when a + function that contains hyper primitive definition may be called multiple + times. Since hyper primitives sharing the same name will be registered + to the same decision point, repeated call to the hyper primitive + definition will not matter. + per_thread: If True, the context manager will be applied to current thread + only. Otherwise, it will be applied on current process. + dna_spec: External provided search space. If None, the dynamic evaluation + context can be used to create new search space via `colelct` context + manager. Otherwise, current context will use the provided DNASpec to + apply decisions. + """ + self._where = where + self._require_hyper_name: bool = require_hyper_name + self._name_to_hyper: Dict[str, base.HyperPrimitive] = dict() + self._annoymous_hyper_name_accumulator = ( + DynamicEvaluationContext._AnnoymousHyperNameAccumulator()) + self._hyper_dict = symbolic.Dict() if dna_spec is None else None + self._dna_spec: Optional[geno.DNASpec] = dna_spec + self._per_thread = per_thread + self._decision_getter = None + + @property + def per_thread(self) -> bool: + """Returns True if current context collects/applies decisions per thread.""" + return self._per_thread + + @property + def dna_spec(self) -> geno.DNASpec: + """Returns the DNASpec of the search space defined so far.""" + if self._dna_spec is None: + assert self._hyper_dict is not None + self._dna_spec = object_template.dna_spec(self._hyper_dict) + return self._dna_spec + + def _decision_name(self, hyper_primitive: base.HyperPrimitive) -> str: + """Get the name for a decision point.""" + name = hyper_primitive.name + if name is None: + if self._require_hyper_name: + raise ValueError( + f'\'name\' must be specified for hyper ' + f'primitive {hyper_primitive!r}.') + name = self._annoymous_hyper_name_accumulator.next_name() + return name + + @property + def is_external(self) -> bool: + """Returns True if the search space is defined by an external DNASpec.""" + return self._hyper_dict is None + + @property + def hyper_dict(self) -> Optional[symbolic.Dict]: + """Returns collected hyper primitives as a dict. + + None if current context is controlled by an external DNASpec. + """ + return self._hyper_dict + + @contextlib.contextmanager + def collect(self): + """A context manager for collecting hyper primitives within this context. + + Example:: + + context = DynamicEvaluationContext() + with context.collect(): + x = pg.oneof([1, 2, 3]) + pg.oneof([4, 5, 6]) + + # Will print 1 + 4 = 5. Meanwhile 2 hyper primitives will be registered + # in the search space represented by the context. + print(x) + + Yields: + The hyper dict representing the search space. + """ + if self.is_external: + raise ValueError( + f'`collect` cannot be called on a dynamic evaluation context that is ' + f'using an external DNASpec: {self._dna_spec}.') + + # Ensure per-thread dynamic evaluation context will not be used + # together with process-level dynamic evaluation context. + _dynamic_evaluation_stack.ensure_thread_safety(self) + + self._hyper_dict = {} + with dynamic_evaluate(self.add_decision_point, per_thread=self._per_thread): + try: + # Push current context to dynamic evaluatoin stack so nested context + # can defer unresolved hyper primitive to current context. + _dynamic_evaluation_stack.push(self) + yield self._hyper_dict + + finally: + # Invalidate DNASpec. + self._dna_spec = None + + # Pop current context from dynamic evaluatoin stack. + _dynamic_evaluation_stack.pop(self) + + def add_decision_point(self, hyper_primitive: base.HyperPrimitive): + """Registers a parameter with current context and return its first value.""" + def _add_child_decision_point(c): + if isinstance(c, types.LambdaType): + s = pg_typing.get_signature(c) + if not s.args and not s.has_wildcard_args: + sub_context = DynamicEvaluationContext( + where=self._where, per_thread=self._per_thread) + sub_context._annoymous_hyper_name_accumulator = ( # pylint: disable=protected-access + self._annoymous_hyper_name_accumulator) + with sub_context.collect() as hyper_dict: + v = c() + return (v, hyper_dict) + return (c, c) + + if self._where and not self._where(hyper_primitive): + # Delegate the resolution of hyper primitives that do not pass + # the `where` predicate to its parent context. + parent_context = _dynamic_evaluation_stack.get_parent(self) + if parent_context is not None: + return parent_context.add_decision_point(hyper_primitive) + return hyper_primitive + + if isinstance(hyper_primitive, object_template.ObjectTemplate): + return hyper_primitive.value + + assert isinstance(hyper_primitive, base.HyperPrimitive), hyper_primitive + name = self._decision_name(hyper_primitive) + if isinstance(hyper_primitive, categorical.Choices): + candidate_values, candidates = zip( + *[_add_child_decision_point(c) for c in hyper_primitive.candidates]) + if hyper_primitive.choices_distinct: + assert hyper_primitive.num_choices <= len(hyper_primitive.candidates) + v = [candidate_values[i] for i in range(hyper_primitive.num_choices)] + else: + v = [candidate_values[0]] * hyper_primitive.num_choices + hyper_primitive = hyper_primitive.clone(deep=True, override={ + 'candidates': list(candidates) + }) + first_value = v[0] if isinstance( + hyper_primitive, categorical.OneOf) else v + elif isinstance(hyper_primitive, numerical.Float): + first_value = hyper_primitive.min_value + else: + assert isinstance(hyper_primitive, custom.CustomHyper), hyper_primitive + first_value = hyper_primitive.decode(hyper_primitive.first_dna()) + + if (name in self._name_to_hyper + and hyper_primitive != self._name_to_hyper[name]): + raise ValueError( + f'Found different hyper primitives under the same name {name!r}: ' + f'Instance1={self._name_to_hyper[name]!r}, ' + f'Instance2={hyper_primitive!r}.') + self._hyper_dict[name] = hyper_primitive + self._name_to_hyper[name] = hyper_primitive + return first_value + + def _decision_getter_and_evaluation_finalizer( + self, decisions: Union[geno.DNA, List[Union[int, float, str]]]): + """Returns decision getter based on input decisions.""" + # NOTE(daiyip): when hyper primitives are required to carry names, we do + # decision lookup from the DNA dict. This allows the decision points + # to appear in any order other than strictly following the order of their + # appearences during the search space inspection. + if self._require_hyper_name: + if isinstance(decisions, list): + dna = geno.DNA.from_numbers(decisions, self.dna_spec) + else: + dna = decisions + dna.use_spec(self.dna_spec) + decision_dict = dna.to_dict( + key_type='name_or_id', multi_choice_key='parent') + + used_decision_names = set() + def get_decision_from_dict( + hyper_primitive, sub_index: Optional[int] = None + ) -> Union[int, float, str]: + name = hyper_primitive.name + assert name is not None, hyper_primitive + if name not in decision_dict: + raise ValueError( + f'Hyper primitive {hyper_primitive!r} is not defined during ' + f'search space inspection (pg.hyper.DynamicEvaluationContext.' + f'collect()). Please make sure `collect` and `apply` are applied ' + f'to the same function.') + + # We use assertion here since DNA is validated with `self.dna_spec`. + # User errors should be caught by `dna.use_spec`. + decision = decision_dict[name] + used_decision_names.add(name) + if (not isinstance(hyper_primitive, categorical.Choices) + or hyper_primitive.num_choices == 1): + return decision + assert isinstance(decision, list), (hyper_primitive, decision) + assert len(decision) == hyper_primitive.num_choices, ( + hyper_primitive, decision) + return decision[sub_index] + + def err_on_unused_decisions(): + if len(used_decision_names) != len(decision_dict): + remaining = {k: v for k, v in decision_dict.items() + if k not in used_decision_names} + raise ValueError( + f'Found extra decision values that are not used. {remaining!r}') + return get_decision_from_dict, err_on_unused_decisions + else: + if isinstance(decisions, geno.DNA): + decision_list = decisions.to_numbers() + else: + decision_list = decisions + value_context = dict(pos=0, value_cache={}) + + def get_decision_by_position( + hyper_primitive, sub_index: Optional[int] = None + ) -> Union[int, float, str]: + if sub_index is None or hyper_primitive.name is None: + name = hyper_primitive.name + else: + name = f'{hyper_primitive.name}:{sub_index}' + if name is None or name not in value_context['value_cache']: + if value_context['pos'] >= len(decision_list): + raise ValueError( + f'No decision is provided for {hyper_primitive!r}.') + decision = decision_list[value_context['pos']] + value_context['pos'] += 1 + if name is not None: + value_context['value_cache'][name] = decision + else: + decision = value_context['value_cache'][name] + + if (isinstance(hyper_primitive, numerical.Float) + and not isinstance(decision, float)): + raise ValueError( + f'Expect float-type decision for {hyper_primitive!r}, ' + f'encoutered {decision!r}.') + if (isinstance(hyper_primitive, custom.CustomHyper) + and not isinstance(decision, str)): + raise ValueError( + f'Expect string-type decision for {hyper_primitive!r}, ' + f'encountered {decision!r}.') + if (isinstance(hyper_primitive, categorical.Choices) + and not (isinstance(decision, int) + and decision < len(hyper_primitive.candidates))): + raise ValueError( + f'Expect int-type decision in range ' + f'[0, {len(hyper_primitive.candidates)}) for choice {sub_index} ' + f'of {hyper_primitive!r}, encountered {decision!r}.') + return decision + + def err_on_unused_decisions(): + if value_context['pos'] != len(decision_list): + remaining = decision_list[value_context['pos']:] + raise ValueError( + f'Found extra decision values that are not used: {remaining!r}') + return get_decision_by_position, err_on_unused_decisions + + @contextlib.contextmanager + def apply( + self, decisions: Union[geno.DNA, List[Union[int, float, str]]]): + """Context manager for applying decisions. + + Example:: + + def fun(): + return pg.oneof([1, 2, 3]) + pg.oneof([4, 5, 6]) + + context = DynamicEvaluationContext() + with context.collect(): + fun() + + with context.apply([0, 1]): + # Will print 6 (1 + 5). + print(fun()) + + Args: + decisions: A DNA or a list of numbers or strings as decisions for currrent + search space. + + Yields: + None + """ + if not isinstance(decisions, (geno.DNA, list)): + raise ValueError('`decisions` should be a DNA or a list of numbers.') + + # Ensure per-thread dynamic evaluation context will not be used + # together with process-level dynamic evaluation context. + _dynamic_evaluation_stack.ensure_thread_safety(self) + + get_current_decision, evaluation_finalizer = ( + self._decision_getter_and_evaluation_finalizer(decisions)) + + has_errors = False + with dynamic_evaluate(self.evaluate, per_thread=self._per_thread): + try: + # Set decision getter for current decision. + self._decision_getter = get_current_decision + + # Push current context to dynamic evaluation stack so nested context + # can delegate evaluate to current context. + _dynamic_evaluation_stack.push(self) + + yield + except Exception: + has_errors = True + raise + finally: + # Pop current context from dynamic evaluatoin stack. + _dynamic_evaluation_stack.pop(self) + + # Reset decisions. + self._decision_getter = None + + # Call evaluation finalizer to make sure all decisions are used. + if not has_errors: + evaluation_finalizer() + + def evaluate(self, hyper_primitive: base.HyperPrimitive): + """Evaluates a hyper primitive based on current decisions.""" + if self._decision_getter is None: + raise ValueError( + '`evaluate` needs to be called under the `apply` context.') + + get_current_decision = self._decision_getter + def _apply_child(c): + if isinstance(c, types.LambdaType): + s = pg_typing.get_signature(c) + if not s.args and not s.has_wildcard_args: + return c() + return c + + if self._where and not self._where(hyper_primitive): + # Delegate the resolution of hyper primitives that do not pass + # the `where` predicate to its parent context. + parent_context = _dynamic_evaluation_stack.get_parent(self) + if parent_context is not None: + return parent_context.evaluate(hyper_primitive) + return hyper_primitive + + if isinstance(hyper_primitive, numerical.Float): + return get_current_decision(hyper_primitive) + + if isinstance(hyper_primitive, custom.CustomHyper): + return hyper_primitive.decode( + geno.DNA(get_current_decision(hyper_primitive))) + + assert isinstance(hyper_primitive, categorical.Choices), hyper_primitive + value = symbolic.List() + for i in range(hyper_primitive.num_choices): + # NOTE(daiyip): during registering the hyper primitives when + # constructing the search space, we will need to evaluate every + # candidate in order to pick up sub search spaces correctly, which is + # not necessary for `pg.DynamicEvaluationContext.apply`. + value.append(_apply_child( + hyper_primitive.candidates[get_current_decision(hyper_primitive, i)])) + if isinstance(hyper_primitive, categorical.OneOf): + assert len(value) == 1 + value = value[0] + return value + + +# We maintain a stack of dynamic evaluation context for support search space +# combination +class _DynamicEvaluationStack: + """Dynamic evaluation stack used for dealing with nested evaluation.""" + + _TLS_KEY = 'dynamic_evaluation_stack' + + def __init__(self): + self._global_stack = [] + + def ensure_thread_safety(self, context: DynamicEvaluationContext): + if ((context.per_thread and self._global_stack) + or (not context.per_thread and self._local_stack)): + raise ValueError( + 'Nested dynamic evaluation contexts must be either all per-thread ' + 'or all process-wise. Please check the `per_thread` argument of ' + 'the `pg.hyper.DynamicEvaluationContext` objects being used.') + + @property + def _local_stack(self): + """Returns thread-local stack.""" + stack = thread_local.get_value(self._TLS_KEY, None) + if stack is None: + stack = [] + thread_local.set_value(self._TLS_KEY, stack) + return stack + + def push(self, context: DynamicEvaluationContext): + """Pushes the context to the stack.""" + stack = self._local_stack if context.per_thread else self._global_stack + stack.append(context) + + def pop(self, context: DynamicEvaluationContext): + """Pops the context from the stack.""" + stack = self._local_stack if context.per_thread else self._global_stack + assert stack + stack_top = stack.pop(-1) + assert stack_top is context, (stack_top, context) + + def get_parent( + self, + context: DynamicEvaluationContext) -> Optional[DynamicEvaluationContext]: + """Returns the parent context of the input context.""" + stack = self._local_stack if context.per_thread else self._global_stack + parent = None + for i in reversed(range(1, len(stack))): + if context is stack[i]: + parent = stack[i - 1] + break + return parent + + +# System-wise dynamic evaluation stack. +_dynamic_evaluation_stack = _DynamicEvaluationStack() + + +def trace( + fun: Callable[[], Any], + *, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None, + require_hyper_name: bool = False, + per_thread: bool = True) -> DynamicEvaluationContext: + """Trace the hyper primitives called within a function by executing it. + + See examples in :class:`pyglove.hyper.DynamicEvaluationContext`. + + Args: + fun: Function in which the search space is defined. + where: A callable object that decide whether a hyper primitive should be + included when being instantiated under `collect`. + If None, all hyper primitives under `collect` will be included. + require_hyper_name: If True, all hyper primitives defined in this scope + will need to carry their names, which is usually a good idea when the + function that instantiates the hyper primtives need to be called multiple + times. + per_thread: If True, the context manager will be applied to current thread + only. Otherwise, it will be applied on current process. + + Returns: + An DynamicEvaluationContext that can be passed to `pg.sample`. + """ + context = DynamicEvaluationContext( + where=where, require_hyper_name=require_hyper_name, per_thread=per_thread) + with context.collect(): + fun() + return context + diff --git a/pyglove/core/hyper/dynamic_evaluation_test.py b/pyglove/core/hyper/dynamic_evaluation_test.py new file mode 100644 index 0000000..c7583ce --- /dev/null +++ b/pyglove/core/hyper/dynamic_evaluation_test.py @@ -0,0 +1,523 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.DynamicEvaluationContext.""" + +import threading +import unittest + +from pyglove.core import geno +from pyglove.core import symbolic +from pyglove.core.hyper.categorical import manyof +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.categorical import OneOf +from pyglove.core.hyper.categorical import permutate +from pyglove.core.hyper.custom import CustomHyper +from pyglove.core.hyper.dynamic_evaluation import dynamic_evaluate as pg_dynamic_evaluate +from pyglove.core.hyper.dynamic_evaluation import DynamicEvaluationContext +from pyglove.core.hyper.dynamic_evaluation import trace as pg_trace +from pyglove.core.hyper.numerical import floatv +from pyglove.core.hyper.object_template import template as pg_template + + +class DynamicEvaluationTest(unittest.TestCase): + """Dynamic evaluation test.""" + + def test_dynamic_evaluate(self): + with self.assertRaisesRegex( + ValueError, '\'evaluate_fn\' must be either None or a callable object'): + with pg_dynamic_evaluate(1): + pass + + with self.assertRaisesRegex( + ValueError, '\'exit_fn\' must be a callable object'): + with pg_dynamic_evaluate(None, exit_fn=1): + pass + + def test_evaluated_values_during_collect(self): + with DynamicEvaluationContext().collect(): + self.assertEqual(oneof([0, 1]), 0) + self.assertEqual(oneof([{'x': oneof(['a', 'b'])}, 1]), + {'x': 'a'}) + self.assertEqual(manyof(2, [0, 1, 3]), [0, 1]) + self.assertEqual(manyof(4, [0, 1, 3], distinct=False), + [0, 0, 0, 0]) + self.assertEqual(permutate([0, 1, 2]), [0, 1, 2]) + self.assertEqual(floatv(0.0, 1.0), 0.0) + + def test_per_thread_collect_and_apply(self): + def thread_fun(): + context = DynamicEvaluationContext() + with context.collect(): + oneof(range(10)) + + with context.apply([3]): + self.assertEqual(oneof(range(10)), 3) + + threads = [] + for _ in range(10): + thread = threading.Thread(target=thread_fun) + threads.append(thread) + thread.start() + for t in threads: + t.join() + + def test_process_wise_collect_and_apply(self): + def thread_fun(): + _ = oneof(range(10)) + + context = DynamicEvaluationContext(per_thread=False) + with context.collect() as hyper_dict: + threads = [] + for _ in range(10): + thread = threading.Thread(target=thread_fun) + threads.append(thread) + thread.start() + for t in threads: + t.join() + self.assertEqual(len(hyper_dict), 10) + + def test_search_space_defined_without_hyper_name(self): + def fun(): + x = oneof([1, 2, 3]) + 1 + y = sum(manyof(2, [2, 4, 6, 8], name='y')) + z = floatv(min_value=1.0, max_value=2.0) + return x + y + z + + # Test dynamic evaluation by allowing reentry (all hyper primitives will + # be registered twice). + context = DynamicEvaluationContext() + with context.collect() as hyper_dict: + result = fun() + result = fun() + + # 1 + 1 + 2 + 4 + 1.0 + self.assertEqual(result, 9.0) + self.assertEqual(hyper_dict, { + 'decision_0': oneof([1, 2, 3]), + 'y': manyof(2, [2, 4, 6, 8], name='y'), + 'decision_1': floatv(min_value=1.0, max_value=2.0), + 'decision_2': oneof([1, 2, 3]), + 'decision_3': floatv(min_value=1.0, max_value=2.0), + }) + + with context.apply(geno.DNA.parse( + [1, [0, 2], 1.5, 0, 1.8])): + # 2 + 1 + 2 + 6 + 1.5 + self.assertEqual(fun(), 12.5) + # 1 + 1 + 2 + 6 + 1.8 + self.assertEqual(fun(), 11.8) + + def test_search_space_defined_with_hyper_name(self): + def fun(): + x = oneof([1, 2, 3], name='a') + 1 + y = sum(manyof(2, [2, 4, 6, 8], name='b')) + z = floatv(min_value=1.0, max_value=2.0, name='c') + return x + y + z + + # Test dynamic evaluation by disallowing reentry (all hyper primitives will + # be registered only once). + context = DynamicEvaluationContext(require_hyper_name=True) + with context.collect() as hyper_dict: + with self.assertRaisesRegex( + ValueError, '\'name\' must be specified for hyper primitive'): + oneof([1, 2, 3]) + result = fun() + result = fun() + + # 1 + 1 + 2 + 4 + 1.0 + self.assertEqual(result, 9.0) + self.assertEqual(hyper_dict, symbolic.Dict( + a=oneof([1, 2, 3], name='a'), + b=manyof(2, [2, 4, 6, 8], name='b'), + c=floatv(min_value=1.0, max_value=2.0, name='c'))) + with context.apply(geno.DNA.parse([1, [0, 2], 1.5])): + # We can call fun multiple times since decision will be bound to each + # name just once. + # 2 + 1 + 2 + 6 + 1.5 + self.assertEqual(fun(), 12.5) + self.assertEqual(fun(), 12.5) + self.assertEqual(fun(), 12.5) + + def test_hierarchical_search_space(self): + def fun(): + return oneof([ + lambda: sum(manyof(2, [2, 4, 6, 8])), + lambda: oneof([3, 7]), + lambda: floatv(min_value=1.0, max_value=2.0), + 10]) + oneof([11, 22]) + + context = DynamicEvaluationContext() + with context.collect() as hyper_dict: + result = fun() + # 2 + 4 + 11 + self.assertEqual(result, 17) + self.assertEqual(hyper_dict, { + 'decision_0': oneof([ + # NOTE(daiyip): child decisions within candidates are always in + # form of list. + { + 'decision_1': manyof(2, [2, 4, 6, 8]), + }, + { + 'decision_2': oneof([3, 7]) + }, + { + 'decision_3': floatv(min_value=1.0, max_value=2.0) + }, + 10, + ]), + 'decision_4': oneof([11, 22]) + }) + + with context.apply(geno.DNA.parse([(0, [1, 3]), 0])): + # 4 + 8 + 11 + self.assertEqual(fun(), 23) + + # Use list-form decisions. + with context.apply([0, 1, 3, 0]): + # 4 + 8 + 11 + self.assertEqual(fun(), 23) + + with context.apply(geno.DNA.parse([(1, 1), 1])): + # 7 + 22 + self.assertEqual(fun(), 29) + + with context.apply(geno.DNA.parse([(2, 1.5), 0])): + # 1.5 + 11 + self.assertEqual(fun(), 12.5) + + with context.apply(geno.DNA.parse([3, 1])): + # 10 + 22 + self.assertEqual(fun(), 32) + + with self.assertRaisesRegex( + ValueError, '`decisions` should be a DNA or a list of numbers.'): + with context.apply(3): + fun() + + with self.assertRaisesRegex( + ValueError, 'No decision is provided for .*'): + with context.apply(geno.DNA.parse(3)): + fun() + + with self.assertRaisesRegex( + ValueError, 'Expect float-type decision for .*'): + with context.apply([2, 0, 1]): + fun() + + with self.assertRaisesRegex( + ValueError, 'Expect int-type decision in range .*'): + with context.apply([5, 0.5, 0]): + fun() + + with self.assertRaisesRegex( + ValueError, 'Found extra decision values that are not used.*'): + with context.apply(geno.DNA.parse([(1, 1), 1, 1])): + fun() + + def test_hierarchical_search_space_with_hyper_name(self): + def fun(): + return oneof([ + lambda: sum(manyof(2, [2, 4, 6, 8], name='a1')), + lambda: oneof([3, 7], name='a2'), + lambda: floatv(min_value=1.0, max_value=2.0, name='a3.xx'), + 10], name='a') + oneof([11, 22], name='b') + + context = DynamicEvaluationContext(require_hyper_name=True) + with context.collect() as hyper_dict: + result = fun() + result = fun() + + # 2 + 4 + 11 + self.assertEqual(result, 17) + self.assertEqual(hyper_dict, { + 'a': oneof([ + # NOTE(daiyip): child decisions within candidates are always in + # form of list. + {'a1': manyof(2, [2, 4, 6, 8], name='a1')}, + {'a2': oneof([3, 7], name='a2')}, + {'a3.xx': floatv(min_value=1.0, max_value=2.0, name='a3.xx')}, + 10, + ], name='a'), + 'b': oneof([11, 22], name='b') + }) + + with context.apply(geno.DNA.parse([(0, [1, 3]), 0])): + # 4 + 8 + 11 + self.assertEqual(fun(), 23) + self.assertEqual(fun(), 23) + self.assertEqual(fun(), 23) + + # Use list form. + with context.apply([0, 1, 3, 0]): + # 4 + 8 + 11 + self.assertEqual(fun(), 23) + self.assertEqual(fun(), 23) + self.assertEqual(fun(), 23) + + with context.apply(geno.DNA.parse([(1, 1), 1])): + # 7 + 22 + self.assertEqual(fun(), 29) + self.assertEqual(fun(), 29) + + with context.apply(geno.DNA.parse([(2, 1.5), 0])): + # 1.5 + 11 + self.assertEqual(fun(), 12.5) + self.assertEqual(fun(), 12.5) + + with context.apply(geno.DNA.parse([3, 1])): + # 10 + 22 + self.assertEqual(fun(), 32) + self.assertEqual(fun(), 32) + + with self.assertRaisesRegex( + ValueError, '`decisions` should be a DNA or a list of numbers.'): + with context.apply(3): + fun() + + with self.assertRaisesRegex( + ValueError, 'DNA value type mismatch'): + with context.apply(geno.DNA.parse(3)): + fun() + + with self.assertRaisesRegex( + ValueError, 'Found extra decision values that are not used'): + with context.apply(context.dna_spec.first_dna()): + # Do not consume any decision points from the search space. + _ = 1 + + with self.assertRaisesRegex( + ValueError, + 'Hyper primitive .* is not defined during search space inspection'): + with context.apply(context.dna_spec.first_dna()): + # Do not consume any decision points from the search space. + _ = oneof(range(5), name='uknown') + + def test_where_statement(self): + context = DynamicEvaluationContext( + where=lambda x: getattr(x, 'name') != 'x') + with context.collect(): + self.assertEqual(oneof(range(10)), 0) + self.assertIsInstance(oneof(range(5), name='x'), OneOf) + + with context.apply([1]): + self.assertEqual(oneof(range(10)), 1) + self.assertIsInstance(oneof(range(5), name='x'), OneOf) + + def test_trace(self): + def fun(): + return oneof([-1, 0, 1]) * oneof([-1, 0, 3]) + 1 + + self.assertEqual( + pg_trace(fun).hyper_dict, + { + 'decision_0': oneof([-1, 0, 1]), + 'decision_1': oneof([-1, 0, 3]) + }) + + def test_dynamic_evaluation_with_custom_hyper(self): + + class IntList(CustomHyper): + + def custom_decode(self, dna): + return [int(x) for x in dna.value.split(':')] + + def first_dna(self): + return geno.DNA('0:1:2:3') + + def fun(): + return sum(IntList()) + oneof([0, 1]) + floatv(-1., 1.) + + context = DynamicEvaluationContext() + with context.collect(): + fun() + + self.assertEqual( + context.hyper_dict, + { + 'decision_0': IntList(), + 'decision_1': oneof([0, 1]), + 'decision_2': floatv(-1., 1.) + }) + with context.apply(geno.DNA(['1:2:3:4', 1, 0.5])): + self.assertEqual(fun(), 1 + 2 + 3 + 4 + 1 + 0.5) + + with self.assertRaisesRegex( + ValueError, 'Expect string-type decision for .*'): + with context.apply(geno.DNA([0, 1, 0.5])): + fun() + + class IntListWithoutFirstDNA(CustomHyper): + + def custom_decode(self, dna): + return [int(x) for x in dna.value.split(':')] + + context = DynamicEvaluationContext() + with self.assertRaisesRegex( + NotImplementedError, + '.* must implement method `next_dna` to be used in ' + 'dynamic evaluation mode'): + with context.collect(): + IntListWithoutFirstDNA() + + def test_dynamic_evaluation_with_external_dna_spec(self): + def fun(): + return oneof(range(5), name='x') + oneof(range(3), name='y') + + context = pg_trace(fun, require_hyper_name=True, per_thread=True) + self.assertFalse(context.is_external) + self.assertIsNotNone(context.hyper_dict) + + search_space_str = symbolic.to_json_str(context.dna_spec) + + context2 = DynamicEvaluationContext( + require_hyper_name=True, per_thread=True, + dna_spec=symbolic.from_json_str(search_space_str)) + self.assertTrue(context2.is_external) + self.assertIsNone(context2.hyper_dict) + + with self.assertRaisesRegex( + ValueError, + '`collect` cannot be called .* is using an external DNASpec'): + with context2.collect(): + fun() + + with context2.apply(geno.DNA([1, 2])): + self.assertEqual(fun(), 3) + + def test_search_space_partitioning_without_hyper_name(self): + def fun(): + return sum([ + oneof([1, 2, 3], hints='ssd1'), + oneof([4, 5], hints='ssd2'), + ]) + + context1 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd1') + context2 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd2') + with context1.collect(): + with context2.collect(): + self.assertEqual(fun(), 1 + 4) + + self.assertEqual( + context1.hyper_dict, { + 'decision_0': oneof([1, 2, 3], hints='ssd1') + }) + self.assertEqual( + context2.hyper_dict, { + 'decision_0': oneof([4, 5], hints='ssd2') + }) + with context1.apply(geno.DNA(2)): + with context2.apply(geno.DNA(1)): + self.assertEqual(fun(), 3 + 5) + + def test_search_space_partitioning_with_hyper_name(self): + def fun(): + return sum([ + oneof([1, 2, 3], name='x', hints='ssd1'), + oneof([4, 5], name='y', hints='ssd2'), + ]) + + context1 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd1') + context2 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd2') + with context1.collect(): + with context2.collect(): + self.assertEqual(fun(), 1 + 4) + + self.assertEqual( + context1.hyper_dict, { + 'x': oneof([1, 2, 3], name='x', hints='ssd1') + }) + self.assertEqual( + context2.hyper_dict, { + 'y': oneof([4, 5], name='y', hints='ssd2') + }) + with context1.apply(geno.DNA(2)): + with context2.apply(geno.DNA(1)): + self.assertEqual(fun(), 3 + 5) + + def test_hierarchial_search_space_with_partitioning(self): + def fun(): + return sum([ + oneof([ + lambda: oneof([1, 2, 3], name='y', hints='ssd1'), + lambda: oneof([4, 5, 6], name='z', hints='ssd1'), + ], name='x', hints='ssd1'), + oneof([7, 8], name='p', hints='ssd2'), + oneof([9, 10], name='q', hints='ssd2'), + ]) + context1 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd1') + context2 = DynamicEvaluationContext(where=lambda x: x.hints == 'ssd2') + with context1.collect(): + with context2.collect(): + self.assertEqual(fun(), 1 + 7 + 9) + + self.assertEqual( + context1.hyper_dict, { + 'x': oneof([ + {'y': oneof([1, 2, 3], name='y', hints='ssd1')}, + {'z': oneof([4, 5, 6], name='z', hints='ssd1')}, + ], name='x', hints='ssd1') + }) + self.assertEqual( + context2.hyper_dict, { + 'p': oneof([7, 8], name='p', hints='ssd2'), + 'q': oneof([9, 10], name='q', hints='ssd2') + }) + with context1.apply(geno.DNA((1, 1))): + with context2.apply(geno.DNA([0, 1])): + self.assertEqual(fun(), 5 + 7 + 10) + + def test_search_space_partitioning_with_different_per_thread_settings(self): + context1 = DynamicEvaluationContext(per_thread=True) + context2 = DynamicEvaluationContext(per_thread=False) + + def fun(): + return oneof([1, 2, 3]) + + with self.assertRaisesRegex( + ValueError, + 'Nested dynamic evaluation contexts must be either .*'): + with context1.collect(): + with context2.collect(): + fun() + + def test_manual_decision_point_registration(self): + context = DynamicEvaluationContext() + self.assertEqual( + context.add_decision_point(oneof([1, 2, 3])), 1) + self.assertEqual( + context.add_decision_point(oneof(['a', 'b'], name='x')), 'a') + self.assertEqual( + context.add_decision_point(pg_template(1)), 1) + + with self.assertRaisesRegex( + ValueError, 'Found different hyper primitives under the same name'): + context.add_decision_point(oneof(['foo', 'bar'], name='x')) + + self.assertEqual(context.hyper_dict, { + 'decision_0': oneof([1, 2, 3]), + 'x': oneof(['a', 'b'], name='x'), + }) + + with self.assertRaisesRegex( + ValueError, '`evaluate` needs to be called under the `apply` context'): + context.evaluate(oneof([1, 2, 3])) + + with context.apply([1, 1]): + self.assertEqual(context.evaluate(context.hyper_dict['decision_0']), 2) + self.assertEqual(context.evaluate(context.hyper_dict['x']), 'b') + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/evolvable.py b/pyglove/core/hyper/evolvable.py new file mode 100644 index 0000000..48e04c0 --- /dev/null +++ b/pyglove/core/hyper/evolvable.py @@ -0,0 +1,278 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evolvable symbolic values.""" + +import dataclasses +import enum +import random +import types +from typing import Any, Callable, List, Optional, Tuple, Union + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import custom + + +class MutationType(str, enum.Enum): + """Mutation type.""" + REPLACE = 0 + INSERT = 1 + DELETE = 2 + + +@dataclasses.dataclass +class MutationPoint: + """Internal class that encapsulates the information for a mutation point. + + Attributes: + mutation_type: The type of the mutation. + location: The location where the mutation will take place. + old_value: The value of the mutation point before mutation. + parent: The parent node of the mutation point. + """ + mutation_type: 'MutationType' + location: object_utils.KeyPath + old_value: Any + parent: symbolic.Symbolic + + +class Evolvable(custom.CustomHyper): + """Hyper primitive for evolving an arbitrary symbolic object.""" + + def _on_bound(self): + super()._on_bound() + self._weights = self.weights or (lambda mt, k, v, p: 1.0) + + def custom_decode(self, dna: geno.DNA) -> Any: + assert isinstance(dna.value, str) + # TODO(daiyip): consider compression. + return symbolic.from_json_str(dna.value) + + def custom_encode(self, value: Any) -> geno.DNA: + return geno.DNA(symbolic.to_json_str(value)) + + def mutation_points_and_weights( + self, + value: symbolic.Symbolic) -> Tuple[List[MutationPoint], List[float]]: + """Returns mutation points with weights for a symbolic tree.""" + mutation_points: List[MutationPoint] = [] + mutation_weights: List[float] = [] + + def _choose_mutation_point(k: object_utils.KeyPath, + v: Any, + p: Optional[symbolic.Symbolic]): + """Visiting function for a symbolic node.""" + def _add_point(mt: MutationType, k=k, v=v, p=p): + assert p is not None + mutation_points.append(MutationPoint(mt, k, v, p)) + mutation_weights.append(self._weights(mt, k, v, p)) + + if p is not None: + # Stopping mutating current branch if metadata said so. + f = p.sym_attr_field(k.key) + if f and f.metadata and 'no_mutation' in f.metadata: + return symbolic.TraverseAction.CONTINUE + _add_point(MutationType.REPLACE) + + # Special handle list traversal to add insertion and deletion. + if isinstance(v, symbolic.List): + if v.value_spec: + spec = v.value_spec + reached_max_size = spec.max_size and len(v) == spec.max_size + reached_min_size = spec.min_size and len(v) == spec.min_size + else: + reached_max_size = False + reached_min_size = False + + for i, cv in enumerate(v): + ck = object_utils.KeyPath(i, parent=k) + if not reached_max_size: + _add_point(MutationType.INSERT, + k=ck, v=object_utils.MISSING_VALUE, p=v) + + if not reached_min_size: + _add_point(MutationType.DELETE, k=ck, v=cv, p=v) + + # Replace type and value will be added in traverse. + symbolic.traverse(cv, _choose_mutation_point, root_path=ck, parent=v) + if not reached_max_size and i == len(v) - 1: + _add_point(MutationType.INSERT, + k=object_utils.KeyPath(i + 1, parent=k), + v=object_utils.MISSING_VALUE, + p=v) + return symbolic.TraverseAction.CONTINUE + return symbolic.TraverseAction.ENTER + + # First-order traverse the symbolic tree to compute + # the mutation points and weights. + symbolic.traverse(value, _choose_mutation_point) + return mutation_points, mutation_weights + + def first_dna(self) -> geno.DNA: + """Returns the first DNA of current sub-space.""" + return self.custom_encode(self.initial_value) + + def random_dna( + self, + random_generator: Union[types.ModuleType, random.Random, None] = None, + previous_dna: Optional[geno.DNA] = None) -> geno.DNA: + """Generates a random DNA.""" + random_generator = random_generator or random + if previous_dna is None: + return self.first_dna() + return self.custom_encode( + self.mutate(self.custom_decode(previous_dna), random_generator)) + + def mutate( + self, + value: symbolic.Symbolic, + random_generator: Union[types.ModuleType, random.Random, None] = None + ) -> symbolic.Symbolic: + """Returns the next value for a symbolic value.""" + r = random_generator or random + points, weights = self.mutation_points_and_weights(value) + [point] = r.choices(points, weights, k=1) + + # Mutating value. + if point.mutation_type == MutationType.REPLACE: + assert point.location, point + value.rebind({ + str(point.location): self.node_transform( + point.location, point.old_value, point.parent)}) + elif point.mutation_type == MutationType.INSERT: + assert isinstance(point.parent, symbolic.List), point + assert point.old_value == object_utils.MISSING_VALUE, point + assert isinstance(point.location.key, int), point + with symbolic.allow_writable_accessors(): + point.parent.insert( + point.location.key, + self.node_transform(point.location, point.old_value, point.parent)) + else: + assert point.mutation_type == MutationType.DELETE, point + assert isinstance(point.parent, symbolic.List), point + assert isinstance(point.location.key, int), point + with symbolic.allow_writable_accessors(): + del point.parent[point.location.key] + return value + + +# We defer members declaration for Evolvable since the weights will reference +# the definition of MutationType. +symbolic.members([ + ('initial_value', pg_typing.Object(symbolic.Symbolic), + 'Symbolic value to involve.'), + ('node_transform', pg_typing.Callable( + [], + returns=pg_typing.Any()), + ''), + ('weights', pg_typing.Callable( + [ + pg_typing.Object(MutationType), + pg_typing.Object(object_utils.KeyPath), + pg_typing.Any().noneable(), + pg_typing.Object(symbolic.Symbolic) + ], returns=pg_typing.Float(min_value=0.0)).noneable(), + ('An optional callable object that returns the unnormalized (e.g. ' + 'the sum of all probabilities do not have to sum to 1.0) mutation ' + 'probabilities for all the nodes in the symbolic tree, based on ' + '(mutation type, location, old value, parent node). If None, all the ' + 'locations and mutation types will be sampled uniformly.')), +])(Evolvable) + + +def evolve( + initial_value: symbolic.Symbolic, + node_transform: Callable[ + [ + object_utils.KeyPath, # Location. + Any, # Old value. + # pg.MISSING_VALUE for insertion. + symbolic.Symbolic, # Parent node. + ], + Any # Replacement. + ], + *, + weights: Optional[Callable[ + [ + MutationType, # Mutation type. + object_utils.KeyPath, # Location. + Any, # Value. + symbolic.Symbolic, # Parent. + ], + float # Mutation weight. + ]] = None, # pylint: disable=bad-whitespace + name: Optional[str] = None, + hints: Optional[Any] = None) -> Evolvable: + """An evolvable symbolic value. + + Example:: + + @pg.symbolize + @dataclasses.dataclass + class Foo: + x: int + y: int + + @pg.symbolize + @dataclasses.dataclass + class Bar: + a: int + b: int + + # Defines possible transitions. + def node_transform(location, value, parent): + if isinstance(value, Foo) + return Bar(value.x, value.y) + if location.key == 'x': + return random.choice([1, 2, 3]) + if location.key == 'y': + return random.choice([3, 4, 5]) + + v = pg.evolve(Foo(1, 3), node_transform) + + See also: + + * :class:`pyglove.hyper.Evolvable` + * :func:`pyglove.oneof` + * :func:`pyglove.manyof` + * :func:`pyglove.permutate` + * :func:`pyglove.floatv` + + Args: + initial_value: The initial value to evolve. + node_transform: A callable object that takes information of the value to + operate (e.g. location, old value, parent node) and returns a new value as + a replacement for the node. Such information allows users to not only + access the mutation node, but the entire symbolic tree if needed, allowing + complex mutation rules to be written with ease - for example - check + adjacent nodes while modifying a list element. This function is designed + to take care of both node replacements and node insertions. When insertion + happens, the old value for the location will be `pg.MISSING_VALUE`. See + `pg.composing.SeenObjectReplacer` as an example. + weights: An optional callable object that returns the unnormalized (e.g. + the sum of all probabilities don't have to sum to 1.0) mutation + probabilities for all the nodes in the symbolic tree, based on (mutation + type, location, old value, parent node), If None, all the locations and + mutation types will be sampled uniformly. + name: An optional name of the decision point. + hints: An optional hints for the decision point. + + Returns: + A `pg.hyper.Evolvable` object. + """ + return Evolvable( + initial_value=initial_value, node_transform=node_transform, + weights=weights, name=name, hints=hints) diff --git a/pyglove/core/hyper/evolvable_test.py b/pyglove/core/hyper/evolvable_test.py new file mode 100644 index 0000000..0a12fbb --- /dev/null +++ b/pyglove/core/hyper/evolvable_test.py @@ -0,0 +1,235 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.Evolvable.""" + +import random +import re +import unittest + +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper.evolvable import evolve +from pyglove.core.hyper.evolvable import MutationType + + +class Layer(symbolic.Object): + pass + + +@symbolic.members([ + ('layers', pg_typing.List(pg_typing.Object(Layer))), +]) +class Sequential(Layer): + pass + + +class Activation(Layer): + pass + + +class ReLU(Activation): + pass + + +class Swish(Activation): + pass + + +@symbolic.members([ + ('filters', pg_typing.Int(min_value=1)), + # `kernel_size` is marked as no_mutation, which should not appear as a + # mutation candidate. + ('kernel_size', pg_typing.Int(min_value=1), '', {'no_mutation': True}), + ('activation', pg_typing.Object(Activation).noneable()) +]) +class Conv(Layer): + pass + + +def seed_program(): + return Sequential([ + Conv(16, 3, ReLU()), + Conv(32, 5, Swish()), + Sequential([ + Conv(64, 7) + ]) + ]) + + +def mutate_at_location(mutation_type: MutationType, location: str): + def _weights(mt, k, v, p): + del v, p + if mt == mutation_type and re.match(location, str(k)): + return 1.0 + return 0.0 + return _weights + + +class EvolvableTest(unittest.TestCase): + """Tests for pg.hyper.Evolvable.""" + + def test_basics(self): + v = evolve( + seed_program(), lambda k, v, p: ReLU(), + weights=mutate_at_location(MutationType.REPLACE, r'^layers\[.*\]$')) + self.assertEqual( + seed_program(), v.custom_decode(v.custom_encode(seed_program()))) + self.assertEqual(v.first_dna(), v.custom_encode(seed_program())) + self.assertEqual(v.random_dna(), v.custom_encode(seed_program())) + self.assertEqual( + v.random_dna(random.Random(1), v.first_dna()), + v.custom_encode( + Sequential([ + ReLU(), + Conv(32, 5, Swish()), + Sequential([ + Conv(64, 7) + ]) + ]))) + + def test_replace(self): + v = evolve( + seed_program(), lambda k, v, p: ReLU(), + weights=mutate_at_location(MutationType.REPLACE, r'^layers\[1\]$')) + self.assertEqual( + v.mutate(seed_program()), + Sequential([ + Conv(16, 3, ReLU()), + ReLU(), + Sequential([ + Conv(64, 7) + ]) + ])) + + def test_insertion(self): + v = evolve( + seed_program(), lambda k, v, p: ReLU(), + weights=mutate_at_location(MutationType.INSERT, r'^layers\[1\]$')) + self.assertEqual( + v.mutate(seed_program()), + Sequential([ + Conv(16, 3, ReLU()), + ReLU(), + Conv(32, 5, Swish()), + Sequential([ + Conv(64, 7) + ]) + ])) + + def test_delete(self): + v = evolve( + seed_program(), lambda k, v, p: ReLU(), + weights=mutate_at_location(MutationType.DELETE, r'^layers\[1\]$')) + self.assertEqual( + v.mutate(seed_program(), random.Random(1)), + Sequential([ + Conv(16, 3, ReLU()), + Sequential([ + Conv(64, 7) + ]) + ])) + + def test_random_generator(self): + v = evolve( + seed_program(), lambda k, v, p: ReLU(), + weights=mutate_at_location(MutationType.REPLACE, r'^layers\[.*\]$')) + self.assertEqual( + v.mutate(seed_program(), random_generator=random.Random(1)), + Sequential([ + ReLU(), + Conv(32, 5, Swish()), + Sequential([ + Conv(64, 7) + ]) + ])) + + def test_mutation_points_and_weights(self): + v = evolve(seed_program(), lambda k, v, p: v, weights=lambda *x: 1.0) + points, weights = v.mutation_points_and_weights(seed_program()) + + # NOTE(daiyip): Conv.kernel_size is marked with 'no_mutation', thus + # it should not show here. + self.assertEqual([(p.mutation_type, p.location) for p in points], [ + (MutationType.REPLACE, 'layers'), + (MutationType.INSERT, 'layers[0]'), + (MutationType.DELETE, 'layers[0]'), + (MutationType.REPLACE, 'layers[0]'), + (MutationType.REPLACE, 'layers[0].filters'), + (MutationType.REPLACE, 'layers[0].activation'), + (MutationType.INSERT, 'layers[1]'), + (MutationType.DELETE, 'layers[1]'), + (MutationType.REPLACE, 'layers[1]'), + (MutationType.REPLACE, 'layers[1].filters'), + (MutationType.REPLACE, 'layers[1].activation'), + (MutationType.INSERT, 'layers[2]'), + (MutationType.DELETE, 'layers[2]'), + (MutationType.REPLACE, 'layers[2]'), + (MutationType.REPLACE, 'layers[2].layers'), + (MutationType.INSERT, 'layers[2].layers[0]'), + (MutationType.DELETE, 'layers[2].layers[0]'), + (MutationType.REPLACE, 'layers[2].layers[0]'), + (MutationType.REPLACE, 'layers[2].layers[0].filters'), + (MutationType.REPLACE, 'layers[2].layers[0].activation'), + (MutationType.INSERT, 'layers[2].layers[1]'), + (MutationType.INSERT, 'layers[3]'), + ]) + self.assertEqual(weights, [1.0] * len(points)) + + def test_mutation_points_and_weights_with_honoring_list_size(self): + # Non-typed list. There is no size limit. + v = evolve( + symbolic.List([]), lambda k, v, p: v, + weights=lambda *x: 1.0) + points, _ = v.mutation_points_and_weights(symbolic.List([1])) + self.assertEqual([(p.mutation_type, p.location) for p in points], [ + (MutationType.INSERT, '[0]'), + (MutationType.DELETE, '[0]'), + (MutationType.REPLACE, '[0]'), + (MutationType.INSERT, '[1]'), + ]) + + # Typed list with size limit. + value_spec = pg_typing.List(pg_typing.Int(), min_size=1, max_size=3) + points, _ = v.mutation_points_and_weights( + symbolic.List([1, 2], value_spec=value_spec)) + self.assertEqual([(p.mutation_type, p.location) for p in points], [ + (MutationType.INSERT, '[0]'), + (MutationType.DELETE, '[0]'), + (MutationType.REPLACE, '[0]'), + (MutationType.INSERT, '[1]'), + (MutationType.DELETE, '[1]'), + (MutationType.REPLACE, '[1]'), + (MutationType.INSERT, '[2]'), + ]) + points, _ = v.mutation_points_and_weights( + symbolic.List([1], value_spec=value_spec)) + self.assertEqual([(p.mutation_type, p.location) for p in points], [ + (MutationType.INSERT, '[0]'), + (MutationType.REPLACE, '[0]'), + (MutationType.INSERT, '[1]'), + ]) + points, _ = v.mutation_points_and_weights( + symbolic.List([1, 2, 3], value_spec=value_spec)) + self.assertEqual([(p.mutation_type, p.location) for p in points], [ + (MutationType.DELETE, '[0]'), + (MutationType.REPLACE, '[0]'), + (MutationType.DELETE, '[1]'), + (MutationType.REPLACE, '[1]'), + (MutationType.DELETE, '[2]'), + (MutationType.REPLACE, '[2]'), + ]) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/iter.py b/pyglove/core/hyper/iter.py new file mode 100644 index 0000000..64a5cc4 --- /dev/null +++ b/pyglove/core/hyper/iter.py @@ -0,0 +1,193 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Iterating hyper values.""" + +from typing import Any, Callable, Optional, Tuple, Union + +from pyglove.core import geno +from pyglove.core import symbolic +from pyglove.core.hyper import base +from pyglove.core.hyper import dynamic_evaluation +from pyglove.core.hyper import object_template + + +def iterate(hyper_value: Any, + num_examples: Optional[int] = None, + algorithm: Optional[geno.DNAGenerator] = None, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None, + force_feedback: bool = False): + """Iterate a hyper value based on an algorithm. + + Example:: + + hyper_dict = pg.Dict(x=pg.oneof([1, 2, 3]), y=pg.oneof(['a', 'b'])) + + # Get all examples from the hyper_dict. + assert list(pg.iter(hyper_dict)) == [ + pg.Dict(x=1, y='a'), + pg.Dict(x=1, y='b'), + pg.Dict(x=2, y='a'), + pg.Dict(x=2, y='b'), + pg.Dict(x=3, y='a'), + pg.Dict(x=3, y='b'), + ] + + # Get the first two examples. + assert list(pg.iter(hyper_dict, 2)) == [ + pg.Dict(x=1, y='a'), + pg.Dict(x=1, y='b'), + ] + + # Random sample examples, which is equivalent to `pg.random_sample`. + list(pg.iter(hyper_dict, 2, pg.geno.Random())) + + # Iterate examples with feedback loop. + for d, feedback in pg.iter( + hyper_dict, 10, + pg.evolution.regularized_evolution(pg.evolution.mutators.Uniform())): + feedback(d.x) + + # Only materialize selected parts. + assert list( + pg.iter(hyper_dict, where=lambda x: len(x.candidates) == 2)) == [ + pg.Dict(x=pg.oneof([1, 2, 3]), y='a'), + pg.Dict(x=pg.oneof([1, 2, 3]), y='b'), + ] + + ``pg.iter`` distinguishes from `pg.sample` in that it's designed + for simple in-process iteration, which is handy for quickly generating + examples from algorithms without maintaining trail states. On the contrary, + `pg.sample` is designed for distributed sampling, with parallel workers and + failover handling. + + Args: + hyper_value: A hyper value that represents a space of instances. + num_examples: An optional integer as the max number of examples to + propose. If None, propose will return an iterator of infinite examples. + algorithm: An optional DNA generator. If None, Sweeping will be used, which + iterates examples in order. + where: Function to filter hyper primitives. If None, all hyper primitives + from `value` will be included in the encoding/decoding process. Otherwise + only the hyper primitives on which 'where' returns True will be included. + `where` can be useful to partition a search space into separate + optimization processes. Please see 'Template' docstr for details. + force_feedback: If True, always return the Feedback object together + with the example, this is useful when the user want to pass different + DNAGenerators to `pg.iter` and want to handle them uniformly. + + Yields: + A tuple of (example, feedback_fn) if the algorithm needs a feedback or + `force_feedback` is True, otherwise the example. + + Raises: + ValueError: when `hyper_value` is a constant value. + """ + if isinstance(hyper_value, dynamic_evaluation.DynamicEvaluationContext): + dynamic_evaluation_context = hyper_value + spec = hyper_value.dna_spec + t = None + else: + t = object_template.template(hyper_value, where) + if t.is_constant: + raise ValueError( + f'\'hyper_value\' is a constant value: {hyper_value!r}.') + dynamic_evaluation_context = None + spec = t.dna_spec() + + if algorithm is None: + algorithm = geno.Sweeping() + + # NOTE(daiyip): algorithm can continue if it's already set up with the same + # DNASpec, or we will setup the algorithm with the DNASpec from the template. + if algorithm.dna_spec is None: + algorithm.setup(spec) + elif symbolic.ne(spec, algorithm.dna_spec): + raise ValueError( + f'{algorithm!r} has been set up with a different DNASpec. ' + f'Existing: {algorithm.dna_spec!r}, New: {spec!r}.') + + count = 0 + while num_examples is None or count < num_examples: + try: + count += 1 + dna = algorithm.propose() + if t is not None: + example = t.decode(dna) + else: + assert dynamic_evaluation_context is not None + example = lambda: dynamic_evaluation_context.apply(dna) + if force_feedback or algorithm.needs_feedback: + yield example, Feedback(algorithm, dna) + else: + yield example + except StopIteration: + return + + +class Feedback: + """Feedback object.""" + + def __init__(self, algorithm: geno.DNAGenerator, dna: geno.DNA): + """Creates a feedback object.""" + self._algorithm = algorithm + self._dna = dna + + def __call__(self, reward: Union[float, Tuple[float, ...]]): + """Call to feedback reward.""" + self._algorithm.feedback(self._dna, reward) + + @property + def dna(self) -> geno.DNA: + """Returns DNA.""" + return self._dna + + +def random_sample( + value: Any, + num_examples: Optional[int] = None, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None, + seed: Optional[int] = None): + """Returns an iterator of random sampled examples. + + Example:: + + hyper_dict = pg.Dict(x=pg.oneof(range(3)), y=pg.floatv(0.0, 1.0)) + + # Generate one random example from the hyper_dict. + d = next(pg.random_sample(hyper_dict)) + + # Generate 5 random examples with random seed. + ds = list(pg.random_sample(hyper_dict, 5, seed=1)) + + # Generate 3 random examples of `x` with `y` intact. + ds = list(pg.random_sample(hyper_dict, 3, + where=lambda x: isinstance(x, pg.hyper.OneOf))) + + + Args: + value: A (maybe) hyper value. + num_examples: An optional integer as number of examples to propose. If None, + propose will return an iterator that iterates forever. + where: Function to filter hyper primitives. If None, all hyper primitives in + `value` will be included in the encoding/decoding process. Otherwise only + the hyper primitives on which 'where' returns True will be included. + `where` can be useful to partition a search space into separate + optimization processes. Please see 'Template' docstr for details. + seed: An optional integer as random seed. + + Returns: + Iterator of random examples. + """ + return iterate( + value, num_examples, geno.Random(seed), where=where) diff --git a/pyglove/core/hyper/iter_test.py b/pyglove/core/hyper/iter_test.py new file mode 100644 index 0000000..2bef98a --- /dev/null +++ b/pyglove/core/hyper/iter_test.py @@ -0,0 +1,135 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.Float.""" + +import unittest + +from pyglove.core import geno +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.dynamic_evaluation import trace as pg_trace +from pyglove.core.hyper.iter import iterate as pg_iterate +from pyglove.core.hyper.iter import random_sample as pg_random_sample + + +class IterateTest(unittest.TestCase): + """Tests for pg.iter.""" + + def test_iter_with_default_algorithm(self): + v = oneof(range(100)) + examples = list(pg_iterate(v)) + self.assertEqual(examples, list(range(100))) + + examples = list(pg_iterate(v, 10)) + self.assertEqual(examples, list(range(10))) + + def test_iter_with_custom_algorithm(self): + + class ConstantAlgorithm(geno.DNAGenerator): + + def _on_bound(self): + self._rewards = [] + + def _propose(self): + if len(self._rewards) == 100: + raise StopIteration() + return geno.DNA(0) + + def _feedback(self, dna, reward): + self._rewards.append(reward) + + @property + def rewards(self): + return self._rewards + + algo = ConstantAlgorithm() + examples = [] + for i, (x, feedback) in enumerate(pg_iterate(oneof([1, 3]), 5, algo)): + examples.append(x) + feedback(float(i)) + self.assertEqual(feedback.dna, geno.DNA(0)) + self.assertEqual(len(examples), 5) + self.assertEqual(examples, [1] * 5) + self.assertEqual(algo.rewards, [float(i) for i in range(5)]) + + for x, feedback in pg_iterate(oneof([1, 3]), algorithm=algo): + examples.append(x) + feedback(0.) + self.assertEqual(len(examples), 100) + + def test_iter_with_dynamic_evaluation(self): + def foo(): + return oneof([1, 3]) + examples = [] + for x in pg_iterate(pg_trace(foo)): + with x(): + examples.append(foo()) + self.assertEqual(examples, [1, 3]) + + def test_iter_with_continuation(self): + + class ConstantAlgorithm3(geno.DNAGenerator): + + def setup(self, dna_spec): + super().setup(dna_spec) + self.num_trials = 0 + + def propose(self): + self.num_trials += 1 + return geno.DNA(0) + + algo = ConstantAlgorithm3() + for unused_x in pg_iterate(oneof([1, 3]), 10, algo): + pass + for unused_x in pg_iterate(oneof([1, 3]), 10, algo): + pass + self.assertEqual(algo.num_trials, 20) + + def test_iter_with_forced_feedback(self): + + class ConstantAlgorithm2(geno.DNAGenerator): + + def propose(self): + return geno.DNA(0) + + algo = ConstantAlgorithm2() + examples = [] + for x, feedback in pg_iterate( + oneof([1, 3]), 10, algorithm=algo, force_feedback=True): + examples.append(x) + # No op. + feedback(0.) + self.assertEqual(len(examples), 10) + + def test_bad_iter(self): + with self.assertRaisesRegex( + ValueError, '\'hyper_value\' is a constant value'): + next(pg_iterate('foo')) + + algo = geno.Random() + next(pg_iterate(oneof([1, 2]), 1, algo)) + with self.assertRaisesRegex( + ValueError, '.* has been set up with a different DNASpec'): + next(pg_iterate(oneof([2, 3]), 10, algo)) + + +class RandomSampleTest(unittest.TestCase): + """Tests for pg.random_sample.""" + + def test_random_sample(self): + self.assertEqual( + list(pg_random_sample(oneof([0, 1]), 3, seed=123)), [0, 1, 0]) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/numerical.py b/pyglove/core/hyper/numerical.py new file mode 100644 index 0000000..4c3c426 --- /dev/null +++ b/pyglove/core/hyper/numerical.py @@ -0,0 +1,219 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Numerical hyper primitives.""" + +import typing +from typing import Any, Callable, Optional, Tuple + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import base + + +@symbolic.members( + [ + ('min_value', pg_typing.Float(), 'Minimum acceptable value.'), + ('max_value', pg_typing.Float(), 'Maximum acceptable value.'), + geno.float_scale_spec('scale'), + ], + init_arg_list=['min_value', 'max_value', 'scale', 'name', 'hints'], + serialization_key='hyper.Float', + additional_keys=['pyglove.generators.genetic.Float'] +) +class Float(base.HyperPrimitive): + """A continuous value within a range. + + Example:: + + # A float value between between 0.0 and 1.0. + v = pg.floatv(0.0, 1.0) + + See also: + + * :func:`pyglove.floatv` + * :class:`pyglove.hyper.Choices` + * :class:`pyglove.hyper.OneOf` + * :class:`pyglove.hyper.ManyOf` + * :class:`pyglove.hyper.CustomHyper` + """ + + def _on_bound(self): + """Constructor.""" + super()._on_bound() + if self.min_value > self.max_value: + raise ValueError( + f'\'min_value\' ({self.min_value}) is greater than \'max_value\' ' + f'({self.max_value}).') + if self.scale in ['log', 'rlog'] and self.min_value <= 0: + raise ValueError( + f'\'min_value\' must be positive when `scale` is {self.scale!r}. ' + f'encountered: {self.min_value}.') + + def dna_spec(self, + location: Optional[object_utils.KeyPath] = None) -> geno.Float: + """Returns corresponding DNASpec.""" + return geno.Float( + min_value=self.min_value, + max_value=self.max_value, + scale=self.scale, + hints=self.hints, + name=self.name, + location=location or object_utils.KeyPath()) + + def _decode(self) -> float: + """Decode a DNA into a float value.""" + dna = self._dna + if not isinstance(dna.value, float): + raise ValueError( + object_utils.message_on_path( + f'Expect float value. Encountered: {dna.value}.', self.sym_path)) + if dna.value < self.min_value: + raise ValueError( + object_utils.message_on_path( + f'DNA value should be no less than {self.min_value}. ' + f'Encountered {dna.value}.', self.sym_path)) + + if dna.value > self.max_value: + raise ValueError( + object_utils.message_on_path( + f'DNA value should be no greater than {self.max_value}. ' + f'Encountered {dna.value}.', self.sym_path)) + return dna.value + + def encode(self, value: float) -> geno.DNA: + """Encode a float value into a DNA.""" + if not isinstance(value, float): + raise ValueError( + object_utils.message_on_path( + f'Value should be float to be encoded for {self!r}. ' + f'Encountered {value}.', self.sym_path)) + if value < self.min_value: + raise ValueError( + object_utils.message_on_path( + f'Value should be no less than {self.min_value}. ' + f'Encountered {value}.', self.sym_path)) + if value > self.max_value: + raise ValueError( + object_utils.message_on_path( + f'Value should be no greater than {self.max_value}. ' + f'Encountered {value}.', self.sym_path)) + return geno.DNA(value) + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool = False, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'Float']: + """Validate candidates during value_spec binding time.""" + del allow_partial + del child_transform + # Check if value_spec directly accepts `self`. + if value_spec.value_type and isinstance(self, value_spec.value_type): + return (False, self) + + float_spec = typing.cast( + pg_typing.Float, pg_typing.ensure_value_spec( + value_spec, pg_typing.Float(), path)) + if float_spec: + if (float_spec.min_value is not None + and self.min_value < float_spec.min_value): + raise ValueError( + object_utils.message_on_path( + f'Float.min_value ({self.min_value}) should be no less than ' + f'the min value ({float_spec.min_value}) of value spec: ' + f'{float_spec}.', path)) + if (float_spec.max_value is not None + and self.max_value > float_spec.max_value): + raise ValueError( + object_utils.message_on_path( + f'Float.max_value ({self.max_value}) should be no greater than ' + f'the max value ({float_spec.max_value}) of value spec: ' + f'{float_spec}.', path)) + return (False, self) + + def is_leaf(self) -> bool: + """Returns whether this is a leaf node.""" + return True + + +def floatv(min_value: float, + max_value: float, + scale: Optional[str] = None, + *, + name: Optional[str] = None, + hints: Optional[Any] = None) -> Any: + """A continuous value within a range. + + Example:: + + # A continuous value within [0.0, 1.0] + v = pg.floatv(0.0, 1.0) + + See also: + + * :class:`pyglove.hyper.Float` + * :func:`pyglove.oneof` + * :func:`pyglove.manyof` + * :func:`pyglove.permutate` + * :func:`pyglove.evolve` + + .. note:: + + Under symbolic mode (by default), `pg.floatv` returns a ``pg.hyper.Float`` + object. Under dynamic evaluate mode, which is called under the context of + :meth:`pyglove.hyper.DynamicEvaluationContext.collect` or + :meth:`pyglove.hyper.DynamicEvaluationContext.apply`, it evaluates to + a concrete candidate value. + + Args: + min_value: Minimum acceptable value (inclusive). + max_value: Maximum acceptable value (inclusive). + scale: An optional string as the scale of the range. Supported values + are None, 'linear', 'log', and 'rlog'. + If None, the feasible space is unscaled. + If `linear`, the feasible space is mapped to [0, 1] linearly. + If `log`, the feasible space is mapped to [0, 1] logarithmically with + formula `x -> log(x / min) / log(max / min)`. + If `rlog`, the feasible space is mapped to [0, 1] "reverse" + logarithmically, resulting in values close to `max_value` spread + out more than the points near the `min_value`, with formula: + x -> 1.0 - log((max + min - x) / min) / log (max / min). + `min_value` must be positive if `scale` is not None. + Also, it depends on the search algorithm to decide whether this + information is used or not. + name: A name that can be used to identify a decision point in the search + space. This is needed when the code to instantiate the same hyper + primitive may be called multiple times under a + `pg.DynamicEvaluationContext.collect` context or a + `pg.DynamicEvaluationContext.apply` context. + hints: An optional value which acts as a hint for the controller. + + Returns: + In symbolic mode, this function returns a `Float`. + In dynamic evaluate mode, this function returns a float value that is no + less than the `min_value` and no greater than the `max_value`. + If evaluated under an `pg.DynamicEvaluationContext.apply` scope, + this function will return a chosen float value from the controller + decisions. + If evaluated under a `pg.DynamicEvaluationContext.collect` + scope, it will return `min_value`. + """ + return Float( + min_value=min_value, max_value=max_value, + scale=scale, name=name, hints=hints) diff --git a/pyglove/core/hyper/numerical_test.py b/pyglove/core/hyper/numerical_test.py new file mode 100644 index 0000000..fc1c78d --- /dev/null +++ b/pyglove/core/hyper/numerical_test.py @@ -0,0 +1,134 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.Float.""" + +import unittest + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper.numerical import Float +from pyglove.core.hyper.numerical import floatv + + +class FloatTest(unittest.TestCase): + """Test for hyper.Float.""" + + def test_basics(self): + v = floatv(0.0, 1.0) + self.assertEqual(v.min_value, 0.0) + self.assertEqual(v.max_value, 1.0) + self.assertIsNone(v.scale) + self.assertTrue(v.is_leaf) + + with self.assertRaisesRegex( + ValueError, '\'min_value\' .* is greater than \'max_value\' .*'): + floatv(min_value=1.0, max_value=0.0) + + def test_scale(self): + self.assertEqual(floatv(-1.0, 1.0, 'linear').scale, 'linear') + with self.assertRaisesRegex( + ValueError, '\'min_value\' must be positive'): + floatv(-1.0, 1.0, 'log') + + def test_dna_spec(self): + self.assertTrue(symbolic.eq( + floatv(0.0, 1.0).dna_spec('a'), + geno.Float( + location=object_utils.KeyPath('a'), + min_value=0.0, + max_value=1.0))) + + def test_decode(self): + v = floatv(0.0, 1.0) + self.assertEqual(v.decode(geno.DNA(0.0)), 0.0) + self.assertEqual(v.decode(geno.DNA(1.0)), 1.0) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + v.decode(geno.DNA(1)) + + with self.assertRaisesRegex( + ValueError, 'DNA value should be no less than'): + v.decode(geno.DNA(-1.0)) + + with self.assertRaisesRegex( + ValueError, 'DNA value should be no greater than'): + v.decode(geno.DNA(2.0)) + + def test_encode(self): + v = floatv(0.0, 1.0) + self.assertEqual(v.encode(0.0), geno.DNA(0.0)) + self.assertEqual(v.encode(1.0), geno.DNA(1.0)) + + with self.assertRaisesRegex( + ValueError, 'Value should be float to be encoded'): + v.encode('abc') + + with self.assertRaisesRegex( + ValueError, 'Value should be no less than'): + v.encode(-1.0) + + with self.assertRaisesRegex( + ValueError, 'Value should be no greater than'): + v.encode(2.0) + + def test_assignment_compatibility(self): + sd = symbolic.Dict.partial( + value_spec=pg_typing.Dict([ + ('a', pg_typing.Int()), + ('b', pg_typing.Float()), + ('c', pg_typing.Union([pg_typing.Str(), pg_typing.Float()])), + ('d', pg_typing.Any()), + ('e', pg_typing.Float(max_value=0.0)), + ('f', pg_typing.Float(min_value=1.0)) + ])) + v = floatv(min_value=0.0, max_value=1.0) + sd.b = v + sd.c = v + sd.d = v + + self.assertEqual(sd.b.sym_path, 'b') + self.assertEqual(sd.c.sym_path, 'c') + self.assertEqual(sd.d.sym_path, 'd') + with self.assertRaisesRegex( + TypeError, 'Source spec Float\\(\\) is not compatible with ' + 'destination spec Int\\(\\)'): + sd.a = v + + with self.assertRaisesRegex( + ValueError, + 'Float.max_value .* should be no greater than the max value'): + sd.e = v + + with self.assertRaisesRegex( + ValueError, + 'Float.min_value .* should be no less than the min value'): + sd.f = v + + def test_custom_apply(self): + v = floatv(min_value=0.0, max_value=1.0) + self.assertIs(pg_typing.Object(Float).apply(v), v) + self.assertIs(pg_typing.Float().apply(v), v) + with self.assertRaisesRegex( + TypeError, r'Source spec Float\(\) is not compatible'): + pg_typing.Int().apply(v) + + with self.assertRaisesRegex( + ValueError, r'.* should be no less than the min value'): + pg_typing.Float(min_value=2.0).apply(v) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper/object_template.py b/pyglove/core/hyper/object_template.py new file mode 100644 index 0000000..f82e3e7 --- /dev/null +++ b/pyglove/core/hyper/object_template.py @@ -0,0 +1,577 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Object template using hyper primitives.""" + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from pyglove.core import geno +from pyglove.core import object_utils +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper import base +from pyglove.core.hyper import derived + + +class ObjectTemplate(base.HyperValue, object_utils.Formattable): + """Object template that encodes and decodes symbolic values. + + An object template can be created from a hyper value, which is a symbolic + object with some parts placeheld by hyper primitives. For example:: + + x = A(a=0, + b=pg.oneof(['foo', 'bar']), + c=pg.manyof(2, [1, 2, 3, 4, 5, 6]), + d=pg.floatv(0.1, 0.5), + e=pg.oneof([ + { + 'f': pg.oneof([True, False]), + } + { + 'g': pg.manyof(2, [B(), C(), D()], distinct=False), + 'h': pg.manyof(2, [0, 1, 2], sorted=True), + } + ]) + }) + t = pg.template(x) + + In this example, the root template have 4 children hyper primitives associated + with keys 'b', 'c', 'd' and 'e', while the hyper primitive 'e' have 3 children + associated with keys 'f', 'g' and 'h', creating a conditional search space. + + Thus the DNA shape is determined by the definition of template, described + by geno.DNASpec. In this case, the DNA spec of this template looks like:: + + pg.geno.space([ + pg.geno.oneof([ # Spec for 'b'. + pg.geno.constant(), # A constant template for 'foo'. + pg.geno.constant(), # A constant template for 'bar'. + ]), + pg.geno.manyof([ # Spec for 'c'. + pg.geno.constant(), # A constant template for 1. + pg.geno.constant(), # A constant template for 2. + pg.geno.constant(), # A constant template for 3. + pg.geno.constant(), # A constant template for 4. + pg.geno.constant(), # A constant template for 5. + pg.geno.constant(), # A constant template for 6. + ]), + pg.geno.floatv(0.1, 0.5), # Spec for 'd'. + pg.geno.oneof([ # Spec for 'e'. + pg.geno.space([ + pg.geno.oneof([ # Spec for 'f'. + pg.geno.constant(), # A constant template for True. + pg.geno.constant(), # A constant template for False. + ]) + ]), + pg.geno.space([ + pg.geno.manyof(2, [ # Spec for 'g'. + pg.geno.constant(), # A constant template for B(). + pg.geno.constant(), # A constant template for C(). + pg.geno.constant(), # A constant template for D(). + ], distinct=False) # choices of the same value can + # be selected multiple times. + pg.geno.manyof(2, [ # Spec for 'h'. + pg.geno.constant(), # A constant template for 0. + pg.geno.constant(), # A constant template for 1. + pg.geno.constant(), # A constant template for 2. + ], sorted=True) # acceptable choices needs to be sorted, + # which enables using choices as set (of + # possibly repeated values). + ]) + ]) + + It may generate DNA as the following: + DNA([0, [0, 2], 0.1, (0, 0)]) + + A template can also work only on a subset of hyper primitives from the input + value through the `where` function. This is useful to partition a search space + into parts for separate optimization. + + For example:: + + t = pg.hyper.ObjectTemplate( + A(a=pg.oneof([1, 2]), b=pg.oneof([3, 4])), + where=lambda e: e.root_path == 'a') + assert t.dna_spec() == pg.geno.space([ + pg.geno.oneof(location='a', candidates=[ + pg.geno.constant(), # For a=1 + pg.geno.constant(), # For a=2 + ], literal_values=['(0/2) 1', '(1/2) 2']) + ]) + assert t.decode(pg.DNA(0)) == A(a=1, b=pg.oneof([3, 4])) + """ + + def __init__(self, + value: Any, + compute_derived: bool = False, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None): + """Constructor. + + Args: + value: Value (maybe) annotated with generators to use as template. + compute_derived: Whether to compute derived value at this level. + We only want to compute derived value at root level since reference path + may go out of scope of a non-root ObjectTemplate. + where: Function to filter hyper primitives. If None, all hyper primitives + from `value` will be included in the encoding/decoding process. + Otherwise only the hyper primitives on which 'where' returns True will + be included. `where` can be useful to partition a search space into + separate optimization processes. + Please see 'ObjectTemplate' docstr for details. + """ + super().__init__() + self._value = value + self._root_path = object_utils.KeyPath() + self._compute_derived = compute_derived + self._where = where + self._parse_generators() + + @property + def root_path(self) -> object_utils.KeyPath: + """Returns root path.""" + return self._root_path + + @root_path.setter + def root_path(self, path: object_utils.KeyPath): + """Set root path.""" + self._root_path = path + + def _parse_generators(self) -> None: + """Parse generators from its templated value.""" + hyper_primitives = [] + def _extract_immediate_child_hyper_primitives( + path: object_utils.KeyPath, value: Any) -> bool: + """Extract top-level hyper primitives.""" + if (isinstance(value, base.HyperValue) + and (not self._where or self._where(value))): + # Apply where clause to child choices. + if (self._where + and isinstance(value, base.HyperPrimitive) + and hasattr(value, 'where')): + value = value.clone().rebind(where=self._where) + hyper_primitives.append((path, value)) + elif isinstance(value, symbolic.Object): + for k, v in value.sym_items(): + object_utils.traverse( + v, _extract_immediate_child_hyper_primitives, + root_path=object_utils.KeyPath(k, path)) + return True + + object_utils.traverse( + self._value, _extract_immediate_child_hyper_primitives) + self._hyper_primitives = hyper_primitives + + @property + def value(self) -> Any: + """Returns templated value.""" + return self._value + + @property + def hyper_primitives(self) -> List[Tuple[str, base.HyperValue]]: + """Returns hyper primitives in tuple (relative path, hyper primitive).""" + return self._hyper_primitives + + @property + def is_constant(self) -> bool: + """Returns whether current template is constant value.""" + return not self._hyper_primitives + + def dna_spec( + self, location: Optional[object_utils.KeyPath] = None) -> geno.Space: + """Return DNA spec (geno.Space) from this template.""" + return geno.Space( + elements=[ + primitive.dna_spec(primitive_location) + for primitive_location, primitive in self._hyper_primitives + ], + location=location or object_utils.KeyPath()) + + def _decode(self) -> Any: + """Decode DNA into a value.""" + dna = self._dna + if not self._hyper_primitives and (dna.value is not None or dna.children): + raise ValueError( + object_utils.message_on_path( + f'Encountered extra DNA value to decode: {dna!r}', + self._root_path)) + + # Compute hyper primitive values first. + rebind_dict = {} + if len(self._hyper_primitives) == 1: + primitive_location, primitive = self._hyper_primitives[0] + rebind_dict[primitive_location.path] = primitive.decode(dna) + else: + if len(dna.children) != len(self._hyper_primitives): + raise ValueError( + object_utils.message_on_path( + f'The length of child values ({len(dna.children)}) is ' + f'different from the number of hyper primitives ' + f'({len(self._hyper_primitives)}) in ObjectTemplate. ' + f'DNA={dna!r}, ObjectTemplate={self!r}.', self._root_path)) + for i, (primitive_location, primitive) in enumerate( + self._hyper_primitives): + rebind_dict[primitive_location.path] = ( + primitive.decode(dna.children[i])) + + if rebind_dict: + if len(rebind_dict) == 1 and '' in rebind_dict: + # NOTE(daiyip): Special handle the case when the root value needs to be + # replaced. For example: `template(oneof([0, 1])).decode(geno.DNA(0))` + # should return 0 instead of rebinding the root `OneOf` object. + value = rebind_dict[''] + else: + # NOTE(daiyip): Instead of deep copying the whole object (with hyper + # primitives), we can cherry-pick only non-hyper parts. Unless we saw + # performance issues it's not worthy to optimize this. + value = symbolic.clone(self._value, deep=True) + value.rebind(rebind_dict) + copied = True + else: + assert self.is_constant + value = self._value + copied = False + + # Compute derived values if needed. + if self._compute_derived: + # TODO(daiyip): Currently derived value parsing is done at decode time, + # which can be optimized by moving to template creation time. + derived_values = [] + def _extract_derived_values( + path: object_utils.KeyPath, value: Any) -> bool: + """Extract top-level primitives.""" + if isinstance(value, derived.DerivedValue): + derived_values.append((path, value)) + elif isinstance(value, symbolic.Object): + for k, v in value.sym_items(): + object_utils.traverse( + v, _extract_derived_values, + root_path=object_utils.KeyPath(k, path)) + return True + object_utils.traverse(value, _extract_derived_values) + + if derived_values: + if not copied: + value = symbolic.clone(value, deep=True) + rebind_dict = {} + for path, derived_value in derived_values: + rebind_dict[path.path] = derived_value() + assert rebind_dict + value.rebind(rebind_dict) + return value + + def encode(self, value: Any) -> geno.DNA: + """Encode a value into a DNA. + + Example:: + + # DNA of a constant template: + template = pg.hyper.ObjectTemplate({'a': 0}) + assert template.encode({'a': 0}) == pg.DNA(None) + # Raises: Unmatched value between template and input. + template.encode({'a': 1}) + + # DNA of a template containing only one pg.oneof. + template = pg.hyper.ObjectTemplate({'a': pg.oneof([1, 2])}) + assert template.encode({'a': 1}) == pg.DNA(0) + + # DNA of a template containing only one pg.oneof. + template = pg.hyper.ObjectTemplate({'a': pg.floatv(0.1, 1.0)}) + assert template.encode({'a': 0.5}) == pg.DNA(0.5) + + Args: + value: Value to encode. + + Returns: + Encoded DNA. + + Raises: + ValueError if value cannot be encoded by this template. + """ + children = [] + def _encode(path: object_utils.KeyPath, + template_value: Any, + input_value: Any) -> Any: + """Encode input value according to template value.""" + if (pg_typing.MISSING_VALUE == input_value + and pg_typing.MISSING_VALUE != template_value): + raise ValueError( + f'Value is missing from input. Path=\'{path}\'.') + if (isinstance(template_value, base.HyperValue) + and (not self._where or self._where(template_value))): + children.append(template_value.encode(input_value)) + elif isinstance(template_value, derived.DerivedValue): + if self._compute_derived: + referenced_values = [ + reference_path.query(value) + for _, reference_path in template_value.resolve() + ] + derived_value = template_value.derive(*referenced_values) + if derived_value != input_value: + raise ValueError( + f'Unmatched derived value between template and input. ' + f'(Path=\'{path}\', Template={template_value!r}, ' + f'ComputedValue={derived_value!r}, Input={input_value!r})') + # For template that doesn't compute derived value, it get passed over + # to parent template who may be able to handle. + elif isinstance(template_value, symbolic.Object): + if type(input_value) is not type(template_value): + raise ValueError( + f'Unmatched Object type between template and input: ' + f'(Path=\'{path}\', Template={template_value!r}, ' + f'Input={input_value!r})') + template_keys = set(template_value.sym_keys()) + value_keys = set(input_value.sym_keys()) + if template_keys != value_keys: + raise ValueError( + f'Unmatched Object keys between template value and input ' + f'value. (Path=\'{path}\', ' + f'TemplateOnlyKeys={template_keys - value_keys}, ' + f'InputOnlyKeys={value_keys - template_keys})') + for key in template_value.sym_keys(): + object_utils.merge_tree( + template_value.sym_getattr(key), + input_value.sym_getattr(key), + _encode, root_path=object_utils.KeyPath(key, path)) + elif isinstance(template_value, symbolic.Dict): + # Do nothing since merge will iterate all elements in dict and list. + if not isinstance(input_value, dict): + raise ValueError( + f'Unmatched dict between template value and input ' + f'value. (Path=\'{path}\', Template={template_value!r}, ' + f'Input={input_value!r})') + elif isinstance(template_value, symbolic.List): + if (not isinstance(input_value, list) + or len(input_value) != len(template_value)): + raise ValueError( + f'Unmatched list between template value and input ' + f'value. (Path=\'{path}\', Template={template_value!r}, ' + f'Input={input_value!r})') + for i, template_item in enumerate(template_value): + object_utils.merge_tree( + template_item, input_value[i], _encode, + root_path=object_utils.KeyPath(i, path)) + else: + if template_value != input_value: + raise ValueError( + f'Unmatched value between template and input. ' + f'(Path=\'{path}\', ' + f'Template={object_utils.quote_if_str(template_value)}, ' + f'Input={object_utils.quote_if_str(input_value)})') + return template_value + object_utils.merge_tree( + self._value, value, _encode, root_path=self._root_path) + return geno.DNA(None, children) + + def try_encode(self, value: Any) -> Tuple[bool, geno.DNA]: + """Try to encode a value without raise Exception.""" + try: + dna = self.encode(value) + return (True, dna) + except ValueError: + return (False, None) # pytype: disable=bad-return-type + except KeyError: + return (False, None) # pytype: disable=bad-return-type + + def __eq__(self, other): + """Operator ==.""" + if not isinstance(other, self.__class__): + return False + return self.value == other.value + + def __ne__(self, other): + """Operator !=.""" + return not self.__eq__(other) + + def format(self, + compact: bool = False, + verbose: bool = True, + root_indent: int = 0, + **kwargs) -> str: + """Format this object.""" + details = object_utils.format( + self._value, compact, verbose, root_indent, **kwargs) + return f'{self.__class__.__name__}(value={details})' + + def custom_apply( + self, + path: object_utils.KeyPath, + value_spec: pg_typing.ValueSpec, + allow_partial: bool, + child_transform: Optional[Callable[ + [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + ) -> Tuple[bool, 'ObjectTemplate']: + """Validate candidates during value_spec binding time.""" + # Check if value_spec directly accepts `self`. + if not value_spec.value_type or not isinstance(self, value_spec.value_type): + value_spec.apply( + self._value, + allow_partial, + root_path=self.root_path) + return (False, self) + + +def template( + value: Any, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None + ) -> ObjectTemplate: + """Creates an object template from the input. + + Example:: + + d = pg.Dict(x=pg.oneof(['a', 'b', 'c'], y=pg.manyof(2, range(4)))) + t = pg.template(d) + + assert t.dna_spec() == pg.geno.space([ + pg.geno.oneof([ + pg.geno.constant(), + pg.geno.constant(), + pg.geno.constant(), + ], location='x'), + pg.geno.manyof([ + pg.geno.constant(), + pg.geno.constant(), + pg.geno.constant(), + pg.geno.constant(), + ], location='y') + ]) + + assert t.encode(pg.Dict(x='a', y=0)) == pg.DNA([0, 0]) + assert t.decode(pg.DNA([0, 0])) == pg.Dict(x='a', y=0) + + t = pg.template(d, where=lambda x: isinstance(x, pg.hyper.ManyOf)) + assert t.dna_spec() == pg.geno.space([ + pg.geno.manyof([ + pg.geno.constant(), + pg.geno.constant(), + pg.geno.constant(), + pg.geno.constant(), + ], location='y') + ]) + assert t.encode(pg.Dict(x=pg.oneof(['a', 'b', 'c']), y=0)) == pg.DNA(0) + assert t.decode(pg.DNA(0)) == pg.Dict(x=pg.oneof(['a', 'b', 'c']), y=0) + + Args: + value: A value based on which the template is created. + where: Function to filter hyper values. If None, all hyper primitives from + `value` will be included in the encoding/decoding process. Otherwise + only the hyper values on which 'where' returns True will be included. + `where` can be useful to partition a search space into separate + optimization processes. Please see 'ObjectTemplate' docstr for details. + + Returns: + A template object. + """ + return ObjectTemplate(value, compute_derived=True, where=where) + + +def dna_spec( + value: Any, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None + ) -> geno.DNASpec: + """Returns the DNASpec from a (maybe) hyper value. + + Example:: + + hyper = pg.Dict(x=pg.oneof([1, 2, 3]), y=pg.oneof(['a', 'b'])) + spec = pg.dna_spec(hyper) + + assert spec.space_size == 6 + assert len(spec.decision_points) == 2 + print(spec.decision_points) + + # Select a partial space with `where` argument. + spec = pg.dna_spec(hyper, where=lambda x: len(x.candidates) == 2) + + assert spec.space_size == 2 + assert len(spec.decision_points) == 1 + + See also: + + * :class:`pyglove.DNASpec` + * :class:`pyglove.DNA` + + Args: + value: A (maybe) hyper value. + where: Function to filter hyper primitives. If None, all hyper primitives + from `value` will be included in the encoding/decoding process. Otherwise + only the hyper primitives on which 'where' returns True will be included. + `where` can be very useful to partition a search space into separate + optimization processes. Please see 'Template' docstr for details. + + Returns: + A DNASpec object, which represents the search space from algorithm's view. + """ + return template(value, where).dna_spec() + + +def materialize( + value: Any, + parameters: Union[geno.DNA, Dict[str, Any]], + use_literal_values: bool = True, + where: Optional[Callable[[base.HyperPrimitive], bool]] = None) -> Any: + """Materialize a (maybe) hyper value using a DNA or parameter dict. + + Example:: + + hyper_dict = pg.Dict(x=pg.oneof(['a', 'b']), y=pg.floatv(0.0, 1.0)) + + # Materialize using DNA. + assert pg.materialize( + hyper_dict, pg.DNA([0, 0.5])) == pg.Dict(x='a', y=0.5) + + # Materialize usign key value pairs. + # See `pg.DNA.from_dict` for more details. + assert pg.materialize( + hyper_dict, {'x': 0, 'y': 0.5}) == pg.Dict(x='a', y=0.5) + + # Partially materialize. + v = pg.materialize( + hyper_dict, pg.DNA(0), where=lambda x: isinstance(x, pg.hyper.OneOf)) + assert v == pg.Dict(x='a', y=pg.floatv(0.0, 1.0)) + + Args: + value: A (maybe) hyper value + parameters: A DNA object or a dict of string (key path) to a + string (in format of '/' for + `geno.Choices`, or '' for `geno.Float`), or their literal + values when `use_literal_values` is set to True. + use_literal_values: Applicable when `parameters` is a dict. If True, the + values in the dict will be from `geno.Choices.literal_values` for + `geno.Choices`. + where: Function to filter hyper primitives. If None, all hyper primitives + from `value` will be included in the encoding/decoding process. Otherwise + only the hyper primitives on which 'where' returns True will be included. + `where` can be useful to partition a search space into separate + optimization processes. Please see 'Template' docstr for details. + + Returns: + A materialized value. + + Raises: + TypeError: if parameters is not a DNA or dict. + ValueError: if parameters cannot be decoded. + """ + t = template(value, where) + if isinstance(parameters, dict): + dna = geno.DNA.from_parameters( + parameters=parameters, + dna_spec=t.dna_spec(), + use_literal_values=use_literal_values) + else: + dna = parameters + + if not isinstance(dna, geno.DNA): + raise TypeError( + f'\'parameters\' must be a DNA or a dict of string to DNA values. ' + f'Encountered: {dna!r}.') + return t.decode(dna) diff --git a/pyglove/core/hyper/object_template_test.py b/pyglove/core/hyper/object_template_test.py new file mode 100644 index 0000000..fe3347a --- /dev/null +++ b/pyglove/core/hyper/object_template_test.py @@ -0,0 +1,269 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for pyglove.hyper.ObjectTemplate.""" + +import unittest + +from pyglove.core import geno +from pyglove.core import symbolic +from pyglove.core import typing as pg_typing +from pyglove.core.hyper.categorical import oneof +from pyglove.core.hyper.derived import ValueReference +from pyglove.core.hyper.numerical import floatv +from pyglove.core.hyper.object_template import dna_spec +from pyglove.core.hyper.object_template import materialize +from pyglove.core.hyper.object_template import ObjectTemplate +from pyglove.core.hyper.object_template import template + + +class ObjectTemplateTest(unittest.TestCase): + """Tests for pg.hyper.ObjectTemplate.""" + + def test_constant_template(self): + + @symbolic.members([('x', pg_typing.Int())]) + class A(symbolic.Object): + pass + + t = ObjectTemplate({'a': A(x=1)}) + self.assertEqual(t.value, {'a': A(x=1)}) + self.assertEqual(len(t.hyper_primitives), 0) + self.assertTrue(t.is_constant) + self.assertTrue(symbolic.eq(t.dna_spec(), geno.Space(elements=[]))) + self.assertEqual(t.root_path, '') + self.assertEqual(t.decode(geno.DNA(None)), {'a': A(x=1)}) + with self.assertRaisesRegex( + ValueError, 'Encountered extra DNA value to decode'): + t.decode(geno.DNA(0)) + self.assertEqual(t.encode({'a': A(x=1)}), geno.DNA(None)) + with self.assertRaisesRegex( + ValueError, 'Unmatched Object type between template and input'): + t.encode({'a': 1}) + + def test_simple_template(self): + v = symbolic.Dict({ + 'a': oneof(candidates=[0, 2.5]), + 'b': floatv(min_value=0.0, max_value=1.0) + }) + t = ObjectTemplate(v) + self.assertEqual(t.value, v) + self.assertFalse(t.is_constant) + self.assertEqual(len(t.hyper_primitives), 2) + self.assertTrue(symbolic.eq( + t.dna_spec(), + geno.Space(elements=[ + geno.Choices( + location='a', + num_choices=1, + candidates=[geno.constant(), geno.constant()], + literal_values=[0, 2.5]), + geno.Float(location='b', min_value=0.0, max_value=1.0) + ]))) + + # Test decode. + self.assertEqual(t.decode(geno.DNA.parse([0, 0.5])), {'a': 0, 'b': 0.5}) + self.assertEqual(t.decode(geno.DNA.parse([1, 0.3])), {'a': 2.5, 'b': 0.3}) + + with self.assertRaisesRegex(ValueError, 'Expect float value'): + t.decode(geno.DNA.parse([0, 0])) + + with self.assertRaisesRegex(ValueError, 'Expect integer for OneOf'): + t.decode(geno.DNA.parse([0.5, 0.0])) + + with self.assertRaisesRegex( + ValueError, + 'The length of child values .* is different from the number ' + 'of hyper primitives'): + t.decode(geno.DNA.parse([0])) + + # Test encode. + self.assertEqual(t.encode({'a': 0, 'b': 0.5}), geno.DNA.parse([0, 0.5])) + + with self.assertRaisesRegex( + ValueError, + 'Cannot encode value: no candidates matches with the value'): + t.encode({'a': 5, 'b': 0.5}) + + # Test set_dna. + dna = geno.DNA.parse([0, 0.5]) + t.set_dna(dna) + + # Test __call__ + self.assertEqual(t(), {'a': 0, 'b': 0.5}) + + # Check after call, child DNA are properly set. + self.assertEqual(t.dna, dna) + self.assertEqual(t.hyper_primitives[0][1].dna, dna.children[0]) + self.assertEqual(t.hyper_primitives[1][1].dna, dna.children[1]) + + t.set_dna(None) + with self.assertRaisesRegex( + ValueError, '\'set_dna\' should be called to set a DNA'): + t() + + def test_template_with_where_clause(self): + @symbolic.functor() + def foo(a, b): + return a + b + + ssd = foo( + a=oneof([ + oneof([0, 1]), + 2 + ]), + b=oneof([3, 4])) + + # Test template that operates on all. + t = template(ssd) + self.assertEqual(t.decode(geno.DNA.parse([(0, 1), 0])), foo(a=1, b=3)) + self.assertEqual(t.encode(foo(a=0, b=4)), geno.DNA.parse([(0, 0), 1])) + + # Test template that operates on `foo.a`. + t = template(ssd, lambda v: v.sym_path != 'b') + self.assertEqual(t.decode(geno.DNA(1)), foo(a=2, b=oneof([3, 4]))) + self.assertEqual(t.decode(geno.DNA.parse((0, 0))), + foo(a=0, b=oneof([3, 4]))) + self.assertEqual(t.encode(foo(a=1, b=oneof([3, 4]))), + geno.DNA.parse((0, 1))) + + # Test template that operates on `foo.a.candidates[0]` (the nested oneof). + t = template(ssd, lambda v: len(v.sym_path) == 3) + self.assertEqual(t.decode(geno.DNA(1)), + foo(a=oneof([1, 2]), b=oneof([3, 4]))) + self.assertEqual(t.encode(foo(a=oneof([0, 2]), + b=oneof([3, 4]))), + geno.DNA(0)) + + # Test template that operates on `foo.b`. + t = template(ssd, lambda v: v.sym_path == 'b') + self.assertEqual(t.decode(geno.DNA(0)), + foo(a=oneof([oneof([0, 1]), 2]), b=3)) + + self.assertEqual(t.encode(foo(a=oneof([oneof([0, 1]), 2]), + b=4)), + geno.DNA(1)) + + def test_template_with_derived_value(self): + @symbolic.members([(pg_typing.StrKey(), pg_typing.Int())]) + class A(symbolic.Object): + pass + + v = symbolic.Dict({ + 'a': oneof(candidates=[0, 1]), + 'b': floatv(min_value=0.0, max_value=1.0), + 'c': ValueReference(['a']), + 'd': A(x=1) + }) + t = ObjectTemplate(v, compute_derived=True) + self.assertEqual(t.value, v) + self.assertFalse(t.is_constant) + self.assertEqual(len(t.hyper_primitives), 2) + self.assertTrue(symbolic.eq( + t.dna_spec(), + geno.Space(elements=[ + geno.Choices( + location='a', + num_choices=1, + candidates=[geno.constant(), geno.constant()], + literal_values=[0, 1]), + geno.Float(location='b', min_value=0.0, max_value=1.0) + ]))) + + # Test decode. + self.assertEqual( + t.decode(geno.DNA.parse([0, 0.5])), { + 'a': 0, + 'b': 0.5, + 'c': 0, + 'd': A(x=1) + }) + + # Test encode. + self.assertEqual( + t.encode({ + 'a': 0, + 'b': 0.5, + 'c': 0, + 'd': A(x=1) + }), geno.DNA.parse([0, 0.5])) + + with self.assertRaisesRegex( + ValueError, + 'Unmatched derived value between template and input.'): + t.encode({'a': 0, 'b': 0.5, 'c': 1, 'd': A(x=1)}) + + with self.assertRaisesRegex( + ValueError, + 'Unmatched Object keys between template value and input value'): + t.encode({'a': 0, 'b': 0.5, 'c': 0, 'd': A(y=1)}) + + def test_assignment_compatibility(self): + sd = symbolic.Dict.partial( + value_spec=pg_typing.Dict([ + ('a', pg_typing.Dict([ + ('x', pg_typing.Int()) + ])), + ('b', pg_typing.Int()) + ])) + sd.a = ObjectTemplate({'x': oneof(candidates=[1, 2, 3, 4])}) + sd.a = ObjectTemplate({'x': 1}) + with self.assertRaisesRegex(TypeError, 'Expect .* but encountered .*'): + sd.a = ObjectTemplate({'x': 'foo'}) + + def test_custom_apply(self): + t = template(symbolic.Dict()) + self.assertIs(pg_typing.Object(ObjectTemplate).apply(t), t) + self.assertIs(pg_typing.Dict().apply(t), t) + with self.assertRaisesRegex( + ValueError, 'Dict .* cannot be assigned to an incompatible field'): + pg_typing.Int().apply(t) + + +class ObjectTemplateHelperTests(unittest.TestCase): + """Tests for object template related helpers.""" + + def test_dna_spec(self): + self.assertTrue(symbolic.eq( + dna_spec(symbolic.Dict(a=oneof([0, 1]))), + geno.Space(elements=[ + geno.Choices(location='a', num_choices=1, candidates=[ + geno.constant(), + geno.constant() + ], literal_values=[0, 1]) + ]))) + + def test_materialize(self): + v = symbolic.Dict(a=oneof([1, 3])) + # Materialize using DNA. + self.assertEqual( + materialize(v, geno.DNA.parse([0])), + {'a': 1}) + # Materialize using parameter dict with use_literal_values set to False. + self.assertEqual( + materialize(v, {'a': '1/2'}, use_literal_values=False), + {'a': 3}) + # Materialize using parameter dict with use_literal_values set to True. + self.assertEqual( + materialize(v, {'a': '1/2 (3)'}, use_literal_values=True), + {'a': 3}) + + # Bad parameters. + with self.assertRaisesRegex( + TypeError, + '\'parameters\' must be a DNA or a dict of string to DNA values. '): + materialize(v, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyglove/core/hyper_test.py b/pyglove/core/hyper_test.py deleted file mode 100644 index 2c34d85..0000000 --- a/pyglove/core/hyper_test.py +++ /dev/null @@ -1,1829 +0,0 @@ -# Copyright 2019 The PyGlove Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for pyglove.hyper.""" - -import random -import re -import threading -import unittest - -from pyglove.core import geno -from pyglove.core import hyper -from pyglove.core import object_utils -from pyglove.core import symbolic -from pyglove.core import typing as schema - - -symbolic.allow_empty_field_description() -symbolic.allow_repeated_class_registration() - - -class ObjectTemplateTest(unittest.TestCase): - """Test for hyper.ObjectTemplate.""" - - def testConstantTemplate(self): - """Test basics.""" - - @symbolic.members([('x', schema.Int())]) - class A(symbolic.Object): - pass - - t = hyper.ObjectTemplate({'a': A(x=1)}) - self.assertEqual(t.value, {'a': A(x=1)}) - self.assertEqual(len(t.hyper_primitives), 0) - self.assertTrue(t.is_constant) - self.assertTrue(symbolic.eq(t.dna_spec(), geno.Space(elements=[]))) - self.assertEqual(t.root_path, '') - self.assertEqual(t.decode(geno.DNA(None)), {'a': A(x=1)}) - with self.assertRaisesRegex( - ValueError, 'Encountered extra DNA value to decode'): - t.decode(geno.DNA(0)) - self.assertEqual(t.encode({'a': A(x=1)}), geno.DNA(None)) - with self.assertRaisesRegex( - ValueError, 'Unmatched Object type between template and input'): - t.encode({'a': 1}) - - def testSimpleTemplate(self): - """Test simple template.""" - v = symbolic.Dict({ - 'a': hyper.oneof(candidates=[0, 2.5]), - 'b': hyper.floatv(min_value=0.0, max_value=1.0) - }) - t = hyper.ObjectTemplate(v) - self.assertEqual(t.value, v) - self.assertFalse(t.is_constant) - self.assertEqual(len(t.hyper_primitives), 2) - self.assertTrue(symbolic.eq( - t.dna_spec(), - geno.Space(elements=[ - geno.Choices( - location='a', - num_choices=1, - candidates=[geno.constant(), geno.constant()], - literal_values=[0, 2.5]), - geno.Float(location='b', min_value=0.0, max_value=1.0) - ]))) - - # Test decode. - self.assertEqual(t.decode(geno.DNA.parse([0, 0.5])), {'a': 0, 'b': 0.5}) - self.assertEqual(t.decode(geno.DNA.parse([1, 0.3])), {'a': 2.5, 'b': 0.3}) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - t.decode(geno.DNA.parse([0, 0])) - - with self.assertRaisesRegex(ValueError, 'Expect integer for OneOf'): - t.decode(geno.DNA.parse([0.5, 0.0])) - - with self.assertRaisesRegex( - ValueError, - 'The length of child values .* is different from the number ' - 'of hyper primitives'): - t.decode(geno.DNA.parse([0])) - - # Test encode. - self.assertEqual(t.encode({'a': 0, 'b': 0.5}), geno.DNA.parse([0, 0.5])) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - t.encode({'a': 5, 'b': 0.5}) - - # Test set_dna. - dna = geno.DNA.parse([0, 0.5]) - t.set_dna(dna) - - # Test __call__ - self.assertEqual(t(), {'a': 0, 'b': 0.5}) - - # Check after call, child DNA are properly set. - self.assertEqual(t.dna, dna) - self.assertEqual(t.hyper_primitives[0][1].dna, dna.children[0]) - self.assertEqual(t.hyper_primitives[1][1].dna, dna.children[1]) - - t.set_dna(None) - with self.assertRaisesRegex( - ValueError, '\'set_dna\' should be called to set a DNA'): - t() - - def testWhere(self): - """Test template with where clause.""" - @symbolic.functor() - def foo(a, b): - return a + b - - ssd = foo( - a=hyper.oneof([ - hyper.oneof([0, 1]), - 2 - ]), - b=hyper.oneof([3, 4])) - - # Test template that operates on all. - t = hyper.template(ssd) - self.assertEqual(t.decode(geno.DNA.parse([(0, 1), 0])), foo(a=1, b=3)) - self.assertEqual(t.encode(foo(a=0, b=4)), geno.DNA.parse([(0, 0), 1])) - - # Test template that operates on `foo.a`. - t = hyper.template(ssd, lambda v: v.sym_path != 'b') - self.assertEqual(t.decode(geno.DNA(1)), foo(a=2, b=hyper.oneof([3, 4]))) - self.assertEqual(t.decode(geno.DNA.parse((0, 0))), - foo(a=0, b=hyper.oneof([3, 4]))) - self.assertEqual(t.encode(foo(a=1, b=hyper.oneof([3, 4]))), - geno.DNA.parse((0, 1))) - - # Test template that operates on `foo.a.candidates[0]` (the nested oneof). - t = hyper.template(ssd, lambda v: len(v.sym_path) == 3) - self.assertEqual(t.decode(geno.DNA(1)), - foo(a=hyper.oneof([1, 2]), b=hyper.oneof([3, 4]))) - self.assertEqual(t.encode(foo(a=hyper.oneof([0, 2]), - b=hyper.oneof([3, 4]))), - geno.DNA(0)) - - # Test template that operates on `foo.b`. - t = hyper.template(ssd, lambda v: v.sym_path == 'b') - self.assertEqual(t.decode(geno.DNA(0)), - foo(a=hyper.oneof([hyper.oneof([0, 1]), 2]), b=3)) - - self.assertEqual(t.encode(foo(a=hyper.oneof([hyper.oneof([0, 1]), 2]), - b=4)), - geno.DNA(1)) - - def testDerived(self): - """Test template with derived value.""" - - @symbolic.members([(schema.StrKey(), schema.Int())]) - class A(symbolic.Object): - pass - - v = symbolic.Dict({ - 'a': hyper.oneof(candidates=[0, 1]), - 'b': hyper.floatv(min_value=0.0, max_value=1.0), - 'c': hyper.ValueReference(['a']), - 'd': A(x=1) - }) - t = hyper.ObjectTemplate(v, compute_derived=True) - self.assertEqual(t.value, v) - self.assertFalse(t.is_constant) - self.assertEqual(len(t.hyper_primitives), 2) - self.assertTrue(symbolic.eq( - t.dna_spec(), - geno.Space(elements=[ - geno.Choices( - location='a', - num_choices=1, - candidates=[geno.constant(), geno.constant()], - literal_values=[0, 1]), - geno.Float(location='b', min_value=0.0, max_value=1.0) - ]))) - - # Test decode. - self.assertEqual( - t.decode(geno.DNA.parse([0, 0.5])), { - 'a': 0, - 'b': 0.5, - 'c': 0, - 'd': A(x=1) - }) - - # Test encode. - self.assertEqual( - t.encode({ - 'a': 0, - 'b': 0.5, - 'c': 0, - 'd': A(x=1) - }), geno.DNA.parse([0, 0.5])) - - with self.assertRaisesRegex( - ValueError, - 'Unmatched derived value between template and input.'): - t.encode({'a': 0, 'b': 0.5, 'c': 1, 'd': A(x=1)}) - - with self.assertRaisesRegex( - ValueError, - 'Unmatched Object keys between template value and input value'): - t.encode({'a': 0, 'b': 0.5, 'c': 0, 'd': A(y=1)}) - - def testDropInCompatibility(self): - """Test drop in compatibility.""" - sd = symbolic.Dict.partial( - value_spec=schema.Dict([('a', schema.Dict([( - 'x', schema.Int())])), ('b', schema.Int())])) - sd.a = hyper.ObjectTemplate({'x': hyper.oneof(candidates=[1, 2, 3, 4])}) - sd.a = hyper.ObjectTemplate({'x': 1}) - with self.assertRaisesRegex(TypeError, 'Expect .* but encountered .*'): - sd.a = hyper.ObjectTemplate({'x': 'foo'}) - - def testCustomApply(self): - """Test custom_apply to ValueSpec.""" - t = hyper.template(symbolic.Dict()) - self.assertIs(schema.Object(hyper.Template).apply(t), t) - self.assertIs(schema.Dict().apply(t), t) - with self.assertRaisesRegex( - ValueError, 'Dict .* cannot be assigned to an incompatible field'): - schema.Int().apply(t) - - -class ManyOfTest(unittest.TestCase): - """Test for hyper.ManyOf.""" - - def testBasics(self): - """Test basics of ManyOf.""" - with self.assertRaisesRegex( - ValueError, '.* candidates cannot produce .* distinct choices'): - hyper.manyof(3, [1, 2], distinct=True) - - def testDNASpec(self): - """Test ManyOf.dna_spec().""" - - # Test simple choice list without nested encoders. - self.assertTrue(symbolic.eq( - hyper.manyof( - 2, ['foo', 1, 2, 'bar'], sorted=True, distinct=True).dna_spec(), - geno.manyof(2, [ - geno.constant(), - geno.constant(), - geno.constant(), - geno.constant() - ], literal_values=[ - '\'foo\'', 1, 2, '\'bar\'' - ], sorted=True, distinct=True))) - - # Test complex choice list with nested encoders. - self.assertTrue(symbolic.eq( - hyper.oneof([ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - }, - [hyper.floatv(min_value=1.0, max_value=2.0, scale='linear'), 1.0], - ]).dna_spec('a.b'), - geno.oneof([ - geno.constant(), - geno.space([ - geno.floatv(min_value=0.0, max_value=1.0, location='a'), - geno.oneof([ - geno.constant(), - geno.constant(), - geno.constant() - ], literal_values=[1, 2, 3], location='b') - ]), - geno.floatv(1.0, 2.0, scale='linear', location='[0]') - ], literal_values=[ - '\'foo\'', - ('{a=Float(min_value=0.0, max_value=1.0), ' - 'b=OneOf(candidates=[0: 1, 1: 2, 2: 3])}'), - '[0: Float(min_value=1.0, max_value=2.0, scale=\'linear\'), 1: 1.0]', - ], location='a.b'))) - - def testDecode(self): - """Test ManyOf.decode().""" - choice_list = hyper.manyof(2, [ - 'foo', 1, 2, 'bar' - ], choices_sorted=True, choices_distinct=True) - self.assertTrue(choice_list.is_leaf) - self.assertEqual(choice_list.decode(geno.DNA.parse([0, 1])), ['foo', 1]) - - with self.assertRaisesRegex( - ValueError, - 'Number of DNA child values does not match the number of choices'): - choice_list.decode(geno.DNA.parse([1, 0, 0])) - - with self.assertRaisesRegex(ValueError, 'Choice value should be int'): - choice_list.decode(geno.DNA.parse([0, 0.1])) - - with self.assertRaisesRegex(ValueError, 'Choice out of range'): - choice_list.decode(geno.DNA.parse([0, 5])) - - with self.assertRaisesRegex( - ValueError, 'DNA child values should be sorted'): - choice_list.decode(geno.DNA.parse([1, 0])) - - with self.assertRaisesRegex( - ValueError, 'DNA child values should be distinct'): - choice_list.decode(geno.DNA.parse([0, 0])) - - choice_list = hyper.manyof(1, [ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - }, - [hyper.floatv(min_value=1.0, max_value=2.0), 1.0], - ]) - self.assertFalse(choice_list.is_leaf) - self.assertEqual(choice_list.decode(geno.DNA.parse(0)), ['foo']) - - self.assertEqual( - choice_list.decode(geno.DNA.parse((1, [0.5, 0]))), [{ - 'a': 0.5, - 'b': 1 - }]) - - self.assertEqual(choice_list.decode(geno.DNA.parse((2, 1.5))), [[1.5, 1.0]]) - - with self.assertRaisesRegex(ValueError, 'Choice out of range'): - choice_list.decode(geno.DNA.parse(5)) - - with self.assertRaisesRegex( - ValueError, 'Encountered extra DNA value to decode'): - choice_list.decode(geno.DNA.parse((0, 1))) - - with self.assertRaisesRegex( - ValueError, - 'The length of child values .* is different from the number ' - 'of hyper primitives'): - choice_list.decode(geno.DNA.parse((1, 0))) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - choice_list.decode(geno.DNA.parse((1, [1, 0]))) - - with self.assertRaisesRegex( - ValueError, - 'The length of child values .* is different from the number ' - 'of hyper primitives'): - choice_list.decode(geno.DNA.parse((1, [0.5, 1, 2]))) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - choice_list.decode(geno.DNA.parse(2)) - - with self.assertRaisesRegex( - ValueError, 'DNA value should be no greater than'): - choice_list.decode(geno.DNA.parse((2, 5.0))) - - def testEncode(self): - """Test ManyOf.encode().""" - choice_list = hyper.manyof(1, [ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - }, - [hyper.floatv(min_value=1.0, max_value=2.0), 1.0], - ]) - self.assertEqual(choice_list.encode(['foo']), geno.DNA(0)) - self.assertEqual( - choice_list.encode([{ - 'a': 0.5, - 'b': 1 - }]), geno.DNA.parse((1, [0.5, 0]))) - self.assertEqual(choice_list.encode([[1.5, 1.0]]), geno.DNA.parse((2, 1.5))) - - with self.assertRaisesRegex( - ValueError, 'Cannot encode value: value should be a list type'): - choice_list.encode('bar') - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_list.encode(['bar']) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - print(choice_list.encode([{'a': 0.5}])) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_list.encode([{'a': 1.8, 'b': 1}]) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_list.encode([[1.0]]) - - choice_list = hyper.manyof(2, ['a', 'b', 'c']) - self.assertEqual(choice_list.encode(['a', 'c']), geno.DNA.parse([0, 2])) - with self.assertRaisesRegex( - ValueError, - 'Length of input list is different from the number of choices'): - choice_list.encode(['a']) - - def testDropInCompatibility(self): - """Test drop-in type compatibility.""" - sd = symbolic.Dict.partial( - value_spec=schema.Dict([( - 'a', schema.Int()), ('b', schema.List(schema.Int( - ))), ('c', schema.List(schema.Union( - [schema.Str(), schema.Int()]))), ('d', schema.Any())])) - choice_list = hyper.manyof(2, [1, 'foo']) - sd.c = choice_list - sd.d = choice_list - - with self.assertRaisesRegex( - TypeError, 'Cannot bind an incompatible value spec Int\\(\\)'): - sd.a = choice_list - - with self.assertRaisesRegex( - TypeError, - 'Cannot bind an incompatible value spec List\\(Int\\(\\)\\)'): - sd.b = choice_list - - def testCustomApply(self): - """test custom_apply on value specs.""" - l = hyper.manyof(2, [1, 2, 3]) - self.assertIs(schema.Object(hyper.ManyOf).apply(l), l) - self.assertIs(schema.List(schema.Int()).apply(l), l) - with self.assertRaisesRegex( - TypeError, r'Cannot bind an incompatible value spec List\(Float\(\)\)'): - schema.List(schema.Float()).apply(l) - - class A: - pass - - class B: - pass - - t = hyper.oneof([B()]) - self.assertEqual( - schema.Union([schema.Object(A), schema.Object(B)]).apply(t), t) - - -class OneOfTest(unittest.TestCase): - """Tests for hyper.OneOf.""" - - def testDNASpec(self): - """Test OneOf.dna_spec().""" - - class C: - pass - - self.assertTrue(symbolic.eq( - hyper.oneof(candidates=[ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - 'c': C() - }, - [hyper.floatv(min_value=1.0, max_value=2.0), 1.0], - ]).dna_spec('a.b'), - geno.Choices( - num_choices=1, - candidates=[ - geno.constant(), - geno.Space(elements=[ - geno.Float(min_value=0.0, max_value=1.0, location='a'), - geno.Choices( - num_choices=1, - candidates=[ - geno.constant(), - geno.constant(), - geno.constant() - ], - literal_values=[1, 2, 3], - location='b'), - ]), - geno.Space(elements=[ - geno.Float(min_value=1.0, max_value=2.0, location='[0]') - ]) - ], - literal_values=[ - '\'foo\'', - ('{a=Float(min_value=0.0, max_value=1.0), ' - 'b=OneOf(candidates=[0: 1, 1: 2, 2: 3]), ' - 'c=C(...)}'), - '[0: Float(min_value=1.0, max_value=2.0), 1: 1.0]', - ], - location='a.b'))) - - def testDecode(self): - """Test OneOf.decode().""" - choice_value = hyper.oneof(candidates=[ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - }, - [hyper.floatv(min_value=1.0, max_value=2.0), 1.0], - ]) - - self.assertEqual(choice_value.decode(geno.DNA.parse(0)), 'foo') - - self.assertEqual( - choice_value.decode(geno.DNA.parse((1, [0.5, 0]))), { - 'a': 0.5, - 'b': 1 - }) - - self.assertEqual(choice_value.decode(geno.DNA.parse((2, 1.5))), [1.5, 1.0]) - - with self.assertRaisesRegex(ValueError, 'Choice out of range'): - choice_value.decode(geno.DNA.parse(5)) - - with self.assertRaisesRegex( - ValueError, 'Encountered extra DNA value to decode'): - choice_value.decode(geno.DNA.parse((0, 1))) - - with self.assertRaisesRegex( - ValueError, - 'The length of child values .* is different from the number ' - 'of hyper primitives'): - choice_value.decode(geno.DNA.parse((1, 0))) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - choice_value.decode(geno.DNA.parse((1, [1, 0]))) - - with self.assertRaisesRegex( - ValueError, - 'The length of child values .* is different from the number ' - 'of hyper primitives'): - choice_value.decode(geno.DNA.parse((1, [0.5, 1, 2]))) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - choice_value.decode(geno.DNA.parse(2)) - - with self.assertRaisesRegex( - ValueError, 'DNA value should be no greater than'): - choice_value.decode(geno.DNA.parse((2, 5.0))) - - def testEncode(self): - """Test OneOf.encode().""" - choice_value = hyper.oneof(candidates=[ - 'foo', - { - 'a': hyper.floatv(min_value=0.0, max_value=1.0), - 'b': hyper.oneof(candidates=[1, 2, 3]), - }, - [hyper.floatv(min_value=1.0, max_value=2.0), 1.0], - ]) - self.assertEqual(choice_value.encode('foo'), geno.DNA(0)) - self.assertEqual( - choice_value.encode({ - 'a': 0.5, - 'b': 1 - }), geno.DNA.parse((1, [0.5, 0]))) - self.assertEqual(choice_value.encode([1.5, 1.0]), geno.DNA.parse((2, 1.5))) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_value.encode(['bar']) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - print(choice_value.encode({'a': 0.5})) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_value.encode({'a': 1.8, 'b': 1}) - - with self.assertRaisesRegex( - ValueError, - 'Cannot encode value: no candidates matches with the value'): - choice_value.encode([1.0]) - - def testDropInCompatibility(self): - """Test drop-in type compatibility.""" - sd = symbolic.Dict.partial( - value_spec=schema.Dict([('a', schema.Str()), ( - 'b', schema.Int()), ( - 'c', - schema.Union([schema.Str(), schema.Int()])), ('d', - schema.Any())])) - choice_value = hyper.oneof(candidates=[1, 'foo']) - sd.c = choice_value - sd.d = choice_value - - with self.assertRaisesRegex( - TypeError, 'Cannot bind an incompatible value spec'): - sd.a = choice_value - - with self.assertRaisesRegex( - TypeError, 'Cannot bind an incompatible value spec'): - sd.b = choice_value - - def testCustomApply(self): - """test custom_apply on value specs.""" - o = hyper.oneof([1, 2]) - self.assertIs(schema.Object(hyper.OneOf).apply(o), o) - self.assertIs(schema.Int().apply(o), o) - with self.assertRaisesRegex( - TypeError, r'Cannot bind an incompatible value spec Float\(\)'): - schema.Float().apply(o) - - -class FloatTest(unittest.TestCase): - """Test for hyper.Float.""" - - def setUp(self): - """Setup test.""" - super().setUp() - self._float = hyper.floatv(min_value=0.0, max_value=1.0) - - def testBasics(self): - """Test Float basics.""" - self.assertEqual(self._float.min_value, 0.0) - self.assertEqual(self._float.max_value, 1.0) - self.assertIsNone(self._float.scale) - self.assertTrue(self._float.is_leaf) - - with self.assertRaisesRegex( - ValueError, '\'min_value\' .* is greater than \'max_value\' .*'): - hyper.floatv(min_value=1.0, max_value=0.0) - - def testScale(self): - self.assertEqual(hyper.floatv(-1.0, 1.0, 'linear').scale, 'linear') - with self.assertRaisesRegex( - ValueError, '\'min_value\' must be positive'): - hyper.floatv(-1.0, 1.0, 'log') - - def testDNASpec(self): - """Test Float.dna_spec().""" - self.assertTrue(symbolic.eq( - self._float.dna_spec('a'), - geno.Float( - location=object_utils.KeyPath('a'), - min_value=self._float.min_value, - max_value=self._float.max_value))) - - def testDecode(self): - """Test Float.decode().""" - self.assertEqual(self._float.decode(geno.DNA(0.0)), 0.0) - self.assertEqual(self._float.decode(geno.DNA(1.0)), 1.0) - - with self.assertRaisesRegex(ValueError, 'Expect float value'): - self._float.decode(geno.DNA(1)) - - with self.assertRaisesRegex( - ValueError, 'DNA value should be no less than'): - self._float.decode(geno.DNA(-1.0)) - - with self.assertRaisesRegex( - ValueError, 'DNA value should be no greater than'): - self._float.decode(geno.DNA(2.0)) - - def testEncode(self): - """Test Float.encode().""" - self.assertEqual(self._float.encode(0.0), geno.DNA(0.0)) - self.assertEqual(self._float.encode(1.0), geno.DNA(1.0)) - - with self.assertRaisesRegex( - ValueError, 'Value should be float to be encoded'): - self._float.encode('abc') - - with self.assertRaisesRegex( - ValueError, 'Value should be no less than'): - self._float.encode(-1.0) - - with self.assertRaisesRegex( - ValueError, 'Value should be no greater than'): - self._float.encode(2.0) - - def testDropInCompatibility(self): - """Test drop-in type compatibility.""" - sd = symbolic.Dict.partial( - value_spec=schema.Dict([('a', schema.Int()), ('b', schema.Float( - )), ('c', - schema.Union([schema.Str(), schema.Float()])), ( - 'd', schema.Any()), ('e', schema.Float( - max_value=0.0)), ('f', schema.Float(min_value=1.0))])) - float_value = hyper.floatv(min_value=0.0, max_value=1.0) - sd.b = float_value - sd.c = float_value - sd.d = float_value - - self.assertEqual(sd.b.sym_path, 'b') - self.assertEqual(sd.c.sym_path, 'c') - self.assertEqual(sd.d.sym_path, 'd') - with self.assertRaisesRegex( - TypeError, 'Source spec Float\\(\\) is not compatible with ' - 'destination spec Int\\(\\)'): - sd.a = float_value - - with self.assertRaisesRegex( - ValueError, - 'Float.max_value .* should be no greater than the max value'): - sd.e = float_value - - with self.assertRaisesRegex( - ValueError, - 'Float.min_value .* should be no less than the min value'): - sd.f = float_value - - def testCustomApply(self): - """test custom_apply on value specs.""" - f = hyper.float_value(min_value=0.0, max_value=1.0) - self.assertIs(schema.Object(hyper.Float).apply(f), f) - self.assertIs(schema.Float().apply(f), f) - with self.assertRaisesRegex( - TypeError, r'Source spec Float\(\) is not compatible'): - schema.Int().apply(f) - - with self.assertRaisesRegex( - ValueError, r'.* should be no less than the min value'): - schema.Float(min_value=2.0).apply(f) - - -class CustomHyperTest(unittest.TestCase): - """Test for hyper.CustomHyper.""" - - def setUp(self): - """Setup test.""" - super().setUp() - - class IntSequence(hyper.CustomHyper): - - def _create_dna(self, numbers): - return geno.DNA(','.join([str(n) for n in numbers])) - - def custom_decode(self, dna): - return [int(v) for v in dna.value.split(',')] - - class IntSequenceWithEncode(IntSequence): - - def custom_encode(self, value): - return geno.DNA(','.join([str(v) for v in value])) - - def next_dna(self, dna): - if dna is None: - return geno.DNA(','.join([str(i) for i in range(5)])) - v = self.custom_decode(dna) - v.append(len(v)) - return self._create_dna(v) - - def random_dna(self, random_generator, previous_dna): - del previous_dna - k = random_generator.randint(0, 10) - v = random_generator.choices(list(range(10)), k=k) - return self._create_dna(v) - - self._int_sequence = IntSequence(hints='1,2,-3,4,5,-2,7') - self._int_sequence_with_encode = IntSequenceWithEncode( - hints='1,2,-3,4,5,-2,7') - - def testDNASpec(self): - """Test CustomHyper.dna_spec().""" - self.assertTrue(symbolic.eq( - self._int_sequence.dna_spec('a'), - geno.CustomDecisionPoint( - hyper_type='IntSequence', - location=object_utils.KeyPath('a'), - hints='1,2,-3,4,5,-2,7'))) - - def testDecode(self): - """Test CustomHyper.decode().""" - self.assertEqual( - self._int_sequence.decode(geno.DNA('0,1,2')), [0, 1, 2]) - self.assertEqual( - self._int_sequence.decode(geno.DNA('0')), [0]) - with self.assertRaisesRegex( - ValueError, '.* expects string type DNA'): - self._int_sequence.decode(geno.DNA(1)) - - def testEncode(self): - """Test CustomHyper.encode().""" - self.assertEqual( - self._int_sequence_with_encode.encode([0, 1, 2]), - geno.DNA('0,1,2')) - - with self.assertRaisesRegex( - NotImplementedError, '\'custom_encode\' is not supported by'): - _ = self._int_sequence.encode([0, 1, 2]) - - def testRandomDNA(self): - """Test working with pg.random_dna.""" - self.assertEqual( - geno.random_dna( - self._int_sequence_with_encode.dna_spec('a'), random.Random(1)), - geno.DNA('5,8')) - - with self.assertRaisesRegex( - NotImplementedError, '`random_dna` is not implemented in .*'): - geno.random_dna(self._int_sequence.dna_spec('a')) - - def testIter(self): - """Test working with pg.iter.""" - self.assertEqual( - self._int_sequence_with_encode.first_dna(), - geno.DNA('0,1,2,3,4')) - self.assertEqual( - list(hyper.iterate(self._int_sequence_with_encode, 3)), - [[0, 1, 2, 3, 4], - [0, 1, 2, 3, 4, 5], - [0, 1, 2, 3, 4, 5, 6]]) - - with self.assertRaisesRegex( - NotImplementedError, '`next_dna` is not implemented in .*'): - next(hyper.iterate(self._int_sequence)) - - def testCooperation(self): - """Test cooperation with pg.oneof.""" - hv = hyper.oneof([ - self._int_sequence, - 1, - 2 - ]) - self.assertEqual(hyper.materialize(hv, geno.DNA(1)), 1) - self.assertEqual(hyper.materialize(hv, geno.DNA((0, '3,4'))), [3, 4]) - - -# -# Classes used for evolvable tests. -# - - -class Layer(symbolic.Object): - pass - - -@symbolic.members([ - ('layers', schema.List(schema.Object(Layer))), -]) -class Sequential(Layer): - pass - - -class Activation(Layer): - pass - - -class ReLU(Activation): - pass - - -class Swish(Activation): - pass - - -@symbolic.members([ - ('filters', schema.Int(min_value=1)), - # `kernel_size` is marked as no_mutation, which should not appear as a - # mutation candidate. - ('kernel_size', schema.Int(min_value=1), '', {'no_mutation': True}), - ('activation', schema.Object(Activation).noneable()) -]) -class Conv(Layer): - pass - - -class EvolvableTest(unittest.TestCase): - """Tests for hyper.evolvable.""" - - def setUp(self): - super().setUp() - self._seed_program = Sequential([ - Conv(16, 3, ReLU()), - Conv(32, 5, Swish()), - Sequential([ - Conv(64, 7) - ]) - ]) - def mutate_at_location( - mutation_type: hyper.MutationType, location: str): - def _weights(mt, k, v, p): - del v, p - if mt == mutation_type and re.match(location, str(k)): - return 1.0 - return 0.0 - return _weights - self._mutate_at_location = mutate_at_location - - def testBasics(self): - v = hyper.evolve( - self._seed_program, lambda k, v, p: ReLU(), - weights=self._mutate_at_location( - hyper.MutationType.REPLACE, r'^layers\[.*\]$')) - self.assertEqual( - self._seed_program, - v.custom_decode(v.custom_encode(self._seed_program))) - self.assertEqual( - v.first_dna(), - v.custom_encode(self._seed_program)) - self.assertEqual(v.random_dna(), v.custom_encode(self._seed_program)) - self.assertEqual( - v.random_dna(random.Random(1), v.first_dna()), - v.custom_encode( - Sequential([ - ReLU(), - Conv(32, 5, Swish()), - Sequential([ - Conv(64, 7) - ]) - ]))) - - def testReplace(self): - v = hyper.evolve( - self._seed_program, lambda k, v, p: ReLU(), - weights=self._mutate_at_location( - hyper.MutationType.REPLACE, r'^layers\[1\]$')) - self.assertEqual( - v.mutate(self._seed_program), - Sequential([ - Conv(16, 3, ReLU()), - ReLU(), - Sequential([ - Conv(64, 7) - ]) - ])) - - def testInsertion(self): - v = hyper.evolve( - self._seed_program, lambda k, v, p: ReLU(), - weights=self._mutate_at_location( - hyper.MutationType.INSERT, r'^layers\[1\]$')) - self.assertEqual( - v.mutate(self._seed_program), - Sequential([ - Conv(16, 3, ReLU()), - ReLU(), - Conv(32, 5, Swish()), - Sequential([ - Conv(64, 7) - ]) - ])) - - def testDelete(self): - v = hyper.evolve( - self._seed_program, lambda k, v, p: ReLU(), - weights=self._mutate_at_location( - hyper.MutationType.DELETE, r'^layers\[1\]$')) - self.assertEqual( - v.mutate(self._seed_program, random.Random(1)), - Sequential([ - Conv(16, 3, ReLU()), - Sequential([ - Conv(64, 7) - ]) - ])) - - def testRandomGenerator(self): - v = hyper.evolve( - self._seed_program, lambda k, v, p: ReLU(), - weights=self._mutate_at_location( - hyper.MutationType.REPLACE, r'^layers\[.*\]$')) - self.assertEqual( - v.mutate(self._seed_program, random_generator=random.Random(1)), - Sequential([ - ReLU(), - Conv(32, 5, Swish()), - Sequential([ - Conv(64, 7) - ]) - ])) - - def testMutationPointsAndWeights(self): - v = hyper.evolve( - self._seed_program, - lambda k, v, p: v, - weights=lambda *x: 1.0) - points, weights = v.mutation_points_and_weights(self._seed_program) - - # NOTE(daiyip): Conv.kernel_size is marked with 'no_mutation', thus - # it should not show here. - self.assertEqual([(p.mutation_type, p.location) for p in points], [ - (hyper.MutationType.REPLACE, 'layers'), - (hyper.MutationType.INSERT, 'layers[0]'), - (hyper.MutationType.DELETE, 'layers[0]'), - (hyper.MutationType.REPLACE, 'layers[0]'), - (hyper.MutationType.REPLACE, 'layers[0].filters'), - (hyper.MutationType.REPLACE, 'layers[0].activation'), - (hyper.MutationType.INSERT, 'layers[1]'), - (hyper.MutationType.DELETE, 'layers[1]'), - (hyper.MutationType.REPLACE, 'layers[1]'), - (hyper.MutationType.REPLACE, 'layers[1].filters'), - (hyper.MutationType.REPLACE, 'layers[1].activation'), - (hyper.MutationType.INSERT, 'layers[2]'), - (hyper.MutationType.DELETE, 'layers[2]'), - (hyper.MutationType.REPLACE, 'layers[2]'), - (hyper.MutationType.REPLACE, 'layers[2].layers'), - (hyper.MutationType.INSERT, 'layers[2].layers[0]'), - (hyper.MutationType.DELETE, 'layers[2].layers[0]'), - (hyper.MutationType.REPLACE, 'layers[2].layers[0]'), - (hyper.MutationType.REPLACE, 'layers[2].layers[0].filters'), - (hyper.MutationType.REPLACE, 'layers[2].layers[0].activation'), - (hyper.MutationType.INSERT, 'layers[2].layers[1]'), - (hyper.MutationType.INSERT, 'layers[3]'), - ]) - self.assertEqual(weights, [1.0] * len(points)) - - def testMutationPointsAndWeightsWithHonoringListSize(self): - # Non-typed list. There is no size limit. - v = hyper.evolve( - symbolic.List([]), lambda k, v, p: v, - weights=lambda *x: 1.0) - points, _ = v.mutation_points_and_weights(symbolic.List([1])) - self.assertEqual([(p.mutation_type, p.location) for p in points], [ - (hyper.MutationType.INSERT, '[0]'), - (hyper.MutationType.DELETE, '[0]'), - (hyper.MutationType.REPLACE, '[0]'), - (hyper.MutationType.INSERT, '[1]'), - ]) - - # Typed list with size limit. - value_spec = schema.List(schema.Int(), min_size=1, max_size=3) - points, _ = v.mutation_points_and_weights( - symbolic.List([1, 2], value_spec=value_spec)) - self.assertEqual([(p.mutation_type, p.location) for p in points], [ - (hyper.MutationType.INSERT, '[0]'), - (hyper.MutationType.DELETE, '[0]'), - (hyper.MutationType.REPLACE, '[0]'), - (hyper.MutationType.INSERT, '[1]'), - (hyper.MutationType.DELETE, '[1]'), - (hyper.MutationType.REPLACE, '[1]'), - (hyper.MutationType.INSERT, '[2]'), - ]) - points, _ = v.mutation_points_and_weights( - symbolic.List([1], value_spec=value_spec)) - self.assertEqual([(p.mutation_type, p.location) for p in points], [ - (hyper.MutationType.INSERT, '[0]'), - (hyper.MutationType.REPLACE, '[0]'), - (hyper.MutationType.INSERT, '[1]'), - ]) - points, _ = v.mutation_points_and_weights( - symbolic.List([1, 2, 3], value_spec=value_spec)) - self.assertEqual([(p.mutation_type, p.location) for p in points], [ - (hyper.MutationType.DELETE, '[0]'), - (hyper.MutationType.REPLACE, '[0]'), - (hyper.MutationType.DELETE, '[1]'), - (hyper.MutationType.REPLACE, '[1]'), - (hyper.MutationType.DELETE, '[2]'), - (hyper.MutationType.REPLACE, '[2]'), - ]) - - -class TunableValueHelpersTests(unittest.TestCase): - """Tests for helper methods on tunable values.""" - - def testDNASpec(self): - """Test hyper.dna_spec.""" - v = symbolic.Dict(a=hyper.oneof([0, 1])) - self.assertTrue(symbolic.eq( - hyper.dna_spec(v), - geno.Space(elements=[ - geno.Choices(location='a', num_choices=1, candidates=[ - geno.constant(), - geno.constant() - ], literal_values=[0, 1]) - ]))) - - def testMaterialize(self): - """Test hyper.materialize.""" - v = symbolic.Dict(a=hyper.oneof([1, 3])) - # Materialize using DNA. - self.assertEqual( - hyper.materialize(v, geno.DNA.parse([0])), - {'a': 1}) - # Materialize using parameter dict with use_literal_values set to False. - self.assertEqual( - hyper.materialize(v, {'a': '1/2'}, use_literal_values=False), - {'a': 3}) - # Materialize using parameter dict with use_literal_values set to True. - self.assertEqual( - hyper.materialize(v, {'a': '1/2 (3)'}, use_literal_values=True), - {'a': 3}) - - # Bad parameters. - with self.assertRaisesRegex( - TypeError, - '\'parameters\' must be a DNA or a dict of string to DNA values. '): - hyper.materialize(v, 1) - - def testIterate(self): - """Test hyper.iterate.""" - # Test iterate with default algorithm (Sweeping) - v = hyper.oneof(range(100)) - examples = list(hyper.iterate(v)) - self.assertEqual(examples, list(range(100))) - - examples = list(hyper.iterate(v, 10)) - self.assertEqual(examples, list(range(10))) - - class ConstantAlgorithm(geno.DNAGenerator): - """An algorithm that always emit a constant DNA.""" - - def _on_bound(self): - self._rewards = [] - - def _propose(self): - if len(self._rewards) == 100: - raise StopIteration() - return geno.DNA(0) - - def _feedback(self, dna, reward): - self._rewards.append(reward) - - @property - def rewards(self): - return self._rewards - - # Test iterate with a custom algorithm. - v = hyper.oneof([1, 3]) - algo = ConstantAlgorithm() - examples = [] - for i, (x, feedback) in enumerate(hyper.iterate(v, 5, algo)): - examples.append(x) - feedback(float(i)) - self.assertEqual(feedback.dna, geno.DNA(0)) - self.assertEqual(len(examples), 5) - self.assertEqual(examples, [1] * 5) - self.assertEqual(algo.rewards, [float(i) for i in range(5)]) - - for x, feedback in hyper.iterate(v, algorithm=algo): - examples.append(x) - feedback(0.) - self.assertEqual(len(examples), 100) - - # Test iterate with dynamic evaluation. - def foo(): - return hyper.oneof([1, 3]) - examples = [] - for x in hyper.iterate(hyper.trace(foo)): - with x(): - examples.append(foo()) - self.assertEqual(examples, [1, 3]) - - with self.assertRaisesRegex( - ValueError, '\'hyper_value\' is a constant value'): - next(hyper.iterate('foo', algo)) - - # Test iterate on DNAGenerator that generate a no-op feedback. - class ConstantAlgorithm2(geno.DNAGenerator): - """An algorithm that always emit a constant DNA.""" - - def propose(self): - return geno.DNA(0) - - algo = ConstantAlgorithm2() - examples = [] - for x, feedback in hyper.iterate( - v, 10, algorithm=algo, force_feedback=True): - examples.append(x) - # No op. - feedback(0.) - self.assertEqual(len(examples), 10) - - # Test iterate with continuation. - class ConstantAlgorithm3(geno.DNAGenerator): - """An algorithm that always emit a constant DNA.""" - - def setup(self, dna_spec): - super().setup(dna_spec) - self.num_trials = 0 - - def propose(self): - self.num_trials += 1 - return geno.DNA(0) - - algo = ConstantAlgorithm3() - for unused_x in hyper.iterate(v, 10, algo): - pass - for unused_x in hyper.iterate(v, 10, algo): - pass - self.assertEqual(algo.num_trials, 20) - with self.assertRaisesRegex( - ValueError, '.* has been set up with a different DNASpec'): - next(hyper.iterate(hyper.oneof([2, 3]), 10, algo)) - - def testRandomSample(self): - """Test hyper.random_sample.""" - self.assertEqual( - list(hyper.random_sample(hyper.one_of([0, 1]), 3, seed=123)), - [0, 1, 0]) - - -class DynamicEvaluationTest(unittest.TestCase): - """Dynamic evaluation test.""" - - def testDynamicEvaluate(self): - """Test dynamic_evaluate.""" - with self.assertRaisesRegex( - ValueError, - '\'evaluate_fn\' must be either None or a callable object'): - with hyper.dynamic_evaluate(1): - pass - - with self.assertRaisesRegex( - ValueError, - '\'exit_fn\' must be a callable object'): - with hyper.dynamic_evaluate(None, exit_fn=1): - pass - - def testDynamicEvaluatedValues(self): - """Test dynamically evaluated values.""" - with hyper.DynamicEvaluationContext().collect(): - self.assertEqual(hyper.oneof([0, 1]), 0) - self.assertEqual(hyper.oneof([{'x': hyper.oneof(['a', 'b'])}, 1]), - {'x': 'a'}) - self.assertEqual(hyper.manyof(2, [0, 1, 3]), [0, 1]) - self.assertEqual(hyper.manyof(4, [0, 1, 3], distinct=False), - [0, 0, 0, 0]) - self.assertEqual(hyper.permutate([0, 1, 2]), [0, 1, 2]) - self.assertEqual(hyper.floatv(0.0, 1.0), 0.0) - - def testDefineByRunPerThread(self): - """Test DynamicEvaluationContext per-thread.""" - def thread_fun(): - context = hyper.DynamicEvaluationContext() - with context.collect(): - hyper.oneof(range(10)) - - with context.apply([3]): - self.assertEqual(hyper.oneof(range(10)), 3) - - threads = [] - for _ in range(10): - thread = threading.Thread(target=thread_fun) - threads.append(thread) - thread.start() - for t in threads: - t.join() - - def testDefineByRunPerProcess(self): - """Test DynamicEvaluationContext per-process.""" - def thread_fun(): - _ = hyper.oneof(range(10)) - - context = hyper.DynamicEvaluationContext(per_thread=False) - with context.collect() as hyper_dict: - threads = [] - for _ in range(10): - thread = threading.Thread(target=thread_fun) - threads.append(thread) - thread.start() - for t in threads: - t.join() - - self.assertEqual(len(hyper_dict), 10) - - def testIndependentDecisions(self): - """Test the search space of independent decisions.""" - def fun(): - x = hyper.oneof([1, 2, 3]) + 1 - y = sum(hyper.manyof(2, [2, 4, 6, 8], name='y')) - z = hyper.floatv(min_value=1.0, max_value=2.0) - return x + y + z - - # Test dynamic evaluation by allowing reentry (all hyper primitives will - # be registered twice). - context = hyper.DynamicEvaluationContext() - with context.collect() as hyper_dict: - result = fun() - result = fun() - - # 1 + 1 + 2 + 4 + 1.0 - self.assertEqual(result, 9.0) - self.assertEqual(hyper_dict, { - 'decision_0': hyper.oneof([1, 2, 3]), - 'y': hyper.manyof(2, [2, 4, 6, 8], name='y'), - 'decision_1': hyper.floatv(min_value=1.0, max_value=2.0), - 'decision_2': hyper.oneof([1, 2, 3]), - 'decision_3': hyper.floatv(min_value=1.0, max_value=2.0), - }) - - with context.apply(geno.DNA.parse( - [1, [0, 2], 1.5, 0, 1.8])): - # 2 + 1 + 2 + 6 + 1.5 - self.assertEqual(fun(), 12.5) - # 1 + 1 + 2 + 6 + 1.8 - self.assertEqual(fun(), 11.8) - - def testIndependentDecisionsWithRequiringHyperName(self): - """Test independent decisions with requiring hyper primitive name.""" - def fun(): - x = hyper.oneof([1, 2, 3], name='a') + 1 - y = sum(hyper.manyof(2, [2, 4, 6, 8], name='b')) - z = hyper.floatv(min_value=1.0, max_value=2.0, name='c') - return x + y + z - - # Test dynamic evaluation by disallowing reentry (all hyper primitives will - # be registered only once). - context = hyper.DynamicEvaluationContext(require_hyper_name=True) - with context.collect() as hyper_dict: - with self.assertRaisesRegex( - ValueError, '\'name\' must be specified for hyper primitive'): - hyper.oneof([1, 2, 3]) - result = fun() - result = fun() - - # 1 + 1 + 2 + 4 + 1.0 - self.assertEqual(result, 9.0) - self.assertEqual(hyper_dict, symbolic.Dict( - a=hyper.oneof([1, 2, 3], name='a'), - b=hyper.manyof(2, [2, 4, 6, 8], name='b'), - c=hyper.floatv(min_value=1.0, max_value=2.0, name='c'))) - with context.apply(geno.DNA.parse([1, [0, 2], 1.5])): - # We can call fun multiple times since decision will be bound to each - # name just once. - # 2 + 1 + 2 + 6 + 1.5 - self.assertEqual(fun(), 12.5) - self.assertEqual(fun(), 12.5) - self.assertEqual(fun(), 12.5) - - def testHierarchicalDecisions(self): - """Test hierarchical search space.""" - def fun(): - return hyper.oneof([ - lambda: sum(hyper.manyof(2, [2, 4, 6, 8])), - lambda: hyper.oneof([3, 7]), - lambda: hyper.floatv(min_value=1.0, max_value=2.0), - 10]) + hyper.oneof([11, 22]) - - context = hyper.DynamicEvaluationContext() - with context.collect() as hyper_dict: - result = fun() - # 2 + 4 + 11 - self.assertEqual(result, 17) - self.assertEqual(hyper_dict, { - 'decision_0': hyper.oneof([ - # NOTE(daiyip): child decisions within candidates are always in - # form of list. - { - 'decision_1': hyper.manyof(2, [2, 4, 6, 8]), - }, - { - 'decision_2': hyper.oneof([3, 7]) - }, - { - 'decision_3': hyper.floatv(min_value=1.0, max_value=2.0) - }, - 10, - ]), - 'decision_4': hyper.oneof([11, 22]) - }) - - with context.apply(geno.DNA.parse([(0, [1, 3]), 0])): - # 4 + 8 + 11 - self.assertEqual(fun(), 23) - - # Use list-form decisions. - with context.apply([0, 1, 3, 0]): - # 4 + 8 + 11 - self.assertEqual(fun(), 23) - - with context.apply(geno.DNA.parse([(1, 1), 1])): - # 7 + 22 - self.assertEqual(fun(), 29) - - with context.apply(geno.DNA.parse([(2, 1.5), 0])): - # 1.5 + 11 - self.assertEqual(fun(), 12.5) - - with context.apply(geno.DNA.parse([3, 1])): - # 10 + 22 - self.assertEqual(fun(), 32) - - with self.assertRaisesRegex( - ValueError, '`decisions` should be a DNA or a list of numbers.'): - with context.apply(3): - fun() - - with self.assertRaisesRegex( - ValueError, 'No decision is provided for .*'): - with context.apply(geno.DNA.parse(3)): - fun() - - with self.assertRaisesRegex( - ValueError, 'Expect float-type decision for .*'): - with context.apply([2, 0, 1]): - fun() - - with self.assertRaisesRegex( - ValueError, 'Expect int-type decision in range .*'): - with context.apply([5, 0.5, 0]): - fun() - - with self.assertRaisesRegex( - ValueError, 'Found extra decision values that are not used.*'): - with context.apply(geno.DNA.parse([(1, 1), 1, 1])): - fun() - - def testHierarchicalDecisionsWithRequiringHyperName(self): - """Test hierarchical search space.""" - def fun(): - return hyper.oneof([ - lambda: sum(hyper.manyof(2, [2, 4, 6, 8], name='a1')), - lambda: hyper.oneof([3, 7], name='a2'), - lambda: hyper.floatv(min_value=1.0, max_value=2.0, name='a3.xx'), - 10], name='a') + hyper.oneof([11, 22], name='b') - - context = hyper.DynamicEvaluationContext(require_hyper_name=True) - with context.collect() as hyper_dict: - result = fun() - result = fun() - - # 2 + 4 + 11 - self.assertEqual(result, 17) - self.assertEqual(hyper_dict, { - 'a': hyper.oneof([ - # NOTE(daiyip): child decisions within candidates are always in - # form of list. - {'a1': hyper.manyof(2, [2, 4, 6, 8], name='a1')}, - {'a2': hyper.oneof([3, 7], name='a2')}, - {'a3.xx': hyper.floatv(min_value=1.0, max_value=2.0, name='a3.xx')}, - 10, - ], name='a'), - 'b': hyper.oneof([11, 22], name='b') - }) - - with context.apply(geno.DNA.parse([(0, [1, 3]), 0])): - # 4 + 8 + 11 - self.assertEqual(fun(), 23) - self.assertEqual(fun(), 23) - self.assertEqual(fun(), 23) - - # Use list form. - with context.apply([0, 1, 3, 0]): - # 4 + 8 + 11 - self.assertEqual(fun(), 23) - self.assertEqual(fun(), 23) - self.assertEqual(fun(), 23) - - with context.apply(geno.DNA.parse([(1, 1), 1])): - # 7 + 22 - self.assertEqual(fun(), 29) - self.assertEqual(fun(), 29) - - with context.apply(geno.DNA.parse([(2, 1.5), 0])): - # 1.5 + 11 - self.assertEqual(fun(), 12.5) - self.assertEqual(fun(), 12.5) - - with context.apply(geno.DNA.parse([3, 1])): - # 10 + 22 - self.assertEqual(fun(), 32) - self.assertEqual(fun(), 32) - - with self.assertRaisesRegex( - ValueError, '`decisions` should be a DNA or a list of numbers.'): - with context.apply(3): - fun() - - with self.assertRaisesRegex( - ValueError, 'DNA value type mismatch'): - with context.apply(geno.DNA.parse(3)): - fun() - - with self.assertRaisesRegex( - ValueError, 'Found extra decision values that are not used'): - with context.apply(context.dna_spec.first_dna()): - # Do not consume any decision points from the search space. - _ = 1 - - with self.assertRaisesRegex( - ValueError, - 'Hyper primitive .* is not defined during search space inspection'): - with context.apply(context.dna_spec.first_dna()): - # Do not consume any decision points from the search space. - _ = hyper.oneof(range(5), name='uknown') - - def testWhereStatement(self): - """Test `where`.""" - context = hyper.DynamicEvaluationContext( - where=lambda x: getattr(x, 'name') != 'x') - with context.collect(): - self.assertEqual(hyper.oneof(range(10)), 0) - self.assertIsInstance(hyper.oneof(range(5), name='x'), hyper.OneOf) - - with context.apply([1]): - self.assertEqual(hyper.oneof(range(10)), 1) - self.assertIsInstance(hyper.oneof(range(5), name='x'), hyper.OneOf) - - def testTrace(self): - """Test `trace`.""" - def fun(): - return hyper.oneof([-1, 0, 1]) * hyper.oneof([-1, 0, 3]) + 1 - - self.assertEqual( - hyper.trace(fun).hyper_dict, - { - 'decision_0': hyper.oneof([-1, 0, 1]), - 'decision_1': hyper.oneof([-1, 0, 3]) - }) - - def testCustomHyper(self): - """Test dynamic evaluation with custom hyper.""" - - class IntList(hyper.CustomHyper): - - def custom_decode(self, dna): - return [int(x) for x in dna.value.split(':')] - - def first_dna(self): - return geno.DNA('0:1:2:3') - - def fun(): - return sum(IntList()) + hyper.oneof([0, 1]) + hyper.floatv(-1., 1.) - - context = hyper.DynamicEvaluationContext() - with context.collect(): - fun() - - self.assertEqual( - context.hyper_dict, - { - 'decision_0': IntList(), - 'decision_1': hyper.oneof([0, 1]), - 'decision_2': hyper.floatv(-1., 1.) - }) - with context.apply(geno.DNA(['1:2:3:4', 1, 0.5])): - self.assertEqual(fun(), 1 + 2 + 3 + 4 + 1 + 0.5) - - with self.assertRaisesRegex( - ValueError, 'Expect string-type decision for .*'): - with context.apply(geno.DNA([0, 1, 0.5])): - fun() - - class IntListWithoutFirstDNA(hyper.CustomHyper): - - def custom_decode(self, dna): - return [int(x) for x in dna.value.split(':')] - - context = hyper.DynamicEvaluationContext() - with self.assertRaisesRegex( - NotImplementedError, - '.* must implement method `next_dna` to be used in ' - 'dynamic evaluation mode'): - with context.collect(): - IntListWithoutFirstDNA() - - def testExternalDNASpec(self): - """Test dynamic evalaution with external DNASpec.""" - - def fun(): - return hyper.oneof(range(5), name='x') + hyper.oneof(range(3), name='y') - - context = hyper.trace(fun, require_hyper_name=True, per_thread=True) - self.assertFalse(context.is_external) - self.assertIsNotNone(context.hyper_dict) - - search_space_str = symbolic.to_json_str(context.dna_spec) - - context2 = hyper.DynamicEvaluationContext( - require_hyper_name=True, per_thread=True, - dna_spec=symbolic.from_json_str(search_space_str)) - self.assertTrue(context2.is_external) - self.assertIsNone(context2.hyper_dict) - - with self.assertRaisesRegex( - ValueError, - '`collect` cannot be called .* is using an external DNASpec'): - with context2.collect(): - fun() - - with context2.apply(geno.DNA([1, 2])): - self.assertEqual(fun(), 3) - - def testNestedDynamicEvaluationSimple(self): - """Test nested dynamic evaluation context.""" - def fun(): - return sum([ - hyper.oneof([1, 2, 3], hints='ssd1'), - hyper.oneof([4, 5], hints='ssd2'), - ]) - - context1 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd1') - context2 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd2') - with context1.collect(): - with context2.collect(): - self.assertEqual(fun(), 1 + 4) - - self.assertEqual( - context1.hyper_dict, { - 'decision_0': hyper.oneof([1, 2, 3], hints='ssd1') - }) - self.assertEqual( - context2.hyper_dict, { - 'decision_0': hyper.oneof([4, 5], hints='ssd2') - }) - with context1.apply(geno.DNA(2)): - with context2.apply(geno.DNA(1)): - self.assertEqual(fun(), 3 + 5) - - def testNestedDynamicEvaluationWithRequiredHyperName(self): - """Test nested dynamic evaluation context with required hyper name.""" - def fun(): - return sum([ - hyper.oneof([1, 2, 3], name='x', hints='ssd1'), - hyper.oneof([4, 5], name='y', hints='ssd2'), - ]) - - context1 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd1') - context2 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd2') - with context1.collect(): - with context2.collect(): - self.assertEqual(fun(), 1 + 4) - - self.assertEqual( - context1.hyper_dict, { - 'x': hyper.oneof([1, 2, 3], name='x', hints='ssd1') - }) - self.assertEqual( - context2.hyper_dict, { - 'y': hyper.oneof([4, 5], name='y', hints='ssd2') - }) - with context1.apply(geno.DNA(2)): - with context2.apply(geno.DNA(1)): - self.assertEqual(fun(), 3 + 5) - - def testNestedSearchSpaceInNestedDynamicEvaluationContext(self): - """Test nested search space in nested dynamic evaluation context.""" - def fun(): - return sum([ - hyper.oneof([ - lambda: hyper.oneof([1, 2, 3], name='y', hints='ssd1'), - lambda: hyper.oneof([4, 5, 6], name='z', hints='ssd1'), - ], name='x', hints='ssd1'), - hyper.oneof([7, 8], name='p', hints='ssd2'), - hyper.oneof([9, 10], name='q', hints='ssd2'), - ]) - context1 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd1') - context2 = hyper.DynamicEvaluationContext( - where=lambda x: x.hints == 'ssd2') - with context1.collect(): - with context2.collect(): - self.assertEqual(fun(), 1 + 7 + 9) - - self.assertEqual( - context1.hyper_dict, { - 'x': hyper.oneof([ - {'y': hyper.oneof([1, 2, 3], name='y', hints='ssd1')}, - {'z': hyper.oneof([4, 5, 6], name='z', hints='ssd1')}, - ], name='x', hints='ssd1') - }) - self.assertEqual( - context2.hyper_dict, { - 'p': hyper.oneof([7, 8], name='p', hints='ssd2'), - 'q': hyper.oneof([9, 10], name='q', hints='ssd2') - }) - with context1.apply(geno.DNA((1, 1))): - with context2.apply(geno.DNA([0, 1])): - self.assertEqual(fun(), 5 + 7 + 10) - - def testNestedDynamicEvaluationWithDifferentPerThreadSetting(self): - """Test nested dynamic evaluation context with different per-thread.""" - context1 = hyper.DynamicEvaluationContext(per_thread=True) - context2 = hyper.DynamicEvaluationContext(per_thread=False) - - def fun(): - return hyper.oneof([1, 2, 3]) - - with self.assertRaisesRegex( - ValueError, - 'Nested dynamic evaluation contexts must be either .*'): - with context1.collect(): - with context2.collect(): - fun() - - def testDynamicEvaluationWithManualRegistry(self): - """Test dynamic evaluation context with manual registration.""" - context = hyper.DynamicEvaluationContext() - self.assertEqual( - context.add_decision_point(hyper.oneof([1, 2, 3])), 1) - self.assertEqual( - context.add_decision_point(hyper.oneof(['a', 'b'], name='x')), 'a') - self.assertEqual( - context.add_decision_point(hyper.template(1)), 1) - - with self.assertRaisesRegex( - ValueError, 'Found different hyper primitives under the same name'): - context.add_decision_point(hyper.oneof(['foo', 'bar'], name='x')) - - self.assertEqual(context.hyper_dict, { - 'decision_0': hyper.oneof([1, 2, 3]), - 'x': hyper.oneof(['a', 'b'], name='x'), - }) - - with self.assertRaisesRegex( - ValueError, '`evaluate` needs to be called under the `apply` context'): - context.evaluate(hyper.oneof([1, 2, 3])) - - with context.apply([1, 1]): - self.assertEqual(context.evaluate(context.hyper_dict['decision_0']), 2) - self.assertEqual(context.evaluate(context.hyper_dict['x']), 'b') - - -class ValueReferenceTest(unittest.TestCase): - """Tests for hyper.ValueReference classes.""" - - def testResolve(self): - """Test ValueReference.resolve.""" - sd = symbolic.Dict({'c': [ - { - 'x': [{ - 'z': 0 - }], - }, - { - 'x': [{ - 'z': 1 - }] - }, - ]}) - sd.a = hyper.ValueReference(reference_paths=['c[0].x[0].z']) - self.assertEqual(sd.a.resolve(), [(sd, 'c[0].x[0].z')]) - - # References refer to the same relative path under different parent. - ref = hyper.ValueReference(reference_paths=['x[0].z']) - sd.c[0].y = ref - sd.c[1].y = ref - self.assertEqual(sd.c[0].y.resolve(), [(sd.c[0], 'c[0].x[0].z')]) - self.assertEqual(sd.c[1].y.resolve(), [(sd.c[1], 'c[1].x[0].z')]) - # Resolve references from this point. - self.assertEqual(sd.c[0].y.resolve(object_utils.KeyPath(0)), (sd.c, 'c[0]')) - self.assertEqual(sd.c[0].y.resolve('[0]'), (sd.c, 'c[0]')) - self.assertEqual(sd.c[0].y.resolve(['[0]', '[1]']), [(sd.c, 'c[0]'), - (sd.c, 'c[1]')]) - - # Bad inputs. - with self.assertRaisesRegex( - ValueError, - 'Argument \'reference_path_or_paths\' must be None, a string, KeyPath ' - 'object, a list of strings, or a list of KeyPath objects.'): - sd.c[0].y.resolve([1]) - - with self.assertRaisesRegex( - ValueError, - 'Argument \'reference_path_or_paths\' must be None, a string, KeyPath ' - 'object, a list of strings, or a list of KeyPath objects.'): - sd.c[0].y.resolve(1) - - with self.assertRaisesRegex( - ValueError, 'Cannot resolve .*: parent not found.'): - hyper.ValueReference(reference_paths=['x[0].z']).resolve() - - def testCall(self): - """Test ValueReference.__call__.""" - - @symbolic.members([('a', schema.Int(), 'Field a.')]) - class A(symbolic.Object): - pass - - sd = symbolic.Dict({'c': [ - { - 'x': [{ - 'z': 0 - }], - }, - { - 'x': [{ - 'z': A(a=1) - }] - }, - ]}) - sd.a = hyper.ValueReference(reference_paths=['c[0].x[0].z']) - self.assertEqual(sd.a(), 0) - - # References refer to the same relative path under different parent. - ref = hyper.ValueReference(reference_paths=['x[0]']) - sd.c[0].y = ref - sd.c[1].y = ref - self.assertEqual(sd.c[0].y(), {'z': 0}) - self.assertEqual(sd.c[1].y(), {'z': A(a=1)}) - - # References to another reference is not supported. - sd.c[1].z = hyper.ValueReference(reference_paths=['y']) - with self.assertRaisesRegex( - ValueError, - 'Derived value .* should not reference derived values'): - sd.c[1].z() - - sd.c[1].z = hyper.ValueReference(reference_paths=['c']) - with self.assertRaisesRegex( - ValueError, - 'Derived value .* should not reference derived values'): - sd.c[1].z() - - def testSchemaCheck(self): - """Test for schema checking on derived value.""" - sd = symbolic.Dict.partial( - x=0, - value_spec=schema.Dict([('x', schema.Int()), ('y', schema.Int()), - ('z', schema.Str())])) - - sd.y = hyper.ValueReference(['x']) - # TODO(daiyip): Enable this test once static analysis is done - # on derived values. - # with self.assertRaisesRegexp( - # TypeError, ''): - # sd.z = hyper.ValueReference(['x']) - - def testBadInit(self): - """Test bad __init__.""" - with self.assertRaisesRegex( - ValueError, - 'Argument \'reference_paths\' should have exact 1 item'): - hyper.ValueReference([]) - -if __name__ == '__main__': - unittest.main()