Skip to content

Commit

Permalink
feat(ui): use regional guidance validation utils in graph builders
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Nov 29, 2024
1 parent 3905c97 commit df0c7d7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
CanvasRegionalGuidanceState,
IPAdapterConfig,
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';

const log = logger('system');
Expand All @@ -23,19 +20,12 @@ type AddedRegionResult = {
addedIPAdapters: number;
};

const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
const isEnabled = rg.isEnabled;
const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
return isEnabled && (hasTextPrompt || hasIPAdapter);
};

type AddRegionsArg = {
manager: CanvasManager;
regions: CanvasRegionalGuidanceState[];
g: Graph;
bbox: Rect;
base: BaseModelType;
model: ParameterModel;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
posCondCollect: Invocation<'collect'>;
Expand All @@ -49,7 +39,7 @@ type AddRegionsArg = {
* @param regions Array of regions to add
* @param g The graph to add the layers to
* @param bbox The bounding box
* @param base The base model type
* @param model The main model
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
Expand All @@ -63,17 +53,23 @@ export const addRegions = async ({
regions,
g,
bbox,
base,
model,
posCond,
negCond,
posCondCollect,
negCondCollect,
ipAdapterCollect,
}: AddRegionsArg): Promise<AddedRegionResult[]> => {
const isSDXL = base === 'sdxl';
const isFLUX = base === 'flux';
const isSDXL = model.base === 'sdxl';
const isFLUX = model.base === 'flux';

const validRegions = regions.filter((rg) => {
if (!rg.isEnabled) {
return false;
}
return getRegionalGuidanceWarnings(rg, model).length === 0;
});

const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];

for (const region of validRegions) {
Expand Down Expand Up @@ -275,11 +271,7 @@ export const addRegions = async ({
}
}

const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
isValidIPAdapter(ipAdapter, base)
);

for (const { id, ipAdapter } of validRGIPAdapters) {
for (const { id, ipAdapter } of region.referenceImages) {
assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');

result.addedIPAdapters++;
Expand Down Expand Up @@ -313,11 +305,3 @@ export const addRegions = async ({

return results;
};

const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
// Must be a model that matches the current base and must have a control image
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasImage;
};
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ export const buildFLUXGraph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond: null,
posCondCollect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ export const buildSD1Graph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond,
posCondCollect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ export const buildSDXLGraph = async (
regions: canvas.regionalGuidance.entities,
g,
bbox: canvas.bbox.rect,
base: modelConfig.base,
model: modelConfig,
posCond,
negCond,
posCondCollect,
Expand Down

0 comments on commit df0c7d7

Please sign in to comment.