-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from tk2lab/dev
Add integral/cfraction/asyn and make it default
- Loading branch information
Showing
6 changed files
with
96 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "logbesselk" | ||
version = "3.3.0-dev" | ||
version = "3.3.0" | ||
description = "Provide function to calculate the modified Bessel function of the second kind" | ||
license = "Apache-2.0" | ||
authors = ["TAKEKAWA Takashi <[email protected]>"] | ||
|
@@ -12,7 +12,7 @@ requires = ["poetry-core>=1.0.0"] | |
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8,<3.12" | ||
python = ">=3.8,<3.13" | ||
#tensorflow = ">=2.6,<2.12" | ||
#jax = "^0.2+cuda" | ||
|
||
|
@@ -36,48 +36,18 @@ legacy_tox_ini = """ | |
min_version = 4.0 | ||
isolated_build = True | ||
env_list = | ||
py{38,39,310,}-tf{26,27,28,29,210,211} | ||
py{38,39,310,311,}-jax{3,4} | ||
py{310,311,312}-jax{3,4} | ||
py{310,311,312}-tf{29,210,211,212,213,214,215,216} | ||
lint | ||
eval_tf | ||
eval_jax | ||
[gh-actions] | ||
python = | ||
3.8: py38-tf26, py38-tf211, py38-jax3, py38-jax4 | ||
3.9: py39-tf26 | ||
3.10: py310-tf29, py310-tf211 | ||
3.11: py311-jax3, py311-jax4 | ||
3.10: py310-jax3, py310-tf29 | ||
3.12: py312-jax4, py312-tf216 | ||
[testenv:lint] | ||
skip_install = True | ||
deps = | ||
isort | ||
black | ||
flake8 | ||
pyproject-flake8 | ||
commands = | ||
isort src | ||
black src | ||
pflake8 src | ||
[testenv:eval_jax] | ||
deps = | ||
jax[cuda] | ||
pandas | ||
install_command = | ||
pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html {opts} {packages} | ||
commands = | ||
{envpython} eval/eval_jax.py | ||
[testenv:eval_tf] | ||
deps = | ||
tensorflow | ||
pandas | ||
commands = | ||
{envpython} eval/eval_tensorflow.py | ||
[testenv:py{38,39,310,311,}-jax{3,4}] | ||
[testenv:py{310,311,312}-jax{3,4}] | ||
deps = | ||
jax3: jax[cuda] (>=0.3,<0.4) | ||
jax4: jax[cuda] (>=0.4,<0.5) | ||
|
@@ -88,27 +58,34 @@ install_command = | |
commands = | ||
{envpython} -m pytest tests/test_jax.py {posargs} | ||
[testenv:py{38,39,310,}-tf{26,27,28,29,210,211}] | ||
[testenv:py{310,311,312}-tf{29,210,211,212,213,214,215,216}] | ||
deps = | ||
tf26: tensorflow (>=2.6,<2.7) | ||
tf27: tensorflow (>=2.7,<2.8) | ||
tf28: tensorflow (>=2.8,<2.9) | ||
tf29: tensorflow (>=2.9,<2.10) | ||
tf210: tensorflow (>=2.10,<2.11) | ||
tf211: tensorflow (>=2.11,<2.12) | ||
tf212: tensorflow (>=2.12,<2.13) | ||
tf213: tensorflow (>=2.13,<2.14) | ||
tf214: tensorflow (>=2.14,<2.15) | ||
tf215: tensorflow (>=2.15,<2.16) | ||
tf216: tensorflow (>=2.16,<2.17) | ||
pytest | ||
pandas | ||
commands = | ||
{envpython} -m pytest tests/test_tensorflow.py {posargs} | ||
[testenv:notebook] | ||
basepython = python3.11 | ||
[testenv:eval_jax] | ||
deps = | ||
jax[cuda] | ||
jupyterlab | ||
jupyterlab_nvdashboard | ||
pandas | ||
install_command = | ||
pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html {opts} {packages} | ||
commands = | ||
jupyter lab --no-browser | ||
{envpython} eval/eval_jax.py | ||
[testenv:eval_tf] | ||
deps = | ||
tensorflow | ||
pandas | ||
commands = | ||
{envpython} eval/eval_tensorflow.py | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import jax.lax as lax | ||
|
||
from .asymptotic import log_bessel_k_naive as log_k_large_v | ||
from .cfraction import log_bessel_ku as log_ku_small_v | ||
from .integral import log_abs_deriv_bessel_k as log_k_small_x | ||
from .math import ( | ||
fround, | ||
is_finite, | ||
) | ||
from .misc import ( | ||
log_bessel_recurrence, | ||
) | ||
from .utils import ( | ||
result_type, | ||
) | ||
from .wrap import ( | ||
wrap_bessel_ke, | ||
wrap_bessel_kratio, | ||
wrap_log_bessel_k, | ||
) | ||
|
||
__all__ = [ | ||
"log_bessel_k", | ||
"bessel_kratio", | ||
"bessel_ke", | ||
] | ||
|
||
|
||
@wrap_log_bessel_k | ||
def log_bessel_k(v, x): | ||
""" | ||
Combination of Integrate, Continued fraction and Asymptotic expansion. | ||
""" | ||
|
||
def small_x_case(): | ||
return log_k_small_x(v, x) | ||
|
||
def large_x_case(): | ||
def small_v_case(): | ||
n = fround(v) | ||
u = (v - n).astype(dtype) | ||
u_ = lax.cond(small_v, lambda: u, lambda: dtype(1 / 2)) | ||
logk0, logk1 = log_ku_small_v(u_, x) | ||
return log_bessel_recurrence(logk0, logk1, u, n, x)[0] | ||
|
||
def large_v_case(): | ||
v_ = lax.cond(large_v, lambda: v.astype(dtype), lambda: dtype(0)) | ||
return log_k_large_v(v_, x) | ||
|
||
return lax.cond(small_v, small_v_case, large_v_case) | ||
|
||
dtype = result_type(v, x) | ||
finite = is_finite(v) & is_finite(x) & (x > 0) | ||
large_v_ = v >= 25 | ||
large_x_ = x >= 100 | ||
|
||
small_x = finite & ~large_x_ | ||
small_v = finite & large_x_ & ~large_v_ | ||
large_v = finite & large_x_ & large_v_ | ||
return lax.cond(small_x, small_x_case, large_x_case) | ||
|
||
|
||
def bessel_kratio(v, x, d=1): | ||
return wrap_bessel_kratio(log_bessel_k, v, x, d) | ||
|
||
|
||
def bessel_ke(v, x): | ||
return wrap_bessel_ke(log_bessel_k, v, x) |