Skip to content

Commit

Permalink
Improve convergence of the Newton-based search for beta.
Browse files Browse the repository at this point in the history
We improve convergence by setting the initial beta based on the distance to the
furthest neighbor. This adjusts the initial guess to match the scale of the
distances, reducing the number of iterations to get to the right ballpark.

We also make sure to update the binary search bounds, even when Newton's is
working properly. This ensures that, if we ever need to fall back to a binary
search, we can leverage previous Newton successes to narrow the interval.
  • Loading branch information
LTLA committed Sep 4, 2024
1 parent 95d4d20 commit 17295fc
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 104 deletions.
193 changes: 126 additions & 67 deletions include/qdtsne/gaussian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,111 +13,170 @@ namespace qdtsne {

namespace internal {

template<typename Index_, typename Float_>
void compute_gaussian_perplexity(NeighborList<Index_, Float_>& neighbors, Float_ perplexity, [[maybe_unused]] int nthreads) {
constexpr Float_ max_value = std::numeric_limits<Float_>::max();
constexpr Float_ tol = 1e-5;

const size_t num_points = neighbors.size();
/**
* The aim of this function is to convert distances into probabilities
* using a Gaussian kernel. Given the following equations:
*
* q_i = exp(-beta * (dist_i)^2)
* p_i = q_i / sum(q_i)
* entropy = -sum(p_i * log(p_i))
* = sum(beta * (dist_i)^2 * q_i) / sum(q_i) + log(sum(q_i))
*
* Where the sum is coputed over all neighbors 'i' for each obesrvatino.
* Our aim is to find 'beta' such that:
*
* entropy == target
*
* We using Newton's method with a fallback to a binary search if the former
* doesn't give sensible steps.
*
* NOTE: the QDTSNE_R_PACKAGE_TESTING macro recapitulates the gaussian kernel
* of the Rtsne package so that we can get a more precise comparison to a
* trusted reference implementation. It should not be defined in production.
*/

template<bool use_newton_ =
#ifndef QDTSNE_R_PACKAGE_TESTING
true
#else
false
#endif
, typename Index_, typename Float_>
void compute_gaussian_perplexity(NeighborList<Index_, Float_>& neighbors, Float_ perplexity, int nthreads) {
const size_t num_points = neighbors.size();
const Float_ log_perplexity = std::log(perplexity);

parallelize(nthreads, num_points, [&](int, size_t start, size_t length) -> void {
std::vector<Float_> squared_delta_dist;
std::vector<Float_> quad_delta_dist;
std::vector<Float_> output;
std::vector<Float_> prob_numerator; // i.e., the numerator of the probability.

for (size_t n = start, end = start + length; n < end; ++n) {
auto& current = neighbors[n];
const int K = current.size();
if (K) {
squared_delta_dist.resize(K);
quad_delta_dist.resize(K);

// We adjust the probabilities by subtracting the first squared
// distance from everything. This avoids problems with underflow
// when converting distances to probabilities; it otherwise has no
// effect on the entropy or even the final probabilities because it
// just scales all probabilities up/down (and they need to be
// normalized anyway, so any scaling effect just cancels out).
const Float_ first = current[0].second;
const Float_ first2 = first * first;
if (K == 0) {
continue;
}

for (int m = 1; m < K; ++m) {
Float_ dist = current[m].second;
Float_ squared_delta_dist_raw = dist * dist - first2;
squared_delta_dist[m] = squared_delta_dist_raw;
quad_delta_dist[m] = squared_delta_dist_raw * squared_delta_dist_raw;
squared_delta_dist.resize(K);
quad_delta_dist.resize(K);

// We adjust the probabilities by subtracting the first squared
// distance from everything. This avoids problems with underflow
// when converting distances to probabilities; it otherwise has no
// effect on the entropy or even the final probabilities because it
// just scales all probabilities up/down (and they need to be
// normalized anyway, so any scaling effect just cancels out).
const Float_ first = current[0].second;
const Float_ first2 = first * first;

for (int m = 1; m < K; ++m) {
Float_ dist = current[m].second;
Float_ squared_delta_dist_raw = dist * dist - first2;
squared_delta_dist[m] = squared_delta_dist_raw;
quad_delta_dist[m] = squared_delta_dist_raw * squared_delta_dist_raw;
}

auto last_squared_delta = squared_delta_dist.back();
if (last_squared_delta == 0) { // quitting early as entropy doesn't depend on beta.
for (auto& x : current) {
x.second = 1.0 / K;
}
continue;
}

Float_ beta = 1.0;
Float_ min_beta = 0, max_beta = max_value;
Float_ sum_P = 0;
output.resize(K);
output[0] = 1;
// Choosing an initial beta that matches the scale of the (squared) distances.
// The choice of numerator is largely based on trial and error to see what
// minimizes the number of iterations in some simulated data.
Float_ beta =
#ifndef QDTSNE_R_PACKAGE_TESTING
3.0 / last_squared_delta
#else
1
#endif
;

constexpr Float_ max_value = std::numeric_limits<Float_>::max();
Float_ min_beta = 0, max_beta = max_value;
Float_ sum_P = 0;
prob_numerator.resize(K);
prob_numerator[0] = 1;

constexpr int max_iter = 200;
for (int iter = 0; iter < max_iter; ++iter) {
// We skip the first value because we know that squared_delta_dist[0] = 0
// (as we subtracted 'first') and thus prob_numerator[0] = 1. We repeat this for
// all iterations from [1, K), e.g., squared_delta_dist, quad_delta_dist.
for (int m = 1; m < K; ++m) {
prob_numerator[m] = std::exp(-beta * squared_delta_dist[m]);
}

for (int iter = 0; iter < 200; ++iter) {
// We skip the first value because we know that squared_delta_dist[0] = 0
// (as we subtracted 'first') and thus output[0] = 1.
for (int m = 1; m < K; ++m) {
output[m] = std::exp(-beta * squared_delta_dist[m]);
}
sum_P = std::accumulate(prob_numerator.begin() + 1, prob_numerator.end(), static_cast<Float_>(1));
const Float_ prod = std::inner_product(squared_delta_dist.begin() + 1, squared_delta_dist.end(), prob_numerator.begin() + 1, static_cast<Float_>(0));
const Float_ entropy = beta * (prod / sum_P) + std::log(sum_P);

sum_P = std::accumulate(output.begin() + 1, output.end(), static_cast<Float_>(1));
const Float_ prod = std::inner_product(squared_delta_dist.begin() + 1, squared_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0));
const Float_ entropy = beta * (prod / sum_P) + std::log(sum_P);
const Float_ diff = entropy - log_perplexity;
constexpr Float_ tol = 1e-5;
if (std::abs(diff) < tol) {
break;
}

const Float_ diff = entropy - log_perplexity;
if (std::abs(diff) < tol) {
break;
}
// Refining the search interval for a (potential) binary search
// later. We know that the entropy is monotonic decreasing with
// increasing beta, so if the difference from the target is
// positive, the current beta must be on the left of the root,
// and vice versa if the difference is negative.
if (diff > 0) {
min_beta = beta;
} else {
max_beta = beta;
}

bool nr_ok = false;
if constexpr(use_newton_) {
// Attempt a Newton-Raphson search first. Note to self: derivative was a bit
// painful but pops out nicely enough, use R's D() to prove it to yourself
// in the simple case of K = 2 where d0, d1 are the squared deltas.
// > D(expression(b * (d0 * exp(- b * d0) + d1 * exp(- b * d1)) / (exp(-b*d0) + exp(-b*d1)) + log(exp(-b*d0) + exp(-b*d1))), name="b")
bool nr_ok = false;
#ifndef QDTSNE_BETA_BINARY_SEARCH_ONLY
const Float_ prod2 = std::inner_product(quad_delta_dist.begin() + 1, quad_delta_dist.end(), output.begin() + 1, static_cast<Float_>(0)); // again, skipping first where delta^2 = 0.
const Float_ prod2 = std::inner_product(quad_delta_dist.begin() + 1, quad_delta_dist.end(), prob_numerator.begin() + 1, static_cast<Float_>(0));
const Float_ d1 = - beta / sum_P * (prod2 - prod * prod / sum_P);

if (d1) {
const Float_ alt_beta = beta - (diff / d1); // if it overflows, we should get Inf or -Inf, so the following comparison should be fine.
if (alt_beta > min_beta && alt_beta < max_beta) {
beta = alt_beta;
nr_ok = true;
}
}
#endif
}

// Otherwise do a binary search.
if (!nr_ok) {
if (diff > 0) {
min_beta = beta;
if (max_beta == max_value) {
beta *= static_cast<Float_>(2);
} else {
beta += (max_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
}
if (!nr_ok) {
// Doing the binary search, if Newton's failed or was not requested.
if (diff > 0) {
if (max_beta == max_value) {
beta *= static_cast<Float_>(2);
} else {
max_beta = beta;
beta += (min_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
beta += (max_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential overflow.
}
}

if (std::isinf(beta)) {
// Avoid propagation of NaNs via Inf * 0.
for (int m = 1; m < K; ++m) {
output[m] = (squared_delta_dist[m] == 0);
}
break;
} else {
beta += (min_beta - beta) / static_cast<Float_>(2); // i.e., midpoint that avoids problems with potential underflow.
}
}

// Row-normalize current row of P.
for (int m = 0; m < K; ++m) {
current[m].second = output[m] / sum_P;
if (std::isinf(beta)) {
// Avoid propagation of NaNs via Inf * 0.
for (int m = 1; m < K; ++m) {
prob_numerator[m] = (squared_delta_dist[m] == 0);
}
sum_P = std::accumulate(prob_numerator.begin(), prob_numerator.end(), static_cast<Float_>(0));
break;
}
}

for (int m = 0; m < K; ++m) {
current[m].second = prob_numerator[m] / sum_P;
}
}
});

Expand Down
2 changes: 1 addition & 1 deletion tests/R/package/src/test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "Rcpp.h"

#define QDTSNE_BETA_BINARY_SEARCH_ONLY
#define QDTSNE_R_PACKAGE_TESTING
#include "qdtsne/qdtsne.hpp"

// [[Rcpp::export(rng=false)]]
Expand Down
Loading

0 comments on commit 17295fc

Please sign in to comment.