Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fun/Trouble with dependent types #709

Closed
tscholak opened this issue Dec 27, 2021 · 6 comments
Closed

Fun/Trouble with dependent types #709

tscholak opened this issue Dec 27, 2021 · 6 comments

Comments

@tscholak
Copy link
Contributor

tscholak commented Dec 27, 2021

Consider:

' Dependent typing test

data T =
  T1
  T2

def f (t:T) : Type =
  case t of
    T1 -> Float
    T2 -> Int

:p f T1
> Float32
:p f T2
> Int32

def g (t:T) : (f t) =
  case t of
    T1 -> 1.0
    T2 -> 2

I'm being presented with:

Compiler bug!
Please report this at github.com/google-research/dex-lang/issues

Lookup of Name GenName "unred" 0 failed
CallStack (from HasCallStack):
  error, called at src/lib/Env.hs:152:14 in dex-0.1.0.0-inplace:Env
  !, called at src/lib/Type.hs:1112:17 in dex-0.1.0.0-inplace:Type

Thus this ticket 😃

@tscholak
Copy link
Contributor Author

also:

def h (t:T) (v:f t) : Int =
  0
Compiler bug!
Please report this at github.com/google-research/dex-lang/issues

Lookup of Name GenName "unred" 0 failed
CallStack (from HasCallStack):
  error, called at src/lib/Env.hs:152:14 in dex-0.1.0.0-inplace:Env
  !, called at src/lib/Type.hs:1112:17 in dex-0.1.0.0-inplace:Type

@dougalm
Copy link
Collaborator

dougalm commented Dec 27, 2021

Thanks for the bug reports! We have two things going on here. (1) there's a scope-escape bug in the compiler and (2) even when fixed, we can't handle programs like these.

The compiler is riddled with these sorts of naming bugs, which is why we embarked on an adventure to prevent them by encoding name scopes at the type level in Haskell, see this branch: https://github.com/google-research/dex-lang/tree/safe-names-dev . We're almost ready to merge it. I just need to finish AD and some loose ends.

If you try your programs on the safe-names-dev branch, here's what you get:

data T =
  T1
  T2

def f (t:T) : Type =
  case t of
    T1 -> Float
    T2 -> Int

:p f T1
> Float32
:p f T2
> Int32

def g (t:T) : (f t) =
  case t of
    T1 -> 1.0
    T2 -> 2
> Type error:Can't reduce type expression: (f t)
>
> def g (t:T) : (f t) =
>                ^^^

def h (t:T) (v:f t) : Int =
  0
> Type error:Can't reduce type expression: (f t)
>
> def h (t:T) (v:f t) : Int =
>                ^^^

And that's why we can't handle programs like these. We allow application in types at the surface level, but they have to be trivially reducible. This lets us use aliases like Fin : Int -> Type, but we can't do anything fancier. We want to improve that though. At the very least, we want to allow unreduced applications in types for things like functors.

@tscholak
Copy link
Contributor Author

Thanks for the prompt reply, @dougalm! It's a bit of a bummer that this doesn't work. I've got the following use case, perhaps you have an idea for an alternative encoding:

' Layer normalization layer

data HasBias =
  WithBias
  WithoutBias

def LayerNorm (hasBias:HasBias) (n:Type) : Type =
  case hasBias of
    WithBias ->
      (n=>Float & n=>Float & Float)
    WithoutBias ->
      (n=>Float & Float)

:p LayerNorm WithBias (Fin 10)
> (((Fin 10) => Float32) & (((Fin 10) => Float32) & Float32))
:p LayerNorm WithoutBias (Fin 10)
> (((Fin 10) => Float32) & Float32)

def layerNorm {hasBias n} : LayerNorm hasBias n -> n=>Float -> n=>Float =
  size' = IToF (size n)
  case hasBias of
    WithBias ->
      \(weight, bias, eps) i.
        mean' = mean i
        std = sqrt (sum for k. sq (i.k - mean') / size' + eps)
        norm = for k. (i.k - mean') / std
        for k. weight.k * norm.k + bias.k
    WithoutBias ->
      \(weight, eps) i.
        mean' = mean i
        std = sqrt (sum for k. sq i.k / size' + eps)
        norm = i / std
        for k. weight.k * norm.k
Type error:Can't reduce type expression: (LayerNorm hasBias n)

def layerNorm {hasBias n} : LayerNorm hasBias n -> n=>Float -> n=>Float =

@dougalm
Copy link
Collaborator

dougalm commented Dec 27, 2021

Great example! I can't think of a way to encode this in Dex today other than to use an ordinary ADT for LayerNorm. In Haskell you could at least use a GADT:

data LayerNorm (hasBias::HasBias) n where
  LayerNormWithBias    :: (n=>Float & n=>Float & Float) -> LayerNorm WithBias n
  LayerNormWithoutBias :: (n=>Float & Float)            -> LayerNorm WithoutBias n

But Dex's current ADTs are only halfway to being GADTs. You can have dependence between arguments but each branch still produces the same type LayerNorm hasBias n.

But even if we fixed that, GADTs don't feel like the perfect solution here. For one thing, you'd like to have some guarantee that all of this case analysis is just compile time stuff and at run time you're working with functions that are fully specialized to the particular configuration of your neural net.

So we should think about what it would take to support your example as written. It'll require changing the way we represent types internally because we'll have to allow applications and/or case expressions in types. And it'll require changing the way we reason about type equality, because we'll have to consider an environment that gets refined as you go under a branch of a case expression and learn that, e.g., hasBias is now known to be WithBias. And there are surely other consequences I'm not thinking of. It'll definitely make the implementation more complicated, and it may hurt compile-time performance, but it might still be worth it. We should at least think through what it would take. Do you mind opening a feature request?

@tscholak
Copy link
Contributor Author

I agree that GADTs don't feel right here, because I've done that before... https://github.com/hasktorch/hasktorch/blob/9560d149b06af17e2d4e73d1e116afd2b0baff86/experimental/gradually-typed/src/Torch/GraduallyTyped/NN/Normalization.hs#L42

I'll open a feature request 😃

@apaszke
Copy link
Collaborator

apaszke commented Jan 3, 2022

This is superseded by #710 so closing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants