From a1ce9d58e6fdba31feed8f2c3bca4e660ea3d763 Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Mon, 1 Jul 2024 17:19:56 +0200 Subject: [PATCH] Fixed typing and broadcasting issue for sequence types (#2972) * Fixed typing and broadcasting issue for sequence types * Fixed errors --- backend/src/events.py | 4 +- .../image/batch_processing/load_images.py | 5 +- .../image/video_frames/load_video.py | 5 +- backend/src/process.py | 31 +++++-- src/common/Backend.ts | 3 +- src/common/common-types.ts | 9 +- src/common/nodes/TypeState.ts | 47 +++++++---- src/common/types/function.ts | 45 +++------- src/main/cli/run.ts | 11 ++- .../NodeDocumentation/NodeExample.tsx | 1 + src/renderer/components/node/NodeOutputs.tsx | 17 +++- src/renderer/contexts/ExecutionContext.tsx | 31 ++++--- src/renderer/contexts/GlobalNodeState.tsx | 82 ++++++------------- src/renderer/helpers/nodeState.ts | 4 +- src/renderer/hooks/useOutputDataStore.ts | 9 +- src/renderer/hooks/useTypeMap.ts | 60 ++++++++++++++ 16 files changed, 222 insertions(+), 142 deletions(-) create mode 100644 src/renderer/hooks/useTypeMap.ts diff --git a/backend/src/events.py b/backend/src/events.py index ee177508ca..5f329e8c18 100644 --- a/backend/src/events.py +++ b/backend/src/events.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Dict, Literal, TypedDict, Union -from api import ErrorValue, InputId, NodeId, OutputId +from api import ErrorValue, InputId, IterOutputId, NodeId, OutputId # General events @@ -89,7 +89,7 @@ class NodeBroadcastData(TypedDict): nodeId: NodeId data: dict[OutputId, object] types: dict[OutputId, object] - expectedLength: int | None + sequenceTypes: dict[IterOutputId, object] | None class NodeBroadcastEvent(TypedDict): diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index 39dec27ba3..3c4c000164 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -94,10 +94,7 @@ def list_glob(directory: Path, globexpr: str, ext_filter: list[str]) -> list[Pat DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), - NumberOutput( - "Index", - output_type="if Input4 { min(uint, Input5 - 1) } else { uint }", - ), + NumberOutput("Index", output_type="min(uint, max(0, IterOutput0.length - 1))"), ], iterator_outputs=IteratorOutputInfo( outputs=[0, 2, 3, 4], diff --git a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py index b546d545fb..337a27bf98 100644 --- a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py +++ b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py @@ -46,8 +46,8 @@ outputs=[ ImageOutput("Frame", channels=3), NumberOutput( - "Frame Index", - output_type="if Input1 { min(uint, Input2 - 1) } else { uint }", + "Index", + output_type="min(uint, max(0, IterOutput0.length - 1))", ).with_docs("A counter that starts at 0 and increments by 1 for each frame."), DirectoryOutput("Video Directory", of_input=0), FileNameOutput("Name", of_input=0), @@ -58,6 +58,7 @@ outputs=[0, 1], length_type="if Input1 { min(uint, Input2) } else { uint }" ), node_context=True, + side_effects=True, kind="generator", ) def load_video_node( diff --git a/backend/src/process.py b/backend/src/process.py index d5cd8a0f53..ae323e41a6 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -13,6 +13,7 @@ from sanic.log import logger +import navi from api import ( BaseInput, BaseOutput, @@ -20,6 +21,8 @@ ExecutionOptions, Generator, InputId, + IteratorOutputInfo, + IterOutputId, Lazy, NodeContext, NodeData, @@ -143,7 +146,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut assert isinstance( raw_output, Generator ), "Expected the output to be a generator" - return GeneratorOutput(generator=raw_output, partial_output=partial) + return GeneratorOutput( + info=generator_output, + generator=raw_output, + partial_output=partial, + ) assert l > len(generator_output.outputs) assert isinstance(raw_output, (tuple, list)) @@ -159,7 +166,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut if o.id not in generator_output.outputs: partial[i] = o.enforce(rest.pop(0)) - return GeneratorOutput(generator=iterator, partial_output=partial) + return GeneratorOutput( + info=generator_output, + generator=iterator, + partial_output=partial, + ) def run_node( @@ -321,6 +332,7 @@ class RegularOutput: @dataclass(frozen=True) class GeneratorOutput: + info: IteratorOutputInfo generator: Generator partial_output: Output @@ -618,7 +630,13 @@ def get_lazy_evaluation_time(): self.__send_node_finish(node, execution_time) elif isinstance(output, GeneratorOutput): await self.__send_node_broadcast( - node, output.partial_output, output.generator.expected_length + node, + output.partial_output, + output_sequence_types={ + output.info.id: navi.named( + "Sequence", {"length": output.generator.expected_length} + ) + }, ) # TODO: execution time @@ -1010,7 +1028,10 @@ def __send_node_progress_done(self, node: Node, length: int): ) async def __send_node_broadcast( - self, node: Node, output: Output, expected_length: int | None = None + self, + node: Node, + output: Output, + output_sequence_types: dict[IterOutputId, object] | None = None, ): def compute_broadcast_data(): if self.progress.aborted: @@ -1032,7 +1053,7 @@ async def send_broadcast(): "nodeId": node.id, "data": data, "types": types, - "expectedLength": expected_length, + "sequenceTypes": output_sequence_types, }, } ) diff --git a/src/common/Backend.ts b/src/common/Backend.ts index 0e412e405c..45b61d7b7b 100644 --- a/src/common/Backend.ts +++ b/src/common/Backend.ts @@ -6,6 +6,7 @@ import { FeatureState, InputId, InputValue, + IterOutputTypes, NodeSchema, OutputData, OutputTypes, @@ -341,7 +342,7 @@ export interface BackendEventMap { nodeId: string; data: OutputData; types: OutputTypes; - expectedLength?: number | null; + sequenceTypes?: IterOutputTypes | null; }; 'backend-status': { message: string; diff --git a/src/common/common-types.ts b/src/common/common-types.ts index ea18447aba..4322c24f81 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -16,8 +16,8 @@ export interface Size { export type SchemaId = string & { readonly __schemaId: never }; export type InputId = number & { readonly __inputId: never }; export type OutputId = number & { readonly __outputId: never }; -export type IteratorInputId = number & { readonly __iteratorInputId: never }; -export type IteratorOutputId = number & { readonly __iteratorOutputId: never }; +export type IterInputId = number & { readonly __iteratorInputId: never }; +export type IterOutputId = number & { readonly __iteratorOutputId: never }; export type GroupId = number & { readonly __groupId: never }; export type PackageId = string & { readonly __packageId: never }; export type FeatureId = string & { readonly __featureId: never }; @@ -279,14 +279,15 @@ export type InputHeight = Readonly>; export type OutputData = Readonly>; export type OutputHeight = Readonly>; export type OutputTypes = Readonly>>; +export type IterOutputTypes = Readonly>>; export interface IteratorInputInfo { - readonly id: IteratorInputId; + readonly id: IterInputId; readonly inputs: readonly InputId[]; readonly sequenceType: ExpressionJson; } export interface IteratorOutputInfo { - readonly id: IteratorOutputId; + readonly id: IterOutputId; readonly outputs: readonly OutputId[]; readonly sequenceType: ExpressionJson; } diff --git a/src/common/nodes/TypeState.ts b/src/common/nodes/TypeState.ts index 96dcb21a12..06612f0d4c 100644 --- a/src/common/nodes/TypeState.ts +++ b/src/common/nodes/TypeState.ts @@ -1,5 +1,5 @@ import { EvaluationError, NonNeverType, Type, isSameType } from '@chainner/navi'; -import { EdgeData, InputId, NodeData, OutputId, SchemaId } from '../common-types'; +import { EdgeData, InputId, IterOutputId, NodeData, OutputId, SchemaId } from '../common-types'; import { log } from '../log'; import { PassthroughMap } from '../PassthroughMap'; import { @@ -22,23 +22,38 @@ const assignmentErrorEquals = ( isSameType(a.inputType, b.inputType) ); }; +const mapEqual = >( + a: ReadonlyMap, + b: ReadonlyMap, + eq: (a: V, b: V) => boolean +): boolean => { + if (a.size !== b.size) return false; + for (const [key, value] of a) { + const otherValue = b.get(key); + if (otherValue === undefined || !eq(value, otherValue)) return false; + } + return true; +}; +const arrayEqual = ( + a: ReadonlyArray, + b: ReadonlyArray, + eq: (a: T, b: T) => boolean +): boolean => { + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i += 1) { + if (!eq(a[i], b[i])) return false; + } + return true; +}; const instanceEqual = (a: FunctionInstance, b: FunctionInstance): boolean => { if (a.definition !== b.definition) return false; - for (const [key, value] of a.inputs) { - const otherValue = b.inputs.get(key); - if (!otherValue || !isSameType(value, otherValue)) return false; - } - - for (const [key, value] of a.outputs) { - const otherValue = b.outputs.get(key); - if (!otherValue || !isSameType(value, otherValue)) return false; - } + if (!mapEqual(a.inputs, b.inputs, isSameType)) return false; + if (!mapEqual(a.inputSequence, b.inputSequence, isSameType)) return false; + if (!mapEqual(a.outputs, b.outputs, isSameType)) return false; + if (!mapEqual(a.outputSequence, b.outputSequence, isSameType)) return false; - if (a.inputErrors.length !== b.inputErrors.length) return false; - for (let i = 0; i < a.inputErrors.length; i += 1) { - if (!assignmentErrorEquals(a.inputErrors[i], b.inputErrors[i])) return false; - } + if (!arrayEqual(a.inputErrors, b.inputErrors, assignmentErrorEquals)) return false; return true; }; @@ -65,7 +80,8 @@ export class TypeState { static create( nodesMap: ReadonlyMap>, rawEdges: readonly Edge[], - outputNarrowing: ReadonlyMap>, + outputNarrowing: ReadonlyMap>, + sequenceOutputNarrowing: ReadonlyMap>, functionDefinitions: ReadonlyMap, passthrough?: PassthroughMap, previousTypeState?: TypeState @@ -127,6 +143,7 @@ export class TypeState { return undefined; }, outputNarrowing.get(n.id), + sequenceOutputNarrowing.get(n.id), passthroughInfo ); } catch (error) { diff --git a/src/common/types/function.ts b/src/common/types/function.ts index 50afe0769d..4c2dc9a082 100644 --- a/src/common/types/function.ts +++ b/src/common/types/function.ts @@ -19,9 +19,9 @@ import { Input, InputId, InputSchemaValue, - IteratorInputId, + IterInputId, + IterOutputId, IteratorInputInfo, - IteratorOutputId, IteratorOutputInfo, NodeSchema, Output, @@ -56,9 +56,9 @@ const getParamRefs = ( }; export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const; -export const getIterInputParamName = (id: IteratorInputId) => `IterInput${id}` as const; +export const getIterInputParamName = (id: IterInputId) => `IterInput${id}` as const; export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const; -export const getIterOutputParamName = (id: IteratorOutputId) => `IterOutput${id}` as const; +export const getIterOutputParamName = (id: IterOutputId) => `IterOutput${id}` as const; interface BaseDesc

{ readonly type: P; @@ -591,7 +591,8 @@ export class FunctionInstance { static fromPartialInputs( definition: FunctionDefinition, partialInputs: (inputId: InputId) => NonNeverType | undefined, - outputNarrowing: ReadonlyMap = EMPTY_MAP, + outputNarrowing: ReadonlyMap = EMPTY_MAP, + sequenceOutputNarrowing: ReadonlyMap = EMPTY_MAP, passthrough?: PassthroughInfo ): FunctionInstance { const inputErrors: FunctionInputAssignmentError[] = []; @@ -689,11 +690,12 @@ export class FunctionInstance { type = item.default; } - if (item.type === 'Output') { - const narrowing = outputNarrowing.get(item.output.id); - if (narrowing) { - type = intersect(narrowing, type); - } + const narrowing = + item.type === 'Output' + ? outputNarrowing.get(item.output.id) + : sequenceOutputNarrowing.get(item.iterOutput.id); + if (narrowing) { + type = intersect(narrowing, type); } if (type.type === 'never') { @@ -708,28 +710,7 @@ export class FunctionInstance { scope.assignParameter(getOutputParamName(item.output.id), type); } else { for (const id of item.iterOutput.outputs) { - if (item.iterOutput.outputs.includes(id)) { - const predeterminedLength = outputNarrowing.get('length'); - if (predeterminedLength && predeterminedLength.type !== 'never') { - const sequenceType = evaluate( - fromJson( - `Sequence { length: ${predeterminedLength.toString()} }` - ), - scope - ); - if (sequenceType.type !== 'never') { - outputLengths.set(id, sequenceType); - } - } else { - const lengthType = evaluate( - fromJson(item.iterOutput.sequenceType), - scope - ); - if (lengthType.type !== 'never') { - outputLengths.set(id, lengthType); - } - } - } + outputLengths.set(id, type); } } scope.assignParameter(item.param, type); diff --git a/src/main/cli/run.ts b/src/main/cli/run.ts index 4b77a05dde..f33789159c 100644 --- a/src/main/cli/run.ts +++ b/src/main/cli/run.ts @@ -16,7 +16,7 @@ import { SchemaMap } from '../../common/SchemaMap'; import { ChainnerSettings } from '../../common/settings/settings'; import { FunctionDefinition } from '../../common/types/function'; import { ProgressController, ProgressMonitor, ProgressToken } from '../../common/ui/progress'; -import { assertNever, delay } from '../../common/util'; +import { EMPTY_MAP, assertNever, delay } from '../../common/util'; import { RunArguments } from '../arguments'; import { BackendProcess } from '../backend/process'; import { setupBackend } from '../backend/setup'; @@ -143,7 +143,14 @@ const ensureStaticCorrectness = ( } const byId = new Map(nodes.map((n) => [n.id, n])); - const typeState = TypeState.create(byId, edges, new Map(), functionDefinitions, passthrough); + const typeState = TypeState.create( + byId, + edges, + EMPTY_MAP, + EMPTY_MAP, + functionDefinitions, + passthrough + ); const chainLineage = new ChainLineage(schemata, nodes, edges); const invalidNodes = nodes.flatMap((node) => { diff --git a/src/renderer/components/NodeDocumentation/NodeExample.tsx b/src/renderer/components/NodeDocumentation/NodeExample.tsx index 0be7a3a336..906e36a700 100644 --- a/src/renderer/components/NodeDocumentation/NodeExample.tsx +++ b/src/renderer/components/NodeDocumentation/NodeExample.tsx @@ -119,6 +119,7 @@ export const NodeExample = memo(({ selectedSchema }: NodeExampleProps) => { new Map([[nodeId, node]]), EMPTY_ARRAY, EMPTY_MAP, + EMPTY_MAP, functionDefinitions, PassthroughMap.EMPTY ); diff --git a/src/renderer/components/node/NodeOutputs.tsx b/src/renderer/components/node/NodeOutputs.tsx index 63de6ad3bb..ef168c95ff 100644 --- a/src/renderer/components/node/NodeOutputs.tsx +++ b/src/renderer/components/node/NodeOutputs.tsx @@ -58,7 +58,7 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { } = nodeState; const { functionDefinitions } = useContext(BackendContext); - const { setManualOutputType } = useContext(GlobalContext); + const { setManualOutputType, setManualSequenceOutputType } = useContext(GlobalContext); const outputDataEntry = useContextSelector(GlobalVolatileContext, (c) => c.outputDataMap.get(id) ); @@ -80,6 +80,7 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { ); const currentTypes = stale ? undefined : outputDataEntry?.types; + const currentSequenceTypes = stale ? undefined : outputDataEntry?.sequenceTypes; const { isAutomatic } = useAutomaticFeatures(id, schemaId); @@ -89,8 +90,20 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { const type = evalExpression(currentTypes?.[output.id]); setManualOutputType(id, output.id, type); } + for (const iterOutput of schema.iteratorOutputs) { + const type = evalExpression(currentSequenceTypes?.[iterOutput.id]); + setManualSequenceOutputType(id, iterOutput.id, type); + } } - }, [id, currentTypes, schema, setManualOutputType, isAutomatic]); + }, [ + id, + currentTypes, + currentSequenceTypes, + schema, + setManualOutputType, + setManualSequenceOutputType, + isAutomatic, + ]); const isCollapsed = useIsCollapsedNode(); if (isCollapsed) { diff --git a/src/renderer/contexts/ExecutionContext.tsx b/src/renderer/contexts/ExecutionContext.tsx index 1c77f6cdb0..828e522be9 100644 --- a/src/renderer/contexts/ExecutionContext.tsx +++ b/src/renderer/contexts/ExecutionContext.tsx @@ -167,7 +167,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> outputDataActions, getInputHash, setManualOutputType, - clearManualOutputTypes, + clearManualTypes, } = useContext(GlobalContext); const { schemata, @@ -262,7 +262,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> let broadcastData; let types; let progress; - let expectedLength; + let sequenceTypes; for (const { type, data } of events) { if (type === 'node-start') { @@ -273,7 +273,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> } else if (type === 'node-broadcast') { broadcastData = data.data; types = data.types; - expectedLength = data.expectedLength; + sequenceTypes = data.sequenceTypes ?? undefined; } else { progress = data; } @@ -283,14 +283,26 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> setNodeStatus(executionStatus, [nodeId]); } - if (executionTime !== undefined || broadcastData !== undefined || types !== undefined) { + if ( + executionTime !== undefined || + broadcastData !== undefined || + types !== undefined || + sequenceTypes !== undefined + ) { // TODO: This is incorrect. The inputs of the node might have changed since // the chain started running. However, sending the then current input hashes // of the chain to the backend along with the rest of its data and then making // the backend send us those hashes is incorrect too because of iterators, I // think. const inputHash = getInputHash(nodeId); - outputDataActions.set(nodeId, executionTime, inputHash, broadcastData, types); + outputDataActions.set( + nodeId, + executionTime, + inputHash, + broadcastData, + types, + sequenceTypes + ); } if (progress) { @@ -315,11 +327,6 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> } } } - - if (expectedLength) { - const type = evaluate(fromJson(expectedLength), getChainnerScope()); - setManualOutputType(nodeId, 'length', type); - } }; const nodeEventBacklog = useEventBacklog({ process: (events: NodeEvents[]) => { @@ -499,7 +506,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> nodeEventBacklog.processAll(); clearNodeStatusMap(); setStatus(ExecutionStatus.READY); - clearManualOutputTypes(iteratorNodeIds); + clearManualTypes(iteratorNodeIds); } }, [ getNodes, @@ -516,7 +523,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> packageSettings, clearNodeStatusMap, nodeEventBacklog, - clearManualOutputTypes, + clearManualTypes, ]); const resume = useCallback(async () => { diff --git a/src/renderer/contexts/GlobalNodeState.tsx b/src/renderer/contexts/GlobalNodeState.tsx index fec1cb899c..42230dd74d 100644 --- a/src/renderer/contexts/GlobalNodeState.tsx +++ b/src/renderer/contexts/GlobalNodeState.tsx @@ -1,4 +1,4 @@ -import { Expression, Type, evaluate } from '@chainner/navi'; +import { Expression } from '@chainner/navi'; import { dirname, parse } from 'path'; import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { @@ -18,6 +18,7 @@ import { InputId, InputKind, InputValue, + IterOutputId, Mutable, NodeData, OutputId, @@ -83,6 +84,7 @@ import { } from '../hooks/useOutputDataStore'; import { getSessionStorageOrDefault, useSessionStorage } from '../hooks/useSessionStorage'; import { useSettings } from '../hooks/useSettings'; +import { useTypeMap } from '../hooks/useTypeMap'; import { ipcRenderer } from '../safeIpc'; import { AlertBoxContext, AlertType } from './AlertBoxContext'; import { BackendContext } from './BackendContext'; @@ -137,12 +139,13 @@ interface Global { setZoom: SetState; exportViewportScreenshot: () => void; exportViewportScreenshotToClipboard: () => void; - setManualOutputType: ( + setManualOutputType: (nodeId: string, outputId: OutputId, type: Expression | undefined) => void; + setManualSequenceOutputType: ( nodeId: string, - outputId: OutputId | 'length', + iterOutputId: IterOutputId, type: Expression | undefined ) => void; - clearManualOutputTypes: (nodes: Iterable) => void; + clearManualTypes: (nodes: Iterable) => void; typeStateRef: Readonly>; chainLineageRef: Readonly>; outputDataActions: OutputDataActions; @@ -207,56 +210,22 @@ export const GlobalProvider = memo( [addEdgeChanges] ); - const [manualOutputTypes, setManualOutputTypes] = useState(() => ({ - map: new Map>(), - })); - const setManualOutputType = useCallback( - ( - nodeId: string, - outputId: OutputId | 'length', - expression: Expression | undefined - ): void => { - const getType = () => { - if (expression === undefined) { - return undefined; - } - - try { - return evaluate(expression, scope); - } catch (error) { - log.error(error); - return undefined; - } - }; - - setManualOutputTypes(({ map }) => { - let inner = map.get(nodeId); - const type = getType(); - if (type) { - if (!inner) { - inner = new Map(); - map.set(nodeId, inner); - } - - inner.set(outputId, type); - } else { - inner?.delete(outputId); - } - return { map }; - }); - }, - [setManualOutputTypes, scope] - ); - const clearManualOutputTypes = useCallback( - (nodes: Iterable): void => { - setManualOutputTypes(({ map }) => { - for (const nodeId of nodes) { - map.delete(nodeId); - } - return { map }; - }); + const [manualOutputTypes, setManualOutputType, clearManualOutputTypes] = useTypeMap< + string, + OutputId + >(scope); + const [ + manualSequenceOutputTypes, + setManualSequenceOutputType, + clearManualSequenceOutputTypes, + ] = useTypeMap(scope); + + const clearManualTypes = useCallback( + (nodes: Iterable) => { + clearManualOutputTypes(nodes); + clearManualSequenceOutputTypes(nodes); }, - [setManualOutputTypes] + [clearManualOutputTypes, clearManualSequenceOutputTypes] ); const [typeState, setTypeState] = useState(TypeState.empty); @@ -270,7 +239,7 @@ export const GlobalProvider = memo( // remove manual overrides of nodes that no longer exist if (manualOutputTypes.map.size > 0) { const ids = [...manualOutputTypes.map.keys()]; - for (const id of ids.filter((key) => !nodeMap.has(key) && key !== 'length')) { + for (const id of ids.filter((key) => !nodeMap.has(key))) { // use interior mutability to not cause updates manualOutputTypes.map.delete(id); } @@ -280,6 +249,7 @@ export const GlobalProvider = memo( nodeMap, getEdges(), manualOutputTypes.map, + manualSequenceOutputTypes.map, functionDefinitions, passthrough, typeStateRef.current @@ -296,6 +266,7 @@ export const GlobalProvider = memo( nodeChanges, edgeChanges, manualOutputTypes, + manualSequenceOutputTypes, functionDefinitions, schemata, passthrough, @@ -1329,7 +1300,8 @@ export const GlobalProvider = memo( exportViewportScreenshot, exportViewportScreenshotToClipboard, setManualOutputType, - clearManualOutputTypes, + setManualSequenceOutputType, + clearManualTypes, typeStateRef, chainLineageRef, outputDataActions, diff --git a/src/renderer/helpers/nodeState.ts b/src/renderer/helpers/nodeState.ts index daac07f6f7..db67f743af 100644 --- a/src/renderer/helpers/nodeState.ts +++ b/src/renderer/helpers/nodeState.ts @@ -1,4 +1,3 @@ -import { NonNeverType } from '@chainner/navi'; import { useMemo } from 'react'; import { useContext, useContextSelector } from 'use-context-selector'; import { @@ -25,14 +24,13 @@ import { useMemoObject } from '../hooks/useMemo'; export interface TypeInfo { readonly instance: FunctionInstance | undefined; readonly connectedInputs: ReadonlySet; - readonly iteratedInputLengths?: ReadonlyMap; - readonly iteratedOutputLengths?: ReadonlyMap; } const useTypeInfo = (id: string): TypeInfo => { const instance = useContextSelector(GlobalVolatileContext, (c) => c.typeState.functions.get(id) ); + const connectedInputsString = useContextSelector(GlobalVolatileContext, (c) => { const connected = c.typeState.edges.byTarget.get(id); return IdSet.from(connected?.map((connection) => connection.inputId) ?? EMPTY_ARRAY); diff --git a/src/renderer/hooks/useOutputDataStore.ts b/src/renderer/hooks/useOutputDataStore.ts index 81b95cb2aa..f55ac9d3b1 100644 --- a/src/renderer/hooks/useOutputDataStore.ts +++ b/src/renderer/hooks/useOutputDataStore.ts @@ -1,6 +1,6 @@ import isDeepEqual from 'fast-deep-equal/react'; import { useCallback, useState } from 'react'; -import { OutputData, OutputTypes } from '../../common/common-types'; +import { IterOutputTypes, OutputData, OutputTypes } from '../../common/common-types'; import { EMPTY_MAP } from '../../common/util'; import { useMemoObject } from './useMemo'; @@ -9,6 +9,7 @@ export interface OutputDataEntry { lastExecutionTime: number | undefined; data: OutputData | undefined; types: OutputTypes | undefined; + sequenceTypes: IterOutputTypes | undefined; } export interface OutputDataActions { @@ -17,7 +18,8 @@ export interface OutputDataActions { executionTime: number | undefined, nodeInputHash: string, data: OutputData | undefined, - types: OutputTypes | undefined + types: OutputTypes | undefined, + sequenceTypes: IterOutputTypes | undefined ): void; delete(nodeId: string): void; clear(): void; @@ -28,7 +30,7 @@ export const useOutputDataStore = () => { const actions: OutputDataActions = { set: useCallback( - (nodeId, executionTime, inputHash, data, types) => { + (nodeId, executionTime, inputHash, data, types, sequenceTypes) => { setMap((prev) => { const existingEntry = prev.get(nodeId); @@ -36,6 +38,7 @@ export const useOutputDataStore = () => { const entry: OutputDataEntry = { data: useExisting ? existingEntry.data : data, types: useExisting ? existingEntry.types : types, + sequenceTypes: useExisting ? existingEntry.sequenceTypes : sequenceTypes, inputHash: useExisting ? existingEntry.inputHash : inputHash, lastExecutionTime: executionTime ?? existingEntry?.lastExecutionTime, }; diff --git a/src/renderer/hooks/useTypeMap.ts b/src/renderer/hooks/useTypeMap.ts new file mode 100644 index 0000000000..8cd6cec2ca --- /dev/null +++ b/src/renderer/hooks/useTypeMap.ts @@ -0,0 +1,60 @@ +import { Expression, Scope, Type, evaluate } from '@chainner/navi'; +import { useCallback, useState } from 'react'; +import { log } from '../../common/log'; + +/** + * A map of types that can be used as either a ref-like object or a state-like value. + */ +export const useTypeMap = (scope: Scope) => { + const [types, setTypes] = useState(() => ({ + map: new Map>(), + })); + + const setType = useCallback( + (nodeId: N, outputId: I, expression: Expression | undefined): void => { + const getType = () => { + if (expression === undefined) { + return undefined; + } + + try { + return evaluate(expression, scope); + } catch (error) { + log.error(error); + return undefined; + } + }; + + setTypes(({ map }) => { + let inner = map.get(nodeId); + const type = getType(); + if (type) { + if (!inner) { + inner = new Map(); + map.set(nodeId, inner); + } + + inner.set(outputId, type); + } else { + inner?.delete(outputId); + } + return { map }; + }); + }, + [setTypes, scope] + ); + + const clear = useCallback( + (nodes: Iterable): void => { + setTypes(({ map }) => { + for (const nodeId of nodes) { + map.delete(nodeId); + } + return { map }; + }); + }, + [setTypes] + ); + + return [types, setType, clear] as const; +};