Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inlining: automatically inline trivial values #112

Merged
merged 1 commit into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ast.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ open Options.Globals

type Ident(name: string) =
let mutable newName = name
let mutable inlined = newName.StartsWith("i_")

member this.Name = newName
member this.OldName = name
member this.Rename(n) = newName <- n
member this.MustBeInlined = this.Name.StartsWith("i_")
member this.ToBeInlined = inlined
member this.Inline() = inlined <- true

// Real identifiers cannot start with a digit, but the temporary ids of the rename pass are numbers.
member this.IsUniqueId = System.Char.IsDigit this.Name.[0]
Expand Down
60 changes: 57 additions & 3 deletions src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ let private stripSpaces str =
result.ToString()


let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.MustBeInlined)
let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.ToBeInlined)

let private bool = function
| true -> Var (Ident "true") // Int (1, "")
Expand Down Expand Up @@ -132,9 +132,9 @@ let rec private simplifyExpr env = function

| Dot(e, field) when options.canonicalFieldNames <> "" -> Dot(e, renameField field)

| Var s as e when s.MustBeInlined ->
| Var s as e ->
match env.vars.TryFind s.Name with
| Some (_, {init = Some init}) -> init |> mapExpr env
| Some (_, {name = id; init = Some init}) when id.ToBeInlined -> init |> mapExpr env
| _ -> e

// pi is acos(-1), pi/2 is acos(0)
Expand Down Expand Up @@ -165,6 +165,56 @@ let private rwTypeSpec = function
let rwType (ty: Type) =
makeType (rwTypeSpec ty.name) (Option.map stripSpaces ty.typeQ)

// Return the list of variables used in the statements, with the number of references.
let collectReferences stmtList =
let count = Dictionary<string, int>()
let collectLocalUses _ = function
| Var v as e ->
match count.TryGetValue(v.Name) with
| true, n -> count.[v.Name] <- n + 1
| false, _ -> count.[v.Name] <- 1
e
| e -> e
for expr in stmtList do
mapStmt (mapEnv collectLocalUses id) expr |> ignore
count

// Mark variables as inlinable when possible.
// For now, only mark a variable when:
// - the variable is used only once in the current block
// - the variable is not used in a sub-block (e.g. inside a loop)
// - the init value is trivial (doesn't depend on a variable)
let findInlinable block =
// Variables that are defined in this scope.
let localDefs = Dictionary<string, Ident>()
// List of expressions in the current block. Do not look in sub-blocks.
let mutable localExpr = []
for stmt: Stmt in block do
match stmt with
| Decl (_, li) ->
for def in li do
// can only inline if it has a value
match def.init with
| None -> ()
| Some init ->
localExpr <- init :: localExpr
// Inline only if the init value doesn't depend on other variables.
let deps = collectReferences [Expr init]
if deps.Count = 0 then
localDefs.[def.name.Name] <- def.name
| Expr e
| Jump (_, Some e) -> localExpr <- e :: localExpr
| Verbatim _ | Jump (_, None) | Block _ | If _| ForE _ | ForD _ | While _ | DoWhile _ -> ()

let localReferences = collectReferences (List.map Expr localExpr)
let allReferences = collectReferences block

for def in localDefs do
if not def.Value.ToBeInlined then
match localReferences.TryGetValue(def.Key), allReferences.TryGetValue(def.Key) with
| (true, 1), (true, 1) -> def.Value.Inline()
| _ -> ()

let private simplifyStmt = function
| Block [] as e -> e
| Block b ->
Expand All @@ -174,6 +224,8 @@ let private simplifyStmt = function

// Remove inner empty blocks
let b = b |> List.filter (function Block [] | Decl (_, []) -> false | _ -> true)

findInlinable b

