Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantile loss #157

Closed
jbbremnes opened this issue Nov 1, 2023 · 2 comments
Closed

Quantile loss #157

jbbremnes opened this issue Nov 1, 2023 · 2 comments

Comments

@jbbremnes
Copy link

I would like to test SimpleChains for prediction of multiple quantiles in a weather forecasting use case. By looking at custom_loss_layer.md I have tried to implement a quantile loss struct/function. However, I get an error indicating that a method is missing and I am not sure how to proceed, please see below. Any ideas?

using SimpleChains

struct QuantileLoss{T, Y<:AbstractVector{T}} <: SimpleChains.AbstractLoss{T}
    targets::Y
    prob::Y          # quantile levels
end
target(loss::QuantileLoss) = loss.targets
(::QuantileLoss)(y::AbstractVector, prob::AbstractVector) = QuantileLoss(y, prob) 
Base.show(io::IO, loss::QuantileLoss) = print(io, "QuantileLoss: $(loss.prob)")

#  quantile loss for multiple levels
function calculate_loss(loss::QuantileLoss, qt)
    y = loss.targets
    prob = loss.prob
    total_loss = zero(eltype(qt))
    for i in eachindex(y) 
        for j in axes(qt, 1)
            dev = y[i] - qt[j,i]
            cmp = dev > zero(eltype(qt))
            total_loss += ifelse(cmp, 1-prob[j], -prob[j]) * dev
        end
    end
    return total_loss
end

function (loss::QuantileLoss)(previous_layer_output::AbstractArray{T}, p::Ptr, pu) where {T}
    total_loss = calculate_loss(loss, previous_layer_output)
    total_loss, p, pu
end

function SimpleChains.layer_output_size(::Val{T}, sl::QuantileLoss, s::Tuple) where {T}
    SimpleChains._layer_output_size_no_temp(Val{T}(), sl, s)
end

function SimpleChains.forward_layer_output_size(::Val{T}, sl::QuantileLoss, s) where {T}
    SimpleChains._layer_output_size_no_temp(Val{T}(), sl, s)
end

function SimpleChains.chain_valgrad!(
    __,
    previous_layer_output::AbstractArray{T},
    layers::Tuple{QuantileLoss},
    _::Ptr,
    pu::Ptr{UInt8},
) where {T}
    loss = getfield(layers, 1)
    total_loss = calculate_loss(loss, previous_layer_output)
    y = loss.targets
    prob = loss.prob
    for i in eachindex(y)
        qt_i = previous_layer_output[:, i]  # ?                    
        for j in eachindex(qt_i)
            qt_i[j] = -prob[j]
            if y[i] < qt_i[j]
                qt_i[j] += one(T)
            end
        end
        previous_layer_output[:, i] = qt_i           
    end
    return total_loss, previous_layer_output, pu
end

##  simple example
prob = Float32[0.1, 0.5, 0.9];
x = rand(Float32, 2, 1000);
y = Float32.(sin.(pi*x[1,:]) .+ 0.1 .* randn(1000));

model = SimpleChain(static(2),
                    TurboDense{true}(tanh, 20),
                    TurboDense{true}(tanh, length(prob)))
model_loss = SimpleChains.add_loss(model, QuantileLoss(y, prob))

p = SimpleChains.init_params(model_loss);
G = SimpleChains.alloc_threaded_grad(model_loss);
opt = SimpleChains.ADAM(0.001)
SimpleChains.train_batched!(G, p, model_loss, x, opt, 100) 


julia> SimpleChains.train_batched!(G, p, model_loss, x, opt, 100)
ERROR: MethodError: no method matching length(::Nothing)

Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator})
   @ Base abstractdict.jl:58
  length(::Union{SparseArrays.FixedSparseVector{Tv, Ti}, SparseArrays.SparseVector{Tv, Ti}} where {Tv, Ti})
   @ SparseArrays ~/apps/julia-1.9.3/share/julia/stdlib/v1.9/SparseArrays/src/sparsevector.jl:95
  length(::Union{ArrayInterface.BidiagonalIndex, ArrayInterface.TridiagonalIndex})
   @ ArrayInterface ~/.julia/packages/ArrayInterface/lyyrK/src/ArrayInterface.jl:666
  ...

Stacktrace:
 [1] maybe_static(f::typeof(StaticArrayInterface.known_length), g::typeof(length), x::Nothing)
   @ Static ~/.julia/packages/Static/tQ45Q/src/Static.jl:762
 [2] static_length(x::Nothing)
   @ StaticArrayInterface ~/.julia/packages/StaticArrayInterface/IzhAV/src/size.jl:204
 [3] _maybe_size(#unused#::Base.HasLength, a::Nothing)
   @ StaticArrayInterface ~/.julia/packages/StaticArrayInterface/IzhAV/src/size.jl:29
 [4] static_size(a::Nothing)
   @ StaticArrayInterface ~/.julia/packages/StaticArrayInterface/IzhAV/src/size.jl:23
 [5] train_batched!(g::StrideArray{Float32, 2, (1, 2), Tuple{Static.StaticInt{123}, Int64}, Tuple{Nothing, StrideArraysCore.StrideReset{Int64}}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, Vector{Float32}}, p::StrideArraysCore.StaticStrideArray{Float32, 1, (1,), Tuple{Static.StaticInt{123}}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, 123}, _chn::SimpleChain{Tuple{Static.StaticInt{2}}, Tuple{TurboDense{true, Static.StaticInt{20}, typeof(tanh)}, TurboDense{true, Static.StaticInt{3}, typeof(tanh)}, QuantileLoss{Float32, Vector{Float32}}}}, X::Matrix{Float32}, opt::SimpleChains.ADAM, iters::Int64; batchsize::Nothing, leaveofflast::Bool)
   @ SimpleChains ~/.julia/packages/SimpleChains/nYZgm/src/optimize.jl:692
 [6] train_batched!(g::StrideArray{Float32, 2, (1, 2), Tuple{Static.StaticInt{123}, Int64}, Tuple{Nothing, StrideArraysCore.StrideReset{Int64}}, Tuple{Static.StaticInt{1}, Static.StaticInt{1}}, Vector{Float32}}, p::StrideArraysCore.StaticStrideArray{Float32, 1, (1,), Tuple{Static.StaticInt{123}}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, 123}, _chn::SimpleChain{Tuple{Static.StaticInt{2}}, Tuple{TurboDense{true, Static.StaticInt{20}, typeof(tanh)}, TurboDense{true, Static.StaticInt{3}, typeof(tanh)}, QuantileLoss{Float32, Vector{Float32}}}}, X::Matrix{Float32}, opt::SimpleChains.ADAM, iters::Int64)
   @ SimpleChains ~/.julia/packages/SimpleChains/nYZgm/src/optimize.jl:656
 [7] top-level scope
   @ REPL[19]:1
@chriselrod
Copy link
Contributor

Was this completed?

@jbbremnes
Copy link
Author

No, I got a bit further, but then had to focus on other work. I am still interested, though. The calculate_loss function was also not correct. Here is the latest.

#  quantile loss for multiple levels
function calculate_loss(loss::QuantileLoss, qt::AbstractArray)
    y = loss.y
    prob = loss.prob
    T = eltype(qt)
    total_loss = zero(T)
    for i in eachindex(y) 
        for j in axes(qt, 1)
            dev = y[i] - qt[j,i]
            cmp = dev < zero(T)
            total_loss += ifelse(cmp, prob[j] - one(T), prob[j]) * dev
        end
    end
    return total_loss
end
function SimpleChains.chain_valgrad!(
    __,
    arg::AbstractArray{T},        # previous_layer_output
    layers::Tuple{QuantileLoss},
    _::Ptr,
    pu::Ptr{UInt8}
) where {T}
    loss = getfield(layers, 1)
    #total_loss = calculate_loss(loss, arg)
    y    = getfield(loss, :y)
    prob = getfield(loss, :prob)
    total_loss = zero(T)
    for i in eachindex(y)
        for j in axes(arg, 1)
            dev = y[i] - arg[j, i]
            cmp = dev < 0
            arg[j, i] = ifelse(cmp, -prob[j] + one(T), -prob[j])
            total_loss += ifelse(cmp, prob[j] - one(T), prob[j]) * dev
        end
    end
    #println(size(y))
    return total_loss, arg, pu
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants