Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtures and __repr__ enhancements. #227

Merged
merged 9 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 104 additions & 17 deletions ciw/dists/distributions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ciw.auxiliary import *
from itertools import cycle
import numpy as np
'''Distributions available in Ciw.'''

import copy
from itertools import cycle
from operator import add, mul, sub, truediv
from random import (
expovariate,
Expand All @@ -11,7 +11,12 @@
lognormvariate,
weibullvariate,
)
from typing import List, NoReturn

import numpy as np

from ciw.auxiliary import *
from ciw.individual import Individual

class Distribution(object):
"""
Expand Down Expand Up @@ -99,7 +104,7 @@ def __init__(self, lower, upper):
self.upper = upper

def __repr__(self):
return "Uniform: {0}, {1}".format(self.lower, self.upper)
return f"Uniform({self.lower}, {self.upper})"

def sample(self, t=None, ind=None):
return uniform(self.lower, self.upper)
Expand All @@ -121,7 +126,7 @@ def __init__(self, value):
self.value = value

def __repr__(self):
return "Deterministic: {0}".format(self.value)
return f"Deterministic({self.value})"

def sample(self, t=None, ind=None):
return self.value
Expand Down Expand Up @@ -151,7 +156,7 @@ def __init__(self, lower, mode, upper):
self.upper = upper

def __repr__(self):
return "Triangular: {0}, {1}, {2}".format(self.lower, self.mode, self.upper)
return f"Triangular({self.lower}, {self.mode}, {self.upper})"

def sample(self, t=None, ind=None):
return triangular(self.lower, self.upper, self.mode)
Expand All @@ -173,7 +178,7 @@ def __init__(self, rate):
self.rate = rate

def __repr__(self):
return "Exponential: {0}".format(self.rate)
return f"Exponential({self.rate})"

def sample(self, t=None, ind=None):
return expovariate(self.rate)
Expand All @@ -193,7 +198,7 @@ def __init__(self, shape, scale):
self.scale = scale

def __repr__(self):
return "Gamma: {0}, {1}".format(self.shape, self.scale)
return f"Gamma({self.shape}, {self.scale})"

def sample(self, t=None, ind=None):
return gammavariate(self.shape, self.scale)
Expand All @@ -213,7 +218,7 @@ def __init__(self, mean, sd):
self.sd = sd

def __repr__(self):
return "Normal: {0}, {1}".format(self.mean, self.sd)
return f"Normal({self.mean}, {self.sd})"

def sample(self, t=None, ind=None):
return truncated_normal(self.mean, self.sd)
Expand All @@ -233,7 +238,7 @@ def __init__(self, mean, sd):
self.sd = sd

def __repr__(self):
return "Lognormal: {0}, {1}".format(self.mean, self.sd)
return f"Lognormal({self.mean}, {self.sd})"

def sample(self, t=None, ind=None):
return lognormvariate(self.mean, self.sd)
Expand All @@ -253,7 +258,7 @@ def __init__(self, scale, shape):
self.shape = shape

def __repr__(self):
return "Weibull: {0}, {1}".format(self.scale, self.shape)
return f"Weibull({self.scale}, {self.shape})"

def sample(self, t=None, ind=None):
return weibullvariate(self.scale, self.shape)
Expand Down Expand Up @@ -298,7 +303,10 @@ def __init__(self, sequence):
self.generator = cycle(self.sequence)

def __repr__(self):
return "Sequential"
if len(self.sequence) <= 3:
return f"Sequential({self.sequence})"
else:
return f"Sequential({self.sequence[0]}, ..., {self.sequence[-1]})"

def sample(self, t=None, ind=None):
return next(self.generator)
Expand All @@ -324,7 +332,7 @@ def __init__(self, values, probs):
self.probs = probs

def __repr__(self):
return "Pmf"
return f"Pmf({self.values}, {self.probs})"

def sample(self, t=None, ind=None):
return random_choice(self.values, self.probs)
Expand Down Expand Up @@ -420,7 +428,7 @@ def __init__(self, rate, num_phases):
super().__init__(initial_state, absorbing_matrix)

def __repr__(self):
return f"Erlang: {self.rate}, {self.num_phases}"
return f"Erlang({self.rate}, {self.num_phases})"


class HyperExponential(PhaseType):
Expand Down Expand Up @@ -611,7 +619,7 @@ def sample(self, t=None, ind=None):
return ciw.rng.poisson(lam=self.rate)

def __repr__(self):
return f"Poisson: {self.rate}"
return f"Poisson({self.rate})"


class Geometric(Distribution):
Expand All @@ -634,7 +642,7 @@ def sample(self, t=None, ind=None):
return ciw.rng.geometric(p=self.prob)

def __repr__(self):
return f"Geometric: {self.prob}"
return f"Geometric({self.prob})"


class Binomial(Distribution):
Expand Down Expand Up @@ -663,4 +671,83 @@ def sample(self, t=None, ind=None):
return ciw.rng.binomial(n=self.n, p=self.prob)

def __repr__(self):
return f"Binomial: {self.n}, {self.prob}"
return f"Binomial({self.n}, {self.prob})"


class MixtureDistribution(Distribution):
"""
A mixture distribution combining multiple probability distributions.

Parameters
----------
dists : List[Distribution]
A list of probability distributions to be combined in the mixture.
rhos : List[float]
A list of weights corresponding to the importance of each distribution in the mixture.
The weights must sum to 1.

Attributes
----------
rhos : List[float]
List of weights assigned to each distribution in the mixture.
dists : List[Distribution]
List of probability distributions in the mixture.

Methods
-------
sample(t: float, inds: List[Individual] = None) -> float:
Generate a random sample from the mixture distribution.

Notes
-----
The weights in `rhos` should sum to 1, indicating the relative importance of each distribution
in the mixture. The distributions in `dists` should be instances of `ciw.dists.Distribution`.
"""

def __init__(self, dists: List[Distribution], rhos: List[float]) -> NoReturn:
"""
Initialize the MixtureDistribution.

Parameters
----------
dists : List[Distribution]
A list of probability distributions to be combined in the mixture.
rhos : List[float]
A list of weights corresponding to the importance of each distribution in the mixture.
The weights must sum to 1.
"""
self.rhos = rhos
self.dists = dists

def sample(self, t: float, inds: List[Individual] = None) -> float:
"""
Generate a random sample from the mixture distribution.

Parameters
----------
t : float
The time parameter for the sample generation.
inds : List[Individual], optional
List of individuals associated with the sample, if applicable.

Returns
-------
float
A random sample from the mixture distribution.
"""
chosen_dist = random.choices(
population=self.dists,
weights=self.rhos,
k=1)[0]

return chosen_dist.sample(t, inds)

def __repr__(self):

dist_strs = [f'{rho} * {dist}' for rho,dist in zip(self.rhos, self.dists)]

if len(dist_strs) <= 3:
inside = ', '.join(dist_strs)
return f"Mixture({inside})"
else:
return f"Mixture({dist_strs[0]}, ..., {dist_strs[-1]})"
Loading
Loading