Skip to content

Commit

Permalink
feat: Add Min/Max ArgumentCount API to functionvariant (#57)
Browse files Browse the repository at this point in the history
* Add Min/Max ArgumentCount API to functionvariant
  • Loading branch information
anshuldata authored Oct 1, 2024
1 parent db135fe commit 5d40def
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 0 deletions.
41 changes: 41 additions & 0 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ type FunctionVariant interface {
// argument nullability is not correctly set this function will return error
// returns (false, nil) valid input argument type and argument can't type replace parameter at argPos
MatchAt(typ types.Type, pos int) (bool, error)
// MinArgumentCount returns minimum number of arguments required for this function
MinArgumentCount() int
// MaxArgumentCount returns minimum number of arguments accepted by this function
MaxArgumentCount() int
}

func validateType(arg Argument, actual types.Type, idx int, nullHandling NullabilityHandling) (bool, error) {
Expand Down Expand Up @@ -249,6 +253,20 @@ func parseFuncName(compoundName string) (name string, args ArgumentList) {
return name, args
}

func minArgumentCount(paramTypeList ArgumentList, variadicBehavior *VariadicBehavior) int {
if variadicBehavior == nil {
return len(paramTypeList)
}
return len(paramTypeList) + variadicBehavior.Min
}

func maxArgumentCount(paramTypeList ArgumentList, variadicBehavior *VariadicBehavior) int {
if variadicBehavior == nil {
return len(paramTypeList)
}
return len(paramTypeList) + variadicBehavior.Max
}

// NewScalarFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -315,6 +333,14 @@ func (s *ScalarFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}

func (s *ScalarFunctionVariant) MinArgumentCount() int {
return minArgumentCount(s.impl.Args, s.impl.Variadic)
}

func (s *ScalarFunctionVariant) MaxArgumentCount() int {
return maxArgumentCount(s.impl.Args, s.impl.Variadic)
}

// NewAggFuncVariant constructs a variant with the provided name and uri
// and uses the defaults for everything else.
//
Expand Down Expand Up @@ -429,6 +455,13 @@ func (s *AggregateFunctionVariant) Match(argumentTypes []types.Type) (bool, erro
func (s *AggregateFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}
func (s *AggregateFunctionVariant) MinArgumentCount() int {
return minArgumentCount(s.impl.Args, s.impl.Variadic)
}

func (s *AggregateFunctionVariant) MaxArgumentCount() int {
return maxArgumentCount(s.impl.Args, s.impl.Variadic)
}

type WindowFunctionVariant struct {
name string
Expand Down Expand Up @@ -544,6 +577,14 @@ func (s *WindowFunctionVariant) MatchAt(typ types.Type, pos int) (bool, error) {
return matchArgumentAt(typ, pos, s.Nullability(), s.impl.Args, s.impl.Variadic)
}

func (s *WindowFunctionVariant) MinArgumentCount() int {
return minArgumentCount(s.impl.Args, s.impl.Variadic)
}

func (s *WindowFunctionVariant) MaxArgumentCount() int {
return maxArgumentCount(s.impl.Args, s.impl.Variadic)
}

// HasSyncParams This API returns if params share a leaf param name
func HasSyncParams(params []types.FuncDefArgType) bool {
// if any of the leaf parameters are same, it indicates parameters are same across parameters
Expand Down
135 changes: 135 additions & 0 deletions functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ scalar_functions:
match, err := fv[0].Match(tt.argTypes)
require.NoError(t, err)
require.True(t, match)
// non-variadic function, min/max argument count should be 2
require.Equal(t, 2, fv[0].MinArgumentCount())
require.Equal(t, 2, fv[0].MaxArgumentCount())

// test MatchAt
for pos, typ := range tt.argTypes {
Expand Down Expand Up @@ -953,6 +956,8 @@ scalar_functions:
// pass third argument as variadic, it should match against last argument type
argTypes := []types.Type{int64Nullable, int32Nullable, int32Nullable}
require.Len(t, fv, 1)
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
match, err := fv[0].Match(argTypes)
require.NoError(t, err)
assert.True(t, match)
Expand Down Expand Up @@ -1141,3 +1146,133 @@ scalar_functions:
// even though function argument allows decimal(P, S)
assert.False(t, match)
}

func TestAggregateFuncMinMax(t *testing.T) {
const uri = "http://localhost/sample.yaml"
const defYaml = `---
aggregate_functions:
-
name: "func_nonvariadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
return: i32
-
name: "func_variadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
variadic:
min: 1
max: 3
return: i32
`

dialectYaml := `
name: test
type: sql
dependencies:
arithmetic:
http://localhost/sample.yaml
supported_types:
i32:
sql_type_name: INTEGER
aggregate_functions:
- name: arithmetic.func_nonvariadic
supported_kernels:
- i32_i32
- name: arithmetic.func_variadic
supported_kernels:
- i32_i32
`
// get substrait function registry
var c extensions.Collection
require.NoError(t, c.Load(uri, strings.NewReader(defYaml)))
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)

// test non-variadic min-max
fv := localRegistry.GetAggregateFunctions(LocalFunctionName("func_nonvariadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 2, fv[0].MinArgumentCount())
require.Equal(t, 2, fv[0].MaxArgumentCount())

// test variadic min-max
fv = localRegistry.GetAggregateFunctions(LocalFunctionName("func_variadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
}

func TestWindowFuncMinMax(t *testing.T) {
const uri = "http://localhost/sample.yaml"
const defYaml = `---
window_functions:
-
name: "func_nonvariadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
return: i32
-
name: "func_variadic"
description: "Add two values."
impls:
- args:
- name: x
value: i32
- name: y
value: i32
variadic:
min: 1
max: 3
return: i32
`

dialectYaml := `
name: test
type: sql
dependencies:
arithmetic:
http://localhost/sample.yaml
supported_types:
i32:
sql_type_name: INTEGER
window_functions:
- name: arithmetic.func_nonvariadic
supported_kernels:
- i32_i32
- name: arithmetic.func_variadic
supported_kernels:
- i32_i32
`
// get substrait function registry
var c extensions.Collection
require.NoError(t, c.Load(uri, strings.NewReader(defYaml)))
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)

// test non-variadic min-max
fv := localRegistry.GetWindowFunctions(LocalFunctionName("func_nonvariadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 2, fv[0].MinArgumentCount())
require.Equal(t, 2, fv[0].MaxArgumentCount())

// test variadic min-max
fv = localRegistry.GetWindowFunctions(LocalFunctionName("func_variadic"), 2)
require.Len(t, fv, 1)
require.Equal(t, 3, fv[0].MinArgumentCount())
require.Equal(t, 5, fv[0].MaxArgumentCount())
}

0 comments on commit 5d40def

Please sign in to comment.