Skip to content

Commit

Permalink
Type annotations
Browse files Browse the repository at this point in the history
Changed State to a typed NamedTuple
  • Loading branch information
MichaelJFishman committed Dec 30, 2022
1 parent ec37dd9 commit 2712bfd
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pddlgym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def render(self, *args, **kwargs):
if self._render:
return self._render(self._state.literals, *args, **kwargs)

def _handle_derived_literals(self, state):
def _handle_derived_literals(self, state: State):
# first remove any old derived literals since they're outdated
to_remove = set()
for lit in state.literals:
Expand Down
11 changes: 6 additions & 5 deletions pddlgym/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DerivedPredicate, NoChange)

import re

from typing import List, Any, Optional, Dict

FAST_DOWNWARD_STR = """
(define (problem {problem}) (:domain {domain})
Expand All @@ -33,7 +33,7 @@
class Operator:
"""Class to hold an operator.
"""
def __init__(self, name, params, preconds, effects):
def __init__(self, name: str, params: List[Type], preconds: List[Literal], effects: List[Literal]):
self.name = name # string
self.params = params # list of structs.Type objects
self.preconds = preconds # structs.Literal representing preconditions
Expand Down Expand Up @@ -99,6 +99,7 @@ def _create_preconds_pddl_str(self, preconds):
class PDDLParser:
"""PDDL parsing class.
"""
predicates: Dict[str, Predicate]
def _parse_into_literal(self, string, params, is_effect=False):
"""Parse the given string (representing either preconditions or effects)
into a literal. Check against params to make sure typing is correct.
Expand Down Expand Up @@ -317,13 +318,13 @@ def __init__(self, domain_name=None, types=None, type_hierarchy=None, predicates
# String of domain name.
self.domain_name = domain_name
# Dict from type name -> structs.Type object.
self.types = types
self.types: Dict[str, Type] = types
# Dict from supertype -> immediate subtypes.
self.type_hierarchy = type_hierarchy
# Dict from predicate name -> structs.Predicate object.
self.predicates = predicates
self.predicates: Dict[str, Predicate] = predicates
# Dict from operator name -> Operator object (class defined above).
self.operators = operators
self.operators: Dict[str, Operator] = operators
# Action predicate names (not part of standard PDDL)
self.actions = actions
# Constant objects, shared across problems
Expand Down
41 changes: 21 additions & 20 deletions pddlgym/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
each episode, since objects, and therefore possible
groundings, may change with each new PDDL problem.
"""
from pddlgym.structs import LiteralConjunction, Literal, ground_literal
from pddlgym.parser import PDDLProblemParser
from pddlgym.structs import LiteralConjunction, Literal, ground_literal, Predicate, State
from pddlgym.parser import PDDLProblemParser, PDDLDomain
from pddlgym.downward_translate.instantiate import explore as downward_explore
from pddlgym.downward_translate.pddl_parser import open as downward_open
from pddlgym.utils import nostdout
Expand All @@ -15,14 +15,15 @@
import os
import tempfile
import itertools
from typing import Set, Collection, Callable

TMP_PDDL_DIR = "/dev/shm" if os.path.exists("/dev/shm") else None


class LiteralSpace(Space):

def __init__(self, predicates,
lit_valid_test=lambda state,lit: True,
def __init__(self, predicates: Collection[Predicate],
lit_valid_test: Callable[[State, Literal], bool] =lambda state,lit: True,
type_hierarchy=None,
type_to_parent_types=None):
self.predicates = sorted(predicates)
Expand All @@ -33,10 +34,10 @@ def __init__(self, predicates,
self._type_to_parent_types = type_to_parent_types
super().__init__()

def reset_initial_state(self, initial_state):
def reset_initial_state(self, initial_state: State):
self._objects = None

def _update_objects_from_state(self, state):
def _update_objects_from_state(self, state: State):
"""Given a state, extract the objects and if they have changed,
recompute all ground literals
"""
Expand All @@ -58,7 +59,7 @@ def _update_objects_from_state(self, state):
self._objects = state.objects
self._all_ground_literals = sorted(self._compute_all_ground_literals(state))

def sample_literal(self, state):
def sample_literal(self, state: State) -> Literal:
while True:
num_lits = len(self._all_ground_literals)
idx = self.np_random.choice(num_lits)
Expand All @@ -67,19 +68,19 @@ def sample_literal(self, state):
break
return lit

def sample(self, state):
def sample(self, state: State) -> Literal:
self._update_objects_from_state(state)
return self.sample_literal(state)

def all_ground_literals(self, state, valid_only=True):
def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]:
self._update_objects_from_state(state)
if not valid_only:
return set(self._all_ground_literals)
return set(l for l in self._all_ground_literals \
if self._lit_valid_test(state, l))

def _compute_all_ground_literals(self, state):
all_ground_literals = set()
def _compute_all_ground_literals(self, state: State) -> set[Literal]:
all_ground_literals: Set[Literal] = set()
for predicate in self.predicates:
choices = [self._type_to_objs[vt] for vt in predicate.var_types]
for choice in itertools.product(*choices):
Expand All @@ -95,7 +96,7 @@ class LiteralActionSpace(LiteralSpace):
For now, assumes operators_as_actions.
"""
def __init__(self, domain, predicates,
def __init__(self, domain: PDDLDomain, predicates: Collection[Predicate],
type_hierarchy=None, type_to_parent_types=None):
self.domain = domain
self._initial_state = None
Expand All @@ -116,11 +117,11 @@ def __init__(self, domain, predicates,
type_hierarchy=type_hierarchy,
type_to_parent_types=type_to_parent_types)

def reset_initial_state(self, initial_state):
def reset_initial_state(self, initial_state: State) -> None:
super().reset_initial_state(initial_state)
self._initial_state = initial_state

def _update_objects_from_state(self, state):
def _update_objects_from_state(self, state: State) -> None:
# Check whether the objects have changed
# If so, we need to recompute things
if state.objects == self._objects:
Expand Down Expand Up @@ -148,18 +149,18 @@ def _update_objects_from_state(self, state):
self._ground_action_to_pos_preconds[ground_action] = pos_preconds
self._ground_action_to_neg_preconds[ground_action] = neg_preconds

def sample_literal(self, state):
def sample_literal(self, state: State) -> Literal:
valid_literals = self.all_ground_literals(state)
valid_literals = list(sorted(valid_literals))
return valid_literals[self.np_random.choice(len(valid_literals))]

def sample(self, state):
def sample(self, state: State) -> Literal:
return self.sample_literal(state)

def all_ground_literals(self, state, valid_only=True):
def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]:
self._update_objects_from_state(state)
assert valid_only, "The point of this class is to avoid the cross product!"
valid_literals = set()
valid_literals: Set[Literal] = set()
for ground_action in self._all_ground_literals:
pos_preconds = self._ground_action_to_pos_preconds[ground_action]
if not pos_preconds.issubset(state.literals):
Expand All @@ -170,7 +171,7 @@ def all_ground_literals(self, state, valid_only=True):
valid_literals.add(ground_action)
return valid_literals

def _compute_all_ground_literals(self, state):
def _compute_all_ground_literals(self, state: State) -> set[Literal]:
"""Call FastDownward's instantiator.
"""
# Generate temporary files to hand over to instantiator.
Expand All @@ -193,7 +194,7 @@ def _compute_all_ground_literals(self, state):
_, _, actions, _, _ = downward_explore(task)
# Post-process to our representation.
obj_name_to_obj = {obj.name: obj for obj in state.objects}
all_ground_literals = set()
all_ground_literals: Set[Literal] = set()
for action in actions:
name = action.name.strip().strip("()").split()
pred_name, obj_names = name[0], name[1:]
Expand Down
20 changes: 15 additions & 5 deletions pddlgym/structs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

"""Python classes for common PDDL structures"""
from collections import namedtuple
import itertools
import numpy as np

from typing import NamedTuple, FrozenSet, Collection, Any


### PDDL Types, Objects, Variables ###
class Type(str):
Expand Down Expand Up @@ -475,24 +479,30 @@ def max(self):
### States ###

# A State is a frozenset of ground literals and a frozenset of objects
class State(namedtuple("State", ["literals", "objects", "goal"])):
__slots__ = ()
# class State(namedtuple("State", ["literals", "objects", "goal"])):
class State(NamedTuple):
literals: FrozenSet[Literal]
objects: FrozenSet[TypedEntity]
goal: Any
# MF: We can't easily type goal yet bc it can be a Literal, LiteralConjunction, etc
# Ideally we'd make a parent class for all of those.


def with_literals(self, literals):
def with_literals(self, literals: Collection[Literal]) -> State:
"""
Return a new state that has the same objects and goal as the given one,
but has the given set of literals instead of state.literals.
"""
return self._replace(literals=frozenset(literals))

def with_objects(self, objects):
def with_objects(self, objects: Collection[TypedEntity]) -> State:
"""
Return a new state that has the same literals and goal as the given one,
but has the given set of objects instead of state.objects.
"""
return self._replace(objects=frozenset(objects))

def with_goal(self, goal):
def with_goal(self, goal) -> State:
"""
Return a new state that has the same literals and objects as the given
one, but has the given goal instead of state.goal.
Expand Down

0 comments on commit 2712bfd

Please sign in to comment.