Skip to content

Commit

Permalink
fix: nested uploads (#1167)
Browse files Browse the repository at this point in the history
* wip upload fix

* fix lists

* make sure we update variables in place

* fix assertion

* remove println
  • Loading branch information
frekw authored Nov 24, 2021
1 parent c13b3d9 commit b4a3d36
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
32 changes: 22 additions & 10 deletions core/src/main/scala/caliban/uploads/Upload.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package caliban.uploads
import caliban.InputValue.ListValue
import caliban.Value.{ NullValue, StringValue }
import caliban.{ GraphQLRequest, InputValue }
import scala.annotation.tailrec
import zio.stream.{ ZSink, ZStream }
import zio.{ Chunk, RIO, UIO, URIO, ZIO }

Expand Down Expand Up @@ -45,14 +46,12 @@ case class GraphQLUploadRequest(
def remap: GraphQLRequest =
request.copy(
variables = request.variables.map { vars =>
val files = fileMap.flatMap {
case (name, Left("variables") :: Left(key) :: path) => vars.get(key).map(loop(_, path, name)).map(key -> _)
case _ => None
}

vars ++ files.groupBy(_._1).map {
case (key, value :: Nil) => (key, value._2)
case (key, values) => (key, ListValue(values.map(_._2)))
fileMap.foldLeft(vars) { case (acc, (name, rest)) =>
val value = rest match {
case Left("variables") :: Left(key) :: path => acc.get(key).map(loop(_, path, name)).map(key -> _)
case _ => None
}
value.fold(acc)(v => acc + v)
}
}
)
Expand All @@ -66,13 +65,14 @@ case class GraphQLUploadRequest(
case Some(Left(key)) =>
value match {
case InputValue.ObjectValue(fields) =>
fields.get(key).fold[InputValue](NullValue)(loop(_, path.drop(1), name))
val v = fields.get(key).fold[InputValue](NullValue)(loop(_, path.drop(1), name))
InputValue.ObjectValue(fields + (key -> v))
case _ => NullValue
}
case Some(Right(idx)) =>
value match {
case InputValue.ListValue(values) =>
values.lift(idx).fold[InputValue](NullValue)(loop(_, path.drop(1), name))
InputValue.ListValue(replaceAt(values, idx)(loop(_, path.drop(1), name)))
case _ => NullValue
}
case None =>
Expand All @@ -81,4 +81,16 @@ case class GraphQLUploadRequest(
StringValue(name)
}

private def replaceAt[A](xs: List[A], idx: Int)(f: A => A): List[A] = {
@tailrec
def loop[A](xs: List[A], idx: Int, acc: List[A], f: A => A): List[A] =
(xs, idx) match {
case (x :: xs, 0) => (f(x) :: acc).reverse ++ xs
case (Nil, _) => acc.reverse
case (x :: xs, idx) => loop(xs, idx - 1, x :: acc, f)
}

loop(xs, idx, List(), f)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,35 @@ object TapirAdapterSpec {
)
}
),
runUpload.map(runUpload =>
testM("test http upload endpoint for extra fields") {
val query =
"""{ "query": "mutation ($uploadedDocuments: [UploadedDocumentInput!]!) { uploadFilesWithExtraFields(uploadedDocuments: $uploadedDocuments) { someField1, someField2} }", "variables": { "uploadedDocuments": [{"file": null, "someField1": 1, "someField2": 2}, {"file": null, "someField1": 3}] }}"""

val parts =
List(
Part("operations", query.getBytes, contentType = Some(MediaType.ApplicationJson)),
Part(
"map",
"""{ "0": ["variables.uploadedDocuments.0.file"], "1": ["variables.uploadedDocuments.1.file"]}""".getBytes
),
Part("0", """image""".getBytes, contentType = Some(MediaType.ImagePng)).fileName("a.png"),
Part("1", """text""".getBytes, contentType = Some(MediaType.TextPlain)).fileName("a.txt")
)

val io =
for {
res <- send(runUpload((parts, null)))
response <- ZIO.fromEither(res.body).orElseFail(new Throwable("Failed to parse result"))
} yield response.data.toString

assertM(io)(
equalTo(
"""{"uploadFilesWithExtraFields":[{"someField1":1,"someField2":2},{"someField1":3,"someField2":null}]}"""
)
)
}
),
runWS.map(runWS =>
testM("test ws endpoint") {
val io =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ object TestApi extends GenericSchema[TestService with Uploads] {
case class File(hash: String, filename: String, mimetype: String)
case class UploadFileArgs(file: Upload)
case class UploadFilesArgs(files: List[Upload])
case class UploadedDocument(file: Upload, someField1: Int, someField2: Option[Int])
case class UploadWithExtraFields(uploadedDocuments: List[UploadedDocument])
case class SomeFieldOutput(someField1: Int, someField2: Option[Int])

case class Queries(
@GQLDescription("Return all characters from a given origin")
Expand All @@ -32,7 +35,8 @@ object TestApi extends GenericSchema[TestService with Uploads] {
case class Mutations(
deleteCharacter: CharacterArgs => URIO[TestService, Boolean],
uploadFile: UploadFileArgs => ZIO[Uploads, Throwable, File],
uploadFiles: UploadFilesArgs => ZIO[Uploads, Throwable, List[File]]
uploadFiles: UploadFilesArgs => ZIO[Uploads, Throwable, List[File]],
uploadFilesWithExtraFields: UploadWithExtraFields => ZIO[Uploads, Throwable, List[SomeFieldOutput]]
)
case class Subscriptions(characterDeleted: ZStream[TestService, Nothing, String])

Expand All @@ -51,7 +55,8 @@ object TestApi extends GenericSchema[TestService with Uploads] {
Mutations(
args => TestService.deleteCharacter(args.name),
args => TestService.uploadFile(args.file),
args => TestService.uploadFiles(args.files)
args => TestService.uploadFiles(args.files),
args => TestService.uploadFilesWithOtherFields(args.uploadedDocuments)
),
Subscriptions(TestService.deletedEvents)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package caliban.interop.tapir

import caliban.interop.tapir.TestApi.File
import caliban.interop.tapir.TestApi.{ File, SomeFieldOutput, UploadedDocument }
import caliban.interop.tapir.TestData._
import caliban.uploads.{ Upload, Uploads }
import zio.stream.ZStream
Expand Down Expand Up @@ -59,6 +59,15 @@ object TestService {
)
)

def uploadFilesWithOtherFields(
uploadedDocuments: List[UploadedDocument]
): ZIO[Uploads, Throwable, List[SomeFieldOutput]] =
ZIO.succeed(
for {
document <- uploadedDocuments
} yield SomeFieldOutput(document.someField1, document.someField2)
)

def make(initial: List[Character]): ZLayer[Any, Nothing, TestService] =
(for {
characters <- Ref.make(initial)
Expand Down

0 comments on commit b4a3d36

Please sign in to comment.