Skip to content

Commit

Permalink
fix bug with periods
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelJuillard committed Nov 10, 2023
1 parent ab175df commit 1ee0d02
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
16 changes: 10 additions & 6 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function data!(datafile::AbstractString;
startperiod = start
if typeof(last) != Int || last > typemin(Int)
# start option and last option are used
@assert Ts == Ta "error in data!(): last must have the same frequency as the datafile"
@assert Tl == Ta "error in data!(): last must have the same frequency as the datafile"
lastperiod = last
end
elseif typeof(last) != Int || last > typemin(Int)
Expand Down Expand Up @@ -136,8 +136,10 @@ function find_letter_in_period(period::AbstractString)
return nothing
end

function identify_period_type(period::Union{AbstractString, Number})
if typeof(period) <: Number
function identify_period_type(period::Union{AbstractString, Number, Date})
if typeof(period) <: Date
return Date
elseif typeof(period) <: Number
if isinteger(period)
if period == 1
return Undated
Expand Down Expand Up @@ -186,7 +188,7 @@ function identify_period_type(period::Union{AbstractString, Number})
end
end

function periodparse(period::Union{AbstractString, Number})::ExtendedDates.PeriodsSinceEpoch
function periodparse(period::Union{AbstractString, Number, Date})::ExtendedDates.PeriodsSinceEpoch
period_type = identify_period_type(period)
if period_type == YearSE
if typeof(period) <: Number
Expand All @@ -204,6 +206,8 @@ function periodparse(period::Union{AbstractString, Number})::ExtendedDates.Perio
return parse(period_type, period)
elseif period_type == DaySE
return parse(period_type, period)
elseif period_type == Date
return DaySE(period)
elseif period_type == Undated
return Undated(period)
else
Expand All @@ -214,13 +218,13 @@ end
function MyAxisArrayTable(filename)
table = CSV.File(filename)
cols = AxisArrayTables.Tables.columnnames(table)
data = AxisArrayTables.Tables.matrix((;ntuple(i -> Symbol(i) => Tables.getcolumn(table, i), length(cols))...))
data = (AxisArrayTables.Tables.matrix((;ntuple(i -> Symbol(i) => Tables.getcolumn(table, i), length(cols))...)))
for (icol, name) in enumerate(cols)
if uppercase(String(name)) in ["DATE", "DATES", "PERIOD", "PERIODS", "TIME", "COLUMN1"]
rows = []
foreach(x -> push!(rows, periodparse(x)), data[:, icol])
k = union(1:icol-1, icol+1:size(data,2))
aat = AxisArrayTable(data[:, k], rows, cols[k])
aat = AxisArrayTable(Matrix{Union{Float64, Missing}}(data[:, k]), rows, cols[k])
return aat
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/filters/kalman/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ struct CalibSmootherOptions
last_obs::Int64
function CalibSmootherOptions(options::Dict{String,Any})
datafile = ""
first_obs = 1
last_obs = 0
first_obs = Undated(typemin(Int))
last_obs = Undated(typemin(Int))
for (k, v) in pairs(options)
if k == "datafile"
datafile = v::String
Expand Down Expand Up @@ -128,7 +128,7 @@ function calibsmoother!(; context=context,
last = nobs
presample = 0
data_pattern = Vector{Vector{Int64}}(undef, 0)
Yt = adjoint(Y)
Yt = copy(adjoint(Y))
for i = 1:nobs
push!(data_pattern, findall(.!ismissing.(Yt[:, i])))
end
Expand Down
8 changes: 2 additions & 6 deletions src/forecast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,14 @@ function recursive_forecasting!(; periods::Integer,
data_ = get_data!(context, datafile, data, variables, first_obs, last_obs, nobs)
first_period == Undated(typemin(Int)) && (first_period = row_labels(data_)[1])
last_period == Undated(typemin(Int)) && (last_period = row_labels(data_)[end])
T = typeof(first_period)
T = typeof(first_period).parameters[1]
empty!(results.forecast)
for p = first_period:last_period
Y = forecasting_(context=context, periods=periods, forecast_mode=calibsmoother, first_obs=first_obs, last_obs=p, data = data_, order=order)
if p == first_period
results.initial_smoother = copy(results.smoother)
end
if T <: Dates.UTInstant
p1 = T(p.periods.value + periods)
else
p1 = p + periods
end
p1 = p + T(periods)
push!(results.forecast, AxisArrayTable(Y,
p:p1,
[Symbol(v) for v in get_endogenous(context.symboltable)]))
Expand Down

0 comments on commit 1ee0d02

Please sign in to comment.