Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add key information to collapsed nodes #2656

Merged
merged 2 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
check_naming_conventions,
check_schema_types,
)
from .node_data import IteratorInputInfo, IteratorOutputInfo, NodeData
from .node_data import IteratorInputInfo, IteratorOutputInfo, KeyInfo, NodeData
from .output import BaseOutput
from .settings import Setting
from .types import FeatureId, InputId, NodeId, NodeKind, OutputId, RunFn
Expand Down Expand Up @@ -113,6 +113,7 @@ def register(
iterator_inputs: list[IteratorInputInfo] | IteratorInputInfo | None = None,
iterator_outputs: list[IteratorOutputInfo] | IteratorOutputInfo | None = None,
node_context: bool = False,
key_info: KeyInfo | None = None,
):
if not isinstance(description, str):
description = "\n\n".join(description)
Expand Down Expand Up @@ -181,6 +182,7 @@ def inner_wrapper(wrapped_func: T) -> T:
outputs=p_outputs,
iterator_inputs=iterator_inputs,
iterator_outputs=iterator_outputs,
key_info=key_info,
side_effects=side_effects,
deprecated=deprecated,
node_context=node_context,
Expand Down
19 changes: 19 additions & 0 deletions backend/src/api/node_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import navi

Expand Down Expand Up @@ -50,6 +51,22 @@ def to_dict(self):
}


class KeyInfo:
def __init__(self, data: dict[str, Any]) -> None:
self._data = data

@staticmethod
def enum(enum_input: InputId | int) -> KeyInfo:
return KeyInfo({"kind": "enum", "enum": enum_input})

@staticmethod
def type(expression: navi.ExpressionJson) -> KeyInfo:
return KeyInfo({"kind": "type", "expression": expression})

def to_dict(self):
return self._data


@dataclass(frozen=True)
class NodeData:
schema_id: str
Expand All @@ -66,6 +83,8 @@ class NodeData:
iterator_inputs: list[IteratorInputInfo]
iterator_outputs: list[IteratorOutputInfo]

key_info: KeyInfo | None

side_effects: bool
deprecated: bool
node_context: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sanic.log import logger
from spandrel import ImageModelDescriptor, ModelTiling

