Skip to content

Commit

Permalink
Prints a warning if a component A's output channel is used as another…
Browse files Browse the repository at this point in the history
… component B's input, but component A is not included in the components of the pipeline.

PiperOrigin-RevId: 436831915
  • Loading branch information
tfx-copybara committed Mar 23, 2022
1 parent 2056a05 commit ae1ff73
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tfx/orchestration/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ def _set_components(self, components: List[base_node.BaseNode]) -> None:
if upstream_node:
component.add_upstream_node(upstream_node)
upstream_node.add_downstream_node(component)
else:
warnings.warn(
f'Node {component.id} depends on the output of node {node_id}'
f', but {node_id} is not included in the components of '
'pipeline. Did you forget to add it?')

layers = topsort.topsorted_layers(
list(deduped_components),
Expand Down
24 changes: 24 additions & 0 deletions tfx/orchestration/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,30 @@ def testPipelineWithNode(self):
metadata_connection_config=self._metadata_connection_config)
self.assertEqual(1, len(my_pipeline.components))

def testPipelineWarnMissingNode(self):
channel_one = types.Channel(type=_ArtifactTypeOne)
channel_two = types.Channel(type=_ArtifactTypeTwo)
component_a = _make_fake_component_instance('component_a', _OutputTypeA,
{'a': channel_one}, {})
component_b = _make_fake_component_instance(
name='component_b',
output_type=_OutputTypeB,
inputs={'a': component_a.outputs['output']},
outputs={'b': channel_two})

warn_text = (
'Node component_b depends on the output of node component_a, '
'but component_a is not included in the components of pipeline. '
'Did you forget to add it?')
with self.assertWarnsRegex(UserWarning, warn_text):
pipeline.Pipeline(
pipeline_name='name',
pipeline_root='root',
components=[
component_b,
],
metadata_connection_config=self._metadata_connection_config)

def testPipelineWithLoop(self):
channel_one = types.Channel(type=_ArtifactTypeOne)
channel_two = types.Channel(type=_ArtifactTypeTwo)
Expand Down

0 comments on commit ae1ff73

Please sign in to comment.