// Try to remove blocks by using the comma operator
let returnExp = b |> Seq.tryPick (function Jump(JumpKeyword.Return, e) -> e | _ -> None)
Expand Down Expand Up @@ -216,6 +268,8 @@ let simplify li =
li
|> reorderTopLevel
|> mapTopLevel (mapEnv simplifyExpr simplifyStmt)
// A second pass, because some variables might now be inlinable.
|> mapTopLevel (mapEnv simplifyExpr simplifyStmt)
|> List.map (function
| TLDecl (ty, li) -> TLDecl (rwType ty, declsNotToInline li)
| TLVerbatim s -> TLVerbatim (stripSpaces s)
Expand Down
86 changes: 42 additions & 44 deletions tests/real/yx_long_way_from_home.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ const char *yx_long_way_from_home_frag =
"m.y+=sin(m.x*2.)*.05;"
"m.y-=length(sin(m.xz*.5))*.1;"
"m.z+=sin(m.x*.5)*.5;"
"float y=.03;"
"m.z+=step(.5,mod(m.x,1.))*.3-.15;"
"m.x=mod(m.x,.5)-.25;"
"float l=t(m.xz),z=smoothstep(.1,.13,l);"
"m.y+=.1-z*y;"
"m.y-=smoothstep(.05,0.,abs(l-.16))*.004;"
"float y=t(m.xz),z=smoothstep(.1,.13,y);"
"m.y+=.1-z*.03;"
"m.y-=smoothstep(.05,0.,abs(y-.16))*.004;"
"m.y-=(1.-z)*.01*h(m.xz);"
"}"
"m.y-=smoothstep(2.,0.,length(n.xz+vec2(-1.5,3.5)))*.2;"
Expand All @@ -91,15 +90,15 @@ const char *yx_long_way_from_home_frag =
"vec3 f=cross(vec3(-1,-1,-1),v);"
"return f;"
"}"
"vec3 e(vec3 v,float m)"
"vec3 e(vec3 v,float y)"
"{"
"v=normalize(v);"
"vec3 y=normalize(p(v)),f=normalize(cross(v,y));"
"vec3 f=normalize(p(v)),m=normalize(cross(v,f));"
"vec2 n=i;"
"n.x=n.x*2.*pi;"
"n.y=pow(n.y,1./(m+1.));"
"n.y=pow(n.y,1./(y+1.));"
"float x=sqrt(1.-n.y*n.y);"
"return cos(n.x)*x*y+sin(n.x)*x*f+n.y*v;"
"return cos(n.x)*x*f+sin(n.x)*x*m+n.y*v;"
"}"
"vec3 x(vec3 v)"
"{"
Expand Down Expand Up @@ -145,39 +144,38 @@ const char *yx_long_way_from_home_frag =
"}"
"vec3 h(vec3 v,vec3 m)"
"{"
"float x=.65,z=.18;"
"vec3 y=normalize(vec3((x-.5)*2.,z*2.,-1));"
"const float n=.0001;"
"const vec3 c=vec3(1.,.6,.2)*2.;"
"vec3 r=vec3(1),o=vec3(0);"
"for(int g=0;g<10;++g)"
"vec3 x=normalize(vec3(.3,.36,-1));"
"const float y=.0001;"
"const vec3 n=vec3(1.,.6,.2)*2.;"
"vec3 z=vec3(1),c=vec3(0);"
"for(int r=0;r<10;++r)"
"{"
"vec3 a,p;"
"float t;"
"if(d(v,m,a,p,t))"
"vec3 a,o;"
"float p;"
"if(d(v,m,a,o,p))"
"{"
"float k=1.;"
"vec3 b=vec3(1);"
"float t=1.;"
"vec3 g=vec3(1);"
"if(f==1)"
"b=vec3(.7);"
"k*=k;"
"g=vec3(.7);"
"t*=t;"
"{"
"v=a+p*.002;"
"vec3 h=reflect(m,p),u=e(p,1.);"
"m=normalize(mix(h,u,k));"
"r*=b;"
"v=a+o*.002;"
"vec3 h=reflect(m,o),u=e(o,1.);"
"m=normalize(mix(h,u,t));"
"z*=g;"
"}"
"vec3 h=d(y,n);"
"float u=dot(p,h);"
"vec3 S,R;"
"float B;"
"if(u>0.&&!d(a+p*.002,h,S,R,B))"
"o+=r*u*c;"
"vec3 h=d(x,y);"
"float u=dot(o,h);"
"vec3 b,k;"
"float S;"
"if(u>0.&&!d(a+o*.002,h,b,k,S))"
"c+=z*u*n;"
"i=s(i.y);"
"}"
"else"
" if(abs(t)>.1)"
"return o+l(m)*r;"
" if(abs(p)>.1)"
"return c+l(m)*z;"
"else"
" break;"
"}"
Expand All @@ -194,26 +192,26 @@ const char *yx_long_way_from_home_frag =
"void main()"
"{"
"vec2 v=gl_FragCoord.xy/iResolution.xy-.5;"
"float m=iTime+(v.x+iResolution.x*v.y)*1.51269;"
"i=s(m);"
"float y=iTime+(v.x+iResolution.x*v.y)*1.51269;"
"i=s(y);"
"v+=(i-.5)/iResolution.xy;"
"v.x*=iResolution.x/iResolution.y;"
"const vec3 f=vec3(-4,2,3),y=vec3(0,0,0);"
"const float x=distance(f,y);"
"const vec3 m=vec3(-4,2,3),f=vec3(0,0,0);"
"const float x=distance(m,f);"
"const vec2 z=vec2(1,2)*.015;"
"vec3 c=vec3(0),r=normalize(vec3(v,2.));"
"vec2 t=d();"
"c.xy+=t*z;"
"r.xy-=t*z*r.z/x;"
"vec3 l=y-f;"
"float p=-atan(l.y,length(l.xz)),a=-atan(l.x,l.z);"
"vec3 l=f-m;"
"float p=-atan(l.y,length(l.xz)),o=-atan(l.x,l.z);"
"c.yz*=n(p);"
"r.yz*=n(p);"
"c.xz*=n(a);"
"r.xz*=n(a);"
"c+=f;"
"vec4 u=vec4(h(c,r),1);"
"gl_FragColor=!isnan(u.x)&&u.x>=0.?u:vec4(0);"
"c.xz*=n(o);"
"r.xz*=n(o);"
"c+=m;"
"vec4 a=vec4(h(c,r),1);"
"gl_FragColor=!isnan(a.x)&&a.x>=0.?a:vec4(0);"
"}";

#endif // YX_LONG_WAY_FROM_HOME_FRAG_EXPECTED_
4 changes: 2 additions & 2 deletions tests/unit/function_comma.expected
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ const char *function_comma_frag =
"}"
"float foo()"
"{"
"float a=1.2,b=2.3;"
"return min((a=1.,b+a),0.);"
"float a=1.2;"
"return min((a=1.,2.3+a),0.);"
"}"
"float bar()"
"{"
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/inline.expected
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ int vars(int arg,int arg2)
{
return arg*(arg+arg2);
}
int arithmetic2()
{
int a=2,c=a+3;
return 4*a*c;
}
8 changes: 8 additions & 0 deletions tests/unit/inline.frag
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ int vars(int arg, int arg2)
int i_c = i_a + i_b;
return i_a * i_c;
}

