Skip to content

Commit

Permalink
Fix Zygote errors with correlation sensitivities
Browse files Browse the repository at this point in the history
  - avoid trivial correlation sensitivity calculation
  - add type annotations to correlation_holder
  - make correlation_holder non-differentiable
  • Loading branch information
FrameConsult authored and sschlenkrich committed Feb 24, 2024
1 parent 492df05 commit 8c9e4b9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/DiffFusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ module Examples
end # module

include("chainrules/models.jl")
include("chainrules/termstructures.jl")
include("chainrules/simulations.jl")

"List of function names eligible for de-serialisation."
Expand Down
10 changes: 8 additions & 2 deletions src/analytics/Valuations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ function model_price_and_vegas(
# We need a correlation holder... for simplicity, we assume this is unique in the model
ch_alias = nothing
for m in model.models
if hasproperty(m, :correlation_holder) && !isnothing(m.correlation_holder)
if hasproperty(m, :correlation_holder) &&
!isnothing(m.correlation_holder) &&
length(m.correlation_holder.correlations) > 0 # avoid Zygote error by differentiating empty dict
#
ch_alias = m.correlation_holder.alias
break
end
Expand Down Expand Up @@ -299,7 +302,10 @@ function model_price_and_vegas_vector(
# We need a correlation holder... for simplicity, we assume this is unique in the model
ch_alias = nothing
for m in model.models
if hasproperty(m, :correlation_holder) && !isnothing(m.correlation_holder)
if hasproperty(m, :correlation_holder) &&
!isnothing(m.correlation_holder) &&
length(m.correlation_holder.correlations) > 0 # avoid Zygote error by differentiating empty dict
#
ch_alias = m.correlation_holder.alias
break
end
Expand Down
5 changes: 5 additions & 0 deletions src/chainrules/termstructures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

# do not differentiate trivial correlation holder setup
ChainRulesCore.@non_differentiable correlation_holder(alias::String,)
ChainRulesCore.@non_differentiable correlation_holder(alias::String, sep::String,)
ChainRulesCore.@non_differentiable correlation_holder(alias::String, sep::String, value_type::DataType)
20 changes: 10 additions & 10 deletions src/termstructures/correlation/CorrelationHolder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
alias::String
correlations::Dict{String, ModelValue}
sep::String
value_type::Type
value_type::DataType
end

A container holding correlation values.
Expand All @@ -20,25 +20,25 @@ struct CorrelationHolder <: CorrelationTermstructure
alias::String
correlations::Dict{String, ModelValue}
sep::String
value_type::Type
value_type::DataType
end


"""
correlation_holder(
alias::String,
correlations::Dict,
sep = "<>",
value_type = ModelValue,
sep::String = "<>",
value_type::DataType = ModelValue,
)

Create a CorrelationHolder object from dictionary.
"""
function correlation_holder(
alias::String,
correlations::Dict,
sep = "<>",
value_type = ModelValue,
sep::String = "<>",
value_type::DataType = ModelValue,
)
for (key, value) in correlations
@assert isa(value, value_type)
Expand All @@ -51,16 +51,16 @@ end
"""
correlation_holder(
alias::String,
sep = "<>",
value_type = ModelValue,
sep::String = "<>",
value_type::DataType = ModelValue,
)

Create an empty CorrelationHolder object.
"""
function correlation_holder(
alias::String,
sep = "<>",
value_type = ModelValue,
sep::String = "<>",
value_type::DataType = ModelValue,
)
return correlation_holder(alias, Dict{String, value_type}(), sep, value_type)
end
Expand Down

0 comments on commit 8c9e4b9

Please sign in to comment.