Skip to content

Commit

Permalink
Check/synth record and tuple patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
Kmeakin committed Dec 12, 2022
1 parent 1d36cb4 commit 2f7e6a8
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 225 deletions.
298 changes: 84 additions & 214 deletions fathom/src/surface/elaboration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::surface::elaboration::reporting::Message;
use crate::surface::{distillation, pretty, BinOp, FormatField, Item, Module, Pattern, Term};

mod order;
mod patterns;
mod reporting;
mod unification;

Expand Down Expand Up @@ -463,6 +464,70 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
(labels.into(), filtered_fields)
}

fn check_tuple_fields<F>(
&mut self,
range: ByteRange,
fields: &[F],
get_range: fn(&F) -> ByteRange,
expected_labels: &[StringId],
) -> Result<(), ()> {
if fields.len() == expected_labels.len() {
return Ok(());
}

let mut found_labels = Vec::with_capacity(fields.len());
let mut fields_iter = fields.iter().enumerate().peekable();
let mut expected_labels_iter = expected_labels.iter();

// use the label names from the expected labels
while let Some(((_, field), label)) =
Option::zip(fields_iter.peek(), expected_labels_iter.next())
{
found_labels.push((get_range(field), *label));
fields_iter.next();
}

// use numeric labels for excess fields
for (index, field) in fields_iter {
found_labels.push((
get_range(field),
self.interner.borrow_mut().get_tuple_label(index),
));
}

self.push_message(Message::MismatchedFieldLabels {
range,
found_labels,
expected_labels: expected_labels.to_vec(),
});
Err(())
}

fn check_record_fields<F>(
&mut self,
range: ByteRange,
fields: &[F],
get_label: impl Fn(&F) -> (ByteRange, StringId),
labels: &'arena [StringId],
) -> Result<(), ()> {
if fields.len() == labels.len()
&& fields
.iter()
.zip(labels.iter())
.all(|(field, type_label)| get_label(field).1 == *type_label)
{
return Ok(());
}

// TODO: improve handling of duplicate labels
self.push_message(Message::MismatchedFieldLabels {
range,
found_labels: fields.iter().map(get_label).collect(),
expected_labels: labels.to_vec(),
});
Err(())
}

/// Parse a source string into number, assuming an ASCII encoding.
fn parse_ascii<T>(
&mut self,
Expand Down Expand Up @@ -696,177 +761,6 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
term
}

/// Check that a pattern matches an expected type.
fn check_pattern(
&mut self,
pattern: &Pattern<ByteRange>,
expected_type: &ArcValue<'arena>,
) -> CheckedPattern {
match pattern {
Pattern::Name(range, name) => CheckedPattern::Binder(*range, *name),
Pattern::Placeholder(range) => CheckedPattern::Placeholder(*range),
Pattern::StringLiteral(range, lit) => {
let constant = match expected_type.match_prim_spine() {
Some((Prim::U8Type, [])) => self.parse_ascii(*range, *lit, Const::U8),
Some((Prim::U16Type, [])) => self.parse_ascii(*range, *lit, Const::U16),
Some((Prim::U32Type, [])) => self.parse_ascii(*range, *lit, Const::U32),
Some((Prim::U64Type, [])) => self.parse_ascii(*range, *lit, Const::U64),
// Some((Prim::Array8Type, [len, _])) => todo!(),
// Some((Prim::Array16Type, [len, _])) => todo!(),
// Some((Prim::Array32Type, [len, _])) => todo!(),
// Some((Prim::Array64Type, [len, _])) => todo!(),
Some((Prim::ReportedError, _)) => None,
_ => {
let expected_type = self.pretty_print_value(expected_type);
self.push_message(Message::StringLiteralNotSupported {
range: *range,
expected_type,
});
None
}
};

match constant {
Some(constant) => CheckedPattern::ConstLit(*range, constant),
None => CheckedPattern::ReportedError(*range),
}
}
Pattern::NumberLiteral(range, lit) => {
let constant = match expected_type.match_prim_spine() {
Some((Prim::U8Type, [])) => self.parse_number_radix(*range, *lit, Const::U8),
Some((Prim::U16Type, [])) => self.parse_number_radix(*range, *lit, Const::U16),
Some((Prim::U32Type, [])) => self.parse_number_radix(*range, *lit, Const::U32),
Some((Prim::U64Type, [])) => self.parse_number_radix(*range, *lit, Const::U64),
Some((Prim::S8Type, [])) => self.parse_number(*range, *lit, Const::S8),
Some((Prim::S16Type, [])) => self.parse_number(*range, *lit, Const::S16),
Some((Prim::S32Type, [])) => self.parse_number(*range, *lit, Const::S32),
Some((Prim::S64Type, [])) => self.parse_number(*range, *lit, Const::S64),
Some((Prim::F32Type, [])) => self.parse_number(*range, *lit, Const::F32),
Some((Prim::F64Type, [])) => self.parse_number(*range, *lit, Const::F64),
Some((Prim::ReportedError, _)) => None,
_ => {
let expected_type = self.pretty_print_value(expected_type);
self.push_message(Message::NumericLiteralNotSupported {
range: *range,
expected_type,
});
None
}
};

match constant {
Some(constant) => CheckedPattern::ConstLit(*range, constant),
None => CheckedPattern::ReportedError(*range),
}
}
Pattern::BooleanLiteral(range, boolean) => {
let constant = match expected_type.match_prim_spine() {
Some((Prim::BoolType, [])) => match *boolean {
true => Some(Const::Bool(true)),
false => Some(Const::Bool(false)),
},
_ => {
self.push_message(Message::BooleanLiteralNotSupported { range: *range });
None
}
};

match constant {
Some(constant) => CheckedPattern::ConstLit(*range, constant),
None => CheckedPattern::ReportedError(*range),
}
}
Pattern::RecordLiteral(_, _) => todo!(),
Pattern::Tuple(_, _) => todo!(),
}
}

