diff --git a/src/ast.fs b/src/ast.fs index 56aaf990..3b8a2920 100644 --- a/src/ast.fs +++ b/src/ast.fs @@ -91,6 +91,10 @@ and Type = { not (Set.intersect (set this.typeQ) (set ["out"; "inout"])).IsEmpty member this.IsExternal = List.exists (fun s -> Set.contains s Builtin.externalQualifiers) this.typeQ + member this.isScalar = + match this.name with + | TypeName n -> Builtin.builtinScalarTypes.Contains n + | _ -> false override t.ToString() = let name = match t.name with | TypeName n -> n diff --git a/src/builtin.fs b/src/builtin.fs index 704a209f..688486b1 100644 --- a/src/builtin.fs +++ b/src/builtin.fs @@ -10,18 +10,24 @@ let keywords = System.Collections.Generic.HashSet<_>([ "const"; "uniform"; "buffer"; "shared"; "attribute"; "varying" ]) -let builtinTypes = set([ - yield! [ "void"; "bool"; "int"; "uint"; "float"; "double" ] +let builtinScalarTypes = set [ + "bool"; "int"; "uint"; "float"; "double" +] +let builtinVectorTypes = set([ for p in [""; "d"; "b"; "i"; "u"] do for n in ["2"; "3"; "4"] do yield p+"vec"+n +]) +let builtinMatrixTypes = set([ for p in [""; "d"] do for n in ["2"; "3"; "4"] do yield p+"mat"+n for c in ["2"; "3"; "4"] do for r in ["2"; "3"; "4"] do yield p+"mat"+c+"x"+r - ]) +]) + +let builtinTypes = set [ "void" ] + builtinScalarTypes + builtinVectorTypes + builtinMatrixTypes; let implicitConversions = // (from, to) [ diff --git a/src/rewriter.fs b/src/rewriter.fs index f97e791a..67290717 100644 --- a/src/rewriter.fs +++ b/src/rewriter.fs @@ -218,8 +218,13 @@ module private RewriterImpl = // x=...+x -> x+=... // Works only if the operator is commutative. * is not commutative with vectors and matrices. - | FunCall(Op "=", [Var x; FunCall(Op ("+"|"&"|"^"|"|" as op), [e; Var y])]) - when x.Name = y.Name -> + | FunCall(Op "=", [Var x; FunCall(Op ("+"|"*"|"&"|"^"|"|" as op), [e; Var y])]) + when x.Name = y.Name + && match x.VarDecl with + // * is commutative when at least one operand is scalar + | Some d -> op <> "*" || d.ty.isScalar + | _ -> false + -> FunCall(Op (op + "="), [Var x; e]) // Unsafe when x contains NaN or Inf values. diff --git a/tests/real/audio-flight-v2.frag.expected b/tests/real/audio-flight-v2.frag.expected index a957f514..b31e758a 100644 --- a/tests/real/audio-flight-v2.frag.expected +++ b/tests/real/audio-flight-v2.frag.expected @@ -188,7 +188,7 @@ vec2 marcher(vec3 ro,vec3 rd) } vec3 normal(vec3 p,float t) { - t=MINDIST*t; + t*=MINDIST; vec2 h=vec2(1,-1)*.5773; return normalize(h.xyy*map(p+h.xyy*t,0.).x+h.yyx*map(p+h.yyx*t,0.).x+h.yxy*map(p+h.yxy*t,0.).x+h.xxx*map(p+h.xxx*t,0.).x); } diff --git a/tests/unit/inline.no.expected b/tests/unit/inline.no.expected index 64dfc981..6400c5ce 100644 --- a/tests/unit/inline.no.expected +++ b/tests/unit/inline.no.expected @@ -4,7 +4,7 @@ float result; void main() { float x=.5; - x=.6*x*x; + x*=.6*x; result=x; } int arithmetic()