Skip to content

Commit

Permalink
Inlining: automatically inline trivial values
Browse files Browse the repository at this point in the history
This is a conservative first step. Inline variables only if there value is a constant (it doesn't depend on another variable).
  • Loading branch information
laurentlb committed Jun 7, 2021
1 parent ab05b85 commit d718a80
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 85 deletions.
5 changes: 4 additions & 1 deletion src/ast.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ open Options.Globals

type Ident(name: string) =
let mutable newName = name
let mutable inlined = newName.StartsWith("i_")

member this.Name = newName
member this.OldName = name
member this.Rename(n) = newName <- n
member this.MustBeInlined = this.Name.StartsWith("i_")
member this.ToBeInlined = inlined
member this.Inline() = inlined <- true

// Real identifiers cannot start with a digit, but the temporary ids of the rename pass are numbers.
member this.IsUniqueId = System.Char.IsDigit this.Name.[0]
Expand Down
60 changes: 57 additions & 3 deletions src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ let private stripSpaces str =
result.ToString()


let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.MustBeInlined)
let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.ToBeInlined)

let private bool = function
| true -> Var (Ident "true") // Int (1, "")
Expand Down Expand Up @@ -132,9 +132,9 @@ let rec private simplifyExpr env = function

| Dot(e, field) when options.canonicalFieldNames <> "" -> Dot(e, renameField field)

| Var s as e when s.MustBeInlined ->
| Var s as e ->
match env.vars.TryFind s.Name with
| Some (_, {init = Some init}) -> init |> mapExpr env
| Some (_, {name = id; init = Some init}) when id.ToBeInlined -> init |> mapExpr env
| _ -> e

// pi is acos(-1), pi/2 is acos(0)
Expand Down Expand Up @@ -165,6 +165,56 @@ let private rwTypeSpec = function
let rwType (ty: Type) =
makeType (rwTypeSpec ty.name) (Option.map stripSpaces ty.typeQ)

// Return the list of variables used in the statements, with the number of references.
let collectReferences stmtList =
let count = Dictionary<string, int>()
let collectLocalUses _ = function
| Var v as e ->
match count.TryGetValue(v.Name) with
| true, n -> count.[v.Name] <- n + 1
| false, _ -> count.[v.Name] <- 1
e
| e -> e
for expr in stmtList do
mapStmt (mapEnv collectLocalUses id) expr |> ignore
count

// Mark variables as inlinable when possible.
// For now, only mark a variable when:
// - the variable is used only once in the current block
// - the variable is not used in a sub-block (e.g. inside a loop)
// - the init value is trivial (doesn't depend on a variable)
let findInlinable block =
// Variables that are defined in this scope.
let localDefs = Dictionary<string, Ident>()
// List of expressions in the current block. Do not look in sub-blocks.
let mutable localExpr = []
for stmt: Stmt in block do
match stmt with
| Decl (_, li) ->
for def in li do
// can only inline if it has a value
match def.init with
| None -> ()
| Some init ->
localExpr <- init :: localExpr
// Inline only if the init value doesn't depend on other variables.
let deps = collectReferences [Expr init]
if deps.Count = 0 then
localDefs.[def.name.Name] <- def.name
| Expr e
| Jump (_, Some e) -> localExpr <- e :: localExpr
| Verbatim _ | Jump (_, None) | Block _ | If _| ForE _ | ForD _ | While _ | DoWhile _ -> ()

let localReferences = collectReferences (List.map Expr localExpr)
let allReferences = collectReferences block

for def in localDefs do
if not def.Value.ToBeInlined then
match localReferences.TryGetValue(def.Key), allReferences.TryGetValue(def.Key) with
| (true, 1), (true, 1) -> def.Value.Inline()
| _ -> ()

let private simplifyStmt = function
| Block [] as e -> e
| Block b ->
Expand All @@ -174,6 +224,8 @@ let private simplifyStmt = function

// Remove inner empty blocks
let b = b |> List.filter (function Block [] | Decl (_, []) -> false | _ -> true)

findInlinable b

