Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for partial file download from S3 #264 #265

Merged
merged 2 commits into from
Apr 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/src/main/paradox/s3.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ Scala
Java
: @@snip (../../../../s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #download }

In order to download a range of a file's data you can use overloaded method which
additionally takes `ByteRange` as argument.

Scala
: @@snip (../../../../s3/src/test/scala/akka/stream/alpakka/s3/scaladsl/S3SourceSpec.scala) { #rangedDownload }

Java
: @@snip (../../../../s3/src/test/java/akka/stream/alpakka/s3/javadsl/S3ClientTest.java) { #rangedDownload }

### Running the example code

The code in this guide is part of runnable tests of this project. You are welcome to edit the code and run it in sbt.
Expand Down
14 changes: 10 additions & 4 deletions s3/src/main/scala/akka/stream/alpakka/s3/impl/S3Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.ByteRange
import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
import akka.stream.Materializer
import akka.stream.alpakka.s3.acl.CannedAcl
Expand Down Expand Up @@ -59,13 +60,18 @@ private[alpakka] final class S3Stream(credentials: AWSCredentials,
val MinChunkSize = 5242880 //in bytes
val signingKey = SigningKey(credentials, CredentialScope(LocalDate.now(), region, "s3"))

def download(s3Location: S3Location): Source[ByteString, NotUsed] = {
def download(s3Location: S3Location, range: Option[ByteRange] = None): Source[ByteString, NotUsed] = {
import mat.executionContext
Source.fromFuture(request(s3Location).flatMap(entityForSuccess).map(_.dataBytes)).flatMapConcat(identity)
Source.fromFuture(request(s3Location, range).flatMap(entityForSuccess).map(_.dataBytes)).flatMapConcat(identity)
}

def request(s3Location: S3Location): Future[HttpResponse] =
signAndGet(getDownloadRequest(s3Location, region))
def request(s3Location: S3Location, rangeOption: Option[ByteRange] = None): Future[HttpResponse] = {
val downloadRequest = getDownloadRequest(s3Location, region)
signAndGet(rangeOption match {
case Some(range) => downloadRequest.withHeaders(headers.Range(range))
case _ => downloadRequest
})
}

/**
* Uploads a stream of ByteStrings to a specified location as a multipart upload.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import java.util.concurrent.CompletionStage
import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.impl.model.JavaUri
import akka.http.javadsl.model.headers.ByteRange
import akka.http.javadsl.model.{ContentType, HttpResponse, Uri}
import akka.http.scaladsl.model.{ContentTypes, ContentType => ScalaContentType}
import akka.http.scaladsl.model.headers.{ByteRange => ScalaByteRange}
import akka.stream.Materializer
import akka.stream.alpakka.s3.auth.AWSCredentials
import akka.stream.alpakka.s3.impl.{CompleteMultipartUploadResult, MetaHeaders, S3Location, S3Stream}
Expand All @@ -34,6 +36,11 @@ final class S3Client(credentials: AWSCredentials, region: String, system: ActorS
def download(bucket: String, key: String): Source[ByteString, NotUsed] =
impl.download(S3Location(bucket, key)).asJava

def download(bucket: String, key: String, range: ByteRange): Source[ByteString, NotUsed] = {
val scalaRange = range.asInstanceOf[ScalaByteRange]
impl.download(S3Location(bucket, key), Some(scalaRange)).asJava
}

def multipartUpload(bucket: String,
key: String,
contentType: ContentType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package akka.stream.alpakka.s3.scaladsl
import akka.NotUsed
import akka.actor.ActorSystem
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.Materializer
import akka.stream.alpakka.s3.S3Settings
import akka.stream.alpakka.s3.acl.CannedAcl
Expand Down Expand Up @@ -42,6 +43,9 @@ final class S3Client(credentials: AWSCredentials, region: String)(implicit syste

def download(bucket: String, key: String): Source[ByteString, NotUsed] = impl.download(S3Location(bucket, key))

def download(bucket: String, key: String, range: ByteRange): Source[ByteString, NotUsed] =
impl.download(S3Location(bucket, key), Some(range))

def multipartUpload(bucket: String,
key: String,
contentType: ContentType = ContentTypes.`application/octet-stream`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
package akka.stream.alpakka.s3.javadsl;

import akka.NotUsed;
import akka.actor.ActorSystem;
import akka.http.javadsl.model.Uri;
import akka.http.javadsl.model.headers.ByteRange;
import akka.stream.ActorMaterializer;
import akka.stream.Materializer;
import akka.stream.alpakka.s3.auth.AWSCredentials;
Expand All @@ -16,10 +16,12 @@
import akka.util.ByteString;
import org.junit.Test;

import java.util.Arrays;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class S3ClientTest extends S3WireMockBase {

Expand Down Expand Up @@ -63,4 +65,22 @@ public void download() throws Exception {

assertEquals(body(), result);
}

@Test
public void rangedDownload() throws Exception {

mockRangedDownload();

//#rangedDownload
final Source<ByteString, NotUsed> source = client.download(bucket(), bucketKey(),
ByteRange.createSlice(bytesRangeStart(), bytesRangeEnd()));
//#rangedDownload

final CompletionStage<byte[]> resultCompletionStage =
source.map(ByteString::toArray).runWith(Sink.head(), materializer);

byte[] result = resultCompletionStage.toCompletableFuture().get(5, TimeUnit.SECONDS);

assertTrue(Arrays.equals(rangeOfBody(), result));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package akka.stream.alpakka.s3.scaladsl

import akka.NotUsed
import akka.http.scaladsl.model.headers.ByteRange
import akka.stream.alpakka.s3.S3Exception
import akka.stream.scaladsl.{Sink, Source}
import akka.util.ByteString
Expand All @@ -17,14 +18,28 @@ class S3SourceSpec extends S3WireMockBase with S3ClientIntegrationSpec {
mockDownload()

//#download
val s3Source: Source[ByteString, NotUsed] = s3Client.download("testBucket", "testKey")
val s3Source: Source[ByteString, NotUsed] = s3Client.download(bucket, bucketKey)
//#download

val result: Future[String] = s3Source.map(_.utf8String).runWith(Sink.head)

result.futureValue shouldBe body
}

it should "download a range of file's bytes from S3 if bytes range given" in {

mockRangedDownload()

//#rangedDownload
val s3Source: Source[ByteString, NotUsed] =
s3Client.download(bucket, bucketKey, ByteRange(bytesRangeStart, bytesRangeEnd))
//#rangedDownload

val result: Future[Array[Byte]] = s3Source.map(_.toArray).runWith(Sink.head)

result.futureValue shouldBe rangeOfBody
}

it should "fail if request returns 404" in {

mock404s()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.github.tomakehurst.wiremock.client.WireMock._
import com.github.tomakehurst.wiremock.core.WireMockConfiguration._
import com.typesafe.config.ConfigFactory
import S3WireMockBase._
import com.github.tomakehurst.wiremock.matching.EqualToPattern

abstract class S3WireMockBase(_system: ActorSystem, _wireMockServer: WireMockServer) extends TestKit(_system) {

Expand Down Expand Up @@ -39,15 +40,30 @@ abstract class S3WireMockBase(_system: ActorSystem, _wireMockServer: WireMockSer
val uploadId = "VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA"
val etag = "5b27a21a97fcf8a7004dd1d906e7a5ba"
val url = s"http://testbucket.s3.amazonaws.com/testKey"
val (bytesRangeStart, bytesRangeEnd) = (2, 10)
val rangeOfBody = body.getBytes.slice(bytesRangeStart, bytesRangeEnd + 1)

def mockDownload(): Unit =
mock
.register(
get(urlEqualTo("/testKey")).willReturn(
get(urlEqualTo(s"/$bucketKey")).willReturn(
aResponse().withStatus(200).withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""").withBody(body)
)
)

def mockRangedDownload(): Unit =
mock
.register(
get(urlEqualTo(s"/$bucketKey"))
.withHeader("Range", new EqualToPattern(s"bytes=$bytesRangeStart-$bytesRangeEnd"))
.willReturn(
aResponse()
.withStatus(200)
.withHeader("ETag", """"fba9dede5f27731c9771645a39863328"""")
.withBody(rangeOfBody)
)
)

def mockUpload(): Unit = {
mock
.register(
Expand Down Expand Up @@ -120,8 +136,8 @@ private object S3WireMockBase {
val s = (Thread.currentThread.getStackTrace map (_.getClassName) drop 1)
.dropWhile(_ matches "(java.lang.Thread|.*WireMockBase.?$)")
val reduced = s.lastIndexWhere(_ == clazz.getName) match {
case -1 s
case z s drop (z + 1)
case -1 => s
case z => s drop (z + 1)
}
reduced.head.replaceFirst(""".*\.""", "").replaceAll("[^a-zA-Z_0-9]", "_")
}
Expand Down