Skip to content

Commit

Permalink
Merge branch 'main' of github.com:DynareJulia/Dynare.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelJuillard committed Nov 21, 2023
2 parents 88e5a65 + 7b9d677 commit 9e4d8b1
Show file tree
Hide file tree
Showing 14 changed files with 579 additions and 221 deletions.
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.8.1"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AxisArrayTables = "af8da316-43a4-49f0-bd76-7de0cd630fd6"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BipartiteMatching = "79040ab4-24c8-4c92-950c-d48b5991a0f6"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down Expand Up @@ -74,10 +75,12 @@ PardisoSolver = ["MKL", "Pardiso"]
PathSolver = "PATHSolver"

[compat]
AbstractMCMC = "4.4.0"
AdvancedMH = "0.7.4"
AbstractMCMC = "4.4.0, 5"
AdvancedMH = "0.7.4, 0.8"
AxisArrayTables = "0.1.2"
AxisArrays = "0.4.7"
BenchmarkTools = "1.2.2"
BipartiteMatching = "0.1"
CSV = "0.10.2"
DataFrames = "1.3.2"
Distances = "0.10.7"
Expand Down
6 changes: 4 additions & 2 deletions src/Dynare.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Dynare

using ExtendedDates
using Reexport
@reexport using ExtendedDates

using Logging
using Printf

Expand All @@ -19,7 +21,7 @@ include("utils.jl")
include("dynare_functions.jl")
include("dynare_containers.jl")
include("accessors.jl")
export irf, simulation
export forecast, irf, simulation, smoother
include("model.jl")
export get_abc, get_de
include("symboltable.jl")
Expand Down
25 changes: 11 additions & 14 deletions src/DynareParser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,19 @@ function get_varobs(modeljson::Dict{String,Any})
return varobs
end

function isfile_recent(filename, modfilename)
pathname = joinpath(modfilename, "model/julia", filename)
return isfile(pathname)
end

function check_function_files!(modfileinfo::ModFileInfo, modfilename::String)
dirname = modfilename * "/model/julia/"
if isfile(dirname * "SparseDynamicResid!.jl")
modfileinfo.has_dynamic_file = true
end
if isfile(dirname * "SparseStaticResid!.jl")
modfileinfo.has_static_file = true
end
if isfile(dirname * "DynamicSetAuxiliarySeries.jl")
modfileinfo.has_auxiliary_variables = true
if !isfile(dirname * "SetAuxiliaryVariables.jl")
error(dirname * "SetAuxiliaryVariables.jl is missing")
end
end
if isfile(dirname * "SteadyState2.jl")
modfileinfo.has_steadystate_file = true
modfileinfo.has_dynamic_file = isfile_recent("SparseDynamicResid!.jl", modfilename)
modfileinfo.has_static_file = isfile_recent("SparseStaticResid!.jl", modfilename)
modfileinfo.has_steadystate_file = isfile_recent("SteadyState2.jl", modfilename)
modfileinfo.has_auxiliary_variables = isfile_recent("DynamicSetAuxiliarySeries.jl", modfilename)
if modfileinfo.has_auxiliary_variables && !isfile(dirname * "SetAuxiliaryVariables.jl")
error(dirname * "SetAuxiliaryVariables.jl is missing")
end
end

Expand Down
50 changes: 43 additions & 7 deletions src/accessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,55 @@ simulation(varnames::Tuple;
=#

## SMOOTHER

function smoother(; context = context,
firstperiod = AxisArrayTables.axes(context.results.model_results[1].smoother)[1][1],
lastperiod = AxisArrayTables.axes(context.results.model_results[1].smoother)[1][end],
firstperiod = row_labels(context.results.model_results[1].smoother)[1],
lastperiod = row_labels(context.results.model_results[1].smoother)[end],
)
return context.results.model_results[1].smoother[firstperiod..lastperiod]
return context.results.model_results[1].smoother[firstperiod:lastperiod]
end

function smoother(varnames; context = context,
firstperiod = AxisArrayTables.axes(context.results.model_results[1].smoother)[1][1],
lastperiod = AxisArrayTables.axes(context.results.model_results[1].smoother)[1][end],
firstperiod = row_labels(context.results.model_results[1].smoother)[1],
lastperiod = row_labels(context.results.model_results[1].smoother)[end],
)
return context.results.model_results[1].smoother[firstperiod:lastperiod, varnames]
end

## FORECAST
function forecast(; context = context,
firstperiod = row_labels(context.results.model_results[1].smoother)[1],
lastperiod = row_labels(context.results.model_results[1].smoother)[end],
informationperiod = Undated(typemin(Int))
)
return context.results.model_results[1].smoother[firstperiod..lastperiod, varnames]
forecast_ = context.results.model_results[1].forecast
if length(forecast_) == 1
return forecast_[1][firstperiod:lastperiod]
else
D = Dict((row_labels(d)[1] => d) for d in forecast_)
if informationperiod != Undated(typemin(Int))
return D[informationperiod]
else
return D
end
end
end

function forecast(varnames;
context = context,
firstperiod = row_labels(context.results.model_results[1].smoother)[1],
lastperiod = row_labels(context.results.model_results[1].smoother)[end],
informationperiod = Undated(typemin(Int))
)
forecast_ = context.results.model_results[1].forecast
if length(forecast_) == 1
return forecast_[1][firstperiod:lastperiod]
else
D = Dict((row_labels(d)[1] => d[varnames]) for d in forecast_)
if informationperiod != Undated(typemin(Int))
return D[informationperiod]
else
return D
end
end
end

176 changes: 149 additions & 27 deletions src/data.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,135 @@
import AxisArrays
using AxisArrayTables
using ExtendedDates

function data!(datafile::AbstractString;
context::Context = context,
variables::Vector{<:Union{String,Symbol}} = [],
start::PeriodsSinceEpoch = Undated(typemin(Int)),
last::PeriodsSinceEpoch = Undated(typemin(Int)),
nobs::Integer = 0,
)
aat = MyAxisArrayTable(datafile)
Ta = typeof(row_labels(aat)[1])
if Ta <: Dates.UTInstant
T = Ta.parameters[1]
elseif Ta == Int
T = Int
else
error("Unrecognized type")
end
Ts = typeof(start)
Tl = typeof(last)
ny = length(variables)
if typeof(start) != Int || start > typemin(Int)
# start option is used
@assert Ts == Ta "error in data!(): start must have the same frequency as the datafile"
startperiod = start
if typeof(last) != Int || last > typemin(Int)
# start option and last option are used
@assert Tl == Ta "error in data!(): last must have the same frequency as the datafile"
lastperiod = last
elseif nobs > 0
lastperiod = startperiod + T(nobs) - T(1)
else
lastperiod = row_labels(aat)[end]
end
elseif typeof(last) != Int || last > typemin(Int)
# start option isn't used but last option is used
@assert Tl == Ta "error in data!(): last must have the same frequency as the datafile"
lastperiod = last
if nobs > 0
startperiod = last - T(nobs) + T(1)
else
startperiod = row_labels(aat)[1]
end
else
# neither start option nor last option are used
startperiod = row_labels(aat)[1]
if nobs > 0
lastperiod = start + T(nobs) - T(1)
else
lastperiod = row_labels(aat)[end]
end
end
if length(variables) == 0
context.work.data = copy(aat[startperiod:lastperiod, :])
else
context.work.data = copy(aat[startperiod:lastperiod, Symbol.(variables)])
end
return context.work.data
end

function get_data!(context::Context,
datafile::String,
data::AxisArrayTable,
variables::Vector{<:Union{String, Symbol}},
first_obs::PeriodsSinceEpoch,
last_obs::PeriodsSinceEpoch,
nobs::Int
)
@assert isempty(datafile) || isempty(data) "datafile and data can't be used at the same time"

if !isempty(datafile)
data!(datafile,
context = context,
variables = variables,
start = first_obs,
last = last_obs,
nobs = nobs)
elseif !isempty(data)
first_obs == Undated(typemin(Int)) && (first_obs = row_labels(data)[1])
last_obs == Undated(typemin(Int)) && (last_obs = row_labels(data)[end])
context.work.data = copy(data[first_obs:last_obs, Symbol.(variables)])
else
error("needs datafile or data argument")
end
return context.work.data
end

function get_detrended_data(context::Context,
datafile::String,
data::AxisArrayTable,
variables::Vector{<:Union{String, Symbol}},
first_obs::PeriodsSinceEpoch,
last_obs::PeriodsSinceEpoch,
nobs::Int
)

endogenous_names = get_endogenous(context.symboltable)
trends = context.results.model_results[1].trends
steady_state = Vector(AxisArrays.AxisArray(trends.endogenous_steady_state, endogenous_names)[variables])
linear_trend = Vector(AxisArrays.AxisArray(trends.endogenous_linear_trend, endogenous_names)[variables])
quadratic_trend = Vector(AxisArrays.AxisArray(trends.endogenous_quadratic_trend, endogenous_names)[variables])

get_data!(context,
datafile,
data,
variables,
first_obs,
last_obs,
nobs
)
aat = copy(context.work.data)
if context.modfileinfo.has_trends
if !isempty(quadratic_trend)
remove_quadratic_trend!(aat,
adjoint(steady_state),
adjoint(linear_trend),
adjoint(quadratic_trend)
)
else
remove_linear_trend!(aat,
adjoint(steady_state),
adjoint(linear_trend),
)
end
else
aat .-= adjoint(steady_state)
end
return aat
end

function find_letter_in_period(period::AbstractString)
for c in period
if 'A' <= c <= 'z'
Expand All @@ -10,10 +139,16 @@ function find_letter_in_period(period::AbstractString)
return nothing
end

function identify_period_type(period::Union{AbstractString, Number})
if typeof(period) <: Number
function identify_period_type(period::Union{AbstractString, Number, Date})
if typeof(period) <: Date
return Date
elseif typeof(period) <: Number
if isinteger(period)
return Undated
if period == 1
return Undated
else
return YearSE
end
else
throw(ErrorException)
end
Expand Down Expand Up @@ -56,10 +191,14 @@ function identify_period_type(period::Union{AbstractString, Number})
end
end

function periodparse(period::Union{AbstractString, Number})::ExtendedDates.PeriodsSinceEpoch
function periodparse(period::Union{AbstractString, Number, Date})::ExtendedDates.PeriodsSinceEpoch
period_type = identify_period_type(period)
if period_type == YearSE
return parse(period_type, period)
if typeof(period) <: Number
return(YearSE(Int(period)))
else
return parse(period_type, period)
end
elseif period_type == SemesterSE
return parse(period_type, period)
elseif period_type == QuarterSE
Expand All @@ -70,8 +209,10 @@ function periodparse(period::Union{AbstractString, Number})::ExtendedDates.Perio
return parse(period_type, period)
elseif period_type == DaySE
return parse(period_type, period)
elseif period_type == Date
return DaySE(period)
elseif period_type == Undated
return Int(period)
return Undated(period)
else
throw(ErrorException)
end
Expand All @@ -80,13 +221,13 @@ end
function MyAxisArrayTable(filename)
table = CSV.File(filename)
cols = AxisArrayTables.Tables.columnnames(table)
data = AxisArrayTables.Tables.matrix((;ntuple(i -> Symbol(i) => Tables.getcolumn(table, i), length(cols))...))
data = (AxisArrayTables.Tables.matrix((;ntuple(i -> Symbol(i) => Tables.getcolumn(table, i), length(cols))...)))
for (icol, name) in enumerate(cols)
if uppercase(String(name)) in ["DATE", "DATES", "PERIOD", "PERIODS", "TIME", "COLUMN1"]
rows = []
foreach(x -> push!(rows, periodparse(x)), data[:, icol])
k = union(1:icol-1, icol+1:size(data,2))
aat = AxisArrayTable(data[:, k], rows, cols[k])
aat = AxisArrayTable(Matrix{Union{Float64, Missing}}(data[:, k]), rows, cols[k])
return aat
end
end
Expand All @@ -105,22 +246,3 @@ function is_continuous(periods::Vector{ExtendedDates.DatePeriod})
end
end
end

function get_data(
filename::String,
variables::Vector{String};
start::Int64 = 1,
last::Int64 = 0,
)
aat = MyAxisArrayTable(filename)
ny = length(variables)
if last == 0
last = size(aat, 1) - start + 1
end
nobs = last - start + 1
Y = Matrix{Union{Missing,Float64}}(undef, ny, nobs)
for (i, v) in enumerate(variables)
Y[i, :] .= Matrix(aat[:, Symbol(v)])[start:last]
end
return Y
end
Loading

0 comments on commit 9e4d8b1

Please sign in to comment.