diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index d012ca2853a..a41dc065507 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -1,35 +1,41 @@ import { logger } from 'app/logging/logger'; import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import type { - CanvasControlLayerState, - ControlNetConfig, - Rect, - T2IAdapterConfig, -} from 'features/controlLayers/store/types'; +import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types'; +import { getControlLayerWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { serializeError } from 'serialize-error'; -import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; +import type { ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); +type AddControlNetsArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddControlNetsResult = { addedControlNets: number; }; -export const addControlNets = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 'controlnet'); +export const addControlNets = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddControlNetsArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 'controlnet') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddControlNetsResult = { addedControlNets: 0, @@ -54,22 +60,31 @@ export const addControlNets = async ( return result; }; +type AddT2IAdaptersArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddT2IAdaptersResult = { addedT2IAdapters: number; }; -export const addT2IAdapters = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 't2i_adapter'); +export const addT2IAdapters = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddT2IAdaptersArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 't2i_adapter') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddT2IAdaptersResult = { addedT2IAdapters: 0, @@ -145,11 +160,3 @@ const addT2IAdapterToGraph = ( g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item'); }; - -const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model - const hasModel = Boolean(controlAdapter.model); - // Model must match the current base model - const modelMatchesBase = controlAdapter.model?.base === base; - return hasModel && modelMatchesBase; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts index 81b98f3ef57..0a3a43a0188 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts @@ -1,19 +1,23 @@ import type { CanvasReferenceImageState } from 'features/controlLayers/store/types'; +import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import type { BaseModelType, Invocation } from 'services/api/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; type AddIPAdaptersResult = { addedIPAdapters: number; }; -export const addIPAdapters = ( - ipAdapters: CanvasReferenceImageState[], - g: Graph, - collector: Invocation<'collect'>, - base: BaseModelType -): AddIPAdaptersResult => { - const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base)); +type AddIPAdaptersArg = { + entities: CanvasReferenceImageState[]; + g: Graph; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + +export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => { + const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0); const result: AddIPAdaptersResult = { addedIPAdapters: 0, @@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item'); }; - -const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image - const hasModel = Boolean(ipAdapter.model); - const modelMatchesBase = ipAdapter.model?.base === base; - const hasImage = Boolean(ipAdapter.image); - return isEnabled && hasModel && modelMatchesBase && hasImage; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index f8d310ee7e0..885954e9955 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -199,14 +199,14 @@ export const buildFLUXGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -217,7 +217,12 @@ export const buildFLUXGraph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 58b195ed61d..75222270076 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -227,14 +227,14 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -245,14 +245,14 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { @@ -263,7 +263,12 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 8aff599743e..9357a291b4b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -232,14 +232,14 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -250,14 +250,14 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { @@ -268,7 +268,12 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager,