Skip to content

Commit

Permalink
can make circuit demos
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Jan 1, 2025
1 parent 68e0590 commit cf9a52b
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 46 deletions.
1 change: 0 additions & 1 deletion predicators/approaches/random_actions_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def is_learning_based(self) -> bool:
def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:

def _policy(_: State) -> Action:
breakpoint()
return Action(self._action_space.sample())

return _policy
6 changes: 3 additions & 3 deletions predicators/envs/assets/urdf/bulb_box_snap.urdf
Original file line number Diff line number Diff line change
Expand Up @@ -1535,21 +1535,21 @@
<joint name="start_metal_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="start_metal"/>
<origin xyz="-0.07500000000000001 0 0" rpy="0 0 0"/>
<origin xyz="-0.075 0 0" rpy="0 0 0"/>
</joint>


<joint name="end_metal_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="end_metal"/>
<origin xyz="0.07500000000000001 0 0" rpy="0 0 0"/>
<origin xyz="0.075 0 0" rpy="0 0 0"/>
</joint>


<joint name="line_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="line"/>
<origin xyz="0 0 0.025050000000000003" rpy="0 0 0"/>
<origin xyz="0 0 0.025" rpy="0 0 0"/>
</joint>


Expand Down
14 changes: 7 additions & 7 deletions predicators/envs/assets/urdf/snap_connector4.urdf
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<visual>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<box size="0.30000000000000004 0.05 0.05" />
<box size="0.30 0.05 0.05" />
</geometry>
<material name="middle_plastic_material">
<color rgba="0.0 0.5 1.0 1.0"/>
Expand All @@ -15,7 +15,7 @@
<collision>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<box size="0.30000000000000004 0.05 0.05" />
<box size="0.3 0.05 0.05" />
</geometry>
</collision>
</link>
Expand Down Expand Up @@ -64,7 +64,7 @@
<visual>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<box size="0.30000000000000004 0.01 0.0001" />
<box size="0.3 0.01 0.0001" />
</geometry>
<material name="line_material">
<color rgba="0.8 0.8 0.8 1.0"/>
Expand All @@ -73,7 +73,7 @@
<collision>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<box size="0.30000000000000004 0.01 0.0001" />
<box size="0.3 0.01 0.0001" />
</geometry>
</collision>
</link>
Expand All @@ -84,21 +84,21 @@
<joint name="start_metal_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="start_metal"/>
<origin xyz="-0.17500000000000002 0 0" rpy="0 0 0"/>
<origin xyz="-0.1750 0 0" rpy="0 0 0"/>
</joint>


<joint name="end_metal_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="end_metal"/>
<origin xyz="0.17500000000000002 0 0" rpy="0 0 0"/>
<origin xyz="0.175 0 0" rpy="0 0 0"/>
</joint>


<joint name="line_joint" type="fixed">
<parent link="middle_plastic"/>
<child link="line"/>
<origin xyz="0 0 0.025050000000000003" rpy="0 0 0"/>
<origin xyz="0 0 0.025" rpy="0 0 0"/>
</joint>

