Skip to content

Commit

Permalink
implement caching and reuse of embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
philippotto committed May 8, 2023
1 parent 14c40d9 commit d3c4fe9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class BoundingBox {
return min[0] <= x && x < max[0] && min[1] <= y && y < max[1] && min[2] <= z && z < max[2];
}

containsBoundingBox(other: BoundingBox) {
return other.equals(this.intersectedWith(other));
}

equals(other: BoundingBox) {
return V3.equals(this.min, other.min) && V3.equals(this.max, other.max);
}

intersectedWith(other: BoundingBox): BoundingBox {
const newMin = V3.max(this.min, other.min);
const uncheckedMax = V3.min(this.max, other.max);
Expand Down
176 changes: 98 additions & 78 deletions frontend/javascripts/oxalis/model/sagas/quick_select_saga.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
TypedArrayWithoutBigInt,
Vector2,
Vector3,
Viewport,
} from "oxalis/constants";
import PriorityQueue from "js-priority-queue";
import ErrorHandling from "libs/error_handling";
Expand Down Expand Up @@ -71,19 +72,53 @@ const TOAST_KEY = "QUICKSELECT_PREVIEW_MESSAGE";
// Used to determine the mean intensity.
const CENTER_RECT_SIZE_PERCENTAGE = 1 / 10;

let _embedding: Float32Array | null = null;
// let _hardcodedEmbedding: Float32Array | null = null;

const useHardcodedEmbedding = false;

const embeddingCache = [];
type CacheEntry = { embedding: Float32Array; bbox: BoundingBox };
const MAXIMUM_CACHE_SIZE = 5;
// Sorted from most recently to least recently used.
let embeddingCache: Array<CacheEntry> = [];

