Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derivation object composition #333

Merged
merged 1 commit into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 112 additions & 81 deletions modules/core/src/main/scala/derevo/Derevo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ class Derevo(val c: blackbox.Context) {
import c.universe._

type Newtype = NewtypeP[Tree]
type NameAndTypes = NameAndTypesP[c.Type]
type NameAndTypes = NameAndTypesP[Tree, c.Type]

val CompositeSymbol = typeOf[composite].typeSymbol
val DelegatingSymbol = typeOf[delegating].typeSymbol
val PhantomSymbol = typeOf[phantom].typeSymbol
val PassTypeArgsSymbol = typeOf[PassTypeArgs].typeSymbol
Expand All @@ -19,6 +20,25 @@ class Derevo(val c: blackbox.Context) {
val IsSpecificDerivation = isInstanceDef[SpecificDerivation[Any, Any, Any]]()
val IsDerivation = isInstanceDef[Derivation[Any]]()

object IsCompositeDerivation {
def unapply(objType: Type): Option[List[(Type, Type, Type, Int, Boolean, Tree)]] =
objType.typeSymbol.annotations.map(_.tree).collectFirst {
case q"new $comp(..$args)" if comp.symbol == CompositeSymbol =>
args.map(arg => arg -> c.typecheck(extractObj(arg)).tpe).flatMap {
case (_, IsCompositeDerivation(subs)) => subs
case (arg, IsSpecificDerivation(f, t, nt, d)) => List((f, t, nt, d, true, arg))
case (arg, IsDerivation(f, t, nt, d)) => List((f, t, nt, d, true, arg))
case (arg, argTpe @ HKDerivation(f, t, nt, d)) =>
argTpe match {
case ParamRequire(fr, _, _, _) => List((fr, t, nt, d, true, arg))
case _ => List((f, t, nt, d, false, arg))
}

case (arg, _) => abort(s"$arg seems not extending InstanceDef traits")
}
}
}

val HKDerivation = new DerivationList(
isInstanceDef[DerivationKN1[Any]](1),
isInstanceDef[DerivationKN2[Any]](2),
Expand Down Expand Up @@ -164,14 +184,14 @@ class Derevo(val c: blackbox.Context) {
}
c.prefix.tree match {
case q"new $_(..${instances})" =>
instances.map(buildInstance(_, cls, newType))
instances.flatMap(buildInstance(_, cls, newType))
case _ =>
c.error(c.prefix.tree.pos, s"FIXME: Could not match annotation tree `${c.prefix.tree}'")
Nil
}
}

private def buildInstance(tree: Tree, impl: ImplDef, newType: Option[Newtype]): Tree = {
private def buildInstance(tree: Tree, impl: ImplDef, newType: Option[Newtype]): List[Tree] = {
val typRef = impl match {
case cls: ClassDef => tq"${impl.name.toTypeName}"
case obj: ModuleDef =>
Expand All @@ -181,88 +201,80 @@ class Derevo(val c: blackbox.Context) {
}
}

val (mode, call) = tree match {
case q"$obj.$method(..$args)" => (nameAndTypes(obj), tree)

case q"$obj(..$args)" => (nameAndTypes(obj), tree)

case q"$obj" =>
val call = newType.fold(q"$obj.instance")(t => q"$obj.newtype[${t.underlying}].instance")
(nameAndTypes(obj), call)
}

val tn = TermName(mode.name)
val allTparams = impl match {
case cls: ClassDef => cls.tparams
case obj: ModuleDef => Nil
}
nameAndTypes(tree, newType).map { mode =>
val tn = TermName(mode.name)
val allTparams = impl match {
case cls: ClassDef => cls.tparams
case obj: ModuleDef => Nil
}

val tparams = allTparams.dropRight(mode.drop)
val pparams = allTparams.takeRight(mode.drop)
val tparams = allTparams.dropRight(mode.drop)
val pparams = allTparams.takeRight(mode.drop)

val tps = tparams.map(_.name)
def appTyp = tq"$typRef[..$tps]"
def allTnames = allTparams.map(_.name)
def lamTyp = tq"({ type Lam[..$pparams] = $typRef[..$allTnames] })#Lam"
val outTyp = if (pparams.isEmpty) appTyp else lamTyp
val tps = tparams.map(_.name)
def appTyp = tq"$typRef[..$tps]"
def allTnames = allTparams.map(_.name)
def lamTyp = tq"({ type Lam[..$pparams] = $typRef[..$allTnames] })#Lam"
val outTyp = if (pparams.isEmpty) appTyp else lamTyp

val resT = mkAppliedType(mode.to, outTyp)
val resT = mkAppliedType(mode.to, outTyp)

val callWithT = if (mode.passArgs) q"$call[$outTyp]" else call
val callWithT = if (mode.passArgs) q"${mode.call}[$outTyp]" else mode.call

def fixFirstTypeParam = {
val nothingT = c.typeOf[Nothing]
def fixFirstTypeParam = {
val nothingT = c.typeOf[Nothing]

c.typecheck(call, silent = true, withMacrosDisabled = true) match {
case q"$method[$nothing, ..$remainingTpes](..$args)" if nothing.tpe == nothingT =>
q"$method[$outTyp, ..$remainingTpes](..$args)"
case q"$method[$nothing, ..$remainingTpes]" if nothing.tpe == nothingT =>
q"$method[$outTyp, ..$remainingTpes]"
case _ => tree
c.typecheck(mode.call, silent = true, withMacrosDisabled = true) match {
case q"$method[$nothing, ..$remainingTpes](..$args)" if nothing.tpe == nothingT =>
q"$method[$outTyp, ..$remainingTpes](..$args)"
case q"$method[$nothing, ..$remainingTpes]" if nothing.tpe == nothingT =>
q"$method[$outTyp, ..$remainingTpes]"
case _ => tree
}
}
}

if (allTparams.isEmpty || allTparams.length <= mode.drop) {
if (mode.keepRefinements) {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit val $tn = $fixFirstTypeParam
"""
if (allTparams.isEmpty || allTparams.length <= mode.drop) {
if (mode.keepRefinements) {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit val $tn = $fixFirstTypeParam
"""
} else {
val resTc = if (newType.isDefined) mode.newtype else mode.to
val resT = mkAppliedType(resTc, tq"$typRef")

q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit val $tn: $resT = $callWithT
"""
}
} else {
val resTc = if (newType.isDefined) mode.newtype else mode.to
val resT = mkAppliedType(resTc, tq"$typRef")

q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit val $tn: $resT = $callWithT
"""
}
} else {

val implicits =
if (mode.cascade)
tparams.flatMap { tparam =>
val phantom = tparam.mods.annotations.exists { t => c.typecheck(t).tpe.typeSymbol == PhantomSymbol }
if (phantom) None
else {
val name = TermName(c.freshName("ev"))
val typ = tparam.name
val reqT = mkAppliedType(mode.from, tq"$typ")
Some(q"val $name: $reqT")
val implicits =
if (mode.cascade)
tparams.flatMap { tparam =>
val phantom = tparam.mods.annotations.exists { t => c.typecheck(t).tpe.typeSymbol == PhantomSymbol }
if (phantom) None
else {
val name = TermName(c.freshName("ev"))
val typ = tparam.name
val reqT = mkAppliedType(mode.from, tq"$typ")
Some(q"val $name: $reqT")
}
}
}
else Nil

if (mode.keepRefinements) {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit def $tn[..$tparams](implicit ..$implicits) = $fixFirstTypeParam
"""
} else {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT
"""
else Nil

if (mode.keepRefinements) {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit def $tn[..$tparams](implicit ..$implicits) = $fixFirstTypeParam
"""
} else {
q"""
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT
"""
}
}
}
}
Expand All @@ -283,19 +295,37 @@ class Derevo(val c: blackbox.Context) {
tq"$tc[$arg]"
}

private def nameAndTypes(obj: Tree): NameAndTypes = {
private def extractCall(tree: Tree, newType: Option[Newtype]): Tree = tree match {
case q"$obj.$method(..$args)" => tree
case q"$obj(..$args)" => tree
case q"$obj" => newType.fold(q"$obj.instance")(t => q"$obj.newtype[${t.underlying}].instance")
}

private def extractObj(tree: Tree): Tree = tree match {
case q"$obj.$method(..$args)" => obj
case q"$obj(..$args)" => obj
case q"$obj" => obj
}

private def nameAndTypes(tree: Tree, newType: Option[Newtype]): List[NameAndTypes] = {
val obj = extractObj(tree)
val mangledName = obj.toString.replaceAll("[^\\w]", "_")
val name = c.freshName(mangledName)

val objTyp = c.typecheck(obj).tpe
val call = extractCall(tree, newType)

val nt = objTyp match {
case IsSpecificDerivation(f, t, nt, d) => new NameAndTypes(name, f, t, nt, d, cascade = true)
case IsDerivation(f, t, nt, d) => new NameAndTypes(name, f, t, nt, d, cascade = true)
case IsCompositeDerivation(subs) =>
subs.map { case (f, t, nt, d, cascade, tree) =>
new NameAndTypes(extractCall(tree, newType), c.freshName(mangledName), f, t, nt, d, cascade)
}
case IsSpecificDerivation(f, t, nt, d) => List(new NameAndTypes(call, name, f, t, nt, d, cascade = true))
case IsDerivation(f, t, nt, d) => List(new NameAndTypes(call, name, f, t, nt, d, cascade = true))
case HKDerivation(f, t, nt, d) =>
objTyp match {
case ParamRequire(fr, _, _, _) => new NameAndTypes(name, fr, t, nt, d, cascade = true)
case _ => new NameAndTypes(name, f, t, nt, d, cascade = false)
case ParamRequire(fr, _, _, _) => List(new NameAndTypes(call, name, fr, t, nt, d, cascade = true))
case _ => List(new NameAndTypes(call, name, f, t, nt, d, cascade = false))
}

case _ => abort(s"$obj seems not extending InstanceDef traits")
Expand All @@ -311,7 +341,7 @@ class Derevo(val c: blackbox.Context) {
case _ => false
}

nt.copy(passArgs = passArgs, keepRefinements = keepRefinements)
nt.map(_.copy(passArgs = passArgs, keepRefinements = keepRefinements))
}

trait DerivationMatcher {
Expand Down Expand Up @@ -357,7 +387,8 @@ object Derevo {
private[Derevo] final case class NewtypeCls[tree](underlying: tree) extends NewtypeP[tree]
private[Derevo] final case class NewtypeMod[tree](underlying: tree, res: tree) extends NewtypeP[tree]

private[Derevo] final case class NameAndTypesP[typ](
private[Derevo] final case class NameAndTypesP[tree, typ](
call: tree,
name: String,
from: typ,
to: typ,
Expand Down
4 changes: 4 additions & 0 deletions modules/core/src/main/scala/derevo/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package derevo {
def macroTransform(annottees: Any*): Any = macro Derevo.deriveMacro
}

class composite(instances: Any*) extends StaticAnnotation

/** */
trait PassTypeArgs
trait KeepRefinements
Expand All @@ -26,6 +28,8 @@ package derevo {
trait DerivationKN17[TC[alg[btr[_, _], _, _]]] extends InstanceDef
trait SpecificDerivation[FromTC[_], ToTC[_], NT[_]] extends InstanceDef

class CompositeDerivation extends InstanceDef

}

package object derevo {
Expand Down
23 changes: 23 additions & 0 deletions modules/core/src/test/scala/derevo/CompositionSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package derevo

trait Part1[T]
trait Part2[T]

object p1 extends Derivation[Part1] {
def instance[T]: Part1[T] = new Part1[T] {}
}

object p2 extends Derivation[Part2] {
def instance[T]: Part2[T] = new Part2[T] {}
}

@composite(p1, p2)
object p1AndP2 extends CompositeDerivation

@derive(p1AndP2)
case class HasP1AndP2()

object RefinementTest {
val part1: Part1[HasP1AndP2] = implicitly
val part2: Part2[HasP1AndP2] = implicitly
}