Skip to content

Commit

Permalink
Simplify tapir schema -> openapi schema conversion (#3584)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Mar 8, 2024
1 parent 937a96c commit 02978cd
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 122 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SchemasForEndpoints(
* A tuple: the first element can be used to create the components section in the docs. The second can be used to resolve (possible)
* top-level references from parameters / bodies.
*/
def apply(): (ListMap[SchemaId, ASchema], Schemas) = {
def apply(): (ListMap[SchemaId, ASchema], TSchemaToASchema) = {
val keyedCombinedSchemas: Iterable[KeyedSchema] = ToKeyedSchemas.uniqueCombined(
es.flatMap(e =>
forInput(e.securityInput) ++ forInput(e.input) ++ forOutput(e.errorOutput) ++ forOutput(e.output)
Expand All @@ -29,12 +29,11 @@ class SchemasForEndpoints(
val toSchemaReference = new ToSchemaReference(keysToIds, keyedCombinedSchemas.toMap)
val tschemaToASchema = new TSchemaToASchema(toSchemaReference, markOptionsAsNullable)

val keysToSchemas: ListMap[SchemaKey, ASchema] = keyedCombinedSchemas.map(td => (td._1, tschemaToASchema(td._2))).toListMap
val keysToSchemas: ListMap[SchemaKey, ASchema] =
keyedCombinedSchemas.map(td => (td._1, tschemaToASchema(td._2, allowReference = false))).toListMap
val schemaIds: Map[SchemaKey, (SchemaId, ASchema)] = keysToSchemas.map { case (k, v) => k -> ((keysToIds(k), v)) }

val schemas = new Schemas(tschemaToASchema, toSchemaReference, markOptionsAsNullable)

(schemaIds.values.toListMap, schemas)
(schemaIds.values.toListMap, tschemaToASchema)
}

private def forInput(input: EndpointInput[_]): List[KeyedSchema] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,76 @@ import sttp.tapir.docs.apispec.DocsExtensionAttribute.RichSchema
import sttp.tapir.docs.apispec.schema.TSchemaToASchema.{tDefaultToADefault, tExampleToAExample}
import sttp.tapir.docs.apispec.{DocsExtensions, exampleValue}
import sttp.tapir.internal._
import sttp.tapir.{Validator, Schema => TSchema, SchemaType => TSchemaType}
import sttp.tapir.{Codec, Validator, Schema => TSchema, SchemaType => TSchemaType}

/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using `toSchemaReference` to resolve nested references. */
private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, markOptionsAsNullable: Boolean) {
def apply[T](schema: TSchema[T], isOptionElement: Boolean = false): ASchema = {
/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using `toSchemaReference` to resolve references. */
private[docs] class TSchemaToASchema(toSchemaReference: ToSchemaReference, markOptionsAsNullable: Boolean) {

def apply[T](codec: Codec[T, _, _]): ASchema = apply(codec.schema, allowReference = true)

/** @param allowReference
* Can a reference schema be generated, if this is a named schema - should be `false` for top-level component definitions (otherwise
* the definitions are infinitely recursive)
*/
def apply[T](schema: TSchema[T], allowReference: Boolean, isOptionElement: Boolean = false): ASchema = {
val nullable = markOptionsAsNullable && isOptionElement
val result = schema.schemaType match {
case TSchemaType.SInteger() => ASchema(SchemaType.Integer)
case TSchemaType.SNumber() => ASchema(SchemaType.Number)
case TSchemaType.SBoolean() => ASchema(SchemaType.Boolean)
case TSchemaType.SString() => ASchema(SchemaType.String)
case p @ TSchemaType.SProduct(fields) =>
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields)
)
case TSchemaType.SArray(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
ASchema(SchemaType.Array).copy(items = Some(toSchemaReference.map(nested, name)))
case TSchemaType.SArray(el) => ASchema(SchemaType.Array).copy(items = Some(apply(el)))
case opt @ TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
// #3288: in case there are multiple different customisations of the nested schema, we need to propagate the
// metadata to properly customise the reference. These are also propagated in ToKeyedSchemas when computing
// the initial list of schemas.
val propagated = propagateMetadataForOption(schema, opt).element
val ref = toSchemaReference.map(propagated, name)
if (!markOptionsAsNullable) ref else ref.copy(nullable = Some(true))
case TSchemaType.SOption(el) => apply(el, isOptionElement = true)
case TSchemaType.SBinary() => ASchema(SchemaType.String).copy(format = SchemaFormat.Binary)
case TSchemaType.SDate() => ASchema(SchemaType.String).copy(format = SchemaFormat.Date)
case TSchemaType.SDateTime() => ASchema(SchemaType.String).copy(format = SchemaFormat.DateTime)
case TSchemaType.SRef(fullName) => toSchemaReference.mapDirect(fullName)
case TSchemaType.SCoproduct(schemas, d) =>
ASchema.oneOf(
schemas
.filterNot(_.hidden)
.map {
case nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => toSchemaReference.map(nested, name)
case t => apply(t)
}
.sortBy {
case schema if schema.$ref.isDefined => schema.$ref.get
case schema => schema.`type`.collect { case t: BasicSchemaType => t.value }.getOrElse("") + schema.toString
},
d.map(tDiscriminatorToADiscriminator)
)
case p @ TSchemaType.SOpenProduct(fields, valueSchema) =>
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields),
additionalProperties = Some(valueSchema.name match {
case Some(name) => toSchemaReference.map(valueSchema, name)
case _ => apply(valueSchema)
}).filterNot(_ => valueSchema.hidden)
)
}

val primitiveValidators = schema.validator.asPrimitiveValidators
val schemaIsWholeNumber = schema.schemaType match {
case TSchemaType.SInteger() => true
case _ => false
val result = schema.name match {
case Some(name) if allowReference => toSchemaReference.map(schema, name)
case _ =>
schema.schemaType match {
case TSchemaType.SInteger() => ASchema(SchemaType.Integer)
case TSchemaType.SNumber() => ASchema(SchemaType.Number)
case TSchemaType.SBoolean() => ASchema(SchemaType.Boolean)
case TSchemaType.SString() => ASchema(SchemaType.String)
case p @ TSchemaType.SProduct(fields) =>
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields)
)
case TSchemaType.SArray(el) => ASchema(SchemaType.Array).copy(items = Some(apply(el, allowReference = true)))
case opt @ TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
// #3288: in case there are multiple different customisations of the nested schema, we need to propagate the
// metadata to properly customise the reference. These are also propagated in ToKeyedSchemas when computing
// the initial list of schemas.
val propagated = propagateMetadataForOption(schema, opt).element
val ref = toSchemaReference.map(propagated, name)
if (!markOptionsAsNullable) ref else ref.copy(nullable = Some(true))
case TSchemaType.SOption(el) => apply(el, allowReference = true, isOptionElement = true)
case TSchemaType.SBinary() => ASchema(SchemaType.String).copy(format = SchemaFormat.Binary)
case TSchemaType.SDate() => ASchema(SchemaType.String).copy(format = SchemaFormat.Date)
case TSchemaType.SDateTime() => ASchema(SchemaType.String).copy(format = SchemaFormat.DateTime)
case TSchemaType.SRef(fullName) => toSchemaReference.mapDirect(fullName)
case TSchemaType.SCoproduct(schemas, d) =>
ASchema.oneOf(
schemas
.filterNot(_.hidden)
.map(apply(_, allowReference = true))
.sortBy {
case schema if schema.$ref.isDefined => schema.$ref.get
case schema => schema.`type`.collect { case t: BasicSchemaType => t.value }.getOrElse("") + schema.toString
},
d.map(tDiscriminatorToADiscriminator)
)
case p @ TSchemaType.SOpenProduct(fields, valueSchema) =>
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields),
additionalProperties = Some(apply(valueSchema, allowReference = true)).filterNot(_ => valueSchema.hidden)
)
}
}

if (result.$ref.isEmpty) {
// only customising non-reference schemas; references might get enriched with some meta-data if there
// are multiple different customisations of the referenced schema in ToSchemaReference (#1203)

val primitiveValidators = schema.validator.asPrimitiveValidators
val schemaIsWholeNumber = schema.schemaType match {
case TSchemaType.SInteger() => true
case _ => false
}

var s = result
s = if (nullable) s.copy(nullable = Some(true)) else s
s = addMetadata(s, schema)
Expand All @@ -84,12 +89,7 @@ private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, mar
private def extractProperties[T](fields: List[TSchemaType.SProductField[T]]) = {
fields
.filterNot(_.schema.hidden)
.map { f =>
f.schema.name match {
case Some(name) => f.name.encodedName -> toSchemaReference.map(f.schema, name)
case None => f.name.encodedName -> apply(f.schema)
}
}
.map(f => f.name.encodedName -> apply(f.schema, allowReference = true))
.toListMap
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ object TapirSchemaToJsonSchema {
val keysToIds = calculateUniqueIds(keyedSchemas.map(_._1), (key: SchemaKey) => schemaName(key.name))
val toSchemaReference = new ToSchemaReference(keysToIds, keyedSchemas.toMap, refRoot = "#/$defs/")
val tschemaToASchema = new TSchemaToASchema(toSchemaReference, markOptionsAsNullable)
val keysToSchemas = keyedSchemas.map(td => (td._1, tschemaToASchema(td._2))).toListMap
val keysToSchemas = keyedSchemas.map(td => (td._1, tschemaToASchema(td._2, allowReference = false))).toListMap
val schemaIds = keysToSchemas.map { case (k, v) => k -> ((keysToIds(k), v)) }

val nestedKeyedSchemas = schemaIds.values
val rootApiSpecSchemaOrRef: ASchema = tschemaToASchema(schema)
val rootApiSpecSchemaOrRef: ASchema = tschemaToASchema(schema, allowReference = false)

val defsList: ListMap[SchemaId, ASchema] =
nestedKeyedSchemas.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import sttp.model.Method
import sttp.tapir.EndpointOutput.WebSocketBodyWrapper
import sttp.tapir.docs.apispec.DocsExtensionAttribute.{RichEndpointIOInfo, RichEndpointInfo}
import sttp.tapir.docs.apispec.{DocsExtensions, namedPathComponents}
import sttp.tapir.docs.apispec.schema.Schemas
import sttp.tapir.docs.apispec.schema.TSchemaToASchema
import sttp.tapir.internal.{IterableToListMap, RichEndpoint}
import sttp.tapir.{AnyEndpoint, Codec, CodecFormat, EndpointIO, EndpointInput}

import scala.collection.immutable.ListMap

private[asyncapi] class EndpointToAsyncAPIWebSocketChannel(
schemas: Schemas,
tschemaToASchema: TSchemaToASchema,
codecToMessageKey: Map[Codec[_, _, _ <: CodecFormat], MessageKey],
options: AsyncAPIDocsOptions
) {
Expand Down Expand Up @@ -49,7 +49,7 @@ private[asyncapi] class EndpointToAsyncAPIWebSocketChannel(
codec: Codec[_, _, _ <: CodecFormat],
info: EndpointIO.Info[_]
): ((String, Codec[_, _, _ <: CodecFormat]), ASchema) = {
val schemaRef = schemas(codec)
val schemaRef = tschemaToASchema(codec)
schemaRef match {
case schema if schema.$ref.isEmpty =>
val schemaWithDescription = if (schema.description.isEmpty) schemaRef.copy(description = info.description) else schemaRef
Expand All @@ -63,7 +63,7 @@ private[asyncapi] class EndpointToAsyncAPIWebSocketChannel(

private def parameters(inputs: Vector[EndpointInput.Basic[_]]): ListMap[String, ReferenceOr[Parameter]] = {
inputs.collect { case EndpointInput.PathCapture(Some(name), codec, info) =>
name -> Right(Parameter(info.description, Some(schemas(codec)), None, DocsExtensions.fromIterable(info.docsExtensions)))
name -> Right(Parameter(info.description, Some(tschemaToASchema(codec)), None, DocsExtensions.fromIterable(info.docsExtensions)))
}.toListMap
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import sttp.apispec.asyncapi.{Message, SingleMessage}
import sttp.model.MediaType
import sttp.tapir.EndpointOutput.WebSocketBodyWrapper
import sttp.tapir.Schema.SName
import sttp.tapir.docs.apispec.schema.{Schemas, ToKeyedSchemas, calculateUniqueIds}
import sttp.tapir.docs.apispec.schema.{TSchemaToASchema, ToKeyedSchemas, calculateUniqueIds}
import sttp.tapir.internal.IterableToListMap
import sttp.tapir.{Codec, CodecFormat, EndpointIO, WebSocketBodyOutput, Schema => TSchema}
import sttp.ws.WebSocketFrame

import scala.collection.immutable.ListMap

private[asyncapi] class MessagesForEndpoints(schemas: Schemas, schemaName: SName => String) {
private[asyncapi] class MessagesForEndpoints(tschemaToASchema: TSchemaToASchema, schemaName: SName => String) {
private type CodecData = Either[(SName, MediaType), TSchema[_]]

private case class CodecWithInfo[T](codec: Codec[WebSocketFrame, T, _ <: CodecFormat], info: EndpointIO.Info[T])
Expand Down Expand Up @@ -42,7 +42,7 @@ private[asyncapi] class MessagesForEndpoints(schemas: Schemas, schemaName: SName
val convertedExamples = ExampleConverter.convertExamples(ci.codec, ci.info.examples)
SingleMessage(
None,
Some(Right(schemas(ci.codec))),
Some(Right(tschemaToASchema(ci.codec))),
None,
None,
Some(ci.codec.format.mediaType.toString()),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package sttp.tapir.docs.openapi

import sttp.apispec.openapi.{MediaType => OMediaType}
import sttp.tapir.docs.apispec.schema.Schemas
import sttp.tapir.docs.apispec.schema.TSchemaToASchema
import sttp.tapir.{CodecFormat, _}

import scala.collection.immutable.ListMap

private[openapi] class CodecToMediaType(schemas: Schemas) {
private[openapi] class CodecToMediaType(tschemaToASchema: TSchemaToASchema) {
def apply[T, CF <: CodecFormat](
o: Codec[_, T, CF],
examples: List[EndpointIO.Example[T]],
Expand All @@ -19,7 +19,7 @@ private[openapi] class CodecToMediaType(schemas: Schemas) {

ListMap(
forcedContentType.getOrElse(o.format.mediaType.noCharset.toString) -> OMediaType(
Some(schemas(o)),
Some(tschemaToASchema(o)),
allExamples.singleExample,
allExamples.multipleExamples
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ private[openapi] object EndpointToOpenAPIDocs {
): OpenAPI = {
val es2 = es.filter(e => findWebSocket(e).isEmpty).map(nameAllPathCapturesInEndpoint)
val additionalOutputs = es2.flatMap(e => options.defaultDecodeFailureOutput(e.input)).toSet.toList
val (idToSchema, schemas) = new SchemasForEndpoints(es2, options.schemaName, options.markOptionsAsNullable, additionalOutputs).apply()
val (idToSchema, tschemaToASchema) =
new SchemasForEndpoints(es2, options.schemaName, options.markOptionsAsNullable, additionalOutputs).apply()
val securitySchemes = SecuritySchemesForEndpoints(es2, apiKeyAuthTypeName = "apiKey")
val pathCreator = new EndpointToOpenAPIPaths(schemas, securitySchemes, options)
val pathCreator = new EndpointToOpenAPIPaths(tschemaToASchema, securitySchemes, options)
val componentsCreator = new EndpointToOpenAPIComponents(idToSchema, securitySchemes)

val base = apiToOpenApi(api, componentsCreator, docsExtensions)
Expand Down
Loading

0 comments on commit 02978cd

Please sign in to comment.