Skip to content

Commit

Permalink
Sanely handle current function overloading limitations in inlining (#292
Browse files Browse the repository at this point in the history
)
  • Loading branch information
eldritchconundrum authored and laurentlb committed Apr 30, 2023
1 parent dc1ae1b commit bed8a83
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 26 deletions.
12 changes: 7 additions & 5 deletions src/analyzer.fs
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ let resolve topLevel =
let resolveExpr (env: MapEnv) = function
| FunCall (Var v, args) as e ->
match env.fns.TryFind (v.Name, args.Length) with
| Some (ft, _) -> v.Declaration <- ft.fName.Declaration
| _ -> () // TODO: resolve builtin functions
| Some [(ft, _)] -> v.Declaration <- ft.fName.Declaration
| None -> () // TODO: resolve builtin functions
| _ -> () // TODO: support type-based disambiguation of user-defined function overloading
e
| Var v as e ->
match env.vars.TryFind v.Name with
Expand Down Expand Up @@ -248,13 +249,12 @@ module private FunctionInlining =
mapStmt (mapEnv collect id) block |> ignore<MapEnv * Stmt>
callSites |> Seq.toList

// This function assumes that user-defined functions are NOT overloaded
type FuncInfo = {
func: TopLevel
funcType: FunctionType
body: Stmt
name: string
callSites: CallSite list
callSites: CallSite list // calls to other user-defined functions, from inside this function.
}
let findFuncInfos code =
let functions = code |> List.choose (function
Expand Down Expand Up @@ -327,8 +327,10 @@ module private FunctionInlining =
for funcInfo in funcInfos do
let canBeRenamed = not (options.noRenamingList |> List.contains funcInfo.name) // noRenamingList includes "main"
let isExternal = options.hlsl && funcInfo.funcType.semantics <> []
if canBeRenamed && not isExternal then
let isOverloadedAmbiguously = funcInfos |> List.except [funcInfo] |> List.exists (fun f -> f.funcType.prototype = funcInfo.funcType.prototype)
if canBeRenamed && not isExternal && not isOverloadedAmbiguously then
if not funcInfo.funcType.hasOutOrInoutParams then // [F]
// Find calls to this function. This works because we checked that the function is not overloaded ambiguously.
let callSites = funcInfos |> List.collect (fun n -> n.callSites)
|> List.filter (fun callSite -> callSite.prototype = funcInfo.funcType.prototype)
if callSites.Length > 0 then // Unused function elimination is not handled here
Expand Down
5 changes: 3 additions & 2 deletions src/ast.fs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ type MapEnv = {
fExpr: MapEnv -> Expr -> Expr
fStmt: Stmt -> Stmt
vars: Map<string, Type * DeclElt>
fns: Map<(string * int), FunctionType * Stmt> // this map assumes that user-defined functions are never overloaded by more than parameter count
fns: Map<(string * int), (FunctionType * Stmt) list> // This doesn't support type-based disambiguation of user-defined function overloading
isInWritePosition: bool // used for findWrites only
}

Expand Down Expand Up @@ -245,7 +245,8 @@ let mapTopLevel env li =
let env, res = mapDecl env t
env, TLDecl res
| Function(fct, body) ->
let envWithFunction = {env with fns = env.fns.Add((fct.fName.Name, fct.args.Length), (fct, body))}
let newFns = (fct, body) :: (env.fns.TryFind(fct.prototype) |> Option.defaultValue [])
let envWithFunction = {env with fns = env.fns.Add(fct.prototype, newFns)}
let envWithFunctionAndArgs, args = foldList envWithFunction mapDecl fct.args
envWithFunction, Function({ fct with args = args }, snd (mapStmt envWithFunctionAndArgs body))
| e -> env, e)
Expand Down
9 changes: 5 additions & 4 deletions src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,9 @@ let private simplifyVec (constr: Ident) args =
let private simplifyExpr (didInline: bool ref) env = function
| FunCall(Var v, passedArgs) as e when v.ToBeInlined ->
match env.fns.TryFind (v.Name, passedArgs.Length) with
| None -> e
| Some ({args = declArgs}, body) ->
| Some ([{args = declArgs}, body]) ->
if List.length declArgs <> List.length passedArgs then
failwithf "Cannot inline %s since it doesn't have the right number of arguments" v.Name
failwithf "Cannot inline function %s since it doesn't have the right number of arguments" v.Name
match body with
| Jump (JumpKeyword.Return, Some bodyExpr)
| Block [Jump (JumpKeyword.Return, Some bodyExpr)] ->
Expand All @@ -276,7 +275,9 @@ let private simplifyExpr (didInline: bool ref) env = function
// turned the function into a one-liner, so allow trying again on
// the next pass. (If it didn't, we'll yell next pass.)
| _ when didInline.Value -> e
| _ -> failwithf "Cannot inline %s since it consists of more than a single return" v.Name
| _ -> failwithf "Cannot inline function %s since it consists of more than a single return" v.Name
| None -> failwithf "Cannot inline function %s because it's a builtin" v.Name
| _ -> failwithf "Cannot inline function %s because type-based disambiguation of user-defined function overloading is not supported" v.Name

