diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 742270a9ae..e9a7909809 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -15,6 +15,7 @@ from flytekit.loggers import logger TIME_PARTITION_KWARG = "time_partition" +MAX_PARTITIONS = 10 class InputsBase(object): @@ -337,6 +338,9 @@ def __init__( self._partitions = Partitions(p) self._partitions.set_reference_artifact(self) + if self.partition_keys and len(self.partition_keys) > MAX_PARTITIONS: + raise ValueError("There is a hard limit of 10 partition keys per artifact currently.") + def __call__(self, *args, **kwargs) -> ArtifactIDSpecification: """ This __call__ should only ever happen in the context of a task or workflow's output, to be diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index ea3734f8aa..1580832426 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -583,3 +583,9 @@ def test_tp_math(): assert tp2.other == datetime.timedelta(days=1) assert tp2.granularity == Granularity.HOUR assert tp2 is not tp + + +def test_lims(): + # test an artifact with 11 partition keys + with pytest.raises(ValueError): + Artifact(name="test artifact", time_partitioned=True, partition_keys=[f"key_{i}" for i in range(11)])