Skip to content

Commit

Permalink
Use Hessian formulation of Fisher information in `make_empirical_fish…
Browse files Browse the repository at this point in the history
…er_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error
  • Loading branch information
agrawalraj authored Dec 8, 2023
1 parent 3f0c83d commit 4d41807
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,21 @@ def make_empirical_fisher_vp(
randomness="different",
)

N = data[next(iter(data))].shape[0] # type: ignore
mean_vector = 1 / N * torch.ones(N)

def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor:
return batched_func_log_prob(params, data)

def jvp_fn(v: ParamDict) -> torch.Tensor:
return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1]

vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1]

def _empirical_fisher_vp(v: ParamDict) -> ParamDict:
jvp_log_prob_v = jvp_fn(v)
return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0]
def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:
return torch.func.jvp(
bound_batched_func_log_prob, (log_prob_params,), (v,)
)[1]

# Perlmutter's trick
vjp_fn = torch.func.vjp(jvp_fn, log_prob_params)[1]
return vjp_fn(-1 * mean_vector)[0] # Fisher = -E[Hessian]

return _empirical_fisher_vp

Expand Down

0 comments on commit 4d41807

Please sign in to comment.