-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculating Jacobian of a model with respect to its parameters? #334
Comments
Update: I tried layer-wise jacobians: func, params, buffers = make_functional_with_buffers(model)
params = list(params)
def jac_p(p, i):
params[i] = p
out = func(params, buffers, input_dict)['logits']
return out
J0 = jacrev(jac_p)(params[0], 0) And I got out of memory error, whereas the normal Jacobian using PyTorch Model size: ~400k parameters
|
Another update: func, params, buffers = make_functional_with_buffers(model)
params = list(params)
def jac_p_o(p, i, o):
params[i] = p
out = func(params, buffers, input_dict)['logits']
return out[o]
J0 = jacrev(jac_p_o)(params[0], 0, 0) Runs, but outputs this a couple of times:
Moreover, this doesn't output what I wanted, it's output is of the shape |
Hi @mohamad-amin thanks for the issue and we are happy to help. Can you clarify what exactly is the quantity that you're looking to compute? From reading your messages, I think it is the following: import torch
import torch.nn as nn
from functorch import vmap, jacrev, make_functional_with_buffers
# simple model and data for demonstration
model = nn.Linear(3, 3)
data = torch.randn(100, 3) # 100 datapoints
func, params, buffers = make_functional_with_buffers(model)
# Compute a jacobian of the parameter for each datapoint
result = vmap(jacrev(func), (None, None, 0))(params, buffers, data)
# The jacobian w.r.t to model.weight has size (100, 3, 3, 3)
# - model.weight is 3, 3 and the output shape is (3,) so each jacobian has size 3, 3, 3
# - 100 for the 100 datapoints.
result[0].shape
# The jacobian w.r.t to model.bias has size (100, 3, 3)
# - model.weight is 3 and the output shape is 3 so each jacobian has size 3, 3
# - 100 for the 100 datapoints.
result[1].shape |
What version of functorch and PyTorch are you running? I tested the example I sent on the latest PyTorch and functorch. (The latest pytorch nightly is 1.11.0.dev20211212 and the latest functorch main) |
This is my pip info: Name: torch Name: functorch Should I be using the PyTorch nightly build?
Update: I just tried colab instalation provided in PyTorch but it's not working: |
@mohamad-amin Your colab link needs to be shared. The library is still currently in rapid development, so I wouldn't be shocked if there were fixes between 1.10 and nightly. |
Oh! Sorry about that. |
There were a lot of fixes to functorch w.r.t using |
@mohamad-amin I understand a bit more about neural tangent kernels after reading through some papers. I can probably put together an example in the next week, but I'm curious about the use case -- given a neural network, we can compute its neural tangent kernel, which is some numeric quantity. What do researchers do with these quantities? It looks like some research derive a mathematical expression for the NTK and then use that NTK in a SVM, but this particular case seems different from what we're doing here. |
@zou3519 That's nice! thanks for the effort. So what we're computing here, is the "empirical" ntk, which I think was first derived (or shown to be useful) here: https://arxiv.org/pdf/1902.06720.pdf. Here So this is just the first order taylor expansion of the neural network (w.r.t it's parameters around initialization). If your network is wide enough, this taylor expansion will be a pretty good approximation of the actual neural network after being trained using SGD for time This can have a lot of use cases, some of which are mentioned here: I think reading that paper that I first mentioned will help a lot to get going with empirical NTKs, but if you wanna know what NTK actually is, I think this is the original NTK paper: https://arxiv.org/abs/1806.07572 (hard to read imo, but brilliant when you get familiar with the notations). Let me know if there's anything else that I can help with! |
Hi @mohamad-amin! Thanks for the references to the papers and for the detailed explanation. I'm still working my way through the papers, but I think I have a working example of computing the empirical NTK. I'm a bit confused about what the actual shape of the empirical NTK is. Let's say I had a simple nn.Linear(5, 7, bias=False) layer, and Then the jacobian of each example has shape Is that correct, or is the reduction done over the entire jacobian, resulting in an empirical NTK of shape |
Hey, Great, looking forward to it! I'm not too confident about this, but as far as I know, it depends on (how/if) you contract the jacobian of This is my personal understanding that might or might not be helpful: *: I've read somewhere that as long as your network's outputs are going to converge to onehot in the infinite width limit, even for your finite width approximations with empirical ntk, tracing over the covariance results in better accuracy, but I'm not convinced that this is true. |
Hey @mohamad-amin, I just wanted to give a quick status update (sorry for leaving you hanging!). We've been stuck on a bug #417 that makes the NTK computation error out in functorch, but we're working through it :). Hopefully will fix the bug sometime next week. |
@mohamad-amin we finally fixed #417! Here's the first version of the example of how to compute NTKs with functorch. https://github.com/pytorch/functorch/blob/cb876fad2b2a9269424c8212a82652f372db6dfa/notebooks/neural_tangent_kernels.ipynb . The code runs on the latest build of functorch (see "functorch main" in https://github.com/pytorch/functorch#installing-functorch-main for installation instructions) if you wanted to give it a try. We would greatly appreciate your feedback on the example! Ultimately I'd like to add it to our website as one of the tutorials (https://pytorch.org/functorch/nightly/). |
@zou3519 Thanks for your tutorial on NTK computation. Following your procedures, I also implement the PyTorch version without using Besides, to make the tutorial more complete, it would be extremely nice if you could add the part of computing import torch
import torch.nn as nn
from torch.nn.utils import _stateless
import functools
def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,
parameters: dict[str, nn.Parameter] = None,
compute='full') -> torch.Tensor:
einsum_expr: str = ''
match compute:
case 'full':
einsum_expr = 'Naf,Mbf->NMab'
case 'trace':
einsum_expr = 'Naf,Maf->NM'
case 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
case _:
raise ValueError(compute)
if parameters is None:
parameters = dict(module.named_parameters())
keys, values = zip(*parameters.items())
def func(*params: torch.Tensor, _input: torch.Tensor = None):
_output: torch.Tensor = _stateless.functional_call(
module, {n: p for n, p in zip(keys, params)}, _input)
return _output # (N, C)
jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
functools.partial(func, _input=input1), values, vectorize=True)
jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
functools.partial(func, _input=input2), values, vectorize=True)
jac1 = [j.flatten(2) for j in jac1]
jac2 = [j.flatten(2) for j in jac2]
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
return result |
@ain-soph Well, you're only computing the explicit empirical NTK in this example :P You may find computing the implicit NTK harder with existing PyTorch APIs. That being said, in many cases, you can use both functorch and PyTorch core to compute the same things. It's up to you which API you prefer more. Personally, I like functorch (i.e. Jax)-style autograd APIs when thinking about complicated gradient quantities, while I prefer PyTorch's imperative AD API for more traditional neural networks. The goal of functorch is to give you that choice (and also give you vmap :P)! I will note that the |
@Chillee I gradually understand that you are correct. |
@ain-soph, thanks for sharing your codes. However, the codes get so different results than the torch ntk tutorial. Is that normal? |
@Fangwq result1 = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
result2 = ntk(net, x_train, x_test)
print((result1-result2).abs().max())
print((result1-result2).abs().sum()) Output:
I don't think it's very different result. |
Hey, I would like to calculate the mentioned jacobians. Right now I'm trying this:
But this gives me the following error:
This is kind of maybe expected? As
params
is a tuple of tensors and not a single tensor. Butjax
does this by returning a dict (or I've not explored enough to see all the cases).But anyway, is it possible in functorch to do this? (A possible solution might be to loop over the parameters and do param-wise jacobian, but that gotta be slow I guess, right?)
However, this works for me (slow, but does the job) and I'm using this right now for calculating the jacobians of a model with respect to its parameters:
For my usecase, I would be super happy with just applying
vmap
to the inner for loop which iterates overout.shape[1]
, namely the logits.P.S: I've tried using
torch.functional.jacobian
but I get the same error. Maybe I'm missing something here?Use case: computing neural tangent kernel https://en.wikipedia.org/wiki/Neural_tangent_kernel
The text was updated successfully, but these errors were encountered: