-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathseqdata.py
536 lines (477 loc) · 17 KB
/
seqdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
import warnings
from pathlib import Path
from typing import (
Any,
Dict,
Hashable,
Iterable,
Literal,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
cast,
)
import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
import zarr
from numcodecs import Blosc
from seqdata._io.bed_ops import (
_bed_to_zarr,
_expand_regions,
_set_uniform_length_around_center,
read_bedlike,
)
from seqdata.types import FlatReader, PathType, RegionReader
from .utils import _filter_by_exact_dims, _filter_layers, _filter_uns
def open_zarr(
store: PathType,
group: Optional[str] = None,
synchronizer=None,
chunks: Optional[
Union[Literal["auto"], int, Mapping[str, int], Tuple[int, ...]]
] = "auto",
decode_cf=True,
mask_and_scale=False,
decode_times=True,
concat_characters=False,
decode_coords=True,
drop_variables: Optional[Union[str, Iterable[str]]] = None,
consolidated: Optional[bool] = None,
overwrite_encoded_chunks=False,
chunk_store: Optional[Union[MutableMapping, PathType]] = None,
storage_options: Optional[Dict[str, str]] = None,
decode_timedelta: Optional[bool] = None,
use_cftime: Optional[bool] = None,
zarr_version: Optional[int] = None,
**kwargs,
):
"""Open a SeqData object from disk.
Parameters
----------
store : str, Path
Path to the SeqData object.
group : str, optional
Name of the group to open, by default None
synchronizer : None, optional
Synchronizer to use, by default None
chunks : {None, True, False, int, dict, tuple}, optional
Chunking scheme to use, by default "auto"
decode_cf : bool, optional
Whether to decode CF conventions, by default True
mask_and_scale : bool, optional
Whether to mask and scale data, by default False
decode_times : bool, optional
Whether to decode times, by default True
concat_characters : bool, optional
Whether to concatenate characters, by default False
decode_coords : bool, optional
Whether to decode coordinates, by default True
drop_variables : {None, str, iterable}, optional
Variables to drop, by default None
consolidated : bool, optional
Whether to consolidate metadata, by default None
overwrite_encoded_chunks : bool, optional
Whether to overwrite encoded chunks, by default False
chunk_store : {None, MutableMapping, str, Path}, optional
Chunk store to use, by default None
storage_options : dict, optional
Storage options to use, by default None
decode_timedelta : bool, optional
Whether to decode timedeltas, by default None
use_cftime : bool, optional
Whether to use cftime, by default None
zarr_version : int, optional
Zarr version to use, by default None
Returns
-------
xr.Dataset
SeqData object
"""
ds = xr.open_zarr(
store=store,
group=group,
synchronizer=synchronizer,
chunks=chunks, # type: ignore
decode_cf=decode_cf,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
consolidated=consolidated,
overwrite_encoded_chunks=overwrite_encoded_chunks,
chunk_store=chunk_store,
storage_options=storage_options,
decode_timedelta=decode_timedelta,
use_cftime=use_cftime,
zarr_version=zarr_version,
**kwargs,
)
return ds
def to_zarr(
sdata: Union[xr.DataArray, xr.Dataset],
store: PathType,
chunk_store: Optional[Union[MutableMapping, PathType]] = None,
mode: Optional[Literal["w", "w-", "a", "r+"]] = None,
synchronizer: Optional[Any] = None,
group: Optional[str] = None,
encoding: Optional[Dict] = None,
compute=True,
consolidated: Optional[bool] = None,
append_dim: Optional[Hashable] = None,
region: Optional[Dict] = None,
safe_chunks=True,
storage_options: Optional[Dict] = None,
zarr_version: Optional[int] = None,
):
"""Write a xarray object to disk as a Zarr store.
Makes use of the `to_zarr` method of xarray objects, but modifies
the encoding for cases where the chunking is not uniform.
Parameters
----------
sdata : xr.Dataset
SeqData object to write to disk.
store : str, Path
Path to the SeqData object.
chunk_store : {None, MutableMapping, str, Path}, optional
Chunk store to use, by default None
mode : {None, "w", "w-", "a", "r+"}, optional
Mode to use, by default None
synchronizer : None, optional
Synchronizer to use, by default None
group : str, optional
Name of the group to open, by default None
encoding : dict, optional
Encoding to use, by default None
compute : bool, optional
Whether to compute, by default True
consolidated : bool, optional
Whether to consolidate metadata, by default None
append_dim : {None, str}, optional
Name of the append dimension, by default None
region : dict, optional
Region to use, by default None
safe_chunks : bool, optional
Whether to use safe chunks, by default True
storage_options : dict, optional
Storage options to use, by default None
zarr_version : int, optional
Zarr version to use, by default None
Returns
-------
None
"""
sdata = sdata.drop_encoding()
if isinstance(sdata, xr.Dataset):
for coord in sdata.coords.values():
if "_FillValue" in coord.attrs:
del coord.attrs["_FillValue"]
for arr in sdata.data_vars:
sdata[arr] = _uniform_chunking(sdata[arr])
else:
sdata = _uniform_chunking(sdata)
sdata.to_zarr(
store=store,
chunk_store=chunk_store,
mode=mode,
synchronizer=synchronizer,
group=group,
encoding=encoding,
compute=compute, # type: ignore
consolidated=consolidated,
append_dim=append_dim,
region=region,
safe_chunks=safe_chunks,
storage_options=storage_options,
zarr_version=zarr_version,
)
def _uniform_chunking(arr: xr.DataArray):
# rechunk if write requirements are broken. namely:
# - all chunks except the last are the same size
# - the final chunk is <= the size of the rest
# Use chunk size that is:
# 1. most frequent
# 2. to break ties, largest
if arr.chunksizes is not None:
new_chunks = {}
for dim, chunk in arr.chunksizes.items():
# > 1 chunk and either the last chunk is different from the rest
# or the second to last chunk is larger than the last
chunks, counts = np.unique(chunk, return_counts=True)
chunk_size = int(chunks[counts == counts.max()].max())
new_chunks[dim] = chunk_size
if new_chunks != arr.chunksizes:
arr = arr.chunk(new_chunks)
if "_FillValue" in arr.attrs:
del arr.attrs["_FillValue"]
return arr
def from_flat_files(
*readers: FlatReader,
path: PathType,
fixed_length: bool,
sequence_dim: Optional[str] = None,
length_dim: Optional[str] = None,
overwrite=False,
) -> xr.Dataset:
"""Composable function to create a SeqData object from flat files.
Saves a SeqData to disk and open it (without loading it into memory).
TODO: Tutorials coming soon.
Parameters
----------
*readers : FlatReader
Readers to use to create the SeqData object.
path : str, Path
Path to save this SeqData to.
fixed_length : bool
`True`: assume the all sequences have the same length and will infer it
from the first sequence.
`False`: write variable length sequences.
overwrite : bool, optional
Whether to overwrite existing arrays of the SeqData at `path`, by default False
Returns
-------
xr.Dataset
"""
sequence_dim = "_sequence" if sequence_dim is None else sequence_dim
if not fixed_length and length_dim is not None:
warnings.warn("Treating sequences as variable length, ignoring `length_dim`.")
elif fixed_length:
length_dim = "_length" if length_dim is None else length_dim
for reader in readers:
reader._write(
out=path,
fixed_length=fixed_length,
overwrite=overwrite,
sequence_dim=sequence_dim,
length_dim=length_dim,
)
zarr.consolidate_metadata(path) # type: ignore
ds = open_zarr(path)
return ds
def from_region_files(
*readers: RegionReader,
path: PathType,
fixed_length: Union[int, bool],
bed: Union[PathType, pl.DataFrame, pd.DataFrame],
max_jitter=0,
sequence_dim: Optional[str] = None,
length_dim: Optional[str] = None,
splice=False,
overwrite=False,
) -> xr.Dataset:
"""Composable function to create a SeqData object from region based files.
Saves a SeqData to disk and open it (without loading it into memory).
TODO: Tutorials coming soon.
Parameters
----------
*readers : RegionReader
Readers to use to create the SeqData object.
path : str, Path
Path to save this SeqData to.
fixed_length : int, bool, optional
`int`: use regions of this length centered around those in the BED file.
`True`: assume the all sequences have the same length and will try to infer it
from the data.
`False`: write variable length sequences
bed : str, Path, pl.DataFrame, optional
BED file or DataFrame matching the BED3+ specification describing what regions
to write.
max_jitter : int, optional
How much jitter to allow for the SeqData object by writing additional
flanking sequences, by default 0
sequence_dim : str, optional
Name of sequence dimension. Defaults to "_sequence".
length_dim : str, optional
Name of length dimension. Defaults to "_length".
splice : bool, optional
Whether to splice together regions that have the same `name` in the BED file, by
default False
overwrite : bool, optional
Whether to overwrite existing arrays of the SeqData at `path`, by default False
Returns
-------
xr.Dataset
"""
sequence_dim = "_sequence" if sequence_dim is None else sequence_dim
if not fixed_length and length_dim is not None:
warnings.warn("Treating sequences as variable length, ignoring `length_dim`.")
elif fixed_length:
length_dim = "_length" if length_dim is None else length_dim
root = zarr.open_group(path)
root.attrs["max_jitter"] = max_jitter
root.attrs["sequence_dim"] = sequence_dim
root.attrs["length_dim"] = length_dim
if isinstance(bed, (str, Path)):
_bed = read_bedlike(bed)
elif isinstance(bed, pd.DataFrame):
_bed = pl.from_pandas(bed)
else:
_bed = bed
if "strand" not in _bed:
_bed = _bed.with_columns(strand=pl.lit("+"))
if not splice:
if fixed_length is False:
_bed = _expand_regions(_bed, max_jitter)
else:
if fixed_length is True:
fixed_length = cast(
int,
_bed.item(0, "chromEnd") - _bed.item(0, "chromStart"),
)
fixed_length += 2 * max_jitter
_bed = _set_uniform_length_around_center(_bed, fixed_length)
_bed_to_zarr(
_bed,
root,
sequence_dim,
compressor=Blosc("zstd", clevel=7, shuffle=-1),
overwrite=overwrite,
)
else:
if max_jitter > 0:
_bed = _bed.with_columns(
pl.when(pl.col("chromStart") == pl.col("chromStart").min().over("name"))
.then(pl.col("chromStart").min().over("name") - max_jitter)
.otherwise(pl.col("chromStart"))
.alias("chromStart"),
pl.when(pl.col("chromEnd") == pl.col("chromEnd").max().over("name"))
.then(pl.col("chromEnd").max().over("name") + max_jitter)
.otherwise(pl.col("chromEnd"))
.alias("chromEnd"),
)
bed_to_write = _bed.group_by("name").agg(
pl.col(pl.Utf8).first(), pl.exclude(pl.Utf8)
)
_bed_to_zarr(
bed_to_write,
root,
sequence_dim,
compressor=Blosc("zstd", clevel=7, shuffle=-1),
overwrite=overwrite,
)
for reader in readers:
reader._write(
out=path,
bed=_bed,
fixed_length=fixed_length,
sequence_dim=sequence_dim,
length_dim=length_dim,
overwrite=overwrite,
splice=splice,
)
zarr.consolidate_metadata(path) # type: ignore
ds = open_zarr(path)
return ds
@xr.register_dataset_accessor("sd")
class SeqDataAccessor:
def __init__(self, ds: xr.Dataset) -> None:
self._ds = ds
@property
def obs(self):
return _filter_by_exact_dims(self._ds, self._ds.attrs["sequence_dim"])
@property
def layers(self):
return _filter_layers(self._ds)
@property
def obsp(self):
return _filter_by_exact_dims(
self._ds, (self._ds.attrs["sequence_dim"], self._ds.attrs["sequence_dim"])
)
@property
def uns(self):
return _filter_uns(self._ds)
def __repr__(self) -> str:
return "SeqData accessor."
def merge_obs(
sdata: xr.Dataset,
obs: Union[xr.Dataset, pl.DataFrame],
on: Optional[str] = None,
left_on: Optional[str] = None,
right_on: Optional[str] = None,
how: Literal["inner", "left", "right", "outer", "exact"] = "inner",
):
"""Warning: This function is experimental and may change in the future.
Merge observations into a SeqData object along sequence axis.
Parameters
----------
sdata : xr.Dataset
SeqData object.
obs : xr.Dataset, pd.DataFrame
Observations to merge.
on : str, optional
Column to merge on, by default None
left_on : str, optional
Column to merge on from the left dataset, by default None
right_on : str, optional
Column to merge on from the right dataset, by default None
how : {"inner", "left", "right", "outer", "exact"}, optional
Type of merge to perform, by default "inner"
Returns
-------
xr.Dataset
Merged SeqData object.
"""
if on is None and (left_on is None or right_on is None):
raise ValueError("Must provide either `on` or both `left_on` and `right_on`.")
if on is not None and (left_on is not None or right_on is not None):
raise ValueError("Cannot provide both `on` and `left_on` or `right_on`.")
if on is None:
assert left_on is not None
assert right_on is not None
else:
left_on = on
right_on = on
if left_on not in sdata.data_vars:
sdata = sdata.assign({left_on: np.arange(sdata.sizes[left_on])})
if left_on not in sdata.xindexes:
sdata = sdata.set_coords(left_on).set_xindex(left_on)
if isinstance(obs, pl.DataFrame):
obs_ = obs.to_pandas()
if obs_.index.name != right_on:
obs_ = obs_.set_index(right_on)
obs_.index.name = left_on
obs_ = obs_.to_xarray()
sdata_dim = sdata[left_on].dims[0]
obs_dim = obs_[left_on].dims[0]
if sdata_dim != obs_dim:
obs_[left_on].rename({obs_dim: sdata_dim})
sdata = sdata.merge(obs, join=how) # type: ignore
elif isinstance(obs, xr.Dataset):
if right_on not in obs.data_vars:
obs = obs.assign({right_on: np.arange(sdata.sizes[right_on])})
if right_on not in obs.xindexes:
obs = (
obs.rename({right_on: left_on}).set_coords(left_on).set_xindex(left_on)
)
sdata = sdata.merge(obs, join=how)
return sdata
def add_layers_from_files(
sdata: xr.Dataset,
*readers: Union[FlatReader, RegionReader],
path: PathType,
overwrite=False,
):
raise NotImplementedError
# if any(map(lambda r: isinstance(r, RegionReader), readers)):
# bed = sdata[["chrom", "chromStart", "chromEnd", "strand"]].to_dataframe()
# for reader in readers:
# if isinstance(reader, FlatReader):
# if reader.n_seqs is not None and reader.n_seqs != sdata.sizes["_sequence"]:
# raise ValueError(
# f"""Reader "{reader.name}" has a different number of sequences
# than this SeqData."""
# )
# _fixed_length = fixed_length is not False
# reader._write(out=path, fixed_length=_fixed_length, overwrite=overwrite)
# elif isinstance(reader, RegionReader):
# reader._write(
# out=path,
# bed=bed, # type: ignore
# overwrite=overwrite,
# )
# ds = xr.open_zarr(path, mask_and_scale=False, concat_characters=False)
# return ds