Skip to content

Commit

Permalink
Dynamic evaluation to support divide-and-conquer on a search space.
Browse files Browse the repository at this point in the history
Example:

```
def fun():
  return pg.oneof([1, 2], hints='ssd1') + pg.oneof([3, 4], hints='ssd2')

ssd1 = pg.hyper.DynamicEvaluationContext(where=lambda x: x.hints == 'ssd1')
ssd2 = pg.hyper.DynamicEvaluationContext(where=lambda x: x.hints == 'ssd2')

# Partitioning the search space into ssd1 and ssd2.
with ssd1.collect():
  with ssd2.collect():
    fun()

# Nested search.
for ex1, f1 in pg.sample(ssd1, algorithm1):
  rs = []
  for ex2, f2 in pg.sample(ssd2, algorithm2):
    with ex1():
      with ex2():
        r = fun()
        f2(r)
        rs.append(r)
  f1(sum(rs))
```

PiperOrigin-RevId: 480115321
  • Loading branch information
daiyip authored and pyglove authors committed Oct 10, 2022
1 parent 31990e5 commit 81d2f1a
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 89 deletions.
299 changes: 210 additions & 89 deletions pyglove/core/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,6 +2208,17 @@ def foo():
best_foo, best_reward = example, reward
"""

class _AnnoymousHyperNameAccumulator:
"""Name accumulator for annoymous hyper primitives."""

def __init__(self):
self.index = 0

def next_name(self):
name = f'decision_{self.index}'
self.index += 1
return name

def __init__(self,
where: Optional[Callable[[HyperPrimitive], bool]] = None,
require_hyper_name: bool = False,
Expand Down Expand Up @@ -2236,10 +2247,17 @@ def __init__(self,
self._where = where
self._require_hyper_name: bool = require_hyper_name
self._name_to_hyper: Dict[Text, HyperPrimitive] = dict()
self._annoymous_hyper_name_accumulator: int = 0
self._annoymous_hyper_name_accumulator = (
DynamicEvaluationContext._AnnoymousHyperNameAccumulator())
self._hyper_dict = symbolic.Dict() if dna_spec is None else None
self._dna_spec: Optional[geno.DNASpec] = dna_spec
self._per_thread = per_thread
self._decision_getter = None

@property
def per_thread(self) -> bool:
"""Returns True if current context collects/applies decisions per thread."""
return self._per_thread

@property
def dna_spec(self) -> geno.DNASpec:
Expand All @@ -2257,8 +2275,7 @@ def _decision_name(self, hyper_primitive: HyperPrimitive) -> Text:
raise ValueError(
f'\'name\' must be specified for hyper '
f'primitive {hyper_primitive!r}.')
name = f'decision_{self._annoymous_hyper_name_accumulator}'
self._annoymous_hyper_name_accumulator += 1
name = self._annoymous_hyper_name_accumulator.next_name()
return name

@property
Expand Down Expand Up @@ -2296,71 +2313,80 @@ def collect(self):
f'`collect` cannot be called on a dynamic evaluation context that is '
f'using an external DNASpec: {self._dna_spec}.')

with self._collect() as sub_space:
# Ensure per-thread dynamic evaluation context will not be used
# together with process-level dynamic evaluation context.
_dynamic_evaluation_stack.ensure_thread_safety(self)

self._hyper_dict = {}
with dynamic_evaluate(self.add_decision_point, per_thread=self._per_thread):
try:
# Push current context to dynamic evaluatoin stack so nested context
# can defer unresolved hyper primitive to current context.
_dynamic_evaluation_stack.push(self)
yield self._hyper_dict
finally:
# NOTE(daiyip): when registering new hyper primitives in the sub-space,
# the keys are already ensured not to conflict with the keys in current
# search space. Therefore it's safe to update current space.
self._hyper_dict.update(sub_space)

finally:
# Invalidate DNASpec.
self._dna_spec = None

def _collect(self):
"""A context manager for collecting hyper primitive within the scope."""
hyper_dict = symbolic.Dict()
# Pop current context from dynamic evaluatoin stack.
_dynamic_evaluation_stack.pop(self)

def _register_child(c):
def add_decision_point(self, hyper_primitive: HyperPrimitive):
"""Registers a parameter with current context and return its first value."""
def _add_child_decision_point(c):
if isinstance(c, types.LambdaType):
s = schema.get_signature(c)
if not s.args and not s.has_wildcard_args:
with self._collect() as child_hyper:
sub_context = DynamicEvaluationContext(
where=self._where, per_thread=self._per_thread)
sub_context._annoymous_hyper_name_accumulator = ( # pylint: disable=protected-access
self._annoymous_hyper_name_accumulator)
with sub_context.collect() as hyper_dict:
v = c()
return (v, child_hyper)
return (v, hyper_dict)
return (c, c)

def _register_hyper_primitive(hyper_primitive):
"""Registers a decision point from an hyper_primitive."""
if self._where and not self._where(hyper_primitive):
# Skip hyper primitives that do not pass the `where` predicate.
return hyper_primitive

if isinstance(hyper_primitive, Template):
return hyper_primitive.value

assert isinstance(hyper_primitive, HyperPrimitive), hyper_primitive
name = self._decision_name(hyper_primitive)
if isinstance(hyper_primitive, Choices):
candidate_values, candidates = zip(
*[_register_child(c) for c in hyper_primitive.candidates])
if hyper_primitive.choices_distinct:
assert hyper_primitive.num_choices <= len(hyper_primitive.candidates)
v = [candidate_values[i] for i in range(hyper_primitive.num_choices)]
else:
v = [candidate_values[0]] * hyper_primitive.num_choices
hyper_primitive = hyper_primitive.clone(deep=True, override={
'candidates': list(candidates)
})
first_value = v[0] if isinstance(hyper_primitive, ChoiceValue) else v
elif isinstance(hyper_primitive, Float):
first_value = hyper_primitive.min_value
if self._where and not self._where(hyper_primitive):
# Delegate the resolution of hyper primitives that do not pass
# the `where` predicate to its parent context.
parent_context = _dynamic_evaluation_stack.get_parent(self)
if parent_context is not None:
return parent_context.add_decision_point(hyper_primitive)
return hyper_primitive

if isinstance(hyper_primitive, Template):
return hyper_primitive.value

assert isinstance(hyper_primitive, HyperPrimitive), hyper_primitive
name = self._decision_name(hyper_primitive)
if isinstance(hyper_primitive, Choices):
candidate_values, candidates = zip(
*[_add_child_decision_point(c) for c in hyper_primitive.candidates])
if hyper_primitive.choices_distinct:
assert hyper_primitive.num_choices <= len(hyper_primitive.candidates)
v = [candidate_values[i] for i in range(hyper_primitive.num_choices)]
else:
assert isinstance(hyper_primitive, CustomHyper), hyper_primitive
first_value = hyper_primitive.decode(hyper_primitive.first_dna())
v = [candidate_values[0]] * hyper_primitive.num_choices
hyper_primitive = hyper_primitive.clone(deep=True, override={
'candidates': list(candidates)
})
first_value = v[0] if isinstance(hyper_primitive, ChoiceValue) else v
elif isinstance(hyper_primitive, Float):
first_value = hyper_primitive.min_value
else:
assert isinstance(hyper_primitive, CustomHyper), hyper_primitive
first_value = hyper_primitive.decode(hyper_primitive.first_dna())

if (name in self._name_to_hyper
and hyper_primitive != self._name_to_hyper[name]):
raise ValueError(
f'Found different hyper primitives under the same name {name!r}: '
f'Instance1={self._name_to_hyper[name]!r}, '
f'Instance2={hyper_primitive!r}.')
hyper_dict[name] = hyper_primitive
self._name_to_hyper[name] = hyper_primitive
return first_value
return dynamic_evaluate(
_register_hyper_primitive, hyper_dict, per_thread=self._per_thread)
if (name in self._name_to_hyper
and hyper_primitive != self._name_to_hyper[name]):
raise ValueError(
f'Found different hyper primitives under the same name {name!r}: '
f'Instance1={self._name_to_hyper[name]!r}, '
f'Instance2={hyper_primitive!r}.')
self._hyper_dict[name] = hyper_primitive
self._name_to_hyper[name] = hyper_primitive
return first_value

def _decision_getter_and_evaluation_finalizer(
self, decisions: Union[geno.DNA, List[Union[int, float, str]]]):
Expand Down Expand Up @@ -2461,6 +2487,7 @@ def err_on_unused_decisions():
f'Found extra decision values that are not used: {remaining!r}')
return get_decision_by_position, err_on_unused_decisions

@contextlib.contextmanager
def apply(
self, decisions: Union[geno.DNA, List[Union[int, float, str]]]):
"""Context manager for applying decisions.
Expand All @@ -2482,65 +2509,159 @@ def fun():
decisions: A DNA or a list of numbers or strings as decisions for currrent
search space.
Returns:
Context manager for applying decisions to the function that defines the
search space.
Yields:
None
"""
if not isinstance(decisions, (geno.DNA, list)):
raise ValueError('`decisions` should be a DNA or a list of numbers.')

get_decision, evaluation_finalizer = (
# Ensure per-thread dynamic evaluation context will not be used
# together with process-level dynamic evaluation context.
_dynamic_evaluation_stack.ensure_thread_safety(self)

get_current_decision, evaluation_finalizer = (
self._decision_getter_and_evaluation_finalizer(decisions))

has_errors = False
with dynamic_evaluate(self.evaluate, per_thread=self._per_thread):
try:
# Set decision getter for current decision.
self._decision_getter = get_current_decision

# Push current context to dynamic evaluation stack so nested context
# can delegate evaluate to current context.
_dynamic_evaluation_stack.push(self)

yield
except Exception:
has_errors = True
raise
finally:
# Pop current context from dynamic evaluatoin stack.
_dynamic_evaluation_stack.pop(self)

# Reset decisions.
self._decision_getter = None

# Call evaluation finalizer to make sure all decisions are used.
if not has_errors:
evaluation_finalizer()

def evaluate(self, hyper_primitive: HyperPrimitive):
"""Evaluates a hyper primitive based on current decisions."""
if self._decision_getter is None:
raise ValueError(
'`evaluate` needs to be called under the `apply` context.')

get_current_decision = self._decision_getter
def _apply_child(c):
if isinstance(c, types.LambdaType):
s = schema.get_signature(c)
if not s.args and not s.has_wildcard_args:
return c()
return c

def _apply_decision(hyper_primitive: HyperPrimitive):
"""Apply a decision value to an hyper_primitive object."""
if self._where and not self._where(hyper_primitive):
# Skip hyper primitives that do not pass the `where` predicate.
return hyper_primitive

if isinstance(hyper_primitive, Float):
return get_decision(hyper_primitive)

if isinstance(hyper_primitive, CustomHyper):
return hyper_primitive.decode(geno.DNA(get_decision(hyper_primitive)))

assert isinstance(hyper_primitive, Choices)
value = symbolic.List()
for i in range(hyper_primitive.num_choices):
# NOTE(daiyip): during registering the hyper primitives when
# constructing the search space, we will need to evaluate every
# candidate in order to pick up sub search spaces correctly, which is
# not necessary for `pg.DynamicEvaluationContext.apply`.
value.append(_apply_child(
hyper_primitive.candidates[get_decision(hyper_primitive, i)]))
if isinstance(hyper_primitive, ChoiceValue):
assert len(value) == 1
value = value[0]
return value
return dynamic_evaluate(
_apply_decision,
exit_fn=evaluation_finalizer,
per_thread=self._per_thread)
if self._where and not self._where(hyper_primitive):
# Delegate the resolution of hyper primitives that do not pass
# the `where` predicate to its parent context.
parent_context = _dynamic_evaluation_stack.get_parent(self)
if parent_context is not None:
return parent_context.evaluate(hyper_primitive)
return hyper_primitive

if isinstance(hyper_primitive, Float):
return get_current_decision(hyper_primitive)

if isinstance(hyper_primitive, CustomHyper):
return hyper_primitive.decode(
geno.DNA(get_current_decision(hyper_primitive)))

assert isinstance(hyper_primitive, Choices), hyper_primitive
value = symbolic.List()
for i in range(hyper_primitive.num_choices):
# NOTE(daiyip): during registering the hyper primitives when
# constructing the search space, we will need to evaluate every
# candidate in order to pick up sub search spaces correctly, which is
# not necessary for `pg.DynamicEvaluationContext.apply`.
value.append(_apply_child(
hyper_primitive.candidates[get_current_decision(hyper_primitive, i)]))
if isinstance(hyper_primitive, ChoiceValue):
assert len(value) == 1
value = value[0]
return value


# We maintain a stack of dynamic evaluation context for support search space
# combination
class _DynamicEvaluationStack:
"""Dynamic evaluation stack used for dealing with nested evaluation."""

_TLS_KEY = 'dynamic_evaluation_stack'

def __init__(self):
self._global_stack = []

def ensure_thread_safety(self, context: DynamicEvaluationContext):
if ((context.per_thread and self._global_stack)
or (not context.per_thread and self._local_stack)):
raise ValueError(
'Nested dynamic evaluation contexts must be either all per-thread '
'or all process-wise. Please check the `per_thread` argument of '
'the `pg.hyper.DynamicEvaluationContext` objects being used.')

@property
def _local_stack(self):
"""Returns thread-local stack."""
stack = getattr(_thread_local_state, self._TLS_KEY, None)
if stack is None:
stack = []
setattr(_thread_local_state, self._TLS_KEY, stack)
return stack

def push(self, context: DynamicEvaluationContext):
"""Pushes the context to the stack."""
stack = self._local_stack if context.per_thread else self._global_stack
stack.append(context)

def pop(self, context: DynamicEvaluationContext):
"""Pops the context from the stack."""
stack = self._local_stack if context.per_thread else self._global_stack
assert stack
stack_top = stack.pop(-1)
assert stack_top is context, (stack_top, context)

def get_parent(
self,
context: DynamicEvaluationContext) -> Optional[DynamicEvaluationContext]:
"""Returns the parent context of the input context."""
stack = self._local_stack if context.per_thread else self._global_stack
parent = None
for i in reversed(range(1, len(stack))):
if context is stack[i]:
parent = stack[i - 1]
break
return parent


# System-wise dynamic evaluation stack.
_dynamic_evaluation_stack = _DynamicEvaluationStack()


def trace(
fun: Callable[[], Any],
*,
where: Optional[Callable[[HyperPrimitive], bool]] = None,
require_hyper_name: bool = False,
per_thread: bool = True
) -> DynamicEvaluationContext:
per_thread: bool = True) -> DynamicEvaluationContext:
"""Trace the hyper primitives called within a function by executing it.
See examples in :class:`pyglove.hyper.DynamicEvaluationContext`.
Args:
fun: Function in which the search space is defined.
where: A callable object that decide whether a hyper primitive should be
included when being instantiated under `collect`.
If None, all hyper primitives under `collect` will be included.
require_hyper_name: If True, all hyper primitives defined in this scope
will need to carry their names, which is usually a good idea when the
function that instantiates the hyper primtives need to be called multiple
Expand All @@ -2552,7 +2673,7 @@ def trace(
An DynamicEvaluationContext that can be passed to `pg.sample`.
"""
context = DynamicEvaluationContext(
require_hyper_name=require_hyper_name, per_thread=per_thread)
where=where, require_hyper_name=require_hyper_name, per_thread=per_thread)
with context.collect():
fun()
return context
Expand Down
Loading

0 comments on commit 81d2f1a

Please sign in to comment.