Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency in mode type parameters for ReverseModeSplit #1881

Closed
gdalle opened this issue Sep 23, 2024 · 9 comments · Fixed by #1895
Closed

Inconsistency in mode type parameters for ReverseModeSplit #1881

gdalle opened this issue Sep 23, 2024 · 9 comments · Fixed by #1895

Comments

@gdalle
Copy link
Contributor

gdalle commented Sep 23, 2024

The docstring and the type definition disagree on the order of type parameters (switch between Width and RuntimeActivity)

struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}
Reverse mode differentiation.
- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward.
- `ReturnShadow`: Should Enzyme return the shadow return value from the augmented-forward.
- `RuntimeActivity`: Should Enzyme differentiate with runtime activity on (default off).
- `Width`: Batch Size (0 if to be automatically derived)
- `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived).
"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end
const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}()
const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}()
@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}()
@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}()

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2024

yeah thats definitely a bug, PR to fix is welcome!

That sad the internal of this struct are not considered public API and subject to change.

@gdalle
Copy link
Contributor Author

gdalle commented Sep 23, 2024

Do you mean that the precise order and nature of type parameters in each Mode is not part of the API? If so, that needs to be explicitly stated within the docstring, because as things stand it is documented and not marked private so people (like me) rightfully assume it is public.

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2024

correct (and that's why we have helper functions)

@gdalle
Copy link
Contributor Author

gdalle commented Sep 23, 2024

So if I understand correctly:

  • The type parameters are documented but not part of the public API, so people shouldn't use them
  • The helper functions are not documented but part of the public API, so people should use them

Sounds about right? If so, I might add some warnings here and there when I open the PR

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2024

👍

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2024

to be clear the type parameters are generally stable, but subject to change if something needs to be added. The helper functions will always adapt to those changes, hence are the recommended way to use -- if needed.

@gdalle
Copy link
Contributor Author

gdalle commented Sep 23, 2024

While we're on helper functions, it turns out I need the following ones for DI to work. Any comments on naming, and whether or not they're welcome in EnzymeCore?

function mode_noprimal(
    ::Type{ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}}
) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}
    return ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}()
end


function mode_withprimal(
    ::Type{ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}}
) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}
    return ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}()
end


function mode_noprimal(
    ::Type{ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}}
) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}
    return ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()
end


function mode_withprimal(
    ::Type{ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}}
) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}
    return ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()
end


function mode_noprimal(
    ::Type{
        ReverseModeSplit{
            ReturnPrimal,
            ReturnShadow,
            RuntimeActivity,
            Width,
            ModifiedBetween,
            ABI,
            ErrIfFuncWritten,
        },
    },
) where {
    ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
}
    return ReverseModeSplit{
        false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
    }()
end


function mode_withprimal(
    ::Type{
        ReverseModeSplit{
            ReturnPrimal,
            ReturnShadow,
            RuntimeActivity,
            Width,
            ModifiedBetween,
            ABI,
            ErrIfFuncWritten,
        },
    },
) where {
    ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
}
    return ReverseModeSplit{
        true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
    }()
end


function mode_split(
    ::Type{ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}}
) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}
    ReturnShadow = true
    Width = 0
    ModifiedBetween = true
    return ReverseModeSplit{
        ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
    }()
end

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2024 via email

@gdalle
Copy link
Contributor Author

gdalle commented Sep 23, 2024

Sounds good. I'll open the PR tonight or tomorrow :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants