Skip to content

Commit

Permalink
Merge pull request #43 from gibsramen/move-fit-single
Browse files Browse the repository at this point in the history
Move _fit_single to directly under BaseModel
  • Loading branch information
gibsramen authored May 11, 2021
2 parents cd19815 + a190619 commit bdd1753
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions birdman/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,9 @@ def _fit_parallel(
if dask_cluster is not None:
dask_cluster.scale(jobs=jobs)

@dask.delayed
def _fit_single(self, values):
dat = self.dat
dat["y"] = values.astype(int)
_fit = self.sm.sample(
chains=self.chains,
parallel_chains=1, # run all chains in serial
data=dat,
iter_warmup=self.num_warmup,
iter_sampling=self.num_iter,
seed=self.seed,
**sampler_args
)
return _fit

_fits = []
for v, i, d in self.table.iter(axis="observation"):
_fit = _fit_single(self, v)
_fit = dask.delayed(self._fit_single)(v, sampler_args)
_fits.append(_fit)

futures = dask.persist(*_fits)
Expand All @@ -185,6 +170,19 @@ def _fit_single(self, values):
self.dat["y"] = self.table.matrix_data.todense().T.astype(int)
return all_fits

def _fit_single(self, values, sampler_args):
dat = self.dat
dat["y"] = values.astype(int)
_fit = self.sm.sample(
chains=self.chains,
data=dat,
iter_warmup=self.num_warmup,
iter_sampling=self.num_iter,
seed=self.seed,
**sampler_args
)
return _fit

def to_inference_object(
self,
params: Sequence[str],
Expand Down

0 comments on commit bdd1753

Please sign in to comment.