Skip to content

Commit

Permalink
Improve convergence of the Newton-based search for beta.
Browse files Browse the repository at this point in the history
We improve convergence by setting the initial beta based on the distance to the
furthest neighbor. This adjusts the initial guess to match the scale of the
distances, reducing the number of iterations to get to the right ballpark.

We also make sure to update the binary search bounds, even when Newton's is
working properly. This ensures that, if we ever need to fall back to a binary
search, we can leverage previous Newton successes to narrow the interval.
  • Loading branch information
LTLA committed Sep 4, 2024
1 parent 95d4d20 commit c62b953
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 104 deletions.
174 changes: 107 additions & 67 deletions include/qdtsne/gaussian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@ namespace qdtsne {

namespace internal {

template<typename Index_, typename Float_>
void compute_gaussian_perplexity(NeighborList<Index_, Float_>& neighbors, Float_ perplexity, [[maybe_unused]] int nthreads) {
constexpr Float_ max_value = std::numeric_limits<Float_>::max();
constexpr Float_ tol = 1e-5;

const size_t num_points = neighbors.size();
// NOTE: the QDTSNE_R_PACKAGE_TESTING macro recapitulates the gaussian kernel
// of the Rtsne package so that we can get a more precise comparison to a
// trusted reference implementation. It should not be defined in production.

template<bool use_newton_ =
#ifndef QDTSNE_R_PACKAGE_TESTING
true
#else
false
#endif
, typename Index_, typename Float_>
void compute_gaussian_perplexity(NeighborList<Index_, Float_>& neighbors, Float_ perplexity, int nthreads) {
const size_t num_points = neighbors.size();
const Float_ log_perplexity = std::log(perplexity);

parallelize(nthreads, num_points, [&](int, size_t start, size_t length) -> void {
Expand All @@ -29,95 +36,128 @@ void compute_gaussian_perplexity(NeighborList<Index_, Float_>& neighbors, Float_
for (size_t n = start, end = start + length; n < end; ++n) {
auto& current = neighbors[n];
const int K = current.size();
if (K) {
squared_delta_dist.resize(K);
quad_delta_dist.resize(K);

// We adjust the probabilities by subtracting the first squared
// distance from everything. This avoids problems with underflow
// when converting distances to probabilities; it otherwise has no
// effect on the entropy or even the final probabilities because it
// just scales all probabilities up/down (and they need to be
// normalized anyway, so any scaling effect just cancels out).
const Float_ first = current[0].second;
const Float_ first2 = first * first;
if (K == 0) {
continue;
}

for (int m = 1; m < K; ++m) {
Float_ dist = current[m].second;
Float_ squared_delta_dist_raw = dist * dist - first2;
squared_delta_dist[m] = squared_delta_dist_raw;
quad_delta_dist[m] = squared_delta_dist_raw * squared_delta_dist_raw;
squared_delta_dist.resize(K);
quad_delta_dist.resize(K);

// We adjust the probabilities by subtracting the first squared
// distance from everything. This avoids problems with underflow
// when converting distances to probabilities; it otherwise has no
// effect on the entropy or even the final probabilities because it
// just scales all probabilities up/down (and they need to be
// normalized anyway, so any scaling effect just cancels out).
const Float_ first = current[0].second;
const Float_ first2 = first * first;

for (int m = 1; m < K; ++m) {
Float_ dist = current[m].second;
Float_ squared_delta_dist_raw = dist * dist - first2;
squared_delta_dist[m] = squared_delta_dist_raw;
quad_delta_dist[m] = squared_delta_dist_raw * squared_delta_dist_raw;
}

auto last_squared_delta = squared_delta_dist.back();
if (last_squared_delta == 0) { // quitting early as entropy doesn't depend on beta.
for (auto& x : current) {
x.second = 1.0 / K;
}
return;
}

Float_ beta = 1.0;
Float_ min_beta = 0, max_beta = max_value;
Float_ sum_P = 0;
output.resize(K);
output[0] = 1;
// Choosing an initial beta that matches the scale of the (squared) distances.
// The choice of numerator is largely based on trial and error to see what
// minimizes the number of iterations in some simulated data.
Float_ beta =
#ifndef QDTSNE_R_PACKAGE_TESTING
3.0 / last_squared_delta
#else
1
#endif
;

constexpr Float_ max_value = std::numeric_limits<Float_>::max();
Float_ min_beta = 0, max_beta = max_value;
Float_ sum_P = 0;
output.resize(K);
output[0] = 1;

constexpr int max_iter = 200;
for (int iter = 0; iter < max_iter; ++iter) {
// We skip the first value because we know that squared_delta_dist[0] = 0
// (as we subtracted 'first') and thus output[0] = 1. We repeat this for
// all iterations from [1, K), e.g., squared_delta_dist, quad_delta_dist.
for (int m = 1; m < K; ++m) {
output[m] = std::exp(-beta * squared_delta_dist[m]);
}

for (int iter = 0; iter < 200; ++iter) {
// We skip the first value because we know that squared_delta_dist[0] = 0
// (as we subtracted 'first') and thus output[0] = 1.
for (int m = 1; m < K; ++m) {
output[m] = std::exp(-beta * squared_delta_dist[m]);
}
sum_P = std::accumulate(output.begin() + 1, output.end(), static_cast<Float_>(1));
const Float_ prod = std::inner_product(squared_delta_dist.begin() + 1, squared_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0));
const Float_ entropy = beta * (prod / sum_P) + std::log(sum_P);

sum_P = std::accumulate(output.begin() + 1, output.end(), static_cast<Float_>(1));
const Float_ prod = std::inner_product(squared_delta_dist.begin() + 1, squared_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0));
const Float_ entropy = beta * (prod / sum_P) + std::log(sum_P);
const Float_ diff = entropy - log_perplexity;
constexpr Float_ tol = 1e-5;
if (std::abs(diff) < tol) {
break;
}

const Float_ diff = entropy - log_perplexity;
if (std::abs(diff) < tol) {
break;
}
// Refining the search interval for a (potential) binary search
// later. We know that the entropy is monotonic decreasing with
// increasing beta, so if the difference from the target is
// positive, the current beta must be on the left of the root,
// and vice versa if the difference is negative.
if (diff > 0) {
min_beta = beta;
} else {
max_beta = beta;
}

bool nr_ok = false;
if constexpr(use_newton_) {
// Attempt a Newton-Raphson search first. Note to self: derivative was a bit
// painful but pops out nicely enough, use R's D() to prove it to yourself
// in the simple case of K = 2 where d0, d1 are the squared deltas.
// > D(expression(b * (d0 * exp(- b * d0) + d1 * exp(- b * d1)) / (exp(-b*d0) + exp(-b*d1)) + log(exp(-b*d0) + exp(-b*d1))), name="b")
bool nr_ok = false;
#ifndef QDTSNE_BETA_BINARY_SEARCH_ONLY
const Float_ prod2 = std::inner_product(quad_delta_dist.begin() + 1, quad_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0)); // again, skipping first where delta^2 = 0.
const Float_ prod2 = std::inner_product(quad_delta_dist.begin() + 1, quad_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0));
const Float_ d1 = - beta / sum_P * (prod2 - prod * prod / sum_P);

if (d1) {
const Float_ alt_beta = beta - (diff / d1); // if it overflows, we should get Inf or -Inf, so the following comparison should be fine.
if (alt_beta > min_beta && alt_beta < max_beta) {
beta = alt_beta;
nr_ok = true;
}
}
#endif
}

// Otherwise do a binary search.
if (!nr_ok) {
if (diff > 0) {
min_beta = beta;
if (max_beta == max_value) {
beta *= static_cast<Float_>(2);
} else {
beta += (max_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
}
if (!nr_ok) {
// Doing the binary search, if Newton's failed or was not requested.
if (diff > 0) {
if (max_beta == max_value) {
beta *= static_cast<Float_>(2);
} else {
max_beta = beta;
beta += (min_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
beta += (max_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
}
}

if (std::isinf(beta)) {
// Avoid propagation of NaNs via Inf * 0.
for (int m = 1; m < K; ++m) {
output[m] = (squared_delta_dist[m] == 0);
}
break;
} else {
beta += (min_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
}
}

// Row-normalize current row of P.
for (int m = 0; m < K; ++m) {
current[m].second = output[m] / sum_P;
if (std::isinf(beta)) {
// Avoid propagation of NaNs via Inf * 0.
for (int m = 1; m < K; ++m) {
output[m] = (squared_delta_dist[m] == 0);
}
break;
}
}

// Row-normalize current row of P.
for (int m = 0; m < K; ++m) {
current[m].second = output[m] / sum_P;
}
}
});

Expand Down
2 changes: 1 addition & 1 deletion tests/R/package/src/test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "Rcpp.h"

#define QDTSNE_BETA_BINARY_SEARCH_ONLY
#define QDTSNE_R_PACKAGE_TESTING
#include "qdtsne/qdtsne.hpp"

// [[Rcpp::export(rng=false)]]
Expand Down
112 changes: 76 additions & 36 deletions tests/src/gaussian.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,23 @@

class GaussianTest : public ::testing::TestWithParam<std::tuple<int, double> > {
protected:
inline static std::unique_ptr<knncolle::Prebuilt<int, int, double> > index;
inline static int N = 200;
inline static int D = 5;
inline static std::vector<double> X;

static void SetUpTestSuite() {
X.resize(N * D);
int D = 5;
std::vector<double> X(N * D);
std::mt19937_64 rng(42);
std::normal_distribution<> dist(0, 1);
for (auto& y : X) {
y = dist(rng);
}
index = knncolle::VptreeBuilder().build_unique(knncolle::SimpleMatrix(D, N, X.data()));
return;
}
};

TEST_P(GaussianTest, Gaussian) {
auto PARAM = GetParam();
size_t K = std::get<0>(PARAM);
double P = std::get<1>(PARAM);

auto index = knncolle::VptreeBuilder().build_unique(knncolle::SimpleMatrix(D, N, X.data()));
qdtsne::NeighborList<int, double> neighbors(N);
{
static qdtsne::NeighborList<int, double> get_neighbors(int K) {
qdtsne::NeighborList<int, double> neighbors(N);
std::vector<int> indices;
std::vector<double> distances;
auto searcher = index->initialize();
Expand All @@ -41,8 +35,16 @@ TEST_P(GaussianTest, Gaussian) {
neighbors[i].emplace_back(indices[k], distances[k]);
}
}
return neighbors;
}
};

TEST_P(GaussianTest, Newton) {
auto PARAM = GetParam();
size_t K = std::get<0>(PARAM);
double P = std::get<1>(PARAM);

auto neighbors = get_neighbors(K);
auto copy = neighbors;
qdtsne::internal::compute_gaussian_perplexity(neighbors, P, 1);
const double expected = std::log(P);
Expand All @@ -56,9 +58,9 @@ TEST_P(GaussianTest, Gaussian) {
sum += x.second;
}
entropy *= -1;
EXPECT_TRUE(std::abs(expected - entropy) < 1e-5);
EXPECT_TRUE(std::abs(sum - 1) < 1e-8);

EXPECT_LT(std::abs(expected - entropy), 1e-5);
EXPECT_LT(std::abs(sum - 1), 1e-8);
}

// Same result in parallel.
Expand All @@ -68,6 +70,34 @@ TEST_P(GaussianTest, Gaussian) {
}
}

TEST_P(GaussianTest, BinaryFallback) {
// We need to test the binary fallback explicitly because I can't figure
// out a scenario where Newton's fails... though I can't prove that it
// won't, hence the need for a fallback at all.
auto PARAM = GetParam();
size_t K = std::get<0>(PARAM);
double P = std::get<1>(PARAM);

auto neighbors = get_neighbors(K);
auto copy = neighbors;
qdtsne::internal::compute_gaussian_perplexity<false>(neighbors, P, 1);
const double expected = std::log(P);

// Checking that the entropy is within range.
for (int i = 0; i < N; ++i) {
double entropy = 0;
double sum = 0;
for (const auto& x : neighbors[i]) {
entropy += x.second * std::log(x.second);
sum += x.second;
}
entropy *= -1;

EXPECT_LT(std::abs(expected - entropy), 1e-5);
EXPECT_LT(std::abs(sum - 1), 1e-8);
}
}

INSTANTIATE_TEST_SUITE_P(
Gaussian,
GaussianTest,
Expand All @@ -77,39 +107,49 @@ INSTANTIATE_TEST_SUITE_P(
)
);

TEST(GaussianTest, Overflow) {
{
qdtsne::NeighborList<int, float> neighbors(1);
auto& first = neighbors.front();
TEST(GaussianTest, Empty) {
qdtsne::NeighborList<int, float> neighbors(1);
qdtsne::internal::compute_gaussian_perplexity(neighbors, static_cast<float>(30), 1);
}

// Lots of ties causes the beta search to overflow.
for (size_t i = 0; i < 90; ++i) {
first.emplace_back(i, 1);
}
TEST(GaussianTest, AllEqualDistances) {
qdtsne::NeighborList<int, float> neighbors(1);
auto& first = neighbors.front();

// Lots of ties causes the beta search to overflow.
for (size_t i = 0; i < 90; ++i) {
first.emplace_back(i, 1);
}

qdtsne::internal::compute_gaussian_perplexity(neighbors, static_cast<float>(30), 1);
qdtsne::internal::compute_gaussian_perplexity(neighbors, static_cast<float>(30), 1);

// Expect finite probabilities.
for (auto& x : neighbors.front()) {
EXPECT_EQ(x.second, neighbors.front().front().second);
}
// Expect finite probabilities.
for (auto& x : neighbors.front()) {
EXPECT_FLOAT_EQ(x.second, 1.0 / 90.0);
}
}

{
TEST(GaussianTest, ConvergenceFailure) {
// Really cranking down the perplexity (and thus forcing the beta
// search to try to get an impossible entropy). We test multiple
// 'leads' to ensure that we handle ties on the first distance.
for (int leads = 1; leads < 10; leads += 5) {
qdtsne::NeighborList<int, float> neighbors(1);
auto& first = neighbors.front();

first.emplace_back(0, 1);
for (size_t i = 1; i < 90; ++i) {
for (int i = 0; i < leads; ++i) {
first.emplace_back(0, 1);
}
for (int i = leads; i < 90; ++i) {
first.emplace_back(i, 1.0000001);
}

// Really cranking down the perplexity (and thus forcing the beta
// search to try to get an unreachable entropy).
qdtsne::internal::compute_gaussian_perplexity(neighbors, static_cast<float>(1), 1);
qdtsne::internal::compute_gaussian_perplexity(neighbors, static_cast<float>(0.5), 1);

EXPECT_TRUE(first.front().second > 0);
for (size_t i = 1; i < 90; ++i) {
for (int i = 0; i < leads; ++i) {
EXPECT_FLOAT_EQ(first[i].second, 1.0/leads);
}
for (int i = leads; i < 90; ++i) {
EXPECT_EQ(first[i].second, 0);
}
}
Expand Down

0 comments on commit c62b953

Please sign in to comment.