Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update ray syntax to get tests passing in CI
Browse files Browse the repository at this point in the history
Michael-T-McCann committed Sep 27, 2023
1 parent 83d129f commit 3757233
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions scico/test/test_ray_tune.py
Original file line number Diff line number Diff line change
@@ -7,19 +7,18 @@

try:
import ray
from scico.ray import report, tune
from scico.ray import train, tune

ray.init(num_cpus=1)
except ImportError as e:
pytest.skip("ray.tune not installed", allow_module_level=True)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_random_run():
def eval_params(config, reporter):
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
reporter(cost=cost)
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
@@ -40,12 +39,11 @@ def eval_params(config, reporter):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_random_tune():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
@@ -66,12 +64,11 @@ def eval_params(config):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_run():
def eval_params(config, reporter):
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
reporter(cost=cost)
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
@@ -90,12 +87,11 @@ def eval_params(config, reporter):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_tune():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
resources = {"gpu": 0, "cpu": 1}
@@ -115,12 +111,11 @@ def eval_params(config):
assert np.abs(best_config["y"] - 0.5) < 0.25


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_hyperopt_tune_alt_init():
def eval_params(config):
x, y = config["x"], config["y"]
cost = x**2 + (y - 0.5) ** 2
report({"cost": cost})
train.report({"cost": cost})

config = {"x": tune.uniform(-1, 1), "y": tune.uniform(-1, 1)}
tuner = tune.Tuner(

0 comments on commit 3757233

Please sign in to comment.