diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 551ab335f..2ced7c94d 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -17,8 +17,8 @@ create_single_arm_pybullet_robot from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, Object, \ - Predicate, State, Type -from predicators.utils import NSPredicate, RawState, VLMQuery + Predicate, State, Type, NSPredicate +from predicators.utils import RawState, VLMQuery class PyBulletBalanceEnv(PyBulletEnv, BalanceEnv): diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index 4b12b5032..4ff3adb3b 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -15,13 +15,13 @@ from predicators import utils from predicators.envs import BaseEnv from predicators.pybullet_helpers.camera import create_gui_connection -from predicators.pybullet_helpers.geometry import Pose3D, Quaternion +from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot, \ create_single_arm_pybullet_robot from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, Mask, Object, \ - Observation, Pose, State, Video + Observation, State, Video from predicators.utils import PyBulletState @@ -173,7 +173,7 @@ def get_pos_feature(state, feature_name): rz = get_pos_feature(state, "z") # EE Orientation - _, default_tilt, default_wrist = p.getQuaternionFromEuler( + _, default_tilt, default_wrist = p.getEulerFromQuaternion( self.get_robot_ee_home_orn()) if "tilt" in self._robot.type.feature_names: tilt = state.get(self._robot, "tilt") diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index 1ea9cdf65..66f460255 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -1,5 +1,5 @@ -"""python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \ - +""" +python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \ --num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \ --sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \ --pybullet_camera_height 900 --pybullet_camera_width 900 @@ -79,8 +79,6 @@ class PyBulletGrowEnv(PyBulletEnv): _camera_target: ClassVar[Pose3D] = (0.75, 1.25, 0.42) def __init__(self, use_gui: bool = True) -> None: - super().__init__(use_gui) - # Define Types: self._robot_type = Type("robot", ["x", "y", "z", "fingers", "tilt", "wrist"]) @@ -95,6 +93,8 @@ def __init__(self, use_gui: bool = True) -> None: self._red_jug = Object("red_jug", self._jug_type) self._blue_jug = Object("blue_jug", self._jug_type) + super().__init__(use_gui) + # Define Predicates self._Grown = Predicate("Grown", [self._cup_type], self._Grown_holds) self._Holding = Predicate("Holding", @@ -213,7 +213,7 @@ def _get_state(self) -> State: "x": rx, "y": ry, "z": rz, - "fingers": self._fingers_joint_to_state(rf), + "fingers": self._fingers_joint_to_state(self._pybullet_robot, rf), "tilt": tilt, "wrist": wrist } @@ -355,7 +355,6 @@ def step(self, action: Action, render_obs: bool = False) -> State: cx = next_state.get(cup_obj, "x") cy = next_state.get(cup_obj, "y") dist = np.hypot(jug_x - cx, jug_y - cy) - logging.debug(f"Dist to cup {cup_obj.name}: {dist}") if dist < 0.13: # "over" the cup # cup_color = next_state.get(cup_obj, "color") # if abs(cup_color - jug_color) < 0.1: diff --git a/predicators/pybullet_helpers/objects.py b/predicators/pybullet_helpers/objects.py index 85fb48f06..8c1157d4b 100644 --- a/predicators/pybullet_helpers/objects.py +++ b/predicators/pybullet_helpers/objects.py @@ -1,9 +1,9 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PybulletEnv +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion # import numpy as np @@ -12,8 +12,8 @@ def create_object(asset_path: str, position: Pose3D = (0, 0, 0), - orientation: Quaternion = PybulletEnv._default_orn, - color: Optional[Sequence[float, float, float, float]] = None, + orientation: Quaternion = PyBulletEnv._default_orn, + color: Optional[Tuple[float, float, float, float]] = None, scale: float = 0.2, use_fixed_base: bool = False, physics_client_id: int = 0) -> int: @@ -27,6 +27,8 @@ def create_object(asset_path: str, orientation, physicsClientId=physics_client_id) if color is not None: + # Change color of the base link (link_id = -1) + p.changeVisualShape(obj_id, -1, rgbaColor=color) # Change color of all links for link_id in range(p.getNumJoints(obj_id)): p.changeVisualShape(obj_id, link_id, rgbaColor=color) @@ -36,8 +38,8 @@ def create_object(asset_path: str, def update_object(obj_id: int, position: Pose3D, - orientation: Quaternion = PybulletEnv._default_orn, - color: Optional[Sequence[float, float, float, float]] = None, + orientation: Quaternion = PyBulletEnv._default_orn, + color: Optional[Tuple[float, float, float, float]] = None, physics_client_id: int = 0) -> None: """Update the position and orientation of an object.""" p.resetBasePositionAndOrientation(obj_id, diff --git a/predicators/structs.py b/predicators/structs.py index 70a1bf4d6..c94a04a4d 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -360,6 +360,72 @@ class VLMPredicate(Predicate): """ get_vlm_query_str: Callable[[Sequence[Object]], str] +# @dataclass(frozen=True, repr=False) +class NSPredicate(Predicate): + """Neuro-Symbolic Predicate.""" + + def __init__( + self, name: str, types: Sequence[Type], + _classifier: Callable[[RawState, Sequence[Object]], bool]) -> None: + self._original_classifier = _classifier + super().__init__(name, types, _MemoizedClassifier(_classifier)) + + @cached_property + def _hash(self) -> int: + # return hash(str(self)) + return hash(self.name + str(self.types)) + + def __hash__(self) -> int: + return self._hash + + def classifier_str(self) -> str: + """Get a string representation of the classifier.""" + clf_str = getsource(self._original_classifier) + clf_str = textwrap.dedent(clf_str) + clf_str = clf_str.replace("@staticmethod\n", "") + return clf_str + +@dataclass(frozen=True, order=False, repr=False) +class ConceptPredicate(Predicate): + """Struct defining a concept predicate""" + name: str + types: Sequence[Type] + # The classifier takes in a complete state and a sequence of objects + # representing the arguments. These objects should be the only ones + # treated "specially" by the classifier. + _classifier: Callable[[Set[GroundAtom], Sequence[Object]], + bool] = field(compare=False) + untransformed_predicate: Optional[Predicate] = field(default=None, + compare=False) + auxiliary_concepts: Optional[Set[ConceptPredicate]] = field(default=None, + compare=False) + + def update_auxiliary_concepts(self, + auxiliary_concepts: Set[ConceptPredicate]) -> ConceptPredicate: + """Create a new ConceptPredicate with updated auxiliary_concepts.""" + return replace(self, auxiliary_concepts=auxiliary_concepts) + + + @cached_property + def _hash(self) -> int: + # return hash(str(self)) + return hash(self.name + str(self.types)) + + def holds(self, state: Set[GroundAtom], objects: Sequence[Object]) -> bool: + """Public method for calling the classifier. + + Performs type checking first. + """ + assert len(objects) == self.arity + for obj, pred_type in zip(objects, self.types): + assert isinstance(obj, Object) + assert obj.is_instance(pred_type) + return self._classifier(state, objects) + + def _negated_classifier(self, state: Set[GroundAtom], + objects: Sequence[Object]) -> bool: + # Separate this into a named function for pickling reasons. + return not self._classifier(state, objects) @dataclass(frozen=True, repr=False, eq=False) class _Atom: diff --git a/predicators/utils.py b/predicators/utils.py index 79793ee67..66f59dfec 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1100,6 +1100,264 @@ def add_images_and_masks(self, unlabeled_image: PIL.Image.Image, BoundingBox = namedtuple('BoundingBox', 'left lower right upper') +@dataclass +class RawState(PyBulletState): + state_image: PIL.Image.Image = None + obj_mask_dict: Dict[Object, Mask] = field(default_factory=dict) + labeled_image: Optional[PIL.Image.Image] = None + option_history: Optional[List[str]] = None + bbox_features: Dict[Object, np.ndarray] = field( + default_factory=lambda: defaultdict(lambda: np.zeros(4))) + prev_state: Optional[RawState] = None + next_state: Optional[RawState] = None + + def __hash__(self): + # Convert the dictionary to a tuple of key-value pairs and hash it + # data_hash = hash(tuple(sorted(self.data.items()))) + data_tuple = tuple((k, tuple(v)) for k, v in sorted(self.data.items())) + if self.simulator_state is not None: + data_tuple += tuple(self.simulator_state) + data_hash = hash(data_tuple) + # # Hash the simulator_state + # simulator_state_hash = hash(self.simulator_state) + # Combine the two hashes + # return hash((data_hash, simulator_state_hash)) + return data_hash + + def evaluate_simple_assertion( + self, assertion: str, image: Tuple[BoundingBox, + Sequence[Object]]) -> VLMQuery: + """Given an assertion and an image, queries a VLM and returns whether + the assertion is true or false.""" + bbox, objs = image + return VLMQuery(assertion, bbox, objs) + + def generate_previous_option_message(self) -> str: + """Generate the message for the previous option.""" + assert self.option_history is not None + msg = "Evaluate the truth value of the following assertions in the "\ + "current state as depicted by the image" + if CFG.nsp_pred_include_prev_image_in_prompt and \ + self.prev_state is not None: + msg += " labeled with 'curr. state'" + if CFG.nsp_pred_include_state_str_in_prompt: + msg += " and the information below" + + msg += ".\n" + + if CFG.nsp_pred_include_state_str_in_prompt: + msg += f"We have the object positions and the robot's "\ + "proprioception:\n" + msg += self.dict_str(indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + + if len(self.option_history) == 0: + msg += "For context, this is at the beginning of a task, before "\ + "the robot has done anything.\n" + else: + # return f"For context, this is right after the robot has "\ + # f"successfully executed its [{', '.join(self.option_history[-2:])}]"\ + # f" option sequence." + # msg = f"For context, this state is right after the robot has "\ + # f"successfully executed its {self.option_history[-1]} action." + msg += "For context, the state is right after the robot has"\ + " successfully executed the action "\ + f"{self.option_history[-1]}." + if CFG.nsp_pred_include_state_str_in_prompt: + if self.prev_state is not None: + msg += " The object position and robot proprioception "\ + "before executing the action is:\n" + msg += self.prev_state.dict_str( + indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + if CFG.nsp_pred_include_prev_image_in_prompt: + msg += " The state before executing the action is depicted"\ + " by the image labeled with 'prev. state'." + msg += " Please carefully examine the images depicting the "\ + "'prev. state' and 'curr. state' before making a judgment." + msg += "\n" + msg += "The assertions to evaluate are:" + return msg + + def add_bbox_features(self) -> None: + """Add the features about the bounding box to the objects.""" + for obj, mask in self.obj_mask_dict.items(): + bbox = mask_to_bbox(mask) + for name, value in bbox._asdict().items(): + self.set(obj, f"bbox_{name}", value) + + def set(self, obj: Object, feature_name: str, feature_val: Any) -> None: + """Set the value of an object feature by name.""" + try: + idx = obj.type.feature_names.index(feature_name) + except: + breakpoint() + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + # When setting the bounding box features for the first time + # So we'd first append 4 dimension and try to set again + self.bbox_features[obj][idx - standard_feature_len] = feature_val + else: + self.data[obj][idx] = feature_val + + def get(self, obj: Object, feature_name: str) -> Any: + idx = obj.type.feature_names.index(feature_name) + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + return self.bbox_features[obj][idx - standard_feature_len] + else: + return self.data[obj][idx] + + def dict_str(self, + indent: int = 0, + object_features: bool = True, + use_object_id: bool = False, + position_proprio_features: bool = False) -> str: + """Return a dictionary representation of the state.""" + state_dict = {} + for obj in self: + obj_dict = {} + for attribute, value in zip( + obj.type.feature_names, + np.concatenate([self[obj], self.bbox_features[obj]]) + if self.bbox_features else self[obj]): + # include if it's proprioception feature, or position/bbox + # feature, or object_features is True + # if (obj.type.name == "robot" and \ + # attribute not in ["bbox_left", "bbox_right", "bbox_upper", + # "pose_x", "pose_y", "pose_z", "pose_y_norm", + # "bbox_lower"]) or object_features: + # # attribute in ["pose_x", "pose_y", "pose_z", "bbox_left", + # # "bbox_right", "bbox_upper", "bbox_lower"] or\ + # if isinstance(value, (float, int, np.float32)): + # value = round(float(value), 1) + # obj_dict[attribute] = value + if (position_proprio_features and attribute in [ + # "pose_x", "pose_y", "pose_z", "x", "y", "z", + "rot", + "fingers" + ]) or (object_features and attribute not in [ + "is_heavy", + # "grasp", + # "held", + # "is_held", + ]): + if isinstance(value, (float, int, np.float32)): + value = round(float(value), 1) + obj_dict[attribute] = value + + if use_object_id: obj_name = obj.id_name + else: obj_name = obj.name + state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict + + # Create a string of n_space spaces + spaces = " " * indent + + # Create a PrettyPrinter with a large width + dict_str = spaces + "{" + n_keys = len(state_dict.keys()) + for i, (key, value) in enumerate(state_dict.items()): + value_str = ', '.join(f"'{k}': {v}" for k, v in value.items()) + if value_str == "": + content_str = f"'{key}'" + else: + content_str = f"'{key}': {{{value_str}}}" + if i == 0: + dict_str += f"{content_str},\n" + elif i == n_keys - 1: + dict_str += spaces + f" {content_str}" + else: + dict_str += spaces + f" {content_str},\n" + dict_str += "}" + return dict_str + + def __eq__(self, other): + # Compare the data and simulator_state + assert isinstance(other, RawState) + + if len(self.data) != len(other.data): + return False + + for key, value in self.data.items(): + if key not in other.data or not np.array_equal( + value, other.data[key]): + return False + + return self.simulator_state == other.simulator_state + + def label_all_objects(self): + state_ip = ImagePatch(self) + # state_ip.cropped_image_in_PIL.save(f"images/obs_before_label_all.png") + # labels = [obj.id for obj in self.obj_mask_dict.keys()] + # masks = self.obj_mask_dict.values() + # state_ip.label_all_objects(masks, labels) + state_ip.label_all_objects(self.obj_mask_dict) + # state_ip.label_object(mask, obj.id) + # state_ip.cropped_image_in_PIL.save(f"images/obs_after_label_all.png") + self.labeled_image = state_ip.cropped_image_in_PIL + + def copy(self) -> RawState: + pybullet_state_copy = super().copy() + # simulator_state_copy = list(self.joint_positions) + state_image_copy = copy.copy(self.state_image) + obj_mask_copy = copy.deepcopy(self.obj_mask_dict) + labeled_image_copy = copy.copy(self.labeled_image) + option_history_copy = copy.copy(self.option_history) + bbox_features_copy = copy.deepcopy(self.bbox_features) + prev_state_copy = self.prev_state.copy() if self.prev_state else None + return RawState(pybullet_state_copy.data, + pybullet_state_copy.simulator_state, state_image_copy, + obj_mask_copy, labeled_image_copy, option_history_copy, + bbox_features_copy, prev_state_copy) + + def get_obj_mask(self, object: Object) -> Mask: + """Return the mask for the object.""" + return self.obj_mask_dict[object] + + def get_obj_bbox(self, object: Object) -> BoundingBox: + """Get the bounding box of the object in the state image The origin is + bottom left corner--(0, 0)""" + mask = self.get_obj_mask(object) + return mask_to_bbox(mask) + + def crop_to_objects( + self, + objects: Sequence[Object], + # left_margin: int = 15, + # lower_margin: int = 15, + # right_margin: int = 15, + # top_margin: int = 20 + left_margin: int = 30, + lower_margin: int = 30, + right_margin: int = 30, + top_margin: int = 30) -> Tuple[BoundingBox, Sequence[Object]]: + + bboxes = [self.get_obj_bbox(obj) for obj in objects] + bbox = smallest_bbox_from_bboxes(bboxes) + return (BoundingBox( + max(bbox.left - left_margin, 0), max(bbox.lower - lower_margin, 0), + min(bbox.right + right_margin, self.state_image.width), + min(bbox.upper + top_margin, self.state_image.height)), objects) + + # state_ip = ImagePatch(self, attn_objects=objects) + # return state_ip.crop_to_objects(objects, left_margin, lower_margin, + # right_margin, top_margin) + + +@dataclass +class VLMQuery: + """A class to represent a query to a VLM.""" + query_str: str + attention_box: BoundingBox + attn_objects: Optional[List[Object]] = None + ground_atom: Optional[GroundAtom] = None + def mask_to_bbox(mask: Mask) -> BoundingBox: """Return the bounding box of the mask."""