diff --git a/app/packages/looker/src/lookers/abstract.ts b/app/packages/looker/src/lookers/abstract.ts index 7d18e8bddb..3eca774de8 100644 --- a/app/packages/looker/src/lookers/abstract.ts +++ b/app/packages/looker/src/lookers/abstract.ts @@ -293,6 +293,11 @@ export abstract class AbstractLooker< return; } + if (this.state.destroyed && this.sampleOverlays) { + // close all current overlays + this.pluckedOverlays.forEach((overlay) => overlay.cleanup?.()); + } + if ( !this.state.windowBBox || this.state.destroyed || diff --git a/app/packages/looker/src/lookers/frame-reader.ts b/app/packages/looker/src/lookers/frame-reader.ts index a85472e7e4..ce0489dc43 100644 --- a/app/packages/looker/src/lookers/frame-reader.ts +++ b/app/packages/looker/src/lookers/frame-reader.ts @@ -52,7 +52,15 @@ interface AcquireReaderOptions { export const { acquireReader, clearReader } = (() => { const createCache = (removeFrame: RemoveFrame) => { return new LRUCache({ - dispose: (_, key) => removeFrame(key), + dispose: (frame, key) => { + const overlays = frame.overlays; + + for (let i = 0; i < overlays.length; i++) { + overlays[i].cleanup?.(); + } + + removeFrame(key); + }, max: MAX_FRAME_STREAM_SIZE, maxSize: MAX_FRAME_STREAM_SIZE_BYTES, noDisposeOnSet: true, diff --git a/app/packages/looker/src/overlays/base.ts b/app/packages/looker/src/overlays/base.ts index fd817ecf9d..a3ec867766 100644 --- a/app/packages/looker/src/overlays/base.ts +++ b/app/packages/looker/src/overlays/base.ts @@ -3,6 +3,7 @@ */ import { getCls, sizeBytesEstimate } from "@fiftyone/utilities"; +import { OverlayMask } from "../numpy"; import type { BaseState, Coordinates, NONFINITE } from "../state"; import { getLabelColor, shouldShowLabelTag } from "./util"; @@ -39,6 +40,11 @@ export interface SelectData { frameNumber?: number; } +export type LabelMask = { + bitmap?: ImageBitmap; + data?: OverlayMask; +}; + export interface RegularLabel extends BaseLabel { _id?: string; label?: string; @@ -67,6 +73,7 @@ export interface Overlay> { getPoints(state: Readonly): Coordinates[]; getSelectData(state: Readonly): SelectData; getSizeBytes(): number; + cleanup?(): void; } export abstract class CoordinateOverlay< diff --git a/app/packages/looker/src/overlays/detection.ts b/app/packages/looker/src/overlays/detection.ts index 4930771692..ec6d45086f 100644 --- a/app/packages/looker/src/overlays/detection.ts +++ b/app/packages/looker/src/overlays/detection.ts @@ -4,17 +4,19 @@ import { NONFINITES } from "@fiftyone/utilities"; import { INFO_COLOR } from "../constants"; -import { OverlayMask } from "../numpy"; import { BaseState, BoundingBox, Coordinates, NONFINITE } from "../state"; import { distanceFromLineSegment } from "../util"; -import { CONTAINS, CoordinateOverlay, PointInfo, RegularLabel } from "./base"; +import { + CONTAINS, + CoordinateOverlay, + LabelMask, + PointInfo, + RegularLabel, +} from "./base"; import { t } from "./util"; export interface DetectionLabel extends RegularLabel { - mask?: { - data: OverlayMask; - image: ArrayBuffer; - }; + mask?: LabelMask; bounding_box: BoundingBox; // valid for 3D bounding boxes @@ -27,10 +29,8 @@ export interface DetectionLabel extends RegularLabel { export default class DetectionOverlay< State extends BaseState > extends CoordinateOverlay { - private imageData: ImageData; private is3D: boolean; private labelBoundingBox: BoundingBox; - private canvas: HTMLCanvasElement; constructor(field, label) { super(field, label); @@ -40,32 +40,6 @@ export default class DetectionOverlay< } else { this.is3D = false; } - - if (this.label.mask) { - const [height, width] = this.label.mask.data.shape; - - if (!height || !width) { - return; - } - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.mask.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.mask.data.shape[1], - this.label.mask.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); - } } containsPoint(state: Readonly): CONTAINS { @@ -169,7 +143,7 @@ export default class DetectionOverlay< } private drawMask(ctx: CanvasRenderingContext2D, state: Readonly) { - if (!this.canvas) { + if (!this.label.mask?.bitmap) { return; } @@ -177,8 +151,9 @@ export default class DetectionOverlay< const [x, y] = t(state, tlx, tly); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; + ctx.imageSmoothingEnabled = false; ctx.drawImage( - this.canvas, + this.label.mask.bitmap, x, y, w * state.canvasBBox[2], @@ -285,6 +260,13 @@ export default class DetectionOverlay< const oh = state.strokeWidth / state.canvasBBox[3]; return [(bx - ow) * w, (by - oh) * h, (bw + ow * 2) * w, (bh + oh * 2) * h]; } + + public cleanup(): void { + if (this.label.mask?.bitmap) { + this.label.mask?.bitmap.close(); + this.label.mask.bitmap = null; + } + } } export const getDetectionPoints = (labels: DetectionLabel[]): Coordinates[] => { diff --git a/app/packages/looker/src/overlays/heatmap.ts b/app/packages/looker/src/overlays/heatmap.ts index c53a3ad971..e8e8817643 100644 --- a/app/packages/looker/src/overlays/heatmap.ts +++ b/app/packages/looker/src/overlays/heatmap.ts @@ -6,14 +6,16 @@ import { getColor, getRGBA, getRGBAColor, + sizeBytesEstimate, } from "@fiftyone/utilities"; -import { ARRAY_TYPES, OverlayMask, TypedArray } from "../numpy"; +import { ARRAY_TYPES, TypedArray } from "../numpy"; import { BaseState, Coordinates } from "../state"; import { isFloatArray } from "../util"; import { clampedIndex } from "../worker/painter"; import { BaseLabel, CONTAINS, + LabelMask, Overlay, PointInfo, SelectData, @@ -21,13 +23,8 @@ import { } from "./base"; import { strokeCanvasRect, t } from "./util"; -interface HeatMap { - data: OverlayMask; - image: ArrayBuffer; -} - interface HeatmapLabel extends BaseLabel { - map?: HeatMap; + map?: LabelMask; range?: [number, number]; } @@ -45,8 +42,6 @@ export default class HeatmapOverlay private label: HeatmapLabel; private targets?: TypedArray; private readonly range: [number, number]; - private canvas: HTMLCanvasElement; - private imageData: ImageData; constructor(field: string, label: HeatmapLabel) { this.field = field; @@ -68,25 +63,6 @@ export default class HeatmapOverlay if (!width || !height) { return; } - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.map.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.map.data.shape[1], - this.label.map.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); } containsPoint(state: Readonly): CONTAINS { @@ -101,22 +77,12 @@ export default class HeatmapOverlay } draw(ctx: CanvasRenderingContext2D, state: Readonly): void { - if (this.imageData) { - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.map.data.shape[1], - this.label.map.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); - + if (this.label.map?.bitmap) { const [tlx, tly] = t(state, 0, 0); const [brx, bry] = t(state, 1, 1); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; - ctx.drawImage(this.canvas, tlx, tly, brx - tlx, bry - tly); + ctx.drawImage(this.label.map.bitmap, tlx, tly, brx - tlx, bry - tly); ctx.globalAlpha = tmp; } @@ -235,6 +201,16 @@ export default class HeatmapOverlay return this.targets[index]; } + + getSizeBytes(): number { + return sizeBytesEstimate(this.label); + } + + public cleanup(): void { + if (this.label.map?.bitmap) { + this.label.map?.bitmap.close(); + } + } } export const getHeatmapPoints = (labels: HeatmapLabel[]): Coordinates[] => { diff --git a/app/packages/looker/src/overlays/segmentation.ts b/app/packages/looker/src/overlays/segmentation.ts index c55a8b5ef5..a4cb098254 100644 --- a/app/packages/looker/src/overlays/segmentation.ts +++ b/app/packages/looker/src/overlays/segmentation.ts @@ -2,12 +2,13 @@ * Copyright 2017-2024, Voxel51, Inc. */ -import { getColor } from "@fiftyone/utilities"; -import { ARRAY_TYPES, OverlayMask, TypedArray } from "../numpy"; +import { getColor, sizeBytesEstimate } from "@fiftyone/utilities"; +import { ARRAY_TYPES, TypedArray } from "../numpy"; import { BaseState, Coordinates, MaskTargets } from "../state"; import { BaseLabel, CONTAINS, + LabelMask, Overlay, PointInfo, SelectData, @@ -16,10 +17,7 @@ import { import { isRgbMaskTargets, strokeCanvasRect, t } from "./util"; interface SegmentationLabel extends BaseLabel { - mask?: { - data: OverlayMask; - image: ArrayBuffer; - }; + mask?: LabelMask; } interface SegmentationInfo extends BaseLabel { @@ -34,8 +32,6 @@ export default class SegmentationOverlay readonly field: string; private label: SegmentationLabel; private targets?: TypedArray; - private canvas: HTMLCanvasElement; - private imageData: ImageData; private isRgbMaskTargets = false; @@ -53,6 +49,7 @@ export default class SegmentationOverlay if (!this.label.mask) { return; } + const [height, width] = this.label.mask.data.shape; if (!height || !width) { @@ -62,25 +59,6 @@ export default class SegmentationOverlay this.targets = new ARRAY_TYPES[this.label.mask.data.arrayType]( this.label.mask.data.buffer ); - - this.canvas = document.createElement("canvas"); - this.canvas.width = width; - this.canvas.height = height; - - this.imageData = new ImageData( - new Uint8ClampedArray(this.label.mask.image), - width, - height - ); - const maskCtx = this.canvas.getContext("2d"); - maskCtx.imageSmoothingEnabled = false; - maskCtx.clearRect( - 0, - 0, - this.label.mask.data.shape[1], - this.label.mask.data.shape[0] - ); - maskCtx.putImageData(this.imageData, 0, 0); } containsPoint(state: Readonly): CONTAINS { @@ -99,12 +77,12 @@ export default class SegmentationOverlay return; } - if (this.imageData) { + if (this.label.mask?.bitmap) { const [tlx, tly] = t(state, 0, 0); const [brx, bry] = t(state, 1, 1); const tmp = ctx.globalAlpha; ctx.globalAlpha = state.options.alpha; - ctx.drawImage(this.canvas, tlx, tly, brx - tlx, bry - tly); + ctx.drawImage(this.label.mask.bitmap, tlx, tly, brx - tlx, bry - tly); ctx.globalAlpha = tmp; } @@ -278,6 +256,16 @@ export default class SegmentationOverlay } return this.targets[index]; } + + getSizeBytes(): number { + return sizeBytesEstimate(this.label); + } + + public cleanup(): void { + if (this.label.mask?.bitmap) { + this.label.mask?.bitmap.close(); + } + } } export const getSegmentationPoints = ( diff --git a/app/packages/looker/src/worker/decorated-fetch.test.ts b/app/packages/looker/src/worker/decorated-fetch.test.ts index 67ed853200..52fa49d21b 100644 --- a/app/packages/looker/src/worker/decorated-fetch.test.ts +++ b/app/packages/looker/src/worker/decorated-fetch.test.ts @@ -15,7 +15,7 @@ describe("fetchWithLinearBackoff", () => { expect(response).toBe(mockResponse); expect(global.fetch).toHaveBeenCalledTimes(1); - expect(global.fetch).toHaveBeenCalledWith("http://fiftyone.ai"); + expect(global.fetch).toHaveBeenCalledWith("http://fiftyone.ai", {}); }); it("should retry when fetch fails and eventually succeed", async () => { @@ -35,7 +35,14 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockRejectedValue(new Error("Network Error")); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 3, 10) + fetchWithLinearBackoff( + "http://fiftyone.ai", + {}, + { + retries: 3, + delay: 10, + } + ) ).rejects.toThrowError(new RegExp("Max retries for fetch reached")); expect(global.fetch).toHaveBeenCalledTimes(3); @@ -46,7 +53,14 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockResolvedValue(mockResponse); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 5, 10) + fetchWithLinearBackoff( + "http://fiftyone.ai", + {}, + { + retries: 5, + delay: 10, + } + ) ).rejects.toThrow("HTTP error: 500"); expect(global.fetch).toHaveBeenCalledTimes(5); @@ -57,7 +71,14 @@ describe("fetchWithLinearBackoff", () => { global.fetch = vi.fn().mockResolvedValue(mockResponse); await expect( - fetchWithLinearBackoff("http://fiftyone.ai", 5, 10) + fetchWithLinearBackoff( + "http://fiftyone.ai", + {}, + { + retries: 5, + delay: 10, + } + ) ).rejects.toThrow("Non-retryable HTTP error: 404"); expect(global.fetch).toHaveBeenCalledTimes(1); @@ -73,7 +94,11 @@ describe("fetchWithLinearBackoff", () => { vi.useFakeTimers(); - const fetchPromise = fetchWithLinearBackoff("http://fiftyone.ai", 5, 10); + const fetchPromise = fetchWithLinearBackoff( + "http://fiftyone.ai", + {}, + { retries: 5, delay: 10 } + ); // advance timers to simulate delays // after first delay diff --git a/app/packages/looker/src/worker/decorated-fetch.ts b/app/packages/looker/src/worker/decorated-fetch.ts index c77059d551..d01f0b48b2 100644 --- a/app/packages/looker/src/worker/decorated-fetch.ts +++ b/app/packages/looker/src/worker/decorated-fetch.ts @@ -3,6 +3,10 @@ const DEFAULT_BASE_DELAY = 200; // list of HTTP status codes that are client errors (4xx) and should not be retried const NON_RETRYABLE_STATUS_CODES = [400, 401, 403, 404, 405, 422]; +export interface RetryOptions { + retries: number; + delay: number; +} class NonRetryableError extends Error { constructor(message: string) { super(message); @@ -12,12 +16,15 @@ class NonRetryableError extends Error { export const fetchWithLinearBackoff = async ( url: string, - retries = DEFAULT_MAX_RETRIES, - delay = DEFAULT_BASE_DELAY + opts: RequestInit = {}, + retry: RetryOptions = { + retries: DEFAULT_MAX_RETRIES, + delay: DEFAULT_BASE_DELAY, + } ) => { - for (let i = 0; i < retries; i++) { + for (let i = 0; i < retry.retries; i++) { try { - const response = await fetch(url); + const response = await fetch(url, opts); if (response.ok) { return response; } else { @@ -35,8 +42,10 @@ export const fetchWithLinearBackoff = async ( // immediately throw throw e; } - if (i < retries - 1) { - await new Promise((resolve) => setTimeout(resolve, delay * (i + 1))); + if (i < retry.retries - 1) { + await new Promise((resolve) => + setTimeout(resolve, retry.delay * (i + 1)) + ); } else { // max retries reached throw new Error( diff --git a/app/packages/looker/src/worker/deserializer.ts b/app/packages/looker/src/worker/deserializer.ts index 02a7b03867..363522b01f 100644 --- a/app/packages/looker/src/worker/deserializer.ts +++ b/app/packages/looker/src/worker/deserializer.ts @@ -25,7 +25,6 @@ export const DeserializerFactory = { image: new ArrayBuffer(width * height * 4), }; buffers.push(data.buffer); - buffers.push(label.mask.image); } }, Detections: (labels, buffers) => { @@ -47,7 +46,6 @@ export const DeserializerFactory = { }; buffers.push(data.buffer); - buffers.push(label.map.image); } }, Segmentation: (label, buffers) => { @@ -63,7 +61,6 @@ export const DeserializerFactory = { }; buffers.push(data.buffer); - buffers.push(label.mask.image); } }, }; diff --git a/app/packages/looker/src/worker/disk-overlay-decoder.test.ts b/app/packages/looker/src/worker/disk-overlay-decoder.test.ts new file mode 100644 index 0000000000..dc7ea31fdf --- /dev/null +++ b/app/packages/looker/src/worker/disk-overlay-decoder.test.ts @@ -0,0 +1,205 @@ +import { getSampleSrc } from "@fiftyone/state/src/recoil/utils"; +import { DETECTIONS, HEATMAP } from "@fiftyone/utilities"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Coloring, CustomizeColor } from ".."; +import { LabelMask } from "../overlays/base"; +import type { Colorscale } from "../state"; +import { decodeWithCanvas } from "./canvas-decoder"; +import { decodeOverlayOnDisk, IntermediateMask } from "./disk-overlay-decoder"; +import { enqueueFetch } from "./pooled-fetch"; + +vi.mock("@fiftyone/state/src/recoil/utils", () => ({ + getSampleSrc: vi.fn(), +})); + +vi.mock("./pooled-fetch", () => ({ + enqueueFetch: vi.fn(), +})); + +vi.mock("./canvas-decoder", () => ({ + decodeWithCanvas: vi.fn(), +})); + +const COLORING = {} as Coloring; +const COLOR_SCALE = {} as Colorscale; +const CUSTOMIZE_COLOR_SETTING: CustomizeColor[] = []; +const SOURCES = {}; + +type MaskUnion = (IntermediateMask & LabelMask) | null; + +describe("decodeOverlayOnDisk", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should return early if label already has overlay field (not on disk)", async () => { + const field = "testField"; + const label = { mask: {}, mask_path: "shouldBeIgnored" }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(label.mask).toBeDefined(); + expect(enqueueFetch).not.toHaveBeenCalled(); + }); + + it("should fetch and decode overlay when label has overlay path field", async () => { + const field = "testField"; + const label = { mask_path: "/path/to/mask", mask: null as MaskUnion }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/mask"; + const mockBlob = new Blob(["mock data"], { type: "image/png" }); + const overlayMask = { shape: [100, 200] }; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(enqueueFetch).mockResolvedValue({ + blob: () => Promise.resolve(mockBlob), + } as Response); + vi.mocked(decodeWithCanvas).mockResolvedValue(overlayMask); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/mask"); + expect(enqueueFetch).toHaveBeenCalledWith({ + url: sampleSrcUrl, + options: { priority: "low" }, + }); + expect(decodeWithCanvas).toHaveBeenCalledWith(mockBlob); + expect(label.mask).toBeDefined(); + expect(label.mask.data).toBe(overlayMask); + expect(label.mask.image).toBeInstanceOf(ArrayBuffer); + expect(label.mask.image.byteLength).toBe(100 * 200 * 4); + }); + + it("should handle HEATMAP class", async () => { + const field = "testField"; + const label = { map_path: "/path/to/map", map: null as MaskUnion }; + const cls = HEATMAP; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/map"; + const mockBlob = new Blob(["mock data"], { type: "image/png" }); + const overlayMask = { shape: [100, 200] }; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(decodeWithCanvas).mockResolvedValue(overlayMask); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/map"); + expect(enqueueFetch).toHaveBeenCalledWith({ + url: sampleSrcUrl, + options: { priority: "low" }, + }); + expect(decodeWithCanvas).toHaveBeenCalledWith(mockBlob); + expect(label.map).toBeDefined(); + expect(label.map.data).toBe(overlayMask); + expect(label.map.image).toBeInstanceOf(ArrayBuffer); + expect(label.map.image.byteLength).toBe(100 * 200 * 4); + }); + + it("should handle DETECTIONS class and process detections recursively", async () => { + const field = "testField"; + const label = { + detections: [ + { mask_path: "/path/to/mask1", mask: null as MaskUnion }, + { mask_path: "/path/to/mask2", mask: null as MaskUnion }, + ], + }; + const cls = DETECTIONS; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl1 = "http://example.com/path/to/mask1"; + const sampleSrcUrl2 = "http://example.com/path/to/mask2"; + const overlayMask1 = { shape: [50, 50] }; + const overlayMask2 = { shape: [60, 60] }; + + vi.mocked(getSampleSrc) + .mockReturnValueOnce(sampleSrcUrl1) + .mockReturnValueOnce(sampleSrcUrl2); + vi.mocked(decodeWithCanvas) + .mockResolvedValueOnce(overlayMask1) + .mockResolvedValueOnce(overlayMask2); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + await Promise.all(maskPathDecodingPromises); + + expect(getSampleSrc).toHaveBeenNthCalledWith(1, "/path/to/mask1"); + expect(getSampleSrc).toHaveBeenNthCalledWith(2, "/path/to/mask2"); + expect(label.detections[0].mask).toBeDefined(); + expect(label.detections[0].mask.data).toBe(overlayMask1); + expect(label.detections[1].mask).toBeDefined(); + expect(label.detections[1].mask.data).toBe(overlayMask2); + }); + + it("should return early if fetch (with retry) fails", async () => { + const field = "testField"; + const label = { mask_path: "/path/to/mask", mask: null as MaskUnion }; + const cls = "Segmentation"; + const maskPathDecodingPromises: Promise[] = []; + + const sampleSrcUrl = "http://example.com/path/to/mask"; + + vi.mocked(getSampleSrc).mockReturnValue(sampleSrcUrl); + vi.mocked(enqueueFetch).mockRejectedValue(new Error("Fetch failed")); + + await decodeOverlayOnDisk( + field, + label, + COLORING, + CUSTOMIZE_COLOR_SETTING, + COLOR_SCALE, + SOURCES, + cls, + maskPathDecodingPromises + ); + + expect(getSampleSrc).toHaveBeenCalledWith("/path/to/mask"); + expect(enqueueFetch).toHaveBeenCalledWith({ + url: sampleSrcUrl, + options: { priority: "low" }, + }); + expect(decodeWithCanvas).not.toHaveBeenCalled(); + expect(label.mask).toBeNull(); + }); +}); diff --git a/app/packages/looker/src/worker/disk-overlay-decoder.ts b/app/packages/looker/src/worker/disk-overlay-decoder.ts new file mode 100644 index 0000000000..8730f74bf0 --- /dev/null +++ b/app/packages/looker/src/worker/disk-overlay-decoder.ts @@ -0,0 +1,121 @@ +import { getSampleSrc } from "@fiftyone/state/src/recoil/utils"; +import { DETECTION, DETECTIONS } from "@fiftyone/utilities"; +import { Coloring, CustomizeColor } from ".."; +import { OverlayMask } from "../numpy"; +import { Colorscale } from "../state"; +import { decodeWithCanvas } from "./canvas-decoder"; +import { enqueueFetch } from "./pooled-fetch"; +import { getOverlayFieldFromCls } from "./shared"; + +export type IntermediateMask = { + data: OverlayMask; + image: ArrayBuffer; +}; + +/** + * Some label types (example: segmentation, heatmap) can have their overlay data stored on-disk, + * we want to impute the relevant mask property of these labels from what's stored in the disk + */ +export const decodeOverlayOnDisk = async ( + field: string, + label: Record, + coloring: Coloring, + customizeColorSetting: CustomizeColor[], + colorscale: Colorscale, + sources: { [path: string]: string }, + cls: string, + maskPathDecodingPromises: Promise[] = [], + maskTargetsBuffers: ArrayBuffer[] = [] +) => { + // handle all list types here + if (cls === DETECTIONS) { + const promises: Promise[] = []; + for (const detection of label.detections) { + promises.push( + decodeOverlayOnDisk( + field, + detection, + coloring, + customizeColorSetting, + colorscale, + {}, + DETECTION, + maskPathDecodingPromises, + maskTargetsBuffers + ) + ); + } + maskPathDecodingPromises.push(...promises); + } + + const overlayFields = getOverlayFieldFromCls(cls); + const overlayPathField = overlayFields.disk; + const overlayField = overlayFields.canonical; + + if (Boolean(label[overlayField]) || !Object.hasOwn(label, overlayPathField)) { + // it's possible we're just re-coloring, in which case re-init mask image and set bitmap to null + if ( + label[overlayField] && + label[overlayField].bitmap && + !label[overlayField].image + ) { + const height = label[overlayField].bitmap.height; + const width = label[overlayField].bitmap.width; + label[overlayField].image = new ArrayBuffer(height * width * 4); + label[overlayField].bitmap.close(); + label[overlayField].bitmap = null; + } + // nothing to be done + return; + } + + // convert absolute file path to a URL that we can "fetch" from + const overlayImageUrl = getSampleSrc( + sources[`${field}.${overlayPathField}`] || label[overlayPathField] + ); + const urlTokens = overlayImageUrl.split("?"); + + let baseUrl = overlayImageUrl; + + // remove query params if not local URL + if (!urlTokens.at(1)?.startsWith("filepath=")) { + baseUrl = overlayImageUrl.split("?")[0]; + } + + let overlayImageBlob: Blob; + try { + const overlayImageFetchResponse = await enqueueFetch({ + url: baseUrl, + options: { priority: "low" }, + }); + overlayImageBlob = await overlayImageFetchResponse.blob(); + } catch (e) { + console.error(e); + // skip decoding if fetch fails altogether + return; + } + + let overlayMask: OverlayMask; + + try { + overlayMask = await decodeWithCanvas(overlayImageBlob); + } catch (e) { + console.error(e); + return; + } + + const [overlayHeight, overlayWidth] = overlayMask.shape; + + // set the `mask` property for this label + // we need to do this because we need raw image pixel data + // to iterate through and paint it with the color + // defined by the user for this particular label + label[overlayField] = { + data: overlayMask, + image: new ArrayBuffer(overlayWidth * overlayHeight * 4), + } as IntermediateMask; + + // no need to transfer image's buffer + //since we'll be constructing ImageBitmap and transfering that + maskTargetsBuffers.push(overlayMask.buffer); +}; diff --git a/app/packages/looker/src/worker/index.ts b/app/packages/looker/src/worker/index.ts index 21859407e2..dcf0b2e79b 100644 --- a/app/packages/looker/src/worker/index.ts +++ b/app/packages/looker/src/worker/index.ts @@ -2,14 +2,12 @@ * Copyright 2017-2024, Voxel51, Inc. */ -import { getSampleSrc } from "@fiftyone/state/src/recoil/utils"; import { DENSE_LABELS, DETECTION, DETECTIONS, DYNAMIC_EMBEDDED_DOCUMENT, EMBEDDED_DOCUMENT, - HEATMAP, LABEL_LIST, Schema, Stage, @@ -29,11 +27,10 @@ import { LabelTagColor, Sample, } from "../state"; -import { decodeWithCanvas } from "./canvas-decoder"; -import { fetchWithLinearBackoff } from "./decorated-fetch"; import { DeserializerFactory } from "./deserializer"; +import { decodeOverlayOnDisk } from "./disk-overlay-decoder"; import { PainterFactory } from "./painter"; -import { mapId } from "./shared"; +import { getOverlayFieldFromCls, mapId } from "./shared"; import { process3DLabels } from "./threed-label-processor"; interface ResolveColor { @@ -97,89 +94,15 @@ const painterFactory = PainterFactory(requestColor); const ALL_VALID_LABELS = new Set(VALID_LABEL_TYPES); /** - * Some label types (example: segmentation, heatmap) can have their overlay data stored on-disk, - * we want to impute the relevant mask property of these labels from what's stored in the disk + * This function processes labels in a recursive manner. It follows the following steps: + * 1. Deserialize masks. Accumulate promises. + * 2. Await mask path decoding to finish. + * 3. Start painting overlays. Accumulate promises. + * 4. Await overlay painting to finish. + * 5. Start bitmap generation. Accumulate promises. + * 6. Await bitmap generation to finish. + * 7. Transfer bitmaps and mask targets array buffers back to the main thread. */ -const imputeOverlayFromPath = async ( - field: string, - label: Record, - coloring: Coloring, - customizeColorSetting: CustomizeColor[], - colorscale: Colorscale, - buffers: ArrayBuffer[], - sources: { [path: string]: string }, - cls: string, - maskPathDecodingPromises: Promise[] = [] -) => { - // handle all list types here - if (cls === DETECTIONS) { - const promises: Promise[] = []; - for (const detection of label.detections) { - promises.push( - imputeOverlayFromPath( - field, - detection, - coloring, - customizeColorSetting, - colorscale, - buffers, - {}, - DETECTION - ) - ); - } - maskPathDecodingPromises.push(...promises); - } - - // overlay path is in `map_path` property for heatmap, or else, it's in `mask_path` property (for segmentation or detection) - const overlayPathField = cls === HEATMAP ? "map_path" : "mask_path"; - const overlayField = overlayPathField === "map_path" ? "map" : "mask"; - - if ( - Object.hasOwn(label, overlayField) || - !Object.hasOwn(label, overlayPathField) - ) { - // nothing to be done - return; - } - - // convert absolute file path to a URL that we can "fetch" from - const overlayImageUrl = getSampleSrc( - sources[`${field}.${overlayPathField}`] || label[overlayPathField] - ); - const urlTokens = overlayImageUrl.split("?"); - - let baseUrl = overlayImageUrl; - - // remove query params if not local URL - if (!urlTokens.at(1)?.startsWith("filepath=")) { - baseUrl = overlayImageUrl.split("?")[0]; - } - - let overlayImageBlob: Blob; - try { - const overlayImageFetchResponse = await fetchWithLinearBackoff(baseUrl); - overlayImageBlob = await overlayImageFetchResponse.blob(); - } catch (e) { - console.error(e); - // skip decoding if fetch fails altogether - return; - } - - const overlayMask = await decodeWithCanvas(overlayImageBlob); - const [overlayHeight, overlayWidth] = overlayMask.shape; - - // set the `mask` property for this label - label[overlayField] = { - data: overlayMask, - image: new ArrayBuffer(overlayWidth * overlayHeight * 4), - }; - - // transfer buffers - buffers.push(overlayMask.buffer); - buffers.push(label[overlayField].image); -}; - const processLabels = async ( sample: ProcessSample["sample"], coloring: ProcessSample["coloring"], @@ -190,13 +113,13 @@ const processLabels = async ( labelTagColors: ProcessSample["labelTagColors"], selectedLabelTags: ProcessSample["selectedLabelTags"], schema: Schema -): Promise => { - const buffers: ArrayBuffer[] = []; - const painterPromises = []; +): Promise<[Promise[], ArrayBuffer[]]> => { + const maskPathDecodingPromises: Promise[] = []; + const painterPromises: Promise[] = []; + const bitmapPromises: Promise[] = []; + const maskTargetsBuffers: ArrayBuffer[] = []; - const maskPathDecodingPromises = []; - - // mask deserialization / mask_path decoding loop + // mask deserialization / on-disk overlay decoding loop for (const field in sample) { let labels = sample[field]; if (!Array.isArray(labels)) { @@ -204,6 +127,10 @@ const processLabels = async ( } const cls = getCls(`${prefix ? prefix : ""}${field}`, schema); + if (!cls) { + continue; + } + for (const label of labels) { if (!label) { continue; @@ -211,37 +138,39 @@ const processLabels = async ( if (DENSE_LABELS.has(cls)) { maskPathDecodingPromises.push( - imputeOverlayFromPath( + decodeOverlayOnDisk( `${prefix || ""}${field}`, label, coloring, customizeColorSetting, colorscale, - buffers, sources, cls, - maskPathDecodingPromises + maskPathDecodingPromises, + maskTargetsBuffers ) ); } if (cls in DeserializerFactory) { - DeserializerFactory[cls](label, buffers); + DeserializerFactory[cls](label, maskTargetsBuffers); } if ([EMBEDDED_DOCUMENT, DYNAMIC_EMBEDDED_DOCUMENT].includes(cls)) { - const moreBuffers = await processLabels( - label, - coloring, - `${prefix ? prefix : ""}${field}.`, - sources, - customizeColorSetting, - colorscale, - labelTagColors, - selectedLabelTags, - schema - ); - buffers.push(...moreBuffers); + const [moreBitmapPromises, moreMaskTargetsBuffers] = + await processLabels( + label, + coloring, + `${prefix ? prefix : ""}${field}.`, + sources, + customizeColorSetting, + colorscale, + labelTagColors, + selectedLabelTags, + schema + ); + bitmapPromises.push(...moreBitmapPromises); + maskTargetsBuffers.push(...moreMaskTargetsBuffers); } if (ALL_VALID_LABELS.has(cls)) { @@ -261,11 +190,17 @@ const processLabels = async ( // overlay painting loop for (const field in sample) { let labels = sample[field]; + if (!Array.isArray(labels)) { labels = [labels]; } + const cls = getCls(`${prefix ? prefix : ""}${field}`, schema); + if (!cls) { + continue; + } + for (const label of labels) { if (!label) { continue; @@ -286,7 +221,72 @@ const processLabels = async ( } } - return Promise.all(painterPromises).then(() => buffers); + await Promise.allSettled(painterPromises); + + // bitmap generation loop + for (const field in sample) { + let labels = sample[field]; + + if (!Array.isArray(labels)) { + labels = [labels]; + } + + const cls = getCls(`${prefix ? prefix : ""}${field}`, schema); + + if (!cls) { + continue; + } + + for (const label of labels) { + if (!label) { + continue; + } + + collectBitmapPromises(label, cls, bitmapPromises); + } + } + + return [bitmapPromises, maskTargetsBuffers]; +}; + +const collectBitmapPromises = (label, cls, bitmapPromises) => { + if (cls === DETECTIONS) { + label?.detections?.forEach((detection) => + collectBitmapPromises(detection, DETECTION, bitmapPromises) + ); + return; + } + + const overlayFields = getOverlayFieldFromCls(cls); + const overlayField = overlayFields.canonical; + + if (label[overlayField]) { + const [height, width] = label[overlayField].data.shape; + + if (!height || !width) { + label[overlayField].image = null; + return; + } + + const imageData = new ImageData( + new Uint8ClampedArray(label[overlayField].image), + width, + height + ); + + // set raw image to null - will be garbage collected + // we don't need it anymore since we copied to ImageData + label[overlayField].image = null; + + bitmapPromises.push( + new Promise((resolve) => { + createImageBitmap(imageData).then((imageBitmap) => { + label[overlayField].bitmap = imageBitmap; + resolve(imageBitmap); + }); + }) + ); + } }; /** GLOBALS */ @@ -316,7 +316,7 @@ export interface ProcessSample { type ProcessSampleMethod = ReaderMethod & ProcessSample; -const processSample = ({ +const processSample = async ({ sample, uuid, coloring, @@ -329,48 +329,68 @@ const processSample = ({ }: ProcessSample) => { mapId(sample); - let bufferPromises = []; + const imageBitmapPromises: Promise[] = []; + let maskTargetsBuffers: ArrayBuffer[] = []; if (sample?._media_type === "point-cloud" || sample?._media_type === "3d") { process3DLabels(schema, sample); } else { - bufferPromises = [ - processLabels( - sample, - coloring, - null, - sources, - customizeColorSetting, - colorscale, - labelTagColors, - selectedLabelTags, - schema - ), - ]; + const [bitmapPromises, moreMaskTargetsBuffers] = await processLabels( + sample, + coloring, + null, + sources, + customizeColorSetting, + colorscale, + labelTagColors, + selectedLabelTags, + schema + ); + + if (bitmapPromises.length !== 0) { + imageBitmapPromises.push(...bitmapPromises); + } + + if (moreMaskTargetsBuffers.length !== 0) { + maskTargetsBuffers.push(...moreMaskTargetsBuffers); + } } - if (sample.frames && sample.frames.length) { - bufferPromises = [ - ...bufferPromises, - ...sample.frames - .map((frame) => - processLabels( - frame, - coloring, - "frames.", - sources, - customizeColorSetting, - colorscale, - labelTagColors, - selectedLabelTags, - schema - ) + // this usually only applies to thumbnail frame + // other frames are processed in the stream (see `getSendChunk`) + if (sample.frames?.length) { + const allFramePromises: ReturnType[] = []; + for (const frame of sample.frames) { + allFramePromises.push( + processLabels( + frame, + coloring, + "frames.", + sources, + customizeColorSetting, + colorscale, + labelTagColors, + selectedLabelTags, + schema ) - .flat(), - ]; + ); + } + const framePromisesResolved = await Promise.all(allFramePromises); + for (const [bitmapPromises, buffers] of framePromisesResolved) { + if (bitmapPromises.length !== 0) { + imageBitmapPromises.push(...bitmapPromises); + } + + if (buffers.length !== 0) { + maskTargetsBuffers.push(...buffers); + } + } } - Promise.all(bufferPromises).then((buffers) => { + Promise.all(imageBitmapPromises).then((bitmaps) => { + const flatBitmaps = bitmaps.flat() ?? []; + const flatMaskTargetsBuffers = maskTargetsBuffers.flat() ?? []; + const transferables = [...flatBitmaps, ...flatMaskTargetsBuffers]; postMessage( { method: "processSample", @@ -383,7 +403,7 @@ const processSample = ({ selectedLabelTags, }, // @ts-ignore - buffers.flat() + transferables ); }); }; @@ -503,9 +523,9 @@ const createReader = ({ const getSendChunk = (uuid: string) => - ({ value }: { done: boolean; value?: FrameChunkResponse }) => { + async ({ value }: { done: boolean; value?: FrameChunkResponse }) => { if (value) { - Promise.all( + const allLabelsPromiseResults = await Promise.allSettled( value.frames.map((frame) => processLabels( frame, @@ -519,18 +539,35 @@ const getSendChunk = value.schema ) ) - ).then((buffers) => { - postMessage( - { - method: "frameChunk", - frames: value.frames, - range: value.range, - uuid, - }, - // @ts-ignore - buffers.flat() - ); - }); + ); + + const allLabelsResults = allLabelsPromiseResults + .filter((result) => result.status === "fulfilled") + .map((result) => result.value); + + const allBuffers = allLabelsResults.map((result) => result[1]).flat(); + + const allBitmapsPromises = allLabelsResults + .map((result) => result[0]) + .flat(); + + const bitmapPromiseResults = ( + await Promise.allSettled(allBitmapsPromises) + ) + .map((result) => (result.status === "fulfilled" ? result.value : [])) + .flat(); + + const transferables = [...bitmapPromiseResults, ...allBuffers]; + postMessage( + { + method: "frameChunk", + frames: value.frames, + range: value.range, + uuid, + }, + // @ts-ignore + transferables + ); } }; diff --git a/app/packages/looker/src/worker/pooled-fetch.test.ts b/app/packages/looker/src/worker/pooled-fetch.test.ts new file mode 100644 index 0000000000..6804f3260a --- /dev/null +++ b/app/packages/looker/src/worker/pooled-fetch.test.ts @@ -0,0 +1,110 @@ +import { beforeEach, describe, expect, it, Mock, vi } from "vitest"; +import { enqueueFetch } from "./pooled-fetch"; + +const MAX_CONCURRENT_REQUESTS = 100; + +// helper function to create a deferred promise +function createDeferredPromise() { + let resolve: (value: T | PromiseLike) => void; + let reject: (reason?: any) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve: resolve!, reject: reject! }; +} + +describe("enqueueFetch", () => { + let mockedFetch: Mock; + + beforeEach(() => { + vi.resetAllMocks(); + mockedFetch = vi.fn(); + global.fetch = mockedFetch; + }); + + it("should return response when fetch succeeds", async () => { + const mockResponse = new Response("OK", { status: 200 }); + mockedFetch.mockResolvedValue(mockResponse); + + const response = await enqueueFetch({ url: "https://fiftyone.ai" }); + expect(response).toBe(mockResponse); + }); + + it("should process multiple requests in order", async () => { + const mockResponse1 = new Response("First", { status: 200 }); + const mockResponse2 = new Response("Second", { status: 200 }); + + const deferred1 = createDeferredPromise(); + const deferred2 = createDeferredPromise(); + + mockedFetch + .mockImplementationOnce(() => deferred1.promise) + .mockImplementationOnce(() => deferred2.promise); + + const promise1 = enqueueFetch({ url: "https://fiftyone.ai/1" }); + const promise2 = enqueueFetch({ url: "https://fiftyone.ai/2" }); + + deferred1.resolve(mockResponse1); + + const response1 = await promise1; + expect(response1).toBe(mockResponse1); + + deferred2.resolve(mockResponse2); + + const response2 = await promise2; + expect(response2).toBe(mockResponse2); + }); + + it("should not exceed MAX_CONCURRENT_REQUESTS", async () => { + const numRequests = MAX_CONCURRENT_REQUESTS + 50; + const deferredPromises = []; + + for (let i = 0; i < numRequests; i++) { + const deferred = createDeferredPromise(); + deferredPromises.push(deferred); + mockedFetch.mockImplementationOnce(() => deferred.promise); + enqueueFetch({ url: `https://fiftyone.ai/${i}` }); + } + + // at this point, fetch should have been called MAX_CONCURRENT_REQUESTS times + expect(mockedFetch).toHaveBeenCalledTimes(MAX_CONCURRENT_REQUESTS); + + // resolve all deferred promises + deferredPromises.forEach((deferred, index) => { + deferred.resolve(new Response(`Response ${index}`, { status: 200 })); + }); + + // wait for all promises to resolve + await Promise.all(deferredPromises.map((dp) => dp.promise)); + + // all requests should have been processed + expect(mockedFetch).toHaveBeenCalledTimes(numRequests); + }); + + it("should reject immediately on non-retryable error", async () => { + const mockResponse = new Response("Not Found", { status: 404 }); + mockedFetch.mockResolvedValue(mockResponse); + + await expect(enqueueFetch({ url: "https://fiftyone.ai" })).rejects.toThrow( + "Non-retryable HTTP error: 404" + ); + }); + + it("should retry on retryable errors up to MAX_RETRIES times", async () => { + const MAX_RETRIES = 3; + mockedFetch.mockRejectedValue(new Error("Network Error")); + + await expect( + enqueueFetch({ + url: "https://fiftyone.ai", + retryOptions: { + retries: MAX_RETRIES, + delay: 50, + }, + }) + ).rejects.toThrow("Max retries for fetch reached"); + + expect(mockedFetch).toHaveBeenCalledTimes(MAX_RETRIES); + }); +}); diff --git a/app/packages/looker/src/worker/pooled-fetch.ts b/app/packages/looker/src/worker/pooled-fetch.ts new file mode 100644 index 0000000000..a23e1cb739 --- /dev/null +++ b/app/packages/looker/src/worker/pooled-fetch.ts @@ -0,0 +1,46 @@ +import { fetchWithLinearBackoff, RetryOptions } from "./decorated-fetch"; + +interface QueueItem { + request: { + url: string; + options?: RequestInit; + retryOptions?: RetryOptions; + }; + resolve: (value: Response | PromiseLike) => void; + reject: (reason?: any) => void; +} + +// note: arbitrary number that seems to work well +const MAX_CONCURRENT_REQUESTS = 100; + +let activeRequests = 0; +const requestQueue: QueueItem[] = []; + +export const enqueueFetch = ( + request: QueueItem["request"] +): Promise => { + return new Promise((resolve, reject) => { + requestQueue.push({ request, resolve, reject }); + processFetchQueue(); + }); +}; + +const processFetchQueue = () => { + if (activeRequests >= MAX_CONCURRENT_REQUESTS || requestQueue.length === 0) { + return; + } + + const { request, resolve, reject } = requestQueue.shift(); + activeRequests++; + + fetchWithLinearBackoff(request.url, request.options, request.retryOptions) + .then((response) => { + activeRequests--; + resolve(response); + processFetchQueue(); + }) + .catch((error) => { + activeRequests--; + reject(error); + }); +}; diff --git a/app/packages/looker/src/worker/shared.ts b/app/packages/looker/src/worker/shared.ts index adfda58d29..ec383b7536 100644 --- a/app/packages/looker/src/worker/shared.ts +++ b/app/packages/looker/src/worker/shared.ts @@ -1,3 +1,5 @@ +import { HEATMAP } from "@fiftyone/utilities"; + /** * Map the _id field to id */ @@ -8,3 +10,12 @@ export const mapId = (obj) => { } return obj; }; + +export const getOverlayFieldFromCls = (cls: string) => { + switch (cls) { + case HEATMAP: + return { canonical: "map", disk: "map_path" }; + default: + return { canonical: "mask", disk: "mask_path" }; + } +};