Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab committed Feb 5, 2023
1 parent 3055474 commit af1320a
Show file tree
Hide file tree
Showing 22 changed files with 766 additions and 179 deletions.
440 changes: 438 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

25 changes: 20 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ python = ">=3.8,<3.12"

[tool.isort]
profile = "black"
multi_line_output = 3
force_grid_wrap = 1

[tool.black]
line-length = 79
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203"]

[tool.pytest.ini_options]
markers = [
Expand Down Expand Up @@ -51,11 +54,12 @@ skip_install = True
deps =
isort
black
autoflake
flake8
pyproject-flake8
commands =
isort --sl src
autoflake -ri --remove-all-unused-imports --ignore-init-module-imports src
isort src
black src
pflake8 src
[testenv:eval_jax]
deps =
Expand Down Expand Up @@ -96,4 +100,15 @@ deps =
pandas
commands =
{envpython} -m pytest tests/test_tensorflow.py {posargs}
[testenv:notebook]
basepython = python3.11
deps =
jax[cuda]
jupyterlab
jupyterlab_nvdashboard
install_command =
pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html {opts} {packages}
commands =
jupyter lab --no-browser
"""
22 changes: 17 additions & 5 deletions src/logbesselk/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from .integral import bessel_ke
from .integral import bessel_kratio
from .integral import log_abs_deriv_bessel_k
from .integral import log_bessel_k
from .misc import sign_deriv_bessel_k
from .integral import (
bessel_ke,
bessel_kratio,
log_abs_deriv_bessel_k,
log_bessel_k,
)
from .misc import (
sign_deriv_bessel_k,
)

__all__ = [
"bessel_ke",
"bessel_kratio",
"log_abs_deriv_bessel_k",
"log_bessel_k",
"sign_deriv_bessel_k",
]
20 changes: 13 additions & 7 deletions src/logbesselk/jax/asymptotic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import jax.lax as lax
import jax.numpy as jnp

from .math import fabs
from .math import log
from .math import sqrt
from .math import square
from .utils import epsilon
from .utils import result_type
from .wrap import wrap_log_bessel_k
from .math import (
fabs,
log,
sqrt,
square,
)
from .utils import (
epsilon,
result_type,
)
from .wrap import (
wrap_log_bessel_k,
)

__all__ = [
"log_bessel_k",
Expand Down
24 changes: 16 additions & 8 deletions src/logbesselk/jax/cfraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@

import jax.lax as lax

from .math import fabs
from .math import fround
from .math import log
from .math import square
from .misc import log_bessel_recurrence
from .utils import epsilon
from .utils import result_type
from .wrap import wrap_log_bessel_k
from .math import (
fabs,
fround,
log,
square,
)
from .misc import (
log_bessel_recurrence,
)
from .utils import (
epsilon,
result_type,
)
from .wrap import (
wrap_log_bessel_k,
)

__all__ = [
"log_bessel_k",
Expand Down
46 changes: 28 additions & 18 deletions src/logbesselk/jax/integral.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import jax.lax as lax
import jax.numpy as jnp
from jax import grad

from .math import cosh
from .math import is_finite
from .math import log
from .math import log_cosh
from .math import log_sinh
from .math import maximum
from .math import square
from .utils import epsilon
from .utils import extend
from .utils import find_zero
from .utils import log_integrate
from .utils import result_type
from .wrap import wrap_bessel_ke
from .wrap import wrap_bessel_kratio
from .wrap import wrap_log_abs_deriv_bessel_k
from jax import (
grad,
)

from .math import (
cosh,
is_finite,
log,
log_cosh,
log_sinh,
maximum,
square,
)
from .utils import (
epsilon,
extend,
find_zero,
log_integrate,
result_type,
)
from .wrap import (
wrap_bessel_ke,
wrap_bessel_kratio,
wrap_log_abs_deriv_bessel_k,
)

__all__ = [
"log_bessel_k",
Expand Down Expand Up @@ -58,6 +66,9 @@ def func(t):
out += n * log_cosh(t)
return out

def mfunc(t):
return func(t) - th

scale = 0.1
tol = 1.0
max_iter = 10
Expand Down Expand Up @@ -85,7 +96,6 @@ def func(t):
tp = find_zero(deriv, start, delta, tol, max_iter)

th = func(tp) + log(eps) - tol
mfunc = lambda t: func(t) - th
tpl = maximum(tp - bins * eps, zero)
tpr = maximum(tp + bins * eps, tp * (1 + bins * eps))
mfunc_at_zero_is_negative = out_is_finite & (mfunc(zero) < 0)
Expand Down
38 changes: 23 additions & 15 deletions src/logbesselk/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,32 @@

import jax
from jax.numpy import abs as fabs
from jax.numpy import cosh
from jax.numpy import exp
from jax.numpy import expm1
from jax.numpy import inf
from jax.numpy import (
cosh,
exp,
expm1,
inf,
)
from jax.numpy import isfinite as is_finite
from jax.numpy import log
from jax.numpy import log1p
from jax.numpy import (
log,
log1p,
)
from jax.numpy import logaddexp as log_add_exp
from jax.numpy import maximum
from jax.numpy import nan
from jax.numpy import (
maximum,
nan,
)
from jax.numpy import round as fround
from jax.numpy import sign
from jax.numpy import sinc
from jax.numpy import sinh
from jax.numpy import sqrt
from jax.numpy import square
from jax.numpy import tanh
from jax.numpy import where
from jax.numpy import (
sign,
sinc,
sinh,
sqrt,
square,
tanh,
where,
)

__all__ = [
"fabs",
Expand Down
12 changes: 8 additions & 4 deletions src/logbesselk/jax/misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import jax.lax as lax

from .math import log
from .math import log_add_exp
from .math import sign
from .utils import result_type
from .math import (
log,
log_add_exp,
sign,
)
from .utils import (
result_type,
)

__all__ = [
"sign_deriv_bessel_k",
Expand Down
24 changes: 16 additions & 8 deletions src/logbesselk/jax/sca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

from .asymptotic import log_bessel_k_naive as log_k_large_v
from .cfraction import log_bessel_ku as log_ku_large_x
from .math import fround
from .math import is_finite
from .math import log
from .misc import log_bessel_recurrence
from .math import (
fround,
is_finite,
log,
)
from .misc import (
log_bessel_recurrence,
)
from .series import log_bessel_ku as log_ku_small_x
from .utils import result_type
from .wrap import wrap_bessel_ke
from .wrap import wrap_bessel_kratio
from .wrap import wrap_log_bessel_k
from .utils import (
result_type,
)
from .wrap import (
wrap_bessel_ke,
wrap_bessel_kratio,
wrap_log_bessel_k,
)

__all__ = [
"log_bessel_k",
Expand Down
33 changes: 20 additions & 13 deletions src/logbesselk/jax/series.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import jax.lax as lax

from .math import cosh
from .math import exp
from .math import fabs
from .math import fround
from .math import log
from .math import sinc
from .math import sinhc
from .math import square
from .misc import log_bessel_recurrence
from .utils import epsilon
from .utils import result_type
from .wrap import wrap_log_bessel_k
from .math import (
cosh,
exp,
fabs,
fround,
log,
sinc,
sinhc,
square,
)
from .misc import (
log_bessel_recurrence,
)
from .utils import (
epsilon,
result_type,
)
from .wrap import (
wrap_log_bessel_k,
)

__all__ = [
"log_bessel_k",
Expand All @@ -33,7 +41,6 @@ def log_bessel_k(v, x):


def log_bessel_ku(u, x):

def cond(args):
ku, kn, i, p, q, r, s = args
update = fabs(r * s) > eps * fabs(ku)
Expand Down
8 changes: 5 additions & 3 deletions src/logbesselk/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import jax.lax as lax
import jax.numpy as jnp

from .math import exp
from .math import fabs
from .math import log
from .math import (
exp,
fabs,
log,
)

__all__ = [
"result_type",
Expand Down
16 changes: 10 additions & 6 deletions src/logbesselk/jax/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import jax
import jax.lax as lax

from .math import exp
from .math import fabs
from .math import inf
from .math import nan
from .math import sign
from .utils import result_type
from .math import (
exp,
fabs,
inf,
nan,
sign,
)
from .utils import (
result_type,
)

__all__ = [
"wrap_log_bessel_k",
Expand Down
22 changes: 17 additions & 5 deletions src/logbesselk/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from .integral import bessel_ke
from .integral import bessel_kratio
from .integral import log_abs_deriv_bessel_k
from .integral import log_bessel_k
from .misc import sign_deriv_bessel_k
from .integral import (
bessel_ke,
bessel_kratio,
log_abs_deriv_bessel_k,
log_bessel_k,
)
from .misc import (
sign_deriv_bessel_k,
)

__all__ = [
"bessel_ke",
"bessel_kratio",
"log_abs_deriv_bessel_k",
"log_bessel_k",
"sign_deriv_bessel_k",
]
Loading

0 comments on commit af1320a

Please sign in to comment.