Skip to content

Commit

Permalink
test: update tests with fixtures (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa authored Aug 8, 2023
1 parent 2dfbf58 commit 14f9860
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pytest/
*.pytest_cache/
.coverage
.converage.*
pytest-coverage.txt
wandb/

Expand Down
10 changes: 5 additions & 5 deletions tests/algos/__snapshots__/test_algos.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@
)
# ---
# name: TestTorchAlgos.test_reproducibility[stable_learning_control.algos.pytorch.sac.sac].1
-0.3054710924625397
-0.17646123468875885
# ---
# name: TestTorchAlgos.test_reproducibility[stable_learning_control.algos.pytorch.sac.sac].2
-0.5438500046730042
-0.7749449014663696
# ---
# name: TestTorchAlgos.test_reproducibility[stable_learning_control.algos.pytorch.sac.sac].3
-0.5551561713218689
-0.4696856439113617
# ---
# name: TestTorchAlgos.test_reproducibility[stable_learning_control.algos.pytorch.sac.sac].4
0.7831009030342102
0.8471362590789795
# ---
# name: TestTorchAlgos.test_reproducibility[stable_learning_control.algos.pytorch.sac.sac].5
-0.43937113881111145
-0.3496301770210266
# ---
10 changes: 5 additions & 5 deletions tests/algos/gpu/__snapshots__/test_algos_gpu.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@
)
# ---
# name: TestTorchAlgosGPU.test_reproducibility[gpu-stable_learning_control.algos.pytorch.sac.sac].1
-0.7098895311355591
-0.5978871583938599
# ---
# name: TestTorchAlgosGPU.test_reproducibility[gpu-stable_learning_control.algos.pytorch.sac.sac].2
-0.1639314442873001
-0.613029956817627
# ---
# name: TestTorchAlgosGPU.test_reproducibility[gpu-stable_learning_control.algos.pytorch.sac.sac].3
-0.7671857476234436
-0.7007529735565186
# ---
# name: TestTorchAlgosGPU.test_reproducibility[gpu-stable_learning_control.algos.pytorch.sac.sac].4
0.415353924036026
0.7336636781692505
# ---
# name: TestTorchAlgosGPU.test_reproducibility[gpu-stable_learning_control.algos.pytorch.sac.sac].5
-0.8201737403869629
-0.8068016171455383
# ---
21 changes: 13 additions & 8 deletions tests/algos/gpu/test_algos_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@
@pytest.mark.parametrize("algo", ALGOS)
@pytest.mark.parametrize("device", ["gpu"])
class TestTorchAlgosGPU:
env = gym.make("Pendulum-v1") # Used because it is a simple environment.
@pytest.fixture
def env(self):
"""Create Pendulum environment."""
env = gym.make("Pendulum-v1") # Used because it is a simple environment.

# Seed the environment.
env.np_random, seed = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)
# Seed the environment.
env.np_random, _ = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)

def test_reproducibility(self, algo, device, snapshot):
return env

def test_reproducibility(self, algo, device, snapshot, env):
"""Checks if the algorithm is still working as expected."""
# Import the algorithm run function.
run = getattr(importlib.import_module(algo), algo.split(".")[-1])

# Run the algorithm.
trained_policy, replay_buffer = run(
lambda: self.env,
lambda: env,
seed=0,
epochs=1,
update_after=400,
Expand All @@ -48,5 +53,5 @@ def test_reproducibility(self, algo, device, snapshot):

# Test if the actions returned by the policy are the same.
for _ in range(5):
action = trained_policy.get_action(self.env.observation_space.sample())
action = trained_policy.get_action(env.observation_space.sample())
assert action == snapshot
21 changes: 13 additions & 8 deletions tests/algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,26 @@

@pytest.mark.parametrize("algo", ALGOS)
class TestTorchAlgos:
env = gym.make("Pendulum-v1") # Used because it is a simple environment.
@pytest.fixture
def env(self):
"""Create Pendulum environment."""
env = gym.make("Pendulum-v1") # Used because it is a simple environment.

# Seed the environment.
env.np_random, seed = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)
# Seed the environment.
env.np_random, _ = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)

