Skip to content

Commit

Permalink
fixed the bug urgh
Browse files Browse the repository at this point in the history
  • Loading branch information
behinger committed Jan 15, 2025
1 parent dc10ff6 commit dce1262
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ function _residuals(::Type{T}, yhat, y) where {T<:UnfoldModel}#; ContinuousTimeT
n_y = size(y, 2)
@debug n_yhat n_y
if n_yhat >= n_y
@debug "n_yhat > n_y" size(y) size.(_split_data(yhat, n_y))
@debug "n_yhat > n_y, yhat is longer" size(y) size.(_split_data(yhat, n_y))
return y .- _split_data(yhat, n_y)[1]
else
@debug "n_y < n_yhat"
@debug "n_yhat < n_y, y is longer"
yA, yB = _split_data(y, n_yhat)
@debug size(yA) size(yB)
res = y .- yA
return cat(res, .-yB; dims = 2)
res = yA .- yhat
return cat(res, yB; dims = 2)
end

end
Expand Down
13 changes: 13 additions & 0 deletions test/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,17 @@ pt = Unfold.result_to_table(m, p, repeat([evts], 2))
@test maximum(abs.(data_e .- (resids_e.+predict(m_mul)[1])[1, :, :])) < 0.0000001


##


@test all(Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3; 3 4 5]) .== 0)

# y longer
res = Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3 4; 3 4 5 6])
@test all(res[:, 1:3] .== 0)
@test res[:, 4] == [4, 6]

# yhat longer
@test all(Unfold._residuals(UnfoldModel, [1 2 3 4; 3 4 5 6], [1 2 3; 3 4 5]) .== 0)

end

0 comments on commit dce1262

Please sign in to comment.