Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes small learning rate lower bound bug #105

Merged
merged 3 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ dist/
# node related
node_modules/

# Files to ignore
.coverage
*.pyc
# IDE files to ignore
.vscode
*.code-workspace

# Other
Expand Down
62 changes: 0 additions & 62 deletions .vscode/launch.json

This file was deleted.

23 changes: 0 additions & 23 deletions .vscode/settings.json

This file was deleted.

25 changes: 8 additions & 17 deletions machine_learning_control/control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,7 @@ def update(self, data):
q2_pi_targ = self.ac_targ.Q2(o_, a2)
if self._opt_type.lower() == "minimize":
q_pi_targ = torch.max(
q1_pi_targ,
q2_pi_targ,
q1_pi_targ, q2_pi_targ,
) # Use max clipping to prevent overestimation bias.
else:
q_pi_targ = torch.min(
Expand Down Expand Up @@ -656,9 +655,7 @@ def load_state_dict(self, state_dict, restore_lagrance_multipliers=True):
try:
super().load_state_dict(state_dict)
except AttributeError as e:
raise type(e)(
"The 'state_dict' could not be loaded successfully.",
) from e
raise type(e)("The 'state_dict' could not be loaded successfully.",) from e

def _update_targets(self):
"""Updates the target networks based on a Exponential moving average
Expand Down Expand Up @@ -721,7 +718,7 @@ def bound_lr(
if lr_alpha_final is not None:
if self._log_alpha_optimizer.param_groups[0]["lr"] < lr_a_final:
self._log_alpha_optimizer.param_groups[0]["lr"] = lr_a_final
if lr_labda_final is not None:
if lr_labda_final is not None and self._use_lyapunov:
if self._log_labda_optimizer.param_groups[0]["lr"] < lr_labda_final:
self._log_labda_optimizer.param_groups[0]["lr"] = lr_labda_final

Expand All @@ -736,8 +733,7 @@ def alpha(self):
def alpha(self, set_val):
"""Property used to make sure alpha and log_alpha are related."""
self.log_alpha.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_alpha.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_alpha.dtype,
)

@property
Expand All @@ -753,8 +749,7 @@ def labda(self):
def labda(self, set_val):
"""Property used to make sure labda and log_labda are related."""
self.log_labda.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_labda.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_labda.dtype,
)

@property
Expand Down Expand Up @@ -1179,9 +1174,7 @@ def lac(
policy, test_env, num_test_episodes, max_ep_len=max_ep_len
)
logger.store(
TestEpRet=eps_ret,
TestEpLen=eps_len,
extend=True,
TestEpRet=eps_ret, TestEpLen=eps_len, extend=True,
)

# Epoch based learning rate decay
Expand Down Expand Up @@ -1218,13 +1211,11 @@ def lac(
logger.log_tabular("LossQ", average_only=True)
if adaptive_temperature:
logger.log_tabular(
"LossAlpha",
average_only=True,
"LossAlpha", average_only=True,
)
if use_lyapunov:
logger.log_tabular(
"LossLambda",
average_only=True,
"LossLambda", average_only=True,
)
if use_lyapunov:
logger.log_tabular("LVals", with_min_and_max=True)
Expand Down
2 changes: 1 addition & 1 deletion machine_learning_control/control/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def plot_data(
Changes the colorscheme and the default legend style, though.
"""
plt.legend(loc="best").set_draggable(True)
# plt.legend(loc='upper center', ncol=3, handlelength=1,
# plt.legend(loc='upper center', ncol=6, handlelength=1, mode="expand"
# borderaxespad=0., prop={'size': 13})

"""
Expand Down
28 changes: 24 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
"""Setup file for the 'machine_learning_control' python package.
"""

import os.path as osp
import re
import sys

from setuptools import find_namespace_packages, setup

# Script settings
stand_alone_ns_pkgs = ["simzoo"]


def submodules_available(submodules):
"""Throws warning and stops the script if any of the submodules is not present."""
for submodule in submodules:
submodule_setup_path = osp.join(
osp.abspath(osp.dirname(__file__)),
"machine_learning_control",
submodule,
"setup.py",
)

if not osp.exists(submodule_setup_path):
print("Could not find {}".format(submodule_setup_path))
print("Did you run 'git submodule update --init --recursive'?")
sys.exit(1)


# Add extra virtual shortened package for each stand-alone namespace package
# NOTE: This only works if you don't have a __init__.py file in your parent folder and
# stand alone_ns_pkgs folder.
Expand All @@ -32,9 +52,9 @@
if short_child not in PACKAGES:
PACKAGES.append(short_child)

# Throw warning if submodules were not pulled
submodules_available(stand_alone_ns_pkgs)

setup(
packages=PACKAGES,
package_dir=PACKAGE_DIR,
packages=PACKAGES, package_dir=PACKAGE_DIR,
)

# TODO: Add submodule check like in pytorch!