Skip to content

Commit

Permalink
Merge pull request #300 from jorenham/refactor-inference
Browse files Browse the repository at this point in the history
Refactor `lmo.inference` as a submodule
  • Loading branch information
jorenham authored Aug 24, 2024
2 parents 7c8f952 + e12ba18 commit af9d885
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 173 deletions.
2 changes: 1 addition & 1 deletion lmo/_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


if TYPE_CHECKING:
from .typing import np as lnpt
import lmo.typing.np as lnpt


__all__ = (
Expand Down
55 changes: 27 additions & 28 deletions lmo/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import math
import sys
from typing import TYPE_CHECKING, Any, Final, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Final, TypeAlias, cast, overload

import numpy as np
import numpy.typing as npt
import optype.numpy as onpt


if sys.version_info >= (3, 13):
Expand All @@ -15,14 +14,10 @@
from typing_extensions import LiteralString, TypeVar

if TYPE_CHECKING:
from .typing import (
AnyAWeights,
AnyFWeights,
AnyOrder,
AnyOrderND,
AnyTrim,
np as lnpt,
)
import optype.numpy as onpt

import lmo.typing as lmt
import lmo.typing.np as lnpt

__all__ = (
'clean_order',
Expand All @@ -40,17 +35,17 @@


_SCT = TypeVar('_SCT', bound=np.generic)
_SCT_uifc = TypeVar('_SCT_uifc', bound=np.number[Any])
_SCT_ui = TypeVar('_SCT_ui', bound=np.integer[Any], default=np.int_)
_SCT_f = TypeVar('_SCT_f', bound=np.floating[Any], default=np.float64)
_SCT_uifc = TypeVar('_SCT_uifc', bound='lnpt.Number')
_SCT_ui = TypeVar('_SCT_ui', bound='lnpt.Int', default=np.int_)
_SCT_f = TypeVar('_SCT_f', bound='lnpt.Float', default=np.float64)

_DT_f = TypeVar('_DT_f', bound=np.dtype[np.floating[Any]])
_DT_f = TypeVar('_DT_f', bound=np.dtype['lnpt.Float'])

_SizeT = TypeVar('_SizeT', bound=int)

_ShapeT = TypeVar('_ShapeT', bound=onpt.AtLeast0D)
_ShapeT1 = TypeVar('_ShapeT1', bound=onpt.AtLeast1D)
_ShapeT2 = TypeVar('_ShapeT2', bound=onpt.AtLeast2D)
_ShapeT = TypeVar('_ShapeT', bound='onpt.AtLeast0D')
_ShapeT1 = TypeVar('_ShapeT1', bound='onpt.AtLeast1D')
_ShapeT2 = TypeVar('_ShapeT2', bound='onpt.AtLeast2D')

_DType: TypeAlias = np.dtype[_SCT] | type[_SCT]

Expand Down Expand Up @@ -148,7 +143,7 @@ def _apply_aweights(

def _sort_like(
a: onpt.Array[_ShapeT1, _SCT_uifc],
i: onpt.Array[tuple[int], np.integer[Any]],
i: onpt.Array[tuple[int], lnpt.Int],
/,
axis: int | None,
) -> onpt.Array[_ShapeT1, _SCT_uifc]:
Expand Down Expand Up @@ -183,10 +178,10 @@ def ordered( # noqa: C901
y: lnpt.AnyArrayFloat | None = None,
/,
axis: int | None = None,
dtype: _DType[np.floating[Any]] | None = None,
dtype: _DType[lnpt.Float] | None = None,
*,
fweights: AnyFWeights | None = None,
aweights: AnyAWeights | None = None,
fweights: lmt.AnyFWeights | None = None,
aweights: lmt.AnyAWeights | None = None,
sort: lnpt.SortKind | bool = True,
) -> onpt.Array[onpt.AtLeast1D, lnpt.Float]:
"""
Expand Down Expand Up @@ -254,7 +249,7 @@ def ordered( # noqa: C901


def clean_order(
r: AnyOrder,
r: lmt.AnyOrder,
/,
name: LiteralString = 'r',
rmin: int = 0,
Expand All @@ -268,11 +263,11 @@ def clean_order(


def clean_orders(
r: AnyOrderND,
r: lmt.AnyOrderND,
/,
name: str = 'r',
rmin: int = 0,
dtype: _DType[_SCT_ui] = np.intp,
dtype: _DType[_SCT_ui] = np.int_,
) -> onpt.Array[Any, _SCT_ui]:
"""Validates and cleans an array-like of (L-)moment orders."""
_r = np.asarray_chkfinite(r, dtype=dtype)
Expand All @@ -291,7 +286,11 @@ def clean_orders(
)


def clean_trim(trim: AnyTrim, /) -> tuple[int, int] | tuple[float, float]:
@overload
def clean_trim(trim: lmt.AnyTrimInt, /) -> tuple[int, int]: ...
@overload
def clean_trim(trim: lmt.AnyTrimFloat, /) -> tuple[float, float]: ...
def clean_trim(trim: lmt.AnyTrim, /) -> tuple[int, int] | tuple[float, float]:
"""
Validates and cleans the passed trim; and return a 2-tuple of either ints
or floats.
Expand Down Expand Up @@ -329,7 +328,7 @@ def clean_trim(trim: AnyTrim, /) -> tuple[int, int] | tuple[float, float]:


def moments_to_ratio(
rs: onpt.Array[Any, np.integer[Any]],
rs: onpt.Array[tuple[int, ...], lnpt.Int],
l_rs: onpt.Array[onpt.AtLeast1D, _SCT_f],
/,
) -> _SCT_f | npt.NDArray[_SCT_f]:
Expand All @@ -355,7 +354,7 @@ def moments_to_ratio(


def moments_to_stats_cov(
t_0r: onpt.Array[tuple[int], np.floating[Any]],
t_0r: onpt.Array[tuple[int], lnpt.Float],
ll_kr: onpt.Array[_ShapeT2, _SCT_f],
) -> onpt.Array[_ShapeT2, _SCT_f]:
# t_0r are L-ratio's for r = 0, 1, ..., R (t_0r[0] == 1 / L-scale)
Expand Down Expand Up @@ -387,7 +386,7 @@ def moments_to_stats_cov(
def l_stats_orders(
num: _SizeT,
/,
dtype: _DType[_SCT_ui] = np.intp,
dtype: _DType[_SCT_ui] = np.int_,
) -> tuple[
onpt.Array[tuple[_SizeT], _SCT_ui],
onpt.Array[tuple[_SizeT], _SCT_ui],
Expand Down
28 changes: 11 additions & 17 deletions lmo/contrib/scipy_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
import lmo.typing as lmt
import lmo.typing.np as lnpt
import lmo.typing.scipy as lspt
from lmo import (
inference,
l_moment as l_moment_est,
)
from lmo import inference
from lmo._lm import l_moment as l_moment_est
from lmo._utils import (
clean_order,
clean_orders,
Expand Down Expand Up @@ -1089,33 +1087,29 @@ def l_fit(
self,
data: lnpt.AnyVectorInt | lnpt.AnyVectorFloat,
*args: float,
n_extra: int = 0,
trim: lmt.AnyTrim = 0,
n_extra: int = ...,
trim: lmt.AnyTrimInt = ...,
full_output: Literal[True],
fit_kwargs: Mapping[str, Any] | None = None,
fit_kwargs: Mapping[str, Any] | None = ...,
**kwds: Any,
) -> tuple[float, ...]:
...

) -> tuple[float, ...]: ...
@overload
def l_fit(
self,
data: lnpt.AnyVectorInt | lnpt.AnyVectorFloat,
*args: float,
n_extra: int = 0,
trim: lmt.AnyTrim = 0,
n_extra: int = ...,
trim: lmt.AnyTrimInt = ...,
full_output: bool = ...,
fit_kwargs: Mapping[str, Any] | None = None,
fit_kwargs: Mapping[str, Any] | None = ...,
**kwds: Any,
) -> tuple[float, ...]:
...

) -> tuple[float, ...]: ...
def l_fit(
self,
data: lnpt.AnyVectorInt | lnpt.AnyVectorFloat,
*args: float,
n_extra: int = 0,
trim: lmt.AnyTrim = 0,
trim: lmt.AnyTrimInt = 0,
full_output: bool = False,
fit_kwargs: Mapping[str, Any] | None = None,
random_state: int | np.random.Generator | None = None,
Expand Down
8 changes: 8 additions & 0 deletions lmo/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Statistical inference for parametric probability distributions."""

from __future__ import annotations

from ._l_gmm import GMMResult, fit


__all__ = 'GMMResult', 'fit'
71 changes: 29 additions & 42 deletions lmo/inference.py → lmo/inference/_l_gmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Parametric inference using the (Generalized) Method of L-Moments, L-(G)MM.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
Expand All @@ -10,28 +6,21 @@
import numpy.typing as npt
from scipy import optimize, special

from ._lm import l_moment as l_moment_est
from ._lm_co import l_coscale as l_coscale_est
from ._utils import clean_orders, clean_trim
from .diagnostic import HypothesisTestResult, l_moment_bounds
from .theoretical import l_moment_from_ppf
from .typing import scipy as lspt
import lmo.typing as lmt
import lmo.typing.np as lnpt
import lmo.typing.scipy as lspt
from lmo._lm import l_moment as l_moment_est
from lmo._lm_co import l_coscale as l_coscale_est
from lmo._utils import clean_orders, clean_trim
from lmo.diagnostic import HypothesisTestResult, l_moment_bounds
from lmo.theoretical import l_moment_from_ppf


if TYPE_CHECKING:
from collections.abc import Callable

from .typing import (
AnyOrderND,
AnyTrim,
np as lnpt,
)


__all__ = (
'GMMResult',
'fit',
)
__all__ = 'GMMResult', 'fit'


_ArrF8: TypeAlias = npt.NDArray[np.float64]
Expand Down Expand Up @@ -156,26 +145,27 @@ def AICc(self) -> float: # noqa: N802
def _loss_step(
args: _ArrF8,
l_fn: Callable[..., _ArrF8],
r: npt.NDArray[np.int64],
r: npt.NDArray[np.intp],
l_r: _ArrF8,
trim: AnyTrim,
trim: lmt.AnyTrim,
w_rr: _ArrF8,
) -> float:
) -> np.float64:
lmbda_r = l_fn(r, *args, trim=trim)

if not np.all(np.isfinite(lmbda_r)):
msg = f'failed to find the L-moments of ppf{args} with {trim=}'
raise ValueError(msg)

g_r = lmbda_r - l_r
return cast(float, np.sqrt(g_r.T @ w_rr @ g_r))
return np.sqrt(cast(np.float64, g_r.T @ w_rr @ g_r))


def _get_l_moment_fn(ppf: lspt.RVFunction[...]) -> Callable[..., _ArrF8]:
def l_moment_fn(
r: AnyOrderND,
r: lmt.AnyOrderND,
/,
*args: Any,
trim: AnyTrim = 0,
trim: lmt.AnyTrim = 0,
) -> _ArrF8:
return l_moment_from_ppf(lambda q: ppf(q, *args), r, trim=trim)

Expand All @@ -184,7 +174,8 @@ def l_moment_fn(

def _get_weights_mc(
y: _ArrF8,
r: npt.NDArray[np.int64],
r: npt.NDArray[np.intp],
/,
trim: tuple[int, int] | tuple[float, float] = (0, 0),
) -> _ArrF8:
l_r = l_moment_est(
Expand Down Expand Up @@ -219,21 +210,16 @@ def fit( # noqa: C901
args0: lnpt.AnyVectorFloat,
n_obs: int,
l_moments: lnpt.AnyVectorFloat,
r: AnyOrderND | None = None,
trim: AnyTrim = 0,
r: lmt.AnyOrderND | None = None,
trim: int | tuple[int, int] = 0,
*,
k: int | None = None,
k_max: int = 50,
l_tol: float = 1e-4,

l_moment_fn: Callable[..., _ArrF8] | None = None,
n_mc_samples: int = 9999,
random_state: (
int
| np.random.Generator
| np.random.RandomState
| None
) = None,
random_state: lnpt.Seed | None = None,
**kwds: Any,
) -> GMMResult:
r"""
Expand Down Expand Up @@ -281,12 +267,13 @@ def fit( # noqa: C901
- Raise on minimization error, warn on failed k-step convergence
- Optional `integrality` kwarg with boolean mask for integral params.
- Implement CUE: Continuously Updating GMM (i.e. implement and
use `_loss_cue()`, then run with `k=1`).
use `_loss_cue()`, then run with `k=1`). See
https://github.com/jorenham/Lmo/issues/299
Parameters:
ppf:
The (vectorized) quantile function of the probability distribution,
with signature `(*args: float, q: T) -> T`.
with signature `(q: T, *params: float) -> T`.
args0:
Initial estimate of the distribution's parameter values.
n_obs:
Expand All @@ -312,7 +299,7 @@ def fit( # noqa: C901
Will be ignored if $k$ is specified or if `n_extra=0`.
l_moment_fn:
Function for parametric L-moment calculation, with signature:
`(r: int64[], *args, trim: float[2] | int[2]) -> float64[]`.
`(r: intp[:], *args, trim: float[2] | int[2]) -> float64[:]`.
n_mc_samples:
The number of Monte-Carlo (MC) samples drawn from the
distribution to to form the weight matrix in step $k > 1$.
Expand Down Expand Up @@ -346,9 +333,9 @@ def fit( # noqa: C901
raise ValueError(msg)

if r is None:
_r = np.arange(1, len(l_r) + 1)
_r = np.arange(1, len(l_r) + 1, dtype=np.intp)
else:
_r = clean_orders(np.asarray(r, np.int64))
_r = clean_orders(np.asarray(r, np.intp))

_r_nonzero = _r != 0
l_r, _r = l_r[_r_nonzero], _r[_r_nonzero]
Expand All @@ -358,7 +345,7 @@ def fit( # noqa: C901
raise ValueError(msg)

_trim = clean_trim(trim)
_r = np.arange(1, n_con + 1, dtype=np.int64)
_r = np.arange(1, n_con + 1, dtype=np.intp)

# Individual L-moment "natural" scaling constants, making their magnitudes
# order- and trim- agnostic (used in convergence criterion)
Expand Down Expand Up @@ -448,7 +435,7 @@ def fit( # noqa: C901
success=success,
statistic=fun**2,
eps=eps,
n_samp=cast(int, n_obs - sum(_trim)),
n_samp=n_obs - int(sum(_trim)),
n_step=_k,
n_iter=i,
weights=w_rr,
Expand Down
2 changes: 1 addition & 1 deletion lmo/ostats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if TYPE_CHECKING:
import optype.numpy as onpt

from .typing import np as lnpt
import lmo.typing.np as lnpt


__all__ = 'weights', 'from_cdf'
Expand Down
Loading

0 comments on commit af9d885

Please sign in to comment.