diff --git a/docs/reference.md b/docs/reference.md index e72a084c..40d835b1 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -145,7 +145,7 @@ In Standard ML but not in Morel: matchmatchItem [ '|' matchItem ]* match matchItempat => exp -fromSourceid in exp +fromSourcepat in exp fromFilterwhere exp filter clause fromGroupgroup groupKey1 , ... , groupKeyg [ compute agg1 , ... , agga ] diff --git a/src/main/java/net/hydromatic/morel/ast/Ast.java b/src/main/java/net/hydromatic/morel/ast/Ast.java index 2564756a..9eda2e1e 100644 --- a/src/main/java/net/hydromatic/morel/ast/Ast.java +++ b/src/main/java/net/hydromatic/morel/ast/Ast.java @@ -1251,19 +1251,26 @@ public Case copy(Exp e, java.util.List matchList) { /** From expression. */ public static class From extends Exp { - public final Map sources; + public final Map sources; public final ImmutableList steps; public final Exp yieldExp; /** The expression in the yield clause, or the default yield expression * if not specified; never null. */ public final Exp yieldExpOrDefault; - From(Pos pos, ImmutableMap sources, ImmutableList steps, + From(Pos pos, ImmutableMap sources, ImmutableList steps, Exp yieldExp) { super(pos, Op.FROM); this.sources = Objects.requireNonNull(sources); this.steps = Objects.requireNonNull(steps); - Set fields = sources.keySet(); + final Set firstFields = new HashSet<>(); + sources.keySet().forEach(pat -> + pat.visit(p -> { + if (p instanceof IdPat) { + firstFields.add(ast.id(Pos.ZERO, ((IdPat) p).name)); + } + })); + Set fields = firstFields; for (FromStep step : steps) { if (step instanceof Group) { final Group group = (Group) step; @@ -1320,7 +1327,7 @@ public Exp accept(Shuttle shuttle) { /** Creates a copy of this {@code From} with given contents, * or {@code this} if the contents are the same. */ - public From copy(Map sources, + public From copy(Map sources, java.util.List steps, Ast.Exp yieldExp) { return this.sources.equals(sources) && this.steps.equals(steps) diff --git a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java index 81ffa5e4..bcb42206 100644 --- a/src/main/java/net/hydromatic/morel/ast/AstBuilder.java +++ b/src/main/java/net/hydromatic/morel/ast/AstBuilder.java @@ -288,7 +288,7 @@ public Ast.Case caseOf(Pos pos, Ast.Exp e, return new Ast.Case(pos, e, ImmutableList.copyOf(matchList)); } - public Ast.From from(Pos pos, Map sources, + public Ast.From from(Pos pos, Map sources, List steps, Ast.Exp yieldExp) { return new Ast.From(pos, ImmutableMap.copyOf(sources), ImmutableList.copyOf(steps), yieldExp); diff --git a/src/main/java/net/hydromatic/morel/compile/Compiler.java b/src/main/java/net/hydromatic/morel/compile/Compiler.java index 6e43cac4..5403d424 100644 --- a/src/main/java/net/hydromatic/morel/compile/Compiler.java +++ b/src/main/java/net/hydromatic/morel/compile/Compiler.java @@ -172,13 +172,18 @@ public Code compile(Environment env, Ast.Exp expression) { case FROM: final Ast.From from = (Ast.From) expression; - final Map sourceCodes = new LinkedHashMap<>(); + final Map sourceCodes = new LinkedHashMap<>(); final List bindings = new ArrayList<>(); - for (Map.Entry idExp : from.sources.entrySet()) { - final Code expCode = compile(env.bindAll(bindings), idExp.getValue()); - final Ast.Id id = idExp.getKey(); - sourceCodes.put(id, expCode); - bindings.add(Binding.of(id.name, typeMap.getType(id))); + for (Map.Entry patExp : from.sources.entrySet()) { + final Code expCode = compile(env.bindAll(bindings), patExp.getValue()); + final Ast.Pat pat = patExp.getKey(); + sourceCodes.put(pat, expCode); + pat.visit(p -> { + if (p instanceof Ast.IdPat) { + final Ast.IdPat idPat = (Ast.IdPat) p; + bindings.add(Binding.of(idPat.name, typeMap.getType(pat))); + } + }); } Supplier rowSinkFactory = createRowSinkFactory(env, ImmutableList.copyOf(bindings), from.steps, @@ -491,25 +496,28 @@ private void flatten(Map matches, * @return Code for match */ private Code compileMatchList(Environment env, - Iterable matchList) { - final ImmutableList.Builder> patCodeBuilder = - ImmutableList.builder(); - for (Ast.Match match : matchList) { - final Environment[] envHolder = {env}; - match.pat.visit(pat -> { - if (pat instanceof Ast.IdPat) { - final Type paramType = typeMap.getType(pat); - envHolder[0] = envHolder[0].bind(((Ast.IdPat) pat).name, - paramType, Unit.INSTANCE); - } - }); - final Code code = compile(envHolder[0], match.e); - patCodeBuilder.add(Pair.of(expandRecordPattern(match.pat), code)); - } - final ImmutableList> patCodes = patCodeBuilder.build(); + List matchList) { + @SuppressWarnings("UnstableApiUsage") + final ImmutableList> patCodes = + matchList.stream() + .map(match -> compileMatch(env, match)) + .collect(ImmutableList.toImmutableList()); return evalEnv -> new Closure(evalEnv, patCodes); } + private Pair compileMatch(Environment env, Ast.Match match) { + final Environment[] envHolder = {env}; + match.pat.visit(pat -> { + if (pat instanceof Ast.IdPat) { + final Type paramType = typeMap.getType(pat); + envHolder[0] = envHolder[0].bind(((Ast.IdPat) pat).name, + paramType, Unit.INSTANCE); + } + }); + final Code code = compile(envHolder[0], match.e); + return Pair.of(expandRecordPattern(match.pat), code); + } + /** Expands a pattern if it is a record pattern that has an ellipsis * or if the arguments are not in the same order as the labels in the type. */ private Ast.Pat expandRecordPattern(Ast.Pat pat) { diff --git a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java index 0ebde331..fe6f012f 100644 --- a/src/main/java/net/hydromatic/morel/compile/TypeResolver.java +++ b/src/main/java/net/hydromatic/morel/compile/TypeResolver.java @@ -258,18 +258,23 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) { final Ast.From from = (Ast.From) node; env2 = env; final Map fieldVars = new LinkedHashMap<>(); - final Map fromSources = new LinkedHashMap<>(); - for (Map.Entry source : from.sources.entrySet()) { - final Ast.Id id = source.getKey(); + final Map fromSources = new LinkedHashMap<>(); + for (Map.Entry source : from.sources.entrySet()) { + final Ast.Pat pat = source.getKey(); final Ast.Exp exp = source.getValue(); final Unifier.Variable v5 = unifier.variable(); final Unifier.Variable v6 = unifier.variable(); final Ast.Exp exp2 = deduceType(env2, exp, v5); - fromSources.put(id, exp2); + final Map termMap1 = new HashMap<>(); + final Ast.Pat pat2 = + deducePatType(env2, pat, termMap1, null, v6); + fromSources.put(pat2, exp2); reg(exp, v5, unifier.apply(LIST_TY_CON, v6)); - reg(id, null, v6); - env2 = env2.bind(id.name, v6); - fieldVars.put(id, v6); + for (Map.Entry e : termMap1.entrySet()) { + env2 = env2.bind(e.getKey().name, e.getValue()); + fieldVars.put(ast.id(Pos.ZERO, e.getKey().name), + (Unifier.Variable) e.getValue()); + } } final List fromSteps = new ArrayList<>(); for (Ast.FromStep step : from.steps) { diff --git a/src/main/java/net/hydromatic/morel/eval/Codes.java b/src/main/java/net/hydromatic/morel/eval/Codes.java index 31ecd6cb..649be6f5 100644 --- a/src/main/java/net/hydromatic/morel/eval/Codes.java +++ b/src/main/java/net/hydromatic/morel/eval/Codes.java @@ -313,7 +313,7 @@ public static Applicable tyCon(Type dataType, String name) { return (env, arg) -> ImmutableList.of(name, arg); } - public static Code from(Map sources, + public static Code from(Map sources, Supplier rowSinkFactory) { if (sources.size() == 0) { return env -> { @@ -322,11 +322,11 @@ public static Code from(Map sources, return rowSink.result(env); }; } - final ImmutableList ids = ImmutableList.copyOf(sources.keySet()); + final ImmutableList pats = ImmutableList.copyOf(sources.keySet()); final ImmutableList codes = ImmutableList.copyOf(sources.values()); return env -> { final RowSink rowSink = rowSinkFactory.get(); - final Looper looper = new Looper(ids, codes, env, rowSink); + final Looper looper = new Looper(pats, codes, env, rowSink); looper.loop(0); return rowSink.result(env); }; @@ -1049,12 +1049,12 @@ private static class Looper { private final ImmutableList codes; private final RowSink rowSink; - Looper(ImmutableList ids, ImmutableList codes, EvalEnv env, + Looper(ImmutableList pats, ImmutableList codes, EvalEnv env, RowSink rowSink) { this.codes = codes; this.rowSink = rowSink; - for (Ast.Id id : ids) { - final MutableEvalEnv mutableEnv = env.bindMutable(id.name); + for (Ast.Pat pat : pats) { + final MutableEvalEnv mutableEnv = env.bindMutablePat(pat); mutableEvalEnvs.add(mutableEnv); env = mutableEnv; iterables.add(null); @@ -1071,16 +1071,18 @@ void loop(int i) { final int next = i + 1; if (next == iterables.size()) { for (Object o : iterable) { - mutableEvalEnv.set(o); - rowSink.accept(mutableEvalEnv); + if (mutableEvalEnv.setOpt(o)) { + rowSink.accept(mutableEvalEnv); + } } } else { for (Object o : iterable) { - mutableEvalEnv.set(o); - //noinspection unchecked - iterables.set(next, (Iterable) - codes.get(next).eval(mutableEvalEnvs.get(next))); - loop(next); + if (mutableEvalEnv.setOpt(o)) { + //noinspection unchecked + iterables.set(next, (Iterable) + codes.get(next).eval(mutableEvalEnvs.get(next))); + loop(next); + } } } } diff --git a/src/main/java/net/hydromatic/morel/eval/EvalEnv.java b/src/main/java/net/hydromatic/morel/eval/EvalEnv.java index fcaf5c9d..4c086ce8 100644 --- a/src/main/java/net/hydromatic/morel/eval/EvalEnv.java +++ b/src/main/java/net/hydromatic/morel/eval/EvalEnv.java @@ -18,8 +18,10 @@ */ package net.hydromatic.morel.eval; +import net.hydromatic.morel.ast.Ast; import net.hydromatic.morel.compile.Environment; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,6 +49,22 @@ default MutableEvalEnv bindMutable(String name) { return new EvalEnvs.MutableSubEvalEnv(this, name); } + /** Creates an evaluation environment that has the same content as this one, + * plus mutable slots for each name in a pattern. */ + default MutableEvalEnv bindMutablePat(Ast.Pat pat) { + if (pat instanceof Ast.IdPat) { + // Pattern is simple; use a simple implementation. + return bindMutable(((Ast.IdPat) pat).name); + } + final List names = new ArrayList<>(); + pat.visit(p -> { + if (p instanceof Ast.IdPat) { + names.add(((Ast.IdPat) p).name); + } + }); + return new EvalEnvs.MutablePatSubEvalEnv(this, pat, names); + } + /** Creates an evaluation environment that has the same content as this one, * plus a mutable slot or slots. * diff --git a/src/main/java/net/hydromatic/morel/eval/EvalEnvs.java b/src/main/java/net/hydromatic/morel/eval/EvalEnvs.java index 8b7d58e3..7be02417 100644 --- a/src/main/java/net/hydromatic/morel/eval/EvalEnvs.java +++ b/src/main/java/net/hydromatic/morel/eval/EvalEnvs.java @@ -21,6 +21,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import net.hydromatic.morel.ast.Ast; +import net.hydromatic.morel.util.Pair; + +import java.math.BigDecimal; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -117,6 +121,119 @@ public Object getOpt(String name) { } } + /** Evaluation environment that binds several slots based on a pattern. */ + static class MutablePatSubEvalEnv extends MutableArraySubEvalEnv { + private final Ast.Pat pat; + private int slot; + + MutablePatSubEvalEnv(EvalEnv parentEnv, Ast.Pat pat, List names) { + super(parentEnv, names); + this.pat = pat; + this.values = new Object[names.size()]; + assert !(pat instanceof Ast.IdPat); + } + + @Override public void set(Object value) { + if (!setOpt(value)) { + // If this error happens, perhaps your code should be calling "setOpt" + // and handling a false result appropriately. + throw new AssertionError("bind failed"); + } + } + + @Override public boolean setOpt(Object value) { + slot = 0; + return bindRecurse(pat, value); + } + + boolean bindRecurse(Ast.Pat pat, Object argValue) { + final List listValue; + final Ast.LiteralPat literalPat; + switch (pat.op) { + case ID_PAT: + this.values[slot++] = argValue; + return true; + + case WILDCARD_PAT: + return true; + + case BOOL_LITERAL_PAT: + case CHAR_LITERAL_PAT: + case STRING_LITERAL_PAT: + literalPat = (Ast.LiteralPat) pat; + return literalPat.value.equals(argValue); + + case INT_LITERAL_PAT: + literalPat = (Ast.LiteralPat) pat; + return ((BigDecimal) literalPat.value).intValue() == (Integer) argValue; + + case REAL_LITERAL_PAT: + literalPat = (Ast.LiteralPat) pat; + return ((BigDecimal) literalPat.value).doubleValue() == (Double) argValue; + + case TUPLE_PAT: + final Ast.TuplePat tuplePat = (Ast.TuplePat) pat; + listValue = (List) argValue; + for (Pair pair : Pair.zip(tuplePat.args, listValue)) { + if (!bindRecurse(pair.left, pair.right)) { + return false; + } + } + return true; + + case RECORD_PAT: + final Ast.RecordPat recordPat = (Ast.RecordPat) pat; + listValue = (List) argValue; + for (Pair pair + : Pair.zip(recordPat.args.values(), listValue)) { + if (!bindRecurse(pair.left, pair.right)) { + return false; + } + } + return true; + + case LIST_PAT: + final Ast.ListPat listPat = (Ast.ListPat) pat; + listValue = (List) argValue; + if (listValue.size() != listPat.args.size()) { + return false; + } + for (Pair pair : Pair.zip(listPat.args, listValue)) { + if (!bindRecurse(pair.left, pair.right)) { + return false; + } + } + return true; + + case CONS_PAT: + final Ast.InfixPat infixPat = (Ast.InfixPat) pat; + @SuppressWarnings("unchecked") final List consValue = + (List) argValue; + if (consValue.isEmpty()) { + return false; + } + final Object head = consValue.get(0); + final List tail = consValue.subList(1, consValue.size()); + return bindRecurse(infixPat.p0, head) + && bindRecurse(infixPat.p1, tail); + + case CON0_PAT: + final Ast.Con0Pat con0Pat = (Ast.Con0Pat) pat; + final List con0Value = (List) argValue; + return con0Value.get(0).equals(con0Pat.tyCon.name); + + case CON_PAT: + final Ast.ConPat conPat = (Ast.ConPat) pat; + final List conValue = (List) argValue; + return conValue.get(0).equals(conPat.tyCon.name) + && bindRecurse(conPat.pat, conValue.get(1)); + + default: + throw new AssertionError("cannot compile " + pat.op + ": " + pat); + } + } + } + /** Evaluation environment that reads from a map. */ static class MapEvalEnv implements EvalEnv { final Map valueMap; diff --git a/src/main/java/net/hydromatic/morel/eval/MutableEvalEnv.java b/src/main/java/net/hydromatic/morel/eval/MutableEvalEnv.java index 355422f8..c60c9268 100644 --- a/src/main/java/net/hydromatic/morel/eval/MutableEvalEnv.java +++ b/src/main/java/net/hydromatic/morel/eval/MutableEvalEnv.java @@ -20,7 +20,20 @@ /** An evaluation environment whose last entry is mutable. */ public interface MutableEvalEnv extends EvalEnv { + /** Puts a value into this environment. */ void set(Object value); + + /** Puts a value into this environment in a way that may not succeed. + * + *

For example, if this environment is based on the pattern (x, 2) + * then (1, 2) will succeed and will bind x to 1, but (3, 4) will fail. + * + *

The default implementation calls {@link #set} and always succeeds. + */ + default boolean setOpt(Object value) { + set(value); + return true; + } } // End MutableEvalEnv.java diff --git a/src/main/javacc/MorelParser.jj b/src/main/javacc/MorelParser.jj index 3d8d64e9..3ed1a390 100644 --- a/src/main/javacc/MorelParser.jj +++ b/src/main/javacc/MorelParser.jj @@ -303,7 +303,7 @@ Exp from() : { final Span span; Span stepSpan; - final Map sources = new LinkedHashMap<>(); + final Map sources = new LinkedHashMap<>(); final List steps = new ArrayList<>(); Exp filterExp; Exp yieldExp = null; @@ -351,14 +351,14 @@ Exp from() : } } -void fromSource(Map sources) : +void fromSource(Map sources) : { final Exp exp; - final Id id; + final Pat pat; } { - id = identifier() exp = expression() { - sources.put(id, exp); + pat = pat() exp = expression() { + sources.put(pat, exp); } } diff --git a/src/test/java/net/hydromatic/morel/MainTest.java b/src/test/java/net/hydromatic/morel/MainTest.java index 4142852f..31a24f89 100644 --- a/src/test/java/net/hydromatic/morel/MainTest.java +++ b/src/test/java/net/hydromatic/morel/MainTest.java @@ -1469,6 +1469,15 @@ private Matcher throwsA(Class clazz, .assertEvalIter(equalsOrdered(list())); } + @Test public void testFromPattern() { + final String ml = "from (x, y) in [(1,2),(3,4),(3,0)] group sum = x + y"; + final String expected = "from (x, y) in [(1, 2), (3, 4), (3, 0)] " + + "group sum = x + y"; + ml(ml).assertParse(expected) + .assertType(is("int list")) + .assertEvalIter(equalsUnordered(3, 7)); + } + /** Tests a program that uses an external collection from the "scott" JDBC * database. */ @Test public void testScott() { diff --git a/src/test/resources/script/relational.sml b/src/test/resources/script/relational.sml index 5163ce4d..b96e90c9 100644 --- a/src/test/resources/script/relational.sml +++ b/src/test/resources/script/relational.sml @@ -429,6 +429,29 @@ from group one = 1 compute two = sum of 2, three = sum of 3 yield {c1 = one, c5 = two + three}; +(*) Patterns left of 'in' +fun sumPairs pairs = + from (left, right) in pairs + yield left + right; +sumPairs []; +sumPairs [(1, 2), (3, 4)]; + +(*) Skip rows that do not match the pattern +from (left, 2) in [(1, 2), (3, 4), (5, 2)] + yield left; + +fun listHeads lists = + from hd :: tl in lists + yield hd + 1; +listHeads []; +listHeads [[1, 2], [3], [4, 5, 6]]; + +fun listFields lists = + from {a = x, b = y} in lists + yield x + 1; +listFields []; +listFields [{a = 1, b = 2}, {a = 3, b = 0}, {a = 4, b = 5}]; + (*) Temporary functions let fun abbrev s = diff --git a/src/test/resources/script/relational.sml.out b/src/test/resources/script/relational.sml.out index a47ef875..18d47bf3 100644 --- a/src/test/resources/script/relational.sml.out +++ b/src/test/resources/script/relational.sml.out @@ -688,6 +688,49 @@ from val it = [{c1=1,c5=5}] : {c1:int, c5:int} list +(*) Patterns left of 'in' +fun sumPairs pairs = + from (left, right) in pairs + yield left + right; +val sumPairs = fn : (int * int) list -> int list + +sumPairs []; +val it = [] : int list + +sumPairs [(1, 2), (3, 4)]; +val it = [3,7] : int list + + +(*) Skip rows that do not match the pattern +from (left, 2) in [(1, 2), (3, 4), (5, 2)] + yield left; +val it = [1,5] : int list + + +fun listHeads lists = + from hd :: tl in lists + yield hd + 1; +val listHeads = fn : int list list -> int list + +listHeads []; +val it = [] : int list + +listHeads [[1, 2], [3], [4, 5, 6]]; +val it = [2,4,5] : int list + + +fun listFields lists = + from {a = x, b = y} in lists + yield x + 1; +val listFields = fn : {a:int, b:'a} list -> int list + +listFields []; +val it = [] : int list + +listFields [{a = 1, b = 2}, {a = 3, b = 0}, {a = 4, b = 5}]; +val it = [2,4,5] : int list + + (*) Temporary functions let fun abbrev s =