From d2856422037199e8e461e173361999f287281bca Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 01:17:37 +0200 Subject: [PATCH] chore(rstaa2024): update experimental scripts This commit updates the experimental scripts. --- ...cost_alpha3_tune_experiment_lac_critic.yml | 3 +- ...ha3_tune_experiment_seed234_lac_critic.yml | 3 +- ...a3_tune_experiment_seed3658_lac_critic.yml | 3 +- ...3_tune_experiment_seed48104_lac_critic.yml | 3 +- ...ha3_tune_experiment_seed567_lac_critic.yml | 3 +- ...3_tune_experiment_seed78456_lac_critic.yml | 3 +- ...each_alpha3_tune_experiment_lac_critic.yml | 3 +- ...ha3_tune_experiment_seed234_lac_critic.yml | 3 +- ...a3_tune_experiment_seed3658_lac_critic.yml | 3 +- ...3_tune_experiment_seed48104_lac_critic.yml | 3 +- ...ha3_tune_experiment_seed567_lac_critic.yml | 3 +- ...3_tune_experiment_seed78456_lac_critic.yml | 3 +- ...c_cart_pole_extra_experiments_seed234.bash | 6 +- ..._cart_pole_extra_experiments_seed3658.bash | 6 +- ...cart_pole_extra_experiments_seed48104.bash | 6 +- ...c_cart_pole_extra_experiments_seed567.bash | 6 +- ...cart_pole_extra_experiments_seed78456.bash | 6 +- ...fetch_reach_extra_experiments_seed234.bash | 6 +- ...etch_reach_extra_experiments_seed3658.bash | 6 +- ...tch_reach_extra_experiments_seed48104.bash | 6 +- ...fetch_reach_extra_experiments_seed567.bash | 6 +- ...tch_reach_extra_experiments_seed78456.bash | 6 +- .../scripts/alpha3_tuning_data_analysis.py | 3 +- ..._tuning_han_2020_original_data_analysis.py | 116 +++++++++ ...extra_critic_layer_compare_data_analyis.py | 246 ++++++++++++++++++ 25 files changed, 417 insertions(+), 44 deletions(-) create mode 100644 experiments/staa_et_al_2024/scripts/alpha3_tuning_han_2020_original_data_analysis.py create mode 100644 experiments/staa_et_al_2024/scripts/sac_extra_critic_layer_compare_data_analyis.py diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_lac_critic.yml index 9372d9c8..b6b66e37 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_lac_critic.yml index c6382b03..caba99da 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_lac_critic.yml index 50f30c7a..4b88e19f 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_lac_critic.yml index 297355b8..d41e1354 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_lac_critic.yml index 544d3cc5..caa73e18 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_lac_critic.yml b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_lac_critic.yml index 9ee56f4d..12eeb3ec 100644 --- a/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_lac_critic.yml +++ b/experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:CartPoleCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_lac_critic.yml index eeda47aa..ef373e4c 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_lac_critic.yml index 6637843b..b76c6897 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_lac_critic.yml index 14612af6..950e2faf 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_lac_critic.yml index 458cf641..34379989 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_lac_critic.yml index 7470edce..2403cf27 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_lac_critic.yml b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_lac_critic.yml index 9b115c32..9309efeb 100644 --- a/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_lac_critic.yml +++ b/experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_lac_critic.yml @@ -4,7 +4,8 @@ env_name: "stable_gym:FetchReachCost-v1" ac_kwargs: hidden_sizes: actor: [256, 256] # NOTE: Using [256, 256] for consistency with the article. - critic: [64, 64, 16] + # critic: [64, 64, 16] + critic: [256, 256, 16] activation: actor: "nn.ReLU" critic: "nn.ReLU" diff --git a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed234.bash b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed234.bash index 5658f066..def43091 100644 --- a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed234.bash +++ b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed234.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed234_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed3658.bash b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed3658.bash index 6efd970a..d6a877fb 100644 --- a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed3658.bash +++ b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed3658.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed3658_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed48104.bash b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed48104.bash index 93a494a0..14dff283 100644 --- a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed48104.bash +++ b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed48104.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed48104_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed567.bash b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed567.bash index cddf9fbf..dcf95523 100644 --- a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed567.bash +++ b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed567.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed567_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed78456.bash b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed78456.bash index 470ae1e9..7ba1afab 100644 --- a/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed78456.bash +++ b/experiments/staa_et_al_2024/run_sac_cart_pole_extra_experiments_seed78456.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/different_steps_per_update/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/lac_critic/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/cartpole/sac_extra/sac_extra_all/han2020_reproduction_sac_cartpole_cost_alpha3_tune_experiment_seed78456_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed234.bash b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed234.bash index ba226edc..dd6d1e98 100644 --- a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed234.bash +++ b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed234.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed234_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed3658.bash b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed3658.bash index 4afe5127..3c5f912c 100644 --- a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed3658.bash +++ b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed3658.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed3658_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed48104.bash b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed48104.bash index 1c6e9a0a..0960a648 100644 --- a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed48104.bash +++ b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed48104.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed48104_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed567.bash b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed567.bash index f45eddc0..710e1a0e 100644 --- a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed567.bash +++ b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed567.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed567_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed78456.bash b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed78456.bash index e590fdc1..2b3fb42d 100644 --- a/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed78456.bash +++ b/experiments/staa_et_al_2024/run_sac_fetch_reach_extra_experiments_seed78456.bash @@ -1,4 +1,4 @@ -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_bigger_initial_alpha.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_different_steps_per_update.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/bigger_initial_alpha/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_bigger_initial_alpha.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/different_steps_per_update/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_different_steps_per_update.yml python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/lac_critic/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_lac_critic.yml -python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_sac_extra_all.yml +# python -m stable_learning_control.run --exp_cfg experiments/staa_et_al_2024/fetch_reach/sac_extra/sac_extra_all/han2020_reproduction_sac_fetch_reach_alpha3_tune_experiment_seed78456_sac_extra_all.yml diff --git a/experiments/staa_et_al_2024/scripts/alpha3_tuning_data_analysis.py b/experiments/staa_et_al_2024/scripts/alpha3_tuning_data_analysis.py index d6d91397..d20a2624 100644 --- a/experiments/staa_et_al_2024/scripts/alpha3_tuning_data_analysis.py +++ b/experiments/staa_et_al_2024/scripts/alpha3_tuning_data_analysis.py @@ -1,7 +1,6 @@ """This script is used to perform the data analysis of the alpha3 tuning experiments of my master thesis. """ - # import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -195,7 +194,7 @@ def calculate_condition_convergence_statistics( print("Creating {} plot...".format(env_name.replace("_", " ").title())) # Retrieve data directories and add legend column. - if not Path(DATA_DIR).exists(): + if not Path(DATA_DIR).resolve().exists(): raise FileNotFoundError(f"The data directory {DATA_DIR} does not exist.") all_data_folders = sorted( [str(f) for f in Path(DATA_DIR).iterdir() if f.is_dir()] diff --git a/experiments/staa_et_al_2024/scripts/alpha3_tuning_han_2020_original_data_analysis.py b/experiments/staa_et_al_2024/scripts/alpha3_tuning_han_2020_original_data_analysis.py new file mode 100644 index 00000000..63e4cd38 --- /dev/null +++ b/experiments/staa_et_al_2024/scripts/alpha3_tuning_han_2020_original_data_analysis.py @@ -0,0 +1,116 @@ +"""This script is used to perform the data analysis of the alpha3 tuning experiments of +my master thesis that were conducted using the original code of Han et al. (2020). +""" +# import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from pathlib import Path + +from stable_learning_control.utils.plot import get_all_datasets, plot_data + +# Script parameters. +DATA_DIRS = [ + "data/han_2020_original/cartpole_cost", + "data/han_2020_original/fetch_reach", + "data/han_2020_original/oscillator", + "data/han_2020_original/oscillator_complicated", +] # List of data directories to analyze. +SAVE_PLOT = True +ALPHA3_VALUES = [0.1, 0.3, 1.0] +CONVERGENCE_THRESHOLD = 0.95 # Define convergence threshold for metric. +MOVING_AVERAGE_WINDOW = 5 # Set moving average window for metric. +EXTRA_SUFFIX = "" # Add extra file suffix. + +# Select the metric to analyse. +# METRIC = "Performance" +METRIC = "lambda" +# METRIC = "AverageLambda" +X_AXIS = "total_timesteps" + + +def create_legend_strings(data_folders): + """Create legend strings for the alpha3 tuning plots. + + Args: + data_folders (List[str]): List of data folders. + + Returns: + List[str]: List of legend strings. + """ + return [ + f'$\\alpha_{3}$={folder.split("_")[-1].replace("alp", "").replace("-", ".")}' + for folder in data_folders + ] + + +def filter_folders_by_alpha3(data_folders, alpha3_values): + """Filter the data folders by alpha3 value. + + Args: + data_folders (List[str]): List of data folders. + alpha3_values (list[float]): Alpha value to filter by. + + Returns: + List[str]: List of filtered data folders. + """ + return [ + folder + for folder in data_folders + if any( + f"alp{str(alpha).replace('.', '-')}" in folder for alpha in alpha3_values + ) + ] + + +if __name__ == "__main__": + print(f"\nCreating '{METRIC}' plots...") + for idx, DATA_DIR in enumerate(DATA_DIRS): + parent_dir = Path(DATA_DIR).parent + env_name = DATA_DIR.split("/")[-1] + data_dir = f"{parent_dir}/plots/{env_name}" + metric_str = METRIC.replace(" ", "_").lower() + Path(data_dir).mkdir(parents=True, exist_ok=True) + print("Creating {} plot...".format(env_name.replace("_", " ").title())) + + # Retrieve data directories and add legend column. + if not Path(DATA_DIR).resolve().exists(): + raise FileNotFoundError(f"The data directory {DATA_DIR} does not exist.") + all_data_folders = sorted( + [str(f) for f in Path(DATA_DIR).iterdir() if f.is_dir()] + ) + if ALPHA3_VALUES: + data_folders = filter_folders_by_alpha3(all_data_folders, ALPHA3_VALUES) + else: + data_folders = all_data_folders + file_suffix = "_filtered" if len(data_folders) < len(all_data_folders) else "" + legend = create_legend_strings(data_folders) + data = get_all_datasets(data_folders, legend=legend) + + # Visualize the data using the SCL plot utilities. + fig_name = env_name.replace("_", " ").title() + fig = plt.figure(figsize=(10, 8), num=f"Alpha3 Tuning ({fig_name})") + palette = sns.color_palette("tab20", n_colors=len(legend)) + plt.tight_layout() + plot_data( + data, + xaxis=X_AXIS, + value=METRIC, + errorbar="ci", + smooth=MOVING_AVERAGE_WINDOW, + style="ticks", + palette=palette, + ) + plt.grid() + + # Save the plot as png file. + if SAVE_PLOT: + plot_file_path = ( + f"{data_dir}/alpha3_tune_{metric_str}_{env_name}_plot" + f"{file_suffix}{EXTRA_SUFFIX}.png" + ) + print(f"Saving plot to {plot_file_path}.") + plt.savefig(plot_file_path, bbox_inches="tight", dpi=300) + + plt.show() + print("Analysis completed.") diff --git a/experiments/staa_et_al_2024/scripts/sac_extra_critic_layer_compare_data_analyis.py b/experiments/staa_et_al_2024/scripts/sac_extra_critic_layer_compare_data_analyis.py new file mode 100644 index 00000000..45805a4d --- /dev/null +++ b/experiments/staa_et_al_2024/scripts/sac_extra_critic_layer_compare_data_analyis.py @@ -0,0 +1,246 @@ +"""This script is used to perform the data analysis to compare the performance +convergence of the regular SAC and the SAC with an additional critic layer. +""" + +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from pathlib import Path + +from stable_learning_control.utils.plot import get_all_datasets, plot_data + +# Script parameters. +DATA_DIRS = [ + # "data/han_reproduction_sac/extra/cartpole/han2020_reproduction_sac_cartpole_cost_alpha3_tune_exp_lac_critic", # noqa E501 + # "data/han_reproduction_sac/main/han2020_reproduction_sac_cartpole_cost_alpha3_tune_exp" # noqa E501 + # "data/han_reproduction_sac/extra/comp_oscillator/han2020_reproduction_sac_oscillator_complicated_alpha3_tune_exp_lac_critic", # noqa E501 + # "data/han_reproduction_sac/main/han2020_reproduction_sac_oscillator_complicated_alpha3_tune_exp", # noqa E501 + # "data/han_reproduction_sac/extra/fetch_reach/han2020_reproduction_sac_fetch_reach_alpha3_tune_exp_lac_critic", # noqa E501 + # "data/han_reproduction_sac/main/han2020_reproduction_sac_fetch_reach_alpha3_tune_exp", # noqa E501 + "data/han_reproduction_sac/extra/oscillator/han2020_reproduction_sac_oscillator_alpha3_tune_exp_lac_critic", # noqa E501 + "data/han_reproduction_sac/main/han2020_reproduction_sac_oscillator_alpha3_tune_exp", # noqa E501 +] # List of data directories to analyze. +# CONDITION_NAMES = ["$f_\phi=[64,64,16]$", "$f_\phi=[64,64]$"] # Change plot legend names. +CONDITION_NAMES = ["$f_\phi=[256,256,16]$", "$f_\phi=[256,256]$"] # Change plot legend names. +PLOT = True +SAVE_DATA = True +SAVE_STATISTICS = True +SAVE_PLOT = True +LAST_N_EPOCHS = [ + 10, + 10, + 10, + 10, + 10, +] # Epoch count for average final metric value calculation per condition. +CONVERGENCE_THRESHOLD = 0.95 # Define convergence threshold for metric. +MOVING_AVERAGE_WINDOW = 5 # Set moving average window for metric. +# LAST_N_PLOT_EPOCHS = 5 # Epochs to plot from training end. None plots all. +LAST_N_PLOT_EPOCHS = None # Epochs to plot from training end. None plots all. +EXTRA_SUFFIX = "" # Add extra file suffix. + +# Select the metric to analyse. +METRIC = "Performance" +# METRIC = "AverageTestEpLen" +X_AXIS = "TotalEnvInteracts" +# X_AXIS = "Epoch" + + +def calculate_condition_performance_statistics( + df, metric="AverageTestEpRet", num_epochs=10 +): + """Calculate the mean, std and max of the metric for each condition. + + Args: + df (pd.DataFrame): DataFrame to calculate the statistics from. + metric (str, optional): The metric to calculate the statistics for. Defaults to + "AverageTestEpRet". + num_epochs (int, optional): Number of epochs to consider. Defaults to 10. + + Returns: + pd.DataFrame: DataFrame with the mean, std, min and max of the metric for each + condition. + """ + # Filter the data to include only the last N epochs for each seed. + last_n_df = df.groupby(["Condition1", "Condition2"]).tail(num_epochs) + + # Calculate mean and std for each seed. + seed_means_stds = ( + last_n_df.groupby(["Condition1", "Condition2"])[metric] + .agg(["mean", "std"]) + .reset_index() + ) + + # Calculate mean, max and std across seeds for each condition. + condition_means_stds = ( + seed_means_stds.groupby("Condition1")["mean"] + .agg(["mean", "std", "min", "max"]) + .reset_index() + ) + condition_means_stds.columns = [ + "Condition", + "Mean of Seed Means", + "Std of Seed Means", + "Min of Seed Means", + "Max of Seed Means", + ] + return condition_means_stds + + +def calculate_condition_convergence_statistics( + df, + metric="Performance", + moving_avg_window=1, + convergence_threshold=0.95, + num_epochs_for_average=10, +): + """Calculate the mean and std of the convergence epoch for each condition. + + Args: + df (pd.DataFrame): DataFrame to calculate the statistics from. + metric (str, optional): The metric to calculate the statistics for. Defaults to + "Performance". + moving_avg_window (int, optional): Moving average window for the metric. + Defaults to 1. + convergence_threshold (float, optional): Convergence threshold for the metric. + Defaults to 0.95. + num_epochs_for_average (int, optional): Number of epochs to consider for the + average metric value. Defaults to 10. + + Returns: + pd.DataFrame: DataFrame with the mean and std of the convergence epoch for each + condition. + """ + # Calculate the convergence epoch per seed. + convergence_epochs = pd.DataFrame( + index=pd.MultiIndex.from_arrays([[], []], names=["Condition1", "Condition2"]) + ) + for (condition, seed), group in df.groupby(["Condition1", "Condition2"]): + mean = group[metric].rolling(window=moving_avg_window).mean() + metric_baseline = mean.tail(num_epochs_for_average).mean() + metric_convergence_threshold = (mean.max() - metric_baseline) * ( + 1 - convergence_threshold + ) + metric_baseline + converged_epoch = None + for epoch, metric_value in zip(group["Epoch"], mean): + if metric_value <= metric_convergence_threshold: + if converged_epoch is None: + converged_epoch = epoch + else: + converged_epoch = None + # Add the convergence epoch to the DataFrame + convergence_epochs.loc[(condition, seed), "Convergence Epoch"] = converged_epoch + + # Calculate the mean and std of convergence epochs across seeds for each condition. + convergence_epoch = ( + convergence_epochs.groupby("Condition1")["Convergence Epoch"] + .agg(["mean", "std"]) + .reset_index() + ) + convergence_epoch.columns = [ + "Condition", + "Convergence Epoch Mean", + "Convergence Epoch Std", + ] + return convergence_epoch + + +if __name__ == "__main__": + print(f"\nCreating '{METRIC}' plots...") + all_data = [] + for idx, DATA_DIR in enumerate(DATA_DIRS): + parent_dir = Path(DATA_DIR).parent + env_name = DATA_DIR.split("/")[-1] + data_dir = f"{parent_dir}/plots/{env_name}" + metric_str = METRIC.replace(" ", "_").lower() + Path(data_dir).mkdir(parents=True, exist_ok=True) + print("Creating {} plot...".format(env_name.replace("_", " ").title())) + + # Retrieve data directories and add legend column. + if not Path(DATA_DIR).resolve().exists(): + raise FileNotFoundError(f"The data directory {DATA_DIR} does not exist.") + data_folders = sorted([str(f) for f in Path(DATA_DIR).iterdir() if f.is_dir()]) + data = get_all_datasets( + data_folders, legend=[CONDITION_NAMES[idx] for _ in data_folders] + ) + + # Add a Condition column to each dataset + for df in data: + df["Condition"] = f"Condition_{idx + 1}" + + all_data.extend(data) + + # Combine all data into a single DataFrame + data_concat = pd.concat(all_data, ignore_index=True) + + # Store the data in csv file. + # NOTE: Can be useful to inspect the data in a spreadsheet software. + if SAVE_DATA: + data_file_path = ( + f"{data_dir}/alpha3_tune_{metric_str}_{env_name}_plot_data" + f"{EXTRA_SUFFIX}.csv" + ) + data_tmp = pd.concat(all_data, ignore_index=True) + print(f"Saving plot input data to {data_file_path}.") + data_tmp.to_csv(data_file_path, index=False) + + # Compute the statistics. + condition_performance_statistics = calculate_condition_performance_statistics( + data_concat, metric=METRIC, num_epochs=LAST_N_EPOCHS[0] + ) + condition_convergence_statistics = calculate_condition_convergence_statistics( + data_concat, + metric=METRIC, + moving_avg_window=MOVING_AVERAGE_WINDOW, + convergence_threshold=CONVERGENCE_THRESHOLD, + num_epochs_for_average=LAST_N_EPOCHS[0], + ) + + # Store the statistics in csv file. + statistics = pd.concat( + [condition_performance_statistics, condition_convergence_statistics], axis=1 + ) + statistics = statistics.loc[:, ~statistics.columns.duplicated()] + statistics_file_path = ( + f"{data_dir}/sac_extra_critic_compare_{metric_str}_{env_name}_statistics" + f"{EXTRA_SUFFIX}.csv" + ) + if SAVE_STATISTICS: + print(f"Saving statistics to {statistics_file_path}.") + statistics.to_csv(statistics_file_path, index=False) + + # Only keep the last N epochs for the plot. + if LAST_N_PLOT_EPOCHS: + all_data = [ + df[df["Epoch"] >= df["Epoch"].max() - LAST_N_PLOT_EPOCHS] for df in all_data + ] + + # Visualize the data using the SCL plot utilities. + if PLOT: + fig_name = env_name.replace("_", " ").title() + fig = plt.figure(figsize=(10, 8), num=f"Alpha3 Tuning ({fig_name})") + palette = sns.color_palette(["#1f77b4", "#ff7f0e"]) # Blue and Orange + plt.tight_layout() + plot_data( + all_data, + xaxis=X_AXIS, + value=METRIC, + errorbar="ci", + smooth=MOVING_AVERAGE_WINDOW, + style="ticks", + palette=palette, + ) + plt.grid() + + # Save the plot as png file. + if SAVE_PLOT: + prefix = "partial_" if LAST_N_PLOT_EPOCHS is not None else "" + plot_file_path = ( + f"{data_dir}/alpha3_tune_{metric_str}_{env_name}_{prefix}plot" + f"{EXTRA_SUFFIX}.png" + ) + print(f"Saving plot to {plot_file_path}.") + plt.savefig(plot_file_path, bbox_inches="tight", dpi=300) + + plt.show() + print("Analysis completed.")