Skip to content

Commit

Permalink
Merge pull request #37 from pyiron/automatic_positioning
Browse files Browse the repository at this point in the history
Automatic positioning
  • Loading branch information
Tara-Lakshmipathy authored Jan 9, 2025
2 parents 8d541c0 + c7c8ee9 commit 06721fe
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 7 deletions.
112 changes: 112 additions & 0 deletions js/useElkLayout.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import { useEffect } from 'react';
import ELK from 'elkjs/lib/elk.bundled.js';
import { useNodesInitialized, useReactFlow } from '@xyflow/react';
// import { saveAs } from 'file-saver';


// elk layouting options can be found here:
// https://www.eclipse.org/elk/reference/algorithms/org-eclipse-elk-layered.html


// uses elkjs to give each node a layouted position
export const getLayoutedNodes2 = async (nodes, edges) => {
const layoutOptions = {
'elk.algorithm': 'layered',
'elk.direction': 'RIGHT',
'elk.layered.spacing.edgeNodeBetweenLayers': '40',
'elk.spacing.nodeNode': '40',
'elk.layered.nodePlacement.strategy': 'SIMPLE',
};

console.log("nodes layout: ", nodes);
console.log("edges layout: ", edges);
const elk = new ELK();

const graph = {
id: 'root',
layoutOptions,
children: nodes.map((n) => {
const targetPorts = n.data.target_labels.map((label) => ({
id: `${n.id}_in_${label}`,

// ⚠️ it's important to let elk know on which side the port is
// in this example targets are on the left (WEST) and sources on the right (EAST)
properties: {
side: 'WEST',
},
}));

const sourcePorts = n.data.source_labels.map((label) => ({
id: `${n.id}_out_${label}`,
properties: {
side: 'EAST',
},
}));

return {
id: n.id,
width: n.style.width_unitless ?? 240,
height: n.style.height_unitless ?? 100,
// ⚠️ we need to tell elk that the ports are fixed, in order to reduce edge crossings
properties: {
'org.eclipse.elk.portConstraints': 'FIXED_ORDER',
// 'org.eclipse.elk.layered.portSortingStrategy': 'UP_DOWN',
},
// we are also passing the id, so we can also handle edges without a sourceHandle or targetHandle option
ports: [ ...targetPorts.reverse(), ...sourcePorts.reverse()],
};
}),
edges: edges.map((e) => ({
id: e.id,
sources: [`${e.source}_out_${e.sourceHandle}`],
targets: [`${e.target}_in_${e.targetHandle}`],
})),
};

// const blob = new Blob([JSON.stringify(graph)], {type: "text/plain;charset=utf-8"});
// saveAs(blob, 'output.json');

console.log("Graph: ", graph);
const layoutedGraph = await elk.layout(graph);
console.log("layoutedGraph: ", layoutedGraph);

const layoutedNodes = nodes.map((node) => {
const layoutedNode = layoutedGraph.children?.find(
(lgNode) => lgNode.id === node.id,
);

return {
...node,
position: {
x: layoutedNode?.x ?? 0,
y: layoutedNode?.y ?? 0,
},
};
});
console.log("layoutedGraphNodes: ", layoutedNodes);

return layoutedNodes;
};

export default function useLayoutNodes() {
const nodesInitialized = useNodesInitialized();
const { getNodes, getEdges, setNodes, fitView } = useReactFlow();

useEffect(() => {
if (nodesInitialized) {
const layoutNodes = async () => {
const layoutedNodes = await getLayoutedNodes(
getNodes(),
getEdges(),
);

setNodes(layoutedNodes);
setTimeout(() => fitView(), 0);
};

layoutNodes();
}
}, [nodesInitialized, getNodes, getEdges, setNodes, fitView]);

return null;
}
17 changes: 17 additions & 0 deletions js/widget.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import '@xyflow/react/dist/style.css';

import TextUpdaterNode from './TextUpdaterNode.jsx';
import CustomNode from './CustomNode.jsx';
import {getLayoutedNodes2} from './useElkLayout';

import './text-updater-node.css';

Expand Down Expand Up @@ -80,6 +81,16 @@ const render = createRender(() => {
customNode: CustomNode,
};

const layoutNodes = async () => {
const layoutedNodes = await getLayoutedNodes2(nodes, edges);
setNodes(layoutedNodes);
// setTimeout(() => fitView(), 0);
};

useEffect(() => {
layoutNodes();
}, [setNodes]);

const [macroName, setMacroName] = useState('custom_macro');

const [currentDateTime, setCurrentDateTime] = useState(() => {
Expand Down Expand Up @@ -394,6 +405,12 @@ const render = createRender(() => {
>
Delete Save File
</button>
<button
style={{position: "absolute", right: "130px", bottom: "170px", zIndex: "4"}}
onClick={layoutNodes}
>
Reset Layout
</button>
</ReactFlow>
</UpdateDataContext.Provider>
</div>
Expand Down
30 changes: 23 additions & 7 deletions pyironflow/wf_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,26 @@ def get_node_types(node_io):
return node_io_types


def get_node_position(node, id_num, node_width=240, y0=100, x_spacing=20):
def get_node_position(node, max_x, node_width=240, y0=100, x_spacing=20):
if 'position' in dir(node):
x, y = node.position
# if isinstance(x, str):
# x, y = 0, 0
else:
x = id_num * (node_width + x_spacing)
x = max_x + node_width + x_spacing
y = y0

return {'x': x, 'y': y}


def get_node_dict(node, id_num, key=None):
def get_node_dict(node, max_x, key=None):
node_width = 240
n_inputs = len(list(node.inputs.channel_dict.keys()))
n_outputs = len(list(node.outputs.channel_dict.keys()))
if n_outputs > n_inputs:
node_height = 30 + (16*n_outputs) + 10
else:
node_height = 30 + (16*n_inputs) + 10
label = node.label
if (node.label != key) and (key is not None):
label = f'{node.label}: {key}'
Expand All @@ -118,22 +124,32 @@ def get_node_dict(node, id_num, key=None):
'source_values': get_node_values(node.outputs.channel_dict),
'source_types': get_node_types(node.outputs),
},
'position': get_node_position(node, id_num),
'position': get_node_position(node, max_x),
'type': 'customNode',
'style': {'border': '1px black solid',
'padding': 5,
'background': get_color(node=node, theme='light'),
'borderRadius': '10px',
'width': f'{node_width}px'},
'width': f'{node_width}px',
'width_unitless': node_width,
'height': f'{node_height}px',
'height_unitless': node_height},
'targetPosition': 'left',
'sourcePosition': 'right'
}


def get_nodes(wf):
nodes = []
x_coords = []
max_x = 0
for i, (k, v) in enumerate(wf.children.items()):
if 'position' in dir(v):
x_coords.append(v.position[0])
if len(x_coords) > 0:
max_x = max(x_coords)
for i, (k, v) in enumerate(wf.children.items()):
nodes.append(get_node_dict(v, id_num=i, key=k))
nodes.append(get_node_dict(v, max_x, key=k))
return nodes


Expand Down Expand Up @@ -178,7 +194,7 @@ def get_input_types_from_hint(node_input: dict):
if listed_type.__name__ != "NoneType":
new_type = new_type + listed_type.__name__ + "|"

new_type = new_type[:-1]
new_type = new_type[:-1]

for listed_type in list(type_hint_to_tuple(node_input.type_hint)):
if listed_type == None:
Expand Down

0 comments on commit 06721fe

Please sign in to comment.