Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiplicative integration #5

Open
MartinuzziFrancesco opened this issue Nov 4, 2024 · 1 comment
Open

Multiplicative integration #5

MartinuzziFrancesco opened this issue Nov 4, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@MartinuzziFrancesco
Copy link
Owner

As described here, multiplicative integration modifies the recurrent equation from

$$h_{t} = \sigma(W \cdot x_t + U \cdot h_{t-1} + b)$$

to

$$h_{t} = \sigma(W \cdot x_t \odot U \cdot h_{t-1} + b)$$

This is applicable to virtually every cell which uses the recurrent equation (which is all of them). It would be nice to find a way to give the user a choice for each cell which version to use (:normal or :multiplicative_integration I guess), but in order to do so we would need to abstract a lot out of the current implementation

@MartinuzziFrancesco MartinuzziFrancesco added the enhancement New feature or request label Dec 13, 2024
@MartinuzziFrancesco
Copy link
Owner Author

struct MGUCell{I, H, V, MI} <: AbstractRecurrentCell{MI}
    Wi::I
    Wh::H
    bias::V
end

function MGUCell((input_size, hidden_size)::Pair{<:Int, <:Int};
        init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform,
        bias::Bool=true, integration=nothing)
    Wi = init_kernel(hidden_size * 2, input_size)
    Wh = init_recurrent_kernel(hidden_size * 2, hidden_size)
    b = create_bias(Wi, bias, size(Wi, 1))

    return {typeof(Wi), typeof(Wh), typeof(b), integration}MGUCell(Wi, Wh, b)
end

function (mgu::MGUCell{MI})(inp::AbstractVecOrMat, state) where {MI}
    _size_check(mgu, inp, 1 => size(mgu.Wi, 2))
    Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
    #split
    gxs = chunk(Wi * inp .+ b, 2; dims=1)
    ghs = chunk(Wh, 2; dims=1)

    forget_gate = sigmoid_fast.(compute_recurrence(MI, gxs[1], ghs[1] * state))
    candidate_state = tanh_fast.(compute_recurrence(MI, gxs[2], ghs[2] * (forget_gate .* state)))
    new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state
    return new_state, new_state
end

function compute_recurrence(::MultiplicativeIntegration, A, B)
    return map(*, A, B)
end

function compute_recurrence(::Nothing, A, B)
  return map(.+, A, B)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant