Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-22.2: sql: fix ambigious udf overload with null arguments #109193

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/sql/logictest/testdata/logic_test/enums
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,7 @@ $$;
query TT
SELECT "🙏"('😊'), "🙏"(NULL:::"Emoji 😉")
----
NULL NULL
NULL mixed

statement ok
CREATE DATABASE "DB➕➕";
Expand Down
32 changes: 32 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/udf
Original file line number Diff line number Diff line change
Expand Up @@ -2839,3 +2839,35 @@ statement error unknown function: public\.LOWERCASE_HINT_ERROR_EXPLICIT_SCHEMA_F
SELECT public."LOWERCASE_HINT_ERROR_EXPLICIT_SCHEMA_FN"();

subtest end

subtest r88374

statement ok
CREATE FUNCTION f88374(i INT2) RETURNS INT STRICT LANGUAGE SQL AS 'SELECT 2';

statement ok
CREATE FUNCTION f88374(i TEXT) RETURNS INT STRICT LANGUAGE SQL AS 'SELECT 2';

statement error pgcode 42725 ambiguous call
SELECT f88374(NULL);

statement ok
CREATE TABLE t88374 (a INT, b INT);

statement ok
INSERT INTO t88374 VALUES (1, NULL);

statement ok
CREATE FUNCTION g88374 (i INT) RETURNS INT CALLED ON NULL INPUT LANGUAGE SQL AS $$ SELECT a FROM t88374 WHERE b IS NOT DISTINCT FROM i $$;

query I
SELECT g88374(NULL);
----
1

query I
SELECT g88374(NULL::INT);
----
1

subtest end
6 changes: 5 additions & 1 deletion pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ func (expr *FuncExpr) TypeCheck(
return nil, pgerror.Wrapf(err, pgcode.InvalidParameterValue, "%s()", def.Name)
}

var hasUDFOverload bool
var calledOnNullInputFns []overloadImpl
var notCalledOnNullInputFns []overloadImpl
for _, f := range fns {
Expand All @@ -1084,6 +1085,9 @@ func (expr *FuncExpr) TypeCheck(
} else {
notCalledOnNullInputFns = append(notCalledOnNullInputFns, f)
}
if f.(QualifiedOverload).IsUDF {
hasUDFOverload = true
}
}

// If the function is an aggregate that does not accept null arguments and we
Expand Down Expand Up @@ -1126,7 +1130,7 @@ func (expr *FuncExpr) TypeCheck(
// NULL arguments, the function isn't a generator or aggregate builtin, and
// NULL is given as an argument.
if len(fns) > 0 && len(calledOnNullInputFns) == 0 && funcCls != GeneratorClass &&
funcCls != AggregateClass {
funcCls != AggregateClass && !hasUDFOverload {
for _, expr := range typedSubExprs {
if expr.ResolvedType().Family() == types.UnknownFamily {
return DNull, nil
Expand Down