Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLR for non orthonormal basis #430

Merged
merged 6 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/plot_kernel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@
print('Score NW:', nw_res)

##############################################################################
# For Local Linear Regression, FDataBasis representation with an orthonormal
# basis should be used (for the previous cases it is possible to use either
# For Local Linear Regression, FDataBasis representation with a basis should be
# used (for the previous cases it is possible to use either
# FDataGrid or FDataBasis).
#
# For basis, Fourier basis with 10 elements has been selected. Note that the
Expand Down
97 changes: 43 additions & 54 deletions skfda/misc/hat_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin

from skfda.representation._functional_data import FData
from skfda.representation.basis import FDataBasis

from ..representation._typing import GridPoints, GridPointsLike
from ..representation._functional_data import FData
from ..representation._typing import GridPoints, GridPointsLike, NDArrayFloat
from ..representation.basis import FDataBasis
from . import kernels


Expand All @@ -36,21 +35,21 @@ def __init__(
self,
*,
bandwidth: Optional[float] = None,
kernel: Callable[[np.ndarray], np.ndarray] = kernels.normal,
kernel: Callable[[NDArrayFloat], NDArrayFloat] = kernels.normal,
):
self.bandwidth = bandwidth
self.kernel = kernel

def __call__(
self,
*,
delta_x: np.ndarray,
delta_x: NDArrayFloat,
X_train: Optional[Union[FData, GridPointsLike]] = None,
X: Optional[Union[FData, GridPointsLike]] = None,
y_train: Optional[np.ndarray] = None,
weights: Optional[np.ndarray] = None,
y_train: Optional[NDArrayFloat] = None,
weights: Optional[NDArrayFloat] = None,
_cv: bool = False,
) -> np.ndarray:
) -> NDArrayFloat:
r"""
Calculate the hat matrix or the prediction.

Expand Down Expand Up @@ -99,8 +98,8 @@ def __call__(
def _hat_matrix_function_not_normalized(
self,
*,
delta_x: np.ndarray,
) -> np.ndarray:
delta_x: NDArrayFloat,
) -> NDArrayFloat:
pass


Expand Down Expand Up @@ -141,8 +140,8 @@ class NadarayaWatsonHatMatrix(HatMatrix):
def _hat_matrix_function_not_normalized(
self,
*,
delta_x: np.ndarray,
) -> np.ndarray:
delta_x: NDArrayFloat,
) -> NDArrayFloat:

if self.bandwidth is None:
percentage = 15
Expand Down Expand Up @@ -185,7 +184,7 @@ class LocalLinearRegressionHatMatrix(HatMatrix):
For **kernel regression** algorithm:

Given functional data, :math:`(X_1, X_2, ..., X_n)` where each function
is expressed in a orthonormal basis with :math:`J` elements and scalar
is expressed in a basis with :math:`J` elements and scalar
response :math:`Y = (y_1, y_2, ..., y_n)`.

It is desired to estimate the values
Expand Down Expand Up @@ -222,13 +221,13 @@ class LocalLinearRegressionHatMatrix(HatMatrix):
def __call__( # noqa: D102
self,
*,
delta_x: np.ndarray,
delta_x: NDArrayFloat,
X_train: Optional[Union[FDataBasis, GridPoints]] = None,
X: Optional[Union[FDataBasis, GridPoints]] = None,
y_train: Optional[np.ndarray] = None,
weights: Optional[np.ndarray] = None,
y_train: Optional[NDArrayFloat] = None,
weights: Optional[NDArrayFloat] = None,
_cv: bool = False,
) -> np.ndarray:
) -> NDArrayFloat:

if self.bandwidth is None:
percentage = 15
Expand All @@ -243,10 +242,23 @@ def __call__( # noqa: D102
m1 = X_train.coefficients
m2 = X.coefficients

# Subtract previous matrices obtaining a 3D matrix
# The i-th element contains the matrix X_train - X[i]
C = m1 - m2[:, np.newaxis]

inner_product_matrix = X_train.basis.inner_product_matrix()

# Calculate new coefficients taking into account cross-products
# if the basis is orthonormal, C would not change
C = C @ inner_product_matrix

# Adding a column of ones in the first position of all matrices
dims = (C.shape[0], C.shape[1], 1)
C = np.concatenate((np.ones(dims), C), axis=-1)

return self._solve_least_squares(
delta_x=delta_x,
m1=m1,
m2=m2,
coefs=C,
y_train=y_train,
)

Expand All @@ -264,39 +276,16 @@ def __call__( # noqa: D102

def _solve_least_squares(
self,
delta_x: np.ndarray,
m1: np.ndarray,
m2: np.ndarray,
y_train: np.ndarray,
) -> np.ndarray:
delta_x: NDArrayFloat,
coefs: NDArrayFloat,
y_train: NDArrayFloat,
) -> NDArrayFloat:

W = np.sqrt(self.kernel(delta_x / self.bandwidth))

# Adding a column of ones to m1
m1 = np.concatenate(
(
np.ones(m1.shape[0])[:, np.newaxis],
m1,
),
axis=1,
)

# Adding a column of zeros to m2
m2 = np.concatenate(
(
np.zeros(m2.shape[0])[:, np.newaxis],
m2,
),
axis=1,
)

# Subtract previous matrices obtaining a 3D matrix
# The i-th element contains the matrix X_train - X[i]
C = m1 - m2[:, np.newaxis]

# A x = b
# Where x = (a, b_1, ..., b_J)
A = (C.T * W.T).T
# Where x = (a, b_1, ..., b_J).
A = (coefs.T * W.T).T
b = np.einsum('ij, j... -> ij...', W, y_train)

# For Ax = b calculates x that minimize the square error
Expand All @@ -312,8 +301,8 @@ def _solve_least_squares(
def _hat_matrix_function_not_normalized(
self,
*,
delta_x: np.ndarray,
) -> np.ndarray:
delta_x: NDArrayFloat,
) -> NDArrayFloat:

if self.bandwidth is None:
percentage = 15
Expand Down Expand Up @@ -369,16 +358,16 @@ def __init__(
self,
*,
n_neighbors: Optional[int] = None,
kernel: Callable[[np.ndarray], np.ndarray] = kernels.uniform,
kernel: Callable[[NDArrayFloat], NDArrayFloat] = kernels.uniform,
):
self.n_neighbors = n_neighbors
self.kernel = kernel

def _hat_matrix_function_not_normalized(
self,
*,
delta_x: np.ndarray,
) -> np.ndarray:
delta_x: NDArrayFloat,
) -> NDArrayFloat:

input_points_len = delta_x.shape[1]

Expand Down
8 changes: 4 additions & 4 deletions skfda/ml/regression/_kernel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_is_fitted

from skfda.misc.hat_matrix import HatMatrix, NadarayaWatsonHatMatrix
from skfda.misc.metrics import PairwiseMetric, l2_distance
from skfda.misc.metrics._typing import Metric
from skfda.representation._functional_data import FData
from ...misc.hat_matrix import HatMatrix, NadarayaWatsonHatMatrix
from ...misc.metrics import PairwiseMetric, l2_distance
from ...misc.metrics._typing import Metric
from ...representation._functional_data import FData


class KernelRegression(
Expand Down
26 changes: 24 additions & 2 deletions tests/test_kernel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from skfda.misc.kernels import normal, uniform
from skfda.misc.metrics import l2_distance
from skfda.ml.regression import KernelRegression
from skfda.representation.basis import FDataBasis, Fourier
from skfda.representation.basis import FDataBasis, Fourier, Monomial
from skfda.representation.grid import FDataGrid


Expand Down Expand Up @@ -80,7 +80,7 @@ def _llr_alt(

C = np.concatenate(
(
(np.ones(fd_train.n_samples))[:, np.newaxis],
np.ones(fd_train.n_samples)[:, np.newaxis],
(fd_train - fd_test[i]).coefficients,
),
axis=1,
Expand Down Expand Up @@ -313,3 +313,25 @@ def test_knn_r(self) -> None:
]

np.testing.assert_almost_equal(y, result_R, decimal=6)


class TestNonOthonormalBasisLLR(unittest.TestCase):
"""Test LocalLinearRegression method with non orthonormal basis."""

def test_llr_non_orthonormal(self) -> None:
"""Test LocalLinearRegression with monomial basis."""
coef1 = [[1, 5, 8], [4, 6, 6], [9, 4, 1]]
coef2 = [[6, 3, 5]]
basis = Monomial(n_basis=3, domain_range=(0, 3))

X_train = FDataBasis(coefficients=coef1, basis=basis)
X = FDataBasis(coefficients=coef2, basis=basis)
y_train = np.array([8, 6, 1])

llr = LocalLinearRegressionHatMatrix(
bandwidth=100,
kernel=uniform,
)
kr = KernelRegression(kernel_estimator=llr)
kr.fit(X_train, y_train)
np.testing.assert_almost_equal(kr.predict(X), 4.35735166)