Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server side encryption and Content-Type support. #15

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ project/plugins/project/
# Scala-IDE specific
.scala_dependencies
.worksheet

# IntelliJ Idea specific
/.idea/
/idea_modules/
*.iml

# MacOSX specific
.DS_Store
88 changes: 88 additions & 0 deletions akka-http-aws/src/main/scala/com/bluelabs/akkaaws/AwsHeaders.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package com.bluelabs.akkaaws

import akka.http.scaladsl.model.headers.{ModeledCustomHeaderCompanion, ModeledCustomHeader}

import scala.util.parsing.json.JSON
import scala.util.{Failure, Success, Try}

object AwsHeaders {

sealed abstract class ServerSideEncryptionAlgorithm(val name: String)

object ServerSideEncryptionAlgorithm {

case object AES256 extends ServerSideEncryptionAlgorithm("AES256")

case object KMS extends ServerSideEncryptionAlgorithm("aws:kms")

def fromString(raw: String): Try[ServerSideEncryptionAlgorithm] = raw match {
case "AES256" => Success(AES256)
case "aws:kms" => Success(KMS)
case invalid => Failure(new IllegalArgumentException(s"$invalid is not a valid server side encryption algorithm."))
}
}

object `X-Amz-Server-Side-Encryption` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] {
override def name: String = "X-Amz-Server-Side-Encryption"

override def parse(value: String): Try[`X-Amz-Server-Side-Encryption`] =
ServerSideEncryptionAlgorithm.fromString(value).map(new `X-Amz-Server-Side-Encryption`(_))
}

final case class `X-Amz-Server-Side-Encryption`(algorithm: ServerSideEncryptionAlgorithm) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption`] {

override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption`] = `X-Amz-Server-Side-Encryption`

override def value(): String = algorithm.name

override def renderInResponses(): Boolean = true

override def renderInRequests(): Boolean = true
}

object `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] {
override def name: String = "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"

override def parse(value: String): Try[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] =
Success(new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(value))
}

final case class `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id: String) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] {

override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`] = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`

override def value(): String = id

override def renderInResponses(): Boolean = true

override def renderInRequests(): Boolean = true
}

object `X-Amz-Server-Side-Encryption-Context` extends ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Context`] {
override def name: String = "X-Amz-Server-Side-Encryption-Context"

override def parse(value: String): Try[`X-Amz-Server-Side-Encryption-Context`] =
JSON.parseFull(value) match {
case Some(context: Map[String, Any]) if context.forall(_._2.isInstanceOf[String]) =>
Success(new `X-Amz-Server-Side-Encryption-Context`(context.asInstanceOf[Map[String, String]]))
case _ =>
Failure(new IllegalArgumentException("$value is not a valid AWS KMS context"))
}
}

final case class `X-Amz-Server-Side-Encryption-Context`(context: Map[String, String]) extends ModeledCustomHeader[`X-Amz-Server-Side-Encryption-Context`] {

override def companion: ModeledCustomHeaderCompanion[`X-Amz-Server-Side-Encryption-Context`] = `X-Amz-Server-Side-Encryption-Context`

override def value(): String =
context.map { case (key, value) => s""""$key":"$value"""" }.mkString("{", ",", "}")

override def renderInResponses(): Boolean = true

override def renderInRequests(): Boolean = true
}


}


Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.bluelabs.akkaaws

import akka.http.scaladsl.model.headers.RawHeader
import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.AES256
import com.bluelabs.akkaaws.AwsHeaders.{`X-Amz-Server-Side-Encryption-Context`, `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`, `X-Amz-Server-Side-Encryption`, ServerSideEncryptionAlgorithm}
import org.scalatest.{FlatSpec, Matchers}

import scala.util.{Failure, Success}


class AwsHeadersSpec extends FlatSpec with Matchers {

"ServerSideEncryptionAlgorithm" should "parse AES256" in {
ServerSideEncryptionAlgorithm.fromString("AES256") shouldBe Success(ServerSideEncryptionAlgorithm.AES256)
}

it should "parse KMS" in {
ServerSideEncryptionAlgorithm.fromString("aws:kms") shouldBe Success(ServerSideEncryptionAlgorithm.KMS)
}

it should "not parse an unsupported algorithm" in {
ServerSideEncryptionAlgorithm.fromString("Zip War AirGanon") shouldBe a[Failure[_]]
}

"`X-Amz-Server-Side-Encryption`" should "parse AES256 algorithm" in {
val `X-Amz-Server-Side-Encryption`(algorithm) = `X-Amz-Server-Side-Encryption`("AES256")
algorithm shouldBe AES256
}

it should "set the X-Amz-Server-Side-Encryption header" in {
val RawHeader(key, value) = `X-Amz-Server-Side-Encryption`("AES256")
key shouldBe "X-Amz-Server-Side-Encryption"
value shouldBe "AES256"
}

"`X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`" should "parse kms key id" in {
val `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(id) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId")
id shouldBe "myId"
}

