diff --git a/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala b/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala
new file mode 100644
index 0000000..e16c904
--- /dev/null
+++ b/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala
@@ -0,0 +1,58 @@
+package com.bluelabs.akkaaws
+
+import akka.http.scaladsl.model.headers.{ModeledCustomHeaderCompanion, ModeledCustomHeader}
+
+import scala.util.{Failure, Success, Try}
+
+
+object AwsHeaders {
+
+ sealed abstract class ServerSideEncryptionAlgorithm(val name: String)
+ object ServerSideEncryptionAlgorithm {
+ case object AES256 extends ServerSideEncryptionAlgorithm("AES256")
+ case object KMS extends ServerSideEncryptionAlgorithm("aws:kms")
+
+ def fromString(raw: String): Try[ServerSideEncryptionAlgorithm] = raw match {
+ case "AES256" => Success(AES256)
+ case "aws:kms" => Success(KMS)
+ case invalid => Failure(new IllegalArgumentException(s"$invalid is not a valid server side encryption algorithm."))
+ }
+ }
+
+ object `X-Amz-Server-Side-Encryption` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] {
+ override def name: String = "X-Amz-Server-Side-Encryption"
+ override def parse(value: String): Try[`X-Amz-Server-Side-Encryption`] =
+ ServerSideEncryptionAlgorithm.fromString(value).map(new `X-Amz-Server-Side-Encryption`(_))
+ }
+
+ final case class `X-Amz-Server-Side-Encryption`(algorithm: ServerSideEncryptionAlgorithm) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption`] {
+
+ override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] = `X-Amz-Server-Side-Encryption`
+
+ override def value(): String = algorithm.name
+
+ override def renderInResponses(): Boolean = true
+
+ override def renderInRequests(): Boolean = true
+ }
+
+ object `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] {
+ override def name: String = "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"
+ override def parse(value: String): Try[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] =
+ Success(new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(value))
+ }
+
+ final case class `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id: String) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] {
+
+ override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`
+
+ override def value(): String = id
+
+ override def renderInResponses(): Boolean = true
+
+ override def renderInRequests(): Boolean = true
+ }
+
+ // TODO add `x-amz-server-side-encryption-context` header.
+
+}
diff --git a/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala b/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala
new file mode 100644
index 0000000..39536cc
--- /dev/null
+++ b/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala
@@ -0,0 +1,47 @@
+package com.bluelabs.akkaaws
+
+import akka.http.scaladsl.model.headers.RawHeader
+import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.AES256
+import com.bluelabs.akkaaws.AwsHeaders.{`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`, `X-Amz-Server-Side-Encryption`, ServerSideEncryptionAlgorithm}
+import org.scalatest.{FlatSpec, Matchers}
+
+import scala.util.{Failure, Success}
+
+
+class AwsHeadersSpec extends FlatSpec with Matchers {
+
+ "ServerSideEncryptionAlgorithm" should "parse AES256" in {
+ ServerSideEncryptionAlgorithm.fromString("AES256") shouldBe Success(ServerSideEncryptionAlgorithm.AES256)
+ }
+
+ it should "parse KMS" in {
+ ServerSideEncryptionAlgorithm.fromString("aws:kms") shouldBe Success(ServerSideEncryptionAlgorithm.KMS)
+ }
+
+ it should "not parse an unsupported algorithm" in {
+ ServerSideEncryptionAlgorithm.fromString("Zip War AirGanon") shouldBe a[Failure[_]]
+ }
+
+ "`X-Amz-Server-Side-Encryption`" should "parse AES256 algorithm" in {
+ val `X-Amz-Server-Side-Encryption`(algorithm) = `X-Amz-Server-Side-Encryption`("AES256")
+ algorithm shouldBe AES256
+ }
+
+ it should "set the X-Amz-Server-Side-Encryption header" in {
+ val RawHeader(key, value) = `X-Amz-Server-Side-Encryption`("AES256")
+ key shouldBe "X-Amz-Server-Side-Encryption"
+ value shouldBe "AES256"
+ }
+
+ "`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`" should "parse kms key id" in {
+ val `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId")
+ id shouldBe "myId"
+ }
+
+ it should "set the X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id header" in {
+ val RawHeader(key, value) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId")
+ key shouldBe "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"
+ value shouldBe "myId"
+ }
+
+}
diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala
index e6eab0b..c896a8f 100644
--- a/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala
+++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala
@@ -1,13 +1,16 @@
package com.bluelabs.s3stream
+import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.{KMS, AES256}
+import scala.collection.immutable
import scala.concurrent.{ExecutionContext, Future}
import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._
import akka.http.scaladsl.marshalling.Marshal
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.Uri.Query
-import akka.http.scaladsl.model.headers.Host
+import akka.http.scaladsl.model.headers._
import akka.util.ByteString
+import com.bluelabs.akkaaws.AwsHeaders._
object HttpRequests {
@@ -17,33 +20,49 @@ object HttpRequests {
.withUri(uriFn(requestUri(s3Location)))
}
- def initiateMultipartUploadRequest(s3Location: S3Location): HttpRequest = {
+ def initiateMultipartUploadRequest(s3Location: S3Location, metadata: Metadata): HttpRequest = {
s3Request(s3Location, HttpMethods.POST, _.withQuery(Query("uploads")))
+ .withEntity(metadata.contentType, ByteString.empty)
+ .mapHeaders(_ ++ metadataHeaders(metadata))
}
-
+
+
+ def metadataHeaders(metadata: Metadata): immutable.Iterable[HttpHeader] = {
+ metadata.serverSideEncryption match {
+ case ServerSideEncryption.Aes256 =>
+ List(new `X-Amz-Server-Side-Encryption`(AES256))
+ case ServerSideEncryption.Kms(keyId) =>
+ List(
+ new `X-Amz-Server-Side-Encryption`(KMS),
+ new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(keyId)
+ // TODO add `x-amz-server-side-encryption-context` header.
+ )
+ case ServerSideEncryption.None =>
+ Nil
+ }
+ }
+
def getRequest(s3Location: S3Location): HttpRequest = {
s3Request(s3Location)
}
- def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: ByteString): HttpRequest = {
+ def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: ByteString, metadata: Metadata): HttpRequest = {
s3Request(upload.s3Location,
- HttpMethods.PUT,
- _.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
- ).withEntity(payload)
+ HttpMethods.PUT,
+ _.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
+ ).withEntity(metadata.contentType, payload).mapHeaders(_ ++ metadataHeaders(metadata))
}
def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)])(implicit ec: ExecutionContext): Future[HttpRequest] = {
val payload =
- {
- parts.map{case (partNumber, etag) => {partNumber}{etag}}
- }
-
+ {parts.map { case (partNumber, etag) => {partNumber}{etag}}}
+
for {
entity <- Marshal(payload).to[RequestEntity]
} yield {
s3Request(upload.s3Location,
- HttpMethods.POST,
- _.withQuery(Query("uploadId" -> upload.uploadId))
+ HttpMethods.POST,
+ _.withQuery(Query("uploadId" -> upload.uploadId))
).withEntity(entity)
}
}
diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala
index e828797..bb33b26 100644
--- a/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala
+++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala
@@ -20,6 +20,11 @@ import akka.stream.{Attributes, Materializer}
import akka.stream.scaladsl.{Flow, Keep, Sink, Source}
import akka.util.ByteString
+case class Metadata(
+ contentType: ContentType = ContentTypes.`application/octet-stream`,
+ serverSideEncryption: ServerSideEncryption = ServerSideEncryption.None
+)
+
case class S3Location(bucket: String, key: String)
case class MultipartUpload(s3Location: S3Location, uploadId: String)
@@ -56,18 +61,18 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param chunkingParallelism
* @return
*/
- def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4): Sink[ByteString, Future[CompleteMultipartUploadResult]] = {
+ def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4, metadata: Metadata = Metadata()): Sink[ByteString, Future[CompleteMultipartUploadResult]] = {
import mat.executionContext
- chunkAndRequest(s3Location, chunkSize)(chunkingParallelism)
+ chunkAndRequest(s3Location, chunkSize, metadata)(chunkingParallelism)
.log("s3-upload-response").withAttributes(Attributes.logLevels(onElement = Logging.DebugLevel, onFailure = Logging.WarningLevel, onFinish = Logging.InfoLevel))
.toMat(completionSink(s3Location))(Keep.right)
}
- def initiateMultipartUpload(s3Location: S3Location): Future[MultipartUpload] = {
+ def initiateMultipartUpload(s3Location: S3Location, metadata: Metadata): Future[MultipartUpload] = {
import mat.executionContext
- val req = HttpRequests.initiateMultipartUploadRequest(s3Location)
+ val req = HttpRequests.initiateMultipartUploadRequest(s3Location, metadata)
val response = for {
signedReq <- Signer.signedRequest(req, signingKey)
response <- Http().singleRequest(signedReq)
@@ -99,8 +104,9 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param s3Location The s3 location to which to upload to
* @return
*/
- def initiateUpload(s3Location: S3Location): Source[(MultipartUpload, Int), NotUsed] = {
- Source.single(s3Location).mapAsync(1)(initiateMultipartUpload(_))
+ def initiateUpload(s3Location: S3Location, metadata: Metadata): Source[(MultipartUpload, Int), NotUsed] = {
+ Source.single(s3Location)
+ .mapAsync(1)(initiateMultipartUpload(_, metadata))
.mapConcat{case r => Stream.continually(r)}
.zip(StreamUtils.counter(1))
}
@@ -113,17 +119,17 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param parallelism
* @return
*/
- def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = {
+ def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4, metadata: Metadata = Metadata()): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = {
assert(chunkSize >= MIN_CHUNK_SIZE, "Chunk size must be at least 5242880B. See http://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html")
- val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location)
+ val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location, metadata)
Flow[ByteString]
.via(new Chunker(chunkSize))
- .zipWith(requestInfo){case (payload, (uploadInfo, chunkIndex)) => (HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload), (uploadInfo, chunkIndex))}
+ .zipWith(requestInfo){case (payload, (uploadInfo, chunkIndex)) => (HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload, metadata), (uploadInfo, chunkIndex))}
.mapAsync(parallelism){case (req, info) => Signer.signedRequest(req, signingKey).zip(Future.successful(info)) }
}
- def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE)(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = {
- createRequests(s3Location, chunkSize, parallelism)
+ def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, metadata: Metadata = Metadata())(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = {
+ createRequests(s3Location, chunkSize, parallelism, metadata)
.via(Http().superPool[(MultipartUpload, Int)]())
.map {
case (Success(r), (upload, index)) => {
diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala
new file mode 100644
index 0000000..5c985ca
--- /dev/null
+++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala
@@ -0,0 +1,14 @@
+package com.bluelabs.s3stream
+
+
+sealed trait ServerSideEncryption
+
+object ServerSideEncryption {
+
+ case object None extends ServerSideEncryption
+
+ case object Aes256 extends ServerSideEncryption
+
+ // TODO add context
+ case class Kms(keyId: String) extends ServerSideEncryption
+}
diff --git a/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala b/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala
new file mode 100644
index 0000000..7fb7b77
--- /dev/null
+++ b/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala
@@ -0,0 +1,38 @@
+package com.bluelabs.s3stream
+
+import akka.http.scaladsl.model.ContentTypes
+import akka.http.scaladsl.model.headers.`Content-Type`
+import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.{KMS, AES256}
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class HttpRequestsSpec extends FlatSpec with Matchers {
+ import HttpRequests._
+ import com.bluelabs.akkaaws.AwsHeaders._
+
+ "metadataHeaders" should "add the contentType header when contentType is default" in {
+ metadataHeaders(Metadata()) should contain (`Content-Type`(ContentTypes.`application/octet-stream`))
+ }
+
+ it should "add the contentType header contentType is custom" in {
+ val customContentType = ContentTypes.`application/json`
+ metadataHeaders(Metadata(customContentType)) should contain (`Content-Type`(customContentType))
+ }
+
+ it should "not add the x-amz-server-side-encryption header when the server side encryption is None" in {
+ metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.None)) should not contain a[`X-Amz-Server-Side-Encryption`]
+ }
+
+ it should "add the x-amz-server-side-encryption with AES256 header when the server side encryption is AES256" in {
+ metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Aes256)) should contain (`X-Amz-Server-Side-Encryption`(AES256))
+ }
+
+ it should "add the x-amz-server-side-encryption with KMZ header when the server side encryption is KMS" in {
+ metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Kms("my-id"))) should contain (`X-Amz-Server-Side-Encryption`(KMS))
+ }
+
+ it should "add the x-amz-server-side-encryption-kms-id with KMZ header when the server side encryption is KMS" in {
+ metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Kms("my-id"))) should contain (`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("my-id"))
+ }
+
+}