Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax bounding box requirements for model training #8222

Merged
merged 10 commits into from
Nov 28, 2024
1 change: 1 addition & 0 deletions CHANGELOG.unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ For upgrade instructions, please check the [migration guide](MIGRATIONS.released
- Reading image files on datastore filesystem is now done asynchronously. [#8126](https://github.com/scalableminds/webknossos/pull/8126)
- Improved error messages for starting jobs on datasets from other organizations. [#8181](https://github.com/scalableminds/webknossos/pull/8181)
- Removed bounding box size restriction for inferral jobs for super users. [#8200](https://github.com/scalableminds/webknossos/pull/8200)
- Allowed to train an AI model using differently sized bounding boxes. We recommend all bounding boxes to have equal dimensions or to have dimensions which are multiples of the smallest bounding box. [#8222](https://github.com/scalableminds/webknossos/pull/8222)

### Fixed
- Fix performance bottleneck when deleting a lot of trees at once. [#8176](https://github.com/scalableminds/webknossos/pull/8176)
Expand Down
169 changes: 124 additions & 45 deletions frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ import _ from "lodash";
import BoundingBox from "oxalis/model/bucket_data_handling/bounding_box";
import { formatVoxels } from "libs/format_utils";
import * as Utils from "libs/utils";
import { V3 } from "libs/mjs";
import type { APIAnnotation, APIDataset, ServerVolumeTracing } from "types/api_flow_types";
import type { Vector3 } from "oxalis/constants";
import type { Vector3, Vector6 } from "oxalis/constants";
import { serverVolumeToClientVolumeTracing } from "oxalis/model/reducers/volumetracing_reducer";
import { convertUserBoundingBoxesFromServerToFrontend } from "oxalis/model/reducers/reducer_helpers";
import { computeArrayFromBoundingBox } from "libs/utils";

const { TextArea } = Input;
const FormItem = Form.Item;
Expand Down Expand Up @@ -66,8 +66,8 @@ enum AiModelCategory {
const ExperimentalWarning = () => (
<Row style={{ display: "grid", marginBottom: 16 }}>
<Alert
message="Please note that this feature is experimental. All bounding boxes must be the same size, with equal width and height. Ensure the size is not too small (we recommend at least 10 Vx per dimension) and choose boxes that represent the data well."
type="warning"
message="Please note that this feature is experimental. All bounding boxes should have equal dimensions or have dimensions which are multiples of the smallest bounding box. Ensure the size is not too small (we recommend at least 10 Vx per dimension) and choose boxes that represent the data well."
type="info"
showIcon
/>
</Row>
Expand Down Expand Up @@ -217,19 +217,29 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
modelCategory: AiModelCategory.EM_NEURONS,
};

const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes }) => userBoundingBoxes);
const userBoundingBoxes = annotationInfos.flatMap(({ userBoundingBoxes, annotation }) =>
userBoundingBoxes.map((box) => ({
...box,
annotationId: "id" in annotation ? annotation.id : annotation.annotationId,
})),
);

const bboxesVoxelCount = _.sum(
(userBoundingBoxes || []).map((bbox) => new BoundingBox(bbox.boundingBox).getVolume()),
);

const { areSomeAnnotationsInvalid, invalidAnnotationsReason } =
areInvalidAnnotationsIncluded(annotationInfos);
const { areSomeBBoxesInvalid, invalidBBoxesReason } =
areInvalidBoundingBoxesIncluded(userBoundingBoxes);
const invalidReasons = [invalidAnnotationsReason, invalidBBoxesReason]
.filter((reason) => reason)
.join("\n");
const { hasAnnotationErrors, errors: annotationErrors } =
checkAnnotationsForErrorsAndWarnings(annotationInfos);
const {
hasBBoxErrors,
hasBBoxWarnings,
errors: bboxErrors,
warnings: bboxWarnings,
} = checkBoundingBoxesForErrorsAndWarnings(userBoundingBoxes);
const hasErrors = hasAnnotationErrors || hasBBoxErrors;
const hasWarnings = hasBBoxWarnings;
const errors = [...annotationErrors, ...bboxErrors];
const warnings = bboxWarnings;

return (
<Form
Expand Down Expand Up @@ -333,16 +343,46 @@ export function TrainAiModelTab<GenericAnnotation extends APIAnnotation | Hybrid
</div>
</FormItem>
) : null}

{hasErrors
? errors.map((error) => (
<Alert
key={error}
description={error}
style={{
marginBottom: 12,
whiteSpace: "pre-line",
}}
type="error"
showIcon
/>
))
: null}
{hasWarnings
? warnings.map((warning) => (
<Alert
key={warning}
description={warning}
style={{
marginBottom: 12,
whiteSpace: "pre-line",
}}
type="warning"
showIcon
/>
))
: null}

<FormItem>
<Tooltip title={invalidReasons}>
<Tooltip title={hasErrors ? "Solve the errors displayed above before continuing." : ""}>
<Button
size="large"
type="primary"
htmlType="submit"
style={{
width: "100%",
}}
disabled={areSomeBBoxesInvalid || areSomeAnnotationsInvalid}
disabled={hasErrors}
>
Start Training
</Button>
Expand Down Expand Up @@ -385,16 +425,16 @@ export function CollapsibleWorkflowYamlEditor({
);
}

function areInvalidAnnotationsIncluded<T extends HybridTracing | APIAnnotation>(
function checkAnnotationsForErrorsAndWarnings<T extends HybridTracing | APIAnnotation>(
annotationsWithDatasets: Array<AnnotationInfoForAIJob<T>>,
): {
areSomeAnnotationsInvalid: boolean;
invalidAnnotationsReason: string | null;
hasAnnotationErrors: boolean;
errors: string[];
} {
if (annotationsWithDatasets.length === 0) {
return {
areSomeAnnotationsInvalid: true,
invalidAnnotationsReason: "At least one annotation must be defined.",
hasAnnotationErrors: true,
errors: ["At least one annotation must be defined."],
};
}
const annotationsWithoutBoundingBoxes = annotationsWithDatasets.filter(
Expand All @@ -407,42 +447,81 @@ function areInvalidAnnotationsIncluded<T extends HybridTracing | APIAnnotation>(
"id" in annotation ? annotation.id : annotation.annotationId,
);
return {
areSomeAnnotationsInvalid: true,
invalidAnnotationsReason: `All annotations must have at least one bounding box. Annotations without bounding boxes are: ${annotationIds.join(", ")}`,
hasAnnotationErrors: true,
errors: [
`All annotations must have at least one bounding box. Annotations without bounding boxes are:\n${annotationIds.join(", ")}`,
],
};
}
return { areSomeAnnotationsInvalid: false, invalidAnnotationsReason: null };
return { hasAnnotationErrors: false, errors: [] };
}

function areInvalidBoundingBoxesIncluded(userBoundingBoxes: UserBoundingBox[]): {
areSomeBBoxesInvalid: boolean;
invalidBBoxesReason: string | null;
function checkBoundingBoxesForErrorsAndWarnings(
userBoundingBoxes: (UserBoundingBox & { annotationId: string })[],
): {
hasBBoxErrors: boolean;
hasBBoxWarnings: boolean;
errors: string[];
warnings: string[];
} {
let hasBBoxErrors = false;
let hasBBoxWarnings = false;
const errors = [];
const warnings = [];
if (userBoundingBoxes.length === 0) {
return {
areSomeBBoxesInvalid: true,
invalidBBoxesReason: "At least one bounding box must be defined.",
};
hasBBoxErrors = true;
errors.push("At least one bounding box must be defined.");
}
const getSize = (bbox: UserBoundingBox) => V3.sub(bbox.boundingBox.max, bbox.boundingBox.min);
// Find smallest bounding box dimensions
const minDimensions = userBoundingBoxes.reduce(
(min, { boundingBox: box }) => ({
x: Math.min(min.x, box.max[0] - box.min[0]),
y: Math.min(min.y, box.max[1] - box.min[1]),
z: Math.min(min.z, box.max[2] - box.min[2]),
}),
{ x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY },
);
Comment on lines +476 to +483
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding validation for zero or negative dimensions

While the code handles minimum size checks, it should also validate for zero or negative dimensions to prevent potential issues.

   const minDimensions = userBoundingBoxes.reduce(
     (min, { boundingBox: box }) => ({
+      // Validate that dimensions are positive
+      x: box.max[0] <= box.min[0] ? Infinity : Math.min(min.x, box.max[0] - box.min[0]),
+      y: box.max[1] <= box.min[1] ? Infinity : Math.min(min.y, box.max[1] - box.min[1]),
+      z: box.max[2] <= box.min[2] ? Infinity : Math.min(min.z, box.max[2] - box.min[2]),
-      x: Math.min(min.x, box.max[0] - box.min[0]),
-      y: Math.min(min.y, box.max[1] - box.min[1]),
-      z: Math.min(min.z, box.max[2] - box.min[2]),
     }),
     { x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY },
   );
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
const minDimensions = userBoundingBoxes.reduce(
(min, { boundingBox: box }) => ({
x: Math.min(min.x, box.max[0] - box.min[0]),
y: Math.min(min.y, box.max[1] - box.min[1]),
z: Math.min(min.z, box.max[2] - box.min[2]),
}),
{ x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY },
);
const minDimensions = userBoundingBoxes.reduce(
(min, { boundingBox: box }) => ({
// Validate that dimensions are positive
x: box.max[0] <= box.min[0] ? Infinity : Math.min(min.x, box.max[0] - box.min[0]),
y: box.max[1] <= box.min[1] ? Infinity : Math.min(min.y, box.max[1] - box.min[1]),
z: box.max[2] <= box.min[2] ? Infinity : Math.min(min.z, box.max[2] - box.min[2]),
}),
{ x: Number.POSITIVE_INFINITY, y: Number.POSITIVE_INFINITY, z: Number.POSITIVE_INFINITY },
);


const size = getSize(userBoundingBoxes[0]);
// width must equal height
if (size[0] !== size[1]) {
return {
areSomeBBoxesInvalid: true,
invalidBBoxesReason: "The bounding box width must equal its height.",
};
// Validate minimum size and multiple requirements
type BoundingBoxWithAnnotationId = { boundingBox: Vector6; name: string; annotationId: string };
const tooSmallBoxes: BoundingBoxWithAnnotationId[] = [];
const nonMultipleBoxes: BoundingBoxWithAnnotationId[] = [];
userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId }) => {
const arrayBox = computeArrayFromBoundingBox(box);
const [_x, _y, _z, width, height, depth] = arrayBox;
if (width < 10 || height < 10 || depth < 10) {
tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId });
}

if (
width % minDimensions.x !== 0 ||
height % minDimensions.y !== 0 ||
depth % minDimensions.z !== 0
) {
nonMultipleBoxes.push({ boundingBox: arrayBox, name, annotationId });
}
});
Comment on lines +489 to +503
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider early return for invalid dimensions

The validation loop should return early if any box has invalid dimensions (zero or negative) to prevent unnecessary processing.

   userBoundingBoxes.forEach(({ boundingBox: box, name, annotationId }) => {
     const arrayBox = computeArrayFromBoundingBox(box);
     const [_x, _y, _z, width, height, depth] = arrayBox;
+    // Check for invalid dimensions
+    if (width <= 0 || height <= 0 || depth <= 0) {
+      tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId });
+      return;
+    }
+
     if (width < 10 || height < 10 || depth < 10) {
       tooSmallBoxes.push({ boundingBox: arrayBox, name, annotationId });
     }

Committable suggestion skipped: line range outside the PR's diff.


const boxWithIdToString = ({ boundingBox, name, annotationId }: BoundingBoxWithAnnotationId) =>
`'${name}' of annotation ${annotationId}: ${boundingBox.join(", ")}`;

if (tooSmallBoxes.length > 0) {
hasBBoxWarnings = true;
const tooSmallBoxesStrings = tooSmallBoxes.map(boxWithIdToString);
warnings.push(
`The following bounding boxes are not at least 10 Vx in each dimension which is suboptimal for the training:\n${tooSmallBoxesStrings.join("\n")}`,
);
}
// all bounding boxes must have the same size
const areSizesIdentical = userBoundingBoxes.every((bbox) => V3.isEqual(getSize(bbox), size));
if (areSizesIdentical) {
return { areSomeBBoxesInvalid: false, invalidBBoxesReason: null };

if (nonMultipleBoxes.length > 0) {
hasBBoxWarnings = true;
const nonMultipleBoxesStrings = nonMultipleBoxes.map(boxWithIdToString);
warnings.push(
`The minimum bounding box dimensions are ${minDimensions.x} x ${minDimensions.y} x ${minDimensions.z}. The following bounding boxes have dimensions which are not a multiple of the minimum dimensions which is suboptimal for the training:\n${nonMultipleBoxesStrings.join("\n")}`,
);
}
return {
areSomeBBoxesInvalid: true,
invalidBBoxesReason: "All bounding boxes must have the same size.",
};

return { hasBBoxErrors, hasBBoxWarnings, errors, warnings };
}

function AnnotationsCsvInput({
Expand Down