From 4abf77140297ba65579eb215605b9e5f2ea9c679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=BD=D0=B5=D0=B2=D0=B8=D0=B4=D0=B8=D0=BC=D0=BA=D0=B0?= Date: Thu, 18 Nov 2021 23:01:10 +0300 Subject: [PATCH] Derivation object composition --- .../core/src/main/scala/derevo/Derevo.scala | 193 ++++++++++-------- .../core/src/main/scala/derevo/package.scala | 4 + .../test/scala/derevo/CompositionSuite.scala | 23 +++ 3 files changed, 139 insertions(+), 81 deletions(-) create mode 100644 modules/core/src/test/scala/derevo/CompositionSuite.scala diff --git a/modules/core/src/main/scala/derevo/Derevo.scala b/modules/core/src/main/scala/derevo/Derevo.scala index 9e8a6efd..1cfecba0 100644 --- a/modules/core/src/main/scala/derevo/Derevo.scala +++ b/modules/core/src/main/scala/derevo/Derevo.scala @@ -7,8 +7,9 @@ class Derevo(val c: blackbox.Context) { import c.universe._ type Newtype = NewtypeP[Tree] - type NameAndTypes = NameAndTypesP[c.Type] + type NameAndTypes = NameAndTypesP[Tree, c.Type] + val CompositeSymbol = typeOf[composite].typeSymbol val DelegatingSymbol = typeOf[delegating].typeSymbol val PhantomSymbol = typeOf[phantom].typeSymbol val PassTypeArgsSymbol = typeOf[PassTypeArgs].typeSymbol @@ -19,6 +20,25 @@ class Derevo(val c: blackbox.Context) { val IsSpecificDerivation = isInstanceDef[SpecificDerivation[Any, Any, Any]]() val IsDerivation = isInstanceDef[Derivation[Any]]() + object IsCompositeDerivation { + def unapply(objType: Type): Option[List[(Type, Type, Type, Int, Boolean, Tree)]] = + objType.typeSymbol.annotations.map(_.tree).collectFirst { + case q"new $comp(..$args)" if comp.symbol == CompositeSymbol => + args.map(arg => arg -> c.typecheck(extractObj(arg)).tpe).flatMap { + case (_, IsCompositeDerivation(subs)) => subs + case (arg, IsSpecificDerivation(f, t, nt, d)) => List((f, t, nt, d, true, arg)) + case (arg, IsDerivation(f, t, nt, d)) => List((f, t, nt, d, true, arg)) + case (arg, argTpe @ HKDerivation(f, t, nt, d)) => + argTpe match { + case ParamRequire(fr, _, _, _) => List((fr, t, nt, d, true, arg)) + case _ => List((f, t, nt, d, false, arg)) + } + + case (arg, _) => abort(s"$arg seems not extending InstanceDef traits") + } + } + } + val HKDerivation = new DerivationList( isInstanceDef[DerivationKN1[Any]](1), isInstanceDef[DerivationKN2[Any]](2), @@ -164,14 +184,14 @@ class Derevo(val c: blackbox.Context) { } c.prefix.tree match { case q"new $_(..${instances})" => - instances.map(buildInstance(_, cls, newType)) + instances.flatMap(buildInstance(_, cls, newType)) case _ => c.error(c.prefix.tree.pos, s"FIXME: Could not match annotation tree `${c.prefix.tree}'") Nil } } - private def buildInstance(tree: Tree, impl: ImplDef, newType: Option[Newtype]): Tree = { + private def buildInstance(tree: Tree, impl: ImplDef, newType: Option[Newtype]): List[Tree] = { val typRef = impl match { case cls: ClassDef => tq"${impl.name.toTypeName}" case obj: ModuleDef => @@ -181,88 +201,80 @@ class Derevo(val c: blackbox.Context) { } } - val (mode, call) = tree match { - case q"$obj.$method(..$args)" => (nameAndTypes(obj), tree) - - case q"$obj(..$args)" => (nameAndTypes(obj), tree) - - case q"$obj" => - val call = newType.fold(q"$obj.instance")(t => q"$obj.newtype[${t.underlying}].instance") - (nameAndTypes(obj), call) - } - - val tn = TermName(mode.name) - val allTparams = impl match { - case cls: ClassDef => cls.tparams - case obj: ModuleDef => Nil - } + nameAndTypes(tree, newType).map { mode => + val tn = TermName(mode.name) + val allTparams = impl match { + case cls: ClassDef => cls.tparams + case obj: ModuleDef => Nil + } - val tparams = allTparams.dropRight(mode.drop) - val pparams = allTparams.takeRight(mode.drop) + val tparams = allTparams.dropRight(mode.drop) + val pparams = allTparams.takeRight(mode.drop) - val tps = tparams.map(_.name) - def appTyp = tq"$typRef[..$tps]" - def allTnames = allTparams.map(_.name) - def lamTyp = tq"({ type Lam[..$pparams] = $typRef[..$allTnames] })#Lam" - val outTyp = if (pparams.isEmpty) appTyp else lamTyp + val tps = tparams.map(_.name) + def appTyp = tq"$typRef[..$tps]" + def allTnames = allTparams.map(_.name) + def lamTyp = tq"({ type Lam[..$pparams] = $typRef[..$allTnames] })#Lam" + val outTyp = if (pparams.isEmpty) appTyp else lamTyp - val resT = mkAppliedType(mode.to, outTyp) + val resT = mkAppliedType(mode.to, outTyp) - val callWithT = if (mode.passArgs) q"$call[$outTyp]" else call + val callWithT = if (mode.passArgs) q"${mode.call}[$outTyp]" else mode.call - def fixFirstTypeParam = { - val nothingT = c.typeOf[Nothing] + def fixFirstTypeParam = { + val nothingT = c.typeOf[Nothing] - c.typecheck(call, silent = true, withMacrosDisabled = true) match { - case q"$method[$nothing, ..$remainingTpes](..$args)" if nothing.tpe == nothingT => - q"$method[$outTyp, ..$remainingTpes](..$args)" - case q"$method[$nothing, ..$remainingTpes]" if nothing.tpe == nothingT => - q"$method[$outTyp, ..$remainingTpes]" - case _ => tree + c.typecheck(mode.call, silent = true, withMacrosDisabled = true) match { + case q"$method[$nothing, ..$remainingTpes](..$args)" if nothing.tpe == nothingT => + q"$method[$outTyp, ..$remainingTpes](..$args)" + case q"$method[$nothing, ..$remainingTpes]" if nothing.tpe == nothingT => + q"$method[$outTyp, ..$remainingTpes]" + case _ => tree + } } - } - if (allTparams.isEmpty || allTparams.length <= mode.drop) { - if (mode.keepRefinements) { - q""" - @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) - implicit val $tn = $fixFirstTypeParam - """ + if (allTparams.isEmpty || allTparams.length <= mode.drop) { + if (mode.keepRefinements) { + q""" + @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) + implicit val $tn = $fixFirstTypeParam + """ + } else { + val resTc = if (newType.isDefined) mode.newtype else mode.to + val resT = mkAppliedType(resTc, tq"$typRef") + + q""" + @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) + implicit val $tn: $resT = $callWithT + """ + } } else { - val resTc = if (newType.isDefined) mode.newtype else mode.to - val resT = mkAppliedType(resTc, tq"$typRef") - q""" - @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) - implicit val $tn: $resT = $callWithT - """ - } - } else { - - val implicits = - if (mode.cascade) - tparams.flatMap { tparam => - val phantom = tparam.mods.annotations.exists { t => c.typecheck(t).tpe.typeSymbol == PhantomSymbol } - if (phantom) None - else { - val name = TermName(c.freshName("ev")) - val typ = tparam.name - val reqT = mkAppliedType(mode.from, tq"$typ") - Some(q"val $name: $reqT") + val implicits = + if (mode.cascade) + tparams.flatMap { tparam => + val phantom = tparam.mods.annotations.exists { t => c.typecheck(t).tpe.typeSymbol == PhantomSymbol } + if (phantom) None + else { + val name = TermName(c.freshName("ev")) + val typ = tparam.name + val reqT = mkAppliedType(mode.from, tq"$typ") + Some(q"val $name: $reqT") + } } - } - else Nil - - if (mode.keepRefinements) { - q""" - @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) - implicit def $tn[..$tparams](implicit ..$implicits) = $fixFirstTypeParam - """ - } else { - q""" - @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) - implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT - """ + else Nil + + if (mode.keepRefinements) { + q""" + @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) + implicit def $tn[..$tparams](implicit ..$implicits) = $fixFirstTypeParam + """ + } else { + q""" + @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all")) + implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT + """ + } } } } @@ -283,19 +295,37 @@ class Derevo(val c: blackbox.Context) { tq"$tc[$arg]" } - private def nameAndTypes(obj: Tree): NameAndTypes = { + private def extractCall(tree: Tree, newType: Option[Newtype]): Tree = tree match { + case q"$obj.$method(..$args)" => tree + case q"$obj(..$args)" => tree + case q"$obj" => newType.fold(q"$obj.instance")(t => q"$obj.newtype[${t.underlying}].instance") + } + + private def extractObj(tree: Tree): Tree = tree match { + case q"$obj.$method(..$args)" => obj + case q"$obj(..$args)" => obj + case q"$obj" => obj + } + + private def nameAndTypes(tree: Tree, newType: Option[Newtype]): List[NameAndTypes] = { + val obj = extractObj(tree) val mangledName = obj.toString.replaceAll("[^\\w]", "_") val name = c.freshName(mangledName) val objTyp = c.typecheck(obj).tpe + val call = extractCall(tree, newType) val nt = objTyp match { - case IsSpecificDerivation(f, t, nt, d) => new NameAndTypes(name, f, t, nt, d, cascade = true) - case IsDerivation(f, t, nt, d) => new NameAndTypes(name, f, t, nt, d, cascade = true) + case IsCompositeDerivation(subs) => + subs.map { case (f, t, nt, d, cascade, tree) => + new NameAndTypes(extractCall(tree, newType), c.freshName(mangledName), f, t, nt, d, cascade) + } + case IsSpecificDerivation(f, t, nt, d) => List(new NameAndTypes(call, name, f, t, nt, d, cascade = true)) + case IsDerivation(f, t, nt, d) => List(new NameAndTypes(call, name, f, t, nt, d, cascade = true)) case HKDerivation(f, t, nt, d) => objTyp match { - case ParamRequire(fr, _, _, _) => new NameAndTypes(name, fr, t, nt, d, cascade = true) - case _ => new NameAndTypes(name, f, t, nt, d, cascade = false) + case ParamRequire(fr, _, _, _) => List(new NameAndTypes(call, name, fr, t, nt, d, cascade = true)) + case _ => List(new NameAndTypes(call, name, f, t, nt, d, cascade = false)) } case _ => abort(s"$obj seems not extending InstanceDef traits") @@ -311,7 +341,7 @@ class Derevo(val c: blackbox.Context) { case _ => false } - nt.copy(passArgs = passArgs, keepRefinements = keepRefinements) + nt.map(_.copy(passArgs = passArgs, keepRefinements = keepRefinements)) } trait DerivationMatcher { @@ -357,7 +387,8 @@ object Derevo { private[Derevo] final case class NewtypeCls[tree](underlying: tree) extends NewtypeP[tree] private[Derevo] final case class NewtypeMod[tree](underlying: tree, res: tree) extends NewtypeP[tree] - private[Derevo] final case class NameAndTypesP[typ]( + private[Derevo] final case class NameAndTypesP[tree, typ]( + call: tree, name: String, from: typ, to: typ, diff --git a/modules/core/src/main/scala/derevo/package.scala b/modules/core/src/main/scala/derevo/package.scala index cffc4b1a..ab787e2f 100644 --- a/modules/core/src/main/scala/derevo/package.scala +++ b/modules/core/src/main/scala/derevo/package.scala @@ -5,6 +5,8 @@ package derevo { def macroTransform(annottees: Any*): Any = macro Derevo.deriveMacro } + class composite(instances: Any*) extends StaticAnnotation + /** */ trait PassTypeArgs trait KeepRefinements @@ -26,6 +28,8 @@ package derevo { trait DerivationKN17[TC[alg[btr[_, _], _, _]]] extends InstanceDef trait SpecificDerivation[FromTC[_], ToTC[_], NT[_]] extends InstanceDef + class CompositeDerivation extends InstanceDef + } package object derevo { diff --git a/modules/core/src/test/scala/derevo/CompositionSuite.scala b/modules/core/src/test/scala/derevo/CompositionSuite.scala new file mode 100644 index 00000000..47c67875 --- /dev/null +++ b/modules/core/src/test/scala/derevo/CompositionSuite.scala @@ -0,0 +1,23 @@ +package derevo + +trait Part1[T] +trait Part2[T] + +object p1 extends Derivation[Part1] { + def instance[T]: Part1[T] = new Part1[T] {} +} + +object p2 extends Derivation[Part2] { + def instance[T]: Part2[T] = new Part2[T] {} +} + +@composite(p1, p2) +object p1AndP2 extends CompositeDerivation + +@derive(p1AndP2) +case class HasP1AndP2() + +object RefinementTest { + val part1: Part1[HasP1AndP2] = implicitly + val part2: Part2[HasP1AndP2] = implicitly +}