Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed typing and broadcasting issue for sequence types #2972

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/src/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Dict, Literal, TypedDict, Union

from api import ErrorValue, InputId, NodeId, OutputId
from api import ErrorValue, InputId, IterOutputId, NodeId, OutputId

# General events

Expand Down Expand Up @@ -89,7 +89,7 @@ class NodeBroadcastData(TypedDict):
nodeId: NodeId
data: dict[OutputId, object]
types: dict[OutputId, object]
expectedLength: int | None
sequenceTypes: dict[IterOutputId, object] | None


class NodeBroadcastEvent(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ def list_glob(directory: Path, globexpr: str, ext_filter: list[str]) -> list[Pat
DirectoryOutput("Directory", output_type="Input0"),
TextOutput("Subdirectory Path"),
TextOutput("Name"),
NumberOutput(
"Index",
output_type="if Input4 { min(uint, Input5 - 1) } else { uint }",
),
NumberOutput("Index", output_type="min(uint, max(0, IterOutput0.length - 1))"),
],
iterator_outputs=IteratorOutputInfo(
outputs=[0, 2, 3, 4],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
outputs=[
ImageOutput("Frame", channels=3),
NumberOutput(
"Frame Index",
output_type="if Input1 { min(uint, Input2 - 1) } else { uint }",
"Index",
output_type="min(uint, max(0, IterOutput0.length - 1))",
).with_docs("A counter that starts at 0 and increments by 1 for each frame."),
DirectoryOutput("Video Directory", of_input=0),
FileNameOutput("Name", of_input=0),
Expand All @@ -58,6 +58,7 @@
outputs=[0, 1], length_type="if Input1 { min(uint, Input2) } else { uint }"
),
node_context=True,
side_effects=True,
kind="generator",
)
def load_video_node(
Expand Down
31 changes: 26 additions & 5 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@

from sanic.log import logger

import navi
from api import (
BaseInput,
BaseOutput,
Collector,
ExecutionOptions,
Generator,
InputId,
IteratorOutputInfo,
IterOutputId,
Lazy,
NodeContext,
NodeData,
Expand Down Expand Up @@ -143,7 +146,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut
assert isinstance(
raw_output, Generator
), "Expected the output to be a generator"
return GeneratorOutput(generator=raw_output, partial_output=partial)
return GeneratorOutput(
info=generator_output,
generator=raw_output,
partial_output=partial,
)

assert l > len(generator_output.outputs)
assert isinstance(raw_output, (tuple, list))
Expand All @@ -159,7 +166,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut
if o.id not in generator_output.outputs:
partial[i] = o.enforce(rest.pop(0))

return GeneratorOutput(generator=iterator, partial_output=partial)
return GeneratorOutput(
info=generator_output,
generator=iterator,
partial_output=partial,
)


def run_node(
Expand Down Expand Up @@ -321,6 +332,7 @@ class RegularOutput:

@dataclass(frozen=True)
class GeneratorOutput:
info: IteratorOutputInfo
generator: Generator
partial_output: Output

Expand Down Expand Up @@ -618,7 +630,13 @@ def get_lazy_evaluation_time():
self.__send_node_finish(node, execution_time)
elif isinstance(output, GeneratorOutput):
await self.__send_node_broadcast(
node, output.partial_output, output.generator.expected_length
node,
output.partial_output,
output_sequence_types={
output.info.id: navi.named(
"Sequence", {"length": output.generator.expected_length}
)
},
)
# TODO: execution time

Expand Down Expand Up @@ -1010,7 +1028,10 @@ def __send_node_progress_done(self, node: Node, length: int):
)

async def __send_node_broadcast(
self, node: Node, output: Output, expected_length: int | None = None
self,
node: Node,
output: Output,
output_sequence_types: dict[IterOutputId, object] | None = None,
):
def compute_broadcast_data():
if self.progress.aborted:
Expand All @@ -1032,7 +1053,7 @@ async def send_broadcast():
"nodeId": node.id,
"data": data,
"types": types,
"expectedLength": expected_length,
"sequenceTypes": output_sequence_types,
},
}
)
Expand Down
3 changes: 2 additions & 1 deletion src/common/Backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
FeatureState,
InputId,
InputValue,
IterOutputTypes,
NodeSchema,
OutputData,
OutputTypes,
Expand Down Expand Up @@ -341,7 +342,7 @@ export interface BackendEventMap {
nodeId: string;
data: OutputData;
types: OutputTypes;
expectedLength?: number | null;
sequenceTypes?: IterOutputTypes | null;
};
'backend-status': {
message: string;
Expand Down
9 changes: 5 additions & 4 deletions src/common/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ export interface Size {
export type SchemaId = string & { readonly __schemaId: never };
export type InputId = number & { readonly __inputId: never };
export type OutputId = number & { readonly __outputId: never };
export type IteratorInputId = number & { readonly __iteratorInputId: never };
export type IteratorOutputId = number & { readonly __iteratorOutputId: never };
export type IterInputId = number & { readonly __iteratorInputId: never };
export type IterOutputId = number & { readonly __iteratorOutputId: never };
export type GroupId = number & { readonly __groupId: never };
export type PackageId = string & { readonly __packageId: never };
export type FeatureId = string & { readonly __featureId: never };
Expand Down Expand Up @@ -279,14 +279,15 @@ export type InputHeight = Readonly<Record<InputId, number>>;
export type OutputData = Readonly<Record<OutputId, unknown>>;
export type OutputHeight = Readonly<Record<OutputId, number>>;
export type OutputTypes = Readonly<Partial<Record<OutputId, ExpressionJson | null>>>;
export type IterOutputTypes = Readonly<Partial<Record<IterOutputId, ExpressionJson | null>>>;

export interface IteratorInputInfo {
readonly id: IteratorInputId;
readonly id: IterInputId;
readonly inputs: readonly InputId[];
readonly sequenceType: ExpressionJson;
}
export interface IteratorOutputInfo {
readonly id: IteratorOutputId;
readonly id: IterOutputId;
readonly outputs: readonly OutputId[];
readonly sequenceType: ExpressionJson;
}
Expand Down
47 changes: 32 additions & 15 deletions src/common/nodes/TypeState.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EvaluationError, NonNeverType, Type, isSameType } from '@chainner/navi';
import { EdgeData, InputId, NodeData, OutputId, SchemaId } from '../common-types';
import { EdgeData, InputId, IterOutputId, NodeData, OutputId, SchemaId } from '../common-types';
import { log } from '../log';
import { PassthroughMap } from '../PassthroughMap';
import {
Expand All @@ -22,23 +22,38 @@ const assignmentErrorEquals = (
isSameType(a.inputType, b.inputType)
);
};
const mapEqual = <K, V extends NonNullable<unknown>>(
a: ReadonlyMap<K, V>,
b: ReadonlyMap<K, V>,
eq: (a: V, b: V) => boolean
): boolean => {
if (a.size !== b.size) return false;
for (const [key, value] of a) {
const otherValue = b.get(key);
if (otherValue === undefined || !eq(value, otherValue)) return false;
}
return true;
};
const arrayEqual = <T>(
a: ReadonlyArray<T>,
b: ReadonlyArray<T>,
eq: (a: T, b: T) => boolean
): boolean => {
if (a.length !== b.length) return false;
for (let i = 0; i < a.length; i += 1) {
if (!eq(a[i], b[i])) return false;
}
return true;
};
const instanceEqual = (a: FunctionInstance, b: FunctionInstance): boolean => {
if (a.definition !== b.definition) return false;

for (const [key, value] of a.inputs) {
const otherValue = b.inputs.get(key);
if (!otherValue || !isSameType(value, otherValue)) return false;
}

for (const [key, value] of a.outputs) {
const otherValue = b.outputs.get(key);
if (!otherValue || !isSameType(value, otherValue)) return false;
}
if (!mapEqual(a.inputs, b.inputs, isSameType)) return false;
if (!mapEqual(a.inputSequence, b.inputSequence, isSameType)) return false;
if (!mapEqual(a.outputs, b.outputs, isSameType)) return false;
if (!mapEqual(a.outputSequence, b.outputSequence, isSameType)) return false;

if (a.inputErrors.length !== b.inputErrors.length) return false;
for (let i = 0; i < a.inputErrors.length; i += 1) {
if (!assignmentErrorEquals(a.inputErrors[i], b.inputErrors[i])) return false;
}
if (!arrayEqual(a.inputErrors, b.inputErrors, assignmentErrorEquals)) return false;

return true;
};
Expand All @@ -65,7 +80,8 @@ export class TypeState {
static create(
nodesMap: ReadonlyMap<string, Node<NodeData>>,
rawEdges: readonly Edge<EdgeData>[],
outputNarrowing: ReadonlyMap<string, ReadonlyMap<OutputId | 'length', Type>>,
outputNarrowing: ReadonlyMap<string, ReadonlyMap<OutputId, Type>>,
sequenceOutputNarrowing: ReadonlyMap<string, ReadonlyMap<IterOutputId, Type>>,
functionDefinitions: ReadonlyMap<SchemaId, FunctionDefinition>,
passthrough?: PassthroughMap,
previousTypeState?: TypeState
Expand Down Expand Up @@ -127,6 +143,7 @@ export class TypeState {
return undefined;
},
outputNarrowing.get(n.id),
sequenceOutputNarrowing.get(n.id),
passthroughInfo
);
} catch (error) {
Expand Down
45 changes: 13 additions & 32 deletions src/common/types/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ import {
Input,
InputId,
InputSchemaValue,
IteratorInputId,
IterInputId,
IterOutputId,
IteratorInputInfo,
IteratorOutputId,
IteratorOutputInfo,
NodeSchema,
Output,
Expand Down Expand Up @@ -56,9 +56,9 @@ const getParamRefs = <R extends ParamRef>(
};

export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const;
export const getIterInputParamName = (id: IteratorInputId) => `IterInput${id}` as const;
export const getIterInputParamName = (id: IterInputId) => `IterInput${id}` as const;
export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const;
export const getIterOutputParamName = (id: IteratorOutputId) => `IterOutput${id}` as const;
export const getIterOutputParamName = (id: IterOutputId) => `IterOutput${id}` as const;

interface BaseDesc<P extends GenericParam> {
readonly type: P;
Expand Down Expand Up @@ -591,7 +591,8 @@ export class FunctionInstance {
static fromPartialInputs(
definition: FunctionDefinition,
partialInputs: (inputId: InputId) => NonNeverType | undefined,
outputNarrowing: ReadonlyMap<OutputId | 'length', Type> = EMPTY_MAP,
outputNarrowing: ReadonlyMap<OutputId, Type> = EMPTY_MAP,
sequenceOutputNarrowing: ReadonlyMap<IterOutputId, Type> = EMPTY_MAP,
passthrough?: PassthroughInfo
): FunctionInstance {
const inputErrors: FunctionInputAssignmentError[] = [];
Expand Down Expand Up @@ -689,11 +690,12 @@ export class FunctionInstance {
type = item.default;
}

if (item.type === 'Output') {
const narrowing = outputNarrowing.get(item.output.id);
if (narrowing) {
type = intersect(narrowing, type);
}
const narrowing =
item.type === 'Output'
? outputNarrowing.get(item.output.id)
: sequenceOutputNarrowing.get(item.iterOutput.id);
if (narrowing) {
type = intersect(narrowing, type);
}

if (type.type === 'never') {
Expand All @@ -708,28 +710,7 @@ export class FunctionInstance {
scope.assignParameter(getOutputParamName(item.output.id), type);
} else {
for (const id of item.iterOutput.outputs) {
if (item.iterOutput.outputs.includes(id)) {
const predeterminedLength = outputNarrowing.get('length');
if (predeterminedLength && predeterminedLength.type !== 'never') {
const sequenceType = evaluate(
fromJson(
`Sequence { length: ${predeterminedLength.toString()} }`
),
scope
);
if (sequenceType.type !== 'never') {
outputLengths.set(id, sequenceType);
}
} else {
const lengthType = evaluate(
fromJson(item.iterOutput.sequenceType),
scope
);
if (lengthType.type !== 'never') {
outputLengths.set(id, lengthType);
}
}
}
outputLengths.set(id, type);
}
}
scope.assignParameter(item.param, type);
Expand Down
11 changes: 9 additions & 2 deletions src/main/cli/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { SchemaMap } from '../../common/SchemaMap';
import { ChainnerSettings } from '../../common/settings/settings';
import { FunctionDefinition } from '../../common/types/function';
import { ProgressController, ProgressMonitor, ProgressToken } from '../../common/ui/progress';
import { assertNever, delay } from '../../common/util';
import { EMPTY_MAP, assertNever, delay } from '../../common/util';
import { RunArguments } from '../arguments';
import { BackendProcess } from '../backend/process';
import { setupBackend } from '../backend/setup';
Expand Down Expand Up @@ -143,7 +143,14 @@ const ensureStaticCorrectness = (
}

const byId = new Map(nodes.map((n) => [n.id, n]));
const typeState = TypeState.create(byId, edges, new Map(), functionDefinitions, passthrough);
const typeState = TypeState.create(
byId,
edges,
EMPTY_MAP,
EMPTY_MAP,
functionDefinitions,
passthrough
);
const chainLineage = new ChainLineage(schemata, nodes, edges);

const invalidNodes = nodes.flatMap((node) => {
Expand Down
1 change: 1 addition & 0 deletions src/renderer/components/NodeDocumentation/NodeExample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ export const NodeExample = memo(({ selectedSchema }: NodeExampleProps) => {
new Map([[nodeId, node]]),
EMPTY_ARRAY,
EMPTY_MAP,
EMPTY_MAP,
functionDefinitions,
PassthroughMap.EMPTY
);
Expand Down
Loading
Loading