diff --git a/katdal/__init__.py b/katdal/__init__.py index 1695f1a2..b2589a9b 100644 --- a/katdal/__init__.py +++ b/katdal/__init__.py @@ -237,15 +237,6 @@ from .visdatav4 import VisibilityDataV4 -# Clean up top-level namespace a bit -_dataset, _concatdata, _sensordata = dataset, concatdata, sensordata -_h5datav1, _h5datav2, _h5datav3 = h5datav1, h5datav2, h5datav3 -_categorical, _lazy_indexer = categorical, lazy_indexer -_spectral_window, _visdatav4 = spectral_window, visdatav4 -del dataset, concatdata, sensordata, h5datav1, h5datav2, h5datav3 -del categorical, lazy_indexer, spectral_window, visdatav4 - - # Setup library logger and add a print-like handler used when no logging is configured class _NoConfigFilter(_logging.Filter): """Filter which only allows event if top-level logging is not configured.""" diff --git a/katdal/applycal.py b/katdal/applycal.py index 01d16640..3a651c89 100644 --- a/katdal/applycal.py +++ b/katdal/applycal.py @@ -18,16 +18,20 @@ from __future__ import print_function, division, absolute_import from builtins import range, zip -from functools import partial -import copy import logging +import itertools +import operator import numpy as np import dask.array as da +import dask.base +import dask.utils +import toolz import numba from .categorical import CategoricalData, ComparableArrayWrapper from .spectral_window import SpectralWindow +from .flags import POSTPROC # A constant indicating invalid / absent gain (typically due to flagged data) @@ -38,6 +42,106 @@ logger = logging.getLogger(__name__) +def _call_from_block_function(func, shape, num_chunks, chunk_location, array_location, func_kwargs): + block_info = { + 'shape': shape, + 'num-chunks': num_chunks, + 'chunk-location': chunk_location, + 'array-location': list(array_location) + } + return func(block_info, **func_kwargs) + + +# This has been submitted to dask as https://github.com/dask/dask/pull/4476. +# If it gets merged it can be used rather than copied here. There are also +# unit tests there. +def from_block_function(func, shape, chunks='auto', dtype=None, name=None, **kwargs): + """ + Create an array from a function that builds individual blocks. + + For each block, the function is passed a dictionary with information about + the block to construct, and should return a numpy array. + + >>> block_info # doctest: +SKIP + {'shape': (12, 20), + 'num-chunks': (3, 4), + 'chunk-location': (2, 1), + 'array-location': [(8, 12), (5, 10)] + } + + The values in the dictionary are respectively the shape of the full + array, the number of chunks in the full array in each dimension, the + position of this block in chunks, and the position in the array + (for example, the slice corresponding to ``8:12, 5:10``). + + Parameters + ---------- + func : callable + Function to produce every block in the array + shape : Tuple[int] + Shape of the resulting array. + chunks : tuple, optional + Chunk shape of resulting blocks. If not provided, a chunking scheme + is chosen automatically. + dtype : np.dtype, optional + The ``dtype`` of the output array. It is recommended to provide this. + If not provided, will be inferred by applying the function to a small + set of fake data. + name : str, optional + The key name to use for the output array. If not provided, + will be determined from `func`. + **kwargs : + Other keyword arguments to pass to function. Values must be constants + (not dask.arrays) + + Examples + -------- + This is a simplified version of :func:`eye` which only handles square + arrays with the ones on the main diagonal. + + >>> def eye_chunk(block_info): + ... location = block_info['array-location'] + ... r0, r1 = location[0] + ... c0, c1 = location[1] + ... if r0 == c0: + ... return np.eye(r1 - r0, c1 - c0) + ... else: + ... return np.zeros((r1 - r0, c1 - c0)) + >>> from_block_function(eye_chunk, (4, 4), chunks=2, dtype=float) + dask.array + >>> _.compute() + array([[1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) + """ + + name = '%s-%s' % (name or dask.utils.funcname(func), + dask.base.tokenize(func, shape, dtype, chunks)) + + if dtype is None: + dummy_block_info = { + 'shape': shape, + 'num-chunks': shape, + 'chunk-location': (0,) * len(shape), + 'array-location': [(0, 1)] * len(shape) + } + dtype = da.apply_infer_dtype(func, [dummy_block_info], kwargs, 'from_block_function') + + chunks = da.core.normalize_chunks(chunks, shape, dtype=dtype) + # Allow for shape=None when chunks are already in normalized form + shape = tuple(sum(bd) for bd in chunks) + + keys = list(itertools.product([name], *[range(len(bd)) for bd in chunks])) + aggdims = [list(toolz.accumulate(operator.add, (0,) + bd)) for bd in chunks] + locdims = [list(zip(a[:-1], a[1:])) for a in aggdims] + locations = list(itertools.product(*locdims)) + num_chunks = tuple(len(bd) for bd in chunks) + dsk = {key: (_call_from_block_function, func, shape, num_chunks, key[1:], location, kwargs) + for key, location in zip(keys, locations)} + return da.Array(dsk, name, chunks, dtype=dtype) + + def complex_interp(x, xi, yi, left=None, right=None): """Piecewise linear interpolation of magnitude and phase of complex values. @@ -273,8 +377,31 @@ def _correction_inputs_to_corrprods(g_per_cp, g_per_input, input1_index, input2_ * np.conj(g_per_input[i, input2_index[j]])) -def calc_correction_per_corrprod(dump, channels, cache, inputs, - input1_index, input2_index, cal_products): +class CorrectionParams(object): + """Data needed to compute corrections in :func:`calc_correction_per_corrprod`. + + Once constructed, the data in this class must not be modified, as it will + be baked into dask graphs. + + Parameters + ---------- + products : dict + A dictionary (indexed by cal product name) of lists (indexed + by input) of sequences (indexed by dump) of numpy arrays, with + corrections to apply. + inputs : list of str + Names of inputs, in the same order as the input axis of products + input1_index, input2_index : ndarray + Indices into `inputs` of first and second items of correlation product + """ + def __init__(self, inputs, input1_index, input2_index, products): + self.inputs = inputs + self.input1_index = input1_index + self.input2_index = input2_index + self.products = products + + +def calc_correction_per_corrprod(dump, channels, params): """Gain correction per channel per correlation product for a given dump. This calculates an array of complex gain correction terms of shape @@ -286,16 +413,10 @@ def calc_correction_per_corrprod(dump, channels, cache, inputs, ---------- dump : int Dump index (applicable to full data set, i.e. absolute) - channels : list of int, length n_chans + channels : slice Channel indices (applicable to full data set, i.e. absolute) - cache : :class:`katdal.sensordata.SensorCache` object - Sensor cache, used to look up individual correction sensors - inputs : sequence of string - Correlator input labels - input1_index, input2_index : list of int, length n_corrprods - Indices into `inputs` of first and second items of correlation product - cal_products : sequence of string - Calibration products that will contribute to corrections + params : :class:`CorrectionParams` + Data for obtaining corrections to merge Returns ------- @@ -307,103 +428,115 @@ def calc_correction_per_corrprod(dump, channels, cache, inputs, KeyError If input and/or cal product has no associated correction """ - g_per_input = np.ones((len(inputs), len(channels)), dtype='complex64') - for product in cal_products: - for n, inp in enumerate(inputs): - sensor_name = 'Calibration/{}_correction_{}'.format(inp, product) - g_product = cache.get(sensor_name)[dump] + n_channels = channels.stop - channels.start + g_per_input = np.ones((len(params.inputs), n_channels), dtype='complex64') + for product in params.products.values(): + for n in range(len(params.inputs)): + sensor = product[n] + g_product = sensor[dump] if np.shape(g_product) != (): g_product = g_product[channels] g_per_input[n] *= g_product - # Ensure these are arrays for the benefit of numba - input1_index = np.asarray(input1_index) - input2_index = np.asarray(input2_index) # Transpose to (channel, input) order, and ensure C ordering g_per_input = np.ascontiguousarray(g_per_input.T) - g_per_cp = np.empty((len(channels), len(input1_index)), np.complex64) - _correction_inputs_to_corrprods(g_per_cp, g_per_input, input1_index, input2_index) + g_per_cp = np.empty((n_channels, len(params.input1_index)), dtype='complex64') + _correction_inputs_to_corrprods(g_per_cp, g_per_input, params.input1_index, params.input2_index) return g_per_cp @numba.jit(nopython=True, nogil=True) -def apply_vis_correction(out, correction): - """Clean up and apply `correction` in-place to visibility data in `out`.""" +def apply_vis_correction(data, correction): + """Clean up and apply `correction` to visibility data in `data`.""" + out = np.empty_like(data) for i in range(out.shape[0]): for j in range(out.shape[1]): - c = correction[i, j] - if not np.isnan(c): - out[i, j] *= c + for k in range(out.shape[2]): + c = correction[i, j, k] + if not np.isnan(c): + out[i, j, k] = data[i, j, k] * c + else: + out[i, j, k] = data[i, j, k] + return out @numba.jit(nopython=True, nogil=True) -def apply_weights_correction(out, correction): - """Clean up and apply `correction` in-place to weight data in `out`.""" +def apply_weights_correction(data, correction): + """Clean up and apply `correction` to weight data in `data`.""" + out = np.empty_like(data) for i in range(out.shape[0]): for j in range(out.shape[1]): - cc = correction[i, j] - c = cc.real**2 + cc.imag**2 - if c > 0: # Will be false if c is NaN - out[i, j] /= c - else: - out[i, j] = 0 + for k in range(out.shape[2]): + cc = correction[i, j, k] + c = cc.real**2 + cc.imag**2 + if c > 0: # Will be false if c is NaN + out[i, j, k] = data[i, j, k] / c + else: + out[i, j, k] = 0 + return out @numba.jit(nopython=True, nogil=True) -def apply_flags_correction(out, correction): - """Update flag data in `out` to True wherever `correction` is invalid.""" +def apply_flags_correction(data, correction): + """Set POSTPROC flag wherever `correction` is invalid.""" + out = np.copy(data) for i in range(out.shape[0]): for j in range(out.shape[1]): - out[i, j] |= np.isnan(correction[i, j]) + for k in range(out.shape[2]): + if np.isnan(correction[i, j, k]): + out[i, j, k] |= POSTPROC + return out -def add_applycal_transform(indexer, cache, corrprods, cal_products, - apply_correction): - """Add transform to indexer that applies calibration corrections. +def _correction_block(block_info, params): + slices = tuple(slice(*l) for l in block_info['array-location']) + block_shape = tuple(s.stop - s.start for s in slices) + correction = np.empty(block_shape, np.complex64) + # TODO: make calc_correction_per_corrprod multi-dump aware + for n, dump in enumerate(range(slices[0].start, slices[0].stop)): + correction[n] = calc_correction_per_corrprod(dump, slices[1], params) + return correction - This adds a transform to the indexer which wraps the underlying data - (visibilities, weights or flags). The transform will apply all calibration - corrections specified in `cal_products` to each dask chunk individually. - The actual application method is also user-specified, which allows most - of the machinery to be reused between visibilities, weights and flags. - The time and frequency selections are salvaged from `indexer` but the - selected `corrprods` still needs to be passed in as a parameter to identify - the relevant inputs in order to access correction sensors. + +def calc_correction(chunks, cache, corrprods, cal_products): + """Create a dask array containing applycal corrections. Parameters ---------- - indexer : :class:`katdal.lazy_indexer.DaskLazyIndexer` object - Indexer with underlying dask array that will be transformed + chunks : tuple of tuple of int + Chunking scheme of the resulting array, in normalized form (see + :func:`dask.array.core.normalize_chunks`). cache : :class:`katdal.sensordata.SensorCache` object Sensor cache, used to look up individual correction sensors corrprods : sequence of (string, string) Selected correlation products as pairs of correlator input labels cal_products : sequence of string Calibration products that will contribute to corrections - apply_correction : function, signature ``out = f(out, correction)`` - Function that will actually apply correction to data from indexer """ - stage1_indices = tuple(k.nonzero()[0] for k in indexer.keep) - # Turn corrprods into a list of input labels and two lists of indices + shape = tuple(sum(bd) for bd in chunks) + if len(chunks[2]) > 1: + logger.warning('ignoring chunking on baseline axis') + chunks = (chunks[0], chunks[1], (shape[2],)) inputs = sorted(set(np.ravel(corrprods))) - input1_index = [inputs.index(cp[0]) for cp in corrprods] - input2_index = [inputs.index(cp[1]) for cp in corrprods] - # Prevent cal_products from changing underneath us if caller changes theirs - cal_products = copy.deepcopy(cal_products) - - def calibrate_chunk(chunk, block_info): - """Apply all specified calibration corrections to chunk.""" - corrected_chunk = chunk.copy() - # Tuple of slices that cuts out `chunk` from full array - slices = tuple(slice(*l) for l in block_info[0]['array-location']) - dumps, chans, _ = tuple(i[s] for i, s in zip(stage1_indices, slices)) - index1 = input1_index[slices[2]] - index2 = input2_index[slices[2]] - ccpc_args = (chans, cache, inputs, index1, index2, cal_products) - for n, dump in enumerate(dumps): - correction = calc_correction_per_corrprod(dump, *ccpc_args) - apply_correction(corrected_chunk[n], correction) - return corrected_chunk - - transform = partial(da.map_blocks, calibrate_chunk, dtype=indexer.dtype) - transform.__name__ = 'applycal[{}]'.format(','.join(cal_products)) - indexer.add_transform(transform) + input1_index = np.array([inputs.index(cp[0]) for cp in corrprods]) + input2_index = np.array([inputs.index(cp[1]) for cp in corrprods]) + products = {} + for product in cal_products: + products[product] = [] + for i, inp in enumerate(inputs): + sensor_name = 'Calibration/{}_correction_{}'.format(inp, product) + sensor = cache.get(sensor_name) + # Indexing CategoricalData by dump is relatively slow (tens of + # microseconds), so expand it into a plain-old Python list. + if isinstance(sensor, CategoricalData): + data = [None] * sensor.events[-1] + for s, v in sensor.segments(): + for j in range(s.start, s.stop): + data[j] = v + else: + data = sensor + products[product].append(data) + params = CorrectionParams(inputs, input1_index, input2_index, products) + name = 'corrections[{}]'.format(','.join(cal_products)) + return from_block_function( + _correction_block, shape=shape, chunks=chunks, dtype=np.complex64, name=name, + params=params) diff --git a/katdal/datasources.py b/katdal/datasources.py index 94e719c8..2eb91215 100644 --- a/katdal/datasources.py +++ b/katdal/datasources.py @@ -35,6 +35,7 @@ from .sensordata import TelstateSensorData, TelstateToStr from .chunkstore_s3 import S3ChunkStore from .chunkstore_npy import NpyFileChunkStore +from .flags import DATA_LOST logger = logging.getLogger(__name__) @@ -98,7 +99,7 @@ def _apply_data_lost(orig_flags, lost, block_id): return orig_flags # Common case - no data lost flags = orig_flags.copy() for idx in mark: - flags[idx] |= 8 + flags[idx] |= DATA_LOST return flags @@ -149,7 +150,7 @@ def weight_power_scale(block, auto_indices, index1, index2, out=None, tmp=None): (or any two dimensions then baseline). It must contain all the baselines of a stream. auto_indices, index1, index2 : np.ndarray - Arrays returned by :func:`corrprod_to_autocrr` + Arrays returned by :func:`corrprod_to_autocorr` out : np.ndarray, optional If specified, the output array, with same shape as `block` and dtype ``np.float32`` tmp : np.ndarray, optional diff --git a/katdal/flags.py b/katdal/flags.py new file mode 100644 index 00000000..5cb2bd94 --- /dev/null +++ b/katdal/flags.py @@ -0,0 +1,44 @@ +################################################################################ +# Copyright (c) 2019, National Research Foundation (Square Kilometre Array) +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +"""Definitions of flag bits""" + +NAMES = ('reserved0', 'static', 'cam', 'data_lost', + 'ingest_rfi', 'predicted_rfi', 'cal_rfi', 'postproc') +DESCRIPTIONS = ('reserved - bit 0', + 'predefined static flag list', + 'flag based on live CAM information', + 'no data was received', + 'RFI detected in ingest', + 'RFI predicted from space based pollutants', + 'RFI detected in calibration', + 'some correction/postprocessing step could not be applied') + +STATIC_BIT = 1 +CAM_BIT = 2 +DATA_LOST_BIT = 3 +INGEST_RFI_BIT = 4 +PREDICTED_RFI_BIT = 5 +CAL_RFI_BIT = 6 +POSTPROC_BIT = 7 + +STATIC = 1 << STATIC_BIT +CAM = 1 << CAM_BIT +DATA_LOST = 1 << DATA_LOST_BIT +INGEST_RFI = 1 << INGEST_RFI_BIT +PREDICTED_RFI = 1 << PREDICTED_RFI_BIT +CAL_RFI = 1 << CAL_RFI_BIT +POSTPROC = 1 << POSTPROC_BIT diff --git a/katdal/h5datav2.py b/katdal/h5datav2.py index 9cbe9fff..ab4cca65 100644 --- a/katdal/h5datav2.py +++ b/katdal/h5datav2.py @@ -31,6 +31,7 @@ from .sensordata import RecordSensorData, SensorCache, to_str from .categorical import CategoricalData, sensor_to_categorical from .lazy_indexer import LazyIndexer, LazyTransform +from .flags import NAMES as FLAG_NAMES, DESCRIPTIONS as FLAG_DESCRIPTIONS logger = logging.getLogger(__name__) @@ -66,13 +67,6 @@ def _calc_azel(cache, name, ant): VIRTUAL_SENSORS = dict(DEFAULT_VIRTUAL_SENSORS) VIRTUAL_SENSORS.update({'Antennas/{ant}/az': _calc_azel, 'Antennas/{ant}/el': _calc_azel}) -FLAG_NAMES = ('reserved0', 'static', 'cam', 'reserved3', 'detected_rfi', - 'predicted_rfi', 'reserved6', 'reserved7') -FLAG_DESCRIPTIONS = ('reserved - bit 0', 'predefined static flag list', - 'flag based on live CAM information', - 'reserved - bit 3', 'RFI detected in the online system', - 'RFI predicted from space based pollutants', - 'reserved - bit 6', 'reserved - bit 7') WEIGHT_NAMES = ('precision',) WEIGHT_DESCRIPTIONS = ('visibility precision (inverse variance, i.e. 1 / sigma^2)',) diff --git a/katdal/h5datav3.py b/katdal/h5datav3.py index 357dbbec..e652b9fa 100644 --- a/katdal/h5datav3.py +++ b/katdal/h5datav3.py @@ -36,6 +36,7 @@ H5TelstateSensorData, telstate_decode, to_str) from .categorical import CategoricalData from .lazy_indexer import LazyIndexer, LazyTransform +from .flags import NAMES as FLAG_NAMES, DESCRIPTIONS as FLAG_DESCRIPTIONS logger = logging.getLogger(__name__) @@ -70,10 +71,6 @@ def _calc_azel(cache, name, ant): VIRTUAL_SENSORS = dict(DEFAULT_VIRTUAL_SENSORS) VIRTUAL_SENSORS.update({'Antennas/{ant}/az': _calc_azel, 'Antennas/{ant}/el': _calc_azel}) -FLAG_NAMES = ('reserved0', 'static', 'cam', 'data_lost', 'ingest_rfi', 'predicted_rfi', 'cal_rfi', 'reserved7') -FLAG_DESCRIPTIONS = ('reserved - bit 0', 'predefined static flag list', 'flag based on live CAM information', - 'no data was received', 'RFI detected in ingest', 'RFI predicted from space based pollutants', - 'RFI detected in calibration', 'reserved - bit 7') WEIGHT_NAMES = ('precision',) WEIGHT_DESCRIPTIONS = ('visibility precision (inverse variance, i.e. 1 / sigma^2)',) diff --git a/katdal/test/test_applycal.py b/katdal/test/test_applycal.py index 0827936b..a80c43e4 100644 --- a/katdal/test/test_applycal.py +++ b/katdal/test/test_applycal.py @@ -26,13 +26,13 @@ from katdal.spectral_window import SpectralWindow from katdal.sensordata import SensorCache from katdal.categorical import ComparableArrayWrapper, CategoricalData -from katdal.lazy_indexer import DaskLazyIndexer -from katdal.applycal import (complex_interp, calc_correction_per_corrprod, +from katdal.applycal import (complex_interp, has_cal_product, get_cal_product, INVALID_GAIN, calc_delay_correction, calc_bandpass_correction, calc_gain_correction, apply_vis_correction, apply_weights_correction, apply_flags_correction, - add_applycal_sensors, add_applycal_transform) + add_applycal_sensors, calc_correction) +from katdal.flags import POSTPROC POLS = ['v', 'h'] @@ -381,89 +381,59 @@ def test_unknown_inputs_and_products(self): self.cache.get(known_input + '_correction_K_unknown') -class TestCorrectionPerCorrprod(object): - """Test :func:`~katdal.applycal.calc_correction_per_corrprod` function.""" +class TestCalcCorrection(object): + """Test :func:`~katdal.applycal.calc_correction` function.""" def setup(self): self.cache = create_sensor_cache() add_applycal_sensors(self.cache, ATTRS, FREQS) - def test_correction_per_corrprod(self): + def test_calc_correction(self): dump = 15 - channels = list(range(22, 38)) - corrections = calc_correction_per_corrprod(dump, channels, self.cache, - INPUTS, INDEX1, INDEX2, - CAL_PRODUCTS)[np.newaxis] + channels = np.s_[22:38] + shape = (N_DUMPS, N_CHANS, N_CORRPRODS) + chunks = da.core.normalize_chunks((10, 5, -1), shape) + corrections = calc_correction(chunks, self.cache, CORRPRODS, CAL_PRODUCTS) + corrections = corrections[dump:dump+1, channels].compute() expected_corrections = corrections_per_corrprod([dump], channels) assert_array_equal(corrections, expected_corrections) class TestApplyCal(object): - """Test :func:`~katdal.applycal.add_applycal_transform` function.""" + """Test :func:`~katdal.applycal.apply_vis_correction` and friends""" def setup(self): self.cache = create_sensor_cache() add_applycal_sensors(self.cache, ATTRS, FREQS) - time_keep = np.full(N_DUMPS, False, dtype=np.bool_) - time_keep[10:20] = True - freq_keep = np.full(N_CHANS, False, dtype=np.bool_) - freq_keep[22:38] = True - corrprod_keep = np.full(N_CORRPRODS, True, dtype=np.bool_) - # Throw out one antenna - for n, inp in enumerate(INPUTS): - if inp.startswith(ANTS[SKIP_ANT]): - corrprod_keep[INDEX1 == n] = False - corrprod_keep[INDEX2 == n] = False - self.stage1 = (time_keep, freq_keep, corrprod_keep) - # Apply stage 2 selection on top of stage 1 - self.stage2 = np.s_[5:7, 2:5, :] - # List of selected correlation products - self.corrprods = [cp for n, cp in enumerate(CORRPRODS) - if corrprod_keep[n]] def _applycal(self, array, apply_correction): """Calibrate `array` with `apply_correction` and return all factors.""" array_dask = da.from_array(array, chunks=(10, 4, 6)) - indexer = DaskLazyIndexer(array_dask, self.stage1) - add_applycal_transform(indexer, self.cache, self.corrprods, - CAL_PRODUCTS, apply_correction) - calibrated_array = indexer[self.stage2] - stage1_indices = tuple(k.nonzero()[0] for k in self.stage1) - final_indices = tuple(i[s] for s, i in zip(self.stage2, - stage1_indices)) - # Quick & dirty oindex of array (yet another way doing axes in reverse) - selected_array = array - dims = reversed(range(array.ndim)) - for dim, indices in zip(dims, reversed(final_indices)): - selected_array = np.take(selected_array, indices, axis=dim) - # Determine the corrections that would apply to selection - corrections = corrections_per_corrprod(*final_indices) - return calibrated_array, selected_array, corrections + correction = calc_correction(array_dask.chunks, self.cache, CORRPRODS, CAL_PRODUCTS) + corrected = da.core.elemwise(apply_correction, array_dask, correction, dtype=array_dask.dtype) + return corrected.compute(), correction.compute() def test_applycal_vis(self): vis_real = np.random.randn(N_DUMPS, N_CHANS, N_CORRPRODS) vis_imag = np.random.randn(N_DUMPS, N_CHANS, N_CORRPRODS) vis = np.asarray(vis_real + 1j * vis_imag, dtype='complex64') - calibrated_vis, expected_vis, corrections = self._applycal( - vis, apply_vis_correction) + calibrated_vis, corrections = self._applycal(vis, apply_vis_correction) # Leave visibilities alone where gains are NaN corrections[np.isnan(corrections)] = 1.0 - expected_vis *= corrections - assert_array_equal(calibrated_vis, expected_vis) + vis *= corrections + assert_array_equal(calibrated_vis, vis) def test_applycal_weights(self): weights = np.random.rand(N_DUMPS, N_CHANS, N_CORRPRODS).astype('float32') - calibrated_weights, expected_weights, corrections = self._applycal( - weights, apply_weights_correction) - # Zero the weights where the gains are non-finite + calibrated_weights, corrections = self._applycal(weights, apply_weights_correction) + # Zero the weights where the gains are NaN or zero corrections2 = corrections.real ** 2 + corrections.imag ** 2 corrections2[np.isnan(corrections2)] = np.inf corrections2[corrections2 == 0] = np.inf - expected_weights /= corrections2 - assert_array_equal(calibrated_weights, expected_weights) + weights /= corrections2 + assert_array_equal(calibrated_weights, weights) def test_applycal_flags(self): - flags = np.random.rand(N_DUMPS, N_CHANS, N_CORRPRODS) > 0.5 - calibrated_flags, expected_flags, corrections = self._applycal( - flags, apply_flags_correction) - expected_flags[np.isnan(corrections)] = True - assert_array_equal(calibrated_flags, expected_flags) + flags = np.random.randint(0, 128, (N_DUMPS, N_CHANS, N_CORRPRODS), np.uint8) + calibrated_flags, corrections = self._applycal(flags, apply_flags_correction) + flags |= np.where(np.isnan(corrections), np.uint8(POSTPROC), np.uint8(0)) + assert_array_equal(calibrated_flags, flags) diff --git a/katdal/test/test_datasources.py b/katdal/test/test_datasources.py index 07a8d2f3..c9f26c65 100644 --- a/katdal/test/test_datasources.py +++ b/katdal/test/test_datasources.py @@ -33,9 +33,7 @@ from katdal.chunkstore import generate_chunks from katdal.chunkstore_npy import NpyFileChunkStore from katdal.datasources import ChunkStoreVisFlagsWeights, TelstateDataSource, view_l0_capture_stream - - -DATA_LOST = 8 # TODO: introduce katdal.flags module for these +from katdal.flags import DATA_LOST def ramp(shape, offset=1.0, slope=1.0, dtype=np.float_): diff --git a/katdal/visdatav4.py b/katdal/visdatav4.py index 6a8ccf6e..d6b2b5bc 100644 --- a/katdal/visdatav4.py +++ b/katdal/visdatav4.py @@ -27,13 +27,15 @@ from .dataset import (DataSet, BrokenFile, Subarray, DEFAULT_SENSOR_PROPS, DEFAULT_VIRTUAL_SENSORS, _robust_target, _selection_to_list) +from .datasources import VisFlagsWeights from .spectral_window import SpectralWindow from .sensordata import SensorCache from .categorical import CategoricalData from .lazy_indexer import DaskLazyIndexer -from .applycal import (add_applycal_sensors, add_applycal_transform, +from .applycal import (add_applycal_sensors, calc_correction, apply_vis_correction, apply_weights_correction, apply_flags_correction, has_cal_product, CAL_PRODUCTS) +from .flags import NAMES as FLAG_NAMES, DESCRIPTIONS as FLAG_DESCRIPTIONS logger = logging.getLogger(__name__) @@ -80,17 +82,6 @@ def _add_sensor_alias(cache, new_name, old_name): VIRTUAL_SENSORS.update({'Antennas/{ant}/az': _calc_azel, 'Antennas/{ant}/el': _calc_azel}) -FLAG_NAMES = ('reserved0', 'static', 'cam', 'data_lost', - 'ingest_rfi', 'predicted_rfi', 'cal_rfi', 'reserved7') -FLAG_DESCRIPTIONS = ('reserved - bit 0', - 'predefined static flag list', - 'flag based on live CAM information', - 'no data was received', - 'RFI detected in ingest', - 'RFI predicted from space based pollutants', - 'RFI detected in calibration', - 'reserved - bit 7') - # ----------------------------------------------------------------------------- # -- CLASS : VisibilityDataV4 # ----------------------------------------------------------------------------- @@ -343,11 +334,32 @@ def __init__(self, source, ref_ant='', time_offset=0.0, applycal='', available_products = [product for product in CAL_PRODUCTS if has_cal_product(self.sensor, attrs, product)] self._applycal = _selection_to_list(applycal, all=available_products) + if not self.source.data or not self._applycal: + self._corrections = None + self._corrected = self.source.data + else: + self._corrections = calc_correction(self.source.data.vis.chunks, self.sensor, + self.subarrays[self.subarray].corr_products, + self._applycal) + corrected_vis = self._make_corrected(apply_vis_correction, self.source.data.vis) + corrected_flags = self._make_corrected(apply_flags_correction, self.source.data.flags) + corrected_weights = self._make_corrected(apply_weights_correction, self.source.data.weights) + name = self.source.data.name + # Acknowledge that the applycal step is making the L1 product + if 'sdp_l0' in name: + name = name.replace('sdp_l0', 'sdp_l1') + else: + name = name + ' (corrected)' + self._corrected = VisFlagsWeights(corrected_vis, corrected_flags, corrected_weights, + name=name) # Apply default selection and initialise all members that depend # on selection in the process self.select(spw=0, subarray=0, ants=obs_ants) + def _make_corrected(self, apply_correction, data): + return da.core.elemwise(apply_correction, data, self._corrections, dtype=data.dtype) + @property def _flags_keep(self): # Reverse flag indices as np.packbits has bit 0 as the MSB (we want LSB) @@ -411,26 +423,15 @@ def _set_keep(self, time_keep=None, freq_keep=None, corrprod_keep=None, stage1 = (self._time_keep, self._freq_keep, self._corrprod_keep) if update_all: # Cache dask graphs for the data fields - self._vis = DaskLazyIndexer(self.source.data.vis, stage1) - self._weights = DaskLazyIndexer(self.source.data.weights, stage1) - if self._applycal: - add_applycal_transform(self._vis, self.sensor, - self.corr_products, self._applycal, - apply_vis_correction) - add_applycal_transform(self._weights, self.sensor, - self.corr_products, self._applycal, - apply_weights_correction) + self._vis = DaskLazyIndexer(self._corrected.vis, stage1) + self._weights = DaskLazyIndexer(self._corrected.weights, stage1) flag_transforms = [] if ~self._flags_select != 0: # Copy so that the lambda isn't affected by future changes select = self._flags_select.copy() flag_transforms.append(lambda flags: da.bitwise_and(select, flags)) flag_transforms.append(lambda flags: flags.view(np.bool_)) - self._flags = DaskLazyIndexer(self.source.data.flags, stage1, flag_transforms) - if self._applycal: - add_applycal_transform(self._flags, self.sensor, - self.corr_products, self._applycal, - apply_flags_correction) + self._flags = DaskLazyIndexer(self._corrected.flags, stage1, flag_transforms) @property def timestamps(self): diff --git a/scripts/mvf_read_benchmark.py b/scripts/mvf_read_benchmark.py index 6f11efbc..b1f57cd6 100755 --- a/scripts/mvf_read_benchmark.py +++ b/scripts/mvf_read_benchmark.py @@ -6,9 +6,11 @@ import logging import time +import dask +import numpy as np + import katdal from katdal.lazy_indexer import DaskLazyIndexer -import numpy as np parser = argparse.ArgumentParser() @@ -18,9 +20,12 @@ parser.add_argument('--dumps', type=int, help='Number of times to read') parser.add_argument('--joint', action='store_true', help='Load vis, weights, flags together') parser.add_argument('--applycal', help='Calibration solutions to apply') +parser.add_argument('--workers', type=int, help='Number of dask workers') args = parser.parse_args() logging.basicConfig(level='INFO', format='%(asctime)s [%(levelname)s] %(message)s') +if args.workers is not None: + dask.config.set(num_workers=args.workers) logging.info('Starting') kwargs = {} if args.applycal is not None: @@ -31,6 +36,9 @@ f.select(channels=np.s_[:args.channels]) if args.dumps: f.select(dumps=np.s_[:args.dumps]) +# Trigger creation of the dask graphs, population of sensor cache for applycal etc +_ = (f.vis[0, 0, 0], f.weights[0, 0, 0], f.flags[0, 0, 0]) +logging.info('Selection complete') start = time.time() for st in range(0, f.shape[0], args.time): et = st + args.time