Skip to content

Commit

Permalink
Execute a test pipeline using Dask
Browse files Browse the repository at this point in the history
Execute dask_pipeline and validate the results. Also, ensure that the
global client has not been modified by the execution as a side effect.
  • Loading branch information
kinghuang committed Aug 31, 2020
1 parent e826cdf commit 1249860
Showing 1 changed file with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
import pytest
from dagster import (
InputDefinition,
ModeDefinition,
OutputDefinition,
ResourceDefinition,
execute_pipeline,
file_relative_path,
pipeline,
solid,
)
from dask.dataframe.utils import assert_eq
from distributed.client import _get_global_client

from dagster_dask import DataFrame, dask_resource


def create_dask_df():
path = file_relative_path(__file__, "num.csv")
return dd.read_csv(path)


@solid(
input_defs=[InputDefinition(dagster_type=DataFrame, name="df")],
output_defs=[OutputDefinition(dagster_type=DataFrame, name="df")],
Expand All @@ -18,3 +35,39 @@ def passthrough(_, df):
)
def dask_pipeline():
return passthrough()


def test_dask_pipeline():
run_config={
"resources": {
"dask": {
"config": {
"cluster": {
"local": {
"n_workers": 2,
"threads_per_worker": 1,
},
},
},
},
},
"solids": {
"passthrough": {
"inputs": {
"df": {
"read": {
"csv": {
"path": file_relative_path(__file__, "num.csv"),
},
},
}
},
},
},
}

global_client = _get_global_client()
result = execute_pipeline(dask_pipeline, run_config=run_config, instance=DagsterInstance.local_temp())

assert global_client == _get_global_client()
assert assert_eq(result.result_for_solid("passthrough").output_value(), create_dask_df())

0 comments on commit 1249860

Please sign in to comment.