Skip to content

Commit

Permalink
new 3d export with memory optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal committed Oct 26, 2022
1 parent f315777 commit dc7cacd
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 114 deletions.
115 changes: 61 additions & 54 deletions src/mdio/converters/mdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@

from __future__ import annotations

import uuid
from os import path

import dask.array as da
import numpy as np
from dask.array.core import Array
from dask.base import compute_as_if_collection
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
from tqdm.dask import TqdmCallback

from mdio import MDIOReader
from mdio.segy._workers import write_block_to_segy
from mdio.segy._workers import chunk_to_sgy_stack
from mdio.segy.byte_utils import ByteOrder
from mdio.segy.byte_utils import Dtype
from mdio.segy.creation import concat_files
from mdio.segy.creation import mdio_spec_to_segy
from mdio.segy.creation import preprocess_headers
from mdio.segy.creation import preprocess_samples


try:
Expand Down Expand Up @@ -95,7 +98,7 @@ def mdio_to_segy( # noqa: C901
... mdio_path_or_buffer="prefix2/file.mdio",
... output_segy_path="prefix/file.segy",
... selection_mask=boolean_mask,
... out_sample_format="ieee32",
... out_sample_format="float32",
... )
"""
Expand Down Expand Up @@ -137,8 +140,6 @@ def mdio_to_segy( # noqa: C901
else:
mdio, sample_format = mdio_spec_to_segy(*creation_args)

num_samp = mdio.shape[-1]

live_mask = mdio.live_mask.compute()

if selection_mask is not None:
Expand All @@ -158,76 +159,77 @@ def mdio_to_segy( # noqa: C901
dim_slices += (slice(start, stop),)

# Lazily pull the data with limits now, and limit mask so its the same shape.
live_mask, headers, traces = mdio[dim_slices]
live_mask, headers, samples = mdio[dim_slices]
live_mask = live_mask.rechunk(headers.chunksize)

if selection_mask is not None:
selection_mask = selection_mask[dim_slices]
live_mask = live_mask & selection_mask

# Now we flatten the data in the slowest changing axis (i.e. 0)
# TODO: Add support for flipping these, if user wants
axis = 0

# Get new chunksizes for sequential array
seq_trc_chunks = tuple(
(dim_chunks if idx == axis else (sum(dim_chunks),))
for idx, dim_chunks in enumerate(traces.chunks)
)

# We must unify chunks with "trc_chunks" here because
# headers and live mask may have different chunking.
# We don't take the time axis for headers / live
# Still lazy computation
traces_seq = traces.rechunk(seq_trc_chunks)
headers_seq = headers.rechunk(seq_trc_chunks[:-1])
live_seq = live_mask.rechunk(seq_trc_chunks[:-1])

# Build a Dask graph to do the computation
# Name of task. Using uuid1 is important because
# we could potentially generate these from different machines
task_name = "block-to-sgy-part-" + str(uuid.uuid1())
write_task_name = "write_sgy_block"

out_dtype = Dtype[out_sample_format.upper()]
out_byteorder = ByteOrder[endian.upper()]
samples_proc = da.blockwise(
preprocess_samples,
"ijk",
samples,
"ijk",
live_mask,
"ij",
out_dtype=out_dtype,
out_byteorder=out_byteorder,
)

trace_keys = flatten(traces_seq.__dask_keys__())
header_keys = flatten(headers_seq.__dask_keys__())
live_keys = flatten(live_seq.__dask_keys__())
headers_proc = da.blockwise(
preprocess_headers,
"ij",
headers,
"ij",
live_mask,
"ij",
out_byteorder=out_byteorder,
)

all_keys = zip(trace_keys, header_keys, live_keys)
sample_keys = samples_proc.__dask_keys__()
header_keys = headers_proc.__dask_keys__()
live_keys = live_mask.__dask_keys__()

# tmp file root
out_dir = path.dirname(output_segy_path)

task_graph_dict = {}
block_file_paths = []
for idx, (trace_key, header_key, live_key) in enumerate(all_keys):
block_file_name = f".{idx}_{uuid.uuid1()}._segyblock"
block_file_path = path.join(out_dir, block_file_name)
block_file_paths.append(block_file_path)

block_args = (
block_file_path,
trace_key,
header_key,
live_key,
num_samp,
sample_format,
endian,
)

task_graph_dict[(task_name, idx)] = (write_block_to_segy,) + block_args
for row in range(live_mask.blocks.shape[0]):
for col in range(live_mask.blocks.shape[1]):
block_args = (
sample_keys[row][col][0],
header_keys[row][col],
live_keys[row][col],
out_dir,
row,
col,
)

task_graph_dict[(write_task_name, row, col)] = (
chunk_to_sgy_stack,
) + block_args

# Make actual graph
task_graph = HighLevelGraph.from_collections(
task_name,
write_task_name,
task_graph_dict,
dependencies=[traces_seq, headers_seq, live_seq],
dependencies=[samples_proc, headers_proc, live_mask],
)

# Note this doesn't work with distributed.
tqdm_kw = dict(unit="block", dynamic_ncols=True)
block_progress = TqdmCallback(desc="Step 1 / 2 Writing Blocks", **tqdm_kw)

with block_progress:
block_exists = compute_as_if_collection(
results = compute_as_if_collection(
cls=Array,
dsk=task_graph,
keys=list(task_graph_dict),
Expand All @@ -236,10 +238,15 @@ def mdio_to_segy( # noqa: C901

concat_file_paths = [output_segy_path]

for filename, is_full in zip(block_file_paths, block_exists):
if not is_full:
continue
concat_file_paths.append(filename)
concat_list = []
for block in results:
for file, exists in block:
if exists:
concat_list.append(file)

concat_list.sort()

concat_file_paths += concat_list

if client is not None:
_ = client.submit(concat_files, concat_file_paths).result()
Expand Down
116 changes: 63 additions & 53 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

from __future__ import annotations

from os import path
from typing import Any
from typing import Sequence
from uuid import uuid1

import numpy as np
import segyio
from numpy.typing import ArrayLike
from numpy.typing import NDArray
from zarr import Array

from mdio.constants import UINT32_MAX
from mdio.core import Grid
from mdio.segy.byte_utils import ByteOrder
from mdio.segy.ibm_float import ieee2ibm


def header_scan_worker(
Expand Down Expand Up @@ -197,69 +199,77 @@ def trace_worker(
return count, chunk_sum, chunk_sum_squares, min_val, max_val


def write_block_to_segy(
block_out_path: str,
traces: np.ndarray,
headers: np.ndarray,
live: np.ndarray,
num_samp: int,
sample_format: str,
endian: str,
) -> int:
"""Write a block of traces to a SEG-Y file without text and binary headers.
def traces_to_file(
samples: NDArray,
headers: NDArray,
live: NDArray,
out_path: str,
) -> None:
"""Interlace headers and samples to form traces and write them out.
Args:
block_out_path: Path to write the block.
traces: Trace data.
headers: Headers for `traces`.
live: Live mask for `traces`.
num_samp: Number of samples in traces.
sample_format: Sample output format. Must be in {"ibm", "ieee"}.
endian: Endianness of the sample format. Must be in {"little", "big"}.
Returns:
Returns the integer "1" if successful. Returns "0" if all traces
in the block were zero.
Raises:
OSError: if unsupported SEG-Y sample format is found
samples: Sample data.
headers: Header data.
live: Live mask.
out_path: Path to the output file.
"""
if np.count_nonzero(live) == 0:
return 0

