Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Support union and none type in flyteidl (#401)
Browse files Browse the repository at this point in the history
* add support for Union Scalar

Signed-off-by: Yubo Wang <[email protected]>

* support union type and literals

Signed-off-by: Yubo Wang <[email protected]>

* change union type extraction

Signed-off-by: Yubo Wang <[email protected]>

---------

Signed-off-by: Yubo Wang <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored and eapolinario committed May 16, 2023
1 parent c815a9c commit 3284f61
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 3 deletions.
11 changes: 11 additions & 0 deletions clients/go/coreutils/extract_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) {
return scalarValue.Schema.Uri, nil
case *core.Scalar_Generic:
return scalarValue.Generic, nil
case *core.Scalar_StructuredDataset:
return scalarValue.StructuredDataset.Uri, nil
case *core.Scalar_Union:
// extract the value of the union but not the actual union object
extractedVal, err := ExtractFromLiteral(scalarValue.Union.Value)
if err != nil {
return nil, err
}
return extractedVal, nil
case *core.Scalar_NoneType:
return nil, nil
default:
return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue)
}
Expand Down
63 changes: 62 additions & 1 deletion clients/go/coreutils/extract_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestFetchLiteral(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, p.GetScalar())
_, err = ExtractFromLiteral(p)
assert.NotNil(t, err)
assert.Nil(t, err)
})

t.Run("Generic", func(t *testing.T) {
Expand Down Expand Up @@ -176,4 +176,65 @@ func TestFetchLiteral(t *testing.T) {
assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind)
}
})

t.Run("Structured dataset", func(t *testing.T) {
literalVal := "s3://blah/blah/blah"
var dataSetColumns []*core.StructuredDatasetType_DatasetColumn
dataSetColumns = append(dataSetColumns, &core.StructuredDatasetType_DatasetColumn{
Name: "Price",
LiteralType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_FLOAT,
},
},
})
var literalType = &core.LiteralType{Type: &core.LiteralType_StructuredDatasetType{StructuredDatasetType: &core.StructuredDatasetType{
Columns: dataSetColumns,
Format: "testFormat",
}}}

lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union", func(t *testing.T) {
literalVal := int64(1)
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union with None", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, nil)

assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Nil(t, extractedLiteralVal)
})
}
48 changes: 48 additions & 0 deletions clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) {
return MakeLiteralForType(typ, nil)
case *core.LiteralType_Schema:
return MakeLiteralForType(typ, nil)
case *core.LiteralType_UnionType:
if len(t.UnionType.Variants) == 0 {
return nil, errors.Errorf("Union type must have at least one variant")
}
// For union types, we just return the default for the first variant
val, err := MakeDefaultLiteralForType(t.UnionType.Variants[0])
if err != nil {
return nil, errors.Errorf("Failed to create default literal for first union type variant [%v]", t.UnionType.Variants[0])
}
res := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: t.UnionType.Variants[0],
Value: val,
},
},
},
},
}
return res, nil
}

return nil, fmt.Errorf("failed to convert to a known Literal. Input Type [%v] not supported", typ.String())
Expand Down Expand Up @@ -565,6 +587,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro
}
return MakePrimitiveLiteral(newV)

case *core.LiteralType_UnionType:
// Try different types in the variants, return the first one matched
found := false
for _, subType := range newT.UnionType.Variants {
lv, err := MakeLiteralForType(subType, v)
if err == nil {
l = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: lv,
Type: subType,
},
},
},
},
}
found = true
break
}
}
if !found {
return nil, fmt.Errorf("incorrect union value [%s], supported values %+v", v, newT.UnionType.Variants)
}

default:
return nil, fmt.Errorf("unsupported type %s", t.String())
}
Expand Down
87 changes: 87 additions & 0 deletions clients/go/coreutils/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ func TestMakeDefaultLiteralForType(t *testing.T) {
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "x"}}}}}}
assert.Equal(t, expected, l)
})

t.Run("union", func(t *testing.T) {
l, err := MakeDefaultLiteralForType(
&core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
},
)
assert.NoError(t, err)
assert.Equal(t, "*core.Union", reflect.TypeOf(l.GetScalar().GetUnion()).String())
})
}

func TestMustMakeDefaultLiteralForType(t *testing.T) {
Expand Down Expand Up @@ -675,4 +692,74 @@ func TestMakeLiteralForType(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "", l.GetScalar().GetPrimitive().GetStringValue())
})

t.Run("Structured Data Set", func(t *testing.T) {
var dataSetColumns []*core.StructuredDatasetType_DatasetColumn
dataSetColumns = append(dataSetColumns, &core.StructuredDatasetType_DatasetColumn{
Name: "Price",
LiteralType: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: core.SimpleType_FLOAT,
},
},
})
var literalType = &core.LiteralType{Type: &core.LiteralType_StructuredDatasetType{StructuredDatasetType: &core.StructuredDatasetType{
Columns: dataSetColumns,
Format: "testFormat",
}}}

expectedLV := &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_StructuredDataset{
StructuredDataset: &core.StructuredDataset{
Uri: "s3://blah/blah/blah",
Metadata: &core.StructuredDatasetMetadata{
StructuredDatasetType: &core.StructuredDatasetType{
Columns: dataSetColumns,
Format: "testFormat",
},
},
},
},
}}}
lv, err := MakeLiteralForType(literalType, "s3://blah/blah/blah")
assert.NoError(t, err)

assert.Equal(t, expectedLV, lv)

expectedVal, err := ExtractFromLiteral(expectedLV)
assert.NoError(t, err)
actualVal, err := ExtractFromLiteral(lv)
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})

t.Run("Union", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
expectedLV := &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_FloatValue{FloatValue: 0.1}}}}}},
},
},
}}}
lv, err := MakeLiteralForType(literalType, float64(0.1))
assert.NoError(t, err)
assert.Equal(t, expectedLV, lv)
expectedVal, err := ExtractFromLiteral(expectedLV)
assert.NoError(t, err)
actualVal, err := ExtractFromLiteral(lv)
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ require (
k8s.io/klog/v2 v2.5.0 // indirect
)

// These 2 versions were wrongly published.
// These 2 versions were wrongly published.
retract (
v1.4.0
v1.4.2
v1.4.0
)

0 comments on commit 3284f61

Please sign in to comment.