From 308dd8c1c94f2c5dbb6727b9b2b181901fbce259 Mon Sep 17 00:00:00 2001 From: andermi Date: Wed, 30 Aug 2023 16:20:56 -0700 Subject: [PATCH] add matrix_mode: False option to sim params yaml to allow vectorized mode (#158) Signed-off-by: Michael Anderson --- buoy_gazebo/scripts/mbari_wec_batch | 37 +++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/buoy_gazebo/scripts/mbari_wec_batch b/buoy_gazebo/scripts/mbari_wec_batch index eba8f864..a61aee49 100755 --- a/buoy_gazebo/scripts/mbari_wec_batch +++ b/buoy_gazebo/scripts/mbari_wec_batch @@ -45,6 +45,14 @@ from rclpy.node import Node DEFAULT_PHYSICS_MAX_STEP_SIZE = 0.001 # seconds +def to_shape(a, shape): + y_ = shape[0] + y = a.shape[0] + y_pad = (y_-y) + return np.pad(a, (0, y_pad), + mode = 'constant', constant_values=a[-1]) + + class MonoChromatic(object): def __init__(self, A, T): self.A, self.T = A, T @@ -349,13 +357,28 @@ def generate_simulations(sim_params_yaml): os.path.join(batch_results_dir, sim_params_date_yaml)) - # generate test matrix - batch_params = list(zip(*[param.ravel() for param in np.meshgrid(physics_step, - door_state, - scale_factor, - battery_state, - mean_piston_position, - incident_waves)])) + if 'matrix_mode' not in sim_params or \ + ('matrix_mode' in sim_params and sim_params['matrix_mode']): + # generate test matrix + batch_params = list(zip(*[param.ravel() for param in np.meshgrid(physics_step, + door_state, + scale_factor, + battery_state, + mean_piston_position, + incident_waves)])) + elif 'matrix_mode' in sim_params and not sim_params['matrix_mode']: + # generate test arrays + batch_params = [physics_step, + door_state, + scale_factor, + battery_state, + mean_piston_position, + incident_waves] + shape = max(batch_params, key=len).shape + for idx, param in enumerate(batch_params): + if np.array(param).shape != shape: + batch_params[idx] = to_shape(np.array(param), shape) + batch_params = list(zip(*[np.array(param).ravel() for param in batch_params])) node.get_logger().info(f'Generated {len(batch_params)} simulation runs') node.get_logger().debug('PhysicsStep, PhysicsRTF, Seed, Duration, DoorState, ScaleFactor' +