From ad89b3e36401ef1abbe5d7625bd2089d0554f2c7 Mon Sep 17 00:00:00 2001 From: ddalvi Date: Wed, 18 Sep 2024 17:40:27 -0400 Subject: [PATCH] Add tests for disabling default caching var and flag Signed-off-by: ddalvi --- sdk/python/kfp/cli/cli_test.py | 104 +++++++++++++++++------ sdk/python/kfp/cli/compile_.py | 1 + sdk/python/kfp/compiler/compiler_test.py | 53 ++++++++++++ sdk/python/kfp/dsl/base_component.py | 3 +- sdk/python/kfp/dsl/pipeline_context.py | 7 +- 5 files changed, 139 insertions(+), 29 deletions(-) diff --git a/sdk/python/kfp/cli/cli_test.py b/sdk/python/kfp/cli/cli_test.py index 361db73a14e4..03d7d693c11c 100644 --- a/sdk/python/kfp/cli/cli_test.py +++ b/sdk/python/kfp/cli/cli_test.py @@ -27,6 +27,7 @@ from click import testing from kfp.cli import cli from kfp.cli import compile_ +import yaml class TestCliNounAliases(unittest.TestCase): @@ -166,34 +167,87 @@ def test_deprecation_warning(self): res.stdout.decode('utf-8')) -info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli)) -commands_dict = { - command: list(body.get('commands', {}).keys()) - for command, body in info_dict['commands'].items() -} -noun_verb_list = [ - (noun, verb) for noun, verbs in commands_dict.items() for verb in verbs -] +class TestKfpDslCompile(unittest.TestCase): + def invoke(self, args): + starting_args = ['dsl', 'compile'] + args = starting_args + args + runner = testing.CliRunner() + return runner.invoke( + cli=cli.cli, args=args, catch_exceptions=False, obj={}) -class TestSmokeTestAllCommandsWithHelp(parameterized.TestCase): + def create_pipeline_file(self): + pipeline_code = b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""" + temp_pipeline = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + temp_pipeline.write(pipeline_code) + temp_pipeline.flush() + return temp_pipeline + + def load_output_yaml(self, output_file): + with open(output_file, 'r') as f: + return yaml.safe_load(f) + + def test_compile_with_caching_flag_enabled(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + result = self.invoke( + ['--py', temp_pipeline.name, '--output', output_file]) + self.assertEqual(result.exit_code, 0) - @classmethod - def setUpClass(cls): - cls.runner = testing.CliRunner() - - cls.vals = [('run', 'list')] - - @parameterized.parameters(*noun_verb_list) - def test(self, noun: str, verb: str): - with mock.patch('kfp.cli.cli.client.Client'): - result = self.runner.invoke( - args=[noun, verb, '--help'], - cli=cli.cli, - catch_exceptions=False, - obj={}) - self.assertTrue(result.output.startswith('Usage: ')) - self.assertEqual(result.exit_code, 0) + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options.get('enableCache'), True) + + def test_compile_with_caching_flag_disabled(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + result = self.invoke([ + '--py', temp_pipeline.name, '--output', output_file, + '--disable-execution-caching-by-default' + ]) + self.assertEqual(result.exit_code, 0) + + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options, {}) + + def test_compile_with_caching_disabled_env_var(self): + temp_pipeline = self.create_pipeline_file() + output_file = 'test_output.yaml' + + os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true' + result = self.invoke( + ['--py', temp_pipeline.name, '--output', output_file]) + self.assertEqual(result.exit_code, 0) + del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] + + output_data = self.load_output_yaml(output_file) + self.assertIn('root', output_data) + self.assertIn('tasks', output_data['root']['dag']) + for task in output_data['root']['dag']['tasks'].values(): + self.assertIn('cachingOptions', task) + caching_options = task['cachingOptions'] + self.assertEqual(caching_options, {}) if __name__ == '__main__': diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index f4b2c3b4570f..e1fc28a83285 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -26,6 +26,7 @@ from kfp.dsl import graph_component from kfp.dsl.pipeline_context import Pipeline + def is_pipeline_func(func: Callable) -> bool: """Checks if a function is a pipeline function. diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 7f0cfd4b98a3..16ebe9e655d9 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [ task = print_and_return(text='Hello') +class TestCompilePipelineCaching(unittest.TestCase): + + def test_compile_pipeline_with_caching_enabled(self): + """Test pipeline compilation with caching enabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec['cachingOptions'] + + self.assertTrue(caching_options['enableCache']) + + def test_compile_pipeline_with_caching_disabled(self): + """Test pipeline compilation with caching disabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name='tiny-pipeline') + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(False) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec.get('cachingOptions', {}) + + self.assertEqual(caching_options, {}) + + class V2NamespaceAliasTest(unittest.TestCase): """Test that imports of both modules and objects are aliased (e.g. all import path variants work).""" diff --git a/sdk/python/kfp/dsl/base_component.py b/sdk/python/kfp/dsl/base_component.py index 089a11116379..2682321417d1 100644 --- a/sdk/python/kfp/dsl/base_component.py +++ b/sdk/python/kfp/dsl/base_component.py @@ -103,7 +103,8 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask: args=task_inputs, execute_locally=pipeline_context.Pipeline.get_default_pipeline() is None, - execution_caching_default=pipeline_context.Pipeline.get_execution_caching_default(), + execution_caching_default=pipeline_context.Pipeline + .get_execution_caching_default(), ) @property diff --git a/sdk/python/kfp/dsl/pipeline_context.py b/sdk/python/kfp/dsl/pipeline_context.py index f9a45d8676c1..4d0bbbaa840e 100644 --- a/sdk/python/kfp/dsl/pipeline_context.py +++ b/sdk/python/kfp/dsl/pipeline_context.py @@ -14,6 +14,7 @@ """Definition for Pipeline.""" import functools +import os from typing import Callable, Optional from kfp.dsl import component_factory @@ -21,8 +22,6 @@ from kfp.dsl import tasks_group from kfp.dsl import utils -import os - def pipeline(func: Optional[Callable] = None, *, @@ -107,7 +106,9 @@ def get_default_pipeline(): # or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT. # align with click's treatment of env vars for boolean flags. # per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True - _execution_caching_default = not str(os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower() in {"1", "true", "t", "yes", "y", "on"} + _execution_caching_default = not str( + os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower( + ) in {'1', 'true', 't', 'yes', 'y', 'on'} @staticmethod def get_execution_caching_default():