From 4d418073122f02c79b2bbdba1c16874988ccc54d Mon Sep 17 00:00:00 2001 From: Raj Agrawal Date: Fri, 8 Dec 2023 10:31:36 -0800 Subject: [PATCH] Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error --- chirho/robust/internals/linearize.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 439440ba..09e8e521 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -94,17 +94,21 @@ def make_empirical_fisher_vp( randomness="different", ) + N = data[next(iter(data))].shape[0] # type: ignore + mean_vector = 1 / N * torch.ones(N) + def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: return batched_func_log_prob(params, data) - def jvp_fn(v: ParamDict) -> torch.Tensor: - return torch.func.jvp(bound_batched_func_log_prob, (log_prob_params,), (v,))[1] - - vjp_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] - def _empirical_fisher_vp(v: ParamDict) -> ParamDict: - jvp_log_prob_v = jvp_fn(v) - return vjp_fn(jvp_log_prob_v / jvp_log_prob_v.shape[0])[0] + def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: + return torch.func.jvp( + bound_batched_func_log_prob, (log_prob_params,), (v,) + )[1] + + # Perlmutter's trick + vjp_fn = torch.func.vjp(jvp_fn, log_prob_params)[1] + return vjp_fn(-1 * mean_vector)[0] # Fisher = -E[Hessian] return _empirical_fisher_vp