diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index f342ec89a7..f7c9b8a677 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -161,6 +161,10 @@ AxisTensor(axes::Tuple{Vararg{AbstractAxis}}, components) = components::AbstractArray{<:Any, N}, ) where {N, A} = AxisTensor(A.instance, components) +# conversion of components +AxisTensor{T, N, A, S}(a::AxisTensor{<:Any, N, A, <:Any}) where {T, N, A, S} = + AxisTensor(axes(a), S(components(a))) +Base.convert(::Type{T}, a::AxisTensor) where {T <: AxisTensor} = T(a) Base.axes(a::AxisTensor) = getfield(a, :axes) Base.axes(::Type{AxisTensor{T, N, A, S}}) where {T, N, A, S} = A.instance diff --git a/test/Geometry/axistensors.jl b/test/Geometry/axistensors.jl index f935404de8..1d393b0547 100644 --- a/test/Geometry/axistensors.jl +++ b/test/Geometry/axistensors.jl @@ -15,6 +15,10 @@ ClimaCore.Geometry.assert_exact_transform() = true f(x) = x.u₁ + x.u₂ + x.u₃ @test_opt f(x) + ref = Ref(zero(x)) + ref[] = Geometry.Covariant12Vector(1, 2) # Int components instead of Float64 + @test ref[] == x + M = Geometry.Axis2Tensor( (Geometry.Cartesian12Axis(), Geometry.Covariant12Axis()), [1.0 0.0; 0.5 2.0],