From d51abc7f7bb3960d00b5a9dc4a336866d75bcc93 Mon Sep 17 00:00:00 2001 From: LenkaNovak Date: Fri, 1 Sep 2023 16:47:13 -0700 Subject: [PATCH] VF getindex func + test move func to indices.jl prior format make output a FieldVector rev + type stable Co-authored-by: @charleskawczynski add name test rm parent idx in test convert to Array fix fix cuda test with @allowscalar --- src/Fields/indices.jl | 8 +++++++- test/Fields/field.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/Fields/indices.jl b/src/Fields/indices.jl index 86d15537c9..92bfcc2b5f 100644 --- a/src/Fields/indices.jl +++ b/src/Fields/indices.jl @@ -17,7 +17,13 @@ end Base.@propagate_inbounds Base.getindex(field::Field, colidx::ColumnIndex) = column(field, colidx) - +Base.@propagate_inbounds function Base.getindex( + fv::FieldVector{T}, + colidx::ColumnIndex, +) where {T} + values = map(x -> x[colidx], _values(fv)) + return FieldVector{T, typeof(values)}(values) +end Base.@propagate_inbounds function column( field::SpectralElementField1D, colidx::ColumnIndex{1}, diff --git a/test/Fields/field.jl b/test/Fields/field.jl index 760cfef3bc..575f3048c2 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -1,4 +1,6 @@ using Test +using JET + using ClimaComms using OrderedCollections using StaticArrays, IntervalSets @@ -12,6 +14,7 @@ using LinearAlgebra: norm using Statistics: mean using ForwardDiff using CUDA +using CUDA: @allowscalar include( joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"), @@ -267,6 +270,37 @@ end @test Y.k.z === 3.0 end +function call_getcolumn(fv, colidx) + @allowscalar fvcol = fv[colidx] + nothing +end +function call_getproperty(fv) + fva = fv.c.a + nothing +end +@testset "FieldVector getindex" begin + cspace = TU.CenterExtrudedFiniteDifferenceSpace(Float32) + fspace = Spaces.ExtrudedFiniteDifferenceSpace{Spaces.CellFace}(cspace) + c = fill((a = Float32(1), b = Float32(2)), cspace) + f = fill((x = Float32(1), y = Float32(2)), fspace) + fv = Fields.FieldVector(; c, f) + colidx = Fields.ColumnIndex((1, 1), 1) # arbitrary index + + @allowscalar @test all(parent(fv.c.a[colidx]) .== Float32(1)) + @allowscalar @test all(parent(fv.f.y[colidx]) .== Float32(2)) + @allowscalar @test propertynames(fv) == propertynames(fv[colidx]) + + # JET tests + # prerequisite + call_getproperty(fv) # compile first + @test_opt call_getproperty(fv) + + call_getcolumn(fv, colidx) # compile first + @test_opt call_getcolumn(fv, colidx) + p = @allocated call_getcolumn(fv, colidx) + @test p ≤ 32 +end + @testset "FieldVector array_type" begin device = ClimaComms.device() context = ClimaComms.SingletonCommsContext(device)