diff --git a/tf_agents/environments/gym_wrapper.py b/tf_agents/environments/gym_wrapper.py index 9caef96e7..aaf01cfce 100644 --- a/tf_agents/environments/gym_wrapper.py +++ b/tf_agents/environments/gym_wrapper.py @@ -111,6 +111,8 @@ def nested_spec(spec, child_name): dtype = space.dtype else: dtype = dtype_map.get(gym.spaces.Box, np.float32) + if dtype == tf.string: + return specs.ArraySpec(shape=space.shape, dtype=dtype, name=name) minimum = np.asarray(space.low, dtype=dtype) maximum = np.asarray(space.high, dtype=dtype) if simplify_box_bounds: