Skip to content

Commit

Permalink
support multiple theta and lambda values
Browse files Browse the repository at this point in the history
  • Loading branch information
slowkow committed Feb 1, 2020
1 parent 8c756bb commit 3683982
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions harmonypy/harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,60 @@ def run_harmony(
"""Run Harmony.
"""

theta = None
lamb = None
sigma = 0.1
nclust = None
tau = 0
block_size = 0.05
max_iter_harmony = 10
max_iter_cluster = 200
epsilon_cluster = 1e-5
epsilon_harmony = 1e-4
plot_convergence = False
verbose = True
reference_values = None
cluster_prior = None
random_state = 0
# theta = None
# lamb = None
# sigma = 0.1
# nclust = None
# tau = 0
# block_size = 0.05
# max_iter_harmony = 10
# max_iter_cluster = 200
# epsilon_cluster = 1e-5
# epsilon_harmony = 1e-4
# plot_convergence = False
# verbose = True
# reference_values = None
# cluster_prior = None
# random_state = 0

N = meta_data.shape[0]
if data_mat.shape[1] != N:
data_mat = data_mat.T

phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T

assert data_mat.shape[1] == N, \
"data_mat and meta_data do not have the same number of cells"

if nclust is None:
nclust = np.min([np.round(N / 30.0), 100]).astype(int)

if type(sigma) is float and nclust > 1:
sigma = np.repeat(sigma, nclust)

if isinstance(vars_use, str):
vars_use = [vars_use]

phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T
phi_n = meta_data[vars_use].describe().loc['unique'].to_numpy().astype(int)

if theta is None:
theta = np.repeat(1, phi.shape[0])
theta = np.repeat([1] * len(phi_n), phi_n)
elif isinstance(theta, float) or isinstance(theta, int):
theta = np.repeat([theta] * len(phi_n), phi_n)
elif len(theta) == len(phi_n):
theta = np.repeat([theta], phi_n)

if lamb is None:
lamb = np.repeat(1, phi.shape[0])
assert len(theta) == np.sum(phi_n), \
"each batch variable must have a theta"

if type(sigma) is float and nclust > 1:
sigma = np.repeat(sigma, nclust)
if lamb is None:
lamb = np.repeat([1] * len(phi_n), phi_n)
elif isinstance(lamb, float) or isinstance(lamb, int):
lamb = np.repeat([lamb] * len(phi_n), phi_n)
elif len(lamb) == len(phi_n):
lamb = np.repeat([lamb], phi_n)

assert len(lamb) == np.sum(phi_n), \
"each batch variable must have a lambda"

# Number of items in each category.
N_b = phi.sum(axis = 1)
Expand Down

0 comments on commit 3683982

Please sign in to comment.