diff --git a/cur-lib/cur/curnel/racket-impl/stxutils.rkt b/cur-lib/cur/curnel/racket-impl/stxutils.rkt index 714f1276..f95146ef 100644 --- a/cur-lib/cur/curnel/racket-impl/stxutils.rkt +++ b/cur-lib/cur/curnel/racket-impl/stxutils.rkt @@ -8,6 +8,7 @@ syntax/id-set syntax/parse racket/syntax + syntax/stx syntax/parse/experimental/reflect) (provide (all-defined-out)) @@ -45,6 +46,51 @@ (datum->syntax syn (map (lambda (e) (subst v x e bvs)) (attribute e)))] [_ syn])) +(define (datum=? e1 e2) (equal? (syntax->datum e1) (syntax->datum e2))) + +(define (stx=? e1 e2 [id=? free-identifier=?]) + (cond + [(and (identifier? e1) (identifier? e2)) + (id=? e1 e2)] + [(and (number? (syntax-e e1)) (number? (syntax-e e2))) + (= (syntax-e e1) (syntax-e e2))] + [(and (stx-pair? e1) (stx-pair? e2)) + (and + ; short-circuit on length, for performance + (= (length (syntax->list e1)) (length (syntax->list e2))) + (andmap (λ (x y) (stx=? x y id=?)) (syntax->list e1) (syntax->list e2)))] + [else + (syntax-parse (list e1 e2) ; α equiv + ;; XXX: Matches on underlying lambda name... this is breaking abstractions + [(((~datum typed-λ) [x1:id (~datum :) ty1] b1) + ((~datum typed-λ) [x2:id (~datum :) ty2] b2)) + (and (stx=? #'ty1 #'ty2 id=?) + (stx=? #'b1 (subst #'x1 #'x2 #'b2) id=?))])])) + +;; returns e if e \in stx and (datum=? e0 e), else #f +;; (needed by ntac to workaround some scoping issues) +(define (find-in e0 stx) + (syntax-parse stx + [e #:when (stx=? #'e e0 datum=?) #'e] + [(e ...) + (for/first ([e (syntax->list #'(e ...))] + #:when (find-in e0 e)) + (find-in e0 e))] + [_ #f])) + +(define (subst-term v e0 syn [bvs (immutable-free-id-set)]) + (syntax-parse syn + [e + #:when (and (stx=? #'e e0) + (or (not (identifier? #'e)) + (not (free-id-set-member? bvs #'e)))) + v] + [((~and (~datum λ) lam) (z:id : ty) e) + #`(lam (z : #,(subst-term v e0 #'ty bvs)) #,(subst-term v e0 #'e (free-id-set-add bvs #'z)))] + [(e ...) + (datum->syntax syn (map (λ (e1) (subst-term v e0 e1 bvs)) (attribute e)))] + [_ syn])) + ;; takes a list of values and a list of identifiers, in dependency order, and substitutes them into syn. ;; TODO PERF: reverse (define (subst* v-ls x-ls syn) @@ -88,3 +134,6 @@ (lambda ([x #f]) (set! n (add1 n)) (format-id x "~a~a" (or x 'x) n #:source x)))) + +;; remove id v from lst +(define (remove-id v lst) (remove v lst free-identifier=?))