Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second round of applycal optimisations #224

Merged
merged 13 commits into from
Feb 20, 2019
10 changes: 1 addition & 9 deletions katdal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,7 @@
from .h5datav2 import H5DataV2
from .h5datav3 import H5DataV3
from .visdatav4 import VisibilityDataV4


# Clean up top-level namespace a bit
bmerry marked this conversation as resolved.
Show resolved Hide resolved
_dataset, _concatdata, _sensordata = dataset, concatdata, sensordata
_h5datav1, _h5datav2, _h5datav3 = h5datav1, h5datav2, h5datav3
_categorical, _lazy_indexer = categorical, lazy_indexer
_spectral_window, _visdatav4 = spectral_window, visdatav4
del dataset, concatdata, sensordata, h5datav1, h5datav2, h5datav3
del categorical, lazy_indexer, spectral_window, visdatav4
from .flags import FLAG_NAMES, FLAG_DESCRIPTIONS


# Setup library logger and add a print-like handler used when no logging is configured
Expand Down
289 changes: 211 additions & 78 deletions katdal/applycal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
from __future__ import print_function, division, absolute_import
from builtins import range, zip

from functools import partial
import copy
import logging
import itertools
import operator

import numpy as np
import dask.array as da
import dask.base
import dask.utils
import toolz
import numba

from .categorical import CategoricalData, ComparableArrayWrapper
from .spectral_window import SpectralWindow
from .flags import POSTPROC


# A constant indicating invalid / absent gain (typically due to flagged data)
Expand All @@ -38,6 +42,106 @@
logger = logging.getLogger(__name__)


def _call_from_block_function(func, shape, num_chunks, chunk_location, array_location, func_kwargs):
ludwigschwardt marked this conversation as resolved.
Show resolved Hide resolved
block_info = {
'shape': shape,
'num-chunks': num_chunks,
'chunk-location': chunk_location,
'array-location': list(array_location)
}
return func(block_info, **func_kwargs)


# This has been submitted to dask as https://github.com/dask/dask/pull/4476.
# If it gets merged it can be used rather than copied here. There are also
# unit tests there.
def from_block_function(func, shape, chunks='auto', dtype=None, name=None, **kwargs):
"""
Create an array from a function that builds individual blocks.

For each block, the function is passed a dictionary with information about
the block to construct, and should return a numpy array.

>>> block_info # doctest: +SKIP
{'shape': (12, 20),
'num-chunks': (3, 4),
'chunk-location': (2, 1),
'array-location': [(8, 12), (5, 10)]
}

The values in the dictionary are respectively the shape of the full
array, the number of chunks in the full array in each dimension, the
position of this block in chunks, and the position in the array
(for example, the slice corresponding to ``8:12, 5:10``).

Parameters
----------
func : callable
Function to produce every block in the array
shape : Tuple[int]
Shape of the resulting array.
chunks : tuple, optional
Chunk shape of resulting blocks. If not provided, a chunking scheme
is chosen automatically.
dtype : np.dtype, optional
The ``dtype`` of the output array. It is recommended to provide this.
If not provided, will be inferred by applying the function to a small
set of fake data.
name : str, optional
The key name to use for the output array. If not provided,
will be determined from `func`.
**kwargs :
Other keyword arguments to pass to function. Values must be constants
(not dask.arrays)

Examples
--------
This is a simplified version of :func:`eye` which only handles square
arrays with the ones on the main diagonal.

>>> def eye_chunk(block_info):
... location = block_info['array-location']
... r0, r1 = location[0]
... c0, c1 = location[1]
... if r0 == c0:
... return np.eye(r1 - r0, c1 - c0)
... else:
... return np.zeros((r1 - r0, c1 - c0))
>>> from_block_function(eye_chunk, (4, 4), chunks=2, dtype=float)
dask.array<eye_chunk, shape=(4, 4), dtype=float64, chunksize=(2, 2)>
>>> _.compute()
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
"""

name = '%s-%s' % (name or dask.utils.funcname(func),
dask.base.tokenize(func, shape, dtype, chunks))

if dtype is None:
dummy_block_info = {
'shape': shape,
'num-chunks': shape,
'chunk-location': (0,) * len(shape),
'array-location': [(0, 1)] * len(shape)
}
dtype = da.apply_infer_dtype(func, [dummy_block_info], kwargs, 'from_block_function')

chunks = da.core.normalize_chunks(chunks, shape, dtype=dtype)
# Allow for shape=None when chunks are already in normalized form
shape = tuple(sum(bd) for bd in chunks)

keys = list(itertools.product([name], *[range(len(bd)) for bd in chunks]))
aggdims = [list(toolz.accumulate(operator.add, (0,) + bd)) for bd in chunks]
locdims = [list(zip(a[:-1], a[1:])) for a in aggdims]
locations = list(itertools.product(*locdims))
num_chunks = tuple(len(bd) for bd in chunks)
dsk = {key: (_call_from_block_function, func, shape, num_chunks, key[1:], location, kwargs)
for key, location in zip(keys, locations)}
return da.Array(dsk, name, chunks, dtype=dtype)


