Skip to content

Commit

Permalink
fix pylint mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Dec 26, 2024
1 parent ef08175 commit dd8bce2
Show file tree
Hide file tree
Showing 18 changed files with 377 additions and 389 deletions.
6 changes: 2 additions & 4 deletions predicators/approaches/vlm_open_loop_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ def _append_to_prompt_state_imgs_list(state: State) -> None:
text,
font=font)[2:]
# Create a new image with additional space for text!
new_image = PIL.Image.new("RGB",
(width, height + int(text_height) +
10),
"white")
new_image = PIL.Image.new(
"RGB", (width, height + int(text_height) + 10), "white")
new_image.paste(pil_img, (0, 0))
draw = ImageDraw.Draw(new_image)
text_x = (width - text_width) / 2
Expand Down
7 changes: 6 additions & 1 deletion predicators/cogman.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def is_learning_based(self) -> bool:
"""See BaseApproach docstring."""
return self._approach.is_learning_based

@property
def get_approach_name(self) -> str:
"""See BaseApproach docstring."""
return self._approach.get_name()

def learn_from_offline_dataset(self, dataset: Dataset) -> None:
"""See BaseApproach docstring."""
return self._approach.learn_from_offline_dataset(dataset)
Expand Down Expand Up @@ -207,7 +212,7 @@ def run_episode_and_get_observations(
env.reset(train_or_test, task_idx)
if monitor is not None:
monitor.reset(train_or_test, task_idx)
render_obs = cogman._approach.get_name() == "oracle" and\
render_obs = cogman.get_approach_name == "oracle" and\
CFG.offline_data_method == "geo_and_demo_with_vlm_imgs"
if isinstance(env, PyBulletEnv):
obs = env.get_observation(render=render_obs)
Expand Down
5 changes: 1 addition & 4 deletions predicators/datasets/generate_atom_trajs_with_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,10 +1083,7 @@ def create_ground_atom_data_from_generated_demos(
raise NotImplementedError(
f"Cropped images not implemented for {CFG.env}.")
if CFG.env in ["pybullet_coffee"]:
state_imgs.append([
img_arr # type: ignore
for img_arr in state.simulator_state["images"]
])
state_imgs.append(list(state.simulator_state['images']))
else:
state_imgs.append([
PIL.Image.fromarray(img_arr) # type: ignore
Expand Down
4 changes: 3 additions & 1 deletion predicators/envs/cluttered_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def _check_collisions(cls,
colliding_can = None
colliding_can_max_dist = float("-inf")
for can in state:
if ignored_can is not None and can == ignored_can or not cls._Untrashed_holds(state, [can]):
if ignored_can is not None and can == ignored_can or \
not cls._Untrashed_holds(
state, [can]):
continue
this_x = state.get(can, "pose_x")
this_y = state.get(can, "pose_y")
Expand Down
52 changes: 35 additions & 17 deletions predicators/envs/coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class CoffeeEnv(BaseEnv):
cord_link_length = 0.02
cord_segment_gap = 0.00
cord_start_x = machine_x - machine_x_len / 2 - 4 * cord_link_length
cord_start_y = machine_y - machine_y_len
cord_start_y = machine_y - machine_y_len
cord_start_z = z_lb + cord_link_length / 2
plug_x = cord_start_x - (num_cord_links - 1) * cord_link_length -\
cord_segment_gap * (num_cord_links - 1)
Expand All @@ -84,7 +84,9 @@ class CoffeeEnv(BaseEnv):

@classmethod
def jug_height(cls) -> float:
"""The height of the jug."""
return 0.15 * (cls.z_ub - cls.z_lb)

jug_init_x_lb: ClassVar[float] = machine_x - machine_x_len + init_padding
jug_init_x_ub: ClassVar[float] = machine_x + machine_x_len - init_padding
jug_init_y_lb: ClassVar[float] = y_lb + jug_radius + pick_jug_y_padding + \
Expand All @@ -97,7 +99,9 @@ def jug_height(cls) -> float:
# jug_handle_height: ClassVar[float] = 3 * jug_height / 4
@classmethod
def jug_handle_height(cls) -> float:
"""The height of the jug handle."""
return 3 * cls.jug_height() / 4

jug_handle_radius: ClassVar[float] = 1e-1 # just for rendering
# Dispense area settings.
dispense_area_x: ClassVar[float] = machine_x + machine_x_len / 2
Expand All @@ -119,8 +123,10 @@ def jug_handle_height(cls) -> float:
# jug_handle_height)
@classmethod
def pour_z_offset(cls) -> float:
"""The z offset for pouring liquid into a cup."""
return 1.1 * (cls.cup_capacity_ub + cls.jug_height() -\
cls.jug_handle_height())

pour_velocity: ClassVar[float] = cup_capacity_ub / 10.0
max_position_vel: ClassVar[float] = 2.5
max_angular_vel: ClassVar[float] = tilt_ub
Expand All @@ -130,12 +136,11 @@ def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

# Types
self._table_type = Type(
"table", [])
self._robot_type = Type(
"robot", ["x", "y", "z", "tilt", "wrist", "fingers"])
self._jug_type = Type(
"jug", ["x", "y", "z", "rot", "is_held", "is_filled"])
self._table_type = Type("table", [])
self._robot_type = Type("robot",
["x", "y", "z", "tilt", "wrist", "fingers"])
self._jug_type = Type("jug",
["x", "y", "z", "rot", "is_held", "is_filled"])
self._machine_type = Type("coffee_machine", ["is_on"])
self._cup_type = Type("cup", [
"x", "y", "z", "capacity_liquid", "target_liquid", "current_liquid"
Expand All @@ -144,7 +149,7 @@ def __init__(self, use_gui: bool = True) -> None:

# Predicates
self._PluggedIn = Predicate("PluggedIn", [self._plug_type],
self._PluggedIn_holds)
self._PluggedIn_holds)
self._CupFilled = Predicate("CupFilled", [self._cup_type],
self._CupFilled_holds)
self._Holding = Predicate("Holding",
Expand Down Expand Up @@ -328,11 +333,21 @@ def _generate_test_tasks(self) -> List[EnvironmentTask]:
@property
def predicates(self) -> Set[Predicate]:
return {
self._CupFilled, self._JugInMachine, self._Holding,
self._MachineOn, self._OnTable, self._HandEmpty, self._JugFilled,
self._RobotAboveCup, self._JugAboveCup, self._NotAboveCup,
self._PressingButton, self._Twisting, self._NotSameCup,
self._JugPickable, self._PluggedIn,
self._CupFilled,
self._JugInMachine,
self._Holding,
self._MachineOn,
self._OnTable,
self._HandEmpty,
self._JugFilled,
self._RobotAboveCup,
self._JugAboveCup,
self._NotAboveCup,
self._PressingButton,
self._Twisting,
self._NotSameCup,
self._JugPickable,
self._PluggedIn,
}

@property
Expand Down Expand Up @@ -489,6 +504,7 @@ def _get_tasks(self,
num_cups_lst: List[int],
rng: np.random.Generator,
is_train: bool = False) -> List[EnvironmentTask]:
del is_train # unused
tasks = []
# Create the parts of the initial state that do not change between
# tasks, which includes the robot and the machine.
Expand Down Expand Up @@ -520,7 +536,7 @@ def _get_tasks(self,
# GroundAtom(self._JugFilled, [self._jug]),
# GroundAtom(self._PluggedIn, [self._plug]),
GroundAtom(self._JugInMachine, [self._jug, self._machine]),
}
}
else:
goal = {GroundAtom(self._CupFilled, [c]) for c in cups}
# Sample initial positions for cups, making sure to keep them
Expand Down Expand Up @@ -618,12 +634,14 @@ def _Holding_holds(state: State, objects: Sequence[Object]) -> bool:
_, jug = objects
return state.get(jug, "is_held") > 0.5

def _PluggedIn_holds(self, state: State, objects: Sequence[Object]) -> bool:
def _PluggedIn_holds(self, state: State,
objects: Sequence[Object]) -> bool:
plug, = objects
plug_x = state.get(plug, "x")
plug_y = state.get(plug, "y")
plug_z = state.get(plug, "z")
sq_dist = np.sum(np.subtract((plug_x, plug_y, plug_z),
sq_dist = np.sum(
np.subtract((plug_x, plug_y, plug_z),
(self.socket_x, self.socket_y, self.socket_z))**2)
return bool(sq_dist < self.plugged_in_tol)

Expand Down Expand Up @@ -743,7 +761,7 @@ def _get_jug_handle_grasp(cls, state: State,
# Orient pointing down.
rot = state.get(jug, "rot") - np.pi / 2
target_x = state.get(jug, "x") + np.cos(rot) * cls.jug_handle_offset
target_y = state.get(jug, "y") + np.sin(rot) * cls.jug_handle_offset
target_y = state.get(jug, "y") + np.sin(rot) * cls.jug_handle_offset
if CFG.coffee_use_pixelated_jug:
target_y -= 0.02
target_z = cls.z_lb + cls.jug_handle_height()
Expand Down
Loading

0 comments on commit dd8bce2

Please sign in to comment.