diff --git a/derivingvia-extras.cabal b/derivingvia-extras.cabal index 7665e6b..6e80297 100644 --- a/derivingvia-extras.cabal +++ b/derivingvia-extras.cabal @@ -37,6 +37,8 @@ source-repository head library exposed-modules: Deriving.On + Deriving.On.Class + Deriving.On.Nth hs-source-dirs: src default-language: Haskell98 build-depends: diff --git a/src/Deriving/On.hs b/src/Deriving/On.hs index 63bb8ed..97e6b55 100644 --- a/src/Deriving/On.hs +++ b/src/Deriving/On.hs @@ -1,5 +1,6 @@ {-# Language DataKinds #-} {-# Language InstanceSigs #-} +{-# Language PolyKinds #-} {-# Language ScopedTypeVariables #-} {-# Language StandaloneKindSignatures #-} {-# Language TypeApplications #-} @@ -8,12 +9,11 @@ module Deriving.On (On(..)) where -import Data.Function (on) -import Data.Hashable (Hashable(..)) -import Data.Kind (Type) -import Data.Ord (comparing) -import GHC.Records (HasField(..)) -import GHC.TypeLits (Symbol) +import Data.Function (on) +import Data.Hashable (Hashable(..)) +import Data.Kind (Type) +import Data.Ord (comparing) +import Deriving.On.Class (OnTarget (..)) -- | With 'DerivingVia': to derive non-structural instances. Specifies -- what field to base instances on. @@ -50,17 +50,18 @@ import GHC.TypeLits (Symbol) -- >> hash alice == hash bob -- True -- @ -type On :: Type -> Symbol -> Type +type On :: forall k. Type -> k -> Type newtype a `On` field = On a -instance (HasField field a b, Eq b) => Eq (a `On` field) where +instance (OnTarget k target a b, Eq b) => Eq (On @k a target) where (==) :: a `On` field -> a `On` field -> Bool - On a1 == On a2 = ((==) `on` getField @field) a1 a2 + On a1 == On a2 = ((==) `on` getTarget @k @target) a1 a2 -instance (HasField field a b, Ord b) => Ord (a `On` field) where - compare :: a `On` field -> a `On` field -> Ordering - On a1 `compare` On a2 = comparing (getField @field) a1 a2 -instance (HasField field a b, Hashable b) => Hashable (a `On` field) where - hashWithSalt :: Int -> a `On` field -> Int - hashWithSalt salt (On a) = hashWithSalt salt (getField @field a) +instance (OnTarget k target a b, Ord b) => Ord (On @k a target) where + compare :: a `On` target -> a `On` target -> Ordering + On a1 `compare` On a2 = comparing (getTarget @k @target) a1 a2 + +instance (OnTarget k target a b, Hashable b) => Hashable (On @k a target) where + hashWithSalt :: Int -> a `On` target -> Int + hashWithSalt salt (On a) = hashWithSalt salt (getTarget @k @target a) diff --git a/src/Deriving/On/Class.hs b/src/Deriving/On/Class.hs new file mode 100644 index 0000000..bd098f0 --- /dev/null +++ b/src/Deriving/On/Class.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE UndecidableInstances #-} + +module Deriving.On.Class where + +import GHC.Records (HasField (getField)) +import GHC.TypeLits (Symbol) +import Data.Kind (Type) + +class OnTarget k (target :: k) a b | k target a -> b where + getTarget :: a -> b + +instance HasField field a b => OnTarget Symbol field a b where + getTarget = getField @field + +instance (OnTarget Type t a b, OnTarget Type u a c) => OnTarget Type (t, u) a (b, c) where + getTarget a = (getTarget @_ @t a, getTarget @_ @u a) + +instance (OnTarget k t a b, OnTarget l u a c) => OnTarget (k, l) '(t, u) a (b, c) where + getTarget a = (getTarget @k @t a, getTarget @l @u a) diff --git a/src/Deriving/On/Nth.hs b/src/Deriving/On/Nth.hs new file mode 100644 index 0000000..75201c4 --- /dev/null +++ b/src/Deriving/On/Nth.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE UndecidableInstances #-} + +module Deriving.On.Nth where + +import Data.Kind (Constraint, Type) +import Deriving.On.Class (OnTarget (..)) +import qualified GHC.Generics as G +import GHC.TypeLits (ErrorMessage (..), KnownNat, Nat, TypeError, type (+), type (-), type (<=?)) + +-- | With 'DerivingVia': to derive non-structural instances with @'On'. +-- @'Nth' specifies what field to base instances on based on its position in +-- a product type. +-- +-- Does not support types with multiple constructors. +-- +-- The type @'On' User (Nth 2)@ is compared and evaluated based only +-- on the third record field, the @"userID"@. This uses GHC Generics to project +-- the relevant component. +-- +-- @ +-- {-# Language DataKinds #-} +-- {-# LANGUAGE DeriveGeneric #-} +-- {-# Language DerivingVia #-} +-- {-# Language TypeOperators #-} +-- +-- import Deriving.On +-- import Deriving.On.Nth +-- import Data.Hashable +-- import GHC.Generics (Generic) +-- +-- data User = User +-- { name :: String +-- , age :: Int +-- , userID :: Integer +-- } +-- deriving Generic +-- deriving (Eq, Ord, Hashable) +-- via User `On` Nth 2 +-- @ +-- +-- @ +-- >> alice = User "Alice" 50 0xDEADBEAF +-- >> bob = User "Bob" 20 0xDEADBEAF +-- >> +-- >> alice == bob +-- True +-- >> alice <= bob +-- True +-- >> hash alice == hash bob +-- True +-- @ +type Nth :: Nat -> Type +data Nth n + +instance + ( G.Generic a, + GHasNth n (G.Rep a) a, + v ~ Value n (G.Rep a) + ) => + OnTarget Type (Nth n) a v + where + getTarget = gGetNth @n @_ @a . G.from + +-- Generically get a value from a product type + +class (KnownNat n) => GHasNth n t originalTypeForErrorReporting where + type Value n t :: Type + gGetNth :: t x -> Value n t + +instance (KnownNat n, GHasNth n constructors original) => GHasNth n (G.D1 meta constructors) original where + type Value n (G.D1 meta constructors) = Value n constructors + gGetNth (G.M1 c) = gGetNth @n @_ @original c + +instance + ( KnownNat n, + v ~ (), + TypeError ('Text "Nth does not work on sum types like" ':$$: 'ShowType original) + ) => + GHasNth n (l G.:+: r) original + where + type Value n (l G.:+: r) = TypeError ('Text "Nth does not work on sum types") + gGetNth = error "Nth does not work on sum types" + +instance + ( KnownNat n, + countFields ~ SelectorSize selectors, + FailIf + (countFields <=? n) + ( 'Text "Specified index Nth " ':<>: 'ShowType n ':<>: 'Text " is too large for type" + ':$$: 'ShowType original + ':$$: 'Text "because it only has " ':<>: 'ShowType countFields ':<>: 'Text " fields." + ), + (GHasNth n selectors original) + ) => + GHasNth n (G.C1 meta selectors) original + where + type Value n (G.C1 meta selectors) = Value n selectors + gGetNth (G.M1 c) = gGetNth @n @_ @original c + +instance + (KnownNat n) => + GHasNth n (G.S1 meta (G.K1 metaK v)) original + where + type Value n (G.S1 meta (G.K1 metaK v)) = v + gGetNth (G.M1 (G.K1 v)) = v + +instance + ( KnownNat n, + GetNowOrLater (PositiveNat n) selectorL selectorR original + ) => + GHasNth n (selectorL G.:*: selectorR) original + where + type Value n (selectorL G.:*: selectorR) = SValue (PositiveNat n) selectorL selectorR + gGetNth (l G.:*: r) = getNowOrLater @_ @(PositiveNat n) @_ @_ @original l r + +type GetNowOrLater :: Maybe Nat -> (k -> Type) -> (k -> Type) -> Type -> Constraint +class GetNowOrLater positiveN now later originalForErrorReporting where + type SValue positiveN now later :: Type + getNowOrLater :: now x -> later x -> SValue positiveN now later + +instance GetNowOrLater 'Nothing (G.S1 meta (G.K1 metaK v)) later originalForErrorReporting where + type SValue 'Nothing (G.S1 meta (G.K1 metaK v)) later = v + getNowOrLater (G.M1 (G.K1 x)) _ = x + +instance + (GHasNth n later originalForErrorReporting) => + GetNowOrLater ('Just n) (G.S1 meta t) later originalForErrorReporting + where + type SValue ('Just n) (G.S1 meta t) later = Value n later + getNowOrLater _ later = gGetNth @n @_ @originalForErrorReporting later + +-- Helper type families + +type family PositiveNat n :: Maybe Nat where + PositiveNat 0 = 'Nothing + PositiveNat n = 'Just (n - 1) + +type family SelectorSize t :: Nat where + SelectorSize (G.S1 _ _) = 1 + SelectorSize (l G.:*: r) = SelectorSize l + SelectorSize r + +type FailIf :: Bool -> ErrorMessage -> Constraint +type family FailIf cond errorMsg where + FailIf 'True errorMsg = () ~ TypeError errorMsg + FailIf _ _ = ()