Skip to content

Commit

Permalink
Code restructure: break "tuning.py" into smaller files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 492486521
  • Loading branch information
daiyip authored and pyglove authors committed Dec 2, 2022
1 parent 5026b3d commit 9603443
Show file tree
Hide file tree
Showing 10 changed files with 1,601 additions and 1,431 deletions.
1,288 changes: 0 additions & 1,288 deletions pyglove/core/tuning.py

This file was deleted.

58 changes: 58 additions & 0 deletions pyglove/core/tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed tuning with pluggable backends.
:func:`pyglove.iter` provides an interface for sampling examples from a search
space within a process. To support distributed tuning, PyGlove introduces
:func:`pyglove.sample`, which is almost identical but with more features:
* Allow multiple worker processes (aka. workers) to collaborate on a search
with failover handling.
* Each worker can process different trials, or can cowork on the same trials
via work groups.
* Provide APIs for communicating between the co-workers.
* Provide API for retrieving the search results.
* Provide a pluggable backend system for supporting user infrastructures.
"""

# pylint: disable=g-bad-import-order

# User facing APIs for tuning.
from pyglove.core.tuning.sample import sample
from pyglove.core.tuning.backend import poll_result

from pyglove.core.tuning.backend import default_backend
from pyglove.core.tuning.backend import set_default_backend

# Tuning protocols.
from pyglove.core.tuning.protocols import Measurement
from pyglove.core.tuning.protocols import Trial
from pyglove.core.tuning.protocols import Result
from pyglove.core.tuning.protocols import Feedback
from pyglove.core.tuning.protocols import RaceConditionError

# Interface for early stopping.
from pyglove.core.tuning.early_stopping import EarlyStoppingPolicy

# Interfaces for tuning backend developers.
from pyglove.core.tuning.backend import Backend
from pyglove.core.tuning.backend import BackendFactory
from pyglove.core.tuning.backend import add_backend
from pyglove.core.tuning.backend import available_backends

# Importing local backend.
import pyglove.core.tuning.local_backend

# pylint: enable=g-bad-import-order
115 changes: 115 additions & 0 deletions pyglove/core/tuning/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for tuning backend and backend factory."""

import abc
from typing import List, Optional, Sequence

from pyglove.core import geno
from pyglove.core.tuning.early_stopping import EarlyStoppingPolicy
from pyglove.core.tuning.protocols import Feedback
from pyglove.core.tuning.protocols import Result


class Backend(metaclass=abc.ABCMeta):
"""Interface for the tuning backend."""

@abc.abstractmethod
def setup(self,
name: Optional[str],
group_id: Optional[str],
dna_spec: geno.DNASpec,
algorithm: geno.DNAGenerator,
metrics_to_optimize: Sequence[str],
early_stopping_policy: Optional[EarlyStoppingPolicy] = None,
num_examples: Optional[int] = None) -> None:
"""Setup current backend for an existing or a new sampling.
Args:
name: An unique string as the identifier for the sampling instance.
group_id: An optional group id for current process.
dna_spec: DNASpec for current sampling.
algorithm: Search algorithm used for current sampling.
metrics_to_optimize: metric names to optimize.
early_stopping_policy: An optional early stopping policy.
num_examples: Max number of examples to sample. Infinite if None.
"""

@abc.abstractmethod
def next(self) -> Feedback:
"""Get the feedback object for the next sample."""


class BackendFactory(metaclass=abc.ABCMeta):
"""Interface for tuning backend factory."""

@abc.abstractmethod
def create(self, **kwargs) -> Backend:
"""Creates a tuning backend for an existing or a new sampling.
Args:
**kwargs: Backend-specific keyword arguments passed from `pg.sample`.
"""

@abc.abstractmethod
def poll_result(self, name: str) -> Result:
"""Gets tuning result by a unique tuning identifier."""


_backend_registry = dict()
_default_backend_name = 'in-memory'


