Skip to content

Commit

Permalink
Colored compatibility hints when dragging wires
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Oct 31, 2023
1 parent 89ba4d9 commit 77a63bc
Show file tree
Hide file tree
Showing 20 changed files with 160 additions and 21 deletions.
8 changes: 6 additions & 2 deletions packages/app/src/components/LoopControllerNodePorts.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import { type FC, type MouseEvent } from 'react';
import { useNodeIO } from '../hooks/useGetNodeIO.js';
import { useStableCallback } from '../hooks/useStableCallback.js';
import { Port } from './Port.js';
import { type WireDef } from './WireLayer';
import { useDependsOnPlugins } from '../hooks/useDependsOnPlugins';
import { type DraggingWireDef } from '../state/graphBuilder';

export type NodePortsProps = {
node: ChartNode;
connections: NodeConnection[];
zoomedOut?: boolean;
draggingWire: WireDef | undefined;
draggingWire: DraggingWireDef | undefined;
closestPortToDraggingWire: { nodeId: NodeId; portId: PortId } | undefined;
onWireStartDrag?: (
event: MouseEvent<HTMLElement>,
Expand Down Expand Up @@ -115,6 +115,7 @@ export const LoopControllerNodePorts: FC<NodePortsProps> = ({
closest={
closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === input.id
}
draggingDataType={draggingWire?.dataType}
definition={input}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
Expand All @@ -141,6 +142,7 @@ export const LoopControllerNodePorts: FC<NodePortsProps> = ({
closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === output.id
}
definition={output}
draggingDataType={draggingWire?.dataType}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
onMouseOver={onPortMouseOver}
Expand Down Expand Up @@ -174,6 +176,7 @@ export const LoopControllerNodePorts: FC<NodePortsProps> = ({
closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === input.id
}
definition={input}
draggingDataType={draggingWire?.dataType}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
onMouseOver={onPortMouseOver}
Expand Down Expand Up @@ -203,6 +206,7 @@ export const LoopControllerNodePorts: FC<NodePortsProps> = ({
closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === output.id
}
definition={output}
draggingDataType={draggingWire?.dataType}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
onMouseOver={onPortMouseOver}
Expand Down
10 changes: 4 additions & 6 deletions packages/app/src/components/NodePorts.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import { useNodeIO } from '../hooks/useGetNodeIO.js';
import { useStableCallback } from '../hooks/useStableCallback.js';
import { Port } from './Port.js';
import { ErrorBoundary } from 'react-error-boundary';
import { type WireDef } from './WireLayer';
import { useDependsOnPlugins } from '../hooks/useDependsOnPlugins';
import { LoopControllerNodePorts } from './LoopControllerNodePorts';
import { type DraggingWireDef } from '../state/graphBuilder';

export type NodePortsProps = {
node: ChartNode;
connections: NodeConnection[];
zoomedOut?: boolean;
draggingWire: WireDef | undefined;
draggingDataType?: DataType;
draggingWire: DraggingWireDef | undefined;
closestPortToDraggingWire: { nodeId: NodeId; portId: PortId } | undefined;
onWireStartDrag?: (
event: MouseEvent<HTMLElement>,
Expand Down Expand Up @@ -58,7 +57,6 @@ export const NodePorts: FC<NodePortsProps> = ({
node,
connections,
draggingWire,
draggingDataType,
closestPortToDraggingWire,
onWireStartDrag,
onWireEndDrag,
Expand Down Expand Up @@ -97,7 +95,7 @@ export const NodePorts: FC<NodePortsProps> = ({
canDragTo={draggingWire ? !draggingWire.startPortIsInput : false}
closest={closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === input.id}
definition={input}
draggingDataType={draggingDataType}
draggingDataType={draggingWire?.dataType}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
onMouseOver={onPortMouseOver}
Expand All @@ -121,7 +119,7 @@ export const NodePorts: FC<NodePortsProps> = ({
canDragTo={draggingWire ? draggingWire.startPortIsInput : false}
closest={closestPortToDraggingWire?.nodeId === node.id && closestPortToDraggingWire.portId === output.id}
definition={output}
draggingDataType={draggingDataType}
draggingDataType={draggingWire?.dataType}
onMouseDown={handlePortMouseDown}
onMouseUp={handlePortMouseUp}
onMouseOver={onPortMouseOver}
Expand Down
35 changes: 29 additions & 6 deletions packages/app/src/components/Port.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import {
type PortId,
type NodeOutputDefinition,
type DataType,
isDataTypeAccepted,
canBeCoercedAny,
} from '@ironclad/rivet-core';
import { type FC, useRef, type MouseEvent, memo } from 'react';
import { type FC, useRef, type MouseEvent, memo, useMemo } from 'react';
import clsx from 'clsx';
import { useStableCallback } from '../hooks/useStableCallback';

Expand All @@ -18,7 +20,7 @@ export const Port: FC<{
canDragTo: boolean;
closest: boolean;
definition: NodeInputDefinition | NodeOutputDefinition;
draggingDataType?: DataType;
draggingDataType?: DataType | Readonly<DataType[]>;
onMouseDown?: (event: MouseEvent<HTMLDivElement>, port: PortId, isInput: boolean) => void;
onMouseUp?: (event: MouseEvent<HTMLDivElement>, port: PortId) => void;
onMouseOver?: (
Expand Down Expand Up @@ -67,13 +69,34 @@ export const Port: FC<{
onMouseOut?.(event, nodeId, input, id, definition);
});

const accepted = useMemo(() => {
if (!draggingDataType || !input) {
return '';
}

if (isDataTypeAccepted(draggingDataType, definition.dataType)) {
return 'compatible';
}

// We almost always coerce so default it to true for now...
if ((definition as NodeInputDefinition).coerced ?? true) {
return canBeCoercedAny(draggingDataType, definition.dataType) ? 'coerced' : 'incompatible';
}

return 'incompatible';
}, [draggingDataType, definition.dataType, (definition as NodeInputDefinition).coerced, input]);

return (
<div
key={id}
className={clsx('port', {
connected,
closest,
})}
className={clsx(
'port',
{
connected,
closest,
},
accepted,
)}
>
<div
ref={ref}
Expand Down
12 changes: 12 additions & 0 deletions packages/app/src/components/nodeStyles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,18 @@ export const nodeStyles = css`
border: 2px solid var(--primary-dark);
}
.port.compatible .port-circle {
border: 2px solid var(--success);
}
.port.coerced .port-circle {
border: 2px solid var(--warning);
}
.port.incompatible .port-circle {
border: 2px solid var(--error);
}
.port.connected .port-label {
color: var(--primary-text);
}
Expand Down
14 changes: 10 additions & 4 deletions packages/app/src/hooks/useDraggingWire.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,23 @@ export const useDraggingWire = (onConnectionsChanged: (connections: NodeConnecti
newConnections.splice(existingConnectionIndex, 1);
onConnectionsChanged(newConnections);

setDraggingWire({
startNodeId: connections[existingConnectionIndex]!.outputNodeId,
startPortId: connections[existingConnectionIndex]!.outputId,
const { outputId, outputNodeId } = connections[existingConnectionIndex]!;

const def = ioByNode[outputNodeId]!.outputDefinitions.find((o) => o.id === outputId)!;

setDraggingWire({
startNodeId: outputNodeId,
startPortId: outputId,
startPortIsInput: false,
dataType: def.dataType,
});
return;
}
return;
}
setDraggingWire({ startNodeId, startPortId, startPortIsInput: isInput });

const def = ioByNode[startNodeId]!.outputDefinitions.find((o) => o.id === startPortId)!;
setDraggingWire({ startNodeId, startPortId, startPortIsInput: isInput, dataType: def.dataType });
},
[connections, onConnectionsChanged, setDraggingWire],
);
Expand Down
5 changes: 4 additions & 1 deletion packages/app/src/state/graphBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
NodeImpl,
type NodeInputDefinition,
type PortId,
type DataType,
} from '@ironclad/rivet-core';
import { recoilPersist } from 'recoil-persist';
import { type WireDef } from '../components/WireLayer.js';
Expand Down Expand Up @@ -51,7 +52,9 @@ export const sidebarOpenState = atom<boolean>({
default: true,
});

export const draggingWireState = atom<WireDef | undefined>({
export type DraggingWireDef = WireDef & { readonly dataType: DataType | Readonly<DataType[]> };

export const draggingWireState = atom<DraggingWireDef | undefined>({
key: 'draggingWire',
default: undefined,
});
Expand Down
3 changes: 3 additions & 0 deletions packages/core/src/model/NodeBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ export type NodeInputDefinition = {

/** User-facing description of the port. */
description?: string;

/** Will the input value attempt to be coerced into the desired type? */
coerced?: boolean;
};

/** Represents an output definition of a node. */
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/AudioNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export class AudioNodeImpl extends NodeImpl<AudioNode> {
id: 'data' as PortId,
title: 'Data',
dataType: 'string',
coerced: false,
});
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/ExtractJsonNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export class ExtractJsonNodeImpl extends NodeImpl<ExtractJsonNode> {
title: 'Input',
dataType: 'string',
required: true,
coerced: false,
},
];
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/ExtractObjectPathNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export class ExtractObjectPathNodeImpl extends NodeImpl<ExtractObjectPathNode> {
title: 'Path',
dataType: 'string',
required: true,
coerced: false,
});
}

Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/model/nodes/ExtractRegexNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class ExtractRegexNodeImpl extends NodeImpl<ExtractRegexNode> {
title: 'Input',
dataType: 'string',
required: true,
coerced: false,
},
];

Expand All @@ -60,6 +61,7 @@ export class ExtractRegexNodeImpl extends NodeImpl<ExtractRegexNode> {
title: 'Regex',
dataType: 'string',
required: false,
coerced: false,
});
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/ExtractYamlNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export class ExtractYamlNodeImpl extends NodeImpl<ExtractYamlNode> {
title: 'Input',
dataType: 'string',
required: true,
coerced: false,
},
];

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/ImageNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export class ImageNodeImpl extends NodeImpl<ImageNode> {
id: 'data' as PortId,
title: 'Data',
dataType: 'string',
coerced: false,
});
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/PopNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export class PopNodeImpl extends NodeImpl<PopNode> {
dataType: 'any[]',
id: 'array' as PortId,
title: 'Array',
coerced: false,
},
];
}
Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/model/nodes/ReadDirectoryNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
title: 'Path',
dataType: 'string',
required: true,
coerced: false,
});
}

Expand All @@ -78,6 +79,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
title: 'Recursive',
dataType: 'boolean',
required: true,
coerced: false,
});
}

Expand All @@ -87,6 +89,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
title: 'Include Directories',
dataType: 'boolean',
required: true,
coerced: false,
});
}

Expand All @@ -96,6 +99,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
title: 'Filter Globs',
dataType: 'string[]',
required: true,
coerced: false,
});
}

Expand All @@ -105,6 +109,7 @@ export class ReadDirectoryNodeImpl extends NodeImpl<ReadDirectoryNode> {
title: 'Relative',
dataType: 'boolean',
required: true,
coerced: false,
});
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/ReadFileNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export class ReadFileNodeImpl extends NodeImpl<ReadFileNode> {
id: 'path' as PortId,
title: 'Path',
dataType: 'string',
coerced: false,
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,6 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl<ChatAnthropicNode> = {
: undefined
: data.stop;

const functions = expectTypeOptional(inputs['functions' as PortId], 'gpt-function[]');

const { messages } = getChatAnthropicNodeMessages(inputs);

let prompt = messages.reduce((acc, message) => {
Expand Down
40 changes: 40 additions & 0 deletions packages/core/src/utils/coerceType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,43 @@ function coerceToObject(value: DataValue | undefined): object | undefined {

return value.value; // Whatever, consider anything an object
}

export function canBeCoercedAny(from: DataType | Readonly<DataType[]>, to: DataType | Readonly<DataType[]>) {
for (const fromType of Array.isArray(from) ? from : [from]) {
for (const toType of Array.isArray(to) ? to : [to]) {
if (canBeCoerced(fromType, toType)) {
return true;
}
}
}
return false;
}

// TODO hard to keep in sync with coerceType
export function canBeCoerced(from: DataType, to: DataType) {
if (to === 'any' || from === 'any') {
return true;
}

if (isArrayDataType(to) && isArrayDataType(from)) {
return canBeCoerced(getScalarTypeOf(from), getScalarTypeOf(to));
}

if (isArrayDataType(to) && !isArrayDataType(from)) {
return canBeCoerced(from, getScalarTypeOf(to));
}

if (isArrayDataType(from) && !isArrayDataType(to)) {
return to === 'string' || to === 'object';
}

if (to === 'gpt-function') {
return from === 'object';
}

if (to === 'audio' || to === 'binary' || to === 'image') {
return false;
}

return true;
}
Loading

0 comments on commit 77a63bc

Please sign in to comment.