diff --git a/nemo/core/neural_types/axes.py b/nemo/core/neural_types/axes.py index dcc2e7736ff6..1b3159815a90 100644 --- a/nemo/core/neural_types/axes.py +++ b/nemo/core/neural_types/axes.py @@ -34,7 +34,7 @@ class AxisKind(AxisKindAbstract): """This Enum represents what does varying axis dimension mean. For example, does this dimension correspond to width, batch, time, etc. The "Dimension" and "Channel" kinds are the same and used to represent - a general axis. + a general axis. "Any" axis will accept any axis kind fed to it. """ Batch = 0 @@ -43,6 +43,7 @@ class AxisKind(AxisKindAbstract): Channel = 2 Width = 3 Height = 4 + Any = 5 def __str__(self): return str(self.name).lower() @@ -61,6 +62,8 @@ def from_str(label): return AxisKind.Width elif _label == "h" or _label == "height": return AxisKind.Height + elif _label == "any": + return AxisKind.Any else: raise ValueError(f"Can't create AxisKind from {label}") diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index 80bda4aa01d9..b36d0c3eba5f 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -165,7 +165,9 @@ def __compare_axes(axes_a, axes_b) -> int: for axis_a, axis_b in zip(axes_a, axes_b): kinds_a[axis_a.kind] = axis_a.size kinds_b[axis_b.kind] = axis_b.size - if ( + if axis_a.kind == AxisKind.Any: + same = True + elif ( axis_a.kind != axis_b.kind or axis_a.is_list != axis_b.is_list or (axis_a.size != axis_b.size and axis_a.size is not None) diff --git a/tests/core/test_neural_types.py b/tests/core/test_neural_types.py index e31fd08941d3..133e747db3fe 100644 --- a/tests/core/test_neural_types.py +++ b/tests/core/test_neural_types.py @@ -176,3 +176,13 @@ def test_unspecified_dimensions(self): t1 = NeuralType(('B', 'T', 'C'), SpectrogramType()) self.assertEqual(t1.compare(t0), NeuralTypeComparisonResult.SAME) self.assertEqual(t0.compare(t1), NeuralTypeComparisonResult.DIM_INCOMPATIBLE) + + def test_any_axis(self): + t0 = NeuralType(('B', 'Any', 'Any'), VoidType()) + t1 = NeuralType(('B', 'Any', 'Any'), SpectrogramType()) + t2 = NeuralType(('B', 'T', 'C'), SpectrogramType()) + self.assertEqual(t0.compare(t1), NeuralTypeComparisonResult.SAME) + self.assertEqual(t0.compare(t2), NeuralTypeComparisonResult.SAME) + self.assertEqual(t1.compare(t2), NeuralTypeComparisonResult.SAME) + self.assertEqual(t2.compare(t1), NeuralTypeComparisonResult.INCOMPATIBLE) + self.assertEqual(t1.compare(t0), NeuralTypeComparisonResult.INCOMPATIBLE)