Skip to content

Commit

Permalink
Fixed typing and broadcasting issue for sequence types (#2972)
Browse files Browse the repository at this point in the history
* Fixed typing and broadcasting issue for sequence types

* Fixed errors
  • Loading branch information
RunDevelopment authored Jul 1, 2024
1 parent 307109c commit a1ce9d5
Show file tree
Hide file tree
Showing 16 changed files with 222 additions and 142 deletions.
4 changes: 2 additions & 2 deletions backend/src/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Dict, Literal, TypedDict, Union

from api import ErrorValue, InputId, NodeId, OutputId
from api import ErrorValue, InputId, IterOutputId, NodeId, OutputId

# General events

Expand Down Expand Up @@ -89,7 +89,7 @@ class NodeBroadcastData(TypedDict):
nodeId: NodeId
data: dict[OutputId, object]
types: dict[OutputId, object]
expectedLength: int | None
sequenceTypes: dict[IterOutputId, object] | None


class NodeBroadcastEvent(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ def list_glob(directory: Path, globexpr: str, ext_filter: list[str]) -> list[Pat
DirectoryOutput("Directory", output_type="Input0"),
TextOutput("Subdirectory Path"),
TextOutput("Name"),
NumberOutput(
"Index",
output_type="if Input4 { min(uint, Input5 - 1) } else { uint }",
),
NumberOutput("Index", output_type="min(uint, max(0, IterOutput0.length - 1))"),
],
iterator_outputs=IteratorOutputInfo(
outputs=[0, 2, 3, 4],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
outputs=[
ImageOutput("Frame", channels=3),
NumberOutput(
"Frame Index",
output_type="if Input1 { min(uint, Input2 - 1) } else { uint }",
"Index",
output_type="min(uint, max(0, IterOutput0.length - 1))",
).with_docs("A counter that starts at 0 and increments by 1 for each frame."),
DirectoryOutput("Video Directory", of_input=0),
FileNameOutput("Name", of_input=0),
Expand All @@ -58,6 +58,7 @@
outputs=[0, 1], length_type="if Input1 { min(uint, Input2) } else { uint }"
),
node_context=True,
side_effects=True,
kind="generator",
)
def load_video_node(
Expand Down
31 changes: 26 additions & 5 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@

from sanic.log import logger

import navi
from api import (
BaseInput,
BaseOutput,
Collector,
ExecutionOptions,
Generator,
InputId,
IteratorOutputInfo,
IterOutputId,
Lazy,
NodeContext,
NodeData,
Expand Down Expand Up @@ -143,7 +146,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut
assert isinstance(
raw_output, Generator
), "Expected the output to be a generator"
return GeneratorOutput(generator=raw_output, partial_output=partial)
return GeneratorOutput(
info=generator_output,
generator=raw_output,
partial_output=partial,
)

assert l > len(generator_output.outputs)
assert isinstance(raw_output, (tuple, list))
Expand All @@ -159,7 +166,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut
if o.id not in generator_output.outputs:
partial[i] = o.enforce(rest.pop(0))

return GeneratorOutput(generator=iterator, partial_output=partial)
return GeneratorOutput(
info=generator_output,
generator=iterator,
partial_output=partial,
)


def run_node(
Expand Down Expand Up @@ -321,6 +332,7 @@ class RegularOutput:

@dataclass(frozen=True)
class GeneratorOutput:
info: IteratorOutputInfo
generator: Generator
partial_output: Output

Expand Down Expand Up @@ -618,7 +630,13 @@ def get_lazy_evaluation_time():
self.__send_node_finish(node, execution_time)
elif isinstance(output, GeneratorOutput):
await self.__send_node_broadcast(
node, output.partial_output, output.generator.expected_length
node,
output.partial_output,
output_sequence_types={
output.info.id: navi.named(
"Sequence", {"length": output.generator.expected_length}
)
},
)
# TODO: execution time

Expand Down Expand Up @@ -1010,7 +1028,10 @@ def __send_node_progress_done(self, node: Node, length: int):
)

async def __send_node_broadcast(
self, node: Node, output: Output, expected_length: int | None = None
self,
node: Node,
output: Output,
output_sequence_types: dict[IterOutputId, object] | None = None,
):
def compute_broadcast_data():
if self.progress.aborted:
Expand All @@ -1032,7 +1053,7 @@ async def send_broadcast():
"nodeId": node.id,
"data": data,
"types": types,
"expectedLength": expected_length,
"sequenceTypes": output_sequence_types,
},
}
)
Expand Down
3 changes: 2 additions & 1 deletion src/common/Backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
FeatureState,
InputId,
InputValue,
IterOutputTypes,
NodeSchema,
OutputData,
OutputTypes,
Expand Down Expand Up @@ -341,7 +342,7 @@ export interface BackendEventMap {
nodeId: string;
data: OutputData;
types: OutputTypes;
expectedLength?: number | null;
sequenceTypes?: IterOutputTypes | null;
};
'backend-status': {
message: string;
Expand Down
9 changes: 5 additions & 4 deletions src/common/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export interface Size {
export type SchemaId = string & { readonly __schemaId: never };
export type InputId = number & { readonly __inputId: never };
export type OutputId = number & { readonly __outputId: never };
export type IteratorInputId = number & { readonly __iteratorInputId: never };
export type IteratorOutputId = number & { readonly __iteratorOutputId: never };
export type IterInputId = number & { readonly __iteratorInputId: never };
export type IterOutputId = number & { readonly __iteratorOutputId: never };
export type GroupId = number & { readonly __groupId: never };
export type PackageId = string & { readonly __packageId: never };
export type FeatureId = string & { readonly __featureId: never };
Expand Down Expand Up @@ -279,14 +279,15 @@ export type InputHeight = Readonly<Record<InputId, number>>;
export type OutputData = Readonly<Record<OutputId, unknown>>;
export type OutputHeight = Readonly<Record<OutputId, number>>;
export type OutputTypes = Readonly<Partial<Record<OutputId, ExpressionJson | null>>>;
export type IterOutputTypes = Readonly<Partial<Record<IterOutputId, ExpressionJson | null>>>;

export interface IteratorInputInfo {
readonly id: IteratorInputId;
readonly id: IterInputId;
readonly inputs: readonly InputId[];
readonly sequenceType: ExpressionJson;
}
export interface IteratorOutputInfo {
readonly id: IteratorOutputId;
readonly id: IterOutputId;
readonly outputs: readonly OutputId[];
readonly sequenceType: ExpressionJson;
}
Expand Down
47 changes: 32 additions & 15 deletions src/common/nodes/TypeState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EvaluationError, NonNeverType, Type, isSameType } from '@chainner/navi';
import { EdgeData, InputId, NodeData, OutputId, SchemaId } from '../common-types';
import { EdgeData, InputId, IterOutputId, NodeData, OutputId, SchemaId } from '../common-types';
import { log } from '../log';
import { PassthroughMap } from '../PassthroughMap';
import {
Expand All @@ -22,23 +22,38 @@ const assignmentErrorEquals = (
isSameType(a.inputType, b.inputType)
);
};
const mapEqual = <K, V extends NonNullable<unknown>>(
a: ReadonlyMap<K, V>,
b: ReadonlyMap<K, V>,
eq: (a: V, b: V) => boolean
): boolean => {
if (a.size !== b.size) return false;
for (const [key, value] of a) {
const otherValue = b.get(key);
if (otherValue === undefined || !eq(value, otherValue)) return false;
}
return true;
};
const arrayEqual = <T>(
a: ReadonlyArray<T>,
b: ReadonlyArray<T>,
eq: (a: T, b: T) => boolean
): boolean => {
if (a.length !== b.length) return false;
for (let i = 0; i < a.length; i += 1) {
if (!eq(a[i], b[i])) return false;
}
return true;
};
const instanceEqual = (a: FunctionInstance, b: FunctionInstance): boolean => {
if (a.definition !== b.definition) return false;

for (const [key, value] of a.inputs) {
const otherValue = b.inputs.get(key);
if (!otherValue || !isSameType(value, otherValue)) return false;
}

for (const [key, value] of a.outputs) {
const otherValue = b.outputs.get(key);
if (!otherValue || !isSameType(value, otherValue)) return false;
}
if (!mapEqual(a.inputs, b.inputs, isSameType)) return false;
if (!mapEqual(a.inputSequence, b.inputSequence, isSameType)) return false;
if (!mapEqual(a.outputs, b.outputs, isSameType)) return false;
if (!mapEqual(a.outputSequence, b.outputSequence, isSameType)) return false;

if (a.inputErrors.length !== b.inputErrors.length) return false;
for (let i = 0; i < a.inputErrors.length; i += 1) {
if (!assignmentErrorEquals(a.inputErrors[i], b.inputErrors[i])) return false;
}
if (!arrayEqual(a.inputErrors, b.inputErrors, assignmentErrorEquals)) return false;

return true;
};
Expand All @@ -65,7 +80,8 @@ export class TypeState {
static create(
nodesMap: ReadonlyMap<string, Node<NodeData>>,
rawEdges: readonly Edge<EdgeData>[],
outputNarrowing: ReadonlyMap<string, ReadonlyMap<OutputId | 'length', Type>>,
outputNarrowing: ReadonlyMap<string, ReadonlyMap<OutputId, Type>>,
sequenceOutputNarrowing: ReadonlyMap<string, ReadonlyMap<IterOutputId, Type>>,
functionDefinitions: ReadonlyMap<SchemaId, FunctionDefinition>,
passthrough?: PassthroughMap,
previousTypeState?: TypeState
Expand Down Expand Up @@ -127,6 +143,7 @@ export class TypeState {
return undefined;
},
outputNarrowing.get(n.id),
sequenceOutputNarrowing.get(n.id),
passthroughInfo
);
} catch (error) {
Expand Down
45 changes: 13 additions & 32 deletions src/common/types/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ import {
Input,
InputId,
InputSchemaValue,
IteratorInputId,
IterInputId,
IterOutputId,
IteratorInputInfo,
IteratorOutputId,
IteratorOutputInfo,
NodeSchema,
Output,
Expand Down Expand Up @@ -56,9 +56,9 @@ const getParamRefs = <R extends ParamRef>(
};

export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const;
export const getIterInputParamName = (id: IteratorInputId) => `IterInput${id}` as const;
export const getIterInputParamName = (id: IterInputId) => `IterInput${id}` as const;
export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const;
export const getIterOutputParamName = (id: IteratorOutputId) => `IterOutput${id}` as const;
export const getIterOutputParamName = (id: IterOutputId) => `IterOutput${id}` as const;

interface BaseDesc<P extends GenericParam> {
readonly type: P;
Expand Down Expand Up @@ -591,7 +591,8 @@ export class FunctionInstance {
static fromPartialInputs(
definition: FunctionDefinition,
partialInputs: (inputId: InputId) => NonNeverType | undefined,
outputNarrowing: ReadonlyMap<OutputId | 'length', Type> = EMPTY_MAP,
outputNarrowing: ReadonlyMap<OutputId, Type> = EMPTY_MAP,
sequenceOutputNarrowing: ReadonlyMap<IterOutputId, Type> = EMPTY_MAP,
passthrough?: PassthroughInfo
): FunctionInstance {
const inputErrors: FunctionInputAssignmentError[] = [];
Expand Down Expand Up @@ -689,11 +690,12 @@ export class FunctionInstance {
type = item.default;
}

if (item.type === 'Output') {
const narrowing = outputNarrowing.get(item.output.id);
if (narrowing) {
type = intersect(narrowing, type);
}
const narrowing =
item.type === 'Output'
? outputNarrowing.get(item.output.id)
: sequenceOutputNarrowing.get(item.iterOutput.id);
if (narrowing) {
type = intersect(narrowing, type);
}

if (type.type === 'never') {
Expand All @@ -708,28 +710,7 @@ export class FunctionInstance {
scope.assignParameter(getOutputParamName(item.output.id), type);
} else {
for (const id of item.iterOutput.outputs) {
if (item.iterOutput.outputs.includes(id)) {
const predeterminedLength = outputNarrowing.get('length');
if (predeterminedLength && predeterminedLength.type !== 'never') {
const sequenceType = evaluate(
fromJson(
`Sequence { length: ${predeterminedLength.toString()} }`
),
scope
);
if (sequenceType.type !== 'never') {
outputLengths.set(id, sequenceType);
}
} else {
const lengthType = evaluate(
fromJson(item.iterOutput.sequenceType),
scope
);
if (lengthType.type !== 'never') {
outputLengths.set(id, lengthType);
}
}
}
outputLengths.set(id, type);
}
}
scope.assignParameter(item.param, type);
Expand Down
11 changes: 9 additions & 2 deletions src/main/cli/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { SchemaMap } from '../../common/SchemaMap';
import { ChainnerSettings } from '../../common/settings/settings';
import { FunctionDefinition } from '../../common/types/function';
import { ProgressController, ProgressMonitor, ProgressToken } from '../../common/ui/progress';
import { assertNever, delay } from '../../common/util';
import { EMPTY_MAP, assertNever, delay } from '../../common/util';
import { RunArguments } from '../arguments';
import { BackendProcess } from '../backend/process';
import { setupBackend } from '../backend/setup';
Expand Down Expand Up @@ -143,7 +143,14 @@ const ensureStaticCorrectness = (
}

const byId = new Map(nodes.map((n) => [n.id, n]));
const typeState = TypeState.create(byId, edges, new Map(), functionDefinitions, passthrough);
const typeState = TypeState.create(
byId,
edges,
EMPTY_MAP,
EMPTY_MAP,
functionDefinitions,
passthrough
);
const chainLineage = new ChainLineage(schemata, nodes, edges);

const invalidNodes = nodes.flatMap((node) => {
Expand Down
1 change: 1 addition & 0 deletions src/renderer/components/NodeDocumentation/NodeExample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ export const NodeExample = memo(({ selectedSchema }: NodeExampleProps) => {
new Map([[nodeId, node]]),
EMPTY_ARRAY,
EMPTY_MAP,
EMPTY_MAP,
functionDefinitions,
PassthroughMap.EMPTY
);
Expand Down
Loading

0 comments on commit a1ce9d5

Please sign in to comment.