it should "set the X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id header" in {
val RawHeader(key, value) = `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`("myId")
key shouldBe "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"
value shouldBe "myId"
}

"`X-Amz-Server-Side-Encryption-Context`" should "parse context" in {
val expectedContext = Map("foo"->"bar", "foo2"->"bar2")
val `X-Amz-Server-Side-Encryption-Context`(context) = `X-Amz-Server-Side-Encryption-Context`(expectedContext)

context shouldBe expectedContext
}

it should "set the X-Amz-Server-Side-Encryption-Context header" in {
val RawHeader(key, value) = `X-Amz-Server-Side-Encryption-Context`(Map("foo"->"bar", "foo2"->"bar2"))
key shouldBe "X-Amz-Server-Side-Encryption-Context"
value shouldBe """{"foo":"bar","foo2":"bar2"}"""
}

it should "parse the raw context" in {
val header = `X-Amz-Server-Side-Encryption-Context`.parse("""{"foo":"bar","foo2":"bar2"}""")
header shouldBe Success(`X-Amz-Server-Side-Encryption-Context`(Map("foo"->"bar", "foo2"->"bar2")))
}

it should "not parse the raw context if it is not string->string" in {
val header = `X-Amz-Server-Side-Encryption-Context`.parse("""{"foo":"bar","foo2":2}""")
header shouldBe a[Failure[_]]
}

}
48 changes: 35 additions & 13 deletions s3-stream/src/main/scala/com/bluelabs/s3stream/HttpRequests.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.bluelabs.s3stream

import com.bluelabs.akkaaws.AwsHeaders.ServerSideEncryptionAlgorithm.{KMS, AES256}

import scala.collection.immutable
import scala.concurrent.{ExecutionContext, Future}
import akka.http.scaladsl.marshallers.xml.ScalaXmlSupport._
import akka.http.scaladsl.marshalling.Marshal
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.Uri.Query
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.headers._
import akka.util.ByteString
import com.bluelabs.akkaaws.AwsHeaders._

