Skip to content

Commit

Permalink
Truncation also works during serialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Oct 14, 2024
1 parent 0cec4fc commit 9404439
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
24 changes: 18 additions & 6 deletions js/findNearestNeighbors.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,16 @@ export class FindNearestNeighborsResults {
}

/**
* @param {object} [options={}] - Optional parameters.
* @param {?number} [options.truncate=null] - Maximum number of neighbors to count for each cell.
* If `null` or greater than the number of available neighbors, all neighbors are counted.
* @return {number} The total number of neighbors across all cells.
* This is usually the product of the number of neighbors and the number of cells.
*/
size() {
return this.#results.size();
size(options = {}) {
const { truncate = null, ...others } = options;
utils.checkOtherOptions(others);
return this.#results.size(FindNearestNeighborsResults.#numberToTruncate(truncate));
}

/**
Expand All @@ -141,6 +146,10 @@ export class FindNearestNeighborsResults {
return this.#results;
}

static #numberToTruncate(truncate) {
return (truncate === null ? -1 : truncate);
}

/**
* @param {object} [options={}] - Optional parameters.
* @param {?Int32WasmArray} [options.runs=null] - A Wasm-allocated array of length equal to `numberOfCells()`,
Expand All @@ -149,6 +158,8 @@ export class FindNearestNeighborsResults {
* to be used to store the indices of the neighbors of each cell.
* @param {?Float64WasmArray} [options.distances=null] - A Wasm-allocated array of length equal to `size()`,
* to be used to store the distances to the neighbors of each cell.
* @param {?number} [options.truncate=null] - Number of nearest neighbors to serialize for each cell.
* If `null` or greater than the number of available neighbors, all neighbors are used.
*
* @return {object}
* An object is returned with the `runs`, `indices` and `distances` keys, each with an appropriate TypedArray as the value.
Expand All @@ -159,14 +170,15 @@ export class FindNearestNeighborsResults {
* If only some of the arguments are non-`null`, an error is raised.
*/
serialize(options = {}) {
const { runs = null, indices = null, distances = null, ...others } = options;
const { runs = null, indices = null, distances = null, truncate = null, ...others } = options;
utils.checkOtherOptions(others);

var copy = (runs === null) + (indices === null) + (distances === null);
if (copy != 3 && copy != 0) {
throw new Error("either all or none of 'runs', 'indices' and 'distances' can be 'null'");
}

let nkeep = FindNearestNeighborsResults.#numberToTruncate(truncate);
var output;

if (copy === 3) {
Expand All @@ -176,10 +188,10 @@ export class FindNearestNeighborsResults {

try {
run_data = utils.createInt32WasmArray(this.numberOfCells());
let s = this.size();
let s = this.#results.size(nkeep);
ind_data = utils.createInt32WasmArray(s);
dist_data = utils.createFloat64WasmArray(s);
this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset);
this.#results.serialize(run_data.offset, ind_data.offset, dist_data.offset, nkeep);

output = {
"runs": run_data.slice(),
Expand All @@ -193,7 +205,7 @@ export class FindNearestNeighborsResults {
}

} else {
this.#results.serialize(runs.offset, indices.offset, distances.offset);
this.#results.serialize(runs.offset, indices.offset, distances.offset, nkeep);
output = {
"runs": runs.array(),
"indices": indices.array(),
Expand Down
15 changes: 10 additions & 5 deletions src/NeighborIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define NEIGHBOR_INDEX_H

#include <memory>
#include <algorithm>
#include <vector>

#include "knncolle/knncolle.hpp"
Expand Down Expand Up @@ -44,10 +45,11 @@ struct NeighborResults {
}

public:
size_t size() const {
size_t size(int32_t truncate) const {
size_t out = 0;
size_t long_truncate = truncate;
for (const auto& current : neighbors) {
out += current.size();
out += std::min(long_truncate, current.size());
}
return out;
}
Expand All @@ -56,16 +58,19 @@ struct NeighborResults {
return neighbors.size();
}

void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances) const {
void serialize(uintptr_t runs, uintptr_t indices, uintptr_t distances, int32_t truncate) const {
auto rptr = reinterpret_cast<int32_t*>(runs);
auto iptr = reinterpret_cast<int32_t*>(indices);
auto dptr = reinterpret_cast<double*>(distances);

size_t long_truncate = truncate;
for (const auto& current : neighbors) {
*rptr = current.size();
size_t nkeep = std::min(long_truncate, current.size());
*rptr = nkeep;
++rptr;

for (const auto& x : current) {
for (int32_t i = 0; i < nkeep; ++i) {
const auto& x = current[i];
*iptr = x.first;
*dptr = x.second;
++iptr;
Expand Down
6 changes: 6 additions & 0 deletions tests/findNearestNeighbors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,10 @@ test("neighbor search can be truncated", () => {
expect(tdump.indices[2]).toEqual(dump.indices[5]);
expect(tdump.indices[5]).toEqual(dump.indices[11]);
expect(tdump.indices[51]).toEqual(dump.indices[126]);

// Checking we get the same results with truncated serialization.
var tdump2 = res.serialize({ truncate: 2 });
expect(tdump2.runs).toEqual(tdump.runs);
expect(tdump2.indices).toEqual(tdump.indices);
expect(tdump2.distances).toEqual(tdump.distances);
})

0 comments on commit 9404439

Please sign in to comment.