// Try to remove blocks by using the comma operator
let returnExp = b |> Seq.tryPick (function Jump(JumpKeyword.Return, e) -> e | _ -> None)
Expand Down Expand Up @@ -216,6 +268,8 @@ let simplify li =
li
|> reorderTopLevel
|> mapTopLevel (mapEnv simplifyExpr simplifyStmt)
// A second pass, because some variables might now be inlinable.
|> mapTopLevel (mapEnv simplifyExpr simplifyStmt)
|> List.map (function
| TLDecl (ty, li) -> TLDecl (rwType ty, declsNotToInline li)
| TLVerbatim s -> TLVerbatim (stripSpaces s)
Expand Down
86 changes: 42 additions & 44 deletions tests/real/yx_long_way_from_home.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ const char *yx_long_way_from_home_frag =
"m.y+=sin(m.x*2.)*.05;"
"m.y-=length(sin(m.xz*.5))*.1;"
"m.z+=sin(m.x*.5)*.5;"
"float y=.03;"
"m.z+=step(.5,mod(m.x,1.))*.3-.15;"
"m.x=mod(m.x,.5)-.25;"
"float l=t(m.xz),z=smoothstep(.1,.13,l);"
"m.y+=.1-z*y;"
"m.y-=smoothstep(.05,0.,abs(l-.16))*.004;"
"float y=t(m.xz),z=smoothstep(.1,.13,y);"
"m.y+=.1-z*.03;"
"m.y-=smoothstep(.05,0.,abs(y-.16))*.004;"
"m.y-=(1.-z)*.01*h(m.xz);"
"}"
"m.y-=smoothstep(2.,0.,length(n.xz+vec2(-1.5,3.5)))*.2;"
Expand All @@ -91,15 +90,15 @@ const char *yx_long_way_from_home_frag =
"vec3 f=cross(vec3(-1,-1,-1),v);"
"return f;"
"}"
"vec3 e(vec3 v,float m)"
"vec3 e(vec3 v,float y)"
"{"
"v=normalize(v);"
"vec3 y=normalize(p(v)),f=normalize(cross(v,y));"
"vec3 f=normalize(p(v)),m=normalize(cross(v,f));"
"vec2 n=i;"
"n.x=n.x*2.*pi;"
"n.y=pow(n.y,1./(m+1.));"
"n.y=pow(n.y,1./(y+1.));"
"float x=sqrt(1.-n.y*n.y);"
"return cos(n.x)*x*y+sin(n.x)*x*f+n.y*v;"
"return cos(n.x)*x*f+sin(n.x)*x*m+n.y*v;"
"}"
"vec3 x(vec3 v)"
"{"
Expand Down Expand Up @@ -145,39 +144,38 @@ const char *yx_long_way_from_home_frag =
"}"
"vec3 h(vec3 v,vec3 m)"
"{"
"float x=.65,z=.18;"
"vec3 y=normalize(vec3((x-.5)*2.,z*2.,-1));"
"const float n=.0001;"
"const vec3 c=vec3(1.,.6,.2)*2.;"
"vec3 r=vec3(1),o=vec3(0);"
"for(int g=0;g<10;++g)"
"vec3 x=normalize(vec3(.3,.36,-1));"
"const float y=.0001;"
"const vec3 n=vec3(1.,.6,.2)*2.;"
"vec3 z=vec3(1),c=vec3(0);"
"for(int r=0;r<10;++r)"
"{"
"vec3 a,p;"
"float t;"
"if(d(v,m,a,p,t))"
"vec3 a,o;"
"float p;"
"if(d(v,m,a,o,p))"
"{"
"float k=1.;"
"vec3 b=vec3(1);"
"float t=1.;"
"vec3 g=vec3(1);"
"if(f==1)"
"b=vec3(.7);"
"k*=k;"
"g=vec3(.7);"
"t*=t;"
"{"
"v=a+p*.002;"
"vec3 h=reflect(m,p),u=e(p,1.);"
"m=normalize(mix(h,u,k));"
"r*=b;"
"v=a+o*.002;"
"vec3 h=reflect(m,o),u=e(o,1.);"
"m=normalize(mix(h,u,t));"
"z*=g;"
"}"
"vec3 h=d(y,n);"
"float u=dot(p,h);"
"vec3 S,R;"
"float B;"
"if(u>0.&&!d(a+p*.002,h,S,R,B))"
"o+=r*u*c;"
"vec3 h=d(x,y);"
"float u=dot(o,h);"
"vec3 b,k;"
"float S;"
"if(u>0.&&!d(a+o*.002,h,b,k,S))"
"c+=z*u*n;"
"i=s(i.y);"
"}"
"else"
" if(abs(t)>.1)"
"return o+l(m)*r;"
" if(abs(p)>.1)"
"return c+l(m)*z;"
"else"
" break;"
"}"
Expand All @@ -194,26 +192,26 @@ const char *yx_long_way_from_home_frag =
"void main()"
"{"
"vec2 v=gl_FragCoord.xy/iResolution.xy-.5;"
"float m=iTime+(v.x+iResolution.x*v.y)*1.51269;"
"i=s(m);"
"float y=iTime+(v.x+iResolution.x*v.y)*1.51269;"
"i=s(y);"
"v+=(i-.5)/iResolution.xy;"
"v.x*=iResolution.x/iResolution.y;"
"const vec3 f=vec3(-4,2,3),y=vec3(0,0,0);"
"const float x=distance(f,y);"
"const vec3 m=vec3(-4,2,3),f=vec3(0,0,0);"
"const float x=distance(m,f);"
"const vec2 z=vec2(1,2)*.015;"
"vec3 c=vec3(0),r=normalize(vec3(v,2.));"
"vec2 t=d();"
"c.xy+=t*z;"
"r.xy-=t*z*r.z/x;"
"vec3 l=y-f;"
"float p=-atan(l.y,length(l.xz)),a=-atan(l.x,l.z);"
"vec3 l=f-m;"
"float p=-atan(l.y,length(l.xz)),o=-atan(l.x,l.z);"
"c.yz*=n(p);"
"r.yz*=n(p);"
"c.xz*=n(a);"
"r.xz*=n(a);"
"c+=f;"
"vec4 u=vec4(h(c,r),1);"
"gl_FragColor=!isnan(u.x)&&u.x>=0.?u:vec4(0);"
"c.xz*=n(o);"
"r.xz*=n(o);"
"c+=m;"
"vec4 a=vec4(h(c,r),1);"
"gl_FragColor=!isnan(a.x)&&a.x>=0.?a:vec4(0);"
"}";

#endif // YX_LONG_WAY_FROM_HOME_FRAG_EXPECTED_
4 changes: 2 additions & 2 deletions tests/unit/function_comma.expected
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ const char *function_comma_frag =
"}"
"float foo()"
"{"
"float a=1.2,b=2.3;"
"return min((a=1.,b+a),0.);"
"float a=1.2;"
"return min((a=1.,2.3+a),0.);"
"}"
"float bar()"
"{"
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/inline.expected
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ int vars(int arg,int arg2)
{
return arg*(arg+arg2);
}
int arithmetic2()
{
int a=2,c=a+3;
return 4*a*c;
}
8 changes: 8 additions & 0 deletions tests/unit/inline.frag
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ int vars(int arg, int arg2)
int i_c = i_a + i_b;
return i_a * i_c;
}

