Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 8, 2024
1 parent 2bdad0b commit 6a97165
Show file tree
Hide file tree
Showing 19 changed files with 182 additions and 126 deletions.
1 change: 1 addition & 0 deletions profit/al/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
In order to get the most out of the least number of training points, the next point is inferred by calculating an
acquisition function like the minimization of local variance or expected improvement.
"""

import numpy as np
from abc import abstractmethod
from warnings import warn
Expand Down
8 changes: 5 additions & 3 deletions profit/al/mcmc_al.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def __init__(
var.name for var in variables.list if var.kind.lower() == "activelearning"
]
Xpred = [
np.linspace(*var.constraints, nsearch)
if var.name in al_keys
else np.unique(var.value)
(
np.linspace(*var.constraints, nsearch)
if var.name in al_keys
else np.unique(var.value)
)
for var in variables.input_list
]
self.Xpred = np.hstack(
Expand Down
10 changes: 5 additions & 5 deletions profit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,11 @@ def process_entries(self, base_config):
{
"class": name,
"columns": select,
"parameters": {
k: float(v) for k, v in config.get("parameters", {})
}
if not isinstance(config, str)
else {},
"parameters": (
{k: float(v) for k, v in config.get("parameters", {})}
if not isinstance(config, str)
else {}
),
}
)

Expand Down
1 change: 1 addition & 0 deletions profit/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Global default configuration values."""

from os import path, getcwd

# Base Config
Expand Down
8 changes: 5 additions & 3 deletions profit/run/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def __repr__(self):
return (
f"<{self.__class__.__name__} (" + ", debug"
if self.debug
else "" + f", {self.command}"
if self.command != "profit-worker"
else "" + ")>"
else (
"" + f", {self.command}"
if self.command != "profit-worker"
else "" + ")>"
)
)

@property
Expand Down
10 changes: 5 additions & 5 deletions profit/run/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def __repr__(self):
return (
f"<{self.__class__.__name__} (" + f", {self.cpus} cpus" + ", OpenMP"
if self.openmp
else "" + ", debug"
if self.debug
else "" + ", custom script"
if self.custom
else "" + ")>"
else (
"" + ", debug"
if self.debug
else "" + ", custom script" if self.custom else "" + ")>"
)
)

@property
Expand Down
1 change: 1 addition & 0 deletions profit/sur/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Encoder(CustomABC):
Attributes:
label (str): Label of the encoder class.
"""

labels = {}

def __init__(self, columns, parameters=None):
Expand Down
1 change: 1 addition & 0 deletions profit/sur/gp/backend/init_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
that is written to `kernels_base.f90`. This code has to
be compiled via `make` subsequently.
"""

# %%
from sympy import symbols, sqrt, exp, diff
from sympy.utilities.codegen import codegen
Expand Down
16 changes: 9 additions & 7 deletions profit/sur/gp/gpy_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,15 @@ def select_kernel(self, kernel):
kern = []
for key in full_str:
kern += [
key
if key in ("+", "*")
else "self.GPy.kern.{}({}, lengthscale={}, variance={})".format(
key,
self.ndim,
self.hyperparameters.get("length_scale", [1]),
self.hyperparameters.get("sigma_f", 1) ** 2,
(
key
if key in ("+", "*")
else "self.GPy.kern.{}({}, lengthscale={}, variance={})".format(
key,
self.ndim,
self.hyperparameters.get("length_scale", [1]),
self.hyperparameters.get("sigma_f", 1) ** 2,
)
)
]
return eval("".join(kern))
Expand Down
10 changes: 6 additions & 4 deletions profit/sur/gp/sklearn_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def select_kernel(self, kernel):
kernel = []
for key in full_str:
kernel += [
key
if key in ("+", "*")
else getattr(sklearn_kernels, key)(
length_scale=self.hyperparameters["length_scale"]
(
key
if key in ("+", "*")
else getattr(sklearn_kernels, key)(
length_scale=self.hyperparameters["length_scale"]
)
)
]
except AttributeError:
Expand Down
Loading

0 comments on commit 6a97165

Please sign in to comment.