From 9a45e28f0b9c540f7ae530082e3d0a6db7fb3d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fredrik=20W=C3=A4rnsberg?= Date: Thu, 14 Oct 2021 13:01:25 +0200 Subject: [PATCH] zio-cached --- .../validation/FragmentValidator.scala | 147 ++++++++++++------ 1 file changed, 101 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/caliban/validation/FragmentValidator.scala b/core/src/main/scala/caliban/validation/FragmentValidator.scala index ff843a1db..65d2ce6d1 100644 --- a/core/src/main/scala/caliban/validation/FragmentValidator.scala +++ b/core/src/main/scala/caliban/validation/FragmentValidator.scala @@ -6,6 +6,9 @@ import caliban.parsing.adt.Selection import zio.{ IO, UIO } import Utils._ import Utils.syntax._ +import Function._ +import zio.Ref +import zio.ZIO object FragmentValidator { def findConflictsWithinSelectionSet( @@ -19,41 +22,91 @@ object FragmentValidator { selectionSet ) - val conflicts = sameResponseShapeByName(context, parentType, selectionSet) ++ - sameForCommonParentsByName(context, parentType, selectionSet) - - conflicts match { - case head :: _ => - IO.fail(ValidationError(head, "")) - case _ => IO.unit - } + (for { + shapeCache <- Ref.make[Map[Iterable[Selection], Iterable[String]]](Map.empty) + parentsCache <- Ref.make[Map[Iterable[Selection], Iterable[String]]](Map.empty) + groupsCache <- Ref.make[Map[Set[SelectedField], Iterable[Set[SelectedField]]]](Map.empty) + responseShape = sameResponseShapeByName(shapeCache, context, parentType, selectionSet) + commonParents = sameForCommonParentsByName(parentsCache, groupsCache, context, parentType, selectionSet) + conflicts <- responseShape zip commonParents + all = conflicts._1 ++ conflicts._2 + _ <- IO.fromOption(all.headOption).flip.mapError(ValidationError(_, "")) + } yield ()) } - def sameResponseShapeByName(context: Context, parentType: __Type, set: Iterable[Selection]): Iterable[String] = { - val fields = FieldMap(context, parentType, set) - fields.flatMap { case (name, values) => - cross(values).flatMap { pair => - val (f1, f2) = pair - if (doTypesConflict(f1.fieldDef.`type`(), f2.fieldDef.`type`())) { - List( - s"$name has conflicting types: ${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name - .getOrElse("")}.${f2.fieldDef.name}. Try using an alias." - ) - } else - sameResponseShapeByName(context, parentType, f1.selection.selectionSet ++ f2.selection.selectionSet) - } + def cached[A, R, E, B](cache: Ref[Map[A, B]], a: A)(f: => ZIO[R, E, B]): ZIO[R, E, B] = + for { + v <- cache.get + res <- v.get(a) match { + case None => + for { + res <- f + _ <- cache.update(_ + (a -> res)) + } yield res + case Some(v) => ZIO.succeed(v) + } + } yield res + + def sameResponseShapeByName( + shapeCache: Ref[Map[Iterable[Selection], Iterable[String]]], + context: Context, + parentType: __Type, + set: Iterable[Selection] + ): UIO[Iterable[String]] = + cached(shapeCache, set) { + val fields = FieldMap(context, parentType, set) + UIO + .collect(fields.toIterable)({ case (name, values) => + UIO + .collect(cross(values)) { pair => + val (f1, f2) = pair + if (doTypesConflict(f1.fieldDef.`type`(), f2.fieldDef.`type`())) { + UIO.succeed( + List( + s"$name has conflicting types: ${f1.parentType.name.getOrElse("")}.${f1.fieldDef.name} and ${f2.parentType.name + .getOrElse("")}.${f2.fieldDef.name}. Try using an alias." + ) + ) + } else + sameResponseShapeByName( + shapeCache, + context, + parentType, + f1.selection.selectionSet ++ f2.selection.selectionSet + ) + } + .map(_.flatten) + }) + .map(_.flatten) } - } - def sameForCommonParentsByName(context: Context, parentType: __Type, set: Iterable[Selection]): Iterable[String] = { - val fields = FieldMap(context, parentType, set) - fields.flatMap({ case (name, fields) => - groupByCommonParents(context, parentType, fields).flatMap { group => - val merged = group.flatMap(_.selection.selectionSet) - requireSameNameAndArguments(group) ++ sameForCommonParentsByName(context, parentType, merged) - } - }) - } + def sameForCommonParentsByName( + parentsCache: Ref[Map[Iterable[Selection], Iterable[String]]], + groupsCache: Ref[Map[Set[SelectedField], Iterable[Set[SelectedField]]]], + context: Context, + parentType: __Type, + set: Iterable[Selection] + ): UIO[Iterable[String]] = + cached(parentsCache, set) { + val fields = FieldMap(context, parentType, set) + UIO + .collect(fields.toIterable) { case (name, fields) => + groupByCommonParents(groupsCache, context, parentType, fields).flatMap { grouped => + UIO.collect(grouped) { group => + val merged = group.flatMap(_.selection.selectionSet) + (UIO(requireSameNameAndArguments(group)) <*> sameForCommonParentsByName( + parentsCache, + groupsCache, + context, + parentType, + merged + )).map { case (a, b) => a ++ b } + } + }.map(_.flatten) + } + .map(_.flatten) + } + // } def doTypesConflict(t1: __Type, t2: __Type): Boolean = if (isNonNull(t1)) @@ -85,26 +138,28 @@ object FragmentValidator { } def groupByCommonParents( + groupsCache: Ref[Map[Set[SelectedField], Iterable[Set[SelectedField]]]], context: Context, parentType: __Type, fields: Set[SelectedField] - ): Iterable[Set[SelectedField]] = { - val abstractGroup = fields.collect({ - case field if !isConcrete(field.parentType) => field - }) - - val concreteGroups = fields - .collect({ - case f if isConcrete(f.parentType) && f.parentType.name.isDefined => (f.parentType.name.get, f) + ): UIO[Iterable[Set[SelectedField]]] = + cached(groupsCache, fields) { + val abstractGroup = fields.collect({ + case field if !isConcrete(field.parentType) => field }) - .foldLeft(Map.empty[String, Set[SelectedField]]) { case (acc, (name, field)) => - val value = acc.get(name).map(_ + field).getOrElse(Set(field)) - acc + (name -> value) - } - if (concreteGroups.size < 1) List(fields) - else concreteGroups.values.map(_ ++ abstractGroup) - } + val concreteGroups = fields + .collect({ + case f if isConcrete(f.parentType) && f.parentType.name.isDefined => (f.parentType.name.get, f) + }) + .foldLeft(Map.empty[String, Set[SelectedField]]) { case (acc, (name, field)) => + val value = acc.get(name).map(_ + field).getOrElse(Set(field)) + acc + (name -> value) + } + + if (concreteGroups.size < 1) UIO(List(fields)) + else UIO(concreteGroups.values.map(_ ++ abstractGroup)) + } def failValidation[T](msg: String, explanatoryText: String): IO[ValidationError, T] = IO.fail(ValidationError(msg, explanatoryText))