Skip to content

Commit

Permalink
Added Type Hints to CVM module
Browse files Browse the repository at this point in the history
  • Loading branch information
Prateek Bhustali committed May 23, 2022
1 parent df66e93 commit 3584406
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions src/UQpy/sensitivity/cramer_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,23 @@
"""

import logging
from typing import Union

import numpy as np
from beartype import beartype

from UQpy.sensitivity.baseclass.sensitivity import Sensitivity
from UQpy.sensitivity.baseclass.pickfreeze import generate_pick_freeze_samples
from UQpy.sensitivity.sobol import compute_first_order as compute_first_order_sobol
from UQpy.sensitivity.sobol import compute_total_order as compute_total_order_sobol
from UQpy.utilities.UQpyLoggingFormatter import UQpyLoggingFormatter
from UQpy.utilities.ValidationTypes import (
PositiveInteger,
PositiveFloat,
NumpyFloatArray,
NumpyIntArray,
)


# TODO: Sampling strategies

Expand Down Expand Up @@ -85,13 +94,14 @@ def __init__(
self.num_vars = None
"Number of input random variables, :class:`int`"

@beartype
def run(
self,
n_samples=1_000,
estimate_sobol_indices=False,
num_bootstrap_samples=None,
confidence_level=0.95,
disable_CVM_indices=False,
n_samples: PositiveInteger = 1_000,
estimate_sobol_indices: bool = False,
num_bootstrap_samples: PositiveInteger = None,
confidence_level: PositiveFloat = 0.95,
disable_CVM_indices: bool = False,
):

"""
Expand Down Expand Up @@ -243,7 +253,8 @@ def run(
return computed_indices

@staticmethod
def indicator_function(Y, W):
@beartype
def indicator_function(Y: Union[NumpyFloatArray, NumpyIntArray], w: float):
"""
Vectorized version of the indicator function.
Expand All @@ -253,22 +264,28 @@ def indicator_function(Y, W):
**Inputs:**
* **Y** (`ndarray`):
Vector of values of the random variable.
Array of values of the random variable.
Shape: `(N, 1)`
* **W** (`ndarray`):
Vector of values of the random variable.
Shape: `(N, 1)`
* **w** (`float`):
Value to compare with the array.
**Outputs:**
* **indicator** (`ndarray`):
Array of integers with truth values.
Shape: `(N, 1)`
"""
return (Y <= W.T).astype(int)
return (Y <= w).astype(int)

def pick_and_freeze_estimator(self, A_model_evals, W_model_evals, C_i_model_evals):
@beartype
def pick_and_freeze_estimator(
self,
A_model_evals: Union[NumpyFloatArray, NumpyIntArray],
W_model_evals: Union[NumpyFloatArray, NumpyIntArray],
C_i_model_evals: Union[NumpyFloatArray, NumpyIntArray],
):

"""
Compute the first order Cramér-von Mises indices
Expand Down

0 comments on commit 3584406

Please sign in to comment.