Skip to content

Commit

Permalink
feat: Send Canvas Image & Mask To ControlNet (#4374)
Browse files Browse the repository at this point in the history
## What type of PR is this? (check all applicable)

- [x] Feature


## Have you discussed this change with the InvokeAI team?
- [x] Yes

      
## Description

Send stuff directly from canvas to ControlNet

## Usage

- Two new buttons available on canvas Controlnet to import image and
mask.
- Click them.
  • Loading branch information
blessedcoolant authored Aug 29, 2023
2 parents a03233b + 15a927b commit ed1456e
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
Expand All @@ -41,6 +43,8 @@ import {
addImageUploadedFulfilledListener,
addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
Expand Down Expand Up @@ -80,8 +84,6 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';

export const listenerMiddleware = createListenerMiddleware();

Expand Down Expand Up @@ -137,6 +139,8 @@ addSessionReadyToInvokeListener();
// Canvas actions
addCanvasSavedToGalleryListener();
addCanvasMaskSavedToGalleryListener();
addCanvasImageToControlNetListener();
addCanvasMaskToControlNetListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';

export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();

const blob = await getBaseLayerBlob(state);

if (!blob) {
log.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}

const { autoAddBoardId } = state.gallery;

const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
},
})
).unwrap();

const { image_name } = imageDTO;

dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';

export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();

const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);

if (!canvasBlobsAndImageData) {
return;
}

const { maskBlob } = canvasBlobsAndImageData;

if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: 'Problem Importing Mask',
description: 'Unable to export mask',
status: 'error',
})
);
return;
}

const { autoAddBoardId } = state.gallery;

const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
},
})
).unwrap();

const { image_name } = imageDTO;

dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};
9 changes: 9 additions & 0 deletions invokeai/frontend/web/src/features/canvas/store/actions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';

export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
Expand All @@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);

export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');

export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
Expand All @@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
const { controlNetId } = controlNet;
const dispatch = useAppDispatch();

const activeTabName = useAppSelector(activeTabNameSelector);

const selector = createSelector(
stateSelector,
({ controlNet }) => {
Expand Down Expand Up @@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
>
<ParamControlNetModel controlNet={controlNet} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
)}
<IAIIconButton
size="sm"
tooltip="Duplicate"
Expand Down Expand Up @@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
/>
)}
</Flex>

<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import {
TypesafeDroppableData,
} from 'features/dnd/types';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { FaSave, FaUndo } from 'react-icons/fa';
import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import {
Expand All @@ -26,11 +30,13 @@ type Props = {

const selector = createSelector(
stateSelector,
({ controlNet }) => {
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
const { autoAddBoardId } = gallery;

return {
pendingControlImages,
autoAddBoardId,
};
},
defaultSelectorOptions
Expand All @@ -47,7 +53,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {

const dispatch = useAppDispatch();

const { pendingControlImages } = useAppSelector(selector);
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);

const [isMouseOverImage, setIsMouseOverImage] = useState(false);

Expand All @@ -59,9 +65,26 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
processedControlImageName ?? skipToken
);

const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();

const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);

const handleSaveControlImage = useCallback(() => {
if (!processedControlImage) {
return;
}

changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
});

addToBoard({ imageDTO: processedControlImage, board_id: autoAddBoardId });
}, [processedControlImage, autoAddBoardId, changeIsIntermediate, addToBoard]);

const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true);
}, []);
Expand Down Expand Up @@ -122,11 +145,19 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
isDropDisabled={shouldShowProcessedImage || !isEnabled}
postUploadAction={postUploadAction}
>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <FaSave size={16} /> : undefined}
tooltip="Save Control Image"
styleOverrides={{ marginTop: 6 }}
/>
</>
</IAIDndImage>

<Box
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';

type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
};

const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const dispatch = useAppDispatch();

const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);

const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);

return (
<Flex
sx={{
gap: 2,
}}
>
<IAIIconButton
size="sm"
icon={<FaImage />}
tooltip="Import Image From Canvas"
aria-label="Import Image From Canvas"
onClick={handleImportImageFromCanvas}
/>
<IAIIconButton
size="sm"
icon={<FaMask />}
tooltip="Import Mask From Canvas"
aria-label="Import Mask From Canvas"
onClick={handleImportMaskFromCanvas}
/>
</Flex>
);
};

export default memo(ControlNetCanvasImageImports);

0 comments on commit ed1456e

Please sign in to comment.