Skip to content

Commit

Permalink
add dataset version of inverse PCA (#125)
Browse files Browse the repository at this point in the history
* add dataset version of inverse PCA

* PCA: Add whitening to batch inverse transform
  • Loading branch information
weefuzzy authored May 3, 2022
1 parent 74c3d05 commit 114cb0c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
23 changes: 23 additions & 0 deletions include/algorithms/public/PCA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ class PCA
return variance / total;
}

void inverseProcess(RealMatrixView in, RealMatrixView out, bool whiten = false) const
{
using namespace Eigen;

if (in.cols() > dims()) return;
if (out.cols() < in.cols()) return;

if (!whiten)
_impl::asEigen<Matrix>(out) =
(_impl::asEigen<Matrix>(in) * mBases.transpose()).rowwise() +
mMean.transpose();

else
{
_impl::asEigen<Matrix>(out) =
(_impl::asEigen<Matrix>(in) *
(mExplainedVariance.sqrt().matrix().asDiagonal() *
mBases.transpose()))
.rowwise() +
mMean.transpose();
}
}

bool initialized() const { return mInitialized; }

void getBases(RealMatrixView out) const { out <<= _impl::asFluid(mBases); }
Expand Down
32 changes: 31 additions & 1 deletion include/clients/nrt/PCAClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ class PCAClient : public FluidBaseClient,
return result;
}

MessageResult<void> inverseTransform(InputDataSetClientRef sourceClient,
DataSetClientRef destClient) const
{

auto srcPtr = sourceClient.get().lock();
auto destPtr = destClient.get().lock();

if (srcPtr && destPtr)
{
auto srcDataSet = srcPtr->getDataSet();
if (srcDataSet.size() == 0) return Error<void>(EmptyDataSet);
if (!mAlgorithm.initialized()) return Error<void>(NoDataFitted);
StringVector ids{srcDataSet.getIds()};
RealMatrix paddedInput(srcPtr->size(), mAlgorithm.dims());
auto inputData = srcDataSet.getData();
paddedInput(Slice(0, inputData.rows()), Slice(0, inputData.cols())) <<=
inputData;
RealMatrix output(srcDataSet.size(), mAlgorithm.dims());
mAlgorithm.inverseProcess(paddedInput, output,get<kWhiten>() == 1);
FluidDataSet<string, double, 1> result(ids, output);
destPtr->setDataSet(result);
return {};
}
else
{
return Error<void>(NoDataSet);
}
}

MessageResult<void> transformPoint(InputBufferPtr in, BufferPtr out) const
{
index k = get<kNumDimensions>();
Expand Down Expand Up @@ -150,7 +179,7 @@ class PCAClient : public FluidBaseClient,
Result resizeResult = outBuf.resize(mAlgorithm.dims(), 1, outBuf.sampleRate());

mAlgorithm.inverseProcessFrame(src, dst, get<kWhiten>());
outBuf.samps(0,mAlgorithm.dims(),0)<< = dst;
outBuf.samps(0,mAlgorithm.dims(),0) <<= dst;
return OK();
}

Expand All @@ -160,6 +189,7 @@ class PCAClient : public FluidBaseClient,
makeMessage("fit", &PCAClient::fit),
makeMessage("transform", &PCAClient::transform),
makeMessage("fitTransform", &PCAClient::fitTransform),
makeMessage("inverseTransform",&PCAClient::inverseTransform),
makeMessage("transformPoint", &PCAClient::transformPoint),
makeMessage("inverseTransformPoint", &PCAClient::inverseTransformPoint),
makeMessage("cols", &PCAClient::dims),
Expand Down

0 comments on commit 114cb0c

Please sign in to comment.