int arithmetic2()
{
int a = 2;
int b = 3;
int c = a + b;
return 4 * a * c;
}
46 changes: 21 additions & 25 deletions tests/unit/inout.expected
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,20 @@ in vec3 c,v;
out vec4 o;
void main()
{
vec3 n=normalize(v),f=normalize(c),u=vec3(.1,.2,.3),z=vec3(.5,.5,.5);
float x=1.5;
vec3 p=texture(e,reflect(-n,f)).xyz,d=texture(e,refract(-n,f,1./x)).xyz,s=mix(u*d,p,.1);
o=vec4(s,1.);
vec3 l=normalize(v),u=normalize(c),f=vec3(.1,.2,.3),z=vec3(.5,.5,.5),x=texture(e,reflect(-l,u)).xyz,p=texture(e,refract(-l,u,1./1.5)).xyz,d=mix(f*p,x,.1);
o=vec4(d,1.);
}
vec3 r(vec3 z,vec3 n,vec3 C)
vec3 r(vec3 z,vec3 l,vec3 s)
{
float y=1.-clamp(dot(n,C),0.,1.);
return y*y*y*y*y*(1.-z)+z;
float C=1.-clamp(dot(l,s),0.,1.);
return C*C*C*C*C*(1.-z)+z;
}
vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)
vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w)
{
vec3 C=normalize(n+w);
float Z=1.+2048.*(1.-b)*(1.-b);
vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C);
return mix(Y,X,W);
vec3 s=normalize(l+y);
float b=1.+2048.*(1.-w)*(1.-w);
vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s);
return mix(Z,Y,X);
}

