-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/skmeans #99
Feature/skmeans #99
Changes from 3 commits
5d54eb0
1251b99
17f8b08
f55870c
3fd2b93
ee73722
db6753e
56496d9
3db0531
2875cdf
969638b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,121 @@ | ||||||
/* | ||||||
Part of the Fluid Corpus Manipulation Project (http://www.flucoma.org/) | ||||||
Copyright 2017-2019 University of Huddersfield. | ||||||
Licensed under the BSD-3 License. | ||||||
See license.md file in the project root for full license information. | ||||||
This project has received funding from the European Research Council (ERC) | ||||||
under the European Union’s Horizon 2020 research and innovation programme | ||||||
(grant agreement No 725899). | ||||||
*/ | ||||||
|
||||||
#pragma once | ||||||
|
||||||
#include "../public/KMeans.hpp" | ||||||
#include "../util/FluidEigenMappings.hpp" | ||||||
#include "../../data/FluidDataSet.hpp" | ||||||
#include "../../data/FluidIndex.hpp" | ||||||
#include "../../data/FluidTensor.hpp" | ||||||
#include "../../data/TensorTypes.hpp" | ||||||
#include <Eigen/Core> | ||||||
#include <queue> | ||||||
#include <string> | ||||||
|
||||||
namespace fluid { | ||||||
namespace algorithm { | ||||||
|
||||||
class SKMeans : public KMeans | ||||||
{ | ||||||
|
||||||
public: | ||||||
void train(const FluidDataSet<std::string, double, 1>& dataset, index k, | ||||||
index maxIter) | ||||||
{ | ||||||
using namespace Eigen; | ||||||
using namespace _impl; | ||||||
assert(!mTrained || (dataset.pointSize() == mDims && mK == k)); | ||||||
MatrixXd dataPoints = asEigen<Matrix>(dataset.getData()); | ||||||
MatrixXd dataPointsT = dataPoints.transpose(); | ||||||
if (mTrained) { mAssignments = assignClusters(dataPointsT);} | ||||||
else | ||||||
{ | ||||||
mK = k; | ||||||
mDims = dataset.pointSize(); | ||||||
initMeans(dataPoints); | ||||||
} | ||||||
|
||||||
while (maxIter-- > 0) | ||||||
{ | ||||||
mEmbedding = mMeans.matrix() * dataPointsT; | ||||||
auto assignments = assignClusters(mEmbedding); | ||||||
if (!changed(assignments)) { break; } | ||||||
else | ||||||
mAssignments = assignments; | ||||||
updateEmbedding(); | ||||||
computeMeans(dataPoints); | ||||||
} | ||||||
mTrained = true; | ||||||
} | ||||||
|
||||||
|
||||||
void transform(RealMatrixView data, RealMatrixView out, | ||||||
double alpha = 0.25) const | ||||||
{ | ||||||
using namespace Eigen; | ||||||
MatrixXd points = _impl::asEigen<Matrix>(data).transpose(); | ||||||
MatrixXd embedding = (mMeans.matrix() * points).array() - alpha; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we sure about baking in the encoding scheme from Coates and Ng here? I guess the argument in favour is that it enables recreating their feature learning scheme with the fewest objects. The arguments against would be that it's not strictly part of skmeans, but was a separate step used by C&N specifically in the feature learning setting, and they do discuss alternatives. Obviously having it here doesn't preclude using an alternative scheme, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The soft thresholding function is similar to neural network activation functions, so NNFuncs is where it would fit best, but I don't think these functions would deserve their own client, so in practice I would still see this as part of the SKMeans client. So it would not help with code duplication. An interesting idea (maybe for the future?) could be to have a feature learning client that could use different learning techniques and encodings. For the moment, maybe it can be introduced as an option for skmeans. We can also use an MLP as autoencoder for feature learning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, let's roll with what we have and see how we get on. I like the idea of some future feature learning object that could make it easy to explore options and manage some of the complexity / fiddliness. |
||||||
embedding = (embedding.array() > 0).select(embedding, 0).transpose(); | ||||||
out = _impl::asFluid(embedding); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
New copy assignment operator for |
||||||
} | ||||||
|
||||||
private: | ||||||
|
||||||
void initMeans(Eigen::MatrixXd& dataPoints) | ||||||
{ | ||||||
using namespace Eigen; | ||||||
mMeans = ArrayXXd::Zero(mK, mDims); | ||||||
mAssignments = | ||||||
((0.5 + (0.5 * ArrayXd::Random(dataPoints.rows()))) * (mK - 1)) | ||||||
.round() | ||||||
.cast<int>(); | ||||||
mEmbedding = MatrixXd::Zero(mK, dataPoints.rows()); | ||||||
for (index i = 0; i < dataPoints.rows(); i++) | ||||||
mEmbedding(mAssignments(i), i) = 1; | ||||||
computeMeans(dataPoints); | ||||||
} | ||||||
|
||||||
void updateEmbedding() | ||||||
{ | ||||||
for (index i = 0; i < mAssignments.cols(); i++) | ||||||
{ | ||||||
double val = mEmbedding(mAssignments(i), i); | ||||||
mEmbedding.col(i).setZero(); | ||||||
mEmbedding(mAssignments(i), i) = val; | ||||||
} | ||||||
} | ||||||
|
||||||
|
||||||
Eigen::VectorXi assignClusters(Eigen::MatrixXd& embedding) const | ||||||
{ | ||||||
Eigen::VectorXi assignments = Eigen::VectorXi::Zero(embedding.cols()); | ||||||
for (index i = 0; i < embedding.cols(); i++) | ||||||
{ | ||||||
Eigen::VectorXd::Index maxIndex; | ||||||
embedding.col(i).maxCoeff(&maxIndex); | ||||||
assignments(i) = static_cast<int>(maxIndex); | ||||||
} | ||||||
return assignments; | ||||||
} | ||||||
|
||||||
|
||||||
void computeMeans(Eigen::MatrixXd& dataPoints) | ||||||
{ | ||||||
mMeans = mEmbedding * dataPoints; | ||||||
mMeans.matrix().rowwise().normalize(); | ||||||
} | ||||||
|
||||||
|
||||||
private: | ||||||
Eigen::MatrixXd mEmbedding; | ||||||
}; | ||||||
} // namespace algorithm | ||||||
} // namespace fluid |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||||||
|
||||||||||
#include <algorithms/public/KDTree.hpp> | ||||||||||
#include <algorithms/public/KMeans.hpp> | ||||||||||
#include <algorithms/public/SKMeans.hpp> | ||||||||||
#include <algorithms/public/Normalization.hpp> | ||||||||||
#include <algorithms/public/RobustScaling.hpp> | ||||||||||
#include <algorithms/public/PCA.hpp> | ||||||||||
|
@@ -186,6 +187,31 @@ void from_json(const nlohmann::json &j, KMeans &kmeans) { | |||||||||
kmeans.setMeans(means); | ||||||||||
} | ||||||||||
|
||||||||||
// SKMeans | ||||||||||
void to_json(nlohmann::json &j, const SKMeans &skmeans) { | ||||||||||
RealMatrix means(skmeans.getK(), skmeans.dims()); | ||||||||||
skmeans.getMeans(means); | ||||||||||
j["means"] = RealMatrixView(means); | ||||||||||
j["rows"] = means.rows(); | ||||||||||
j["cols"] = means.cols(); | ||||||||||
} | ||||||||||
|
||||||||||
bool check_json(const nlohmann::json &j, const SKMeans &) { | ||||||||||
return fluid::check_json(j, | ||||||||||
{"rows", "cols", "means"}, | ||||||||||
{JSONTypes::NUMBER, JSONTypes::NUMBER,JSONTypes::ARRAY} | ||||||||||
); | ||||||||||
} | ||||||||||
|
||||||||||
void from_json(const nlohmann::json &j, SKMeans &skmeans) { | ||||||||||
index rows = j.at("rows"); | ||||||||||
index cols = j.at("cols"); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The implicit conversion in |
||||||||||
RealMatrix means(rows, cols); | ||||||||||
j.at("means").get_to(means); | ||||||||||
skmeans.setMeans(means); | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
// Normalize | ||||||||||
void to_json(nlohmann::json &j, const Normalization &normalization) { | ||||||||||
RealVector dataMin(normalization.dims()); | ||||||||||
|
@@ -296,20 +322,21 @@ void to_json(nlohmann::json &j, const PCA &pca) { | |||||||||
index cols = pca.size(); | ||||||||||
RealMatrix bases(rows, cols); | ||||||||||
RealVector values(cols); | ||||||||||
RealVector explainedVariance(cols); | ||||||||||
RealVector mean(rows); | ||||||||||
pca.getBases(bases); | ||||||||||
pca.getValues(values); | ||||||||||
pca.getExplainedVariance(explainedVariance); | ||||||||||
pca.getMean(mean); | ||||||||||
j["bases"] = RealMatrixView(bases); | ||||||||||
j["values"] = RealVectorView(values); | ||||||||||
j["explainedvariance"] = RealVectorView(explainedVariance); | ||||||||||
j["mean"] = RealVectorView(mean); | ||||||||||
j["rows"] = rows; | ||||||||||
j["cols"] = cols; | ||||||||||
} | ||||||||||
|
||||||||||
bool check_json(const nlohmann::json &j, const PCA &) { | ||||||||||
return fluid::check_json(j, | ||||||||||
{"rows","cols", "bases", "values", "mean"}, | ||||||||||
{"rows","cols", "bases", "explainedvariance", "mean"}, | ||||||||||
{JSONTypes::NUMBER, JSONTypes::NUMBER, JSONTypes::ARRAY, JSONTypes::ARRAY, JSONTypes::ARRAY} | ||||||||||
); | ||||||||||
} | ||||||||||
|
@@ -319,11 +346,11 @@ void from_json(const nlohmann::json &j, PCA &pca) { | |||||||||
index cols = j.at("cols"); | ||||||||||
RealMatrix bases(rows, cols); | ||||||||||
RealVector mean(rows); | ||||||||||
RealVector values(cols); | ||||||||||
RealVector explainedVariance(cols); | ||||||||||
j.at("mean").get_to(mean); | ||||||||||
j.at("values").get_to(values); | ||||||||||
j.at("explainedvariance").get_to(explainedVariance); | ||||||||||
j.at("bases").get_to(bases); | ||||||||||
pca.init(bases, values, mean); | ||||||||||
pca.init(bases, explainedVariance, mean); | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.