Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher order derivative products? #1102

Closed
rdangovs opened this issue Mar 27, 2020 · 4 comments
Closed

Higher order derivative products? #1102

rdangovs opened this issue Mar 27, 2020 · 4 comments

Comments

@rdangovs
Copy link

Suppose I have a loss function that takes data and parameters . Can I compute the following?

I.e. can I differentiate through the gradient? Please find below a pytorch example of how to achieve that for and .

import torch

# define the loss as a function of data and param
def loss(data, param):
    return data * (param ** 3)

# inner and outer data
a = 1. 
b = 2.

# init the param and make a copy in `y`
x = torch.tensor([1.], requires_grad=True)
y = x

# finetune `y` for one step
# note that `create_graph=True` which allows graph of the derivative 
# to be computed, and thus allowing higher order derivative products
innerloss = loss(a, y)
grad = torch.autograd.grad(b, y, create_graph=True)[0]
y = y - 1 * grad

# get gradients for the original param `x`
outerloss = loss(dataouter, y)
print(torch.autograd.grad(outerloss, x)[0]) # result is -120

I wonder whether this behavior could be reproduced in Flux for a large class of loss functions in an easy way, say starting from this one? Here is one unsuccessful attempt of mine.

η = 1.

function innerloss(a, x)
    sum(a .* x .^ 3)
end

function outerloss(a, b, x)
    innergs = gradient(params(x)) do 
        innerloss(a, x)
    end    
    adaptedx = x - η * innergs[x]
    innerloss(b, adaptedx)
end

a = [1.]; b = [2.]; x = [1.];

gs = gradient(params(x)) do
    outerloss(a, b, x)
end

print(gs[x], '\n') # [-552.0]

It seems to me that here gradient does not differentiate through innergs properly. I am afraid my understanding of gradient is currently limited to make this work right now. Could you help me? Any advice is appreciated. Thanks!

@CarloLucibello
Copy link
Member

I modified outerloss from your examples, is this the expected result?

η = 1.

function innerloss(a, x)
    sum(a .* x .^ 3)
end

function outerloss(a, b, x)
    g = gradient(x -> innerloss(a, x), x)[1]
    adaptedx = x - η * g
    innerloss(b, adaptedx)
end

a = [1.]; b = [2.]; x = [1.];

gs = gradient(params(x)) do
    outerloss(a, b, x)
end

println(gs[x]) # [-264.0]

@rdangovs
Copy link
Author

@CarloLucibello: thanks! I am afraid this is not what we look for exactly. Given the example above, the computation one would like to do is the following

I guess, another way to tackle this is to write the chain rule explicitly

So then my code would have to compute the Hessian explicitly. It seems that this issue is similar to #129.

Any thoughts on how I can solve this elegantly? Thanks!

@lssimoes
Copy link

@rdangovs would perhaps Zygote.hessian suit you?

@rdangovs
Copy link
Author

@lssimoes Thanks! Will give it a go!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants