Skip to content

Commit

Permalink
Merge pull request #29 from pyiron/polish_macro_creation
Browse files Browse the repository at this point in the history
Polish macro creation
  • Loading branch information
Tara-Lakshmipathy authored Nov 26, 2024
2 parents 93e6b4d + 3307933 commit 7dc2c10
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 53 deletions.
5 changes: 2 additions & 3 deletions js/CustomNode.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ export default memo(({ data }) => {
console.log('source: ', data.label)
model.set("commands", `source: ${data.label}`);
model.save_changes();
}

}

const renderLabel = (label) => {
return (
Expand Down Expand Up @@ -154,7 +153,7 @@ export default memo(({ data }) => {
context(data.label, index, convertedValue);
}}
style={{
width: '15px',
width: '20px',
height: '10px',
fontSize: '6px',
backgroundColor: getBackgroundColor(value, inp_type)
Expand Down
123 changes: 111 additions & 12 deletions js/widget.jsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useCallback, useState, useEffect, createContext } from 'react';
import React, { useCallback, useState, useEffect, createContext, useSelection } from 'react';
import { createRender, useModel } from "@anywidget/react";
import ELK from 'elkjs/lib/elk.bundled.js';
import {
Expand All @@ -9,6 +9,7 @@ import {
applyEdgeChanges,
applyNodeChanges,
addEdge,
useOnSelectionChange,
} from '@xyflow/react';
import '@xyflow/react/dist/style.css';

Expand Down Expand Up @@ -40,6 +41,27 @@ export const UpdateDataContext = createContext(null);

// const nodeTypes = { textUpdater: TextUpdaterNode, customNode: CustomNode };

function SelectionDisplay() {
const [selectedNodes, setSelectedNodes] = useState([]);
const [selectedEdges, setSelectedEdges] = useState([]);

// the passed handler has to be memoized, otherwise the hook will not work correctly
const onChange = useCallback(({ nodes, edges }) => {
setSelectedNodes(nodes.map((node) => node.id));
setSelectedEdges(edges.map((edge) => edge.id));
}, []);

useOnSelectionChange({
onChange,
});

return (
<div>
<p>Selected nodes: {selectedNodes.join(', ')}</p>
<p>Selected edges: {selectedEdges.join(', ')}</p>
</div>
);
}

const render = createRender(() => {
const model = useModel();
Expand All @@ -50,12 +72,16 @@ const render = createRender(() => {
const [nodes, setNodes] = useState(initialNodes);
const [edges, setEdges] = useState(initialEdges);

const selectedNodes = [];
const selectedEdges = [];

const nodeTypes = {
textUpdater: TextUpdaterNode,
customNode: CustomNode,
};

const [macroName, setMacroName] = useState('custom_macro');


const updateData = (nodeLabel, handleIndex, newValue) => {
setNodes(prevNodes =>
Expand Down Expand Up @@ -103,9 +129,31 @@ const render = createRender(() => {
const onNodesChange = useCallback(
(changes) => {
setNodes((nds) => {
const new_nodes = applyNodeChanges(changes, nds);
console.log('onNodesChange: ', changes, new_nodes)
const new_nodes = applyNodeChanges(changes, nds);
for (const i in changes) {
if (Object.hasOwn(changes[i], 'selected')) {
if (changes[i].selected){
for (const k in new_nodes){
if (new_nodes[k].id == changes[i].id) {
selectedNodes.push(new_nodes[k]);
}

}
}
else{
for (const j in selectedNodes){
if (selectedNodes[j].id == changes[i].id) {
//const index = selectedNodes[j].indexOf(changes[i].id);
selectedNodes.splice(j, 1);
}
}
}
}
}
console.log('selectedNodes:', selectedNodes);
console.log('nodes:', nodes);
model.set("nodes", JSON.stringify(new_nodes));
model.set("selected_nodes", JSON.stringify(selectedNodes));
model.save_changes();
return new_nodes;
});
Expand All @@ -115,11 +163,38 @@ const render = createRender(() => {

const onEdgesChange = useCallback(
(changes) => {
setEdges((eds) => {
const new_edges = applyEdgeChanges(changes, eds);
model.set("edges", JSON.stringify(new_edges));
model.save_changes();
return new_edges;
setEdges((eds) => {
const new_edges = applyEdgeChanges(changes, eds);
for (const i in changes) {
if (Object.hasOwn(changes[i], 'selected')) {
if (changes[i].selected){
for (const k in new_edges){
if (new_edges[k].id == changes[i].id) {
selectedEdges.push(new_edges[k]);
}
}
}
else{
for (const j in selectedEdges){
if (selectedEdges[j].id == changes[i].id) {
selectedEdges.splice(j, 1);
}
}
}
}
}
for (const n in selectedEdges){
var filterResult = new_edges.filter((edge) => edge.id === selectedEdges[n].id);
if (filterResult == []){
selectedEdges.splice(n, 1);
}
}
console.log('selectedEdges:', selectedEdges);
console.log('edges:', new_edges);
model.set("edges", JSON.stringify(new_edges));
model.set("selected_edges", JSON.stringify(selectedEdges));
model.save_changes();
return new_edges;
});
},
[setEdges],
Expand Down Expand Up @@ -173,8 +248,17 @@ const render = createRender(() => {
data: { ...node.data, forceToolbarVisible: enabled },
})),
),
);

);

const macroFunction = (userInput) => {
console.log('macro: ', userInput);
if (model) {
model.set("commands", `macro: ${userInput}`);
model.save_changes();
} else {
console.error('model is undefined');
}
}

return (
<div style={{ position: "relative", height: "800px", width: "100%" }}>
Expand All @@ -188,10 +272,25 @@ const render = createRender(() => {
onNodesDelete={onNodesDelete}
nodeTypes={nodeTypes}
fitView
style={rfStyle}>
<Background variant="dots" gap={12} size={1} />
style={rfStyle}
/*debugMode={true}*/
>
<div style={{ position: "absolute", right: "10px", top: "10px", zIndex: "4", fontSize: "12px"}}>
<label style={{display: "block"}}>Macro class name:</label>
<input
value={macroName}
onChange={(evt) => setMacroName(evt.target.value)}
/>
</div>
<Background variant="dots" gap={20} size={2} />
<MiniMap />
<Controls />
<button
style={{position: "absolute", right: "100px", top: "50px", zIndex: "4"}}
onClick={() => macroFunction(macroName)}
>
Create Macro
</button>
</ReactFlow>
</UpdateDataContext.Provider>
</div>
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"dependencies": {
"@anywidget/react": "^0.0.7",
"@xyflow/react": "^12.0.4",
"@xyflow/react": "^12.3.5",
"elkjs": "^0.9.3",
"react": "^18.3.1",
"react-dom": "^18.3.1"
Expand Down
105 changes: 105 additions & 0 deletions pyironflow/create_macro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pyiron_workflow.type_hinting import type_hint_to_tuple
import typing

def get_import_path(obj):
module = obj.__module__ if hasattr(obj, "__module__") else obj.__class__.__module__
# name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
name = obj.__name__ if "__name__" in dir(obj) else obj.__class__.__name__
path = f"{module}.{name}"
if path == "numpy.ndarray":
path = "numpy.array"
return path



def get_input_types_from_hint(node_input: dict):

new_type = ""

for listed_type in list(type_hint_to_tuple(node_input.type_hint)):
if listed_type == None:
listed_type = type(None)
if listed_type.__name__ != "NoneType":
new_type = new_type + listed_type.__name__ + "|"

new_type = new_type[:-1]

for listed_type in list(type_hint_to_tuple(node_input.type_hint)):
if listed_type == None:
listed_type = type(None)
if listed_type.__name__ == "NoneType":
if new_type != "":
new_type = ": Optional[" + new_type + "]"

return new_type

def custom(wf = dict, name = str, root_path='../pyiron_nodes/pyiron_nodes'):

imports = list("")
var_def = ""

file = open(root_path + '/' + name + '.py', 'w')

for i, (k, v) in enumerate(wf.children.items()):
rest, n = get_import_path(v).rsplit('.', 1)
new_import = " from " + rest + " import " + n
imports.append(new_import)
list_inputs = list(v.inputs.channel_dict.keys())

for j in list(v.inputs):
if ((v.label + "__" + j.label) in list(wf.inputs.channel_dict.keys())):
if str(j) == ("NOT_DATA" or "None"):
value = "None"
elif type(j.value) == str:
value = "'" + j.value + "'"
else:
value = str(j.value)
var_def = var_def + v.label + "_" + j.label + get_input_types_from_hint(j)+ " = " + value + ", "

var_def = var_def[:-2]

count = 0
new_list = list("")
for ic, (out, inp) in enumerate(wf.graph_as_dict["edges"]["data"].keys()):
out_node, out_port = out.split('/')[2].split('.')
inp_node, inp_port = inp.split('/')[2].split('.')
new_list.append([out_node, inp_node, inp_port])


file.write(
'''from pyiron_workflow import as_function_node, as_macro_node
from typing import Optional
@as_macro_node()
def ''' + name + '''(self, ''' + var_def + '''):
''')
for j in imports:
file.write(j + "\n")

for i, (k, v) in enumerate(wf.children.items()):
rest, n = get_import_path(v).rsplit('.', 1)
file.write(" self." + v.label + " = " + n + "()\n")

for i, (k, v) in enumerate(wf.children.items()):
rest, n = get_import_path(v).rsplit('.', 1)

node_def =""

for j in list(wf.inputs.channel_dict.keys()):
node_label, input_label =j.rsplit('__', 1)
if v.label == node_label:
node_def = node_def + input_label + " = " + node_label + "_" + input_label+ ", "

for p in new_list:
if v.label == p[1]:
node_def = node_def + p[2] + " = self."+ p[0] + ", "
node_def = node_def[:-2]
file.write(" self." + v.label + ".set_input_values" + "(" + node_def + ")\n")


rest, n = list(wf.outputs.channel_dict.keys())[0].rsplit('__', 1)
file.write(" return self." + rest)
print("\nSuccessfully created macro: " + root_path + '/' + name + '.py')
file.close()

return
2 changes: 1 addition & 1 deletion pyironflow/pyironflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, wf_list=None, root_path='../pyiron_nodes/pyiron_nodes'):

self.out_log = widgets.Output(layout={'border': '1px solid black', 'width': '800px'})
self.out_widget = widgets.Output(layout={'border': '1px solid black', 'min_width': '400px'})
self.wf_widgets = [PyironFlowWidget(wf, log=self.out_log, out_widget=self.out_widget)
self.wf_widgets = [PyironFlowWidget(wf=wf, root_path=root_path, log=self.out_log, out_widget=self.out_widget)
for wf in self.workflows]
self.view_flows = self.view_flows()
self.tree_view = TreeView(root_path=root_path, flow_widget=self.wf_widgets[0], log=self.out_log)
Expand Down
Loading

0 comments on commit 7dc2c10

Please sign in to comment.