Skip to content

Commit

Permalink
fix: update import statements and clean up PyBulletGrow
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Dec 28, 2024
1 parent 316bb51 commit 27f20c2
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 17 deletions.
4 changes: 2 additions & 2 deletions predicators/envs/pybullet_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
create_single_arm_pybullet_robot
from predicators.settings import CFG
from predicators.structs import Action, Array, EnvironmentTask, Object, \
Predicate, State, Type
from predicators.utils import NSPredicate, RawState, VLMQuery
Predicate, State, Type, NSPredicate
from predicators.utils import RawState, VLMQuery


class PyBulletBalanceEnv(PyBulletEnv, BalanceEnv):
Expand Down
6 changes: 3 additions & 3 deletions predicators/envs/pybullet_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from predicators import utils
from predicators.envs import BaseEnv
from predicators.pybullet_helpers.camera import create_gui_connection
from predicators.pybullet_helpers.geometry import Pose3D, Quaternion
from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion
from predicators.pybullet_helpers.link import get_link_state
from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot, \
create_single_arm_pybullet_robot
from predicators.settings import CFG
from predicators.structs import Action, Array, EnvironmentTask, Mask, Object, \
Observation, Pose, State, Video
Observation, State, Video
from predicators.utils import PyBulletState


Expand Down Expand Up @@ -173,7 +173,7 @@ def get_pos_feature(state, feature_name):
rz = get_pos_feature(state, "z")

