Skip to content

Commit

Permalink
Add a workaround for sbt's classloader to find Router
Browse files Browse the repository at this point in the history
  • Loading branch information
xerial committed Mar 6, 2020
1 parent 4ea4b70 commit 73bc2d8
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 74 deletions.
8 changes: 4 additions & 4 deletions airframe-http/src/main/scala/wvlet/airframe/http/Router.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
*/
package wvlet.airframe.http

import wvlet.airframe.http.router.{ControllerRoute, Route, RouteMatch, RouteMatcher, RouterMacros}
import wvlet.airframe.http.router._
import wvlet.airframe.surface.{MethodSurface, Surface}
import wvlet.log.LogSupport

import scala.annotation.tailrec
import scala.language.higherKinds
import scala.language.experimental.macros
import scala.language.higherKinds

/**
* Router defines mappings from HTTP requests to Routes.
Expand All @@ -43,7 +43,7 @@ case class Router(
localRoutes: Seq[Route] = Seq.empty,
filterSurface: Option[Surface] = None,
filterInstance: Option[HttpFilterType] = None
) {
) extends LogSupport {
def isEmpty = this eq Router.empty

def isLeafFilter = children.isEmpty && localRoutes.isEmpty
Expand Down Expand Up @@ -134,7 +134,7 @@ case class Router(
// Add methods annotated with @Endpoint
val newRoutes =
controllerMethodSurfaces
.map(m => (m, m.findAnnotationOf[Endpoint]))
.map { m => (m, m.findAnnotationOf[Endpoint]) }
.collect {
case (m: ReflectMethodSurface, Some(endPoint)) =>
ControllerRoute(controllerSurface, endPoint.method(), prefixPath + endPoint.path(), m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ case class DbConfig(
def withUser(user: String): DbConfig =
this.copy(user = Option(user))

def withPassword(password: String): DbConfig =
this.copy(password = Some(password))
def withPassword(password: String): DbConfig = _
this.copy(password = Some(password))

def jdbcDriverName: String = {
driver match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
*/
package wvlet.airframe.surface.reflect

import wvlet.airframe.surface.{AnyRefSurface, MethodParameter, MethodSurface, Surface}
import java.{lang => jl}

import wvlet.airframe.surface.{MethodParameter, MethodSurface, Surface}
import wvlet.log.LogSupport

import scala.util.Try
Expand All @@ -26,20 +26,14 @@ import scala.util.Try
case class ReflectMethodSurface(mod: Int, owner: Surface, name: String, returnType: Surface, args: Seq[MethodParameter])
extends MethodSurface
with LogSupport {
private def findActualMethod(cls: Class[_]): Option[jl.reflect.Method] = {
// For `symbol-based method names`, we need to encode Scala method names into the bytecode format used in class files.
val rawMethodName = scala.reflect.NameTransformer.encode(name)
Try(cls.getDeclaredMethod(rawMethodName, args.map(_.surface.rawType): _*)).toOption
}

private lazy val method: Option[jl.reflect.Method] = findActualMethod(owner.rawType)
private lazy val method: Option[jl.reflect.Method] = ReflectMethodSurface.findMethod(owner.rawType, this)

def getMethod: Option[jl.reflect.Method] = method

override def call(obj: Any, x: Any*): Any = {
val targetMethod: Option[jl.reflect.Method] = method.orElse {
// RefinedTypes may have new methods which cannot be found from the owner
Option(obj).flatMap(objRef => findActualMethod(objRef.getClass))
Option(obj).flatMap(objRef => ReflectMethodSurface.findMethod(objRef.getClass, this))
}

targetMethod match {
Expand All @@ -58,3 +52,13 @@ case class ReflectMethodSurface(mod: Int, owner: Surface, name: String, returnTy
}
}
}

object ReflectMethodSurface {

def findMethod(owner: Class[_], m: MethodSurface): Option[jl.reflect.Method] = {
// For `symbol-based method names`, we need to encode Scala method names into the bytecode format used in class files.
val rawMethodName = scala.reflect.NameTransformer.encode(m.name)
Try(owner.getDeclaredMethod(rawMethodName, m.args.map(_.surface.rawType): _*)).toOption
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@ object ReflectSurfaceFactory extends LogSupport {
apply(tpe)
}
def ofClass(cls: Class[_]): Surface = {
val tpe = scala.reflect.runtime.currentMirror.classSymbol(cls).toType
ofType(tpe)
val cs = scala.reflect.runtime.currentMirror.classSymbol(cls)
val tpe = cs.toType
ofType(tpe) match {
// Workaround for sbt's layered class loader, which cannot find the original classes using the reflect mirror
case Alias(_, _, AnyRefSurface) if cs.isTrait =>
new GenericSurface(cls)
case other => other
}
}

private def getPrimaryConstructorOf(cls: Class[_]): Option[Constructor[_]] = {
Expand Down Expand Up @@ -134,15 +140,15 @@ object ReflectSurfaceFactory extends LogSupport {

def methodsOf[A: ru.WeakTypeTag]: Seq[MethodSurface] = methodsOfType(implicitly[ru.WeakTypeTag[A]].tpe)

def methodsOfType(tpe: ru.Type): Seq[MethodSurface] = {
def methodsOfType(tpe: ru.Type, cls: Option[Class[_]] = None): Seq[MethodSurface] = {
methodSurfaceCache.getOrElseUpdate(fullTypeNameOf(tpe), {
new SurfaceFinder().createMethodSurfaceOf(tpe)
new SurfaceFinder().createMethodSurfaceOf(tpe, cls)
})
}

def methodsOfClass(cls: Class[_]): Seq[MethodSurface] = {
val tpe = scala.reflect.runtime.currentMirror.classSymbol(cls).toType
methodsOfType(tpe)
methodsOfType(tpe, Some(cls))
}

private[surface] def mirror = ru.runtimeMirror(Thread.currentThread.getContextClassLoader)
Expand Down Expand Up @@ -215,7 +221,7 @@ object ReflectSurfaceFactory extends LogSupport {
m.owner == t.typeSymbol || t.baseClasses.filter(nonObject).exists(_ == m.owner)
}

def createMethodSurfaceOf(targetType: ru.Type): Seq[MethodSurface] = {
def createMethodSurfaceOf(targetType: ru.Type, cls: Option[Class[_]] = None): Seq[MethodSurface] = {
val name = fullTypeNameOf(targetType)
if (methodSurfaceCache.contains(name)) {
methodSurfaceCache(name)
Expand All @@ -236,7 +242,7 @@ object ReflectSurfaceFactory extends LogSupport {
for (m <- localMethods) {
try {
val mod = modifierBitMaskOf(m)
val owner = surfaceOf(targetType)
val owner = cls.map(ofClass(_)).getOrElse(surfaceOf(targetType))
val name = m.name.decodedName.toString
val ret = surfaceOf(m.returnType)
val args = methodParametersOf(targetType, m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package wvlet.airframe.surface.reflect

import wvlet.airframe.surface.{Surface, SurfaceSpec}
import wvlet.airframe.surface.{SurfaceSpec, secret}

import scala.concurrent.Future
import scala.util.Try
Expand All @@ -34,6 +34,11 @@ object RuntimeExamples {
case class E(a: A)

case class F(p0: Int = 10)

trait TraitOnly {
@secret
def methodInTrait: Unit = {}
}
}

/**
Expand All @@ -54,6 +59,18 @@ class RuntimeSurfaceTest extends SurfaceSpec {
assert(b.isPrimitive == false)
}

def `resolve trait type`: Unit = {
val s = ReflectSurfaceFactory.ofClass(classOf[TraitOnly])
s.isAlias shouldBe false
s.rawType shouldBe classOf[TraitOnly]

val m = ReflectSurfaceFactory.methodsOfClass(classOf[TraitOnly]).head
m.owner.isAlias shouldNotBe true
m.owner.rawType shouldBe classOf[TraitOnly]

m.findAnnotationOf[secret] shouldBe defined
}

def `Find surface from Class[_]` : Unit = {
checkPrimitive(RuntimeSurface.of[Boolean], "Boolean")
checkPrimitive(RuntimeSurface.of[Byte], "Byte")
Expand Down
3 changes: 1 addition & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -963,14 +963,13 @@ lazy val sbtAirframe =
name := "sbt-airframe",
description := "sbt plugin for helping programming with Airframe",
scalaVersion := SCALA_2_12,
crossScalaVersions := Seq(SCALA_2_12),
crossSbtVersions := Vector("1.2.8"),
scriptedLaunchOpts := {
scriptedLaunchOpts.value ++
Seq("-Xmx1024M", "-Dplugin.version=" + version.value)
},
scriptedDependencies := {
// Publish all dependencies for runnign scripted tests
// Publish all dependencies for running the scripted tests
scriptedDependencies.value
publishLocal.all(ScopeFilter(inDependencies(http))).value
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,14 @@ import java.net.URLClassLoader

import sbt.Keys._
import sbt._
import wvlet.airframe.surface.SurfaceFactory
import wvlet.airframe.surface.reflect.{ReflectMethodSurface, ReflectSurfaceFactory}
import wvlet.airframe.sbt.http.HttpPlugin.AirframeHttpKeys
import wvlet.log.LogSupport
import wvlet.log.io.Resource

import scala.util.{Failure, Success, Try}

/**
*
*/
object AirframePlugin extends AutoPlugin with LogSupport {

trait AirframeHttpKeys {
val airframeHttpPackages = settingKey[Seq[String]]("The list of package names containing Airframe HTTP interfaces")
val airframeHttpGenerateClient = taskKey[Seq[File]]("Generate the client code")

}

object autoImport extends AirframeHttpKeys
import autoImport._

Expand All @@ -42,49 +32,18 @@ object AirframePlugin extends AutoPlugin with LogSupport {

override def projectSettings = Seq(
airframeHttpPackages := Seq(),
airframeHttpGenerateClient := {
wvlet.airframe.log.init
airframeHttpRouter := {
val files = (sources in Compile).value
val baseDirs = (sourceDirectories in Compile).value
val classDir = (classDirectory in Runtime).value
val classLoader = new URLClassLoader(Array(classDir.toURI.toURL), getClass.getClassLoader)
findHttpInterface(baseDirs, files, classLoader)
val router = HttpPlugin.buildRouter(baseDirs, files, classLoader)
info(router)
router
},
airframeHttpGenerateClient := {
val router = airframeHttpRouter.value
Seq.empty
}
)

def findHttpInterface(sourceDirs: Seq[File], files: Seq[File], classLoader: ClassLoader): Unit = {
def relativise(f: File): Option[File] = {
sourceDirs.collectFirst { case dir if f.relativeTo(dir).isDefined => f.relativeTo(dir).get }
}
val lst = for (f <- files; r <- relativise(f)) yield r

val classes = lst
.map { f => f.getPath }
.filter(_.endsWith(".scala"))
.map(_.stripSuffix(".scala").replaceAll("/", "."))
.map { clsName =>
info(clsName)
Try(classLoader.loadClass(clsName)) match {
case x if x.isSuccess => x
case f @ Failure(e) =>
warn(e)
f
}
}
.collect {
case Success(cls) =>
info(cls)
cls
}

for (cl <- classes) yield {
val s = ReflectSurfaceFactory.ofClass(cl)
val m = ReflectSurfaceFactory.methodsOfClass(cl)

info(s)
info(m)
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.sbt.http
import sbt.{File, settingKey, taskKey, _}
import wvlet.airframe.http.{Endpoint, Router}
import wvlet.airframe.surface.{MethodSurface, Surface}
import wvlet.log.LogSupport

import scala.util.{Failure, Success, Try}

/**
*
*/
object HttpPlugin extends LogSupport {
wvlet.airframe.log.init

trait AirframeHttpKeys {
val airframeHttpPackages = settingKey[Seq[String]]("A list of package names containing Airframe HTTP interfaces")
val airframeHttpGenerateClient = taskKey[Seq[File]]("Generate the client code")
val airframeHttpRouter = taskKey[Router]("Airframe Router")
}

case class HttpInterface(surface: Surface, endpoints: Seq[HttpEndpoint])
case class HttpEndpoint(endpoint: Endpoint, method: MethodSurface)

/**
* Find Airframe HTTP interfaces and build a Router object
* @param sourceDirs
* @param files
* @param classLoader
*/
def buildRouter(sourceDirs: Seq[File], files: Seq[File], classLoader: ClassLoader): Router = {
def relativise(f: File): Option[File] = {
sourceDirs.collectFirst { case dir if f.relativeTo(dir).isDefined => f.relativeTo(dir).get }
}
val lst = for (f <- files; r <- relativise(f)) yield r

val classes = lst
.map { f => f.getPath }
.filter(_.endsWith(".scala"))
.map(_.stripSuffix(".scala").replaceAll("/", "."))
.map { clsName =>
trace(s"Searching endpoints in ${clsName}")
Try(classLoader.loadClass(clsName)) match {
case x if x.isSuccess => x
case f @ Failure(e) =>
f
}
}
.collect {
case Success(cls) =>
cls
}

var router = Router.empty
for (cl <- classes) yield {
import wvlet.airframe.surface.reflect._
val s = ReflectSurfaceFactory.ofClass(cl)
val methods = ReflectSurfaceFactory.methodsOfClass(cl)
if (methods.exists(_.findAnnotationOf[Endpoint].isDefined)) {
info(s"Adding ${s.fullName} to Router")
router = router.addInternal(s, methods)
}
}
router
}

}

0 comments on commit 73bc2d8

Please sign in to comment.