def add_backend(backend_name: str):
"""Decorator to register a backend factory with name."""
def _decorator(factory_cls):
if not issubclass(factory_cls, BackendFactory):
raise TypeError(f'{factory_cls!r} is not a BackendFactory subclass.')
_backend_registry[backend_name] = factory_cls
return factory_cls
return _decorator


def available_backends() -> List[str]:
"""Gets available backend names."""
return list(_backend_registry.keys())


def set_default_backend(backend_name: str):
"""Sets the default tuning backend name."""
if backend_name not in _backend_registry:
raise ValueError(f'Backend {backend_name!r} does not exist.')
global _default_backend_name
_default_backend_name = backend_name


def default_backend() -> str:
"""Gets the default tuning backend name."""
return _default_backend_name


def create_backend_factory(backend_name: str) -> BackendFactory:
"""Get backend by name."""
backend_name = backend_name or default_backend()
if backend_name not in _backend_registry:
raise ValueError(f'Backend {backend_name!r} does not exist.')
return _backend_registry[backend_name]()


def poll_result(
name: str,
backend: Optional[str] = None,
**kwargs) -> Result:
"""Gets tuning result by name."""
return create_backend_factory(backend).poll_result(name, **kwargs)
57 changes: 57 additions & 0 deletions pyglove/core/tuning/backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pyglove.core.tuning.backend."""

import unittest
from pyglove.core.tuning import backend
from pyglove.core.tuning import local_backend # pylint: disable=unused-import


class BackendTest(unittest.TestCase):
"""Tests for pluggable backend."""

def test_pluggable_backend(self):
self.assertEqual(backend.available_backends(), ['in-memory'])

@backend.add_backend('test')
class TestBackendFactory(backend.BackendFactory): # pylint: disable=unused-variable
"""A fake backend factory for testing."""

def create(self, **kwargs):
return None

def poll_result(self, name):
return None

self.assertEqual(backend.available_backends(), ['in-memory', 'test'])
self.assertEqual(backend.default_backend(), 'in-memory')
backend.set_default_backend('test')
self.assertEqual(backend.default_backend(), 'test')

with self.assertRaisesRegex(
ValueError, 'Backend .* does not exist'):
backend.set_default_backend('non-exist-backend')

with self.assertRaisesRegex(
TypeError, '.* is not a BackendFactory subclass'):

@backend.add_backend('bad')
class BadBackendFactory: # pylint: disable=unused-variable
pass
backend.set_default_backend('in-memory')
self.assertEqual(backend.default_backend(), 'in-memory')


if __name__ == '__main__':
unittest.main()
64 changes: 64 additions & 0 deletions pyglove/core/tuning/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 The PyGlove Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for early stopping policies."""

import abc
from typing import Iterable, Optional

from pyglove.core import geno
from pyglove.core import symbolic
from pyglove.core.tuning.protocols import Trial


class EarlyStoppingPolicy(symbolic.Object):
"""Interface for early stopping policy."""

def setup(self, dna_spec: geno.DNASpec) -> None:
"""Setup states of an early stopping policy based on dna_spec.
Args:
dna_spec: DNASpec for DNA to propose.
Raises:
RuntimeError: if dna_spec is not supported.
"""
self._dna_spec = dna_spec

@property
def dna_spec(self) -> Optional[geno.DNASpec]:
return getattr(self, '_dna_spec', None)

@abc.abstractmethod
def should_stop_early(self, trial: Trial) -> bool:
"""Should stop the input trial early based on its measurements."""

def recover(self, history: Iterable[Trial]) -> None:
"""Recover states by replaying the trial history.
Subclass can override.
NOTE: `recover` will always be called before the first `should_stop_early`
is called. It could be called multiple times if there are multiple source
of history, e.g: trials from a previous study and existing trials from
current study.
The default behavior is to replay `should_stop_early` on all intermediate
measurements on all trials.
Args:
history: An iterable object of trials.
"""
for trial in history:
if trial.status in ['COMPLETED', 'PENDING', 'STOPPING']:
self.should_stop_early(trial)
Loading

0 comments on commit 9603443

Please sign in to comment.