/// Synthesize the type of a pattern.
fn synth_pattern(
&mut self,
pattern: &Pattern<ByteRange>,
) -> (CheckedPattern, ArcValue<'arena>) {
match pattern {
Pattern::Name(range, name) => {
let source = MetaSource::NamedPatternType(*range, *name);
let r#type = self.push_unsolved_type(source);
(CheckedPattern::Binder(*range, *name), r#type)
}
Pattern::Placeholder(range) => {
let source = MetaSource::PlaceholderPatternType(*range);
let r#type = self.push_unsolved_type(source);
(CheckedPattern::Placeholder(*range), r#type)
}
Pattern::StringLiteral(range, _) => {
self.push_message(Message::AmbiguousStringLiteral { range: *range });
let source = MetaSource::ReportedErrorType(*range);
let r#type = self.push_unsolved_type(source);
(CheckedPattern::ReportedError(*range), r#type)
}
Pattern::NumberLiteral(range, _) => {
self.push_message(Message::AmbiguousNumericLiteral { range: *range });
let source = MetaSource::ReportedErrorType(*range);
let r#type = self.push_unsolved_type(source);
(CheckedPattern::ReportedError(*range), r#type)
}
Pattern::BooleanLiteral(range, val) => {
let r#const = Const::Bool(*val);
let r#type = self.bool_type.clone();
(CheckedPattern::ConstLit(*range, r#const), r#type)
}
Pattern::RecordLiteral(_, _) => todo!(),
Pattern::Tuple(_, _) => todo!(),
}
}

/// Check that the type of an annotated pattern matches an expected type.
fn check_ann_pattern(
&mut self,
pattern: &Pattern<ByteRange>,
r#type: Option<&Term<'_, ByteRange>>,
expected_type: &ArcValue<'arena>,
) -> CheckedPattern {
match r#type {
None => self.check_pattern(pattern, expected_type),
Some(r#type) => {
let range = r#type.range();
let r#type = self.check(r#type, &self.universe.clone());
let r#type = self.eval_env().eval(&r#type);

match self.unification_context().unify(&r#type, expected_type) {
Ok(()) => self.check_pattern(pattern, &r#type),
Err(error) => {
let lhs = self.pretty_print_value(&r#type);
let rhs = self.pretty_print_value(expected_type);
self.push_message(Message::FailedToUnify {
range,
lhs,
rhs,
error,
});
CheckedPattern::ReportedError(range)
}
}
}
}
}

/// Synthesize the type of an annotated pattern.
fn synth_ann_pattern(
&mut self,
pattern: &Pattern<ByteRange>,
r#type: Option<&Term<'_, ByteRange>>,
) -> (CheckedPattern, ArcValue<'arena>) {
match r#type {
None => self.synth_pattern(pattern),
Some(r#type) => {
let r#type = self.check(r#type, &self.universe.clone());
let type_value = self.eval_env().eval(&r#type);
(self.check_pattern(pattern, &type_value), type_value)
}
}
}

/// Push a local definition onto the context.
/// The supplied `pattern` is expected to be irrefutable.
fn push_local_def(
Expand All @@ -886,6 +780,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
None
}
CheckedPattern::ReportedError(_) => None,
CheckedPattern::RecordLit(_, _, _) => todo!(),
};

self.local_env.push_def(name, expr, r#type);
Expand All @@ -911,6 +806,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
None
}
CheckedPattern::ReportedError(_) => None,
CheckedPattern::RecordLit(_, _, _) => todo!(),
};

let expr = self.local_env.push_param(name, r#type);
Expand Down Expand Up @@ -970,18 +866,10 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
self.check_fun_lit(*range, patterns, body_expr, &expected_type)
}
(Term::RecordLiteral(range, expr_fields), Value::RecordType(labels, types)) => {
// TODO: improve handling of duplicate labels
if expr_fields.len() != labels.len()
|| Iterator::zip(expr_fields.iter(), labels.iter())
.any(|(expr_field, type_label)| expr_field.label.1 != *type_label)
if self
.check_record_fields(*range, expr_fields, |field| field.label, labels)
.is_err()
{
self.push_message(Message::MismatchedFieldLabels {
range: *range,
expr_labels: (expr_fields.iter())
.map(|expr_field| expr_field.label)
.collect(),
type_labels: labels.to_vec(),
});
return core::Term::Prim(range.into(), Prim::ReportedError);
}

Expand Down Expand Up @@ -1045,33 +933,11 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
core::Term::FormatRecord(range.into(), labels, formats)
}
(Term::Tuple(range, elem_exprs), Value::RecordType(labels, types)) => {
if elem_exprs.len() != labels.len() {
let mut expr_labels = Vec::with_capacity(elem_exprs.len());
let mut elem_exprs = elem_exprs.iter().enumerate().peekable();
let mut label_iter = labels.iter();

// use the label names from the expected type
while let Some(((_, elem_expr), label)) =
Option::zip(elem_exprs.peek(), label_iter.next())
{
expr_labels.push((elem_expr.range(), *label));
elem_exprs.next();
}

// use numeric labels for excess elems
for (index, elem_expr) in elem_exprs {
expr_labels.push((
elem_expr.range(),
self.interner.borrow_mut().get_tuple_label(index),
));
}

self.push_message(Message::MismatchedFieldLabels {
range: *range,
expr_labels,
type_labels: labels.to_vec(),
});
return core::Term::Prim(range.into(), Prim::ReportedError);
if self
.check_tuple_fields(*range, elem_exprs, |expr| expr.range(), labels)
.is_err()
{
return core::Term::error(range.into());
}

let mut types = types.clone();
Expand Down Expand Up @@ -2027,6 +1893,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
self.elab_match_unreachable(match_info, equations);
core::Term::Prim(range.into(), Prim::ReportedError)
}
CheckedPattern::RecordLit(_, _, _) => todo!(),
}
}
None => self.elab_match_absurd(is_reachable, match_info),
Expand Down Expand Up @@ -2116,6 +1983,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
default_branch = (None, self.scope.to_scope(default_expr) as &_);
self.local_env.pop();
}
CheckedPattern::RecordLit(_, _, _) => todo!(),
};

// A default pattern was found, check any unreachable patterns.
Expand Down Expand Up @@ -2196,15 +2064,17 @@ impl_from_str_radix!(u64);

/// Simple patterns that have had some initial elaboration performed on them
#[derive(Debug)]
enum CheckedPattern {
/// Pattern that binds local variable
Binder(ByteRange, StringId),
enum CheckedPattern<'arena> {
/// Error sentinel
ReportedError(ByteRange),
/// Placeholder patterns that match everything
Placeholder(ByteRange),
/// Pattern that binds local variable
Binder(ByteRange, StringId),
/// Constant literals
ConstLit(ByteRange, Const),
/// Error sentinel
ReportedError(ByteRange),
/// Record literals
RecordLit(ByteRange, &'arena [StringId], &'arena [Self]),
}

/// Scrutinee of a match expression
Expand Down
Loading

0 comments on commit 2f7e6a8

Please sign in to comment.