int arithmetic2()
{
int a = 2;
int b = 3;
int c = a + b;
return 4 * a * c;
}
46 changes: 21 additions & 25 deletions tests/unit/inout.expected
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,20 @@ in vec3 c,v;
out vec4 o;
void main()
{
vec3 n=normalize(v),f=normalize(c),u=vec3(.1,.2,.3),z=vec3(.5,.5,.5);
float x=1.5;
vec3 p=texture(e,reflect(-n,f)).xyz,d=texture(e,refract(-n,f,1./x)).xyz,s=mix(u*d,p,.1);
o=vec4(s,1.);
vec3 l=normalize(v),u=normalize(c),f=vec3(.1,.2,.3),z=vec3(.5,.5,.5),x=texture(e,reflect(-l,u)).xyz,p=texture(e,refract(-l,u,1./1.5)).xyz,d=mix(f*p,x,.1);
o=vec4(d,1.);
}
vec3 r(vec3 z,vec3 n,vec3 C)
vec3 r(vec3 z,vec3 l,vec3 s)
{
float y=1.-clamp(dot(n,C),0.,1.);
return y*y*y*y*y*(1.-z)+z;
float C=1.-clamp(dot(l,s),0.,1.);
return C*C*C*C*C*(1.-z)+z;
}
vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)
vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w)
{
vec3 C=normalize(n+w);
float Z=1.+2048.*(1.-b)*(1.-b);
vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C);
return mix(Y,X,W);
vec3 s=normalize(l+y);
float b=1.+2048.*(1.-w)*(1.-w);
vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s);
return mix(Z,Y,X);
}

