From 1a410f0a2235d6eae02bc4a006c50ee49683d273 Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Thu, 27 Aug 2020 04:30:03 -0700 Subject: [PATCH] Ability to relocate/shade in assembly (#947) This gives the ability to relocate/shade in assembly on a global level. This is a good first step but we might want later to provide the ability to shade only for a given lib. This is currently not possible because at the time we get the class path we lost the information of the dependency tree. Resolves https://github.com/lihaoyi/mill/issues/355 Example of usage: override def assemblyRules = Assembly.defaultRules ++ Seq(Assembly.Rule.Relocate("shapeless.**", "com.netflix.data.playback.shade.shapless.@1")) Pull request: https://github.com/lihaoyi/mill/pull/947 --- build.sc | 4 +- docs/pages/2 - Configuring Mill.md | 5 +- main/src/modules/Assembly.scala | 108 +++++++++--------- main/src/modules/Jvm.scala | 67 +++++------ .../hello-world-deps/core/src/Main.scala | 13 +++ .../resources/hello-world/core/src/Main.scala | 1 - scalalib/test/src/HelloWorldTests.scala | 40 ++++++- 7 files changed, 138 insertions(+), 100 deletions(-) create mode 100644 scalalib/test/resources/hello-world-deps/core/src/Main.scala diff --git a/build.sc b/build.sc index 3cbac66b5bf..a5854e46c7a 100755 --- a/build.sc +++ b/build.sc @@ -74,6 +74,7 @@ object Deps { val utest = ivy"com.lihaoyi::utest:0.7.4" val zinc = ivy"org.scala-sbt::zinc:1.4.0-M1" val bsp = ivy"ch.epfl.scala:bsp4j:2.0.0-M4" + val jarjarabrams = ivy"com.eed3si9n.jarjarabrams::jarjar-abrams-core:0.3.0" } trait MillPublishModule extends PublishModule{ @@ -167,7 +168,8 @@ object main extends MillModule { // Necessary so we can share the JNA classes throughout the build process Deps.jna, Deps.jnaPlatform, - Deps.coursier + Deps.coursier, + Deps.jarjarabrams ) def generatedSources = T { diff --git a/docs/pages/2 - Configuring Mill.md b/docs/pages/2 - Configuring Mill.md index 082c0aa9f3d..143657080b9 100644 --- a/docs/pages/2 - Configuring Mill.md +++ b/docs/pages/2 - Configuring Mill.md @@ -513,7 +513,7 @@ compilation output, but if there is more than one or the main class comes from some library you can explicitly specify which one to use. This also adds the main class to your `foo.jar` and `foo.assembly` jars. -## Merge/exclude files from assembly +## Merge/exclude/relocate files from assembly When you make a runnable jar of your project with `assembly` command, you may want to exclude some files from a final jar (like signature files, and manifest files from library jars), @@ -532,7 +532,8 @@ object foo extends ScalaModule { def assemblyRules = Seq( Rule.Append("application.conf"), // all application.conf files will be concatenated into single file Rule.AppendPattern(".*\\.conf"), // all *.conf files will be concatenated into single file - Rule.ExcludePattern("*.temp") // all *.temp files will be excluded from a final jar + Rule.ExcludePattern("*.temp"), // all *.temp files will be excluded from a final jar + Rule.Relocate("shapeless.**", "shade.shapless.@1") // the `shapeless` package will be shaded under the `shade` package ) } ``` diff --git a/main/src/modules/Assembly.scala b/main/src/modules/Assembly.scala index b2c9c44462b..7e1b9794607 100644 --- a/main/src/modules/Assembly.scala +++ b/main/src/modules/Assembly.scala @@ -1,13 +1,13 @@ package mill.modules -import java.io.InputStream +import com.eed3si9n.jarjarabrams.{ShadePattern, Shader} +import java.io.{ByteArrayInputStream, InputStream} import java.util.jar.JarFile import java.util.regex.Pattern - -import geny.Generator import mill.Agg - +import os.Generator import scala.collection.JavaConverters._ +import scala.tools.nsc.io.Streamable object Assembly { @@ -32,13 +32,18 @@ object Assembly { case class Exclude(path: String) extends Rule + case class Relocate(from: String, to: String) extends Rule + object ExcludePattern { def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern)) } case class ExcludePattern(pattern: Pattern) extends Rule } - def groupAssemblyEntries(inputPaths: Agg[os.Path], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = { + def groupAssemblyEntries( + mappings: Generator[(String, UnopenedInputStream)], + assemblyRules: Seq[Assembly.Rule] + ): Map[String, GroupedEntry] = { val rulesMap = assemblyRules.collect { case r@Rule.Append(path, _) => path -> r case r@Rule.Exclude(path) => path -> r @@ -52,23 +57,19 @@ object Assembly { case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_) } - classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) { - case (entries, entry) => - val mapping = entry.mapping - + mappings.foldLeft(Map.empty[String, GroupedEntry]) { + case (entries, (mapping, entry)) => rulesMap.get(mapping) match { case Some(_: Assembly.Rule.Exclude) => entries case Some(a: Assembly.Rule.Append) => val newEntry = entries.getOrElse(mapping, AppendEntry(Nil, a.separator)).append(entry) entries + (mapping -> newEntry) - case _ if excludePatterns.exists(_(mapping)) => entries case _ if appendPatterns.exists(_(mapping)) => val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry) entries + (mapping -> newEntry) - case _ if !entries.contains(mapping) => entries + (mapping -> WriteOnceEntry(entry)) case _ => @@ -77,52 +78,51 @@ object Assembly { } } - private def classpathIterator(inputPaths: Agg[os.Path]): Generator[AssemblyEntry] = { - Generator.from(inputPaths) - .filter(os.exists) - .flatMap { - p => - if (os.isFile(p)) { - val jf = new JarFile(p.toIO) - Generator.from( - for(entry <- jf.entries().asScala if !entry.isDirectory) - yield JarFileEntry(entry.getName, () => jf.getInputStream(entry)) - ) - } - else { - os.walk.stream(p) - .filter(os.isFile) - .map(sub => PathEntry(sub.relativeTo(p).toString, sub)) + def loadShadedClasspath( + inputPaths: Agg[os.Path], + assemblyRules: Seq[Assembly.Rule] + ): Generator[(String, UnopenedInputStream)] = { + val shadeRules = assemblyRules.collect { + case Rule.Relocate(from, to) => ShadePattern.Rename(List(from -> to)).inAll + } + val shader = + if (shadeRules.isEmpty) (name: String, inputStream: UnopenedInputStream) => Some(name -> inputStream) + else { + val shader = Shader.bytecodeShader(shadeRules, verbose = false) + (name: String, inputStream: UnopenedInputStream) => + shader(Streamable.bytes(inputStream()), name).map { + case (bytes, name) => + name -> (() => new ByteArrayInputStream(bytes) { override def close(): Unit = inputStream().close() }) } } - } -} - -private[modules] sealed trait GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry -} - -private[modules] object AppendEntry { - val empty: AppendEntry = AppendEntry(Nil, Assembly.defaultSeparator) -} -private[modules] case class AppendEntry(entries: List[AssemblyEntry], separator: String) extends GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries) -} - -private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry = this -} - -private[this] sealed trait AssemblyEntry { - def mapping: String - def inputStream: InputStream -} + Generator.from(inputPaths).filter(os.exists).flatMap { path => + if (os.isFile(path)) { + val jarFile = new JarFile(path.toIO) + Generator.from(jarFile.entries().asScala.filterNot(_.isDirectory)) + .flatMap(entry => shader(entry.getName, () => jarFile.getInputStream(entry))) + } + else { + os.walk + .stream(path) + .filter(os.isFile) + .flatMap(subPath => shader(subPath.relativeTo(path).toString, () => os.read.inputStream(subPath))) + } + } + } -private[this] case class PathEntry(mapping: String, path: os.Path) extends AssemblyEntry { - def inputStream: InputStream = os.read.inputStream(path) -} + type UnopenedInputStream = () => InputStream -private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry { - def inputStream: InputStream = getIs() + private[modules] sealed trait GroupedEntry { + def append(entry: UnopenedInputStream): GroupedEntry + } + private[modules] object AppendEntry { + val empty: AppendEntry = AppendEntry(Nil, defaultSeparator) + } + private[modules] case class AppendEntry(inputStreams: Seq[UnopenedInputStream], separator: String) extends GroupedEntry { + def append(inputStream: UnopenedInputStream): GroupedEntry = copy(inputStreams = inputStreams :+ inputStream) + } + private[modules] case class WriteOnceEntry(inputStream: UnopenedInputStream) extends GroupedEntry { + def append(entry: UnopenedInputStream): GroupedEntry = this + } } diff --git a/main/src/modules/Jvm.scala b/main/src/modules/Jvm.scala index e006b640595..f2ce892cb80 100644 --- a/main/src/modules/Jvm.scala +++ b/main/src/modules/Jvm.scala @@ -5,21 +5,19 @@ import java.lang.reflect.Modifier import java.net.URI import java.nio.file.{FileSystems, Files, StandardOpenOption} import java.nio.file.attribute.PosixFilePermission -import java.util.Collections import java.util.jar.{Attributes, JarEntry, JarFile, JarOutputStream, Manifest} - -import coursier.{Dependency, Fetch, Repository, Resolution} +import coursier.{Dependency, Repository, Resolution} import coursier.util.{Gather, Task} -import geny.Generator +import java.util.Collections import mill.main.client.InputPumper import mill.eval.{PathRef, Result} import mill.util.Ctx import mill.api.IO import mill.api.Loose.Agg - +import mill.modules.Assembly.{AppendEntry, WriteOnceEntry} import scala.collection.mutable -import scala.collection.JavaConverters._ -import upickle.default.{macroRW, ReadWriter => RW} +import scala.jdk.CollectionConverters._ +import upickle.default.{ReadWriter => RW} object Jvm { /** @@ -220,7 +218,6 @@ object Jvm { * be provided for the jar. An optional filter function may also be provided to * selectively include/exclude specific files. * @param inputPaths - `Agg` of `os.Path`s containing files to be included in the jar - * @param mainClass - optional main class for the jar * @param fileFilter - optional file filter to select files to be included. * Given a `os.Path` (from inputPaths) and a `os.RelPath` for the individual file, * return true if the file is to be included in the jar. @@ -270,18 +267,14 @@ object Jvm { base: Option[os.Path] = None, assemblyRules: Seq[Assembly.Rule] = Assembly.defaultRules) (implicit ctx: Ctx.Dest with Ctx.Log): PathRef = { - val tmp = ctx.dest / "out-tmp.jar" val baseUri = "jar:" + tmp.toIO.getCanonicalFile.toURI.toASCIIString - val hm = new java.util.HashMap[String, String]() - - base match{ - case Some(b) => os.copy(b, tmp) - case None => hm.put("create", "true") + val hm = base.fold(Map("create" -> "true")) { b => + os.copy(b, tmp) + Map.empty } - - val zipFs = FileSystems.newFileSystem(URI.create(baseUri), hm) + val zipFs = FileSystems.newFileSystem(URI.create(baseUri), hm.asJava) val manifestPath = zipFs.getPath(JarFile.MANIFEST_NAME) Files.createDirectories(manifestPath.getParent) @@ -293,30 +286,26 @@ object Jvm { manifest.build.write(manifestOut) manifestOut.close() - Assembly.groupAssemblyEntries(inputPaths, assemblyRules).view - .foreach { - case (mapping, AppendEntry(entries, separator)) => - val path = zipFs.getPath(mapping).toAbsolutePath - val separated = - if (entries.isEmpty) Nil - else - entries.head +: entries.tail.flatMap { e => - List(JarFileEntry(e.mapping, () => new ByteArrayInputStream(separator.getBytes)), e) - } - val concatenated = new SequenceInputStream( - Collections.enumeration(separated.map(_.inputStream).asJava)) - writeEntry(path, concatenated, append = true) - case (mapping, WriteOnceEntry(entry)) => - val path = zipFs.getPath(mapping).toAbsolutePath - writeEntry(path, entry.inputStream, append = false) - } - + val mappings = Assembly.loadShadedClasspath(inputPaths, assemblyRules) + Assembly.groupAssemblyEntries(mappings, assemblyRules).foreach { + case (mapping, entry) => + val path = zipFs.getPath(mapping).toAbsolutePath + entry match { + case entry: AppendEntry => + val separated = entry.inputStreams + .flatMap(inputStream => Seq(new ByteArrayInputStream(entry.separator.getBytes), inputStream())) + .drop(1) + val concatenated = new SequenceInputStream(Collections.enumeration(separated.asJava)) + writeEntry(path, concatenated, append = true) + case entry: WriteOnceEntry => writeEntry(path, entry.inputStream(), append = false) + } + } zipFs.close() - val output = ctx.dest / "out.jar" + val output = ctx.dest / "out.jar" // Prepend shell script and make it executable if (prependShellScript.isEmpty) os.move(tmp, output) - else{ + else { val lineSep = if (!prependShellScript.endsWith("\n")) "\n\r\n" else "" os.write(output, prependShellScript + lineSep) os.write.append(output, os.read.inputStream(tmp)) @@ -335,16 +324,16 @@ object Jvm { PathRef(output) } - private def writeEntry(p: java.nio.file.Path, is: InputStream, append: Boolean): Unit = { + private def writeEntry(p: java.nio.file.Path, inputStream: InputStream, append: Boolean): Unit = { if (p.getParent != null) Files.createDirectories(p.getParent) val options = if(append) Seq(StandardOpenOption.APPEND, StandardOpenOption.CREATE) else Seq(StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE) val outputStream = java.nio.file.Files.newOutputStream(p, options:_*) - IO.stream(is, outputStream) + IO.stream(inputStream, outputStream) outputStream.close() - is.close() + inputStream.close() } def universalScript(shellCommands: String, diff --git a/scalalib/test/resources/hello-world-deps/core/src/Main.scala b/scalalib/test/resources/hello-world-deps/core/src/Main.scala new file mode 100644 index 00000000000..4ac7cc3e330 --- /dev/null +++ b/scalalib/test/resources/hello-world-deps/core/src/Main.scala @@ -0,0 +1,13 @@ +import akka.http.scaladsl.model.{ContentTypes, HttpEntity} +import akka.http.scaladsl.server.Directives._ + +object Main extends App { + val route = + path("hello") { + get { + complete(HttpEntity(ContentTypes.`text/html(UTF-8)`, "

Say hello to akka-http

")) + } + } + + println(route) +} diff --git a/scalalib/test/resources/hello-world/core/src/Main.scala b/scalalib/test/resources/hello-world/core/src/Main.scala index d86f4da8732..63f587e4d4f 100644 --- a/scalalib/test/resources/hello-world/core/src/Main.scala +++ b/scalalib/test/resources/hello-world/core/src/Main.scala @@ -1,4 +1,3 @@ -import scala.collection._ import java.nio.file.{Files, Paths} import java.sql.Date import java.time.LocalDate diff --git a/scalalib/test/src/HelloWorldTests.scala b/scalalib/test/src/HelloWorldTests.scala index e91b24b8be3..4b509465955 100644 --- a/scalalib/test/src/HelloWorldTests.scala +++ b/scalalib/test/src/HelloWorldTests.scala @@ -87,6 +87,14 @@ object HelloWorldTests extends TestSuite { } } + object HelloWorldAkkaHttpRelocate extends HelloBase { + object core extends HelloWorldModuleWithMain { + def ivyDeps = akkaHttpDeps + + def assemblyRules = Seq(Assembly.Rule.Relocate("akka.**", "shaded.akka.@1")) + } + } + object HelloWorldAkkaHttpNoRules extends HelloBase { object core extends HelloWorldModuleWithMain { def ivyDeps = akkaHttpDeps @@ -684,9 +692,7 @@ object HelloWorldTests extends TestSuite { referenceContent.contains("Akka Stream Reference Config File"), // our application config is present too referenceContent.contains("My application Reference Config File"), - referenceContent.contains( - """akka.http.client.user-agent-header="hello-world-client"""" - ) + referenceContent.contains("""akka.http.client.user-agent-header="hello-world-client"""") ) } @@ -767,6 +773,34 @@ object HelloWorldTests extends TestSuite { resourcePath = helloWorldMultiResourcePath ) + def checkRelocate[M <: TestUtil.BaseModule](module: M, + target: Target[PathRef], + resourcePath: os.Path = resourcePath + ) = + workspaceTest(module, resourcePath) { eval => + val Right((result, _)) = eval.apply(target) + + val jarFile = new JarFile(result.path.toIO) + + assert(!jarEntries(jarFile).contains("akka/http/scaladsl/model/HttpEntity.class")) + assert(jarEntries(jarFile).contains("shaded/akka/http/scaladsl/model/HttpEntity.class")) + } + + 'relocate - { + 'withDeps - checkRelocate( + HelloWorldAkkaHttpRelocate, + HelloWorldAkkaHttpRelocate.core.assembly + ) + + 'run - workspaceTest( + HelloWorldAkkaHttpRelocate, + resourcePath = os.pwd / 'scalalib / 'test / 'resources / "hello-world-deps" + ) { eval => + val Right((_, evalCount)) = eval.apply(HelloWorldAkkaHttpRelocate.core.runMain("Main")) + assert(evalCount > 0) + } + } + 'writeDownstreamWhenNoRule - { 'withDeps - workspaceTest(HelloWorldAkkaHttpNoRules) { eval => val Right((result, _)) = eval.apply(HelloWorldAkkaHttpNoRules.core.assembly)