From efd7894181187446e32aaa3fcac37af113d6698d Mon Sep 17 00:00:00 2001 From: Joan Goyeau Date: Fri, 14 Aug 2020 15:18:11 -0700 Subject: [PATCH] Ability to relocate/shade in assembly --- build.sc | 6 +- .../com/eed3si9n/jarjarabrams/Shader.scala | 89 +++++++++++++++++++ main/src/modules/Assembly.scala | 79 ++++++++-------- main/src/modules/Jvm.scala | 21 ++--- 4 files changed, 135 insertions(+), 60 deletions(-) create mode 100644 main/src/com/eed3si9n/jarjarabrams/Shader.scala diff --git a/build.sc b/build.sc index 3cbac66b5bf..a2d33c3eea4 100755 --- a/build.sc +++ b/build.sc @@ -74,6 +74,7 @@ object Deps { val utest = ivy"com.lihaoyi::utest:0.7.4" val zinc = ivy"org.scala-sbt::zinc:1.4.0-M1" val bsp = ivy"ch.epfl.scala:bsp4j:2.0.0-M4" + val jarjarabrams = ivy"com.eed3si9n.jarjarabrams::jarjar-abrams-core:0.1.0" } trait MillPublishModule extends PublishModule{ @@ -92,7 +93,7 @@ trait MillPublishModule extends PublishModule{ ) ) - def javacOptions = Seq("-source", "1.8", "-target", "1.8") + def javacOptions = Seq("-source", "1.9", "-target", "1.9") } trait MillApiModule extends MillPublishModule with ScalaModule{ def scalaVersion = T{ "2.13.2" } @@ -167,7 +168,8 @@ object main extends MillModule { // Necessary so we can share the JNA classes throughout the build process Deps.jna, Deps.jnaPlatform, - Deps.coursier + Deps.coursier, + Deps.jarjarabrams ) def generatedSources = T { diff --git a/main/src/com/eed3si9n/jarjarabrams/Shader.scala b/main/src/com/eed3si9n/jarjarabrams/Shader.scala new file mode 100644 index 00000000000..9efaf92c784 --- /dev/null +++ b/main/src/com/eed3si9n/jarjarabrams/Shader.scala @@ -0,0 +1,89 @@ +package com.eed3si9n.jarjarabrams + +import java.io.{ByteArrayInputStream, InputStream} +import java.nio.file.{Files, Path, StandardOpenOption} +import org.pantsbuild.jarjar.{JJProcessor, _} +import org.pantsbuild.jarjar.util.EntryStruct + +object Shaderr { + def shadeDirectory( + rules: Seq[ShadeRule], + dir: Path, + mappings: Seq[(Path, String)], + verbose: Boolean + ): Unit = { + val inputStreams = mappings.filter(x => !Files.isDirectory(x._1)).map(x => Files.newInputStream(x._1) -> x._2) + val result = shadeInputStreams(rules, inputStreams, verbose) + mappings.foreach(f => Files.delete(f._1)) + result.foreach { case (inputStream, mapping) => + val out = dir.resolve(mapping) + if (!Files.exists(out.getParent)) Files.createDirectories(out.getParent) + Files.write(out, inputStream.readAllBytes(), StandardOpenOption.CREATE) + } + } + + def shadeInputStreams( + rules: Seq[ShadeRule], + mappings: Seq[(InputStream, String)], + verbose: Boolean + ): Seq[(InputStream, String)] = { + val jjrules = rules.flatMap { r => + r.shadePattern match { + case ShadePattern.Rename(patterns) => + patterns.map { case (from, to) => + val jrule = new Rule() + jrule.setPattern(from) + jrule.setResult(to) + jrule + } + case ShadePattern.Zap(patterns) => + patterns.map { pattern => + val jrule = new Zap() + jrule.setPattern(pattern) + jrule + } + case ShadePattern.Keep(patterns) => + patterns.map { pattern => + val jrule = new Keep() + jrule.setPattern(pattern) + jrule + } + case _ => Nil + } + } + + val proc = new JJProcessor(jjrules, verbose, true, null) + + /* + jarjar MisplacedClassProcessor class transforms byte[] to a class using org.objectweb.asm.ClassReader.getClassName + which always translates class names containing '.' into '/', regardless of OS platform. + We need to transform any windows file paths in order for jarjar to match them properly and not omit them. + */ + val sanitizedMappings = + mappings.map(f => if (f._2.contains('\\')) (f._1, f._2.replace('\\', '/')) else f) + val shadedInputStreams = sanitizedMappings.flatMap { f => + val entry = new EntryStruct + entry.data = f._1.readAllBytes() + entry.name = f._2 + entry.time = -1 + entry.skipTransform = false + if (proc.process(entry)) Some(new ByteArrayInputStream(entry.data) -> entry.name) + else None + } + val excludes = proc.getExcludes + shadedInputStreams.filterNot(mapping => excludes.contains(mapping._2)) + } +} + +sealed trait ShadePattern { + def inAll: ShadeRule = ShadeRule(this, Vector(ShadeTarget.inAll)) + def inProject: ShadeRule = ShadeRule(this, Vector(ShadeTarget.inProject)) + def inModuleCoordinates(moduleId: ModuleCoordinate*): ShadeRule = + ShadeRule(this, moduleId.toVector map ShadeTarget.inModuleCoordinate) +} + +object ShadePattern { + case class Rename(patterns: List[(String, String)]) extends ShadePattern + case class Zap(patterns: List[String]) extends ShadePattern + case class Keep(patterns: List[String]) extends ShadePattern +} diff --git a/main/src/modules/Assembly.scala b/main/src/modules/Assembly.scala index b2c9c44462b..28b79047322 100644 --- a/main/src/modules/Assembly.scala +++ b/main/src/modules/Assembly.scala @@ -1,12 +1,10 @@ package mill.modules +import com.eed3si9n.jarjarabrams.{ShadePattern, Shaderr} import java.io.InputStream import java.util.jar.JarFile import java.util.regex.Pattern - -import geny.Generator import mill.Agg - import scala.collection.JavaConverters._ object Assembly { @@ -32,6 +30,8 @@ object Assembly { case class Exclude(path: String) extends Rule + case class Relocate(from: String, to: String) extends Rule + object ExcludePattern { def apply(pattern: String): ExcludePattern = ExcludePattern(Pattern.compile(pattern)) } @@ -52,23 +52,19 @@ object Assembly { case Rule.ExcludePattern(pattern) => pattern.asPredicate().test(_) } - classpathIterator(inputPaths).foldLeft(Map.empty[String, GroupedEntry]) { - case (entries, entry) => - val mapping = entry.mapping - + classpathIterator(inputPaths, assemblyRules).foldLeft(Map.empty[String, GroupedEntry]) { + case (entries, (mapping, entry)) => rulesMap.get(mapping) match { case Some(_: Assembly.Rule.Exclude) => entries case Some(a: Assembly.Rule.Append) => - val newEntry = entries.getOrElse(mapping, AppendEntry(Nil, a.separator)).append(entry) + val newEntry = entries.getOrElse(mapping, AppendEntry(Seq.empty, a.separator)).append(entry) entries + (mapping -> newEntry) - case _ if excludePatterns.exists(_(mapping)) => entries case _ if appendPatterns.exists(_(mapping)) => val newEntry = entries.getOrElse(mapping, AppendEntry.empty).append(entry) entries + (mapping -> newEntry) - case _ if !entries.contains(mapping) => entries + (mapping -> WriteOnceEntry(entry)) case _ => @@ -77,52 +73,47 @@ object Assembly { } } - private def classpathIterator(inputPaths: Agg[os.Path]): Generator[AssemblyEntry] = { - Generator.from(inputPaths) + private def classpathIterator(inputPaths: Agg[os.Path], assemblyRules: Seq[Assembly.Rule]): Agg[(String, InputStream)] = { + val shadeRules = assemblyRules.collect { + case Rule.Relocate(from, to) => ShadePattern.Rename(List(from -> to)).inAll + } + + inputPaths .filter(os.exists) - .flatMap { - p => - if (os.isFile(p)) { - val jf = new JarFile(p.toIO) - Generator.from( - for(entry <- jf.entries().asScala if !entry.isDirectory) - yield JarFileEntry(entry.getName, () => jf.getInputStream(entry)) - ) - } - else { - os.walk.stream(p) - .filter(os.isFile) - .map(sub => PathEntry(sub.relativeTo(p).toString, sub)) - } + .flatMap { path => + if (os.isFile(path)) { + val jarFile = new JarFile(path.toIO) + val mappings = jarFile + .entries() + .asScala + .filter(!_.isDirectory) + .map(entry => jarFile.getInputStream(entry) -> entry.getName) + Shaderr.shadeInputStreams(shadeRules, mappings.toSeq, verbose = false) + } + else { + val pathsWithMappings = os + .walk(path) + .filter(os.isFile) + .map(subPath => os.read.inputStream(subPath) -> subPath.relativeTo(path).toString) + Shaderr.shadeInputStreams(shadeRules, pathsWithMappings, verbose = false) + } } + .map { case (inputStream, mapping) => mapping -> inputStream } } } private[modules] sealed trait GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry + def append(entry: InputStream): GroupedEntry } private[modules] object AppendEntry { val empty: AppendEntry = AppendEntry(Nil, Assembly.defaultSeparator) } -private[modules] case class AppendEntry(entries: List[AssemblyEntry], separator: String) extends GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry = copy(entries = entry :: this.entries) -} - -private[modules] case class WriteOnceEntry(entry: AssemblyEntry) extends GroupedEntry { - def append(entry: AssemblyEntry): GroupedEntry = this -} - -private[this] sealed trait AssemblyEntry { - def mapping: String - def inputStream: InputStream -} - -private[this] case class PathEntry(mapping: String, path: os.Path) extends AssemblyEntry { - def inputStream: InputStream = os.read.inputStream(path) +private[modules] case class AppendEntry(entries: Seq[InputStream], separator: String) extends GroupedEntry { + def append(entry: InputStream): GroupedEntry = copy(entries = entry +: entries) } -private[this] case class JarFileEntry(mapping: String, getIs: () => InputStream) extends AssemblyEntry { - def inputStream: InputStream = getIs() +private[modules] case class WriteOnceEntry(entry: InputStream) extends GroupedEntry { + def append(entry: InputStream): GroupedEntry = this } diff --git a/main/src/modules/Jvm.scala b/main/src/modules/Jvm.scala index e006b640595..ab13b469847 100644 --- a/main/src/modules/Jvm.scala +++ b/main/src/modules/Jvm.scala @@ -7,19 +7,16 @@ import java.nio.file.{FileSystems, Files, StandardOpenOption} import java.nio.file.attribute.PosixFilePermission import java.util.Collections import java.util.jar.{Attributes, JarEntry, JarFile, JarOutputStream, Manifest} - -import coursier.{Dependency, Fetch, Repository, Resolution} +import coursier.{Dependency, Repository, Resolution} import coursier.util.{Gather, Task} -import geny.Generator import mill.main.client.InputPumper import mill.eval.{PathRef, Result} import mill.util.Ctx import mill.api.IO import mill.api.Loose.Agg - import scala.collection.mutable import scala.collection.JavaConverters._ -import upickle.default.{macroRW, ReadWriter => RW} +import upickle.default.{ReadWriter => RW} object Jvm { /** @@ -299,21 +296,17 @@ object Jvm { val path = zipFs.getPath(mapping).toAbsolutePath val separated = if (entries.isEmpty) Nil - else - entries.head +: entries.tail.flatMap { e => - List(JarFileEntry(e.mapping, () => new ByteArrayInputStream(separator.getBytes)), e) - } - val concatenated = new SequenceInputStream( - Collections.enumeration(separated.map(_.inputStream).asJava)) + else entries.head +: entries.tail.flatMap(e => Seq(new ByteArrayInputStream(separator.getBytes), e)) + val concatenated = new SequenceInputStream(Collections.enumeration(separated.asJava)) writeEntry(path, concatenated, append = true) case (mapping, WriteOnceEntry(entry)) => val path = zipFs.getPath(mapping).toAbsolutePath - writeEntry(path, entry.inputStream, append = false) - } + writeEntry(path, entry, append = false) + } zipFs.close() - val output = ctx.dest / "out.jar" + val output = ctx.dest / "out.jar" // Prepend shell script and make it executable if (prependShellScript.isEmpty) os.move(tmp, output) else{