Skip to content

Commit

Permalink
Merge pull request #29 from tk2lab/dev
Browse files Browse the repository at this point in the history
Add integral/cfraction/asyn and make it default
  • Loading branch information
tk2lab authored May 25, 2024
2 parents af1320a + 2824d5b commit ba00c91
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']

steps:
- name: Check out repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/upload.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:

steps:
- name: Check out repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v3
Expand Down
69 changes: 23 additions & 46 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "logbesselk"
version = "3.3.0-dev"
version = "3.3.0"
description = "Provide function to calculate the modified Bessel function of the second kind"
license = "Apache-2.0"
authors = ["TAKEKAWA Takashi <[email protected]>"]
Expand All @@ -12,7 +12,7 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.dependencies]
python = ">=3.8,<3.12"
python = ">=3.8,<3.13"
#tensorflow = ">=2.6,<2.12"
#jax = "^0.2+cuda"

Expand All @@ -36,48 +36,18 @@ legacy_tox_ini = """
min_version = 4.0
isolated_build = True
env_list =
py{38,39,310,}-tf{26,27,28,29,210,211}
py{38,39,310,311,}-jax{3,4}
py{310,311,312}-jax{3,4}
py{310,311,312}-tf{29,210,211,212,213,214,215,216}
lint
eval_tf
eval_jax
[gh-actions]
python =
3.8: py38-tf26, py38-tf211, py38-jax3, py38-jax4
3.9: py39-tf26
3.10: py310-tf29, py310-tf211
3.11: py311-jax3, py311-jax4
3.10: py310-jax3, py310-tf29
3.12: py312-jax4, py312-tf216
[testenv:lint]
skip_install = True
deps =
isort
black
flake8
pyproject-flake8
commands =
isort src
black src
pflake8 src
[testenv:eval_jax]
deps =
jax[cuda]
pandas
install_command =
pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html {opts} {packages}
commands =
{envpython} eval/eval_jax.py
[testenv:eval_tf]
deps =
tensorflow
pandas
commands =
{envpython} eval/eval_tensorflow.py
[testenv:py{38,39,310,311,}-jax{3,4}]
[testenv:py{310,311,312}-jax{3,4}]
deps =
jax3: jax[cuda] (>=0.3,<0.4)
jax4: jax[cuda] (>=0.4,<0.5)
Expand All @@ -88,27 +58,34 @@ install_command =
commands =
{envpython} -m pytest tests/test_jax.py {posargs}
[testenv:py{38,39,310,}-tf{26,27,28,29,210,211}]
[testenv:py{310,311,312}-tf{29,210,211,212,213,214,215,216}]
deps =
tf26: tensorflow (>=2.6,<2.7)
tf27: tensorflow (>=2.7,<2.8)
tf28: tensorflow (>=2.8,<2.9)
tf29: tensorflow (>=2.9,<2.10)
tf210: tensorflow (>=2.10,<2.11)
tf211: tensorflow (>=2.11,<2.12)
tf212: tensorflow (>=2.12,<2.13)
tf213: tensorflow (>=2.13,<2.14)
tf214: tensorflow (>=2.14,<2.15)
tf215: tensorflow (>=2.15,<2.16)
tf216: tensorflow (>=2.16,<2.17)
pytest
pandas
commands =
{envpython} -m pytest tests/test_tensorflow.py {posargs}
[testenv:notebook]
basepython = python3.11
[testenv:eval_jax]
deps =
jax[cuda]
jupyterlab
jupyterlab_nvdashboard
pandas
install_command =
pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html {opts} {packages}
commands =
jupyter lab --no-browser
{envpython} eval/eval_jax.py
[testenv:eval_tf]
deps =
tensorflow
pandas
commands =
{envpython} eval/eval_tensorflow.py
"""
4 changes: 1 addition & 3 deletions src/logbesselk/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .integral import (
from .ica import (
bessel_ke,
bessel_kratio,
log_abs_deriv_bessel_k,
log_bessel_k,
)
from .misc import (
Expand All @@ -11,7 +10,6 @@
__all__ = [
"bessel_ke",
"bessel_kratio",
"log_abs_deriv_bessel_k",
"log_bessel_k",
"sign_deriv_bessel_k",
]
68 changes: 68 additions & 0 deletions src/logbesselk/jax/ica.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.lax as lax

from .asymptotic import log_bessel_k_naive as log_k_large_v
from .cfraction import log_bessel_ku as log_ku_small_v
from .integral import log_abs_deriv_bessel_k as log_k_small_x
from .math import (
fround,
is_finite,
)
from .misc import (
log_bessel_recurrence,
)
from .utils import (
result_type,
)
from .wrap import (
wrap_bessel_ke,
wrap_bessel_kratio,
wrap_log_bessel_k,
)

__all__ = [
"log_bessel_k",
"bessel_kratio",
"bessel_ke",
]


@wrap_log_bessel_k
def log_bessel_k(v, x):
"""
Combination of Integrate, Continued fraction and Asymptotic expansion.
"""

def small_x_case():
return log_k_small_x(v, x)

def large_x_case():
def small_v_case():
n = fround(v)
u = (v - n).astype(dtype)
u_ = lax.cond(small_v, lambda: u, lambda: dtype(1 / 2))
logk0, logk1 = log_ku_small_v(u_, x)
return log_bessel_recurrence(logk0, logk1, u, n, x)[0]

def large_v_case():
v_ = lax.cond(large_v, lambda: v.astype(dtype), lambda: dtype(0))
return log_k_large_v(v_, x)

return lax.cond(small_v, small_v_case, large_v_case)

dtype = result_type(v, x)
finite = is_finite(v) & is_finite(x) & (x > 0)
large_v_ = v >= 25
large_x_ = x >= 100

small_x = finite & ~large_x_
small_v = finite & large_x_ & ~large_v_
large_v = finite & large_x_ & large_v_
return lax.cond(small_x, small_x_case, large_x_case)


def bessel_kratio(v, x, d=1):
return wrap_bessel_kratio(log_bessel_k, v, x, d)


def bessel_ke(v, x):
return wrap_bessel_ke(log_bessel_k, v, x)

0 comments on commit ba00c91

Please sign in to comment.