object HttpRequests {

Expand All @@ -17,33 +20,52 @@ object HttpRequests {
.withUri(uriFn(requestUri(s3Location)))
}

def initiateMultipartUploadRequest(s3Location: S3Location): HttpRequest = {
def initiateMultipartUploadRequest(s3Location: S3Location, metadata: Metadata): HttpRequest = {
s3Request(s3Location, HttpMethods.POST, _.withQuery(Query("uploads")))
.withEntity(metadata.contentType, ByteString.empty)
.mapHeaders(_ ++ metadataHeaders(metadata))
}



def metadataHeaders(metadata: Metadata): immutable.Iterable[HttpHeader] = {
metadata.serverSideEncryption match {
case ServerSideEncryption.Aes256 =>
List(new `X-Amz-Server-Side-Encryption`(AES256))
case ServerSideEncryption.Kms(keyId, context) =>
List(
new `X-Amz-Server-Side-Encryption`(KMS),
new `X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id`(keyId)
) ++ (if(context.isEmpty) {
List.empty
} else {
List(`X-Amz-Server-Side-Encryption-Context`(context))
})
case ServerSideEncryption.None =>
Nil
}
}

def getRequest(s3Location: S3Location): HttpRequest = {
s3Request(s3Location)
}

def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: ByteString): HttpRequest = {
def uploadPartRequest(upload: MultipartUpload, partNumber: Int, payload: ByteString, metadata: Metadata): HttpRequest = {
s3Request(upload.s3Location,
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
).withEntity(payload)
HttpMethods.PUT,
_.withQuery(Query("partNumber" -> partNumber.toString, "uploadId" -> upload.uploadId))
).withEntity(metadata.contentType, payload).mapHeaders(_ ++ metadataHeaders(metadata))
}

def completeMultipartUploadRequest(upload: MultipartUpload, parts: Seq[(Int, String)])(implicit ec: ExecutionContext): Future[HttpRequest] = {
val payload = <CompleteMultipartUpload>
{
parts.map{case (partNumber, etag) => <Part><PartNumber>{partNumber}</PartNumber><ETag>{etag}</ETag></Part>}
}
</CompleteMultipartUpload>
{parts.map { case (partNumber, etag) => <Part><PartNumber>{partNumber}</PartNumber><ETag>{etag}</ETag></Part>}}
</CompleteMultipartUpload>
for {
entity <- Marshal(payload).to[RequestEntity]
} yield {
s3Request(upload.s3Location,
HttpMethods.POST,
_.withQuery(Query("uploadId" -> upload.uploadId))
HttpMethods.POST,
_.withQuery(Query("uploadId" -> upload.uploadId))
).withEntity(entity)
}
}
Expand Down
28 changes: 17 additions & 11 deletions s3-stream/src/main/scala/com/bluelabs/s3stream/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import akka.stream.{Attributes, Materializer}
import akka.stream.scaladsl.{Flow, Keep, Sink, Source}
import akka.util.ByteString

case class Metadata(
contentType: ContentType = ContentTypes.`application/octet-stream`,
serverSideEncryption: ServerSideEncryption = ServerSideEncryption.None
)

case class S3Location(bucket: String, key: String)

case class MultipartUpload(s3Location: S3Location, uploadId: String)
Expand Down Expand Up @@ -56,18 +61,18 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param chunkingParallelism
* @return
*/
def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4): Sink[ByteString, Future[CompleteMultipartUploadResult]] = {
def multipartUpload(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, chunkingParallelism: Int = 4, metadata: Metadata = Metadata()): Sink[ByteString, Future[CompleteMultipartUploadResult]] = {
import mat.executionContext

chunkAndRequest(s3Location, chunkSize)(chunkingParallelism)
chunkAndRequest(s3Location, chunkSize, metadata)(chunkingParallelism)
.log("s3-upload-response").withAttributes(Attributes.logLevels(onElement = Logging.DebugLevel, onFailure = Logging.WarningLevel, onFinish = Logging.InfoLevel))
.toMat(completionSink(s3Location))(Keep.right)
}

def initiateMultipartUpload(s3Location: S3Location): Future[MultipartUpload] = {
def initiateMultipartUpload(s3Location: S3Location, metadata: Metadata): Future[MultipartUpload] = {
import mat.executionContext

val req = HttpRequests.initiateMultipartUploadRequest(s3Location)
val req = HttpRequests.initiateMultipartUploadRequest(s3Location, metadata)
val response = for {
signedReq <- Signer.signedRequest(req, signingKey)
response <- Http().singleRequest(signedReq)
Expand Down Expand Up @@ -99,8 +104,9 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param s3Location The s3 location to which to upload to
* @return
*/
def initiateUpload(s3Location: S3Location): Source[(MultipartUpload, Int), NotUsed] = {
Source.single(s3Location).mapAsync(1)(initiateMultipartUpload(_))
def initiateUpload(s3Location: S3Location, metadata: Metadata): Source[(MultipartUpload, Int), NotUsed] = {
Source.single(s3Location)
.mapAsync(1)(initiateMultipartUpload(_, metadata))
.mapConcat{case r => Stream.continually(r)}
.zip(StreamUtils.counter(1))
}
Expand All @@ -113,17 +119,17 @@ class S3Stream(credentials: AWSCredentials, region: String = "us-east-1")(implic
* @param parallelism
* @return
*/
def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = {
def createRequests(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, parallelism: Int = 4, metadata: Metadata = Metadata()): Flow[ByteString, (HttpRequest, (MultipartUpload, Int)), NotUsed] = {
assert(chunkSize >= MIN_CHUNK_SIZE, "Chunk size must be at least 5242880B. See http://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html")
val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location)
val requestInfo: Source[(MultipartUpload, Int), NotUsed] = initiateUpload(s3Location, metadata)
Flow[ByteString]
.via(new Chunker(chunkSize))
.zipWith(requestInfo){case (payload, (uploadInfo, chunkIndex)) => (HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload), (uploadInfo, chunkIndex))}
.zipWith(requestInfo){case (payload, (uploadInfo, chunkIndex)) => (HttpRequests.uploadPartRequest(uploadInfo, chunkIndex, payload, metadata), (uploadInfo, chunkIndex))}
.mapAsync(parallelism){case (req, info) => Signer.signedRequest(req, signingKey).zip(Future.successful(info)) }
}

def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE)(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = {
createRequests(s3Location, chunkSize, parallelism)
def chunkAndRequest(s3Location: S3Location, chunkSize: Int = MIN_CHUNK_SIZE, metadata: Metadata = Metadata())(parallelism: Int = 4): Flow[ByteString, UploadPartResponse, NotUsed] = {
createRequests(s3Location, chunkSize, parallelism, metadata)
.via(Http().superPool[(MultipartUpload, Int)]())
.map {
case (Success(r), (upload, index)) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.bluelabs.s3stream


sealed trait ServerSideEncryption

object ServerSideEncryption {

case object None extends ServerSideEncryption

case object Aes256 extends ServerSideEncryption

case class Kms(keyId: String, context: Map[String, String] = Map.empty) extends ServerSideEncryption
}
Loading