diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator.go b/flyteadmin/pkg/manager/impl/validation/execution_validator.go index c9f357b525..2d852c5a97 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator.go @@ -100,13 +100,8 @@ func CheckAndFetchInputsForExecution( } executionInputMap[name] = expectedInput.GetDefault() } else { - inputType := validators.LiteralTypeForLiteral(executionInputMap[name]) - err := validators.ValidateLiteralType(inputType) - if err != nil { - return nil, errors.NewInvalidLiteralTypeError(name, err) - } - if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) { - return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType) + if !validators.IsInstance(executionInputMap[name], expectedInput.GetVar().GetType()) { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got literal %s", name, expectedInput.GetVar().GetType(), executionInputMap[name]) } } } diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go index 89e97370fa..fcca3b0316 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go @@ -17,8 +17,6 @@ import ( var execConfig = testutils.GetApplicationConfigWithDefaultDomains() -const failedToValidateLiteralType = "Failed to validate literal type" - func TestValidateExecEmptyProject(t *testing.T) { request := testutils.GetExecutionRequest() request.Project = "" @@ -154,7 +152,7 @@ func TestValidateExecInputsWrongType(t *testing.T) { lpRequest.GetSpec().GetFixedInputs(), lpRequest.GetSpec().GetDefaultInputs(), ) - utils.AssertEqualWithSanitizedRegex(t, "invalid foo input wrong type. Expected simple:STRING, but got simple:INTEGER", err.Error()) + utils.AssertEqualWithSanitizedRegex(t, "invalid foo input wrong type. Expected simple:STRING, but got literal scalar: {primitive:{integer:1}}", err.Error()) } func TestValidateExecInputsExtraInputs(t *testing.T) { @@ -244,7 +242,7 @@ func TestValidateExecUnknownIDLInputs(t *testing.T) { assert.NotNil(t, err) // Expected error message - assert.Contains(t, err.Error(), failedToValidateLiteralType) + assert.Contains(t, err.Error(), "invalid foo input wrong type. Expected simple:1000, but got literal scalar:{}") } func TestValidExecutionId(t *testing.T) { diff --git a/flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go b/flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go index 0308faceba..0168bb066c 100644 --- a/flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/launch_plan_validator.go @@ -156,14 +156,9 @@ func checkAndFetchExpectedInputForLaunchPlan( if !ok { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unexpected fixed_input %s", name) } - inputType := validators.LiteralTypeForLiteral(fixedInput) - err := validators.ValidateLiteralType(inputType) - if err != nil { - return nil, errors.NewInvalidLiteralTypeError(name, err) - } - if !validators.AreTypesCastable(inputType, value.GetType()) { + if !validators.IsInstance(fixedInput, value.GetType()) { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "invalid fixed_input wrong type %s, expected %v, got %v instead", name, value.GetType(), inputType) + "invalid fixed_input wrong type %s, expected %v, got literal %v instead", name, value.GetType(), fixedInput) } } diff --git a/flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go b/flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go index 8dee3e3cca..5ae9101746 100644 --- a/flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/launch_plan_validator_test.go @@ -106,7 +106,7 @@ func TestValidateLpDefaultInputsWrongType(t *testing.T) { request.Spec.DefaultInputs.Parameters["foo"].Var.Type = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}} err := ValidateLaunchPlan(context.Background(), request, testutils.GetRepoWithDefaultProject(), lpApplicationConfig, getWorkflowInterface()) - expected := "Type mismatch for Parameter foo in default_inputs has type simple:FLOAT , expected simple:STRING " + expected := "Invalid default value for variable foo in default_inputs - expected type simple:FLOAT, but got literal scalar:{primitive:{string_value:\"foo-value\"}}" utils.AssertEqualWithSanitizedRegex(t, expected, err.Error()) } @@ -207,7 +207,7 @@ func TestGetLpExpectedInvalidFixedInputType(t *testing.T) { request.GetSpec().GetFixedInputs(), request.GetSpec().GetDefaultInputs(), ) - utils.AssertEqualWithSanitizedRegex(t, "invalid fixed_input wrong type bar, expected simple:BINARY , got simple:STRING instead", err.Error()) + utils.AssertEqualWithSanitizedRegex(t, "invalid fixed_input wrong type bar, expected simple:BINARY, got literal scalar: {primitive: {string_value: \"bar-value\"}} instead", err.Error()) assert.Nil(t, actualMap) } @@ -272,7 +272,7 @@ func TestGetLpExpectedInvalidFixedInputWithUnknownIDL(t *testing.T) { assert.NotNil(t, err) // Expected error message - assert.Contains(t, err.Error(), failedToValidateLiteralType) + assert.Contains(t, err.Error(), "invalid fixed_input wrong type foo, expected simple:1000, got literal scalar:{} instead") } func TestGetLpExpectedNoFixedInput(t *testing.T) { diff --git a/flyteadmin/pkg/manager/impl/validation/signal_validator.go b/flyteadmin/pkg/manager/impl/validation/signal_validator.go index 0ba2d3b704..e7ac9a7133 100644 --- a/flyteadmin/pkg/manager/impl/validation/signal_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/signal_validator.go @@ -71,16 +71,11 @@ func ValidateSignalSetRequest(ctx context.Context, db repositoryInterfaces.Repos "failed to validate that signal [%v] exists, err: [%+v]", signalModel.SignalKey, err) } - valueType := propellervalidators.LiteralTypeForLiteral(request.GetValue()) lookupSignal, err := transformers.FromSignalModel(lookupSignalModel) if err != nil { return err } - err = propellervalidators.ValidateLiteralType(valueType) - if err != nil { - return errors.NewInvalidLiteralTypeError("", err) - } - if !propellervalidators.AreTypesCastable(lookupSignal.GetType(), valueType) { + if !propellervalidators.IsInstance(request.GetValue(), lookupSignal.GetType()) { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "requested signal value [%v] is not castable to existing signal type [%v]", request.GetValue(), lookupSignalModel.Type) diff --git a/flyteadmin/pkg/manager/impl/validation/signal_validator_test.go b/flyteadmin/pkg/manager/impl/validation/signal_validator_test.go index c78c2c366b..a9b6267ec6 100644 --- a/flyteadmin/pkg/manager/impl/validation/signal_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/signal_validator_test.go @@ -329,6 +329,6 @@ func TestValidateSignalUpdateRequest(t *testing.T) { assert.NotNil(t, err) // Expected error message - assert.Contains(t, err.Error(), failedToValidateLiteralType) + assert.Contains(t, err.Error(), "requested signal value [scalar:{}] is not castable to existing signal type") }) } diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index 03bc8f963d..5dd73793e4 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -281,16 +281,10 @@ func validateParameterMap(inputMap *core.ParameterMap, fieldName string) error { } defaultValue := defaultInput.GetDefault() if defaultValue != nil { - inputType := validators.LiteralTypeForLiteral(defaultValue) - err := validators.ValidateLiteralType(inputType) - if err != nil { - return errors.NewInvalidLiteralTypeError(name, err) - } - - if !validators.AreTypesCastable(inputType, defaultInput.GetVar().GetType()) { + if !validators.IsInstance(defaultValue, defaultInput.GetVar().GetType()) { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, - "Type mismatch for Parameter %s in %s has type %s, expected %s", name, fieldName, - defaultInput.GetVar().GetType().String(), inputType.String()) + "Invalid default value for variable %s in %s - expected type %s, but got literal %s", + name, fieldName, defaultInput.GetVar().GetType(), defaultValue) } if defaultInput.GetVar().GetType().GetSimple() == core.SimpleType_DATETIME { diff --git a/flyteadmin/pkg/manager/impl/validation/validation_test.go b/flyteadmin/pkg/manager/impl/validation/validation_test.go index 265868789e..f040fe32d2 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation_test.go +++ b/flyteadmin/pkg/manager/impl/validation/validation_test.go @@ -347,7 +347,7 @@ func TestValidateParameterMap(t *testing.T) { err := validateParameterMap(&exampleMap, fieldName) assert.Error(t, err) fmt.Println(err.Error()) - assert.Contains(t, err.Error(), failedToValidateLiteralType) + assert.Contains(t, err.Error(), "Invalid default value for variable foo in test_field_name - expected type simple:1000, but got literal scalar:{}") }) } diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index c129312e44..94679ecdfb 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -33,7 +33,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/trace v1.24.0 - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 golang.org/x/sync v0.7.0 golang.org/x/time v0.5.0 google.golang.org/grpc v1.62.1 @@ -142,6 +141,7 @@ require ( go.opentelemetry.io/otel/sdk v1.24.0 // indirect go.opentelemetry.io/proto/otlp v1.1.0 // indirect golang.org/x/crypto v0.25.0 // indirect + golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/flytepropeller/pkg/compiler/errors/compiler_errors.go b/flytepropeller/pkg/compiler/errors/compiler_errors.go index b2e3796edd..9d8dd9f935 100755 --- a/flytepropeller/pkg/compiler/errors/compiler_errors.go +++ b/flytepropeller/pkg/compiler/errors/compiler_errors.go @@ -213,6 +213,14 @@ func NewMismatchingTypesErr(nodeID, fromVar, fromType, toType string) *CompileEr ) } +func NewMismatchingInstanceErr(nodeID, toVar, toType, fromVar string) *CompileError { + return newError( + MismatchingTypes, + fmt.Sprintf("Variable [%v] expected to be of type [%v], but got [%v].", toVar, toType, fromVar), + nodeID, + ) +} + func NewMismatchingVariablesErr(nodeID, fromVar, fromType, toVar, toType string) *CompileError { return newError( MismatchingTypes, diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go index 6d7572e9f5..75b2d42c26 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go @@ -35,14 +35,8 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } - inputType := validators.LiteralTypeForLiteral(inputVal) - err := validators.ValidateLiteralType(inputType) - if err != nil { - errs.Collect(errors.NewInvalidLiteralTypeErr(nodeID, inputVar, err)) - continue - } - if !validators.AreTypesCastable(inputType, v.GetType()) { - errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, common.LiteralTypeToStr(v.GetType()), common.LiteralTypeToStr(inputType))) + if !validators.IsInstance(inputVal, v.GetType()) { + errs.Collect(errors.NewMismatchingInstanceErr(nodeID, inputVar, common.LiteralTypeToStr(v.GetType()), inputVal.String())) continue } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go index d77aafec49..eadea93228 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go @@ -19,7 +19,7 @@ func TestValidateInputs_InvalidLiteralType(t *testing.T) { "input1": { Type: &core.LiteralType{ Type: &core.LiteralType_Simple{ - Simple: 1000, + Simple: core.SimpleType_INTEGER, }, }, }, @@ -42,7 +42,7 @@ func TestValidateInputs_InvalidLiteralType(t *testing.T) { idlNotFound := false var errMsg string for _, err := range errs.Errors().List() { - if err.Code() == "InvalidLiteralType" { + if err.Code() == "MismatchingTypes" { idlNotFound = true errMsg = err.Error() break @@ -50,6 +50,6 @@ func TestValidateInputs_InvalidLiteralType(t *testing.T) { } assert.True(t, idlNotFound, "Expected InvalidLiteralType error was not found in errors") - expectedContainedErrorMsg := "Failed to validate literal type" + expectedContainedErrorMsg := "Variable [input1] expected to be of type " assert.Contains(t, errMsg, expectedContainedErrorMsg) } diff --git a/flytepropeller/pkg/compiler/validators/bindings_test.go b/flytepropeller/pkg/compiler/validators/bindings_test.go index e817ba5d5c..2975b340b9 100644 --- a/flytepropeller/pkg/compiler/validators/bindings_test.go +++ b/flytepropeller/pkg/compiler/validators/bindings_test.go @@ -3,7 +3,6 @@ package validators import ( "testing" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -103,7 +102,11 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, } @@ -132,7 +135,15 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral([]interface{}{5})), + Type: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, }, }, } @@ -227,10 +238,15 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral( - map[string]interface{}{ - "xy": 5, - })), + Type: &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, }, }, } @@ -267,7 +283,11 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, }, @@ -294,7 +314,11 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, } @@ -333,7 +357,20 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(map[string]interface{}{"x": []interface{}{1, 3, 4}})), + //Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(map[string]interface{}{"x": []interface{}{1, 3, 4}})), + Type: &core.LiteralType{ + Type: &core.LiteralType_MapValueType{ + MapValueType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, }, }, }, @@ -368,7 +405,11 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, } @@ -400,10 +441,19 @@ func TestValidateBindings(t *testing.T) { n2.OnGetId().Return("node2") n2.OnGetMetadata().Return(&core.NodeMetadata{Name: "node2"}) n2.OnGetOutputAliases().Return(nil) - literalType := LiteralTypeForLiteral(coreutils.MustMakeLiteral(&structpb.Struct{})) + literalType := &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRUCT, + }, + } literalType.Structure = &core.TypeStructure{} - literalType.Structure.DataclassType = map[string]*core.LiteralType{"x": LiteralTypeForLiteral(coreutils.MustMakeLiteral(1))} - + literalType.Structure.DataclassType = map[string]*core.LiteralType{ + "x": &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + } n2.OnGetInterface().Return(&core.TypedInterface{ Inputs: &core.VariableMap{ Variables: map[string]*core.Variable{}, @@ -446,7 +496,11 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, } @@ -481,7 +535,11 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, }, @@ -503,7 +561,11 @@ func TestValidateBindings(t *testing.T) { vars := &core.VariableMap{ Variables: map[string]*core.Variable{ "x": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(5)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, } @@ -1066,7 +1128,11 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, }, @@ -1149,7 +1215,11 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(2)), + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, }, }, }, @@ -1239,25 +1309,20 @@ func TestValidateBindings(t *testing.T) { Outputs: &core.VariableMap{ Variables: map[string]*core.Variable{ "n2_out": { - Type: LiteralTypeForLiteral(&core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Union{ - Union: &core.Union{ - Value: coreutils.MustMakeLiteral(5), - Type: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_INTEGER, - }, - Structure: &core.TypeStructure{ - Tag: "int1", - }, + Type: &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + Structure: &core.TypeStructure{ + Tag: "int1", }, }, }, }, }, - }), + }, }, }, }, diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index e8f0089c14..bf9047935c 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -2,10 +2,9 @@ package validators import ( "fmt" + "strings" "github.com/golang/protobuf/proto" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" @@ -170,125 +169,244 @@ func UnionDistinctVariableMaps(m1, m2 map[string]*core.Variable) (map[string]*co return res, nil } -func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType { - var variants []*core.LiteralType - isNested := false +// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid. +func ValidateLiteralType(lt *core.LiteralType) error { + if lt == nil { + err := fmt.Errorf("got unknown literal type: [%v].\n"+ + "Suggested solution: Please update all your Flyte deployment images to the latest version and try again", lt) + return err + } + if lt.GetCollectionType() != nil { + return ValidateLiteralType(lt.GetCollectionType()) + } + if lt.GetMapValueType() != nil { + return ValidateLiteralType(lt.GetMapValueType()) + } - for _, x := range innerType { - unionType := x.GetCollectionType().GetUnionType() - if unionType != nil { - isNested = true - variants = append(variants, unionType.GetVariants()...) - } else { - variants = append(variants, x) + return nil +} + +type instanceChecker interface { + isInstance(*core.Literal) bool +} + +type trivialInstanceChecker struct { + literalType *core.LiteralType +} + +func (t trivialInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetValue().(*core.Literal_Scalar); !ok { + return false + } + targetType := t.literalType + if targetType.GetEnumType() != nil { + // If t is an enum, it can be created from a string as Enums as just constrained String aliases + if _, ok := lit.GetScalar().GetPrimitive().GetValue().(*core.Primitive_StringValue); ok { + return true } } - unionLiteralType := &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: variants, - }, - }, + + literalType := literalTypeForScalar(lit.GetScalar()) + err := ValidateLiteralType(literalType) + if err != nil { + return false + } + return AreTypesCastable(literalType, targetType) +} + +type noneInstanceChecker struct{} + +func (t noneInstanceChecker) isInstance(lit *core.Literal) bool { + if lit == nil { + return true } + _, ok := lit.GetScalar().GetValue().(*core.Scalar_NoneType) + return ok +} - if isNested { - return &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: unionLiteralType, - }, +type collectionInstanceChecker struct { + literalType *core.LiteralType +} + +func (t collectionInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetValue().(*core.Literal_Collection); !ok { + return false + } + for _, x := range lit.GetCollection().GetLiterals() { + if !IsInstance(x, t.literalType.GetCollectionType()) { + return false } } + return true +} - return unionLiteralType +type mapInstanceChecker struct { + literalType *core.LiteralType } -func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { - innerType := make([]*core.LiteralType, 0, 1) - innerTypeSet := sets.NewString() - var noneType *core.LiteralType - for _, x := range literals { - otherType := LiteralTypeForLiteral(x) - otherTypeKey := otherType.String() - if _, ok := x.GetValue().(*core.Literal_Collection); ok { - if x.GetCollection().GetLiterals() == nil { - noneType = otherType - continue - } +func (t mapInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetValue().(*core.Literal_Map); !ok { + return false + } + for _, x := range lit.GetMap().GetLiterals() { + if !IsInstance(x, t.literalType.GetMapValueType()) { + return false } + } + return true +} + +type blobInstanceChecker struct { + literalType *core.LiteralType +} + +func (t blobInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetScalar().GetValue().(*core.Scalar_Blob); !ok { + return false + } + + blobType := lit.GetScalar().GetBlob().GetMetadata().GetType() + if blobType == nil { + return false + } + + // Empty blobs should match any blob. + if blobType.GetFormat() == "" || t.literalType.GetBlob().GetFormat() == "" { + return true + } + + return blobType.GetFormat() == t.literalType.GetBlob().GetFormat() +} + +type schemaInstanceChecker struct { + literalType *core.LiteralType +} + +func (t schemaInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetValue().(*core.Literal_Scalar); !ok { + return false + } + scalar := lit.GetScalar() - if !innerTypeSet.Has(otherTypeKey) { - innerType = append(innerType, otherType) - innerTypeSet.Insert(otherTypeKey) + switch v := scalar.GetValue().(type) { + case *core.Scalar_Schema: + return schemaCastFromSchema(scalar.GetSchema().GetType(), t.literalType.GetSchema()) + case *core.Scalar_StructuredDataset: + if v.StructuredDataset == nil || v.StructuredDataset.GetMetadata() == nil { + return true } + return schemaCastFromStructuredDataset(scalar.GetStructuredDataset().GetMetadata().GetStructuredDatasetType(), t.literalType.GetSchema()) + default: + return false } +} + +type structuredDatasetInstanceChecker struct { + literalType *core.LiteralType +} - // only add none type if there aren't other types - if len(innerType) == 0 && noneType != nil { - innerType = append(innerType, noneType) +func (t structuredDatasetInstanceChecker) isInstance(lit *core.Literal) bool { + if _, ok := lit.GetValue().(*core.Literal_Scalar); !ok { + return false } + scalar := lit.GetScalar() - if len(innerType) == 0 { - return &core.LiteralType{ - Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, + switch v := scalar.GetValue().(type) { + case *core.Scalar_NoneType: + return true + case *core.Scalar_Schema: + // Flyte Schema can only be serialized to parquet + format := t.literalType.GetStructuredDatasetType().GetFormat() + if len(format) != 0 && !strings.EqualFold(format, "parquet") { + return false } - } else if len(innerType) == 1 { - return innerType[0] - } - - // sort inner types to ensure consistent union types are generated - slices.SortFunc(innerType, func(a, b *core.LiteralType) int { - aStr := a.String() - bStr := b.String() - if aStr < bStr { - return -1 - } else if aStr > bStr { - return 1 + return structuredDatasetCastFromSchema(scalar.GetSchema().GetType(), t.literalType.GetStructuredDatasetType()) + case *core.Scalar_StructuredDataset: + if v.StructuredDataset == nil || v.StructuredDataset.GetMetadata() == nil { + return true } + return structuredDatasetCastFromStructuredDataset(scalar.GetStructuredDataset().GetMetadata().GetStructuredDatasetType(), t.literalType.GetStructuredDatasetType()) + default: + return false + } +} - return 0 - }) - return buildMultipleTypeUnion(innerType) +type unionInstanceChecker struct { + literalType *core.LiteralType } -// ValidateLiteralType check if the literal type is valid, return error if the literal is invalid. -func ValidateLiteralType(lt *core.LiteralType) error { - if lt == nil { - err := fmt.Errorf("got unknown literal type: [%v].\n"+ - "Suggested solution: Please update all your Flyte deployment images to the latest version and try again", lt) - return err - } - if lt.GetCollectionType() != nil { - return ValidateLiteralType(lt.GetCollectionType()) +func (t unionInstanceChecker) isInstance(lit *core.Literal) bool { + unionType := t.literalType.GetUnionType() + + if u := lit.GetScalar().GetUnion().GetType(); u != nil { + found := false + for _, d := range unionType.GetVariants() { + if AreTypesCastable(u, d) { + found = true + break + } + } + return found } - if lt.GetMapValueType() != nil { - return ValidateLiteralType(lt.GetMapValueType()) + + // Matches iff we can unambiguously select a variant + foundOne := false + for _, x := range unionType.GetVariants() { + if IsInstance(lit, x) { + if foundOne { + return false + } + foundOne = true + } } - return nil + return foundOne } -// LiteralTypeForLiteral gets LiteralType for literal, nil if the value of literal is unknown, or type collection/map of -// type None if the literal is a non-homogeneous type. -func LiteralTypeForLiteral(l *core.Literal) *core.LiteralType { - switch l.GetValue().(type) { - case *core.Literal_Scalar: - return literalTypeForScalar(l.GetScalar()) - case *core.Literal_Collection: - return &core.LiteralType{ - Type: &core.LiteralType_CollectionType{ - CollectionType: literalTypeForLiterals(l.GetCollection().GetLiterals()), - }, +func getInstanceChecker(t *core.LiteralType) instanceChecker { + switch t.GetType().(type) { + case *core.LiteralType_CollectionType: + return collectionInstanceChecker{ + literalType: t, } - case *core.Literal_Map: - return &core.LiteralType{ - Type: &core.LiteralType_MapValueType{ - MapValueType: literalTypeForLiterals(maps.Values(l.GetMap().GetLiterals())), - }, + case *core.LiteralType_MapValueType: + return mapInstanceChecker{ + literalType: t, + } + case *core.LiteralType_Blob: + return blobInstanceChecker{ + literalType: t, + } + case *core.LiteralType_Schema: + return schemaInstanceChecker{ + literalType: t, + } + case *core.LiteralType_UnionType: + return unionInstanceChecker{ + literalType: t, + } + case *core.LiteralType_StructuredDatasetType: + return structuredDatasetInstanceChecker{ + literalType: t, + } + default: + if isNoneType(t) { + return noneInstanceChecker{} + } + + return trivialInstanceChecker{ + literalType: t, } - case *core.Literal_OffloadedMetadata: - return l.GetOffloadedMetadata().GetInferredType() } - return nil +} + +func IsInstance(lit *core.Literal, t *core.LiteralType) bool { + instanceChecker := getInstanceChecker(t) + + if lit.GetOffloadedMetadata() != nil { + return AreTypesCastable(lit.GetOffloadedMetadata().GetInferredType(), t) + } + return instanceChecker.isInstance(lit) } func GetTagForType(x *core.LiteralType) string { diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 0238c652e2..412c0d3356 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -3,7 +3,6 @@ package validators import ( "testing" - "github.com/golang/protobuf/proto" "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" @@ -11,10 +10,13 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" ) -func TestLiteralTypeForLiterals(t *testing.T) { +func TestIsInstance(t *testing.T) { t.Run("empty", func(t *testing.T) { - lt := literalTypeForLiterals(nil) - assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) + assert.True(t, IsInstance(nil, &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + })) }) t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { @@ -33,8 +35,11 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, }, } - lt := LiteralTypeForLiteral(lv) - assert.Equal(t, core.SimpleType_BINARY.String(), lt.GetSimple().String()) + assert.True(t, IsInstance(lv, &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BINARY, + }, + })) }) t.Run("binary idl with messagepack input map[int]strings", func(t *testing.T) { @@ -61,8 +66,11 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, }, } - lt := LiteralTypeForLiteral(lv) - assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) + assert.True(t, IsInstance(lv, &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRUCT, + }, + })) }) t.Run("binary idl with messagepack input map[float]strings", func(t *testing.T) { @@ -89,45 +97,11 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, }, } - lt := LiteralTypeForLiteral(lv) - assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) - }) - - t.Run("homogeneous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(5), - }) - - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetSimple().String()) - }) - - t.Run("non-homogenous", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral("hello"), - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral("world"), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(2), - }) - - assert.Len(t, lt.GetUnionType().GetVariants(), 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().GetVariants()[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().GetVariants()[1].GetSimple().String()) - }) - - t.Run("non-homogenous ensure ordering", func(t *testing.T) { - lt := literalTypeForLiterals([]*core.Literal{ - coreutils.MustMakeLiteral(5), - coreutils.MustMakeLiteral("world"), - coreutils.MustMakeLiteral(0), - coreutils.MustMakeLiteral(2), - }) - - assert.Len(t, lt.GetUnionType().GetVariants(), 2) - assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().GetVariants()[0].GetSimple().String()) - assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().GetVariants()[1].GetSimple().String()) + assert.True(t, IsInstance(lv, &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRUCT, + }, + })) }) t.Run("list with mixed types", func(t *testing.T) { @@ -196,8 +170,6 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - lt := LiteralTypeForLiteral(literals) - expectedLt := &core.LiteralType{ Type: &core.LiteralType_CollectionType{ CollectionType: &core.LiteralType{ @@ -237,7 +209,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - assert.True(t, proto.Equal(expectedLt, lt)) + assert.True(t, IsInstance(literals, expectedLt)) }) t.Run("nested lists with empty list", func(t *testing.T) { @@ -276,8 +248,6 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - lt := LiteralTypeForLiteral(literals) - expectedLt := &core.LiteralType{ Type: &core.LiteralType_CollectionType{ CollectionType: &core.LiteralType{ @@ -292,7 +262,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - assert.True(t, proto.Equal(expectedLt, lt)) + assert.True(t, IsInstance(literals, expectedLt)) }) t.Run("nested Lists with different types", func(t *testing.T) { @@ -374,9 +344,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - lt := LiteralTypeForLiteral(literals) - - assert.True(t, proto.Equal(expectedLt, lt)) + assert.True(t, IsInstance(literals, expectedLt)) }) t.Run("empty nested listed", func(t *testing.T) { @@ -408,9 +376,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } - lt := LiteralTypeForLiteral(literals) - - assert.True(t, proto.Equal(expectedLt, lt)) + assert.True(t, IsInstance(literals, expectedLt)) }) t.Run("nested Lists with different types", func(t *testing.T) { @@ -450,8 +416,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { }, } expectedLt := inferredType - lt := LiteralTypeForLiteral(literals) - assert.True(t, proto.Equal(expectedLt, lt)) + assert.True(t, IsInstance(literals, expectedLt)) }) } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 51d3105a0a..6418c35270 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -16,7 +16,6 @@ import ( "github.com/flyteorg/flyte/flytepropeller/events" eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/common" @@ -200,15 +199,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu size := -1 - for key, variable := range literalMap.GetLiterals() { - literalType := validators.LiteralTypeForLiteral(variable) - err := validators.ValidateLiteralType(literalType) - if err != nil { - errMsg := fmt.Sprintf("Failed to validate literal type for [%s] with err: %s", key, err) - return handler.DoTransition(handler.TransitionTypeEphemeral, - handler.PhaseInfoFailure(idlcore.ExecutionError_USER, errors.IDLNotFoundErr, errMsg, nil), - ), nil - } + for _, variable := range literalMap.GetLiterals() { if variable.GetOffloadedMetadata() != nil { // variable will be overwritten with the contents of the offloaded data which contains the actual large literal. // We need this for the map task to be able to create the subNodeSpec @@ -219,8 +210,8 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu ), nil } } - switch literalType.GetType().(type) { - case *idlcore.LiteralType_CollectionType: + switch variable.GetValue().(type) { + case *idlcore.Literal_Collection: collectionLength := len(variable.GetCollection().GetLiterals()) if size == -1 { size = collectionLength @@ -479,6 +470,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu )), nil case v1alpha1.ArrayNodePhaseSucceeding: gatherOutputsRequests := make([]*gatherOutputsRequest, 0, len(arrayNodeState.SubNodePhases.GetItems())) + outputLiteralTypes := make(map[string]*idlcore.LiteralType) for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() { nodePhase := v1alpha1.NodePhase(nodePhaseUint64) // #nosec G115 gatherOutputsRequest := &gatherOutputsRequest{ @@ -505,6 +497,12 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu if task.CoreTask() != nil && task.CoreTask().GetInterface() != nil && task.CoreTask().GetInterface().GetOutputs() != nil { for name := range task.CoreTask().GetInterface().GetOutputs().GetVariables() { outputLiterals[name] = nilLiteral + // Extract the literal type from the task interface + outputLiteralTypes[name] = &idlcore.LiteralType{ + Type: &idlcore.LiteralType_CollectionType{ + CollectionType: task.CoreTask().GetInterface().GetOutputs().GetVariables()[name].GetType(), + }, + } } } @@ -595,7 +593,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // use the OffloadLargeLiteralKey to create {OffloadLargeLiteralKey}_offloaded_metadata.pb file in the datastore. // Update the url in the outputLiteral with the offloaded url and also update the size of the literal. offloadedOutputFile := v1alpha1.GetOutputsLiteralMetadataFile(outputLiteralKey, nCtx.NodeStatus().GetOutputDir()) - if err := common.OffloadLargeLiteral(ctx, nCtx.DataStore(), offloadedOutputFile, outputLiteral, a.literalOffloadingConfig); err != nil { + if err := common.OffloadLargeLiteral(ctx, nCtx.DataStore(), offloadedOutputFile, outputLiteral, outputLiteralTypes[outputLiteralKey], a.literalOffloadingConfig); err != nil { return handler.UnknownTransition, err } } diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index ac0e4b45ad..91e6533b5e 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -1023,8 +1023,8 @@ func TestHandle_InvalidLiteralType(t *testing.T) { }, expectedTransitionType: handler.TransitionTypeEphemeral, expectedPhase: handler.EPhaseFailed, - expectedErrorCode: errors.IDLNotFoundErr, - expectedContainedErrorMsg: "Failed to validate literal type", + expectedErrorCode: errors.InvalidArrayLength, + expectedContainedErrorMsg: "no input array provided", }, } diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go index 403d1a6885..9ff494e54f 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer.go @@ -53,13 +53,8 @@ func GenerateTaskOutputsFromArtifact(id core.Identifier, taskInterface core.Type } expectedVarType := outputVariables[artifactData.GetName()].GetType() - inputType := validators.LiteralTypeForLiteral(artifactData.GetValue()) - err := validators.ValidateLiteralType(inputType) - if err != nil { - return nil, fmt.Errorf("failed to validate literal type for %s with err: %s", artifactData.GetName(), err) - } - if !validators.AreTypesCastable(inputType, expectedVarType) { - return nil, fmt.Errorf("unexpected artifactData: [%v] type: [%v] does not match any task output type: [%v]", artifactData.GetName(), inputType, expectedVarType) + if !validators.IsInstance(artifactData.GetValue(), expectedVarType) { + return nil, fmt.Errorf("unexpected artifactData: [%v] val: [%v] does not match any task output type: [%v]", artifactData.GetName(), artifactData.GetValue(), expectedVarType) } outputs[artifactData.GetName()] = artifactData.GetValue() diff --git a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go index 6fd6455e02..c1d27f5891 100644 --- a/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/catalog/datacatalog/transformer_test.go @@ -360,14 +360,14 @@ func TestGenerateTaskOutputsFromArtifact_IDLNotFound(t *testing.T) { Data: []*datacatalog.ArtifactData{ { Name: "output1", - Value: &core.Literal{}, // This will cause LiteralTypeForLiteral to return nil + Value: &core.Literal{}, }, }, } _, err := GenerateTaskOutputsFromArtifact(taskID, taskInterface, artifact) - expectedContainedErrorMsg := "failed to validate literal type" + expectedContainedErrorMsg := "unexpected artifactData: [output1] val: [] does not match any task output type" assert.Error(t, err) assert.Contains(t, err.Error(), expectedContainedErrorMsg) } diff --git a/flytepropeller/pkg/controller/nodes/common/utils.go b/flytepropeller/pkg/controller/nodes/common/utils.go index 839be0c99f..6bb7ee554b 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils.go +++ b/flytepropeller/pkg/controller/nodes/common/utils.go @@ -12,7 +12,6 @@ import ( idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" @@ -104,7 +103,7 @@ func ReadLargeLiteral(ctx context.Context, datastore *storage.DataStore, // OffloadLargeLiteral offloads the large literal if meets the threshold conditions func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, dataReference storage.DataReference, - toBeOffloaded *idlcore.Literal, literalOffloadingConfig config.LiteralOffloadingConfig) error { + toBeOffloaded *idlcore.Literal, inferredType *idlcore.LiteralType, literalOffloadingConfig config.LiteralOffloadingConfig) error { literalSizeBytes := int64(proto.Size(toBeOffloaded)) literalSizeMB := literalSizeBytes / MB // check if the literal is large @@ -118,7 +117,6 @@ func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, data return nil } - inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) if inferredType == nil { errString := "Failed to determine literal type for offloaded literal" logger.Errorf(ctx, errString) diff --git a/flytepropeller/pkg/controller/nodes/common/utils_test.go b/flytepropeller/pkg/controller/nodes/common/utils_test.go index bde50c8040..ac1ca45bbd 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/common/utils_test.go @@ -10,7 +10,6 @@ import ( idlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" executorMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors/mocks" nodeMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces/mocks" @@ -142,8 +141,16 @@ func TestOffloadLargeLiteral(t *testing.T) { MinSizeInMBForOffloading: 0, MaxSizeInMBForOffloading: 1, } - inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) - err = OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + inferredType := &idlCore.LiteralType{ + Type: &idlCore.LiteralType_CollectionType{ + CollectionType: &idlCore.LiteralType{ + Type: &idlCore.LiteralType_Simple{ + Simple: idlCore.SimpleType_INTEGER, + }, + }, + }, + } + err = OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, inferredType, literalOffloadingConfig) assert.NoError(t, err) assert.Equal(t, "foo/bar", toBeOffloaded.GetOffloadedMetadata().GetUri()) assert.Equal(t, uint64(6), toBeOffloaded.GetOffloadedMetadata().GetSizeBytes()) @@ -173,7 +180,16 @@ func TestOffloadLargeLiteral(t *testing.T) { MinSizeInMBForOffloading: 0, MaxSizeInMBForOffloading: 1, } - err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + inferredType := &idlCore.LiteralType{ + Type: &idlCore.LiteralType_CollectionType{ + CollectionType: &idlCore.LiteralType{ + Type: &idlCore.LiteralType_Simple{ + Simple: idlCore.SimpleType_INTEGER, + }, + }, + }, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, inferredType, literalOffloadingConfig) assert.NoError(t, err) assert.Equal(t, "hash", toBeOffloaded.GetHash()) }) @@ -199,7 +215,16 @@ func TestOffloadLargeLiteral(t *testing.T) { MinSizeInMBForOffloading: 0, MaxSizeInMBForOffloading: 0, } - err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + inferredType := &idlCore.LiteralType{ + Type: &idlCore.LiteralType_CollectionType{ + CollectionType: &idlCore.LiteralType{ + Type: &idlCore.LiteralType_Simple{ + Simple: idlCore.SimpleType_INTEGER, + }, + }, + }, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, inferredType, literalOffloadingConfig) assert.Error(t, err) }) @@ -224,7 +249,16 @@ func TestOffloadLargeLiteral(t *testing.T) { MinSizeInMBForOffloading: 2, MaxSizeInMBForOffloading: 3, } - err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + inferredType := &idlCore.LiteralType{ + Type: &idlCore.LiteralType_CollectionType{ + CollectionType: &idlCore.LiteralType{ + Type: &idlCore.LiteralType_Simple{ + Simple: idlCore.SimpleType_INTEGER, + }, + }, + }, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, inferredType, literalOffloadingConfig) assert.NoError(t, err) assert.Nil(t, toBeOffloaded.GetOffloadedMetadata()) })