Skip to content

Commit

Permalink
Ability to relocate/shade in assembly
Browse files Browse the repository at this point in the history
  • Loading branch information
joan38 committed Aug 24, 2020
1 parent a967d1a commit 59ec2f4
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 110 deletions.
4 changes: 3 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object Deps {
val utest = ivy"com.lihaoyi::utest:0.7.4"
val zinc = ivy"org.scala-sbt::zinc:1.4.0-M1"
val bsp = ivy"ch.epfl.scala:bsp4j:2.0.0-M4"
val jarjarabrams = ivy"com.eed3si9n.jarjarabrams::jarjar-abrams-core:0.3.0"
}

trait MillPublishModule extends PublishModule{
Expand Down Expand Up @@ -167,7 +168,8 @@ object main extends MillModule {
// Necessary so we can share the JNA classes throughout the build process
Deps.jna,
Deps.jnaPlatform,
Deps.coursier
Deps.coursier,
Deps.jarjarabrams
)

def generatedSources = T {
Expand Down
5 changes: 3 additions & 2 deletions docs/pages/2 - Configuring Mill.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ compilation output, but if there is more than one or the main class comes from
some library you can explicitly specify which one to use. This also adds the
main class to your `foo.jar` and `foo.assembly` jars.

## Merge/exclude files from assembly
## Merge/exclude/relocate files from assembly

When you make a runnable jar of your project with `assembly` command,
you may want to exclude some files from a final jar (like signature files, and manifest files from library jars),
Expand All @@ -532,7 +532,8 @@ object foo extends ScalaModule {
def assemblyRules = Seq(
Rule.Append("application.conf"), // all application.conf files will be concatenated into single file
Rule.AppendPattern(".*\\.conf"), // all *.conf files will be concatenated into single file
Rule.ExcludePattern("*.temp") // all *.temp files will be excluded from a final jar
Rule.ExcludePattern("*.temp"), // all *.temp files will be excluded from a final jar
Rule.Relocate("shapeless.**", "shade.shapless.@1") // the `shapeless` package will be shaded under the `shade` package
)
}
```
Expand Down
110 changes: 51 additions & 59 deletions main/src/modules/Assembly.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package mill.modules

import java.io.InputStream
import com.eed3si9n.jarjarabrams.{ShadePattern, Shader}
import java.io.{ByteArrayInputStream, InputStream, SequenceInputStream}
import java.util.jar.JarFile
import java.util.regex.Pattern

import geny.Generator
import mill.Agg

import os.Generator
import scala.collection.JavaConverters._
import scala.tools.nsc.io.Streamable

object Assembly {

Expand All @@ -32,13 +32,20 @@ object Assembly {

case class Exclude(path: String) extends Rule

case class Relocate(from: String, to: String) extends Rule

object ExcludePattern {
def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern))
}
case class ExcludePattern(pattern: Pattern) extends Rule
}

def groupAssemblyEntries(inputPaths: Agg[os.Path], assemblyRules: Seq[Assembly.Rule]): Map[String, GroupedEntry] = {
type UnopenedInputStream = () => InputStream

def groupAssemblyEntries(
mappings: Generator[(String, UnopenedInputStream)],
assemblyRules: Seq[Assembly.Rule]
): Map[String, UnopenedInputStream] = {
val rulesMap = assemblyRules.collect {
case r@Rule.Append(path, _) => path -> r
case r@Rule.Exclude(path) => path -> r
Expand All @@ -52,77 +59,62 @@ object Assembly {
case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_)
}

classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) {
case (entries, entry) =>
val mapping = entry.mapping

mappings.foldLeft(Map.empty[String, UnopenedInputStream]) {
case (entries, (mapping, entry)) =>
rulesMap.get(mapping) match {
case Some(_: Assembly.Rule.Exclude) =>
entries
case Some(a: Assembly.Rule.Append) =>
val newEntry = entries.getOrElse(mapping, AppendEntry(Nil, a.separator)).append(entry)
val newEntry =
entries.get(mapping).fold(entry)(previous => () => concatEntries(previous(), a.separator, entry()))
entries + (mapping -> newEntry)

case _ if excludePatterns.exists(_(mapping)) =>
entries
case _ if appendPatterns.exists(_(mapping)) =>
val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry)
val newEntry =
entries.get(mapping).fold(entry)(previous => () => concatEntries(previous(), defaultSeparator, entry()))
entries + (mapping -> newEntry)

case _ if !entries.contains(mapping) =>
entries + (mapping -> WriteOnceEntry(entry))
entries + (mapping -> entry)
case _ =>
entries
}
}
}

private def classpathIterator(inputPaths: Agg[os.Path]): Generator[AssemblyEntry] = {
Generator.from(inputPaths)
.filter(os.exists)
.flatMap {
p =>
if (os.isFile(p)) {
val jf = new JarFile(p.toIO)
Generator.from(
for(entry <- jf.entries().asScala if !entry.isDirectory)
yield JarFileEntry(entry.getName, () => jf.getInputStream(entry))
)
}
else {
os.walk.stream(p)
.filter(os.isFile)
.map(sub => PathEntry(sub.relativeTo(p).toString, sub))
private def concatEntries(first: InputStream, separator: String, second: InputStream) =
new SequenceInputStream(new SequenceInputStream(first, new ByteArrayInputStream(separator.getBytes)), second)

def loadShadedClasspath(
inputPaths: Agg[os.Path],
assemblyRules: Seq[Assembly.Rule]
): Generator[(String, UnopenedInputStream)] = {
val shadeRules = assemblyRules.collect {
case Rule.Relocate(from, to) => ShadePattern.Rename(List(from -> to)).inAll
}
val shader =
if (shadeRules.isEmpty) (name: String, inputStream: UnopenedInputStream) => Some(name -> inputStream)
else {
val shader = Shader.bytecodeShader(shadeRules, verbose = false)
(name: String, inputStream: UnopenedInputStream) =>
shader(Streamable.bytes(inputStream()), name).map {
case (bytes, name) =>
name -> (() => new ByteArrayInputStream(bytes) { override def close(): Unit = inputStream().close() })
}
}
}
}

private[modules] sealed trait GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry
}

private[modules] object AppendEntry {
val empty: AppendEntry = AppendEntry(Nil, Assembly.defaultSeparator)
}

private[modules] case class AppendEntry(entries: List[AssemblyEntry], separator: String) extends GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries)
}

private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry {
def append(entry: AssemblyEntry): GroupedEntry = this
}

private[this] sealed trait AssemblyEntry {
def mapping: String
def inputStream: InputStream
}

private[this] case class PathEntry(mapping: String, path: os.Path) extends AssemblyEntry {
def inputStream: InputStream = os.read.inputStream(path)
}

private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry {
def inputStream: InputStream = getIs()
Generator.from(inputPaths).filter(os.exists).flatMap { path =>
if (os.isFile(path)) {
val jarFile = new JarFile(path.toIO)
Generator.from(jarFile.entries().asScala.filterNot(_.isDirectory))
.flatMap(entry => shader(entry.getName, () => jarFile.getInputStream(entry)))
}
else {
os.walk
.stream(path)
.filter(os.isFile)
.flatMap(subPath => shader(subPath.relativeTo(path).toString, () => os.read.inputStream(subPath)))
}
}
}
}
58 changes: 16 additions & 42 deletions main/src/modules/Jvm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@ import java.lang.reflect.Modifier
import java.net.URI
import java.nio.file.{FileSystems, Files, StandardOpenOption}
import java.nio.file.attribute.PosixFilePermission
import java.util.Collections
import java.util.jar.{Attributes, JarEntry, JarFile, JarOutputStream, Manifest}

import coursier.{Dependency, Fetch, Repository, Resolution}
import coursier.{Dependency, Repository, Resolution}
import coursier.util.{Gather, Task}
import geny.Generator
import mill.main.client.InputPumper
import mill.eval.{PathRef, Result}
import mill.util.Ctx
import mill.api.IO
import mill.api.Loose.Agg

import scala.collection.mutable
import scala.collection.JavaConverters._
import upickle.default.{macroRW, ReadWriter => RW}
import scala.jdk.CollectionConverters._
import upickle.default.{ReadWriter => RW}

object Jvm {
/**
Expand Down Expand Up @@ -220,7 +216,6 @@ object Jvm {
* be provided for the jar. An optional filter function may also be provided to
* selectively include/exclude specific files.
* @param inputPaths - `Agg` of `os.Path`s containing files to be included in the jar
* @param mainClass - optional main class for the jar
* @param fileFilter - optional file filter to select files to be included.
* Given a `os.Path` (from inputPaths) and a `os.RelPath` for the individual file,
* return true if the file is to be included in the jar.
Expand Down Expand Up @@ -267,19 +262,12 @@ object Jvm {
def createAssembly(inputPaths: Agg[os.Path],
manifest: JarManifest = JarManifest.Default,
prependShellScript: String = "",
base: Option[os.Path] = None,
assemblyRules: Seq[Assembly.Rule] = Assembly.defaultRules)
(implicit ctx: Ctx.Dest with Ctx.Log): PathRef = {

val tmp = ctx.dest / "out-tmp.jar"

val baseUri = "jar:" + tmp.toIO.getCanonicalFile.toURI.toASCIIString
val hm = new java.util.HashMap[String, String]()

base match{
case Some(b) => os.copy(b, tmp)
case None => hm.put("create", "true")
}
val hm = Map("create" -> "true").asJava

val zipFs = FileSystems.newFileSystem(URI.create(baseUri), hm)

Expand All @@ -293,30 +281,18 @@ object Jvm {
manifest.build.write(manifestOut)
manifestOut.close()

Assembly.groupAssemblyEntries(inputPaths, assemblyRules).view
.foreach {
case (mapping, AppendEntry(entries, separator)) =>
val path = zipFs.getPath(mapping).toAbsolutePath
val separated =
if (entries.isEmpty) Nil
else
entries.head +: entries.tail.flatMap { e =>
List(JarFileEntry(e.mapping, () => new ByteArrayInputStream(separator.getBytes)), e)
}
val concatenated = new SequenceInputStream(
Collections.enumeration(separated.map(_.inputStream).asJava))
writeEntry(path, concatenated, append = true)
case (mapping, WriteOnceEntry(entry)) =>
val path = zipFs.getPath(mapping).toAbsolutePath
writeEntry(path, entry.inputStream, append = false)
}

val mappings = Assembly.loadShadedClasspath(inputPaths, assemblyRules)
Assembly.groupAssemblyEntries(mappings, assemblyRules).foreach {
case (mapping, inputStream) =>
val path = zipFs.getPath(mapping).toAbsolutePath
writeEntry(path, inputStream())
}
zipFs.close()
val output = ctx.dest / "out.jar"

val output = ctx.dest / "out.jar"
// Prepend shell script and make it executable
if (prependShellScript.isEmpty) os.move(tmp, output)
else{
else {
val lineSep = if (!prependShellScript.endsWith("\n")) "\n\r\n" else ""
os.write(output, prependShellScript + lineSep)
os.write.append(output, os.read.inputStream(tmp))
Expand All @@ -335,16 +311,14 @@ object Jvm {
PathRef(output)
}

private def writeEntry(p: java.nio.file.Path, is: InputStream, append: Boolean): Unit = {
private def writeEntry(p: java.nio.file.Path, inputStream: InputStream): Unit = {
if (p.getParent != null) Files.createDirectories(p.getParent)
val options =
if(append) Seq(StandardOpenOption.APPEND, StandardOpenOption.CREATE)
else Seq(StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE)
val options = Seq(StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.CREATE)

val outputStream = java.nio.file.Files.newOutputStream(p, options:_*)
IO.stream(is, outputStream)
IO.stream(inputStream, outputStream)
outputStream.close()
is.close()
inputStream.close()
}

def universalScript(shellCommands: String,
Expand Down
3 changes: 1 addition & 2 deletions scalalib/src/JavaModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,9 @@ trait JavaModule extends mill.Module
*/
def assembly = T{
createAssembly(
Agg.from(localClasspath().map(_.path)),
Agg.from(localClasspath().map(_.path)) ++ upstreamAssemblyClasspath().map(_.path),
manifest(),
prependShellScript(),
Some(upstreamAssembly().path),
assemblyRules
)
}
Expand Down
13 changes: 13 additions & 0 deletions scalalib/test/resources/hello-world-deps/core/src/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import akka.http.scaladsl.model.{ContentTypes, HttpEntity}
import akka.http.scaladsl.server.Directives._

object Main extends App {
val route =
path("hello") {
get {
complete(HttpEntity(ContentTypes.`text/html(UTF-8)`, "<h1>Say hello to akka-http</h1>"))
}
}

println(route)
}
1 change: 0 additions & 1 deletion scalalib/test/resources/hello-world/core/src/Main.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import scala.collection._
import java.nio.file.{Files, Paths}
import java.sql.Date
import java.time.LocalDate
Expand Down
40 changes: 37 additions & 3 deletions scalalib/test/src/HelloWorldTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ object HelloWorldTests extends TestSuite {
}
}

object HelloWorldAkkaHttpRelocate extends HelloBase {
object core extends HelloWorldModuleWithMain {
def ivyDeps = akkaHttpDeps

def assemblyRules = Seq(Assembly.Rule.Relocate("akka.**", "shaded.akka.@1"))
}
}

object HelloWorldAkkaHttpNoRules extends HelloBase {
object core extends HelloWorldModuleWithMain {
def ivyDeps = akkaHttpDeps
Expand Down Expand Up @@ -684,9 +692,7 @@ object HelloWorldTests extends TestSuite {
referenceContent.contains("Akka Stream Reference Config File"),
// our application config is present too
referenceContent.contains("My application Reference Config File"),
referenceContent.contains(
"""akka.http.client.user-agent-header="hello-world-client""""
)
referenceContent.contains("""akka.http.client.user-agent-header="hello-world-client"""")
)
}

Expand Down Expand Up @@ -767,6 +773,34 @@ object HelloWorldTests extends TestSuite {
resourcePath = helloWorldMultiResourcePath
)

def checkRelocate[M <: TestUtil.BaseModule](module: M,
target: Target[PathRef],
resourcePath: os.Path = resourcePath
) =
workspaceTest(module, resourcePath) { eval =>
val Right((result, _)) = eval.apply(target)

val jarFile = new JarFile(result.path.toIO)

assert(!jarEntries(jarFile).contains("akka/http/scaladsl/model/HttpEntity.class"))
assert(jarEntries(jarFile).contains("shaded/akka/http/scaladsl/model/HttpEntity.class"))
}

'relocate - {
'withDeps - checkRelocate(
HelloWorldAkkaHttpRelocate,
HelloWorldAkkaHttpRelocate.core.assembly
)

'run - workspaceTest(
HelloWorldAkkaHttpRelocate,
resourcePath = os.pwd / 'scalalib / 'test / 'resources / "hello-world-deps"
) { eval =>
val Right((_, evalCount)) = eval.apply(HelloWorldAkkaHttpRelocate.core.runMain("Main"))
assert(evalCount > 0)
}
}

'writeDownstreamWhenNoRule - {
'withDeps - workspaceTest(HelloWorldAkkaHttpNoRules) { eval =>
val Right((result, _)) = eval.apply(HelloWorldAkkaHttpNoRules.core.assembly)
Expand Down

0 comments on commit 59ec2f4

Please sign in to comment.