Skip to content

Commit

Permalink
Fix _check_start_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 14, 2021
1 parent a0e36d0 commit 77d6c60
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,11 @@ def sample(
def _check_start_shape(model, start):
if not isinstance(start, dict):
raise TypeError("start argument must be a dict or an array-like of dicts")

# Filter "non-input" variables
initial_point = model.initial_point
start = {k: v for k, v in deepcopy(start).items() if k in initial_point}

e = ""
for var in model.basic_RVs:
var_shape = model.fastfn(var.shape)(start)
Expand Down

0 comments on commit 77d6c60

Please sign in to comment.