Skip to content

Commit

Permalink
Add upper bound on number of CG steps (#404)
Browse files Browse the repository at this point in the history
* upper bound on cg_iters

* address comment
  • Loading branch information
eb8680 authored Dec 7, 2023
1 parent c5fe64b commit 117d645
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def _flat_conjugate_gradient_solve(
p = torch.where(not_converged, r + mu * p, p)
rdotr = torch.where(not_converged, newrdotr, rdotr)

# rdotr = newrdotr
# if rdotr < residual_tol:
# break
return x


Expand Down Expand Up @@ -140,6 +137,11 @@ def linearize(
log_prob_params, func_log_prob = make_functional_call(log_prob)
score_fn = torch.func.grad(func_log_prob)

log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
if cg_iters is None:
cg_iters = log_prob_params_numel
else:
cg_iters = min(cg_iters, log_prob_params_numel)
cg_solver = functools.partial(
conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol
)
Expand Down

0 comments on commit 117d645

Please sign in to comment.