Skip to content

Commit

Permalink
Merge pull request #149 from GFNOrg/no_more_class_factories
Browse files Browse the repository at this point in the history
No more class factories
  • Loading branch information
josephdviviano authored Feb 16, 2024
2 parents eedc7e8 + ae3fa2e commit 3276492
Show file tree
Hide file tree
Showing 22 changed files with 531 additions and 380 deletions.
10 changes: 5 additions & 5 deletions docs/requirements_docs.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pre-commit
black
pytest
sphinx==5.3.0
myst-parser==0.18.1
sphinx_rtd_theme==1.1.1
sphinx-math-dollar==1.2.1
sphinx-autoapi==2.0.0
sphinx>=6.2.1
myst-parser
sphinx_rtd_theme
sphinx-math-dollar
sphinx-autoapi>=3.0.0
renku-sphinx-theme
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ myst-parser = { version = "*", optional = true }
pre-commit = { version = "*", optional = true }
pytest = { version = "*", optional = true }
renku-sphinx-theme = { version = "*", optional = true }
sphinx = { version = "*", optional = true }
sphinx = { version = ">=6.2.1", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
sphinx-autoapi = { version = "*", optional = true }
sphinx-autoapi = { version = ">=3.0.0", optional = true }
sphinx-math-dollar = { version = "*", optional = true }
tox = { version = "*", optional = true }

Expand Down Expand Up @@ -85,8 +85,6 @@ all = [
"Homepage" = "https://gfn.readthedocs.io/en/latest/"
"Bug Tracker" = "https://github.com/saleml/gfn/issues"



[tool.black]
py36 = true
include = '\.pyi?$'
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(
self.training_objects = Transitions(env)
self.objects_type = "transitions"
elif objects_type == "states":
self.training_objects = env.States.from_batch_shape((0,))
self.terminating_states = env.States.from_batch_shape((0,))
self.training_objects = env.states_from_batch_shape((0,))
self.terminating_states = env.states_from_batch_shape((0,))
self.objects_type = "states"
else:
raise ValueError(f"Unknown objects_type: {objects_type}")
Expand Down
14 changes: 8 additions & 6 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,11 @@ def __init__(
self.states = (
states
if states is not None
else env.States.from_batch_shape(batch_shape=(0, 0))
else env.states_from_batch_shape((0, 0))
)
assert len(self.states.batch_shape) == 2
self.actions = (
actions
if actions is not None
else env.Actions.make_dummy_actions(batch_shape=(0, 0))
actions if actions is not None else env.actions_from_batch_shape((0, 0))
)
assert len(self.actions.batch_shape) == 2
self.when_is_done = (
Expand Down Expand Up @@ -253,9 +251,13 @@ def extend(self, other: Trajectories) -> None:

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and is_tensor(other.estimator_outputs):
if self.estimator_outputs is None and isinstance(
other.estimator_outputs, Tensor
):
self.estimator_outputs = other.estimator_outputs
elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs):
elif isinstance(self.estimator_outputs, Tensor) and isinstance(
other.estimator_outputs, Tensor
):
batch_shape = self.actions.batch_shape
n_bs = len(batch_shape)
output_dtype = self.estimator_outputs.dtype
Expand Down
8 changes: 3 additions & 5 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ def __init__(
self.states = (
states
if states is not None
else env.States.from_batch_shape(batch_shape=(0,))
else env.states_from_batch_shape(batch_shape=(0,))
)
assert len(self.states.batch_shape) == 1

self.actions = (
actions
if actions is not None
else env.Actions.make_dummy_actions(batch_shape=(0,))
actions if actions is not None else env.actions_from_batch_shape((0,))
)
self.is_done = (
is_done
Expand All @@ -85,7 +83,7 @@ def __init__(
self.next_states = (
next_states
if next_states is not None
else env.States.from_batch_shape(batch_shape=(0,))
else env.states_from_batch_shape(batch_shape=(0,))
)
assert (
len(self.next_states.batch_shape) == 1
Expand Down
Loading

0 comments on commit 3276492

Please sign in to comment.