Skip to content

Commit

Permalink
feat: support generics in kotlin (#828)
Browse files Browse the repository at this point in the history
  • Loading branch information
worstell authored Jan 24, 2024
1 parent 55ec4f8 commit 0efe839
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ModuleGenerator() {
private fun buildDataClass(type: Data, namespace: String): TypeSpec {
val dataClassBuilder = TypeSpec.classBuilder(type.name)
.addModifiers(KModifier.DATA)
.addTypeVariables(type.typeParameters.map { TypeVariableName(it.name) })
.addKdoc(type.comments.joinToString("\n"))

val dataConstructorBuilder = FunSpec.constructorBuilder()
Expand Down Expand Up @@ -146,6 +147,7 @@ class ModuleGenerator() {
type.bool != null -> ClassName("kotlin", "Boolean")
type.time != null -> ClassName("java.time", "OffsetDateTime")
type.any != null -> ClassName("kotlin", "Any")
type.parameter != null -> TypeVariableName(type.parameter.name)
type.array != null -> {
val element = type.array?.element ?: throw IllegalArgumentException(
"Missing element type in kotlin array generator"
Expand All @@ -167,7 +169,13 @@ class ModuleGenerator() {

type.dataRef != null -> {
val module = if (type.dataRef.module.isEmpty()) namespace else "ftl.${type.dataRef.module}"
ClassName(module, type.dataRef.name)
ClassName(module, type.dataRef.name).let { className ->
if (type.dataRef.typeParameters.isNotEmpty()) {
className.parameterizedBy(type.dataRef.typeParameters.map { getTypeClass(it, namespace) })
} else {
className
}
}
}

type.optional != null -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,20 @@ public class TestModule()
@Test
fun `should generate all Types`() {
val decls = listOf(
Decl(
data_ = Data(
name = "ParamTestData",
typeParameters = listOf(TypeParameter(name = "T")),
fields = listOf(
Field(name = "t", type = Type(parameter = TypeParameter(name = "T"))),
)
)
),
Decl(data_ = Data(comments = listOf("Request comments"), name = "TestRequest")),
Decl(
data_ = Data(
comments = listOf("Response comments"), name = "TestResponse", fields = listOf(
comments = listOf("Response comments"), name = "TestResponse",
fields = listOf(
Field(name = "int", type = Type(int = Int())),
Field(name = "float", type = Type(float = Float())),
Field(name = "string", type = Type(string = String())),
Expand Down Expand Up @@ -73,6 +83,14 @@ public class TestModule()
Field(name = "dataRef", type = Type(dataRef = DataRef(name = "TestRequest"))),
Field(name = "externalDataRef", type = Type(dataRef = DataRef(module = "other", name = "TestRequest"))),
Field(name = "any", type = Type(any = xyz.block.ftl.v1.schema.Any())),
Field(
name = "parameterizedDataRef", type = Type(
dataRef = DataRef(
name = "ParamTestData",
typeParameters = listOf(Type(parameter = TypeParameter(name = "T")))
)
)
),
)
)
),
Expand All @@ -96,6 +114,10 @@ import kotlin.collections.ArrayList
import kotlin.collections.Map
import xyz.block.ftl.Ignore
public data class ParamTestData<T>(
public val t: T,
)
/**
* Request comments
*/
Expand All @@ -120,6 +142,7 @@ public data class TestResponse(
public val dataRef: TestRequest,
public val externalDataRef: ftl.other.TestRequest,
public val any: Any,
public val parameterizedDataRef: ParamTestData<T>,
)
@Ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
import org.jetbrains.kotlin.resolve.source.getPsi
import org.jetbrains.kotlin.resolve.typeBinding.createTypeBindingForReturnType
import org.jetbrains.kotlin.types.KotlinType
import org.jetbrains.kotlin.types.checker.SimpleClassicTypeSystemContext.isTypeParameterTypeConstructor
import org.jetbrains.kotlin.types.getAbbreviation
import org.jetbrains.kotlin.types.isNullable
import org.jetbrains.kotlin.types.typeUtil.isAny
Expand All @@ -42,6 +43,7 @@ import kotlin.io.path.createDirectories

data class ModuleData(val comments: List<String> = emptyList(), val decls: MutableSet<Decl> = mutableSetOf())

data class blah<T>(val a: T)
// Helpers
private fun DataRef.compare(module: String, name: String): Boolean = this.name == name && this.module == module
private fun DataRef.text(): String = "${this.module}.${this.name}"
Expand Down Expand Up @@ -377,11 +379,26 @@ class SchemaExtractor(
)
}.toList(),
comments = this.comments(),
typeParameters = this.children.flatMap { (it as? KtTypeParameterList)?.parameters ?: emptyList() }.map {
TypeParameter(
name = it.name!!,
pos = getLineAndColumnInPsiFile(it.containingFile, it.textRange).toPosition(it.containingKtFile.name),
)
}.toList(),
pos = getLineAndColumnInPsiFile(this.containingFile, this.textRange).toPosition(this.containingKtFile.name),
)
}

private fun KotlinType.toSchemaType(position: Position): Type {
if (this.unwrap().constructor.isTypeParameterTypeConstructor()) {
return Type(
parameter = TypeParameter(
name = this.constructor.declarationDescriptor?.name?.asString() ?: "T",
pos = position,
)
)
}

val type = when (this.fqNameOrNull()?.asString()) {
String::class.qualifiedName -> Type(string = xyz.block.ftl.v1.schema.String())
Int::class.qualifiedName -> Type(int = xyz.block.ftl.v1.schema.Int())
Expand Down Expand Up @@ -437,6 +454,7 @@ class SchemaExtractor(
name = refName,
module = fqName.extractModuleName().takeIf { it != currentModuleName } ?: "",
pos = position,
typeParameters = this.arguments.map { it.type.toSchemaType(position) }.toList(),
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
/**
* Request to echo a message.
*/
data class EchoRequest(val name: String, val stuff: Any)
data class EchoRequest<T>(val t: T, val name: String, val stuff: Any)
data class EchoResponse(val messages: List<EchoMessage>)
/**
Expand All @@ -51,7 +51,7 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
@Throws(InvalidInput::class)
@Verb
@Ingress(Method.GET, "/echo")
fun echo(context: Context, req: EchoRequest): EchoResponse {
fun echo(context: Context, req: EchoRequest<String>): EchoResponse {
callTime(context)
return EchoResponse(messages = listOf(EchoMessage(message = "Hello!")))
}
Expand Down Expand Up @@ -96,22 +96,26 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
name = "metadata",
type = Type(
map = Map(
key = xyz.block.ftl.v1.schema.Type(string = xyz.block.ftl.v1.schema.String()),
value_ = xyz.block.ftl.v1.schema.Type(
key = Type(string = xyz.block.ftl.v1.schema.String()),
value_ = Type(
dataRef = DataRef(
name = "MapValue",
)
)
)
)
)
),
),
),
),
Decl(
data_ = Data(
name = "EchoRequest",
fields = listOf(
Field(
name = "t",
type = Type(parameter = TypeParameter(name = "T"))
),
Field(
name = "name",
type = Type(string = xyz.block.ftl.v1.schema.String())
Expand All @@ -126,6 +130,9 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
* Request to echo a message.
*/"""
),
typeParameters = listOf(
TypeParameter(name = "T")
)
),
),
Decl(
Expand All @@ -136,7 +143,7 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
name = "messages",
type = Type(
array = Array(
element = xyz.block.ftl.v1.schema.Type(
element = Type(
dataRef = DataRef(
name = "EchoMessage",
)
Expand All @@ -158,6 +165,9 @@ internal class ExtractSchemaRuleTest(private val env: KotlinCoreEnvironment) {
request = Type(
dataRef = DataRef(
name = "EchoRequest",
typeParameters = listOf(
Type(string = xyz.block.ftl.v1.schema.String())
)
)
),
response = Type(
Expand Down

0 comments on commit 0efe839

Please sign in to comment.