async function getEmbedding(dataset: APIDataset, boundingBox: BoundingBox, mag: Vector3) {
try {
// todo: use caching
_embedding = null;
if (_embedding == null) {
_embedding = new Float32Array(
async function getEmbedding(
dataset: APIDataset,
boundingBox: BoundingBox,
mag: Vector3,
activeViewport: OrthoView,
): Promise<CacheEntry> {
const matchingCacheEntry = embeddingCache.find((entry) =>
entry.bbox.containsBoundingBox(boundingBox),
);
if (matchingCacheEntry) {
// Move entry to the front.
embeddingCache = [
matchingCacheEntry,
...embeddingCache.filter((el) => el != matchingCacheEntry).slice(0, MAXIMUM_CACHE_SIZE - 1),
];
console.log("Use", matchingCacheEntry, "from cache.");
return matchingCacheEntry;
} else {
try {
const embeddingCenter = V3.round(boundingBox.getCenter());
const embeddingTopLeft = V3.sub(embeddingCenter, [512, 512, 0]);
const embeddingBottomRight = [
embeddingTopLeft[0] + 1024,
embeddingTopLeft[1] + 1024,
embeddingTopLeft[2],
] as Vector3;
const embeddingBoxMag1 = new BoundingBox({
min: V3.floor(V3.min(embeddingTopLeft, embeddingBottomRight)),
max: V3.floor(
V3.add(
V3.max(embeddingTopLeft, embeddingBottomRight),
Dimensions.transDim([0, 0, 1], activeViewport),
),
),
});
console.log("Load new embedding for ", embeddingBoxMag1);

const embedding = new Float32Array(
(await fetch(
`/api/datasets/${dataset.owningOrganization}/${dataset.name}/layers/color/segmentAnythingEmbedding`,
{
Expand All @@ -93,26 +128,19 @@ async function getEmbedding(dataset: APIDataset, boundingBox: BoundingBox, mag:
},
body: JSON.stringify({
mag,
boundingBox: boundingBox.asServerBoundingBox(),
boundingBox: embeddingBoxMag1.asServerBoundingBox(),
}),
},
).then((res) => res.arrayBuffer())) as ArrayBuffer,
);

// _hardcodedEmbedding = new Float32Array(
// (await fetch("/dist/paper_l4_embedding.bin").then((res) =>
// res.arrayBuffer(),
// )) as ArrayBuffer,
// );
const newEntry = { embedding, bbox: embeddingBoxMag1 };
embeddingCache.unshift(newEntry);
return newEntry;
} catch (exception) {
console.error(exception);
throw new Error("Could not load embedding. See console for details.");
}
// console.log("_hardcodedEmbedding", _hardcodedEmbedding.slice(0, 20));
// if (useHardcodedEmbedding) {
// return _hardcodedEmbedding;
// }
return _embedding;
} catch (exception) {
console.error(exception);
throw new Error("Could not load embedding. See console for details.");
}
}

Expand All @@ -127,7 +155,14 @@ async function getSession() {
return session;
}

async function inferFromEmbedding(embedding: Float32Array, topLeft: Vector3, bottomRight: Vector3) {
async function inferFromEmbedding(
embedding: Float32Array,
embeddingBoxInTargetMag: BoundingBox,
userBoxInTargetMag: BoundingBox,
) {
const topLeft = V3.sub(userBoxInTargetMag.min, embeddingBoxInTargetMag.min);
const bottomRight = V3.sub(userBoxInTargetMag.max, embeddingBoxInTargetMag.min);

const ort_session = await getSession();
const onnx_coord = useHardcodedEmbedding
? new Float32Array([topLeft[1], topLeft[0], bottomRight[1], bottomRight[0]])
Expand Down Expand Up @@ -158,7 +193,23 @@ async function inferFromEmbedding(embedding: Float32Array, topLeft: Vector3, bot
console.log("thresholded_mask", thresholded_mask);

// @ts-ignore
return new Uint8Array(thresholded_mask);
const thresholdFieldData = new Uint8Array(thresholded_mask);

const size = embeddingBoxInTargetMag.getSize();
const stride = [1, size[0], size[0] * size[1]];
console.log("stride", stride);
let thresholdField = ndarray(thresholdFieldData, size, stride);

if (useHardcodedEmbedding) {
thresholdField = thresholdField.transpose(1, 0, 2);
}

thresholdField = thresholdField
// a.lo(x,y) => a[x:, y:]
.lo(topLeft[0], topLeft[1], 0)
// a.hi(x,y) => a[:x, :y]
.hi(userBoxInTargetMag.getSize()[0], userBoxInTargetMag.getSize()[1], 1);
return thresholdField;
}

export default function* listenToQuickSelect(): Saga<void> {
Expand Down Expand Up @@ -226,36 +277,14 @@ function* performQuickSelect(action: ComputeQuickSelectForRectAction): Saga<void

const { startPosition, endPosition } = action;

const embeddingCenter = V3.round(V3.scale(V3.add(startPosition, endPosition), 0.5));
const embeddingTopLeft = V3.sub(embeddingCenter, [512, 512, 0]);

const embeddingBottomRight = [
embeddingTopLeft[0] + 1024,
embeddingTopLeft[1] + 1024,
embeddingTopLeft[2],
] as Vector3;

const relativeTopLeft = V3.sub(startPosition, embeddingTopLeft);
const relativeBottomRight = V3.sub(endPosition, embeddingTopLeft);

const userBoxMag1 = new BoundingBox({
min: V3.floor(V3.min(startPosition, endPosition)),
max: V3.floor(
V3.add(V3.max(startPosition, endPosition), Dimensions.transDim([0, 0, 1], activeViewport)),
),
});

// const layerBBox = yield* select((state) => getLayerBoundingBox(state.dataset, colorLayer.name));
const embeddingBoxMag1 = new BoundingBox({
min: V3.floor(V3.min(embeddingTopLeft, embeddingBottomRight)),
max: V3.floor(
V3.add(
V3.max(embeddingTopLeft, embeddingBottomRight),
Dimensions.transDim([0, 0, 1], activeViewport),
),
),
}); //.intersectedWith(layerBBox);

// const layerBBox = yield* select((state) => getLayerBoundingBox(state.dataset, colorLayer.name));
// Ensure that the third dimension is inclusive (otherwise, the center of the passed
// coordinates wouldn't be exactly on the W plane on which the user started this action).
const inclusiveMaxW = map3((el, idx) => (idx === thirdDim ? el - 1 : el), userBoxMag1.max);
Expand Down Expand Up @@ -292,14 +321,6 @@ function* performQuickSelect(action: ComputeQuickSelectForRectAction): Saga<void
);
const labeledResolution = resolutionInfo.getResolutionByIndexOrThrow(labeledZoomStep);

const embeddingBoxInTargetMag = embeddingBoxMag1.fromMag1ToMag(labeledResolution);
const userBoxInTargetMag = userBoxMag1.fromMag1ToMag(labeledResolution);

if (embeddingBoxInTargetMag.getVolume() === 0) {
Toast.warning("The drawn rectangular had a width or height of zero.");
return;
}

// const inputDataRaw = yield* call(
// [api.data, api.data.getDataForBoundingBox],
// colorLayer.name,
Expand All @@ -314,14 +335,6 @@ function* performQuickSelect(action: ComputeQuickSelectForRectAction): Saga<void
// embeddingBottomRight,
// );

const size = embeddingBoxInTargetMag.getSize();
if (size.some((el) => el !== 1 && el !== 1024)) {
throw new Error("Incorrectly sized window");
}

const stride = [1, size[0], size[0] * size[1]];
console.log("stride", stride);

// if (inputDataRaw instanceof BigUint64Array) {
// throw new Error("Color input layer must not be 64-bit.");
// }
Expand All @@ -348,28 +361,35 @@ function* performQuickSelect(action: ComputeQuickSelectForRectAction): Saga<void
// later processed to a binary mask.

console.time("getEmbedding");
const embedding = yield* call(getEmbedding, dataset, embeddingBoxMag1, [1, 1, 1]);
const { embedding, bbox: embeddingBoxMag1 } = yield* call(
getEmbedding,
dataset,
userBoxMag1,
[1, 1, 1],
activeViewport,
);
console.timeEnd("getEmbedding");

const embeddingBoxInTargetMag = embeddingBoxMag1.fromMag1ToMag(labeledResolution);
const userBoxInTargetMag = userBoxMag1.fromMag1ToMag(labeledResolution);

if (embeddingBoxInTargetMag.getVolume() === 0) {
Toast.warning("The drawn rectangular had a width or height of zero.");
return;
}
const size = embeddingBoxInTargetMag.getSize();
if (size.some((el) => el !== 1 && el !== 1024)) {
throw new Error("Incorrectly sized window");
}

console.time("infer");
let thresholdFieldData = yield* call(
let thresholdField = yield* call(
inferFromEmbedding,
embedding,
relativeTopLeft,
relativeBottomRight,
embeddingBoxInTargetMag,
userBoxInTargetMag,
);
console.timeEnd("infer");
let thresholdField = ndarray(thresholdFieldData, size, stride);

if (useHardcodedEmbedding) {
thresholdField = thresholdField.transpose(1, 0, 2);
}

thresholdField = thresholdField
// a.lo(x,y) => a[x:, y:]
.lo(relativeTopLeft[0], relativeTopLeft[1], 0)
// a.hi(x,y) => a[:x, :y]
.hi(userBoxInTargetMag.getSize()[0], userBoxInTargetMag.getSize()[1], 1);

// if (initialDetectDarkSegment) {
// thresholdField = darkThresholdField;
Expand Down

0 comments on commit d3c4fe9

Please sign in to comment.