Skip to content

Commit

Permalink
[MOREL-31] In from expression, allow in to assign to pattern
Browse files Browse the repository at this point in the history
Rows that do not match the pattern are skipped.
  • Loading branch information
julianhyde committed Apr 23, 2020
1 parent 9829483 commit b1f5955
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 53 deletions.
2 changes: 1 addition & 1 deletion docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ In Standard ML but not in Morel:
<i>match</i> &rarr; <i>matchItem</i> [ '<b>|</b>' <i>matchItem</i> ]*
match
<i>matchItem</i> &rarr; <i>pat</i> <b>=&gt;</b> <i>exp</i>
<i>fromSource</i> &rarr; <i>id</i> <b>in</b> <i>exp</i>
<i>fromSource</i> &rarr; <i>pat</i> <b>in</b> <i>exp</i>
<i>fromFilter</i> &rarr; <b>where</b> <i>exp</i> filter clause
<i>fromGroup</i> &rarr; <b>group</b> <i>groupKey<sub>1</sub></i> <b>,</b> ... <b>,</b> <i>groupKey<sub>g</sub></i>
[ <b>compute</b> <i>agg<sub>1</sub></i> <b>,</b> ... <b>,</b> <i>agg<sub>a</sub></i> ]
Expand Down
15 changes: 11 additions & 4 deletions src/main/java/net/hydromatic/morel/ast/Ast.java
Original file line number Diff line number Diff line change
Expand Up @@ -1251,19 +1251,26 @@ public Case copy(Exp e, java.util.List<Match> matchList) {

/** From expression. */
public static class From extends Exp {
public final Map<Id, Exp> sources;
public final Map<Pat, Exp> sources;
public final ImmutableList<FromStep> steps;
public final Exp yieldExp;
/** The expression in the yield clause, or the default yield expression
* if not specified; never null. */
public final Exp yieldExpOrDefault;

From(Pos pos, ImmutableMap<Id, Exp> sources, ImmutableList<FromStep> steps,
From(Pos pos, ImmutableMap<Pat, Exp> sources, ImmutableList<FromStep> steps,
Exp yieldExp) {
super(pos, Op.FROM);
this.sources = Objects.requireNonNull(sources);
this.steps = Objects.requireNonNull(steps);
Set<Id> fields = sources.keySet();
final Set<Id> firstFields = new HashSet<>();
sources.keySet().forEach(pat ->
pat.visit(p -> {
if (p instanceof IdPat) {
firstFields.add(ast.id(Pos.ZERO, ((IdPat) p).name));
}
}));
Set<Id> fields = firstFields;
for (FromStep step : steps) {
if (step instanceof Group) {
final Group group = (Group) step;
Expand Down Expand Up @@ -1320,7 +1327,7 @@ public Exp accept(Shuttle shuttle) {

/** Creates a copy of this {@code From} with given contents,
* or {@code this} if the contents are the same. */
public From copy(Map<Ast.Id, Ast.Exp> sources,
public From copy(Map<Ast.Pat, Ast.Exp> sources,
java.util.List<FromStep> steps, Ast.Exp yieldExp) {
return this.sources.equals(sources)
&& this.steps.equals(steps)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/hydromatic/morel/ast/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public Ast.Case caseOf(Pos pos, Ast.Exp e,
return new Ast.Case(pos, e, ImmutableList.copyOf(matchList));
}

public Ast.From from(Pos pos, Map<Ast.Id, Ast.Exp> sources,
public Ast.From from(Pos pos, Map<Ast.Pat, Ast.Exp> sources,
List<Ast.FromStep> steps, Ast.Exp yieldExp) {
return new Ast.From(pos, ImmutableMap.copyOf(sources),
ImmutableList.copyOf(steps), yieldExp);
Expand Down
52 changes: 30 additions & 22 deletions src/main/java/net/hydromatic/morel/compile/Compiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,18 @@ public Code compile(Environment env, Ast.Exp expression) {

case FROM:
final Ast.From from = (Ast.From) expression;
final Map<Ast.Id, Code> sourceCodes = new LinkedHashMap<>();
final Map<Ast.Pat, Code> sourceCodes = new LinkedHashMap<>();
final List<Binding> bindings = new ArrayList<>();
for (Map.Entry<Ast.Id, Ast.Exp> idExp : from.sources.entrySet()) {
final Code expCode = compile(env.bindAll(bindings), idExp.getValue());
final Ast.Id id = idExp.getKey();
sourceCodes.put(id, expCode);
bindings.add(Binding.of(id.name, typeMap.getType(id)));
for (Map.Entry<Ast.Pat, Ast.Exp> patExp : from.sources.entrySet()) {
final Code expCode = compile(env.bindAll(bindings), patExp.getValue());
final Ast.Pat pat = patExp.getKey();
sourceCodes.put(pat, expCode);
pat.visit(p -> {
if (p instanceof Ast.IdPat) {
final Ast.IdPat idPat = (Ast.IdPat) p;
bindings.add(Binding.of(idPat.name, typeMap.getType(pat)));
}
});
}
Supplier<Codes.RowSink> rowSinkFactory =
createRowSinkFactory(env, ImmutableList.copyOf(bindings), from.steps,
Expand Down Expand Up @@ -491,25 +496,28 @@ private void flatten(Map<Ast.Pat, Ast.Exp> matches,
* @return Code for match
*/
private Code compileMatchList(Environment env,
Iterable<Ast.Match> matchList) {
final ImmutableList.Builder<Pair<Ast.Pat, Code>> patCodeBuilder =
ImmutableList.builder();
for (Ast.Match match : matchList) {
final Environment[] envHolder = {env};
match.pat.visit(pat -> {
if (pat instanceof Ast.IdPat) {
final Type paramType = typeMap.getType(pat);
envHolder[0] = envHolder[0].bind(((Ast.IdPat) pat).name,
paramType, Unit.INSTANCE);
}
});
final Code code = compile(envHolder[0], match.e);
patCodeBuilder.add(Pair.of(expandRecordPattern(match.pat), code));
}
final ImmutableList<Pair<Ast.Pat, Code>> patCodes = patCodeBuilder.build();
List<Ast.Match> matchList) {
@SuppressWarnings("UnstableApiUsage")
final ImmutableList<Pair<Ast.Pat, Code>> patCodes =
matchList.stream()
.map(match -> compileMatch(env, match))
.collect(ImmutableList.toImmutableList());
return evalEnv -> new Closure(evalEnv, patCodes);
}

private Pair<Ast.Pat, Code> compileMatch(Environment env, Ast.Match match) {
final Environment[] envHolder = {env};
match.pat.visit(pat -> {
if (pat instanceof Ast.IdPat) {
final Type paramType = typeMap.getType(pat);
envHolder[0] = envHolder[0].bind(((Ast.IdPat) pat).name,
paramType, Unit.INSTANCE);
}
});
final Code code = compile(envHolder[0], match.e);
return Pair.of(expandRecordPattern(match.pat), code);
}

/** Expands a pattern if it is a record pattern that has an ellipsis
* or if the arguments are not in the same order as the labels in the type. */
private Ast.Pat expandRecordPattern(Ast.Pat pat) {
Expand Down
19 changes: 12 additions & 7 deletions src/main/java/net/hydromatic/morel/compile/TypeResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,23 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
final Ast.From from = (Ast.From) node;
env2 = env;
final Map<Ast.Id, Unifier.Variable> fieldVars = new LinkedHashMap<>();
final Map<Ast.Id, Ast.Exp> fromSources = new LinkedHashMap<>();
for (Map.Entry<Ast.Id, Ast.Exp> source : from.sources.entrySet()) {
final Ast.Id id = source.getKey();
final Map<Ast.Pat, Ast.Exp> fromSources = new LinkedHashMap<>();
for (Map.Entry<Ast.Pat, Ast.Exp> source : from.sources.entrySet()) {
final Ast.Pat pat = source.getKey();
final Ast.Exp exp = source.getValue();
final Unifier.Variable v5 = unifier.variable();
final Unifier.Variable v6 = unifier.variable();
final Ast.Exp exp2 = deduceType(env2, exp, v5);
fromSources.put(id, exp2);
final Map<Ast.IdPat, Unifier.Term> termMap1 = new HashMap<>();
final Ast.Pat pat2 =
deducePatType(env2, pat, termMap1, null, v6);
fromSources.put(pat2, exp2);
reg(exp, v5, unifier.apply(LIST_TY_CON, v6));
reg(id, null, v6);
env2 = env2.bind(id.name, v6);
fieldVars.put(id, v6);
for (Map.Entry<Ast.IdPat, Unifier.Term> e : termMap1.entrySet()) {
env2 = env2.bind(e.getKey().name, e.getValue());
fieldVars.put(ast.id(Pos.ZERO, e.getKey().name),
(Unifier.Variable) e.getValue());
}
}
final List<Ast.FromStep> fromSteps = new ArrayList<>();
for (Ast.FromStep step : from.steps) {
Expand Down
28 changes: 15 additions & 13 deletions src/main/java/net/hydromatic/morel/eval/Codes.java
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ public static Applicable tyCon(Type dataType, String name) {
return (env, arg) -> ImmutableList.of(name, arg);
}

public static Code from(Map<Ast.Id, Code> sources,
public static Code from(Map<Ast.Pat, Code> sources,
Supplier<RowSink> rowSinkFactory) {
if (sources.size() == 0) {
return env -> {
Expand All @@ -322,11 +322,11 @@ public static Code from(Map<Ast.Id, Code> sources,
return rowSink.result(env);
};
}
final ImmutableList<Ast.Id> ids = ImmutableList.copyOf(sources.keySet());
final ImmutableList<Ast.Pat> pats = ImmutableList.copyOf(sources.keySet());
final ImmutableList<Code> codes = ImmutableList.copyOf(sources.values());
return env -> {
final RowSink rowSink = rowSinkFactory.get();
final Looper looper = new Looper(ids, codes, env, rowSink);
final Looper looper = new Looper(pats, codes, env, rowSink);
looper.loop(0);
return rowSink.result(env);
};
Expand Down Expand Up @@ -1049,12 +1049,12 @@ private static class Looper {
private final ImmutableList<Code> codes;
private final RowSink rowSink;

Looper(ImmutableList<Ast.Id> ids, ImmutableList<Code> codes, EvalEnv env,
Looper(ImmutableList<Ast.Pat> pats, ImmutableList<Code> codes, EvalEnv env,
RowSink rowSink) {
this.codes = codes;
this.rowSink = rowSink;
for (Ast.Id id : ids) {
final MutableEvalEnv mutableEnv = env.bindMutable(id.name);
for (Ast.Pat pat : pats) {
final MutableEvalEnv mutableEnv = env.bindMutablePat(pat);
mutableEvalEnvs.add(mutableEnv);
env = mutableEnv;
iterables.add(null);
Expand All @@ -1071,16 +1071,18 @@ void loop(int i) {
final int next = i + 1;
if (next == iterables.size()) {
for (Object o : iterable) {
mutableEvalEnv.set(o);
rowSink.accept(mutableEvalEnv);
if (mutableEvalEnv.setOpt(o)) {
rowSink.accept(mutableEvalEnv);
}
}
} else {
for (Object o : iterable) {
mutableEvalEnv.set(o);
//noinspection unchecked
iterables.set(next, (Iterable<Object>)
codes.get(next).eval(mutableEvalEnvs.get(next)));
loop(next);
if (mutableEvalEnv.setOpt(o)) {
//noinspection unchecked
iterables.set(next, (Iterable<Object>)
codes.get(next).eval(mutableEvalEnvs.get(next)));
loop(next);
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/net/hydromatic/morel/eval/EvalEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
*/
package net.hydromatic.morel.eval;

import net.hydromatic.morel.ast.Ast;
import net.hydromatic.morel.compile.Environment;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -47,6 +49,22 @@ default MutableEvalEnv bindMutable(String name) {
return new EvalEnvs.MutableSubEvalEnv(this, name);
}

/** Creates an evaluation environment that has the same content as this one,
* plus mutable slots for each name in a pattern. */
default MutableEvalEnv bindMutablePat(Ast.Pat pat) {
if (pat instanceof Ast.IdPat) {
// Pattern is simple; use a simple implementation.
return bindMutable(((Ast.IdPat) pat).name);
}
final List<String> names = new ArrayList<>();
pat.visit(p -> {
if (p instanceof Ast.IdPat) {
names.add(((Ast.IdPat) p).name);
}
});
return new EvalEnvs.MutablePatSubEvalEnv(this, pat, names);
}

/** Creates an evaluation environment that has the same content as this one,
* plus a mutable slot or slots.
*
Expand Down
117 changes: 117 additions & 0 deletions src/main/java/net/hydromatic/morel/eval/EvalEnvs.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import net.hydromatic.morel.ast.Ast;
import net.hydromatic.morel.util.Pair;

import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -117,6 +121,119 @@ public Object getOpt(String name) {
}
}

/** Evaluation environment that binds several slots based on a pattern. */
static class MutablePatSubEvalEnv extends MutableArraySubEvalEnv {
private final Ast.Pat pat;
private int slot;

MutablePatSubEvalEnv(EvalEnv parentEnv, Ast.Pat pat, List<String> names) {
super(parentEnv, names);
this.pat = pat;
this.values = new Object[names.size()];
assert !(pat instanceof Ast.IdPat);
}

@Override public void set(Object value) {
if (!setOpt(value)) {
// If this error happens, perhaps your code should be calling "setOpt"
// and handling a false result appropriately.
throw new AssertionError("bind failed");
}
}

@Override public boolean setOpt(Object value) {
slot = 0;
return bindRecurse(pat, value);
}

boolean bindRecurse(Ast.Pat pat, Object argValue) {
final List<Object> listValue;
final Ast.LiteralPat literalPat;
switch (pat.op) {
case ID_PAT:
this.values[slot++] = argValue;
return true;

case WILDCARD_PAT:
return true;

case BOOL_LITERAL_PAT:
case CHAR_LITERAL_PAT:
case STRING_LITERAL_PAT:
literalPat = (Ast.LiteralPat) pat;
return literalPat.value.equals(argValue);

case INT_LITERAL_PAT:
literalPat = (Ast.LiteralPat) pat;
return ((BigDecimal) literalPat.value).intValue() == (Integer) argValue;

case REAL_LITERAL_PAT:
literalPat = (Ast.LiteralPat) pat;
return ((BigDecimal) literalPat.value).doubleValue() == (Double) argValue;

case TUPLE_PAT:
final Ast.TuplePat tuplePat = (Ast.TuplePat) pat;
listValue = (List) argValue;
for (Pair<Ast.Pat, Object> pair : Pair.zip(tuplePat.args, listValue)) {
if (!bindRecurse(pair.left, pair.right)) {
return false;
}
}
return true;

case RECORD_PAT:
final Ast.RecordPat recordPat = (Ast.RecordPat) pat;
listValue = (List) argValue;
for (Pair<Ast.Pat, Object> pair
: Pair.zip(recordPat.args.values(), listValue)) {
if (!bindRecurse(pair.left, pair.right)) {
return false;
}
}
return true;

case LIST_PAT:
final Ast.ListPat listPat = (Ast.ListPat) pat;
listValue = (List) argValue;
if (listValue.size() != listPat.args.size()) {
return false;
}
for (Pair<Ast.Pat, Object> pair : Pair.zip(listPat.args, listValue)) {
if (!bindRecurse(pair.left, pair.right)) {
return false;
}
}
return true;

case CONS_PAT:
final Ast.InfixPat infixPat = (Ast.InfixPat) pat;
@SuppressWarnings("unchecked") final List<Object> consValue =
(List) argValue;
if (consValue.isEmpty()) {
return false;
}
final Object head = consValue.get(0);
final List<Object> tail = consValue.subList(1, consValue.size());
return bindRecurse(infixPat.p0, head)
&& bindRecurse(infixPat.p1, tail);

case CON0_PAT:
final Ast.Con0Pat con0Pat = (Ast.Con0Pat) pat;
final List con0Value = (List) argValue;
return con0Value.get(0).equals(con0Pat.tyCon.name);

case CON_PAT:
final Ast.ConPat conPat = (Ast.ConPat) pat;
final List conValue = (List) argValue;
return conValue.get(0).equals(conPat.tyCon.name)
&& bindRecurse(conPat.pat, conValue.get(1));

default:
throw new AssertionError("cannot compile " + pat.op + ": " + pat);
}
}
}

/** Evaluation environment that reads from a map. */
static class MapEvalEnv implements EvalEnv {
final Map<String, Object> valueMap;
Expand Down
Loading

0 comments on commit b1f5955

Please sign in to comment.