def test_reproducibility(self, algo, snapshot):
return env

def test_reproducibility(self, algo, snapshot, env):
"""Checks if the algorithm is still working as expected."""
# Import the algorithm run function.
run = getattr(importlib.import_module(algo), algo.split(".")[-1])

# Run the algorithm.
trained_policy, replay_buffer = run(
lambda: self.env,
lambda: env,
seed=0,
epochs=1,
update_after=400,
Expand All @@ -47,5 +52,5 @@ def test_reproducibility(self, algo, snapshot):

# Test if the actions returned by the policy are the same.
for _ in range(5):
action = trained_policy.get_action(self.env.observation_space.sample())
action = trained_policy.get_action(env.observation_space.sample())
assert action == snapshot
21 changes: 13 additions & 8 deletions tests/algos/tf2/gpu/test_tf2_algos_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
@pytest.mark.parametrize("algo", ALGOS)
@pytest.mark.parametrize("device", ["gpu"])
class TestTF2AlgosGPU:
env = gym.make("Pendulum-v1") # Used because it is a simple environment.
@pytest.fixture
def env(self):
"""Create Pendulum environment."""
env = gym.make("Pendulum-v1") # Used because it is a simple environment.

# Seed the environment.
env.np_random, seed = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)
# Seed the environment.
env.np_random, _ = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)

def test_reproducibility(self, algo, device, snapshot):
return env

def test_reproducibility(self, algo, device, snapshot, env):
"""Checks if the algorithm is still working as expected."""
# Check if TensorFlow is available.
if not importlib.util.find_spec("tensorflow"):
Expand All @@ -45,7 +50,7 @@ def test_reproducibility(self, algo, device, snapshot):

# Run the algorithm.
trained_policy, replay_buffer = run(
lambda: self.env,
lambda: env,
seed=0,
epochs=1,
update_after=400,
Expand All @@ -59,5 +64,5 @@ def test_reproducibility(self, algo, device, snapshot):

# Test if the actions returned by the policy are the same.
for _ in range(5):
action = trained_policy.get_action(self.env.observation_space.sample())
action = trained_policy.get_action(env.observation_space.sample())
assert action.numpy() == snapshot
21 changes: 13 additions & 8 deletions tests/algos/tf2/test_tf2_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@

@pytest.mark.parametrize("algo", ALGOS)
class TestTF2Algos:
env = gym.make("Pendulum-v1") # Used because it is a simple environment.
@pytest.fixture
def env(self):
"""Create Pendulum environment."""
env = gym.make("Pendulum-v1") # Used because it is a simple environment.

# Seed the environment.
env.np_random, seed = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)
# Seed the environment.
env.np_random, _ = seeding.np_random(0)
env.action_space.seed(0)
env.observation_space.seed(0)

def test_reproducibility(self, algo, snapshot):
return env

def test_reproducibility(self, algo, snapshot, env):
"""Checks if the algorithm is still working as expected."""
# Check if TensorFlow is available.
if not importlib.util.find_spec("tensorflow"):
Expand All @@ -39,7 +44,7 @@ def test_reproducibility(self, algo, snapshot):

# Run the algorithm.
trained_policy, replay_buffer = run(
lambda: self.env,
lambda: env,
seed=0,
epochs=1,
update_after=400,
Expand All @@ -53,5 +58,5 @@ def test_reproducibility(self, algo, snapshot):

# Test if the actions returned by the policy are the same.
for _ in range(5):
action = trained_policy.get_action(self.env.observation_space.sample())
action = trained_policy.get_action(env.observation_space.sample())
assert action.numpy() == snapshot

0 comments on commit 14f9860

Please sign in to comment.