Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations #70

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pddlgym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def render(self, *args, **kwargs):
if self._render:
return self._render(self._state.literals, *args, **kwargs)

def _handle_derived_literals(self, state):
def _handle_derived_literals(self, state: State):
# first remove any old derived literals since they're outdated
to_remove = set()
for lit in state.literals:
Expand Down
11 changes: 6 additions & 5 deletions pddlgym/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DerivedPredicate, NoChange)

import re

from typing import List, Any, Optional, Dict

FAST_DOWNWARD_STR = """
(define (problem {problem}) (:domain {domain})
Expand All @@ -33,7 +33,7 @@
class Operator:
"""Class to hold an operator.
"""
def __init__(self, name, params, preconds, effects):
def __init__(self, name: str, params: List[Type], preconds: List[Literal], effects: List[Literal]):
self.name = name # string
self.params = params # list of structs.Type objects
self.preconds = preconds # structs.Literal representing preconditions
Expand Down Expand Up @@ -99,6 +99,7 @@ def _create_preconds_pddl_str(self, preconds):
class PDDLParser:
"""PDDL parsing class.
"""
predicates: Dict[str, Predicate]
def _parse_into_literal(self, string, params, is_effect=False):
"""Parse the given string (representing either preconditions or effects)
into a literal. Check against params to make sure typing is correct.
Expand Down Expand Up @@ -317,13 +318,13 @@ def __init__(self, domain_name=None, types=None, type_hierarchy=None, predicates
# String of domain name.
self.domain_name = domain_name
# Dict from type name -> structs.Type object.
self.types = types
self.types: Dict[str, Type] = types
# Dict from supertype -> immediate subtypes.
self.type_hierarchy = type_hierarchy
# Dict from predicate name -> structs.Predicate object.
self.predicates = predicates
self.predicates: Dict[str, Predicate] = predicates
# Dict from operator name -> Operator object (class defined above).
self.operators = operators
self.operators: Dict[str, Operator] = operators
# Action predicate names (not part of standard PDDL)
self.actions = actions
# Constant objects, shared across problems
Expand Down
41 changes: 21 additions & 20 deletions pddlgym/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
each episode, since objects, and therefore possible
groundings, may change with each new PDDL problem.
"""
from pddlgym.structs import LiteralConjunction, Literal, ground_literal
from pddlgym.parser import PDDLProblemParser
from pddlgym.structs import LiteralConjunction, Literal, ground_literal, Predicate, State
from pddlgym.parser import PDDLProblemParser, PDDLDomain
from pddlgym.downward_translate.instantiate import explore as downward_explore
from pddlgym.downward_translate.pddl_parser import open as downward_open
from pddlgym.utils import nostdout
Expand All @@ -15,14 +15,15 @@
import os
import tempfile
import itertools
from typing import Set, Collection, Callable

TMP_PDDL_DIR = "/dev/shm" if os.path.exists("/dev/shm") else None


class LiteralSpace(Space):

def __init__(self, predicates,
lit_valid_test=lambda state,lit: True,
def __init__(self, predicates: Collection[Predicate],
lit_valid_test: Callable[[State, Literal], bool] =lambda state,lit: True,
type_hierarchy=None,
type_to_parent_types=None):
self.predicates = sorted(predicates)
Expand All @@ -33,10 +34,10 @@ def __init__(self, predicates,
self._type_to_parent_types = type_to_parent_types
super().__init__()

def reset_initial_state(self, initial_state):
def reset_initial_state(self, initial_state: State):
self._objects = None

def _update_objects_from_state(self, state):
def _update_objects_from_state(self, state: State):
"""Given a state, extract the objects and if they have changed,
recompute all ground literals
"""
Expand All @@ -58,7 +59,7 @@ def _update_objects_from_state(self, state):
self._objects = state.objects
self._all_ground_literals = sorted(self._compute_all_ground_literals(state))

def sample_literal(self, state):
def sample_literal(self, state: State) -> Literal:
while True:
num_lits = len(self._all_ground_literals)
idx = self.np_random.choice(num_lits)
Expand All @@ -67,19 +68,19 @@ def sample_literal(self, state):
break
return lit

def sample(self, state):
def sample(self, state: State) -> Literal:
self._update_objects_from_state(state)
return self.sample_literal(state)

def all_ground_literals(self, state, valid_only=True):
def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]:
self._update_objects_from_state(state)
if not valid_only:
return set(self._all_ground_literals)
return set(l for l in self._all_ground_literals \
if self._lit_valid_test(state, l))

def _compute_all_ground_literals(self, state):
all_ground_literals = set()
def _compute_all_ground_literals(self, state: State) -> set[Literal]:
all_ground_literals: Set[Literal] = set()
for predicate in self.predicates:
choices = [self._type_to_objs[vt] for vt in predicate.var_types]
for choice in itertools.product(*choices):
Expand All @@ -95,7 +96,7 @@ class LiteralActionSpace(LiteralSpace):

For now, assumes operators_as_actions.
"""
def __init__(self, domain, predicates,
def __init__(self, domain: PDDLDomain, predicates: Collection[Predicate],
type_hierarchy=None, type_to_parent_types=None):
self.domain = domain
self._initial_state = None
Expand All @@ -116,11 +117,11 @@ def __init__(self, domain, predicates,
type_hierarchy=type_hierarchy,
type_to_parent_types=type_to_parent_types)

def reset_initial_state(self, initial_state):
def reset_initial_state(self, initial_state: State) -> None:
super().reset_initial_state(initial_state)
self._initial_state = initial_state

def _update_objects_from_state(self, state):
def _update_objects_from_state(self, state: State) -> None:
# Check whether the objects have changed
# If so, we need to recompute things
if state.objects == self._objects:
Expand Down Expand Up @@ -148,18 +149,18 @@ def _update_objects_from_state(self, state):
self._ground_action_to_pos_preconds[ground_action] = pos_preconds
self._ground_action_to_neg_preconds[ground_action] = neg_preconds

def sample_literal(self, state):
def sample_literal(self, state: State) -> Literal:
valid_literals = self.all_ground_literals(state)
valid_literals = list(sorted(valid_literals))
return valid_literals[self.np_random.choice(len(valid_literals))]

def sample(self, state):
def sample(self, state: State) -> Literal:
return self.sample_literal(state)

def all_ground_literals(self, state, valid_only=True):
def all_ground_literals(self, state: State, valid_only: bool=True) -> set[Literal]:
self._update_objects_from_state(state)
assert valid_only, "The point of this class is to avoid the cross product!"
valid_literals = set()
valid_literals: Set[Literal] = set()
for ground_action in self._all_ground_literals:
pos_preconds = self._ground_action_to_pos_preconds[ground_action]
if not pos_preconds.issubset(state.literals):
Expand All @@ -170,7 +171,7 @@ def all_ground_literals(self, state, valid_only=True):
valid_literals.add(ground_action)
return valid_literals

def _compute_all_ground_literals(self, state):
def _compute_all_ground_literals(self, state: State) -> set[Literal]:
"""Call FastDownward's instantiator.
"""
# Generate temporary files to hand over to instantiator.
Expand All @@ -193,7 +194,7 @@ def _compute_all_ground_literals(self, state):
_, _, actions, _, _ = downward_explore(task)
# Post-process to our representation.
obj_name_to_obj = {obj.name: obj for obj in state.objects}
all_ground_literals = set()
all_ground_literals: Set[Literal] = set()
for action in actions:
name = action.name.strip().strip("()").split()
pred_name, obj_names = name[0], name[1:]
Expand Down
20 changes: 15 additions & 5 deletions pddlgym/structs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

"""Python classes for common PDDL structures"""
from collections import namedtuple
import itertools
import numpy as np

from typing import NamedTuple, FrozenSet, Collection, Any


### PDDL Types, Objects, Variables ###
class Type(str):
Expand Down Expand Up @@ -475,24 +479,30 @@ def max(self):
### States ###

# A State is a frozenset of ground literals and a frozenset of objects
class State(namedtuple("State", ["literals", "objects", "goal"])):
__slots__ = ()
# class State(namedtuple("State", ["literals", "objects", "goal"])):
class State(NamedTuple):
literals: FrozenSet[Literal]
objects: FrozenSet[TypedEntity]
goal: Any
# MF: We can't easily type goal yet bc it can be a Literal, LiteralConjunction, etc
# Ideally we'd make a parent class for all of those.


def with_literals(self, literals):
def with_literals(self, literals: Collection[Literal]) -> State:
"""
Return a new state that has the same objects and goal as the given one,
but has the given set of literals instead of state.literals.
"""
return self._replace(literals=frozenset(literals))

def with_objects(self, objects):
def with_objects(self, objects: Collection[TypedEntity]) -> State:
"""
Return a new state that has the same literals and goal as the given one,
but has the given set of objects instead of state.objects.
"""
return self._replace(objects=frozenset(objects))

def with_goal(self, goal):
def with_goal(self, goal) -> State:
"""
Return a new state that has the same literals and objects as the given
one, but has the given goal instead of state.goal.
Expand Down