diff --git a/web/src/pages/flow/categorize-form/hooks.ts b/web/src/pages/flow/categorize-form/hooks.ts index 7c6544a087..80df0cf53e 100644 --- a/web/src/pages/flow/categorize-form/hooks.ts +++ b/web/src/pages/flow/categorize-form/hooks.ts @@ -78,9 +78,11 @@ const buildCategorizeObjectFromList = (list: Array) => { export const useHandleFormValuesChange = ({ onValuesChange, form, - node, + nodeId, }: IOperatorForm) => { const edges = useGraphStore((state) => state.edges); + const getNode = useGraphStore((state) => state.getNode); + const node = getNode(nodeId); const handleValuesChange = useCallback( (changedValues: any, values: any) => { @@ -94,12 +96,13 @@ export const useHandleFormValuesChange = ({ ); useEffect(() => { + const items = buildCategorizeListFromObject( + get(node, 'data.form.category_description', {}), + edges, + node, + ); form?.setFieldsValue({ - items: buildCategorizeListFromObject( - get(node, 'data.form.category_description', {}), - edges, - node, - ), + items, }); }, [form, node, edges]); @@ -107,19 +110,29 @@ export const useHandleFormValuesChange = ({ }; export const useHandleToSelectChange = (nodeId?: string) => { - const { addEdge } = useGraphStore((state) => state); + const { addEdge, deleteEdgeBySourceAndSourceHandle } = useGraphStore( + (state) => state, + ); const handleSelectChange = useCallback( (name?: string) => (value?: string) => { - if (nodeId && value && name) { - addEdge({ - source: nodeId, - target: value, - sourceHandle: name, - targetHandle: null, - }); + if (nodeId && name) { + if (value) { + addEdge({ + source: nodeId, + target: value, + sourceHandle: name, + targetHandle: null, + }); + } else { + // clear selected value + deleteEdgeBySourceAndSourceHandle({ + source: nodeId, + sourceHandle: name, + }); + } } }, - [addEdge, nodeId], + [addEdge, nodeId, deleteEdgeBySourceAndSourceHandle], ); return { handleSelectChange }; diff --git a/web/src/pages/flow/categorize-form/index.tsx b/web/src/pages/flow/categorize-form/index.tsx index 5f1192a696..de0fc5a359 100644 --- a/web/src/pages/flow/categorize-form/index.tsx +++ b/web/src/pages/flow/categorize-form/index.tsx @@ -10,7 +10,7 @@ const CategorizeForm = ({ form, onValuesChange, node }: IOperatorForm) => { const { t } = useTranslate('flow'); const { handleValuesChange } = useHandleFormValuesChange({ form, - node, + nodeId: node?.id, onValuesChange, }); useSetLlmSetting(form); diff --git a/web/src/pages/flow/constant.tsx b/web/src/pages/flow/constant.tsx index cfe8a0ce07..12579bd508 100644 --- a/web/src/pages/flow/constant.tsx +++ b/web/src/pages/flow/constant.tsx @@ -101,13 +101,7 @@ export const CategorizeAnchorPointPositions = [ // key is the source of the edge, value is the target of the edge // no connection lines are allowed between key and value export const RestrictedUpstreamMap = { - [Operator.Begin]: [ - Operator.Begin, - Operator.Answer, - Operator.Categorize, - Operator.Generate, - Operator.Retrieval, - ], + [Operator.Begin]: [], [Operator.Categorize]: [Operator.Begin, Operator.Categorize, Operator.Answer], [Operator.Answer]: [], [Operator.Retrieval]: [], diff --git a/web/src/pages/flow/interface.ts b/web/src/pages/flow/interface.ts index 73eca14d14..55042647e1 100644 --- a/web/src/pages/flow/interface.ts +++ b/web/src/pages/flow/interface.ts @@ -10,6 +10,7 @@ export interface IOperatorForm { onValuesChange?(changedValues: any, values: any): void; form?: FormInstance; node?: Node; + nodeId?: string; } export interface IBeginForm { diff --git a/web/src/pages/flow/mock.tsx b/web/src/pages/flow/mock.tsx index af0129ba5c..0f9a4ad579 100644 --- a/web/src/pages/flow/mock.tsx +++ b/web/src/pages/flow/mock.tsx @@ -37,7 +37,7 @@ export const dsl = { graph: { nodes: [ { - id: 'begin', + id: 'Begin', type: 'beginNode', position: { x: 50, diff --git a/web/src/pages/flow/store.ts b/web/src/pages/flow/store.ts index 8d2edf0f4b..f20cc31344 100644 --- a/web/src/pages/flow/store.ts +++ b/web/src/pages/flow/store.ts @@ -35,7 +35,7 @@ export type RFState = { updateNodeForm: (nodeId: string, values: any) => void; onSelectionChange: OnSelectionChangeFunc; addNode: (nodes: Node) => void; - getNode: (id: string) => Node | undefined; + getNode: (id?: string) => Node | undefined; addEdge: (connection: Connection) => void; getEdge: (id: string) => Edge | undefined; deletePreviousEdgeOfClassificationNode: (connection: Connection) => void; @@ -43,7 +43,7 @@ export type RFState = { deleteEdge: () => void; deleteEdgeById: (id: string) => void; deleteNodeById: (id: string) => void; - deleteEdgeBySourceAndTarget: (source: string, target: string) => void; + deleteEdgeBySourceAndSourceHandle: (connection: Partial) => void; findNodeByName: (operatorName: Operator) => Node | undefined; updateMutableNodeFormItem: (id: string, field: string, value: any) => void; }; @@ -87,7 +87,7 @@ const useGraphStore = create()( addNode: (node: Node) => { set({ nodes: get().nodes.concat(node) }); }, - getNode: (id: string) => { + getNode: (id?: string) => { return get().nodes.find((x) => x.id === id); }, addEdge: (connection: Connection) => { @@ -150,12 +150,17 @@ const useGraphStore = create()( edges: edges.filter((edge) => edge.id !== id), }); }, - deleteEdgeBySourceAndTarget: (source: string, target: string) => { + deleteEdgeBySourceAndSourceHandle: ({ + source, + sourceHandle, + }: Partial) => { const { edges } = get(); + const nextEdges = edges.filter( + (edge) => + edge.source !== source || edge.sourceHandle !== sourceHandle, + ); set({ - edges: edges.filter( - (edge) => edge.target !== target && edge.source !== source, - ), + edges: nextEdges, }); }, deleteNodeById: (id: string) => { diff --git a/web/src/pages/flow/utils.ts b/web/src/pages/flow/utils.ts index 404f4d1447..b2a6916168 100644 --- a/web/src/pages/flow/utils.ts +++ b/web/src/pages/flow/utils.ts @@ -173,7 +173,8 @@ export const getOperatorTypeFromId = (id: string | null) => { // restricted lines cannot be connected successfully. export const isValidConnection = (connection: Connection) => { - return RestrictedUpstreamMap[ + const ret = RestrictedUpstreamMap[ getOperatorTypeFromId(connection.source) as Operator ]?.every((x) => x !== getOperatorTypeFromId(connection.target)); + return ret; };