| FunCall(Op _, _) as op -> simplifyOperator env op
| FunCall(Var constr, args) when constr.Name = "vec2" || constr.Name = "vec3" || constr.Name = "vec4" ->
Expand Down
1 change: 1 addition & 0 deletions tests/commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
-o tests/real/oscars_chair.frag.expected tests/real/oscars_chair.frag
-o tests/real/the_real_party_is_in_your_pocket.frag.expected tests/real/the_real_party_is_in_your_pocket.frag
--no-remove-unused --no-inlining -o tests/unit/function_overload.expected tests/unit/function_overload.frag
-o tests/unit/overload.expected tests/unit/overload.frag
--no-remove-unused -o tests/unit/externals.expected tests/unit/externals.frag
--no-remove-unused -o tests/unit/qualifiers.expected tests/unit/qualifiers.frag
--no-remove-unused -o tests/unit/macros.expected --no-inlining tests/unit/macros.frag
Expand Down
62 changes: 62 additions & 0 deletions tests/unit/overload.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Generated with (https://github.com/laurentlb/Shader_Minifier/)
#ifndef OVERLOAD_EXPECTED_
# define OVERLOAD_EXPECTED_
# define VAR_ZH "t"
# define VAR__V "x"
# define VAR_a_ "i"
# define VAR_stopinlining "v"

const char *overload_frag =
"#version 330\n"
"uniform sampler2D x;"
"uniform vec2 i;"
"in vec2 t;"
"out float v;"
"float h()"
"{"
"return v+.1;"
"}"
"float f()"
"{"
"return v+.2;"
"}"
"float h(int t)"
"{"
"return v+.3;"
"}"
"float f(float t)"
"{"
"return v+.4;"
"}"
"float f(bool t)"
"{"
"return v+.5;"
"}"
"float f(int t,int f)"
"{"
"return v+.6;"
"}"
"float f(int t,int f,int x,int h)"
"{"
"return v+.7;"
"}"
"float f(sampler2D t,float v)"
"{"
"return texelFetch(t,ivec2(255.*v)%256,0).x;"
"}"
"float f(sampler2D t,vec2 v)"
"{"
"return texelFetch(t,ivec2(255.*v)%256,0).x;"
"}"
"float f(sampler2D t,vec3 v)"
"{"
"float f=texelFetch(t,ivec2(255.*v.yz)%256,0).x;"
"return texelFetch(t,ivec2(255.*v.x,255.*f)%256,0).x;"
"}"
"void main()"
"{"
"float v=0.,r=f(true)+f(0,1)-f(0,1,2,3),C=f(x,i*t);"
"gl_FragColor=vec4(h()+2.*h(0),2.*f(1.2)-f(),r,C+v++);"
"}";

#endif // OVERLOAD_EXPECTED_
37 changes: 37 additions & 0 deletions tests/unit/overload.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#version 330
uniform sampler2D _V;
uniform vec2 a_;
in vec2 ZH;
out float stopinlining;

float f1() { return stopinlining+0.1; }
float f2(void) { return stopinlining+0.2; }
float f1(int x) { return stopinlining+0.3; }
float f2(float y) { return stopinlining+0.4; }
float f3(bool b) { return stopinlining+0.5; }
float f3(int a, int b) { return stopinlining+0.6; }
float f3(int a, int b, int c, int d) { return stopinlining+0.7; }

float hashTex(sampler2D _V, float p)
{
return texelFetch(_V, ivec2(255. * p) % 256, 0).r;
}
float hashTex(sampler2D _V, vec2 p)
{
return texelFetch(_V, ivec2(255. * p) % 256, 0).r;
}
float hashTex(sampler2D _V, vec3 p)
{
float h = texelFetch(_V, ivec2(255. * p.yz) % 256, 0).r;
return texelFetch(_V, ivec2(255. * p.x, 255. * h) % 256, 0).r;
}

void main()
{
float stopinlining = 0.;
float a = f1() + 2. * f1(0);
float b = 2. * f2(1.2) - f2();
float c = f3(true) + f3(0, 1) - f3(0, 1, 2, 3);
float d = hashTex(_V, a_ * ZH);
gl_FragColor=vec4(a,b,c,d+stopinlining++);
}
15 changes: 0 additions & 15 deletions tests/unit/overload.frag.disabled

This file was deleted.

0 comments on commit bed8a83

Please sign in to comment.