// tests/unit/inout2.frag
Expand All @@ -31,25 +29,23 @@ vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)

uniform samplerCube e;
uniform float t;
uniform vec3 a,m,l,i;
uniform vec3 m,i,a,n;
in vec3 c,v;
out vec4 o;
vec3 r(vec3 z,vec3 n,vec3 C)
vec3 r(vec3 z,vec3 l,vec3 s)
{
float y=1.-clamp(dot(n,C),0.,1.);
return y*y*y*y*y*(1.-z)+z;
float C=1.-clamp(dot(l,s),0.,1.);
return C*C*C*C*C*(1.-z)+z;
}
void main()
{
vec3 n=normalize(v),f=normalize(c),u=m,z=i;
float V=.5;
vec3 s=l+mix(u*a,a,V);
o=vec4(s,1.);
vec3 l=normalize(v),u=normalize(c),f=i,z=n,d=a+mix(f*m,m,.5);
o=vec4(d,1.);
}
vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)
vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w)
{
vec3 C=normalize(n+w);
float Z=1.+2048.*(1.-b)*(1.-b);
vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C);
return mix(Y,X,W);
vec3 s=normalize(l+y);
float b=1.+2048.*(1.-w)*(1.-w);
vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s);
return mix(Z,Y,X);
}
4 changes: 2 additions & 2 deletions tests/unit/macros.expected
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ const char *macros_frag =
"#define p$\n"
"int t()"
"{"
"int t=1,r=2,u=3,Z=4,Y=5,X=6,W=7,V=8,U=9,T=10,S=11,R=12;"
"return t+Y+R;"
"int t=2,r=3,u=4,Z=6,Y=7,X=8,W=9,V=10,U=11;"
"return 18;"
"}";

#endif // MACROS_EXPECTED_
16 changes: 8 additions & 8 deletions tests/unit/many_variables.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"int t(float t,float o,float l,float f,float a,float n,float i,float r,float u,float e,float Z,float Y)"
"{"
"float X=t,W=o,V=l,U=f,T=a,S=n;"
"int R=1,Q=2,P=3,O=4,N=5,M=6,L=7,K=8;"
"float J=i,I=r,H=u,G=e,F=Z,E=Y;"
"int D=1,C=2,B=3,A=4,z=5,y=6,x=7,w=8;"
"float v=0.;"
"int s=1,q=2,p=3,m=4,k=5,j=6,h=7,g=8;"
"float d=0.;"
"int c=1,b=2,at=3,ab=4,ac=5,ad=6,ag=7,ah=8;"
"return K+w+g+ah;"
"int R=1,Q=2,P=3,O=4,N=5,M=6,L=7;"
"float K=i,J=r,I=u,H=e,G=Z,F=Y;"
"int E=1,D=2,C=3,B=4,A=5,z=6,y=7;"
"float x=0.;"
"int w=1,v=2,s=3,q=4,p=5,m=6,k=7;"
"float j=0.;"
"int h=1,g=2,d=3,c=4,b=5,at=6,ab=7;"
"return 32;"
"}",