Skip to content

Commit

Permalink
imporove: use while_loop for t loop
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab committed Oct 25, 2022
1 parent 8436161 commit 1c28220
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "logbesselk"
version = "2.4.0"
version = "2.4.1"
description = "Provide function to calculate the modified Bessel function of the second kind"
license = "Apache-2.0"
authors = ["TAKEKAWA Takashi <[email protected]>"]
Expand Down
23 changes: 15 additions & 8 deletions src/logbesselk/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,23 @@ def func_mth(t):
dt = tf.where(zero_exists & mask, scale, zero)
t1 = find_zero_with_extend(func_mth, tpp, dt, tol, max_iter)

for b in range(bins):
def funcb(b):
a = (2 * b + 1) / (2 * bins)
t = (1 - a) * t0 + a * t1
ft = func(t)
if b == 0:
out = tf.ones(shape, dtype)
fmax = ft
else:
out = tf.where(fmax > ft, out + tf.exp(ft - fmax), out * tk.exp(fmax - ft) + 1)
fmax = tf.where(fmax > ft, fmax, ft)
return func(t)

def cond(b, fmax, out):
return b < bins

def loop(b, fmax, out):
b += 1
ft = funcb(b)
out = tf.where(fmax > ft, out + tf.exp(ft - fmax), out * tk.exp(fmax - ft) + 1)
fmax = tf.where(fmax > ft, fmax, ft)
return b, fmax, out

init = 0, funcb(0), tf.ones(shape, dtype)
b, fmax, out = tf.while_loop(cond, loop, init)
h = (t1 - t0) / bins
out = tk.log(h) + fmax + tk.log(out)

Expand Down

0 comments on commit 1c28220

Please sign in to comment.