</robot>
54 changes: 43 additions & 11 deletions predicators/envs/pybullet_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
In the simplest case, the lightbulb is automatically turned on when the
light is connected to both the positive and negative terminals of the
battery. The lightbulb and the battery are fixed, the wire is moveable.
python predicators/main.py --approach oracle --env pybullet_circuit \
--seed 0 --num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \
--sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \
--pybullet_camera_height 900 --pybullet_camera_width 900 --debug
"""

import logging
Expand Down Expand Up @@ -55,6 +60,7 @@ class PyBulletCircuitEnv(PyBulletEnv):
[0.0, 0.0, np.pi / 2])
robot_init_tilt: ClassVar[float] = np.pi / 2
robot_init_wrist: ClassVar[float] = -np.pi / 2
max_angular_vel: ClassVar[float] = np.pi / 4

# Hard-coded finger states for open/close
open_fingers: ClassVar[float] = 0.4
Expand All @@ -64,7 +70,7 @@ class PyBulletCircuitEnv(PyBulletEnv):
_bulb_on_color: ClassVar[Tuple[float, float, float,
float]] = (1.0, 1.0, 0.0, 1.0) # yellow
_bulb_off_color: ClassVar[Tuple[float, float, float,
float]] = (1.0, 1.0, 1.0, 1.0) # white
float]] = (0.8, 0.8, 0.8, 1.0) # white

# Connector dimensions
snap_width: ClassVar[float] = 0.05
Expand All @@ -76,13 +82,13 @@ class PyBulletCircuitEnv(PyBulletEnv):
# Camera parameters
_camera_distance: ClassVar[float] = 1.3
_camera_yaw: ClassVar[float] = 70
_camera_pitch: ClassVar[float] = -38
_camera_pitch: ClassVar[float] = -50
_camera_target: ClassVar[Pose3D] = (0.75, 1.25, 0.42)

# --- CHANGED / ADDED ---
# Added "rot" to both the battery and light types.
_robot_type = Type("robot", ["x", "y", "z", "fingers", "tilt", "wrist"])
_wire_type = Type("wire", ["x", "y", "z", "rot"])
_wire_type = Type("wire", ["x", "y", "z", "rot", "is_held"])
_battery_type = Type("battery", ["x", "y", "z", "rot"])
_light_type = Type("light", ["x", "y", "z", "rot", "is_on"])

Expand All @@ -104,6 +110,11 @@ def __init__(self, use_gui: bool = True) -> None:
# self._Connected = Predicate("Connected",
# [self._light_type, self._battery_type],
# self._Connected_holds)
self._Holding = Predicate("Holding",
[self._robot_type, self._wire_type],
self._Holding_holds)
self._HandEmpty = Predicate("HandEmpty", [self._robot_type],
self._HandEmpty_holds)
self._ConnectedToLight = Predicate("ConnectedToLight",
[self._wire_type, self._light_type],
self._ConnectedToLight_holds)
Expand All @@ -117,7 +128,8 @@ def __init__(self, use_gui: bool = True) -> None:
# connected to the battery.

# Normal version used in the simulator
self._CircuitClosed = Predicate("CircuitClosed", [],
self._CircuitClosed = Predicate("CircuitClosed",
[self._light_type, self._battery_type],
self._CircuitClosed_holds)
# self._CircuitClosed_abs = ConceptPredicate("CircuitClosed",
# [self._wire_type, self._wire_type],
Expand All @@ -134,6 +146,8 @@ def predicates(self) -> Set[Predicate]:
return {
# If you want to define self._Connected, re-add it here
# self._Connected,
self._Holding,
self._HandEmpty,
self._LightOn,
self._ConnectedToLight,
self._ConnectedToBattery,
Expand Down Expand Up @@ -265,11 +279,13 @@ def _get_state(self) -> State:
for wire_obj in [self._wire1, self._wire2]:
(wx, wy, wz), orn = p.getBasePositionAndOrientation(
wire_obj.id, physicsClientId=self._physics_client_id)
is_held_val = 1.0 if wire_obj.id == self._held_obj_id else 0.0
state_dict[wire_obj] = {
"x": wx,
"y": wy,
"z": wz,
"rot": p.getEulerFromQuaternion(orn)[2],
"is_held": is_held_val
}

# Convert dictionary to a PyBulletState
Expand Down Expand Up @@ -319,6 +335,9 @@ def _reset_state(self, state: State) -> None:
position=(wx, wy, wz),
orientation=p.getQuaternionFromEuler([0, 0, rot]),
physics_client_id=self._physics_client_id)
if state.get(wire_obj, "is_held") > 0.5:
self._attach(wire_obj.id, self._pybullet_robot)
self._held_obj_id = wire_obj.id

# Check if re-creation matches
reconstructed_state = self._get_state()
Expand All @@ -344,6 +363,16 @@ def step(self, action: Action, render_obs: bool = False) -> State:

# -------------------------------------------------------------------------
# Predicates
@staticmethod
def _Holding_holds(state: State, objects: Sequence[Object]) -> bool:
_, wire = objects
return state.get(wire, "is_held") > 0.5

@staticmethod
def _HandEmpty_holds(state: State, objects: Sequence[Object]) -> bool:
robot, = objects
return state.get(robot, "fingers") > 0.2

@staticmethod
def _ConnectedToLight_holds(state: State,
objects: Sequence[Object]) -> bool:
Expand All @@ -366,7 +395,7 @@ def _ConnectedToLight_holds(state: State,
return False

# Correct x and y differences for connection
target_x_diff = PyBulletCircuitEnv.bulb_snap_length / 2 - \
target_x_diff = PyBulletCircuitEnv.wire_snap_length / 2 - \
PyBulletCircuitEnv.snap_width / 2
target_y_diff = PyBulletCircuitEnv.bulb_snap_length / 2 + \
PyBulletCircuitEnv.snap_width / 2
Expand Down Expand Up @@ -443,15 +472,15 @@ def _turn_bulb_on(self) -> None:
if self._light.id is not None:
p.changeVisualShape(
self._light.id,
-1, # all link indices
3, # all link indices
rgbaColor=self._bulb_on_color,
physicsClientId=self._physics_client_id)

def _turn_bulb_off(self) -> None:
if self._light.id is not None:
p.changeVisualShape(
self._light.id,
-1, # all link indices
3, # all link indices
rgbaColor=self._bulb_off_color,
physicsClientId=self._physics_client_id)

Expand Down Expand Up @@ -484,7 +513,7 @@ def _make_tasks(self, num_tasks: int,
# For randomization, tweak or keep rot=0.0 as needed
battery_dict = {
"x": battery_x,
"y": 1.35,
"y": 1.3,
"z": self.z_lb + self.snap_height / 2,
"rot": np.pi / 2,
}
Expand All @@ -495,20 +524,22 @@ def _make_tasks(self, num_tasks: int,
"y": 1.15, # lower region
"z": self.z_lb + self.snap_height / 2,
"rot": 0.0,
"is_held": 0.0,
}
wire2_dict = {
"x": 0.75,
"y": 1.55, # upper region
"y": self.y_ub - self.init_padding * 3, # upper region
"z": self.z_lb + self.snap_height / 2,
"rot": 0.0,
"is_held": 0.0,
}

# Light near upper region
bulb_x = battery_x + self.wire_snap_length - self.snap_width
# For randomization, tweak or keep rot=0.0 as needed
light_dict = {
"x": bulb_x,
"y": 1.35,
"y": 1.3,
"z": self.z_lb + self.snap_height / 2,
"rot": -np.pi / 2,
"is_on": 0.0,
Expand All @@ -525,7 +556,8 @@ def _make_tasks(self, num_tasks: int,

# The goal can be that the light is on.
goal_atoms = {
GroundAtom(self._LightOn, [self._light]),
# GroundAtom(self._LightOn, [self._light]),
GroundAtom(self._CircuitClosed, [self._light, self._battery]),
}
tasks.append(EnvironmentTask(init_state, goal_atoms))

Expand Down
8 changes: 4 additions & 4 deletions predicators/envs/pybullet_grow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \
"""
python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \
--num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \
--sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \
--pybullet_camera_height 900 --pybullet_camera_width 900
Expand All @@ -17,8 +17,8 @@
from predicators.pybullet_helpers.objects import create_object, update_object
from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot
from predicators.settings import CFG
from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \
Predicate, State, Type
from predicators.structs import Action, EnvironmentTask, GroundAtom, \
Object, Predicate, State, Type


class PyBulletGrowEnv(PyBulletEnv):
Expand Down
34 changes: 19 additions & 15 deletions predicators/ground_truth_models/circuit/nsrts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,46 @@ def get_nsrts(env_name: str, types: Dict[str, Type],
CircuitClosed = predicates["CircuitClosed"]

# Options
PickConnector = options["PickConnector"]
Pick = options["PickWire"]
Connect = options["Connect"]

nsrts = set()

# PickConnector
# PickWire
robot = Variable("?robot", robot_type)
connector = Variable("?connector", wire_type)
parameters = [robot, connector]
option_vars = [robot, connector]
option = PickConnector
wire = Variable("?wire", wire_type)
parameters = [robot, wire]
option_vars = [robot, wire]
option = Pick
preconditions = {
LiftedAtom(HandEmpty, [robot]),
}
add_effects = {
LiftedAtom(Holding, [robot, connector]),
LiftedAtom(Holding, [robot, wire]),
}
delete_effects = {
LiftedAtom(HandEmpty, [robot]),
}
pick_connector_nsrt = NSRT("PickConnector", parameters,
pick_wire_nsrt = NSRT("PickWire", parameters,
preconditions, add_effects, delete_effects,
set(), option, option_vars, null_sampler)
nsrts.add(pick_connector_nsrt)
nsrts.add(pick_wire_nsrt)

# ConnectFirstWire. Connect first wire to light and battery.
robot = Variable("?robot", robot_type)
wire = Variable("?wire", wire_type)
light = Variable("?light", light_type)
battery = Variable("?battery", battery_type)
parameters = [wire, light, battery]
option_vars = [wire, light, battery]
parameters = [robot, wire, light, battery]
option_vars = [robot, wire, light, battery]
option = Connect
preconditions = {
LiftedAtom(Holding, [robot, wire]),
# Should add one that says the distance between the terminals are
# close enough
}
add_effects = {
LiftedAtom(HandEmpty, [robot]),
LiftedAtom(ConnectedToLight, [wire, light]),
LiftedAtom(ConnectedToBattery, [wire, battery]),
}
Expand All @@ -85,18 +87,20 @@ def get_nsrts(env_name: str, types: Dict[str, Type],
nsrts.add(connect_first_wire_nsrt)

# hacky: connect second wire to light and power
robot = Variable("?robot", robot_type)
wire = Variable("?wire", wire_type)
light = Variable("?light", light_type)
battery = Variable("?battery", battery_type)
parameters = [wire, light, battery]
option_vars = [wire, light, battery]
parameters = [robot, wire, light, battery]
option_vars = [robot, wire, light, battery]
option = Connect
preconditions = {
LiftedAtom(Holding, [robot, wire]),
}
add_effects = {
LiftedAtom(CircuitClosed, []),
LiftedAtom(LightOn, [light]),
LiftedAtom(HandEmpty, [robot]),
LiftedAtom(CircuitClosed, [light, battery]),
# LiftedAtom(LightOn, [light]),
}
delete_effects = {
LiftedAtom(Holding, [robot, wire]),
Expand Down
Loading

0 comments on commit cf9a52b

Please sign in to comment.