Skip to content

Commit

Permalink
Starts TMCMC porting
Browse files Browse the repository at this point in the history
  • Loading branch information
dimtsap committed Nov 23, 2022
1 parent cfdf3df commit b87cf47
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 164 deletions.
69 changes: 49 additions & 20 deletions src/UQpy/sampling/mcmc/MetropolisHastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ class MetropolisHastings(MCMC):

@beartype
def __init__(
self,
pdf_target: Union[Callable, list[Callable]] = None,
log_pdf_target: Union[Callable, list[Callable]] = None,
args_target: tuple = None,
burn_length: Annotated[int, Is[lambda x: x >= 0]] = 0,
jump: int = 1,
dimension: int = None,
seed: list = None,
save_log_pdf: bool = False,
concatenate_chains: bool = True,
n_chains: int = None,
proposal: Distribution = None,
proposal_is_symmetric: bool = False,
random_state: RandomStateType = None,
nsamples: PositiveInteger = None,
nsamples_per_chain: PositiveInteger = None,
self,
pdf_target: Union[Callable, list[Callable]] = None,
log_pdf_target: Union[Callable, list[Callable]] = None,
args_target: tuple = None,
burn_length: Annotated[int, Is[lambda x: x >= 0]] = 0,
jump: int = 1,
dimension: int = None,
seed: list = None,
save_log_pdf: bool = False,
concatenate_chains: bool = True,
n_chains: int = None,
proposal: Distribution = None,
proposal_is_symmetric: bool = False,
random_state: RandomStateType = None,
nsamples: PositiveInteger = None,
nsamples_per_chain: PositiveInteger = None,
):
"""
Metropolis-Hastings algorithm :cite:`MCMC1` :cite:`MCMC2`
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
self.logger.info("\nUQpy: Initialization of " + self.__class__.__name__ + " algorithm complete.")

if (nsamples is not None) or (nsamples_per_chain is not None):
self.run(nsamples=nsamples, nsamples_per_chain=nsamples_per_chain,)
self.run(nsamples=nsamples, nsamples_per_chain=nsamples_per_chain, )

def run_one_iteration(self, current_state: np.ndarray, current_log_pdf: np.ndarray):
"""
Expand Down Expand Up @@ -144,11 +144,11 @@ def run_one_iteration(self, current_state: np.ndarray, current_log_pdf: np.ndarr
) # this vector will be used to compute accept_ratio of each chain
unif_rvs = (
Uniform()
.rvs(nsamples=self.n_chains, random_state=self.random_state)
.reshape((-1,))
.rvs(nsamples=self.n_chains, random_state=self.random_state)
.reshape((-1,))
)
for nc, (cand, log_p_cand, r_) in enumerate(
zip(candidate, log_p_candidate, log_ratios)
zip(candidate, log_p_candidate, log_ratios)
):
accept = np.log(unif_rvs[nc]) < r_
if accept:
Expand All @@ -159,3 +159,32 @@ def run_one_iteration(self, current_state: np.ndarray, current_log_pdf: np.ndarr
self._update_acceptance_rate(accept_vec)

return current_state, current_log_pdf

def __copy__(self, **kwargs):
pdf_target = self.pdf_target if kwargs['pdf_target'] is None else kwargs['pdf_target']
log_pdf_target = self.log_pdf_target if kwargs['log_pdf_target'] is None else kwargs['log_pdf_target']
args_target = self.args_target if kwargs['args_target'] is None else kwargs['args_target']
burn_length = self.burn_length if kwargs['burn_length'] is None else kwargs['burn_length']
jump = self.jump if kwargs['jump'] is None else kwargs['jump']
dimension = self.dimension if kwargs['dimension'] is None else kwargs['dimension']
seed = self.seed if kwargs['seed'] is None else kwargs['seed']
save_log_pdf = self.save_log_pdf if kwargs['save_log_pdf'] is None else kwargs['save_log_pdf']
concatenate_chains = self.concatenate_chains if kwargs['concatenate_chains'] is None\
else kwargs['concatenate_chains']
n_chains = self.n_chains if kwargs['n_chains'] is None else kwargs['n_chains']
proposal = self.proposal if kwargs['proposal'] is None else kwargs['proposal']
proposal_is_symmetric = self.proposal_is_symmetric if kwargs['proposal_is_symmetric'] is None \
else kwargs['proposal_is_symmetric']
random_state = self.random_state if kwargs['random_state'] is None else kwargs['random_state']
nsamples = self.nsamples if kwargs['nsamples'] is None else kwargs['nsamples']
nsamples_per_chain = self.nsamples_per_chain if kwargs['nsamples_per_chain'] is None \
else kwargs['nsamples_per_chain']

new = self.__class__(pdf_target=pdf_target, log_pdf_target=log_pdf_target, args_target=args_target,
burn_length=burn_length, jump=jump, dimension=dimension, seed=seed,
save_log_pdf=save_log_pdf, concatenate_chains=concatenate_chains,
proposal=proposal, proposal_is_symmetric=proposal_is_symmetric, n_chains=n_chains,
random_state=random_state, nsamples=nsamples, nsamples_per_chain=nsamples_per_chain)

return new

25 changes: 11 additions & 14 deletions src/UQpy/sampling/mcmc/baseclass/MCMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from UQpy.distributions import Distribution
from UQpy.utilities.ValidationTypes import *
from UQpy.utilities.Utilities import process_random_state
from abc import ABC
from abc import ABC, abstractmethod


class MCMC(ABC):
Expand Down Expand Up @@ -177,29 +177,22 @@ def _concatenate_chains(self):
return None

def _unconcatenate_chains(self):
self.samples = self.samples.reshape(
(-1, self.n_chains, self.dimension), order="C"
)
self.samples = self.samples.reshape((-1, self.n_chains, self.dimension), order="C")
if self.save_log_pdf:
self.log_pdf_values = self.log_pdf_values.reshape(
(-1, self.n_chains), order="C"
)
self.log_pdf_values = self.log_pdf_values.reshape((-1, self.n_chains), order="C")
return None

def _initialize_samples(self, nsamples, nsamples_per_chain):
if ((nsamples is not None) and (nsamples_per_chain is not None)) \
or (nsamples is None and nsamples_per_chain is None):
raise ValueError("UQpy: Either nsamples or nsamples_per_chain must be provided (not both)")
if nsamples_per_chain is not None:
if not (isinstance(nsamples_per_chain, int) and nsamples_per_chain >= 0):
raise TypeError("UQpy: nsamples_per_chain must be an integer >= 0.")
nsamples = int(nsamples_per_chain * self.n_chains)
else:
if nsamples_per_chain is None:
if not (isinstance(nsamples, int) and nsamples >= 0):
raise TypeError("UQpy: nsamples must be an integer >= 0.")
nsamples_per_chain = int(np.ceil(nsamples / self.n_chains))
nsamples = int(nsamples_per_chain * self.n_chains)

elif not (isinstance(nsamples_per_chain, int) and nsamples_per_chain >= 0):
raise TypeError("UQpy: nsamples_per_chain must be an integer >= 0.")
nsamples = int(nsamples_per_chain * self.n_chains)
if self.samples is None: # very first call of run, set current_state as the seed and initialize self.samples
self.samples = np.zeros((nsamples_per_chain, self.n_chains, self.dimension))
if self.save_log_pdf:
Expand Down Expand Up @@ -316,3 +309,7 @@ def _check_methods_proposal(proposal_distribution):
raise AttributeError("UQpy: The proposal should have a log_pdf or pdf method")
proposal_distribution.log_pdf = lambda x: np.log(
np.maximum(proposal_distribution.pdf(x), 10 ** (-320) * np.ones((x.shape[0],))))

@abstractmethod
def __copy__(self, **kwargs):
pass
63 changes: 24 additions & 39 deletions src/UQpy/sampling/tempering_mcmc/ParallelTemperingMCMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ParallelTemperingMCMC(TemperingMCMC):
"""
Parallel-Tempering MCMC
This algorithms runs the chains sampling from various tempered distributions in parallel. Periodically during the
This algorithm runs the chains sampling from various tempered distributions in parallel. Periodically during the
run, the different temperatures swap members of their ensemble in a way that
preserves detailed balance.The chains closer to the reference chain (hot chains) can sample from regions that have
low probability under the target and thus allow a better exploration of the parameter space, while the cold chains
Expand All @@ -35,13 +35,18 @@ class ParallelTemperingMCMC(TemperingMCMC):
"""

def __init__(self, niter_between_sweeps, pdf_intermediate=None, log_pdf_intermediate=None, args_pdf_intermediate=(),
distribution_reference=None, nburn=0, jump=1, dimension=None, seed=None,
save_log_pdf=False, nsamples=None, nsamples_per_chain=None, nchains=None, verbose=False,
random_state=None, temper_param_list=None, n_temper_params=None, mcmc_class=MetropolisHastings, **kwargs_mcmc):
distribution_reference=None,
save_log_pdf=False, nsamples=None, nsamples_per_chain=None,
random_state=None,
temper_param_list=None, n_temper_params=None,
sampler: Union[MCMC, list[MCMC]] = None):