# Drop dead traces, this also makes data sequential.
traces = traces[live]
headers = headers[live]
live = live[live]

# Handle float formats
if sample_format == 1: # IBM
trace_dtype = num_samp * np.dtype("uint32")
traces = ieee2ibm(traces)
elif sample_format == 5: # IEEE
trace_dtype = num_samp * traces.dtype
else:
raise OSError("Unknown SEG-Y sample format")

full_dtype = {
"names": ("header", "pad", "trace"),
"formats": [headers.dtype, np.dtype("int64"), trace_dtype],
"formats": [
headers.dtype,
np.dtype("int64"),
samples.shape[-1] * samples.dtype,
],
}
full_dtype = np.dtype(full_dtype)

full_trace = np.empty(len(live), dtype=full_dtype)
full_trace["header"] = headers
full_trace["pad"].fill(0)
full_trace["trace"] = traces
n_live = np.count_nonzero(live)
trace = np.empty(n_live, dtype=full_dtype)

trace["header"] = headers[live]
trace["pad"] = 0
trace["trace"] = samples[live]

with open(out_path, mode="wb") as fp:
trace.tofile(fp)


def chunk_to_sgy_stack(
samples: NDArray,
headers: NDArray,
live: NDArray,
out_root: str,
row: int,
col: int,
) -> list[str]:
"""Convert a partial chunk (block) to stack of SEG-Y traces.
Args:
samples: Sample data.
headers: Header data.
live: Live mask.
out_root: Root directory for output file.
row: Row index of chunk block within full array.
col: Col index of chunk block within full array.
Returns:
List of (path, exists) tuples created in this function.
"""
block_files = []

for idx, (s, h, l) in enumerate(zip(samples, headers, live)):
f_name = f".{row:05d}_{idx:05d}_{col:05d}_{str(uuid1())}.sgyblock"
f_path = path.join(out_root, f_name)

if endian == "big":
full_trace.byteswap(inplace=True)
if np.count_nonzero(l) == 0:
block_files.append((f_path, 0))
continue

# This will write the SEG-Y file with the same order as MDIO.
with open(block_out_path, "wb") as out_file:
out_file.write(full_trace.tobytes())
block_files.append((f_path, 1))
traces_to_file(s, h, l, f_path)

return 1
return block_files


# tqdm only works properly with pool.map
Expand Down
Loading

0 comments on commit dc7cacd

Please sign in to comment.