Skip to content

Commit

Permalink
fix: allow variadic variants to evaluate (#55)
Browse files Browse the repository at this point in the history
* fix: allow variadic variants to evaluate

* add some tests
  • Loading branch information
zeroshade authored Sep 18, 2024
1 parent 1a800d9 commit 18c1a41
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 34 deletions.
97 changes: 64 additions & 33 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,43 +36,74 @@ type FunctionVariant interface {
MatchAt(typ types.Type, pos int) (bool, error)
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, actualTypes []types.Type) (types.Type, error) {
func validateType(arg Argument, actual types.Type, idx int, nullHandling NullabilityHandling) (bool, error) {
allNonNull := true
switch p := arg.(type) {
case EnumArg:
if actual != nil {
return allNonNull, fmt.Errorf("%w: arg #%d (%s) should be an enum",
substraitgo.ErrInvalidType, idx, p.Name)
}
case ValueArg:
if actual == nil {
return allNonNull, fmt.Errorf("%w: arg #%d should be of type %s",
substraitgo.ErrInvalidType, idx, p.toTypeString())
}

isNullable := actual.GetNullability() != types.NullabilityRequired
if isNullable {
allNonNull = false
}

if nullHandling == DiscreteNullability {
if t, ok := p.Value.Expr.(*parser.Type); ok {
if isNullable != t.Optional() {
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, idx)
}
} else {
return allNonNull, substraitgo.ErrNotImplemented
}
}
case TypeArg:
return allNonNull, substraitgo.ErrNotImplemented
}

return allNonNull, nil
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
if len(paramTypeList) != len(actualTypes) {
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
substraitgo.ErrInvalidExpr, len(actualTypes), len(paramTypeList))
if variadic == nil {
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
substraitgo.ErrInvalidExpr, len(actualTypes), len(paramTypeList))
}

if !variadic.IsValidArgumentCount(len(actualTypes) - len(paramTypeList) - 1) {
return nil, fmt.Errorf("%w: mismatch in number of arguments provided, invalid number of variadic params. got %d total",
substraitgo.ErrInvalidExpr, len(actualTypes))
}
}

allNonNull := true
for i, p := range paramTypeList {
switch p := p.(type) {
case EnumArg:
if actualTypes[i] != nil {
return nil, fmt.Errorf("%w: arg #%d (%s) should be an enum",
substraitgo.ErrInvalidType, i, p.Name)
}
case ValueArg:
if actualTypes[i] == nil {
return nil, fmt.Errorf("%w: arg #%d should be of type %s",
substraitgo.ErrInvalidType, i, p.toTypeString())
}

isNullable := actualTypes[i].GetNullability() != types.NullabilityRequired
if isNullable {
allNonNull = false
}
nonNull, err := validateType(p, actualTypes[i], i, nullHandling)
if err != nil {
return nil, err
}
allNonNull = allNonNull && nonNull
}

if nullHandling == DiscreteNullability {
if t, ok := p.Value.Expr.(*parser.Type); ok {
if isNullable != t.Optional() {
return nil, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, i)
}
} else {
return nil, substraitgo.ErrNotImplemented
}
// validate varidic argument consistency
if variadic != nil && len(actualTypes) > len(paramTypeList) && variadic.ParameterConsistency == ConsistentParams {
nparams := len(paramTypeList)
lastParam := paramTypeList[nparams-1]
for i, actual := range actualTypes[nparams:] {
nonNull, err := validateType(lastParam, actual, nparams+i, nullHandling)
if err != nil {
return nil, err
}
case TypeArg:
return nil, substraitgo.ErrNotImplemented
allNonNull = allNonNull && nonNull
}
}

Expand Down Expand Up @@ -267,7 +298,7 @@ func (s *ScalarFunctionVariant) SessionDependent() bool { return s.imp
func (s *ScalarFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *ScalarFunctionVariant) URI() string { return s.uri }
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *ScalarFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand Down Expand Up @@ -375,7 +406,7 @@ func (s *AggregateFunctionVariant) SessionDependent() bool { return s.
func (s *AggregateFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *AggregateFunctionVariant) URI() string { return s.uri }
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *AggregateFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand Down Expand Up @@ -488,7 +519,7 @@ func (s *WindowFunctionVariant) SessionDependent() bool { return s.imp
func (s *WindowFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *WindowFunctionVariant) URI() string { return s.uri }
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *WindowFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand Down
49 changes: 48 additions & 1 deletion extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,54 @@ func TestEvaluateTypeExpression(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, tt.args)
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, nil, tt.args)
if tt.err == "" {
assert.NoError(t, err)
assert.Truef(t, tt.expected.Equals(result), "expected: %s\ngot: %s", tt.expected, result)
} else {
assert.EqualError(t, err, tt.err)
}
})
}
}

func TestVariantWithVariadic(t *testing.T) {
var (
p, _ = parser.New()
i64Null, _ = p.ParseString("i64?")
i64NonNull, _ = p.ParseString("i64")
// strNull, _ = p.ParseString("string?")
)

tests := []struct {
name string
nulls extensions.NullabilityHandling
ret parser.TypeExpression
extArgs extensions.ArgumentList
args []types.Type
expected types.Type
variadic extensions.VariadicBehavior
err string
}{
{"basic", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
&types.Int64Type{Nullability: types.NullabilityNullable},
extensions.VariadicBehavior{
Min: 0, ParameterConsistency: extensions.ConsistentParams}, ""},
{"bad arg count", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
nil, extensions.VariadicBehavior{
Min: 2, ParameterConsistency: extensions.ConsistentParams},
"invalid expression: mismatch in number of arguments provided, invalid number of variadic params. got 2 total"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extensions.EvaluateTypeExpression(tt.nulls, tt.ret, tt.extArgs, &tt.variadic, tt.args)
if tt.err == "" {
assert.NoError(t, err)
assert.Truef(t, tt.expected.Equals(result), "expected: %s\ngot: %s", tt.expected, result)
Expand Down

0 comments on commit 18c1a41

Please sign in to comment.