Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format Python code with psf/black push #106

Merged
merged 1 commit into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions machine_learning_control/control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def update(self, data):
q2_pi_targ = self.ac_targ.Q2(o_, a2)
if self._opt_type.lower() == "minimize":
q_pi_targ = torch.max(
q1_pi_targ, q2_pi_targ,
q1_pi_targ,
q2_pi_targ,
) # Use max clipping to prevent overestimation bias.
else:
q_pi_targ = torch.min(
Expand Down Expand Up @@ -655,7 +656,9 @@ def load_state_dict(self, state_dict, restore_lagrance_multipliers=True):
try:
super().load_state_dict(state_dict)
except AttributeError as e:
raise type(e)("The 'state_dict' could not be loaded successfully.",) from e
raise type(e)(
"The 'state_dict' could not be loaded successfully.",
) from e

def _update_targets(self):
"""Updates the target networks based on a Exponential moving average
Expand Down Expand Up @@ -733,7 +736,8 @@ def alpha(self):
def alpha(self, set_val):
"""Property used to make sure alpha and log_alpha are related."""
self.log_alpha.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_alpha.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_alpha.dtype,
)

@property
Expand All @@ -749,7 +753,8 @@ def labda(self):
def labda(self, set_val):
"""Property used to make sure labda and log_labda are related."""
self.log_labda.data = torch.as_tensor(
np.log(1e-37 if set_val < 1e-37 else set_val), dtype=self.log_labda.dtype,
np.log(1e-37 if set_val < 1e-37 else set_val),
dtype=self.log_labda.dtype,
)

@property
Expand Down Expand Up @@ -1174,7 +1179,9 @@ def lac(
policy, test_env, num_test_episodes, max_ep_len=max_ep_len
)
logger.store(
TestEpRet=eps_ret, TestEpLen=eps_len, extend=True,
TestEpRet=eps_ret,
TestEpLen=eps_len,
extend=True,
)

# Epoch based learning rate decay
Expand Down Expand Up @@ -1211,11 +1218,13 @@ def lac(
logger.log_tabular("LossQ", average_only=True)
if adaptive_temperature:
logger.log_tabular(
"LossAlpha", average_only=True,
"LossAlpha",
average_only=True,
)
if use_lyapunov:
logger.log_tabular(
"LossLambda", average_only=True,
"LossLambda",
average_only=True,
)
if use_lyapunov:
logger.log_tabular("LVals", with_min_and_max=True)
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ def submodules_available(submodules):
submodules_available(stand_alone_ns_pkgs)

setup(
packages=PACKAGES, package_dir=PACKAGE_DIR,
packages=PACKAGES,
package_dir=PACKAGE_DIR,
)