Skip to content

Commit

Permalink
Add support for CI8 data going into lsl.correlator._core.XEngine2/3.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaycedowell committed Jan 30, 2025
1 parent 89ac781 commit 04d4561
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
17 changes: 17 additions & 0 deletions lsl/correlator/blas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,20 @@ inline void blas_dotc_sub(const int N,
* cblas_zdotc_sub(N, X, incX, Y, incY, dotc);
*/
}

// blas_dotc_sub but for 8+8-bit complex integers, aka CI8
template<typename OType>
inline void blas_dotc_sub(const int N,
const int8_t* X, const int incX,
const int8_t* Y, const int incY,
OType* dotc) {
OType PX, PY, accum = 0.0;
for(int i=0; i<N; i++) {
PX = OType(*X, *(X+1));
PY = OType(*Y, *(Y+1));
accum += conj(PX) * PY;
X += 2*incX;
Y += 2*incY;
}
*dotc = accum;
}
28 changes: 17 additions & 11 deletions lsl/correlator/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,8 @@ void compute_xengine_two(long nStand,
// Setup
Py_BEGIN_ALLOW_THREADS

constexpr int scale = std::is_same<InType, int8_t>::value ? 2 : 1;

// Mapper for baseline number to stand 1, stand 2
long s1, s2, mapper[nBL][2];
long k = 0;
Expand Down Expand Up @@ -790,7 +792,7 @@ void compute_xengine_two(long nStand,
}

for(c=0; c<nChan; c++) {
blas_dotc_sub(nFFT, (data2 + mapper[bl][1]*nChan*nFFT + c*nFFT), 1, (data1 + mapper[bl][0]*nChan*nFFT + c*nFFT), 1, &tempVis);
blas_dotc_sub(nFFT, (data2 + scale*mapper[bl][1]*nChan*nFFT + scale*c*nFFT), 1, (data1 + scale*mapper[bl][0]*nChan*nFFT + scale*c*nFFT), 1, &tempVis);
*(dataA + bl*nChan + c) = tempVis / (float) nActVis;
}
}
Expand All @@ -814,6 +816,8 @@ void compute_xengine_three(long nStand,
// Setup
Py_BEGIN_ALLOW_THREADS

constexpr int scale = std::is_same<InType, int8_t>::value ? 2 : 1;

// Mapper for baseline number to stand 1, stand 2
long s1, s2, mapper[nBL][2];
long k = 0;
Expand Down Expand Up @@ -852,19 +856,19 @@ void compute_xengine_three(long nStand,

for(c=0; c<nChan; c++) {
// XX
blas_dotc_sub(nFFT, (dataX + mapper[bl][1]*nChan*nFFT + c*nFFT), 1, (dataX + mapper[bl][0]*nChan*nFFT + c*nFFT), 1, &tempVis);
blas_dotc_sub(nFFT, (dataX + scale*mapper[bl][1]*nChan*nFFT + scale*c*nFFT), 1, (dataX + scale*mapper[bl][0]*nChan*nFFT + scale*c*nFFT), 1, &tempVis);
*(dataA + 0*nBL*nChan + bl*nChan + c) = tempVis / (float) nActVisPureX;

// XY
blas_dotc_sub(nFFT, (dataY + mapper[bl][1]*nChan*nFFT + c*nFFT), 1, (dataX + mapper[bl][0]*nChan*nFFT + c*nFFT), 1, &tempVis);
blas_dotc_sub(nFFT, (dataY + scale*mapper[bl][1]*nChan*nFFT + scale*c*nFFT), 1, (dataX + scale*mapper[bl][0]*nChan*nFFT + scale*c*nFFT), 1, &tempVis);
*(dataA + 1*nBL*nChan + bl*nChan + c) = tempVis / (float) nActVisCross0;

// YX
blas_dotc_sub(nFFT, (dataX + mapper[bl][1]*nChan*nFFT + c*nFFT), 1, (dataY + mapper[bl][0]*nChan*nFFT + c*nFFT), 1, &tempVis);
blas_dotc_sub(nFFT, (dataX + scale*mapper[bl][1]*nChan*nFFT + scale*c*nFFT), 1, (dataY + scale*mapper[bl][0]*nChan*nFFT + scale*c*nFFT), 1, &tempVis);
*(dataA + 2*nBL*nChan + bl*nChan + c) = tempVis / (float) nActVisCross1;

// YY
blas_dotc_sub(nFFT, (dataY + mapper[bl][1]*nChan*nFFT + c*nFFT), 1, (dataY + mapper[bl][0]*nChan*nFFT + c*nFFT), 1, &tempVis);
blas_dotc_sub(nFFT, (dataY + scale*mapper[bl][1]*nChan*nFFT + scale*c*nFFT), 1, (dataY + scale*mapper[bl][0]*nChan*nFFT + scale*c*nFFT), 1, &tempVis);
*(dataA + 3*nBL*nChan + bl*nChan + c) = tempVis / (float) nActVisPureY;
}
}
Expand All @@ -888,10 +892,10 @@ static PyObject *XEngine2(PyObject *self, PyObject *args) {
// Bring the data into C and make it usable
data1 = (PyArrayObject *) PyArray_ContiguousFromObject(signals1,
PyArray_TYPE((PyArrayObject *) signals1),
3, 3);
3, 4);
data2 = (PyArrayObject *) PyArray_ContiguousFromObject(signals2,
PyArray_TYPE((PyArrayObject *) signals1),
3, 3);
3, 4);
valid1 = (PyArrayObject *) PyArray_ContiguousFromObject(sigValid1, NPY_UINT8, 2, 2);
valid2 = (PyArrayObject *) PyArray_ContiguousFromObject(sigValid2, NPY_UINT8, 2, 2);
if( data1 == NULL ) {
Expand Down Expand Up @@ -935,7 +939,8 @@ static PyObject *XEngine2(PyObject *self, PyObject *args) {
(unsigned char *) PyArray_DATA(valid2), \
(Complex32 *) PyArray_DATA(vis))

switch( PyArray_TYPE(data1) ){
switch( PyArray_LSL_TYPE(data1, 3) ){
case( LSL_CI8 ): LAUNCH_XENGINE_TWO(int8_t); break;
case( NPY_COMPLEX64 ): LAUNCH_XENGINE_TWO(Complex32); break;
case( NPY_COMPLEX128 ): LAUNCH_XENGINE_TWO(Complex64); break;
default: PyErr_Format(PyExc_RuntimeError, "Unsupport input data type"); goto fail;
Expand Down Expand Up @@ -999,10 +1004,10 @@ static PyObject *XEngine3(PyObject *self, PyObject *args) {
// Bring the data into C and make it usable
dataX = (PyArrayObject *) PyArray_ContiguousFromObject(signalsX,
PyArray_TYPE((PyArrayObject *) signalsX),
3, 3);
3, 4);
dataY = (PyArrayObject *) PyArray_ContiguousFromObject(signalsY,
PyArray_TYPE((PyArrayObject *) signalsX),
3, 3);
3, 4);
validX = (PyArrayObject *) PyArray_ContiguousFromObject(sigValidX, NPY_UINT8, 2, 2);
validY = (PyArrayObject *) PyArray_ContiguousFromObject(sigValidY, NPY_UINT8, 2, 2);
if( dataX == NULL ) {
Expand Down Expand Up @@ -1047,7 +1052,8 @@ static PyObject *XEngine3(PyObject *self, PyObject *args) {
(unsigned char *) PyArray_DATA(validY), \
(Complex32 *) PyArray_DATA(vis))

switch( PyArray_TYPE(dataX) ){
switch( PyArray_LSL_TYPE(dataX, 3) ){
case( LSL_CI8 ): LAUNCH_XENGINE_THREE(int8_t); break;
case( NPY_COMPLEX64 ): LAUNCH_XENGINE_THREE(Complex32); break;
case( NPY_COMPLEX128 ): LAUNCH_XENGINE_THREE(Complex64); break;
default: PyErr_Format(PyExc_RuntimeError, "Unsupport input data type"); goto fail;
Expand Down
48 changes: 47 additions & 1 deletion tests/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from lsl.common.data_access import DataAccess
from lsl.common import stations
from lsl.correlator import fx
from lsl.correlator import fx, _core
from lsl.reader.base import CI8
import lsl.testing

Expand Down Expand Up @@ -1009,6 +1009,51 @@ def wndw2(L):
self.run_correlator_test_complex(dtype, nchan=259, window=wndw2)


class _XEngine_tests(unittest.TestCase):
"""A unittest.TestCase collection of unit tests for the lsl.correlator._core.XEngine2 and 3 functions on CI8 data."""

nAnt = 8

def setUp(self):
"""Turn off all numpy and python warnings."""

np.seterr(all='ignore')
warnings.simplefilter('ignore')
np.random.seed(1234)

def test_xengine_2(self, nchan=256):
"""Test that _core.XEngine2 works with CI8 data"""

fakeData = 10.0*np.random.rand(self.nAnt,nchan,16) + 3.0
fakeData = np.round(fakeData).astype(np.float32)
fakeDataCI8 = fakeData.astype(np.int8).view(CI8)
fakeDataCF32 = fakeData.view(np.complex64)
fakeValid = np.ones((self.nAnt,8), dtype=np.uint8)

if fakeDataCI8.dtype == CI8:
fakeDataCI8 = fakeDataCI8.view(np.int8)
fakeDataCI8 = fakeDataCI8.reshape(fakeDataCI8.shape[:-1]+(-1,2))
cps = _core.XEngine2(fakeDataCI8, fakeDataCI8, fakeValid, fakeValid)
cps2 = _core.XEngine2(fakeDataCF32, fakeDataCF32, fakeValid, fakeValid)
lsl.testing.assert_allclose(cps, cps2)

def test_xengine_3(self, nchan=256):
"""Test that _core.XEngine3 works with CI8 data"""

fakeData = 10.0*np.random.rand(self.nAnt,nchan,16) + 3.0
fakeData = np.round(fakeData).astype(np.float32)
fakeDataCI8 = fakeData.astype(np.int8).view(CI8)
fakeDataCF32 = fakeData.view(np.complex64)
fakeValid = np.ones((self.nAnt,8), dtype=np.uint8)

if fakeDataCI8.dtype == CI8:
fakeDataCI8 = fakeDataCI8.view(np.int8)
fakeDataCI8 = fakeDataCI8.reshape(fakeDataCI8.shape[:-1]+(-1,2))
cps = _core.XEngine3(fakeDataCI8, fakeDataCI8, fakeValid, fakeValid)
cps2 = _core.XEngine3(fakeDataCF32, fakeDataCF32, fakeValid, fakeValid)
lsl.testing.assert_allclose(cps, cps2)


class fx_test_suite(unittest.TestSuite):
"""A unittest.TestSuite class which contains all of the lsl.correlator.fx
units tests."""
Expand All @@ -1021,6 +1066,7 @@ def __init__(self):
self.addTests(loader.loadTestsFromTestCase(StokesMaster_tests))
self.addTests(loader.loadTestsFromTestCase(FXMaster_tests))
self.addTests(loader.loadTestsFromTestCase(FXStokes_tests))
self.addTests(loader.loadTestsFromTestCase(_XEngine_tests))


if __name__ == '__main__':
Expand Down

0 comments on commit 04d4561

Please sign in to comment.