Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more class factories #149

Merged
merged 30 commits into from
Feb 16, 2024
Merged

No more class factories #149

merged 30 commits into from
Feb 16, 2024

Conversation

josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Nov 29, 2023

To be merged after #147

  • make_States_class and make_Actions_class no longer need to be defined by the user - all relevant logic is submitted directly to the Env or DiscreteEnv subclass. (Of course, the user could overwrite the default DefaultEnvState and DefaultEnvAction classes returned by make_States_class and make_Actions_class IIF they require boutique functionality, but this is not expected to be a normal workflow).
  • I've removed maskless_ naming from step and backward_step.
  • I've removed "sensible defaults" on anything the user should have to implement themselves for the logic to work. The goal here is to ensure the code fails loudly, not silently.

As a result of this, multiple methods are offloaded from the States class into the Env, and make_random_states_tensor must be passed from the Env to the States class, which accounts for a large number of these diffs.

As an example, see the below Env definition for the line environment, which is complete:

class Line(Env):
    """Mixture of Gaussians Line environment."""
    def __init__(
        self,
        mus: list,
        sigmas: list,
        init_value: float,
        n_steps_per_trajectory: int = 5,
        device_str: Literal["cpu", "cuda"] = "cpu",
    ):
        assert len(mus) == len(sigmas)
        self.mus = torch.tensor(mus)
        self.sigmas = torch.tensor(sigmas)
        self.n_sd = n_sd
        self.n_steps_per_trajectory = n_steps_per_trajectory
        self.mixture = [Normal(m, s) for m, s in zip(self.mus, self.sigmas)]

        s0 = torch.tensor([init_value, 0.0], device=torch.device(device_str))
        dummy_action = torch.tensor([float("inf")], device=torch.device(device_str))
        exit_action = torch.tensor([-float("inf")], device=torch.device(device_str))
        super().__init__(
            s0=s0,
            state_shape=(2,),  # [x_pos, step_counter].
            action_shape=(1,),  # [x_pos]
            dummy_action=dummy_action,
            exit_action=exit_action,
        )  # sf is -inf by default.

    def step(
        self, states: States, actions: Actions) -> TT["batch_shape", 2, torch.float]:
        states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze(-1)  # x position.
        states.tensor[..., 1] = states.tensor[..., 1] + 1  # Step counter.
        return states.tensor

    def backward_step(
        self, states: States, actions: Actions) -> TT["batch_shape", 2, torch.float]:
        states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze(-1)  # x position.
        states.tensor[..., 1] = states.tensor[..., 1] - 1  # Step counter.
        return states.tensor

    def is_action_valid(self, states: States, actions: Actions, backward: bool = False) -> bool:
        # Can't take a backward step at the beginning of a trajectory.
        if torch.any(states[~actions.is_exit].is_initial_state) and backward:
            return False

        return True

    def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
        s = final_states.tensor[..., 0]
        log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape)
        for i, m in enumerate(self.mixture):
            log_rewards[i] = m.log_prob(s)

        return torch.logsumexp(log_rewards, 0)

    @property
    def log_partition(self) -> float:
        """Log Partition log of the number of gaussians."""
        return torch.tensor(len(self.mus)).log()

…_step functions private. this maybe isn't the best solution as they are accessed externally by other elements of the library. mask updating is now handled by the DiscreteEnv. A generic make_States_class and make_Actions_class method is added to both Env and DiscreteEnv.
…e_random_state_tensor is now a function passed to the States class as inheritance can no longer be relied on to overwrite the default method.
@josephdviviano josephdviviano added the enhancement New feature or request label Nov 29, 2023
@josephdviviano josephdviviano self-assigned this Nov 29, 2023
@marpaia marpaia mentioned this pull request Nov 30, 2023
@marpaia marpaia changed the base branch from master to rethinking_sampling November 30, 2023 09:27
@marpaia
Copy link
Collaborator

marpaia commented Nov 30, 2023

FYI @josephdviviano I changed the "base" branch to rethinking_sampling instead of master. This allows us to view this PRs changes in isolation. When you merge #147, this PR will automatically update to be based off of master again! Alternatively, you can merge this PR into #147 and then merge #147 into master and it will have the same effect. I would suggest merging #147 first though and then iterating on / merging this PR in isolation 🙌

Copy link
Collaborator

@marpaia marpaia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a fantastic PR and that the new API for defining environments represents a huge improvement to this codebase 🙌 I left a few comments about the naming of the functions make_xxx_class which I think would improve the API as well but in general, this is awesome. Thank you for making this change!

@@ -16,11 +16,6 @@
from gfn.containers.transitions import Transitions


def is_tensor(t) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice change 👌

src/gfn/env.py Outdated
raise NotImplementedError

# Optionally implemented by the user when advanced functionality is required.
def make_States_class(self) -> type[States]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that the class is States but I still would advocate that this method should be make_states_class to be more inline with PEP 8.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes makes sense to me. That naming bugged me as well. thanks.

src/gfn/env.py Outdated
make_random_states_tensor = env.make_random_states_tensor

return DefaultEnvState

def make_Actions_class(self) -> type[Actions]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here - I know that the class is Actions but I still would advocate that this method should be make_actions_class to be more inline with PEP 8.

src/gfn/env.py Outdated
n_actions = env.n_actions
device = env.device

return DiscreteEnvStates

def make_Actions_class(self) -> type[Actions]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here - I know that the class is Actions but I still would advocate that this method should be make_actions_class to be more inline with PEP 8.

# if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad.
# p.grad.data.clamp_(
# -gradient_clip_value, gradient_clip_value
# ).nan_to_num_(0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest deleting this code or adding a comment explaining why it's not included in the example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oups -- this is a mistake - it shouldn't be commented out - good catch :)

@josephdviviano
Copy link
Collaborator Author

I implemented the renaming and also realized I needed to update the documentation which is now fixed.

Base automatically changed from rethinking_sampling to master February 16, 2024 18:16
@josephdviviano josephdviviano merged commit 3276492 into master Feb 16, 2024
3 checks passed
@josephdviviano josephdviviano deleted the no_more_class_factories branch February 16, 2024 19:18
@saleml saleml mentioned this pull request Feb 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants