Skip to content

Commit

Permalink
Added Type Hints to baseclass modules
Browse files Browse the repository at this point in the history
  • Loading branch information
Prateek Bhustali committed May 23, 2022
1 parent 13d229e commit aff7d09
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 40 deletions.
27 changes: 26 additions & 1 deletion src/UQpy/sensitivity/baseclass/pickfreeze.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
import copy
from typing import Union

from beartype import beartype

def generate_pick_freeze_samples(dist_obj, n_samples, random_state=None):
from UQpy.distributions.collection import JointIndependent
from UQpy.utilities.ValidationTypes import (
RandomStateType,
PositiveInteger,
)


@beartype
def generate_pick_freeze_samples(
dist_obj: Union[JointIndependent, Union[list, tuple]],
n_samples: PositiveInteger,
random_state: RandomStateType = None,
):

"""
Generate samples to be used in the Pick-and-Freeze algorithm.
**Inputs**:
* **dist_obj** (`JointIndependent` or `list` or `tuple`):
A distribution object or a list or tuple of distribution objects.
* **n_samples** (`int`):
The number of samples to be generated.
* **random_state** (`None` or `int` or `numpy.random.RandomState`):
A random seed or a `numpy.random.RandomState` object.
**Outputs:**
* **A_samples** (`ndarray`):
Expand Down
68 changes: 29 additions & 39 deletions src/UQpy/sensitivity/baseclass/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,38 @@
import numpy as np
import scipy.stats

from typing import Union
from beartype import beartype

from UQpy.distributions import *
from UQpy.utilities.ValidationTypes import (
PositiveFloat,
RandomStateType,
PositiveInteger,
NumpyFloatArray,
NumpyIntArray,
)
from UQpy.run_model import RunModel
from UQpy.distributions.baseclass import DistributionContinuous1D
from UQpy.distributions.collection import JointIndependent


class Sensitivity:
@beartype
def __init__(
self, runmodel_object, dist_object, random_state=None, **kwargs
self,
runmodel_object: RunModel,
dist_object: Union[JointIndependent, Union[list, tuple]],
random_state: RandomStateType = None,
**kwargs,
) -> None:

# Check RunModel object
if not isinstance(runmodel_object, RunModel):
raise TypeError("UQpy: runmodel_object must be an object of class RunModel")

self.runmodel_object = runmodel_object

# Check distributions
if isinstance(dist_object, list):
for i in range(len(dist_object)):
if not isinstance(dist_object[i], (DistributionContinuous1D, JointIndependent)):
raise TypeError(
"UQpy: A ``DistributionContinuous1D`` or ``JointInd`` object "
"must be provided."
)
else:
if not isinstance(dist_object, (DistributionContinuous1D, JointIndependent)):
raise TypeError(
"UQpy: A ``DistributionContinuous1D`` or ``JointInd`` object must be provided."
)

self.dist_object = dist_object

# Check random state
self.random_state = random_state
if isinstance(self.random_state, int):
self.random_state = np.random.RandomState(self.random_state)
elif not (
self.random_state is None
or isinstance(self.random_state, np.random.RandomState)
):
raise TypeError(
"UQpy: random state should be None, an integer or np.random.RandomState object"
)

# wrapper created for convenience to generate model evaluations
def _run_model(self, samples):
@beartype
def _run_model(self, samples: Union[NumpyFloatArray, NumpyIntArray]):
"""Generate model evaluations for a set of samples.
**Inputs**:
Expand All @@ -83,7 +69,8 @@ def _run_model(self, samples):
return model_evals

@staticmethod
def bootstrap_sample_generator_1D(samples):
@beartype
def bootstrap_sample_generator_1D(samples: Union[NumpyFloatArray, NumpyIntArray]):
"""Generate bootstrap samples.
Generators are used to avoid copying the entire array.
Expand Down Expand Up @@ -113,7 +100,8 @@ def bootstrap_sample_generator_1D(samples):
yield samples[_indices]

@staticmethod
def bootstrap_sample_generator_2D(samples):
@beartype
def bootstrap_sample_generator_2D(samples: Union[NumpyFloatArray, NumpyIntArray]):
"""Generate bootstrap samples.
Generators are used to avoid copying the entire array.
Expand Down Expand Up @@ -156,7 +144,8 @@ def bootstrap_sample_generator_2D(samples):
yield samples[_indices, cols]

@staticmethod
def bootstrap_sample_generator_3D(samples):
@beartype
def bootstrap_sample_generator_3D(samples: Union[NumpyFloatArray, NumpyIntArray]):
"""Generate bootstrap samples.
Generators are used to avoid copying the entire array.
Expand Down Expand Up @@ -190,13 +179,14 @@ def bootstrap_sample_generator_3D(samples):

yield samples[:, _indices, cols]

@beartype
def bootstrapping(
self,
estimator,
estimator_inputs,
qoi_mean,
num_bootstrap_samples,
confidence_level=0.95,
qoi_mean: Union[NumpyFloatArray, NumpyIntArray],
num_bootstrap_samples: PositiveInteger = None,
confidence_level: PositiveFloat = 0.95,
**kwargs,
):

Expand Down

0 comments on commit aff7d09

Please sign in to comment.