diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index 7496e2b..8d70bb4 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -51,7 +51,7 @@ type Connection struct { Downstream []string } -type HttpOperator struct { +type HttpTask struct { TaskID string ConnectionId string Name string @@ -61,10 +61,20 @@ type HttpOperator struct { Upstream []string } +type PythonTask struct { + TaskID string + Name string + Data interface{} + Downstream []string + Upstream []string +} + type GenData struct { - DagDef Dag - Connections []Connection - Tasks []HttpOperator + DagDef Dag + Connections []Connection + PythonImports []string + Tasks []HttpTask + PythonTask []PythonTask } func CheckDeps(deps []string) bool { @@ -159,6 +169,9 @@ func CreateDagGen(g GenData, directory string) (string, error) { for _, task := range data.Tasks { taskIDMap[task.TaskID] = TransformTaskID(task.TaskID) } + for _, task := range data.PythonTask { + taskIDMap[task.TaskID] = TransformTaskID(task.TaskID) + } t := template.New("dag").Funcs( template.FuncMap{ diff --git a/pkg/codegen/dag.tpl b/pkg/codegen/dag.tpl index dc8e681..becc290 100644 --- a/pkg/codegen/dag.tpl +++ b/pkg/codegen/dag.tpl @@ -61,6 +61,11 @@ def create_http_connection(custom_conn_config, session=None): session.commit() return f'Connection {custom_conn_config["ConnectionID"]} successful!' +# ##################### IMPORT SPECIFIC PYTHON FUNCTIONS ########################## +{{range .PythonImports}} +from custom_functions.{{ . }} import {{ . }} +{{ end }} + # ##################### ESTABLISH DB/REDIS CONNECTIONS ####################### {{range $conn := .Connections}} {{ transformTaskID $conn.ConnectionID }} = PythonOperator( @@ -98,6 +103,19 @@ task_id_map = { {{- end}} ){{ end }} +# ##################### CUSTOM PYTHON OPERATOR ########################## +{{range .PythonTask}} +{{ transformTaskID .TaskID }}_data = {{mapToPythonDict .Data}} +{{ transformTaskID .TaskID }} = PythonOperator( + task_id='{{ transformTaskID .TaskID }}', + python_callable={{ .Name }}, + op_args=[ + {{ transformTaskID .TaskID }}_data + ], + dag=dag, +) +{{ end }} + # ##################### DIRECTED ACYLIC GRAPH DEFINITION ########################## {{range $conn := .Connections}} {{- if checkDeps $conn.Downstream }} diff --git a/test/codegen/codegen_test.go b/test/codegen/codegen_test.go index d284381..fddd3d2 100644 --- a/test/codegen/codegen_test.go +++ b/test/codegen/codegen_test.go @@ -30,7 +30,7 @@ func TestCreateDagObject(t *testing.T) { } func TestToMapValid(t *testing.T) { - testStruct := codegen.HttpOperator{ + testStruct := codegen.HttpTask{ TaskID: "test_task_id", ConnectionId: "test_connection_id", Name: "test_name", @@ -52,6 +52,25 @@ func TestToMapValid(t *testing.T) { // TODO: Test the invalid case for ToMap function. I'm not sure how to hit the edge case yet. +func TestToMapValidWithPythonTask(t *testing.T) { + testStruct := codegen.PythonTask{ + TaskID: "test_python_task_id", + Name: "test_name", + Data: map[string]interface{}{"key": "value"}, + Downstream: []string{"downstream_task_1", "downstream_task_2"}, + Upstream: []string{"upstream_task_1"}, + } + res, err := codegen.ToMap(testStruct) + if err != nil { + t.Error(err) + } + assert.Equal(t, res["Name"], "test_name") + assert.Equal(t, res["TaskID"], "test_python_task_id") + assert.Equal(t, res["Data"], map[string]interface{}{"key": "value"}) + assert.ElementsMatch(t, res["Downstream"], []interface{}{"downstream_task_1", "downstream_task_2"}) + assert.ElementsMatch(t, res["Upstream"], []interface{}{"upstream_task_1"}) +} + func TestMapToPythonDict(t *testing.T) { expectedPythonDict := ` {