// tests/unit/inout2.frag
Expand All @@ -31,25 +29,23 @@ vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)

uniform samplerCube e;
uniform float t;
uniform vec3 a,m,l,i;
uniform vec3 m,i,a,n;
in vec3 c,v;
out vec4 o;
vec3 r(vec3 z,vec3 n,vec3 C)
vec3 r(vec3 z,vec3 l,vec3 s)
{
float y=1.-clamp(dot(n,C),0.,1.);
return y*y*y*y*y*(1.-z)+z;
float C=1.-clamp(dot(l,s),0.,1.);
return C*C*C*C*C*(1.-z)+z;
}
void main()
{
vec3 n=normalize(v),f=normalize(c),u=m,z=i;
float V=.5;
vec3 s=l+mix(u*a,a,V);
o=vec4(s,1.);
vec3 l=normalize(v),u=normalize(c),f=i,z=n,d=a+mix(f*m,m,.5);
o=vec4(d,1.);
}
vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b)
vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w)
{
vec3 C=normalize(n+w);
float Z=1.+2048.*(1.-b)*(1.-b);
vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C);
return mix(Y,X,W);
vec3 s=normalize(l+y);
float b=1.+2048.*(1.-w)*(1.-w);
vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s);
return mix(Z,Y,X);
}
4 changes: 2 additions & 2 deletions tests/unit/macros.expected
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ const char *macros_frag =
"#define p$\n"
"int t()"
"{"
"int t=1,r=2,u=3,Z=4,Y=5,X=6,W=7,V=8,U=9,T=10,S=11,R=12;"
"return t+Y+R;"
"int t=2,r=3,u=4,Z=6,Y=7,X=8,W=9,V=10,U=11;"
"return 18;"
"}";

#endif // MACROS_EXPECTED_
16 changes: 8 additions & 8 deletions tests/unit/many_variables.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
"int t(float t,float o,float l,float f,float a,float n,float i,float r,float u,float e,float Z,float Y)"
"{"
"float X=t,W=o,V=l,U=f,T=a,S=n;"
"int R=1,Q=2,P=3,O=4,N=5,M=6,L=7,K=8;"
"float J=i,I=r,H=u,G=e,F=Z,E=Y;"
"int D=1,C=2,B=3,A=4,z=5,y=6,x=7,w=8;"
"float v=0.;"
"int s=1,q=2,p=3,m=4,k=5,j=6,h=7,g=8;"
"float d=0.;"
"int c=1,b=2,at=3,ab=4,ac=5,ad=6,ag=7,ah=8;"
"return K+w+g+ah;"
"int R=1,Q=2,P=3,O=4,N=5,M=6,L=7;"
"float K=i,J=r,I=u,H=e,G=Z,F=Y;"
"int E=1,D=2,C=3,B=4,A=5,z=6,y=7;"
"float x=0.;"
"int w=1,v=2,s=3,q=4,p=5,m=6,k=7;"
"float j=0.;"
"int h=1,g=2,d=3,c=4,b=5,at=6,ab=7;"
"return 32;"
"}",

0 comments on commit d718a80

Please sign in to comment.