Skip to content

Commit

Permalink
Export open_mfdataset (#470)
Browse files Browse the repository at this point in the history
* upgrade and export open_mfdataset

* use newest YAXArrayBase

* Fix unrelated test

* add dependabot

* Apply suggestions from code review

* fix locally

---------

Co-authored-by: Lazaro Alonso <[email protected]>
  • Loading branch information
meggart and lazarusA authored Nov 26, 2024
1 parent f4253c2 commit d7fd921
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 39 deletions.
7 changes: 7 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ Statistics = "1"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1.0"
WeightedOnlineStats = "0.3, 0.4, 0.5, 0.6"
YAXArrayBase = "0.6, 0.7"
YAXArrayBase = "0.7.5"
julia = "1.9"
135 changes: 97 additions & 38 deletions src/DatasetAPI/Datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using DiskArrays: DiskArrays, GridChunks
using Glob: glob
using DimensionalData: DimensionalData as DD

export Dataset, Cube, open_dataset, to_dataset, savecube, savedataset
export Dataset, Cube, open_dataset, to_dataset, savecube, savedataset, open_mfdataset

"""
Dataset object which stores an `OrderedDict` of YAXArrays with Symbol keys.
Expand Down Expand Up @@ -253,7 +253,7 @@ function collectdims(g)
varnames = get_varnames(g)
foreach(varnames) do k
d = get_var_dims(g, k)
v = get_var_handle(g, k)
v = get_var_handle(g, k, persist=false)
for (len, dname) in zip(size(v), d)
if !occursin("bnd", dname) && !occursin("bounds", dname)
datts = if dname in varnames
Expand All @@ -277,7 +277,7 @@ function toaxis(dimname, g, offs, len)
if !haskey(g, dimname)
return DD.rebuild(DD.name2dim(axname), 1:len)
end
ar = get_var_handle(g, dimname)
ar = get_var_handle(g, dimname, persist=false)
aratts = get_var_attrs(g, dimname)
if match(r"^(days)|(hours)|(seconds)|(months) since",lowercase(get(aratts,"units",""))) !== nothing
tsteps = try
Expand Down Expand Up @@ -337,6 +337,63 @@ open_mfdataset(g::AbstractString; kwargs...) = open_mfdataset(_glob(g); kwargs..
open_mfdataset(g::Vector{<:AbstractString}; kwargs...) =
merge_datasets(map(i -> open_dataset(i; kwargs...), g))

function merge_new_axis(alldatasets, firstcube,var,mergedim)
newdim = DD.rebuild(mergedim,1:length(alldatasets))
alldiskarrays = map(ds->ds.cubes[var].data,alldatasets).data
newda = diskstack(alldiskarrays)
newdims = (DD.dim(firstcube)...,newdim)
YAXArray(newdims,newda,deepcopy(firstcube.properties))
end
function merge_existing_axis(alldatasets,firstcube,var,mergedim)
allaxvals = map(ds->DD.dims(ds.cubes[var],mergedim).val,alldatasets)
newaxvals = reduce(vcat,allaxvals)
newdim = DD.rebuild(mergedim,newaxvals)
alldiskarrays = map(ds->ds.cubes[var].data,alldatasets)
istack = DD.dimnum(firstcube,mergedim)
newshape = ntuple(i->i!=istack ? 1 : length(alldiskarrays),ndims(firstcube))
newda = DiskArrays.ConcatDiskArray(reshape(alldiskarrays,newshape))
newdims = Base.setindex(firstcube.axes,newdim,istack)
YAXArray(newdims,newda,deepcopy(firstcube.properties))
end

"""
open_mfdataset(files::DD.DimVector{<:AbstractString}; kwargs...)
Opens and concatenates a list of dataset paths along the dimension specified in `files`.
This method can be used when the generic glob-based version of open_mfdataset fails
or is too slow.
For example, to concatenate a list of annual NetCDF files along the `Ti` dimension,
one can use:
````julia
files = ["1990.nc","1991.nc","1992.nc"]
open_mfdataset(DD.DimArray(files,DD.Ti()))
````
alternatively, if the dimension to concatenate along does not exist yet, the
dimension provided in the input arg is used:
````julia
files = ["a.nc","b.nc","c.nc"]
open_mfdataset(DD.DimArray(files,DD.Dim{:NewDim}(["a","b","c"])))
````
"""
function open_mfdataset(vec::DD.DimVector{<:AbstractString};kwargs...)
alldatasets = open_dataset.(vec;kwargs...);
fi = first(alldatasets)
mergedim = DD.dims(alldatasets) |> only
ars = map(collect(keys(fi.cubes))) do var
cfi = fi.cubes[var]
mergedar = if DD.dims(cfi,mergedim) !== nothing
merge_existing_axis(alldatasets,cfi,var,mergedim)
else
merge_new_axis(alldatasets,cfi,var,mergedim)
end
var => mergedar
end
Dataset(;ars...)
end


"""
open_dataset(g; driver=:all)
Expand All @@ -345,44 +402,46 @@ Open the dataset at `g` with the given `driver`.
The default driver will search for available drivers and tries to detect the useable driver from the filename extension.
"""
function open_dataset(g; driver = :all)
g = YAXArrayBase.to_dataset(g, driver = driver)
isempty(get_varnames(g)) && throw(ArgumentError("Group does not contain datasets."))
dimlist = collectdims(g)
dnames = string.(keys(dimlist))
varlist = filter(get_varnames(g)) do vn
upname = uppercase(vn)
!occursin("BNDS", upname) &&
!occursin("BOUNDS", upname) &&
!any(i -> isequal(upname, uppercase(i)), dnames)
end
allcubes = OrderedDict{Symbol,YAXArray}()
for vname in varlist
vardims = get_var_dims(g, vname)
iax = tuple(collect(dimlist[vd].ax for vd in vardims)...)
offs = [dimlist[vd].offs for vd in vardims]
subs = if all(iszero, offs)
nothing
else
ntuple(i -> (offs[i]+1):(offs[i]+length(iax[i])), length(offs))
end
ar = get_var_handle(g, vname)
att = get_var_attrs(g, vname)
if subs !== nothing
ar = view(ar, subs...)
dsopen = YAXArrayBase.to_dataset(g, driver = driver)
YAXArrayBase.open_dataset_handle(dsopen) do g
isempty(get_varnames(g)) && throw(ArgumentError("Group does not contain datasets."))
dimlist = collectdims(g)
dnames = string.(keys(dimlist))
varlist = filter(get_varnames(g)) do vn
upname = uppercase(vn)
!occursin("BNDS", upname) &&
!occursin("BOUNDS", upname) &&
!any(i -> isequal(upname, uppercase(i)), dnames)
end
if !haskey(att, "name")
att["name"] = vname
end
atts = propfromattr(att)
if any(in(keys(atts)), ["missing_value", "scale_factor", "add_offset"])
ar = CFDiskArray(ar, atts)
allcubes = OrderedDict{Symbol,YAXArray}()
for vname in varlist
vardims = get_var_dims(g, vname)
iax = tuple(collect(dimlist[vd].ax for vd in vardims)...)
offs = [dimlist[vd].offs for vd in vardims]
subs = if all(iszero, offs)
nothing
else
ntuple(i -> (offs[i]+1):(offs[i]+length(iax[i])), length(offs))
end
ar = get_var_handle(g, vname,persist=true)
att = get_var_attrs(g, vname)
if subs !== nothing
ar = view(ar, subs...)
end
if !haskey(att, "name")
att["name"] = vname
end
atts = propfromattr(att)
if any(in(keys(atts)), ["missing_value", "scale_factor", "add_offset"])
ar = CFDiskArray(ar, atts)
end
allcubes[Symbol(vname)] = YAXArray(iax, ar, atts, cleaner = CleanMe[])
end
allcubes[Symbol(vname)] = YAXArray(iax, ar, atts, cleaner = CleanMe[])
gatts = YAXArrayBase.get_global_attrs(g)
gatts = Dict{String,Any}(string(k)=>v for (k,v) in gatts)
sdimlist = Dict(DD.name(v.ax) => v.ax for (k, v) in dimlist)
Dataset(allcubes, sdimlist,gatts)
end
gatts = YAXArrayBase.get_global_attrs(g)
gatts = Dict{String,Any}(string(k)=>v for (k,v) in gatts)
sdimlist = Dict(DD.name(v.ax) => v.ax for (k, v) in dimlist)
Dataset(allcubes, sdimlist,gatts)
end
#Base.getindex(x::Dataset; kwargs...) = subsetcube(x; kwargs...)
YAXDataset(; kwargs...) = Dataset(YAXArrays.YAXDefaults.cubedir[]; kwargs...)
Expand Down
1 change: 1 addition & 0 deletions test/DAT/mapcube.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
end

@testset "Error shown in parallel" begin
import Zarr
x,y,z = X(1:4), Y(1:5), Z(1:6)
a1 = YAXArray((x,y,z), rand(4,5,6))
indims = InDims("x")
Expand Down

0 comments on commit d7fd921

Please sign in to comment.