-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy path_jax_idct.py
65 lines (50 loc) · 1.84 KB
/
_jax_idct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""A custom JAX idct method."""
import jax.numpy as jnp
import scipy.fft as osp_fft
from jax import lax
from jax._src.numpy.util import _wraps
from jax._src.util import canonicalize_axis
# `_W4` copied from jax._src.scipy.fft
def _W4(N, k):
return jnp.exp(-0.5j * jnp.pi * k / N)
def _slice_ref_in_dim(x, start, stop, stride, axis):
return tuple(
slice(start, stop, stride) if dim == axis else slice(None, None)
for dim in range(x.ndim)
)
def _dct_interleave_inverse(x, axis):
inverse = True
if inverse:
N = x.shape[axis]
v0 = lax.slice_in_dim(x, None, (N + 1) // 2, 1, axis)
v1 = lax.rev(lax.slice_in_dim(x, (N + 1) // 2, None, 1, axis), (axis,))
ref0 = _slice_ref_in_dim(x, None, None, 2, axis)
ref1 = _slice_ref_in_dim(x, 1, None, 2, axis)
out = jnp.zeros(x.shape, dtype=x.dtype)
out = out.at[ref0].set(v0)
out = out.at[ref1].set(v1)
return out
# REFACTOR I'm sure there is a much cleaner way of doing this
@_wraps(osp_fft.idct)
def idct(x, norm=None, axis=-1):
axis = canonicalize_axis(axis, x.ndim)
N = x.shape[axis]
k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis])
V = _W4(N, -k) * x
x0 = lax.slice_in_dim(x, None, 1, 1, axis) / 2
V = V.at[_slice_ref_in_dim(V, None, 1, 1, axis)].set(x0)
if norm == "ortho":
factor = lax.concatenate(
[
lax.full((1,), 2 * jnp.sqrt(N), V.dtype),
lax.full((N - 1,), jnp.sqrt(2 * N), V.dtype),
],
0,
)
factor = lax.expand_dims(factor, [a for a in range(V.ndim) if a != axis])
V = V * factor
v = jnp.fft.ifft(V, axis=axis)
xrev = lax.slice_in_dim(x, 1, None, 1, axis)
xrev = lax.rev(xrev, (axis,))
out = _dct_interleave_inverse(v, axis)
return out.real