Skip to content

Commit

Permalink
🐛 Fixes small learning rate lower bound bug (#105)
Browse files Browse the repository at this point in the history
* 🔧 Adds submodule pull check to the setup.py

* 🐛 Fixes lr lower bound bug

This commit fixes a small bug which caused a error to be thrown when
trying to fix the learning rate lower bound.
  • Loading branch information
rickstaa authored Feb 16, 2021
1 parent 9a9533f commit 53c2c8c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 110 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ dist/
# node related
node_modules/

# Files to ignore
.coverage
*.pyc
# IDE files to ignore
.vscode
*.code-workspace

# Other
Expand Down
62 changes: 0 additions & 62 deletions .vscode/launch.json

This file was deleted.

23 changes: 0 additions & 23 deletions .vscode/settings.json

This file was deleted.

25 changes: 8 additions & 17 deletions machine_learning_control/control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,7 @@ def update(self, data):
q2_pi_targ = self.ac_targ.Q2(o_, a2)
if self._opt_type.lower() == "minimize":
q_pi_targ = torch.max(
q1_pi_targ,
q2_pi_targ,
q1_pi_targ, q2_pi_targ,
) # Use max clipping to prevent overestimation bias.
else:
q_pi_targ = torch.min(
Expand Down Expand Up @@ -656,9 +655,7 @@ def load_state_dict(self, state_dict, restore_lagrance_multipliers=True):
try:
super().load_state_dict(state_dict)
except AttributeError as e:
raise type(e)(
"The 'state_dict' could not be loaded successfully.",
) from e
raise type(e)("The 'state_dict' could not be loaded successfully.",) from e

def _update_targets(self):
"""Updates the target networks based on a Exponential moving average
Expand Down Expand Up @@ -721,7 +718,7 @@ def bound_lr(
if lr_alpha_final is not None:
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_a_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_a_final
if lr_labda_final is not None:
if lr_labda_final is not None and self._use_lyapunov:
if self._log_labda_optimizer.param_groups[0]["lr"] < lr_labda_final:
self._log_labda_optimizer.param_groups[0]["lr"] = lr_labda_final

Expand All @@ -736,8 +733,7 @@ def alpha(self):
def alpha(self, set_val):
"""Property used to make sure alpha and log_alpha are related."""
self.log_alpha.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_alpha.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_alpha.dtype,
)

@property
Expand All @@ -753,8 +749,7 @@ def labda(self):
def labda(self, set_val):
"""Property used to make sure labda and log_labda are related."""
self.log_labda.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_labda.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_labda.dtype,
)

@property
Expand Down Expand Up @@ -1179,9 +1174,7 @@ def lac(
policy, test_env, num_test_episodes, max_ep_len=max_ep_len
)
logger.store(
TestEpRet=eps_ret,
TestEpLen=eps_len,
extend=True,
TestEpRet=eps_ret, TestEpLen=eps_len, extend=True,
)

# Epoch based learning rate decay
Expand Down Expand Up @@ -1218,13 +1211,11 @@ def lac(
logger.log_tabular("LossQ", average_only=True)
if adaptive_temperature:
logger.log_tabular(
"LossAlpha",
average_only=True,
"LossAlpha", average_only=True,
)
if use_lyapunov:
logger.log_tabular(
"LossLambda",
average_only=True,
"LossLambda", average_only=True,
)
if use_lyapunov:
logger.log_tabular("LVals", with_min_and_max=True)
Expand Down
2 changes: 1 addition & 1 deletion machine_learning_control/control/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def plot_data(
Changes the colorscheme and the default legend style, though.
"""
plt.legend(loc="best").set_draggable(True)
# plt.legend(loc='upper center', ncol=3, handlelength=1,
# plt.legend(loc='upper center', ncol=6, handlelength=1, mode="expand"
# borderaxespad=0., prop={'size': 13})

"""
Expand Down
28 changes: 24 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
"""Setup file for the 'machine_learning_control' python package.
"""

import os.path as osp
import re
import sys

from setuptools import find_namespace_packages, setup

# Script settings
stand_alone_ns_pkgs = ["simzoo"]


def submodules_available(submodules):
"""Throws warning and stops the script if any of the submodules is not present."""
for submodule in submodules:
submodule_setup_path = osp.join(
osp.abspath(osp.dirname(__file__)),
"machine_learning_control",
submodule,
"setup.py",
)

if not osp.exists(submodule_setup_path):
print("Could not find {}".format(submodule_setup_path))
print("Did you run 'git submodule update --init --recursive'?")
sys.exit(1)


# Add extra virtual shortened package for each stand-alone namespace package
# NOTE: This only works if you don't have a __init__.py file in your parent folder and
# stand alone_ns_pkgs folder.
Expand All @@ -32,9 +52,9 @@
if short_child not in PACKAGES:
PACKAGES.append(short_child)

# Throw warning if submodules were not pulled
submodules_available(stand_alone_ns_pkgs)

setup(
packages=PACKAGES,
package_dir=PACKAGE_DIR,
packages=PACKAGES, package_dir=PACKAGE_DIR,
)

# TODO: Add submodule check like in pytorch!

0 comments on commit 53c2c8c

Please sign in to comment.