From 866afebcc2325b02eec5d1bfabcc3be49df2a122 Mon Sep 17 00:00:00 2001 From: Quanyi Li Date: Wed, 22 Jan 2025 09:53:45 +0000 Subject: [PATCH] Detecting traffic light (#805) * new traffic light implementation * add test for checking lidar * add test and test success * format --- metadrive/base_class/base_object.py | 7 +- .../traffic_light/base_traffic_light.py | 10 +- metadrive/component/vehicle/base_vehicle.py | 3 +- metadrive/constants.py | 2 +- metadrive/policy/idm_policy.py | 18 ++- .../test_component/test_traffic_light.py | 132 +++++++++++++++++- .../vis_functionality/vis_traffic_light.py | 5 +- 7 files changed, 157 insertions(+), 20 deletions(-) diff --git a/metadrive/base_class/base_object.py b/metadrive/base_class/base_object.py index 7fa20f663..8d8927c6d 100644 --- a/metadrive/base_class/base_object.py +++ b/metadrive/base_class/base_object.py @@ -340,8 +340,11 @@ def velocity(self): """ Velocity, unit: m/s """ - velocity = self.body.get_linear_velocity() - return np.asarray([velocity[0], velocity[1]]) + if isinstance(self.body, BaseGhostBodyNode): + return np.array([0, 0]) + else: + velocity = self.body.get_linear_velocity() + return np.asarray([velocity[0], velocity[1]]) @property def velocity_km_h(self): diff --git a/metadrive/component/traffic_light/base_traffic_light.py b/metadrive/component/traffic_light/base_traffic_light.py index 803aed230..ba1fd201a 100644 --- a/metadrive/component/traffic_light/base_traffic_light.py +++ b/metadrive/component/traffic_light/base_traffic_light.py @@ -1,10 +1,10 @@ import numpy as np from metadrive.base_class.base_object import BaseObject -from metadrive.constants import CamMask -from metadrive.scenario.scenario_description import ScenarioDescription +from metadrive.constants import CamMask, CollisionGroup from metadrive.constants import MetaDriveType, Semantics from metadrive.engine.asset_loader import AssetLoader +from metadrive.scenario.scenario_description import ScenarioDescription from metadrive.utils.pg.utils import generate_static_box_physics_body @@ -50,7 +50,7 @@ def __init__( type_name=MetaDriveType.TRAFFIC_LIGHT, ghost_node=True, ) - self.add_body(air_wall, add_to_static_world=True) + self.add_body(air_wall, add_to_static_world=False) # add to dynamic world so the lidar can detect it if position is None: # auto determining @@ -97,6 +97,7 @@ def set_green(self): self.current_light = BaseTrafficLight.TRAFFIC_LIGHT_MODEL["green"].instanceTo(self.origin) self._try_draw_line([3 / 255, 255 / 255, 3 / 255]) self.status = MetaDriveType.LIGHT_GREEN + self._body.setIntoCollideMask(CollisionGroup.AllOff) # can not be detected by anything def set_red(self): if self.render: @@ -106,6 +107,7 @@ def set_red(self): self.current_light = BaseTrafficLight.TRAFFIC_LIGHT_MODEL["red"].instanceTo(self.origin) self._try_draw_line([252 / 255, 0 / 255, 0 / 255]) self.status = MetaDriveType.LIGHT_RED + self._body.setIntoCollideMask(CollisionGroup.InvisibleWall) # will be detected by lidar and object detector def set_yellow(self): if self.render: @@ -115,6 +117,7 @@ def set_yellow(self): self.current_light = BaseTrafficLight.TRAFFIC_LIGHT_MODEL["yellow"].instanceTo(self.origin) self._try_draw_line([252 / 255, 227 / 255, 3 / 255]) self.status = MetaDriveType.LIGHT_YELLOW + self._body.setIntoCollideMask(CollisionGroup.InvisibleWall) # will be detected by lidar and object detector def set_unknown(self): if self.render: @@ -123,6 +126,7 @@ def set_unknown(self): if self._show_model: self.current_light = BaseTrafficLight.TRAFFIC_LIGHT_MODEL["unknown"].instanceTo(self.origin) self.status = MetaDriveType.LIGHT_UNKNOWN + self._body.setIntoCollideMask(CollisionGroup.AllOff) # can not be detected by anything def destroy(self): super(BaseTrafficLight, self).destroy() diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py index cf88172eb..a67fa4d32 100644 --- a/metadrive/component/vehicle/base_vehicle.py +++ b/metadrive/component/vehicle/base_vehicle.py @@ -50,7 +50,7 @@ def init_state_info(self): # traffic light self.red_light = False self.yellow_light = False - self.green_light = False + self.green_light = False # should always be False, since we don't detect green light # lane line detection self.on_yellow_continuous_line = False @@ -772,6 +772,7 @@ def _state_check(self): elif name == MetaDriveType.TRAFFIC_LIGHT: light = get_object_from_node(node) if light.status == MetaDriveType.LIGHT_GREEN: + raise ValueError("Green light should not be in the contact test!") self.green_light = True elif light.status == MetaDriveType.LIGHT_RED: self.red_light = True diff --git a/metadrive/constants.py b/metadrive/constants.py index 76bd48c1e..c61616091 100644 --- a/metadrive/constants.py +++ b/metadrive/constants.py @@ -170,7 +170,7 @@ def collision_rules(cls): (cls.BrokenLaneLine, cls.TrafficParticipants, True), (cls.BrokenLaneLine, cls.Crosswalk, False), - # vehicle collision + # vehicle contact (cls.Vehicle, cls.Vehicle, True), (cls.Vehicle, cls.LaneSurface, True), (cls.Vehicle, cls.ContinuousLaneLine, True), diff --git a/metadrive/policy/idm_policy.py b/metadrive/policy/idm_policy.py index 1b9c9c791..7045521fe 100644 --- a/metadrive/policy/idm_policy.py +++ b/metadrive/policy/idm_policy.py @@ -1,10 +1,11 @@ import numpy as np - +from metadrive.component.traffic_light.base_traffic_light import BaseTrafficLight from metadrive.component.lane.point_lane import PointLane from metadrive.component.vehicle.PID_controller import PIDController from metadrive.policy.base_policy import BasePolicy from metadrive.policy.manual_control_policy import ManualControlPolicy from metadrive.utils.math import not_zero, wrap_to_pi, norm +import logging class FrontBackObjects: @@ -336,6 +337,17 @@ def lane_change_policy(self, all_objects): next_lanes = self.control_object.navigation.next_ref_lanes lane_num_diff = len(current_lanes) - len(next_lanes) if next_lanes is not None else 0 + def lane_follow(): + # fall back to lane follow + self.target_speed = self.NORMAL_SPEED + self.overtake_timer += 1 + return surrounding_objects.front_object(), surrounding_objects.front_min_distance( + ), self.routing_target_lane + + if isinstance(surrounding_objects.front_object(), BaseTrafficLight): + # traffic light, go lane follow + return lane_follow() + # We have to perform lane changing because the number of lanes in next road is less than current road if lane_num_diff > 0: # lane num decreasing happened in left road or right road @@ -397,9 +409,7 @@ def lane_change_policy(self, all_objects): current_lanes[expect_lane_idx] # fall back to lane follow - self.target_speed = self.NORMAL_SPEED - self.overtake_timer += 1 - return surrounding_objects.front_object(), surrounding_objects.front_min_distance(), self.routing_target_lane + return lane_follow() class ManualControllableIDMPolicy(IDMPolicy): diff --git a/metadrive/tests/test_component/test_traffic_light.py b/metadrive/tests/test_component/test_traffic_light.py index 89cfce3bc..83a4369cd 100644 --- a/metadrive/tests/test_component/test_traffic_light.py +++ b/metadrive/tests/test_component/test_traffic_light.py @@ -1,9 +1,9 @@ -from metadrive.component.traffic_participants.pedestrian import Pedestrian from metadrive.component.traffic_light.base_traffic_light import BaseTrafficLight from metadrive.envs.metadrive_env import MetaDriveEnv +from metadrive.policy.idm_policy import IDMPolicy -def test_traffic_light(render=False, manual_control=False, debug=False): +def test_traffic_light_state_check(render=False, manual_control=False, debug=False): env = MetaDriveEnv( { "num_scenarios": 1, @@ -12,7 +12,6 @@ def test_traffic_light(render=False, manual_control=False, debug=False): "manual_control": manual_control, "use_render": render, "debug": debug, - "debug_static_world": debug, "map": "X", "window_size": (1200, 800), "vehicle_config": { @@ -27,16 +26,17 @@ def test_traffic_light(render=False, manual_control=False, debug=False): env.reset() light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) light.set_green() - test_success = False + test_success = True for s in range(1, 100): env.step([0, 1]) - if env.agent.green_light: - test_success = True + if env.agent.red_light or env.agent.yellow_light: + test_success = False break assert test_success light.destroy() # red test + env.reset() light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) light.set_red() test_success = False @@ -47,6 +47,7 @@ def test_traffic_light(render=False, manual_control=False, debug=False): break assert test_success light.destroy() + # yellow env.reset() light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) @@ -64,5 +65,122 @@ def test_traffic_light(render=False, manual_control=False, debug=False): env.close() +def test_traffic_light_detection(render=False, manual_control=False, debug=False): + env = MetaDriveEnv( + { + "num_scenarios": 1, + "traffic_density": 0., + "traffic_mode": "hybrid", + "manual_control": manual_control, + "use_render": render, + "debug": debug, + "map": "X", + "window_size": (1200, 800), + "vehicle_config": { + "enable_reverse": True, + "show_dest_mark": True + }, + } + ) + env.reset() + try: + # green + env.reset() + light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) + light.set_green() + test_success = True + for s in range(1, 100): + env.step([0, 1]) + if min(env.observations["default_agent"].cloud_points) < 0.99: + test_success = False + break + assert len(env.observations["default_agent"].detected_objects) == 0 + assert test_success + light.destroy() + + # red test + env.reset() + light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) + light.set_red() + test_success = False + for s in range(1, 100): + env.step([0, 1]) + if min(env.observations["default_agent"].cloud_points) < 0.5: + test_success = True + assert list(env.observations["default_agent"].detected_objects)[0].status == BaseTrafficLight.LIGHT_RED + break + assert test_success + light.destroy() + + # yellow + env.reset() + light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) + light.set_yellow() + test_success = False + for s in range(1, 100): + env.step([0, 1]) + if min(env.observations["default_agent"].cloud_points) < 0.5: + test_success = True + assert list(env.observations["default_agent"].detected_objects + )[0].status == BaseTrafficLight.LIGHT_YELLOW + break + assert test_success + light.destroy() + + finally: + env.close() + + +def test_idm_policy(render=False, debug=False): + env = MetaDriveEnv( + { + "num_scenarios": 1, + "traffic_density": 0., + "traffic_mode": "hybrid", + "agent_policy": IDMPolicy, + "use_render": render, + "debug": debug, + "map": "X", + "window_size": (1200, 800), + "show_coordinates": True, + "vehicle_config": { + "show_lidar": True, + "enable_reverse": True, + "show_dest_mark": True + }, + } + ) + env.reset() + try: + # green + env.reset() + light = env.engine.spawn_object(BaseTrafficLight, lane=env.current_map.road_network.graph[">>>"]["1X1_0_"][0]) + light.set_green() + for s in range(1, 1000): + if s == 30: + light.set_yellow() + elif s == 90: + light.set_red() + env.step([0, 1]) + if env.vehicle.red_light or env.vehicle.yellow_light: + raise ValueError("Vehicle should not stop at red light!") + assert env.vehicle.speed < 0.1 + + # move + light.set_green() + test_success = False + for s in range(1, 1000): + o, r, d, t, i = env.step([0, 1]) + if i["arrive_dest"]: + test_success = True + break + light.destroy() + assert test_success + finally: + env.close() + + if __name__ == "__main__": - test_traffic_light(True, manual_control=True) + # test_traffic_light_state_check(True, manual_control=False) + # test_traffic_light_detection(True, manual_control=False) + test_idm_policy(True) diff --git a/metadrive/tests/vis_functionality/vis_traffic_light.py b/metadrive/tests/vis_functionality/vis_traffic_light.py index 5a80bdfbc..92c140313 100644 --- a/metadrive/tests/vis_functionality/vis_traffic_light.py +++ b/metadrive/tests/vis_functionality/vis_traffic_light.py @@ -12,11 +12,12 @@ def vis_traffic_light(render=True, manual_control=False, debug=False): "manual_control": manual_control, "use_render": render, "debug": debug, - "debug_static_world": debug, + "debug_static_world": False, "map": "X", "window_size": (1200, 800), "show_coordinates": True, "vehicle_config": { + "show_lidar": True, "enable_reverse": True, "show_dest_mark": True }, @@ -60,4 +61,4 @@ def vis_traffic_light(render=True, manual_control=False, debug=False): if __name__ == "__main__": - vis_traffic_light(True, manual_control=True) + vis_traffic_light(True, manual_control=True, debug=True)