Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One-hot Encoding #69

Open
AlCap23 opened this issue Feb 17, 2023 · 2 comments
Open

One-hot Encoding #69

AlCap23 opened this issue Feb 17, 2023 · 2 comments

Comments

@AlCap23
Copy link

AlCap23 commented Feb 17, 2023

Hi there!

I've been trying to implement a one hot encoding using StochasticAD. So far, I've failed 🥲.

I think it essentially boils down to this TODO in the src.
After tinkering for some while, I've decided to ask for help given that I did not come up with a good solution.

Cheers!

@gaurav-arya
Copy link
Owner

Hey! Is it possible to provide a minimum example that you can't differentiate?

@AlCap23
Copy link
Author

AlCap23 commented Feb 24, 2023

Hey! Sorry for being Dormant. I think this might work out as a MWE ( and maybe in general, I have to test ).

using Revise
using StochasticAD
using Distributions

# Simple stochastic program

struct OneHot{T, K} <: AbstractVector{T}
	n::Int
	k::K
    val::T
end

OneHot(n::Int,k::K,val::T = one(K)) where {T,K} = OneHot{T, K}(n, StochasticAD.value(k), val - StochasticAD.value(val) + 1) 

Base.size(x::OneHot) = (x.n,)

Base.getindex(x::OneHot{T}, i::Int) where T = (x.k == i ? x.val : zero(T))

Base.argmax(x::OneHot) = x.k


_softmax(x) = begin
    y = exp.(x .- maximum(x))
    y ./ sum(y)
end

_logsoftmax(x) = begin
    y = (x .- maximum(x))
    y .- log(sum(exp, y))
end

f(θ) = begin
    id = rand(Categorical(_softmax(θ)))
    @info id
    v = OneHot(length(θ), id, id)
    sum(v'_logsoftmax(θ))
end

θ = randn(3)
f(θ)

m = StochasticModel(f, θ)

stochastic_gradient(m) # Returns a gradient, still have to check if it finds the right value though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants