From b67b68e979ea1a9566cc3f726614505631df6a54 Mon Sep 17 00:00:00 2001 From: Michael Fishman Date: Thu, 29 Dec 2022 20:00:15 -0500 Subject: [PATCH] Type annotations Changed State to a typed NamedTuple --- pddlgym/core.py | 2 +- pddlgym/parser.py | 11 ++++++----- pddlgym/spaces.py | 41 +++++++++++++++++++++-------------------- pddlgym/structs.py | 20 +++++++++++++++----- 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/pddlgym/core.py b/pddlgym/core.py index 2f87e974..cebb9d43 100644 --- a/pddlgym/core.py +++ b/pddlgym/core.py @@ -546,7 +546,7 @@ def render(self, *args, **kwargs): if self._render: return self._render(self._state.literals, *args, **kwargs) - def _handle_derived_literals(self, state): + def _handle_derived_literals(self, state: State): # first remove any old derived literals since they're outdated to_remove = set() for lit in state.literals: diff --git a/pddlgym/parser.py b/pddlgym/parser.py index b315ca3d..2a1fdbed 100644 --- a/pddlgym/parser.py +++ b/pddlgym/parser.py @@ -6,7 +6,7 @@ DerivedPredicate, NoChange) import re - +from typing import List, Any, Optional, Dict FAST_DOWNWARD_STR = """ (define (problem {problem}) (:domain {domain}) @@ -33,7 +33,7 @@ class Operator: """Class to hold an operator. """ - def __init__(self, name, params, preconds, effects): + def __init__(self, name: str, params: List[Type], preconds: List[Literal], effects: List[Literal]): self.name = name # string self.params = params # list of structs.Type objects self.preconds = preconds # structs.Literal representing preconditions @@ -99,6 +99,7 @@ def _create_preconds_pddl_str(self, preconds): class PDDLParser: """PDDL parsing class. """ + predicates: Dict[str, Predicate] def _parse_into_literal(self, string, params, is_effect=False): """Parse the given string (representing either preconditions or effects) into a literal. Check against params to make sure typing is correct. @@ -317,13 +318,13 @@ def __init__(self, domain_name=None, types=None, type_hierarchy=None, predicates # String of domain name. self.domain_name = domain_name # Dict from type name -> structs.Type object. - self.types = types + self.types: Dict[str, Type] = types # Dict from supertype -> immediate subtypes. self.type_hierarchy = type_hierarchy # Dict from predicate name -> structs.Predicate object. - self.predicates = predicates + self.predicates: Dict[str, Predicate] = predicates # Dict from operator name -> Operator object (class defined above). - self.operators = operators + self.operators: Dict[str, Operator] = operators # Action predicate names (not part of standard PDDL) self.actions = actions # Constant objects, shared across problems diff --git a/pddlgym/spaces.py b/pddlgym/spaces.py index e5dbd828..5b7084d9 100644 --- a/pddlgym/spaces.py +++ b/pddlgym/spaces.py @@ -4,8 +4,8 @@ each episode, since objects, and therefore possible groundings, may change with each new PDDL problem. """ -from pddlgym.structs import LiteralConjunction, Literal, ground_literal -from pddlgym.parser import PDDLProblemParser +from pddlgym.structs import LiteralConjunction, Literal, ground_literal, Predicate, State +from pddlgym.parser import PDDLProblemParser, PDDLDomain from pddlgym.downward_translate.instantiate import explore as downward_explore from pddlgym.downward_translate.pddl_parser import open as downward_open from pddlgym.utils import nostdout @@ -15,14 +15,15 @@ import os import tempfile import itertools +from typing import Set, Collection, Callable TMP_PDDL_DIR = "/dev/shm" if os.path.exists("/dev/shm") else None class LiteralSpace(Space): - def __init__(self, predicates, - lit_valid_test=lambda state,lit: True, + def __init__(self, predicates: Collection[Predicate], + lit_valid_test: Callable[[State, Literal], bool] =lambda state,lit: True, type_hierarchy=None, type_to_parent_types=None): self.predicates = sorted(predicates) @@ -33,10 +34,10 @@ def __init__(self, predicates, self._type_to_parent_types = type_to_parent_types super().__init__() - def reset_initial_state(self, initial_state): + def reset_initial_state(self, initial_state: State): self._objects = None - def _update_objects_from_state(self, state): + def _update_objects_from_state(self, state: State): """Given a state, extract the objects and if they have changed, recompute all ground literals """ @@ -58,7 +59,7 @@ def _update_objects_from_state(self, state): self._objects = state.objects self._all_ground_literals = sorted(self._compute_all_ground_literals(state)) - def sample_literal(self, state): + def sample_literal(self, state: State) -> Literal: while True: num_lits = len(self._all_ground_literals) idx = self.np_random.choice(num_lits) @@ -67,19 +68,19 @@ def sample_literal(self, state): break return lit - def sample(self, state): + def sample(self, state: State) -> Literal: self._update_objects_from_state(state) return self.sample_literal(state) - def all_ground_literals(self, state, valid_only=True): + def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]: self._update_objects_from_state(state) if not valid_only: return set(self._all_ground_literals) return set(l for l in self._all_ground_literals \ if self._lit_valid_test(state, l)) - def _compute_all_ground_literals(self, state): - all_ground_literals = set() + def _compute_all_ground_literals(self, state: State) -> set[Literal]: + all_ground_literals: Set[Literal] = set() for predicate in self.predicates: choices = [self._type_to_objs[vt] for vt in predicate.var_types] for choice in itertools.product(*choices): @@ -95,7 +96,7 @@ class LiteralActionSpace(LiteralSpace): For now, assumes operators_as_actions. """ - def __init__(self, domain, predicates, + def __init__(self, domain: PDDLDomain, predicates: Collection[Predicate], type_hierarchy=None, type_to_parent_types=None): self.domain = domain self._initial_state = None @@ -116,11 +117,11 @@ def __init__(self, domain, predicates, type_hierarchy=type_hierarchy, type_to_parent_types=type_to_parent_types) - def reset_initial_state(self, initial_state): + def reset_initial_state(self, initial_state: State) -> None: super().reset_initial_state(initial_state) self._initial_state = initial_state - def _update_objects_from_state(self, state): + def _update_objects_from_state(self, state: State) -> None: # Check whether the objects have changed # If so, we need to recompute things if state.objects == self._objects: @@ -148,18 +149,18 @@ def _update_objects_from_state(self, state): self._ground_action_to_pos_preconds[ground_action] = pos_preconds self._ground_action_to_neg_preconds[ground_action] = neg_preconds - def sample_literal(self, state): + def sample_literal(self, state: State) -> Literal: valid_literals = self.all_ground_literals(state) valid_literals = list(sorted(valid_literals)) return valid_literals[self.np_random.choice(len(valid_literals))] - def sample(self, state): + def sample(self, state: State) -> Literal: return self.sample_literal(state) - def all_ground_literals(self, state, valid_only=True): + def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]: self._update_objects_from_state(state) assert valid_only, "The point of this class is to avoid the cross product!" - valid_literals = set() + valid_literals: Set[Literal] = set() for ground_action in self._all_ground_literals: pos_preconds = self._ground_action_to_pos_preconds[ground_action] if not pos_preconds.issubset(state.literals): @@ -170,7 +171,7 @@ def all_ground_literals(self, state, valid_only=True): valid_literals.add(ground_action) return valid_literals - def _compute_all_ground_literals(self, state): + def _compute_all_ground_literals(self, state: State) -> set[Literal]: """Call FastDownward's instantiator. """ # Generate temporary files to hand over to instantiator. @@ -193,7 +194,7 @@ def _compute_all_ground_literals(self, state): _, _, actions, _, _ = downward_explore(task) # Post-process to our representation. obj_name_to_obj = {obj.name: obj for obj in state.objects} - all_ground_literals = set() + all_ground_literals: Set[Literal] = set() for action in actions: name = action.name.strip().strip("()").split() pred_name, obj_names = name[0], name[1:] diff --git a/pddlgym/structs.py b/pddlgym/structs.py index 49943103..bce2d520 100644 --- a/pddlgym/structs.py +++ b/pddlgym/structs.py @@ -1,8 +1,12 @@ +from __future__ import annotations + """Python classes for common PDDL structures""" from collections import namedtuple import itertools import numpy as np +from typing import NamedTuple, FrozenSet, Collection, Any + ### PDDL Types, Objects, Variables ### class Type(str): @@ -475,24 +479,30 @@ def max(self): ### States ### # A State is a frozenset of ground literals and a frozenset of objects -class State(namedtuple("State", ["literals", "objects", "goal"])): - __slots__ = () +# class State(namedtuple("State", ["literals", "objects", "goal"])): +class State(NamedTuple): + literals: FrozenSet[Literal] + objects: FrozenSet[TypedEntity] + goal: Any + # MF: We can't easily type goal yet bc it can be a Literal, LiteralConjunction, etc + # Ideally we'd make a parent class for all of those. + - def with_literals(self, literals): + def with_literals(self, literals: Collection[Literal]) -> State: """ Return a new state that has the same objects and goal as the given one, but has the given set of literals instead of state.literals. """ return self._replace(literals=frozenset(literals)) - def with_objects(self, objects): + def with_objects(self, objects: Collection[TypedEntity]) -> State: """ Return a new state that has the same literals and goal as the given one, but has the given set of objects instead of state.objects. """ return self._replace(objects=frozenset(objects)) - def with_goal(self, goal): + def with_goal(self, goal) -> State: """ Return a new state that has the same literals and objects as the given one, but has the given goal instead of state.goal.