def complex_interp(x, xi, yi, left=None, right=None):
"""Piecewise linear interpolation of magnitude and phase of complex values.

Expand Down Expand Up @@ -273,8 +377,31 @@ def _correction_inputs_to_corrprods(g_per_cp, g_per_input, input1_index, input2_
* np.conj(g_per_input[i, input2_index[j]]))


def calc_correction_per_corrprod(dump, channels, cache, inputs,
input1_index, input2_index, cal_products):
class CorrectionParams(object):
"""Data needed to compute corrections in :func:`calc_correction_per_corrprod`.

Once constructed, the data in this class must not be modified, as it will
be baked into dask graphs.

Parameters
----------
products : dict
A dictionary (indexed by cal product name) of lists (indexed
by input) of sequences (indexed by dump) of numpy arrays, with
corrections to apply.
inputs : list of str
Names of inputs, in the same order as the input axis of products
input1_index, input2_index : ndarray
Indices into `inputs` of first and second items of correlation product
"""
def __init__(self, inputs, input1_index, input2_index, products):
self.inputs = inputs
self.input1_index = input1_index
self.input2_index = input2_index
self.products = products


def calc_correction_per_corrprod(dump, channels, params):
"""Gain correction per channel per correlation product for a given dump.

This calculates an array of complex gain correction terms of shape
Expand All @@ -286,16 +413,10 @@ def calc_correction_per_corrprod(dump, channels, cache, inputs,
----------
dump : int
Dump index (applicable to full data set, i.e. absolute)
channels : list of int, length n_chans
channels : slice
Channel indices (applicable to full data set, i.e. absolute)
cache : :class:`katdal.sensordata.SensorCache` object
Sensor cache, used to look up individual correction sensors
inputs : sequence of string
Correlator input labels
input1_index, input2_index : list of int, length n_corrprods
Indices into `inputs` of first and second items of correlation product
cal_products : sequence of string
Calibration products that will contribute to corrections
params : :class:`CorrectionParams`
Data for obtaining corrections to merge

Returns
-------
Expand All @@ -307,103 +428,115 @@ def calc_correction_per_corrprod(dump, channels, cache, inputs,
KeyError
If input and/or cal product has no associated correction
"""
g_per_input = np.ones((len(inputs), len(channels)), dtype='complex64')
for product in cal_products:
for n, inp in enumerate(inputs):
sensor_name = 'Calibration/{}_correction_{}'.format(inp, product)
g_product = cache.get(sensor_name)[dump]
n_channels = channels.stop - channels.start
g_per_input = np.ones((len(params.inputs), n_channels), dtype='complex64')
bmerry marked this conversation as resolved.
Show resolved Hide resolved
for product in params.products.values():
for n in range(len(params.inputs)):
sensor = product[n]
ludwigschwardt marked this conversation as resolved.
Show resolved Hide resolved
g_product = sensor[dump]
if np.shape(g_product) != ():
g_product = g_product[channels]
g_per_input[n] *= g_product
# Ensure these are arrays for the benefit of numba
input1_index = np.asarray(input1_index)
input2_index = np.asarray(input2_index)
# Transpose to (channel, input) order, and ensure C ordering
g_per_input = np.ascontiguousarray(g_per_input.T)
g_per_cp = np.empty((len(channels), len(input1_index)), np.complex64)
_correction_inputs_to_corrprods(g_per_cp, g_per_input, input1_index, input2_index)
g_per_cp = np.empty((n_channels, len(params.input1_index)), np.complex64)
_correction_inputs_to_corrprods(g_per_cp, g_per_input, params.input1_index, params.input2_index)
return g_per_cp


@numba.jit(nopython=True, nogil=True)
def apply_vis_correction(out, correction):
"""Clean up and apply `correction` in-place to visibility data in `out`."""
def apply_vis_correction(data, correction):
"""Clean up and apply `correction` visibility data in `data`."""
bmerry marked this conversation as resolved.
Show resolved Hide resolved
out = np.empty_like(data)
for i in range(out.shape[0]):
for j in range(out.shape[1]):
c = correction[i, j]
if not np.isnan(c):
out[i, j] *= c
for k in range(out.shape[2]):
c = correction[i, j, k]
if not np.isnan(c):
out[i, j, k] = data[i, j, k] * c
else:
out[i, j, k] = data[i, j, k]
return out


@numba.jit(nopython=True, nogil=True)
def apply_weights_correction(out, correction):
"""Clean up and apply `correction` in-place to weight data in `out`."""
def apply_weights_correction(data, correction):
"""Clean up and apply `correction` to weight data in `data`."""
out = np.empty_like(data)
for i in range(out.shape[0]):
for j in range(out.shape[1]):
cc = correction[i, j]
c = cc.real**2 + cc.imag**2
if c > 0: # Will be false if c is NaN
out[i, j] /= c
else:
out[i, j] = 0
for k in range(out.shape[2]):
cc = correction[i, j, k]
c = cc.real**2 + cc.imag**2
if c > 0: # Will be false if c is NaN
out[i, j, k] = data[i, j, k] / c
else:
out[i, j, k] = 0
return out


@numba.jit(nopython=True, nogil=True)
def apply_flags_correction(out, correction):
"""Update flag data in `out` to True wherever `correction` is invalid."""
def apply_flags_correction(data, correction):
"""Update flag data to True wherever `correction` is invalid."""
bmerry marked this conversation as resolved.
Show resolved Hide resolved
out = np.copy(data)
for i in range(out.shape[0]):
for j in range(out.shape[1]):
out[i, j] |= np.isnan(correction[i, j])
for k in range(out.shape[2]):
if np.isnan(correction[i, j, k]):
out[i, j, k] |= POSTPROC
return out


def _correction_block(block_info, params):
bmerry marked this conversation as resolved.
Show resolved Hide resolved
slices = tuple(slice(*l) for l in block_info['array-location'])
block_shape = tuple(s.stop - s.start for s in slices)
correction = np.empty(block_shape, np.complex64)
# TODO: make calc_correction_per_corrprod multi-dump aware
for n, dump in enumerate(range(slices[0].start, slices[0].stop)):
correction[n] = calc_correction_per_corrprod(dump, slices[1], params)
return correction

def add_applycal_transform(indexer, cache, corrprods, cal_products,
apply_correction):
"""Add transform to indexer that applies calibration corrections.

This adds a transform to the indexer which wraps the underlying data
(visibilities, weights or flags). The transform will apply all calibration
corrections specified in `cal_products` to each dask chunk individually.
The actual application method is also user-specified, which allows most
of the machinery to be reused between visibilities, weights and flags.
The time and frequency selections are salvaged from `indexer` but the
selected `corrprods` still needs to be passed in as a parameter to identify
the relevant inputs in order to access correction sensors.
def calc_correction(chunks, cache, corrprods, cal_products):
"""Create a dask array containing applycal corrections.

Parameters
----------
indexer : :class:`katdal.lazy_indexer.DaskLazyIndexer` object
Indexer with underlying dask array that will be transformed
chunks : tuple of tuple of int
Chunking scheme of the resulting array, in normalized form (see
:func:`dask.array.core.normalize_chunks`).
cache : :class:`katdal.sensordata.SensorCache` object
Sensor cache, used to look up individual correction sensors
corrprods : sequence of (string, string)
Selected correlation products as pairs of correlator input labels
cal_products : sequence of string
Calibration products that will contribute to corrections
apply_correction : function, signature ``out = f(out, correction)``
Function that will actually apply correction to data from indexer
"""
stage1_indices = tuple(k.nonzero()[0] for k in indexer.keep)
# Turn corrprods into a list of input labels and two lists of indices
shape = tuple(sum(bd) for bd in chunks)
if len(chunks[2]) > 1:
bmerry marked this conversation as resolved.
Show resolved Hide resolved
logger.warning('ignoring chunking on baseline axis')
chunks = (chunks[0], chunks[1], (shape[2],))
inputs = sorted(set(np.ravel(corrprods)))
input1_index = [inputs.index(cp[0]) for cp in corrprods]
input2_index = [inputs.index(cp[1]) for cp in corrprods]
# Prevent cal_products from changing underneath us if caller changes theirs
cal_products = copy.deepcopy(cal_products)

def calibrate_chunk(chunk, block_info):
"""Apply all specified calibration corrections to chunk."""
corrected_chunk = chunk.copy()
# Tuple of slices that cuts out `chunk` from full array
slices = tuple(slice(*l) for l in block_info[0]['array-location'])
dumps, chans, _ = tuple(i[s] for i, s in zip(stage1_indices, slices))
index1 = input1_index[slices[2]]
index2 = input2_index[slices[2]]
ccpc_args = (chans, cache, inputs, index1, index2, cal_products)
for n, dump in enumerate(dumps):
correction = calc_correction_per_corrprod(dump, *ccpc_args)
apply_correction(corrected_chunk[n], correction)
return corrected_chunk

transform = partial(da.map_blocks, calibrate_chunk, dtype=indexer.dtype)
transform.__name__ = 'applycal[{}]'.format(','.join(cal_products))
indexer.add_transform(transform)
input1_index = np.array([inputs.index(cp[0]) for cp in corrprods])
input2_index = np.array([inputs.index(cp[1]) for cp in corrprods])
products = {}
for product in cal_products:
products[product] = []
for i, inp in enumerate(inputs):
sensor_name = 'Calibration/{}_correction_{}'.format(inp, product)
sensor = cache.get(sensor_name)
# Indexing CategoricalData by dump is relatively slow (tens of
# microseconds), so expand it into a plain-old Python list.
if isinstance(sensor, CategoricalData):
ludwigschwardt marked this conversation as resolved.
Show resolved Hide resolved
data = [None] * sensor.events[-1]
for s, v in sensor.segments():
for j in range(s.start, s.stop):
data[j] = v
bmerry marked this conversation as resolved.
Show resolved Hide resolved
else:
data = sensor
products[product].append(data)
params = CorrectionParams(inputs, input1_index, input2_index, products)

return from_block_function(
_correction_block, shape=shape, chunks=chunks, dtype=np.complex64, name='correction',
ludwigschwardt marked this conversation as resolved.
Show resolved Hide resolved
params=params)
Loading