diff --git a/core/src/main/scala-2/caliban/Scala3Annotations.scala b/core/src/main/scala-2/caliban/Scala3Annotations.scala new file mode 100644 index 000000000..5384ff442 --- /dev/null +++ b/core/src/main/scala-2/caliban/Scala3Annotations.scala @@ -0,0 +1,10 @@ +package caliban + +import scala.annotation.StaticAnnotation + +/** + * Stubs for annotations that exist in Scala 3 but not in Scala 2 + */ +private[caliban] object Scala3Annotations { + final class static extends StaticAnnotation +} diff --git a/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala b/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala index 8866de900..34324a550 100644 --- a/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala +++ b/core/src/main/scala-2/caliban/schema/ArgBuilderDerivation.scala @@ -10,6 +10,7 @@ import scala.collection.compat._ import scala.language.experimental.macros trait CommonArgBuilderDerivation { + import caliban.syntax._ type Typeclass[T] = ArgBuilder[T] @@ -53,7 +54,7 @@ trait CommonArgBuilderDerivation { ctx.constructMonadic { p => val idx = p.index val (label, default) = params(idx) - val field = fields.getOrElse(label, null) + val field = fields.getOrElseNull(label) if (field ne null) p.typeclass.build(field) else default } } diff --git a/core/src/main/scala-2/caliban/syntax.scala b/core/src/main/scala-2/caliban/syntax.scala new file mode 100644 index 000000000..b1e968a5f --- /dev/null +++ b/core/src/main/scala-2/caliban/syntax.scala @@ -0,0 +1,15 @@ +package caliban + +import scala.collection.mutable + +private[caliban] object syntax { + val NullFn: () => AnyRef = () => null + + implicit class EnrichedImmutableMapOps[K, V <: AnyRef](private val self: Map[K, V]) extends AnyVal { + def getOrElseNull(key: K): V = self.getOrElse(key, NullFn()).asInstanceOf[V] + } + + implicit class EnrichedHashMapOps[K, V <: AnyRef](private val self: mutable.HashMap[K, V]) extends AnyVal { + def getOrElseNull(key: K): V = self.getOrElse(key, NullFn()).asInstanceOf[V] + } +} diff --git a/core/src/main/scala-3/caliban/Scala3Annotations.scala b/core/src/main/scala-3/caliban/Scala3Annotations.scala new file mode 100644 index 000000000..cf479701d --- /dev/null +++ b/core/src/main/scala-3/caliban/Scala3Annotations.scala @@ -0,0 +1,10 @@ +package caliban + +import scala.annotation + +/** + * Proxies for annotations that exist in Scala 3 but not in Scala 2 + */ +private[caliban] object Scala3Annotations { + type static = annotation.static +} diff --git a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala index e3b59b6bf..1160acfdc 100644 --- a/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala +++ b/core/src/main/scala-3/caliban/schema/ArgBuilderDerivation.scala @@ -13,6 +13,8 @@ import scala.deriving.Mirror import scala.util.NotGiven trait CommonArgBuilderDerivation { + import caliban.syntax.* + transparent inline def recurseSum[P, Label <: Tuple, A <: Tuple]( inline values: List[(String, List[Any], ArgBuilder[Any])] = Nil ): List[(String, List[Any], ArgBuilder[Any])] = @@ -150,7 +152,7 @@ trait CommonArgBuilderDerivation { val arr = Array.ofDim[Any](l) while (i < l) { val (label, default, builder) = params(i) - val field = fields.getOrElse(label, null) + val field = fields.getOrElseNull(label) val value = if (field ne null) builder.build(field) else default value match { case Right(v) => arr(i) = v diff --git a/core/src/main/scala-3/caliban/schema/SumSchema.scala b/core/src/main/scala-3/caliban/schema/SumSchema.scala index 620f257a1..3254173a8 100644 --- a/core/src/main/scala-3/caliban/schema/SumSchema.scala +++ b/core/src/main/scala-3/caliban/schema/SumSchema.scala @@ -20,7 +20,7 @@ final private class SumSchema[R, A]( val (m, s) = _members ( m.sortBy(_._1), - s.toVector, + s.toArray, s.map(s0 => SchemaUtils.isEmptyUnionObject(s0.toType_())).toArray[Boolean] ) } diff --git a/core/src/main/scala-3/caliban/syntax.scala b/core/src/main/scala-3/caliban/syntax.scala new file mode 100644 index 000000000..5f0c0e95b --- /dev/null +++ b/core/src/main/scala-3/caliban/syntax.scala @@ -0,0 +1,20 @@ +package caliban + +import scala.annotation.static + +import scala.collection.mutable + +private[caliban] object syntax { + @static val NullFn: () => AnyRef = () => null + + extension [K, V <: AnyRef](inline map: Map[K, V]) { + transparent inline def getOrElseNull(key: K): V = map.getOrElse(key, NullFn()).asInstanceOf[V] + } + + extension [K, V <: AnyRef](inline map: mutable.HashMap[K, V]) { + transparent inline def getOrElseNull(key: K): V = map.getOrElse(key, NullFn()).asInstanceOf[V] + } +} + +// Required for @static fields +private final class syntax private diff --git a/core/src/main/scala/caliban/execution/Field.scala b/core/src/main/scala/caliban/execution/Field.scala index c70753d1b..f8e03d705 100644 --- a/core/src/main/scala/caliban/execution/Field.scala +++ b/core/src/main/scala/caliban/execution/Field.scala @@ -43,24 +43,28 @@ case class Field( ) { self => lazy val locationInfo: LocationInfo = _locationInfo() - private[caliban] val aliasedName: String = alias.getOrElse(name) + private[caliban] val aliasedName: String = + if (alias.isEmpty) name else alias.get private[caliban] lazy val allFieldsUniqueNameAndCondition: Boolean = { - def inner: Boolean = { - val set = new mutable.HashSet[String] + def inner(fields: List[Field]): Boolean = { val headCondition = fields.head._condition - val _ = set.add(fields.head.aliasedName) - - var remaining = fields.tail - var result = true - while ((remaining ne Nil) && result) { - val f = remaining.head - result = set.add(f.aliasedName) && f._condition == headCondition - remaining = remaining.tail + + val seen = new mutable.HashSet[String] + seen.add(fields.head.aliasedName) + + var rem = fields.tail + while (rem ne Nil) { + val f = rem.head + val continue = seen.add(f.aliasedName) && f._condition == headCondition + if (!continue) return false + rem = rem.tail } - result + true } - fields.isEmpty || fields.tail.isEmpty || inner + + val fields0 = fields + fields0.isEmpty || fields0.tail.isEmpty || inner(fields0) } def combine(other: Field): Field = @@ -153,7 +157,10 @@ object Field { case F(alias, name, arguments, directives, selectionSet, index) => val selected = innerType.getFieldOrNull(name) - val schemaDirectives = if (selected eq null) Nil else selected.directives.getOrElse(Nil) + val schemaDirectives = + if ((selected eq null) || selected.directives.isEmpty) Nil + else selected.directives.get + val resolvedDirectives = (directives ::: schemaDirectives).map(resolveDirectiveVariables(variableValues, variableDefinitionsMap)) diff --git a/core/src/main/scala/caliban/introspection/adt/__Type.scala b/core/src/main/scala/caliban/introspection/adt/__Type.scala index d74be7f6a..e242dc982 100644 --- a/core/src/main/scala/caliban/introspection/adt/__Type.scala +++ b/core/src/main/scala/caliban/introspection/adt/__Type.scala @@ -24,6 +24,8 @@ case class __Type( @GQLExcluded origin: Option[String] = None, isOneOf: Option[Boolean] = None ) { self => + import caliban.syntax._ + final override lazy val hashCode: Int = super.hashCode() private[caliban] lazy val typeNameRepr: String = DocumentRenderer.renderTypeName(this) @@ -144,7 +146,7 @@ case class __Type( } private[caliban] def getFieldOrNull(name: String): __Field = - allFieldsMap.getOrElse(name, null) + allFieldsMap.getOrElseNull(name) lazy val innerType: __Type = Types.innerType(this) diff --git a/core/src/main/scala/caliban/schema/ObjectFieldResolver.scala b/core/src/main/scala/caliban/schema/ObjectFieldResolver.scala index 51dcedcf8..d6cc1363a 100644 --- a/core/src/main/scala/caliban/schema/ObjectFieldResolver.scala +++ b/core/src/main/scala/caliban/schema/ObjectFieldResolver.scala @@ -1,5 +1,6 @@ package caliban.schema +import caliban.Scala3Annotations.static import caliban.schema.Step.{ NullStep, ObjectStep } import scala.collection.compat._ @@ -12,15 +13,16 @@ final private class ObjectFieldResolver[R, A] private ( import ObjectFieldResolver._ private def getFieldStep(value: A): String => Step[R] = - fields.getOrElse(_, nullStepFn)(value) + fields.getOrElse(_, NullStepFn0())(value) def resolve(value: A): Step[R] = ObjectStep(name, getFieldStep(value)) } private object ObjectFieldResolver { + @static private val NullStepFn: Any => Step[Any] = _ => NullStep + @static private val NullStepFn0: () => Any => Step[Any] = () => NullStepFn + def apply[R, A](objectName: String, fields: Iterable[(String, A => Step[R])]): ObjectFieldResolver[R, A] = // NOTE: mutable.HashMap is about twice as fast than immutable.HashMap for .get new ObjectFieldResolver(objectName, mutable.HashMap.from(fields)) - - private val nullStepFn: Any => Step[Any] = _ => NullStep } diff --git a/core/src/main/scala/caliban/validation/Validator.scala b/core/src/main/scala/caliban/validation/Validator.scala index bd9c2f07b..2d472e9b2 100644 --- a/core/src/main/scala/caliban/validation/Validator.scala +++ b/core/src/main/scala/caliban/validation/Validator.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ListBuffer object Validator extends SchemaValidator { import ValidationOps._ + import caliban.syntax._ /** * A QueryValidation is a pure program that can access a Context, fail with a ValidationError or succeed with Unit. @@ -349,12 +350,7 @@ object Validator extends SchemaValidator { ) val v2 = - if ( - t match { - case Some(t0) => t0._isOneOfInput - case _ => false - } - ) + if (t.exists(_._isOneOfInput)) Some( failWhen(v.variableType.nullable)( s"Variable '${v.name}' cannot be nullable.", @@ -415,7 +411,7 @@ object Validator extends SchemaValidator { val descendantSpreads = collectFragmentSpreads(selectionSets) val cycleDetected = descendantSpreads.exists { s => visited.contains(s.name) || { - val f = context.fragments.getOrElse(s.name, null) + val f = context.fragments.getOrElseNull(s.name) (f ne null) && detectCycles(context, f, visited + s.name, checked) } } @@ -472,7 +468,7 @@ object Validator extends SchemaValidator { case f: Field => validateField(context, f, currentType) case FragmentSpread(name, _) => - context.fragments.getOrElse(name, null) match { + context.fragments.getOrElseNull(name) match { case null => failValidation( s"Fragment spread '$name' is not defined.", @@ -497,7 +493,7 @@ object Validator extends SchemaValidator { typeCondition: Option[NamedType], selectionSet: List[Selection] )(implicit v: ValidatedFragments): Either[ValidationError, Unit] = - typeCondition.fold(currentType)(t => context.rootType.types.getOrElse(t.name, null)) match { + typeCondition.fold(currentType)(t => context.rootType.types.getOrElseNull(t.name)) match { case null => val typeConditionName = typeCondition.fold("?")(_.name) failValidation( @@ -559,7 +555,7 @@ object Validator extends SchemaValidator { val providedArgs = field.arguments val v1 = validateAllNonEmpty(fieldArgsNonNull.flatMap { arg => - val arg0 = field.arguments.getOrElse(arg.name, null) + val arg0 = field.arguments.getOrElseNull(arg.name) val opt1 = (arg.defaultValue, arg0) match { case (None, null) | (None, NullValue) => Some( @@ -653,7 +649,7 @@ object Validator extends SchemaValidator { ) ) case VariableValue(variableName) => - context.variableDefinitions.getOrElse(variableName, null) match { + context.variableDefinitions.getOrElseNull(variableName) match { case null => failValidation( s"Variable '$variableName' is not defined.",