super().__init__(pdf_intermediate=pdf_intermediate, log_pdf_intermediate=log_pdf_intermediate,
args_pdf_intermediate=args_pdf_intermediate, distribution_reference=None, dimension=dimension,
args_pdf_intermediate=args_pdf_intermediate, distribution_reference=None,
save_log_pdf=save_log_pdf, random_state=random_state)
self.logger = logging.getLogger(__name__)
self.sampler = sampler

self.distribution_reference = distribution_reference
self.evaluate_log_reference = self._preprocess_reference(self.distribution_reference)

Expand All @@ -68,42 +73,31 @@ def __init__(self, niter_between_sweeps, pdf_intermediate=None, log_pdf_intermed
self.n_temper_params = len(self.temper_param_list)

# Initialize mcmc objects, need as many as number of temperatures
if not issubclass(mcmc_class, MCMC):
raise ValueError('UQpy: mcmc_class should be a subclass of MCMC.')
if not all((isinstance(val, (list, tuple)) and len(val) == self.n_temper_params)
for val in kwargs_mcmc.values()):
raise ValueError(
'UQpy: additional kwargs arguments should be mcmc algorithm specific inputs, given as lists of length '
'the number of temperatures.')
# if not all((isinstance(val, (list, tuple)) and len(val) == self.n_temper_params)
# for val in kwargs_mcmc.values()):
# raise ValueError(
# 'UQpy: additional kwargs arguments should be mcmc algorithm specific inputs, given as lists of length '
# 'the number of temperatures.')
# default value
if isinstance(mcmc_class, MetropolisHastings) and len(kwargs_mcmc) == 0:
kwargs_mcmc = {}
if isinstance(self.sampler, MetropolisHastings) and not kwargs_mcmc:
from UQpy.distributions import JointIndependent, Normal
kwargs_mcmc = {'proposal_is_symmetric': [True, ] * self.n_temper_params,
'proposal': [JointIndependent([Normal(scale=1. / np.sqrt(temper_param))] * dimension)
'proposal': [JointIndependent([Normal(scale=1. / np.sqrt(temper_param))] *
self.sampler.dimension)
for temper_param in self.temper_param_list]}

# Initialize algorithm specific inputs: target pdfs
self.thermodynamic_integration_results = None

self.mcmc_samplers = []
for i, temper_param in enumerate(self.temper_param_list):
# log_pdf_target = self._target_generator(
# self.evaluate_log_intermediate, self.evaluate_log_reference, temper_param)
log_pdf_target = (lambda x, temper_param=temper_param: self.evaluate_log_reference(
x) + self.evaluate_log_intermediate(x, temper_param))
self.mcmc_samplers.append(
mcmc_class(log_pdf_target=log_pdf_target,
dimension=dimension, seed=seed, nburn=nburn, jump=jump, save_log_pdf=save_log_pdf,
concat_chains=True, verbose=verbose, random_state=self.random_state, nchains=nchains,
**dict([(key, val[i]) for key, val in kwargs_mcmc.items()])))

# Samples connect to posterior samples, i.e. the chain with temperature 1.
# self.samples = self.mcmc_samplers[0].samples
# if self.save_log_pdf:
# self.log_pdf_values = self.mcmc_samplers[0].samples
self.mcmc_samplers.append(sampler.__copy__(log_pdf_target=log_pdf_target, concatenate_chains=True,
**kwargs_mcmc))

if self.verbose:
print('\nUQpy: Initialization of ' + self.__class__.__name__ + ' algorithm complete.')
self.logger.info('\nUQpy: Initialization of ' + self.__class__.__name__ + ' algorithm complete.')

# If nsamples is provided, run the algorithm
if (nsamples is not None) or (nsamples_per_chain is not None):
Expand Down Expand Up @@ -138,8 +132,7 @@ def _run(self, nsamples=None, nsamples_per_chain=None):
current_state.append(current_state_t.copy())
current_log_pdf.append(current_log_pdf_t.copy())

if self.verbose:
print('UQpy: Running MCMC...')
self.logger.info('UQpy: Running MCMC...')

# Run nsims iterations of the MCMC algorithm, starting at current_state
while self.mcmc_samplers[0].nsamples_per_chain < final_ns_per_chain:
Expand Down Expand Up @@ -181,8 +174,7 @@ def _run(self, nsamples=None, nsamples_per_chain=None):
# self.nsamples_per_chain += 1
# self.nsamples += self.nchains

if self.verbose:
print('UQpy: MCMC run successfully !')
self.logger.info('UQpy: MCMC run successfully !')

# Concatenate chains maybe
if self.mcmc_samplers[-1].concat_chains:
Expand Down Expand Up @@ -234,13 +226,6 @@ def evaluate_normalization_constant(self, compute_potential, log_Z0=None, nsampl
# use quadrature to integrate between 0 and 1
temper_param_list_for_integration = np.copy(np.array(self.temper_param_list))
log_pdf_averages = np.array(log_pdf_averages)
# if self.temper_param_list[-1] != 1.:
# log_pdf_averages = np.append(log_pdf_averages, log_pdf_averages[-1])
# slope_linear = (log_pdf_averages[-1]-log_pdf_averages[-2]) / (
# betas_for_integration[-1] - betas_for_integration[-2])
# log_pdf_averages = np.append(
# log_pdf_averages, log_pdf_averages[-1] + (1. - betas_for_integration[-1]) * slope_linear)
# betas_for_integration = np.append(betas_for_integration, 1.)
int_value = trapz(x=temper_param_list_for_integration, y=log_pdf_averages)
if log_Z0 is None:
samples_p0 = self.distribution_reference.rvs(nsamples=nsamples_from_p0)
Expand Down
Loading

0 comments on commit b87cf47

Please sign in to comment.