# EE Orientation
_, default_tilt, default_wrist = p.getQuaternionFromEuler(
_, default_tilt, default_wrist = p.getEulerFromQuaternion(
self.get_robot_ee_home_orn())
if "tilt" in self._robot.type.feature_names:
tilt = state.get(self._robot, "tilt")
Expand Down
11 changes: 5 additions & 6 deletions predicators/envs/pybullet_grow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \
"""
python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \
--num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \
--sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \
--pybullet_camera_height 900 --pybullet_camera_width 900
Expand Down Expand Up @@ -79,8 +79,6 @@ class PyBulletGrowEnv(PyBulletEnv):
_camera_target: ClassVar[Pose3D] = (0.75, 1.25, 0.42)

def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

# Define Types:
self._robot_type = Type("robot",
["x", "y", "z", "fingers", "tilt", "wrist"])
Expand All @@ -95,6 +93,8 @@ def __init__(self, use_gui: bool = True) -> None:
self._red_jug = Object("red_jug", self._jug_type)
self._blue_jug = Object("blue_jug", self._jug_type)

super().__init__(use_gui)

# Define Predicates
self._Grown = Predicate("Grown", [self._cup_type], self._Grown_holds)
self._Holding = Predicate("Holding",
Expand Down Expand Up @@ -213,7 +213,7 @@ def _get_state(self) -> State:
"x": rx,
"y": ry,
"z": rz,
"fingers": self._fingers_joint_to_state(rf),
"fingers": self._fingers_joint_to_state(self._pybullet_robot, rf),
"tilt": tilt,
"wrist": wrist
}
Expand Down Expand Up @@ -355,7 +355,6 @@ def step(self, action: Action, render_obs: bool = False) -> State:
cx = next_state.get(cup_obj, "x")
cy = next_state.get(cup_obj, "y")
dist = np.hypot(jug_x - cx, jug_y - cy)
logging.debug(f"Dist to cup {cup_obj.name}: {dist}")
if dist < 0.13: # "over" the cup
# cup_color = next_state.get(cup_obj, "color")
# if abs(cup_color - jug_color) < 0.1:
Expand Down
14 changes: 8 additions & 6 deletions predicators/pybullet_helpers/objects.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

import pybullet as p

from predicators import utils
from predicators.envs.pybullet_env import PybulletEnv
from predicators.envs.pybullet_env import PyBulletEnv
from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion

# import numpy as np
Expand All @@ -12,8 +12,8 @@

def create_object(asset_path: str,
position: Pose3D = (0, 0, 0),
orientation: Quaternion = PybulletEnv._default_orn,
color: Optional[Sequence[float, float, float, float]] = None,
orientation: Quaternion = PyBulletEnv._default_orn,
color: Optional[Tuple[float, float, float, float]] = None,
scale: float = 0.2,
use_fixed_base: bool = False,
physics_client_id: int = 0) -> int:
Expand All @@ -27,6 +27,8 @@ def create_object(asset_path: str,
orientation,
physicsClientId=physics_client_id)
if color is not None:
# Change color of the base link (link_id = -1)
p.changeVisualShape(obj_id, -1, rgbaColor=color)
# Change color of all links
for link_id in range(p.getNumJoints(obj_id)):
p.changeVisualShape(obj_id, link_id, rgbaColor=color)
Expand All @@ -36,8 +38,8 @@ def create_object(asset_path: str,

def update_object(obj_id: int,
position: Pose3D,
orientation: Quaternion = PybulletEnv._default_orn,
color: Optional[Sequence[float, float, float, float]] = None,
orientation: Quaternion = PyBulletEnv._default_orn,
color: Optional[Tuple[float, float, float, float]] = None,
physics_client_id: int = 0) -> None:
"""Update the position and orientation of an object."""
p.resetBasePositionAndOrientation(obj_id,
Expand Down
66 changes: 66 additions & 0 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,72 @@ class VLMPredicate(Predicate):
"""
get_vlm_query_str: Callable[[Sequence[Object]], str]

# @dataclass(frozen=True, repr=False)
class NSPredicate(Predicate):
"""Neuro-Symbolic Predicate."""

def __init__(
self, name: str, types: Sequence[Type],
_classifier: Callable[[RawState, Sequence[Object]], bool]) -> None:
self._original_classifier = _classifier
super().__init__(name, types, _MemoizedClassifier(_classifier))

@cached_property
def _hash(self) -> int:
# return hash(str(self))
return hash(self.name + str(self.types))

def __hash__(self) -> int:
return self._hash

def classifier_str(self) -> str:
"""Get a string representation of the classifier."""
clf_str = getsource(self._original_classifier)
clf_str = textwrap.dedent(clf_str)
clf_str = clf_str.replace("@staticmethod\n", "")
return clf_str

@dataclass(frozen=True, order=False, repr=False)
class ConceptPredicate(Predicate):
"""Struct defining a concept predicate"""
name: str
types: Sequence[Type]
# The classifier takes in a complete state and a sequence of objects
# representing the arguments. These objects should be the only ones
# treated "specially" by the classifier.
_classifier: Callable[[Set[GroundAtom], Sequence[Object]],
bool] = field(compare=False)
untransformed_predicate: Optional[Predicate] = field(default=None,
compare=False)
auxiliary_concepts: Optional[Set[ConceptPredicate]] = field(default=None,
compare=False)

def update_auxiliary_concepts(self,
auxiliary_concepts: Set[ConceptPredicate]) -> ConceptPredicate:
"""Create a new ConceptPredicate with updated auxiliary_concepts."""
return replace(self, auxiliary_concepts=auxiliary_concepts)


@cached_property
def _hash(self) -> int:
# return hash(str(self))
return hash(self.name + str(self.types))

def holds(self, state: Set[GroundAtom], objects: Sequence[Object]) -> bool:
"""Public method for calling the classifier.
Performs type checking first.
"""
assert len(objects) == self.arity
for obj, pred_type in zip(objects, self.types):
assert isinstance(obj, Object)
assert obj.is_instance(pred_type)
return self._classifier(state, objects)

def _negated_classifier(self, state: Set[GroundAtom],
objects: Sequence[Object]) -> bool:
# Separate this into a named function for pickling reasons.
return not self._classifier(state, objects)

@dataclass(frozen=True, repr=False, eq=False)
class _Atom:
Expand Down
Loading

0 comments on commit 27f20c2

Please sign in to comment.