Skip to content

Commit

Permalink
refactor: restructure AST and everything else :)
Browse files Browse the repository at this point in the history
  • Loading branch information
therain7 committed Nov 3, 2024
1 parent 8d8fef3 commit 945bb72
Show file tree
Hide file tree
Showing 25 changed files with 960 additions and 592 deletions.
180 changes: 87 additions & 93 deletions lib/ast/LAst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,109 +7,103 @@
[@@@ocaml.text "/*"]

open! Base

(** List containing at least 1 element *)
type 'a list1 = 'a * 'a list [@@deriving show {with_path= false}]

(** List containing at least 2 elements *)
type 'a list2 = 'a * 'a * 'a list [@@deriving show {with_path= false}]

(** Identifiers **)
type ident = Id of string [@@deriving show {with_path= false}]

type constant =
| ConstInt of int (** Integer such as [25] *)
| ConstChar of char (** Character such as ['c'] *)
| ConstString of string
(** Constant string such as ["constant"] or [{|other constant|}] *)
[@@deriving show {with_path= false}]

(* ======= Types ======= *)
type ty =
| TyVar of ident (** A type variable such as ['a] *)
| TyArr of ty * ty (** [T1 -> T2] *)
| TyTuple of ty list2 (** [T1 * ... * Tn] *)
| TyCon of ident * ty list
(** [TyCon(tconstr, l)] represents:
open LMisc

module Const = struct
type t =
| Int of int (** Integer such as [25] *)
| Char of char (** Character such as ['c'] *)
| String of string
(** Constant string such as ["constant"] or [{|other constant|}] *)
[@@deriving show {with_path= false}]
end

module Ty = struct
type t =
| Var of Id.t (** A type variable such as ['a] *)
| Arr of t * t (** [T1 -> T2] *)
| Tuple of t List2.t (** [T1 * ... * Tn] *)
| Con of Id.t * t list
(** [Con(tconstr, l)] represents:
- [tconstr] when [l=[]]
- [T tconstr] when [l=[T]]
- [(T1, ..., Tn) tconstr] when [l=[T1, ..., Tn]]
*)
[@@deriving show {with_path= false}]

(* ======= Patterns ======= *)
type pattern =
| PatAny (** The pattern [_] *)
| PatVar of ident (** A variable pattern such as [x] *)
| PatConst of constant (** Patterns such as [1], ['a'], ["hello"], [1.5] *)
| PatTuple of pattern list2 (** [(P1, ..., Pn)] *)
| PatOr of pattern * pattern (** [P1 | P2] *)
| PatConstruct of ident * pattern option
(** [PatConstruct(C, arg)] represents:
[@@deriving show {with_path= false}]
end

module Pat = struct
type t =
| Any (** The pattern [_] *)
| Var of Id.t (** A variable pattern such as [x] *)
| Const of Const.t (** Patterns such as [1], ['a'], ["hello"], [1.5] *)
| Tuple of t List2.t (** [(P1, ..., Pn)] *)
| Or of t * t (** [P1 | P2] *)
| Construct of Id.t * t option
(** [Construct(C, arg)] represents:
- [C] when [arg] is [None]
- [C P] when [arg] is [Some P]
*)
| PatConstraint of pattern * ty (** [(P : T)] *)
[@@deriving show {with_path= false}]

(* ======= Expressions ======= *)
type rec_flag =
| Recursive (** Recursive value binding *)
| Nonrecursive (** Nonrecursive value binding *)
[@@deriving show {with_path= false}]

type value_binding = {pat: pattern; expr: expression}
[@@deriving show {with_path= false}]

(** Pattern matching case *)
and case = {left: pattern; right: expression}
[@@deriving show {with_path= false}]

and expression =
| ExpIdent of ident (** Identifiers such as [x], [fact] *)
| ExpConst of constant
(** Expression constant such as [1], ['a'], ["hello"], [1.5] *)
| ExpLet of rec_flag * value_binding list1 * expression
(** [ExpLet(flag, [(P1,E1) ; ... ; (Pn,En)], E)] represents:
- [let P1 = E1 and ... and Pn = EN in E] when [flag] is [Nonrecursive]
- [let rec P1 = E1 and ... and Pn = EN in E] when [flag] is [Recursive]
| Constraint of t * Ty.t (** [(P : T)] *)
[@@deriving show {with_path= false}]
end

module Expr = struct
type rec_flag = Rec | Nonrec [@@deriving show {with_path= false}]

type value_binding = {pat: Pat.t; expr: t}
[@@deriving show {with_path= false}]

(** Pattern matching case *)
and case = {left: Pat.t; right: t} [@@deriving show {with_path= false}]

and t =
| Id of Id.t (** Identifiers such as [x], [fact] *)
| Const of Const.t
(** Expression constant such as [1], ['a'], ["hello"], [1.5] *)
| Let of rec_flag * value_binding List1.t * t
(** [Let(flag, [(P1,E1) ; ... ; (Pn,En)], E)] represents:
- [let P1 = E1 and ... and Pn = EN in E] when [flag] is [Nonrec]
- [let rec P1 = E1 and ... and Pn = EN in E] when [flag] is [Rec]
*)
| ExpFun of pattern list1 * expression (** [fun P1 ... Pn -> E] *)
| ExpFunction of case list1 (** [function C1 | ... | Cn] *)
| ExpApply of expression * expression (** [E1 E2] *)
| ExpMatch of expression * case list1
(** [match E with P1 -> E1 | ... | Pn -> En] *)
| ExpTuple of expression list2 (** [(E1, ..., En)] *)
| ExpConstruct of ident * expression option
(** [ExpConstruct(C, exp)] represents:
| Fun of Pat.t List1.t * t (** [fun P1 ... Pn -> E] *)
| Function of case List1.t (** [function C1 | ... | Cn] *)
| Apply of t * t (** [E1 E2] *)
| Match of t * case List1.t (** [match E with P1 -> E1 | ... | Pn -> En] *)
| Tuple of t List2.t (** [(E1, ..., En)] *)
| Construct of Id.t * t option
(** [Construct(C, exp)] represents:
- [C] when [exp] is [None]
- [C E] when [exp] is [Some E]
- [C (E1, ..., En)] when [exp] is [Some (ExpTuple[E1,...,En])]
- [C (E1, ..., En)] when [exp] is [Some (Tuple[E1,...,En])]
*)
| ExpIf of expression * expression * expression option
(** [if E1 then E2 else E3] *)
| ExpSeq of expression list2 (** [E1; E2] *)
| ExpConstraint of expression * ty (** [(E : T)] *)
[@@deriving show {with_path= false}]

(* ======= Module structure ======= *)

(** Constructor declaration. E.g. [A of string] *)
type constructor_decl = {id: ident; arg: ty option}
[@@deriving show {with_path= false}]

(** Variant type declaration *)
type type_decl = {id: ident; params: ident list; variants: constructor_decl list}
[@@deriving show {with_path= false}]

type structure_item =
| StrEval of expression (** [E] *)
| StrType of type_decl (** [type ('a, 'b) ab = A of T1 | B of T2 ...] *)
| StrLet of rec_flag * value_binding list1
(** [StrLet(flag, [(P1, E1) ; ... ; (Pn, En)])] represents:
- [let P1 = E1 and ... and Pn = EN] when [flag] is [Nonrecursive]
- [let rec P1 = E1 and ... and Pn = EN ] when [flag] is [Recursive]
| If of t * t * t option (** [if E1 then E2 else E3] *)
| Seq of t List2.t (** [E1; E2] *)
| Constraint of t * Ty.t (** [(E : T)] *)
[@@deriving show {with_path= false}]
end

module StrItem = struct
type rec_flag = Expr.rec_flag [@@deriving show {with_path= false}]
type value_binding = Expr.value_binding [@@deriving show {with_path= false}]

(** Constructor declaration. E.g. [A of string] *)
type construct_decl = {id: Id.t; arg: Ty.t option}
[@@deriving show {with_path= false}]

(** Variant type declaration *)
type type_decl = {id: Id.t; params: Id.t list; variants: construct_decl list}
[@@deriving show {with_path= false}]

type t =
| Eval of Expr.t (** [E] *)
| Type of type_decl (** [type ('a, 'b) ab = A of T1 | B of T2 ...] *)
| Let of rec_flag * value_binding List1.t
(** [Let(flag, [(P1, E1) ; ... ; (Pn, En)])] represents:
- [let P1 = E1 and ... and Pn = EN] when [flag] is [Nonrec]
- [let rec P1 = E1 and ... and Pn = EN ] when [flag] is [Rec]
*)
[@@deriving show {with_path= false}]
[@@deriving show {with_path= false}]
end

type structure = structure_item list [@@deriving show {with_path= false}]
type structure = StrItem.t list [@@deriving show {with_path= false}]
2 changes: 1 addition & 1 deletion lib/ast/dune
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(library
(name LAst)
(public_name NeML.Ast)
(libraries base)
(libraries base LMisc)
(preprocess
(pps ppx_deriving.show))
(instrumentation
Expand Down
36 changes: 36 additions & 0 deletions lib/misc/LMisc.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[@@@ocaml.text "/*"]

(** Copyright 2024, Andrei, PavlushaSource *)

(** SPDX-License-Identifier: MIT *)

[@@@ocaml.text "/*"]

open! Base

(** Identifiers *)
module Id = struct
type t = I of string [@@deriving show {with_path= false}]
end

(** List containing at least 1 element *)
module List1 = struct
type 'a t = 'a * 'a list [@@deriving show {with_path= false}]

let of_list_exn : 'a list -> 'a t = function
| hd :: tl ->
(hd, tl)
| [] ->
raise (Invalid_argument "empty list")
end

(** List containing at least 2 elements *)
module List2 = struct
type 'a t = 'a * 'a * 'a list [@@deriving show {with_path= false}]

let of_list_exn : 'a list -> 'a t = function
| fst :: snd :: tl ->
(fst, snd, tl)
| _ :: [] | [] ->
raise (Invalid_argument "not enough elements")
end
8 changes: 8 additions & 0 deletions lib/misc/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(library
(name LMisc)
(public_name NeML.Misc)
(libraries base)
(preprocess
(pps ppx_deriving.show))
(instrumentation
(backend bisect_ppx)))
2 changes: 1 addition & 1 deletion lib/parse/LParse.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
open! Base
open Angstrom

let parse s = parse_string ~consume:All Str.pstr s |> Result.ok
let parse s = parse_string ~consume:All PStr.p s |> Result.ok
Loading

0 comments on commit 945bb72

Please sign in to comment.