from api import NodeContext
from api import KeyInfo, NodeContext
from nodes.groups import Condition, if_group
from nodes.impl.pytorch.auto_split import pytorch_auto_split
from nodes.impl.upscale.auto_split_tiles import (
Expand Down Expand Up @@ -215,6 +215,23 @@ def estimate():
assume_normalized=True, # pytorch_auto_split already does clipping internally
)
],
key_info=KeyInfo.type(
"""
let model = Input0;
let useCustomScale = Input4;
let customScale = Input5;

let singleUpscale = convenientUpscale(model, img);

let scale = if bool::and(useCustomScale, model.scale >= 2, model.inputChannels == model.outputChannels) {
customScale
} else {
model.scale
};

string::concat(toString(scale), "x")
"""
),
node_context=True,
)
def upscale_image_node(
Expand Down
2 changes: 2 additions & 0 deletions backend/src/packages/chaiNNer_standard/image/io/save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image
from sanic.log import logger

from api import KeyInfo
from nodes.groups import Condition, if_enum_group, if_group
from nodes.impl.dds.format import (
BC7_FORMATS,
Expand Down Expand Up @@ -211,6 +212,7 @@ class TiffColorDepth(Enum):
),
],
outputs=[],
key_info=KeyInfo.enum(4),
side_effects=True,
limited_to_8bpc="Image will be saved with 8 bits/channel by default. Some formats support higher bit depths.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from api import KeyInfo
from nodes.groups import if_enum_group
from nodes.impl.color.color import Color
from nodes.impl.image_utils import BorderType, create_border
Expand Down Expand Up @@ -95,6 +96,7 @@ class BorderMode(Enum):
assume_normalized=True,
)
],
key_info=KeyInfo.enum(3),
)
def pad_node(
img: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from api import KeyInfo
from nodes.groups import if_enum_group
from nodes.properties.inputs import EnumInput, ImageInput, NumberInput
from nodes.properties.outputs import ImageOutput
Expand Down Expand Up @@ -77,6 +78,7 @@ class CropMode(Enum):
"The cropped area would result in an image with no width or no height."
)
],
key_info=KeyInfo.enum(1),
)
def crop_node(
img: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from api import KeyInfo
from nodes.groups import Condition, if_enum_group, if_group
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import (
Expand Down Expand Up @@ -89,6 +90,20 @@ class ImageResizeMode(Enum):
assume_normalized=True,
)
],
key_info=KeyInfo.type(
"""
let mode = Input1;

let scale = Input2;
let width = Input3;
let height = Input4;

match mode {
ImageResizeMode::Percentage => string::concat(toString(scale), "%"),
ImageResizeMode::Absolute => string::concat(toString(width), "x", toString(height)),
}
"""
),
)
def resize_node(
img: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from enum import Enum

from api import KeyInfo
from nodes.groups import if_enum_group
from nodes.impl.color.color import Color
from nodes.properties.inputs import EnumInput, SliderInput
Expand All @@ -22,7 +23,9 @@ class ColorType(Enum):
description="Create a new color value from individual channels.",
icon="MdColorLens",
inputs=[
EnumInput(ColorType, "Color Type", ColorType.RGBA, preferred_style="tabs"),
EnumInput(
ColorType, "Color Type", ColorType.RGBA, preferred_style="tabs"
).with_id(0),
if_enum_group(0, ColorType.GRAY)(
SliderInput(
"Luma",
Expand Down Expand Up @@ -96,6 +99,7 @@ class ColorType(Enum):
"""
)
],
key_info=KeyInfo.enum(0),
)
def color_from_node(
color_type: ColorType,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from api import KeyInfo
from nodes.properties.inputs import NumberInput
from nodes.properties.outputs import NumberOutput

Expand All @@ -24,6 +25,7 @@
outputs=[
NumberOutput("Number", output_type="Input0").suggest(),
],
key_info=KeyInfo.type("""toString(Input0)"""),
)
def number_node(number: float) -> float:
return number
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from api import KeyInfo
from nodes.properties.inputs import SliderInput
from nodes.properties.outputs import NumberOutput

Expand All @@ -25,6 +26,7 @@
outputs=[
NumberOutput("Percent", output_type="Input0"),
],
key_info=KeyInfo.type("""string::concat(toString(Input0), "%")"""),
)
def percent_node(number: int) -> int:
return number
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from enum import Enum

from api import KeyInfo
from nodes.groups import optional_list_group
from nodes.properties.inputs import AnyInput, EnumInput
from nodes.properties.outputs import BaseOutput
Expand Down Expand Up @@ -29,7 +30,7 @@ class ValueIndex(Enum):
description="Allows you to pass in multiple inputs and then change which one passes through to the output.",
icon="BsShuffle",
inputs=[
EnumInput(ValueIndex),
EnumInput(ValueIndex).with_id(0),
AnyInput(label="Value A"),
AnyInput(label="Value B"),
optional_list_group(
Expand Down Expand Up @@ -64,6 +65,7 @@ class ValueIndex(Enum):
label="Value",
).with_never_reason("The selected value should have a connection.")
],
key_info=KeyInfo.enum(0),
see_also=["chainner:utility:pass_through"],
)
def switch_node(selection: ValueIndex, *args: object | None) -> object:
Expand Down
1 change: 1 addition & 0 deletions backend/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def nodes(_request: Request):
],
"iteratorInputs": [x.to_dict() for x in node.iterator_inputs],
"iteratorOutputs": [x.to_dict() for x in node.iterator_outputs],
"keyInfo": node.key_info.to_dict() if node.key_info else None,
"description": node.description,
"seeAlso": node.see_also,
"icon": node.icon,
Expand Down
11 changes: 11 additions & 0 deletions src/common/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ export interface IteratorOutputInfo {
readonly lengthType: ExpressionJson;
}

export type KeyInfo = EnumKeyInfo | TypeKeyInfo;
export interface EnumKeyInfo {
readonly kind: 'enum';
readonly enum: InputId;
}
export interface TypeKeyInfo {
readonly kind: 'type';
readonly expression: ExpressionJson;
}

export interface NodeSchema {
readonly name: string;
readonly category: CategoryId;
Expand All @@ -299,6 +309,7 @@ export interface NodeSchema {
readonly groupLayout: readonly (InputId | Group)[];
readonly iteratorInputs: readonly IteratorInputInfo[];
readonly iteratorOutputs: readonly IteratorOutputInfo[];
readonly keyInfo?: KeyInfo | null;
readonly schemaId: SchemaId;
readonly hasSideEffects: boolean;
readonly deprecated: boolean;
Expand Down
92 changes: 92 additions & 0 deletions src/common/nodes/keyInfo.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import {
ParameterDefinition,
Scope,
ScopeBuilder,
StringType,
evaluate,
isStringLiteral,
isSubsetOf,
} from '@chainner/navi';
import { InputData, KeyInfo, NodeSchema, OfKind } from '../common-types';
import {
FunctionDefinition,
FunctionInstance,
getInputParamName,
getOutputParamName,
} from '../types/function';
import { fromJson } from '../types/json';
import { lazyKeyed } from '../util';

const getKeyInfoScopeTemplate = lazyKeyed((definition: FunctionDefinition): Scope => {
const builder = new ScopeBuilder('key info', definition.scope);

// assign inputs and outputs
definition.inputDefaults.forEach((input, inputId) => {
builder.add(new ParameterDefinition(getInputParamName(inputId), input));
});
definition.outputDefaults.forEach((output, outputId) => {
builder.add(new ParameterDefinition(getOutputParamName(outputId), output));
});

return builder.createScope();
});
const getKeyInfoScope = (instance: FunctionInstance): Scope => {
const scope = getKeyInfoScopeTemplate(instance.definition);

// assign inputs and outputs
instance.inputs.forEach((input, inputId) => {
scope.assignParameter(getInputParamName(inputId), input);
});
instance.outputs.forEach((output, outputId) => {
scope.assignParameter(getOutputParamName(outputId), output);
});

return scope;
};

const accessors: {
[kind in KeyInfo['kind']]: (
info: OfKind<KeyInfo, kind>,
node: NodeSchema,
inputData: InputData,
types: FunctionInstance | undefined
) => string | undefined;
} = {
enum: (info, node, inputData) => {
const input = node.inputs.find((i) => i.id === info.enum);
if (!input) throw new Error(`Input ${info.enum} not found`);
if (input.kind !== 'dropdown') throw new Error(`Input ${info.enum} is not a dropdown`);

const value = inputData[input.id];
const option = input.options.find((o) => o.value === value);
return option?.option;
},
type: (info, node, inputData, types) => {
if (!types) return undefined;

const expression = fromJson(info.expression);
const scope = getKeyInfoScope(types);
const result = evaluate(expression, scope);

if (isStringLiteral(result)) return result.value;

// check that the expression actually evaluates to a string
if (!isSubsetOf(result, StringType.instance)) {
throw new Error(
`Key info expression must evaluate to a string, but got ${result.toString()}`
);
}

return undefined;
},
};

export const getKeyInfo = (
node: NodeSchema,
inputData: InputData,
types: FunctionInstance | undefined
): string | undefined => {
const { keyInfo } = node;
if (!keyInfo) return undefined;
return accessors[keyInfo.kind](keyInfo as never, node, inputData, types);
};
4 changes: 2 additions & 2 deletions src/common/types/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ const getParamRefs = <P extends 'Input' | 'Output'>(
return refs;
};

const getInputParamName = (inputId: InputId) => `Input${inputId}` as const;
const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const;
export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const;
export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const;

interface InputInfo {
expression: Expression;
Expand Down
Loading
Loading