diff --git a/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala b/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala new file mode 100644 index 0000000..e16c904 --- /dev/null +++ b/akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala @@ -0,0 +1,58 @@ +package com.bluelabs.akkaaws + +import akka.http.scaladsl.model.headers.{ModeledCustomHeaderCompanion, ModeledCustomHeader} + +import scala.util.{Failure, Success, Try} + + +object AwsHeaders { + + sealed abstract class ServerSideEncryptionAlgorithm(val name: String) + object ServerSideEncryptionAlgorithm { + case object AES256 extends ServerSideEncryptionAlgorithm("AES256") + case object KMS extends ServerSideEncryptionAlgorithm("aws:kms") + + def fromString(raw: String): Try[ServerSideEncryptionAlgorithm] = raw match { + case "AES256" => Success(AES256) + case "aws:kms" => Success(KMS) + case invalid => Failure(new IllegalArgumentException(s"$invalid is not a valid server side encryption algorithm.")) + } + } + + object `X-Amz-Server-Side-Encryption` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] { + override def name: String = "X-Amz-Server-Side-Encryption" + override def parse(value: String): Try[`X-Amz-Server-Side-Encryption`] = + ServerSideEncryptionAlgorithm.fromString(value).map(new `X-Amz-Server-Side-Encryption`(_)) + } + + final case class `X-Amz-Server-Side-Encryption`(algorithm: ServerSideEncryptionAlgorithm) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption`] { + + override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] = `X-Amz-Server-Side-Encryption` + + override def value(): String = algorithm.name + + override def renderInResponses(): Boolean = true + + override def renderInRequests(): Boolean = true + } + + object `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] { + override def name: String = "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id" + override def parse(value: String): Try[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] = + Success(new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(value)) + } + + final case class `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id: String) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] { + + override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id` + + override def value(): String = id + + override def renderInResponses(): Boolean = true + + override def renderInRequests(): Boolean = true + } + + // TODO add `x-amz-server-side-encryption-context` header. + +} diff --git a/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala b/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala new file mode 100644 index 0000000..39536cc --- /dev/null +++ b/akka-http-aws/src/test/scala/com/bluelabs/akkaaws/AwsHeadersSpec.scala @@ -0,0 +1,47 @@ +package com.bluelabs.akkaaws + +import akka.http.scaladsl.model.headers.RawHeader +import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.AES256 +import com.bluelabs.akkaaws.AwsHeaders.{`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`, `X-Amz-Server-Side-Encryption`, ServerSideEncryptionAlgorithm} +import org.scalatest.{FlatSpec, Matchers} + +import scala.util.{Failure, Success} + + +class AwsHeadersSpec extends FlatSpec with Matchers { + + "ServerSideEncryptionAlgorithm" should "parse AES256" in { + ServerSideEncryptionAlgorithm.fromString("AES256") shouldBe Success(ServerSideEncryptionAlgorithm.AES256) + } + + it should "parse KMS" in { + ServerSideEncryptionAlgorithm.fromString("aws:kms") shouldBe Success(ServerSideEncryptionAlgorithm.KMS) + } + + it should "not parse an unsupported algorithm" in { + ServerSideEncryptionAlgorithm.fromString("Zip War AirGanon") shouldBe a[Failure[_]] + } + + "`X-Amz-Server-Side-Encryption`" should "parse AES256 algorithm" in { + val `X-Amz-Server-Side-Encryption`(algorithm) = `X-Amz-Server-Side-Encryption`("AES256") + algorithm shouldBe AES256 + } + + it should "set the X-Amz-Server-Side-Encryption header" in { + val RawHeader(key, value) = `X-Amz-Server-Side-Encryption`("AES256") + key shouldBe "X-Amz-Server-Side-Encryption" + value shouldBe "AES256" + } + + "`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`" should "parse kms key id" in { + val `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId") + id shouldBe "myId" + } + + it should "set the X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id header" in { + val RawHeader(key, value) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId") + key shouldBe "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id" + value shouldBe "myId" + } + +} diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala index e6eab0b..4419a86 100644 --- a/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala +++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala @@ -1,13 +1,16 @@ package com.bluelabs.s3stream +import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.{KMS, AES256} +import scala.collection.immutable import scala.concurrent.{ExecutionContext, Future} import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._ import akka.http.scaladsl.marshalling.Marshal import akka.http.scaladsl.model._ import akka.http.scaladsl.model.Uri.Query -import akka.http.scaladsl.model.headers.Host +import akka.http.scaladsl.model.headers._ import akka.util.ByteString +import com.bluelabs.akkaaws.AwsHeaders._ object HttpRequests { @@ -17,33 +20,55 @@ object HttpRequests { .withUri(uriFn(requestUri(s3Location))) } - def initiateMultipartUploadRequest(s3Location: S3Location): HttpRequest = { + def initiateMultipartUploadRequest(s3Location: S3Location, metadata: Metadata): HttpRequest = { s3Request(s3Location, HttpMethods.POST, _.withQuery(Query("uploads"))) + .mapHeaders(_ ++ metadataHeaders(metadata)) } - + + + def metadataHeaders(metadata: Metadata): immutable.Iterable[HttpHeader] = { + `Content-Type`(metadata.contentType) :: (metadata.serverSideEncryption match { + case ServerSideEncryption.Aes256 => + List(new `X-Amz-Server-Side-Encryption`(AES256)) + case ServerSideEncryption.Kms(keyId) => + List( + new `X-Amz-Server-Side-Encryption`(KMS), + new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(keyId) + // TODO add `x-amz-server-side-encryption-context` header. + ) + case ServerSideEncryption.None => + Nil + }) + } + def getRequest(s3Location: S3Location): HttpRequest = { s3Request(s3Location) } def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: ByteString): HttpRequest = { s3Request(upload.s3Location, - HttpMethods.PUT, - _.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId)) + HttpMethods.PUT, + _.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId)) ).withEntity(payload) } def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)])(implicit ec: ExecutionContext): Future[HttpRequest] = { val payload = - { - parts.map{case (partNumber, etag) => {partNumber}{etag}} - } - + {parts.map { case (partNumber, etag) => + + {partNumber} + + {etag} + + + }} + for { entity <- Marshal(payload).to[RequestEntity] } yield { s3Request(upload.s3Location, - HttpMethods.POST, - _.withQuery(Query("uploadId" -> upload.uploadId)) + HttpMethods.POST, + _.withQuery(Query("uploadId" -> upload.uploadId)) ).withEntity(entity) } } diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala index e828797..bf4baa6 100644 --- a/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala +++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala @@ -20,6 +20,11 @@ import akka.stream.{Attributes, Materializer} import akka.stream.scaladsl.{Flow, Keep, Sink, Source} import akka.util.ByteString +case class Metadata( + contentType: ContentType = ContentTypes.`application/octet-stream`, + serverSideEncryption: ServerSideEncryption = ServerSideEncryption.None +) + case class S3Location(bucket: String, key: String) case class MultipartUpload(s3Location: S3Location, uploadId: String) @@ -56,18 +61,18 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic * @param chunkingParallelism * @return */ - def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4): Sink[ByteString, Future[CompleteMultipartUploadResult]] = { + def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4, metadata: Metadata = Metadata()): Sink[ByteString, Future[CompleteMultipartUploadResult]] = { import mat.executionContext - chunkAndRequest(s3Location, chunkSize)(chunkingParallelism) + chunkAndRequest(s3Location, chunkSize, metadata)(chunkingParallelism) .log("s3-upload-response").withAttributes(Attributes.logLevels(onElement = Logging.DebugLevel, onFailure = Logging.WarningLevel, onFinish = Logging.InfoLevel)) .toMat(completionSink(s3Location))(Keep.right) } - def initiateMultipartUpload(s3Location: S3Location): Future[MultipartUpload] = { + def initiateMultipartUpload(s3Location: S3Location, metadata: Metadata): Future[MultipartUpload] = { import mat.executionContext - val req = HttpRequests.initiateMultipartUploadRequest(s3Location) + val req = HttpRequests.initiateMultipartUploadRequest(s3Location, metadata) val response = for { signedReq <- Signer.signedRequest(req, signingKey) response <- Http().singleRequest(signedReq) @@ -99,8 +104,9 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic * @param s3Location The s3 location to which to upload to * @return */ - def initiateUpload(s3Location: S3Location): Source[(MultipartUpload, Int), NotUsed] = { - Source.single(s3Location).mapAsync(1)(initiateMultipartUpload(_)) + def initiateUpload(s3Location: S3Location, metadata: Metadata): Source[(MultipartUpload, Int), NotUsed] = { + Source.single(s3Location) + .mapAsync(1)(initiateMultipartUpload(_, metadata)) .mapConcat{case r => Stream.continually(r)} .zip(StreamUtils.counter(1)) } @@ -113,17 +119,17 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic * @param parallelism * @return */ - def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = { + def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4, metadata: Metadata = Metadata()): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = { assert(chunkSize >= MIN_CHUNK_SIZE, "Chunk size must be at least 5242880B. See http://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html") - val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location) + val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location, metadata) Flow[ByteString] .via(new Chunker(chunkSize)) .zipWith(requestInfo){case (payload, (uploadInfo, chunkIndex)) => (HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload), (uploadInfo, chunkIndex))} .mapAsync(parallelism){case (req, info) => Signer.signedRequest(req, signingKey).zip(Future.successful(info)) } } - def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE)(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = { - createRequests(s3Location, chunkSize, parallelism) + def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, metadata: Metadata = Metadata())(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = { + createRequests(s3Location, chunkSize, parallelism, metadata) .via(Http().superPool[(MultipartUpload, Int)]()) .map { case (Success(r), (upload, index)) => { diff --git a/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala b/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala new file mode 100644 index 0000000..5c985ca --- /dev/null +++ b/s3-stream/src/main/scala/com/bluelabs/s3stream/ServerSideEncryption.scala @@ -0,0 +1,14 @@ +package com.bluelabs.s3stream + + +sealed trait ServerSideEncryption + +object ServerSideEncryption { + + case object None extends ServerSideEncryption + + case object Aes256 extends ServerSideEncryption + + // TODO add context + case class Kms(keyId: String) extends ServerSideEncryption +} diff --git a/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala b/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala new file mode 100644 index 0000000..7fb7b77 --- /dev/null +++ b/s3-stream/src/test/scala/com/bluelabs/s3stream/HttpRequestsSpec.scala @@ -0,0 +1,38 @@ +package com.bluelabs.s3stream + +import akka.http.scaladsl.model.ContentTypes +import akka.http.scaladsl.model.headers.`Content-Type` +import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.{KMS, AES256} + +import org.scalatest.{Matchers, FlatSpec} + +class HttpRequestsSpec extends FlatSpec with Matchers { + import HttpRequests._ + import com.bluelabs.akkaaws.AwsHeaders._ + + "metadataHeaders" should "add the contentType header when contentType is default" in { + metadataHeaders(Metadata()) should contain (`Content-Type`(ContentTypes.`application/octet-stream`)) + } + + it should "add the contentType header contentType is custom" in { + val customContentType = ContentTypes.`application/json` + metadataHeaders(Metadata(customContentType)) should contain (`Content-Type`(customContentType)) + } + + it should "not add the x-amz-server-side-encryption header when the server side encryption is None" in { + metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.None)) should not contain a[`X-Amz-Server-Side-Encryption`] + } + + it should "add the x-amz-server-side-encryption with AES256 header when the server side encryption is AES256" in { + metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Aes256)) should contain (`X-Amz-Server-Side-Encryption`(AES256)) + } + + it should "add the x-amz-server-side-encryption with KMZ header when the server side encryption is KMS" in { + metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Kms("my-id"))) should contain (`X-Amz-Server-Side-Encryption`(KMS)) + } + + it should "add the x-amz-server-side-encryption-kms-id with KMZ header when the server side encryption is KMS" in { + metadataHeaders(Metadata(serverSideEncryption = ServerSideEncryption.Kms("my-id"))) should contain (`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("my-id")) + } + +}