diff --git a/CHANGELOG.md b/CHANGELOG.md index 735ca988be..b65db8fad7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ Thank you to all who have contributed! ## [Unreleased] ### Added +- Adds top-level IR node creation functions. +- Adds `componentN` functions (destructuring) to IR nodes via Kotlin data classes +- Adds public `tag` field to IR nodes for associating metadata ### Changed @@ -36,12 +39,14 @@ Thank you to all who have contributed! ### Fixed ### Removed +- [Breaking] Removed IR factory in favor of static top-level functions. Change `Ast.foo()` + to `foo()` ### Security ### Contributors Thank you to all who have contributed! -- @ +- @rchowell ## [0.13.2-alpha] - 2023-09-29 diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinGenerator.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinGenerator.kt index c219d5b801..51a52093b4 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinGenerator.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinGenerator.kt @@ -2,16 +2,14 @@ package org.partiql.sprout.generator.target.kotlin import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.ClassName -import com.squareup.kotlinpoet.CodeBlock -import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec -import com.squareup.kotlinpoet.asTypeName import net.pearx.kasechange.toCamelCase import org.partiql.sprout.generator.Generator import org.partiql.sprout.generator.target.kotlin.poems.KotlinBuilderPoem +import org.partiql.sprout.generator.target.kotlin.poems.KotlinFactoryPoem import org.partiql.sprout.generator.target.kotlin.poems.KotlinJacksonPoem import org.partiql.sprout.generator.target.kotlin.poems.KotlinListenerPoem import org.partiql.sprout.generator.target.kotlin.poems.KotlinUtilsPoem @@ -21,7 +19,6 @@ import org.partiql.sprout.generator.target.kotlin.spec.KotlinNodeSpec import org.partiql.sprout.generator.target.kotlin.spec.KotlinUniverseSpec import org.partiql.sprout.model.TypeDef import org.partiql.sprout.model.TypeProp -import org.partiql.sprout.model.TypeRef import org.partiql.sprout.model.Universe /** @@ -29,6 +26,9 @@ import org.partiql.sprout.model.Universe */ class KotlinGenerator(private val options: KotlinOptions) : Generator { + // @JvmField + private val jvmField = AnnotationSpec.builder(JvmField::class).build() + override fun generate(universe: Universe): KotlinResult { // --- Initialize an empty symbol table(?) @@ -39,6 +39,7 @@ class KotlinGenerator(private val options: KotlinOptions) : Generator KotlinVisitorPoem(symbols) + "factory" -> KotlinFactoryPoem(symbols) "builder" -> KotlinBuilderPoem(symbols) "listener" -> KotlinListenerPoem(symbols) "jackson" -> KotlinJacksonPoem(symbols) @@ -51,12 +52,19 @@ class KotlinGenerator(private val options: KotlinOptions) : Generator = types.mapNotNull { it.generate(symbols) }.map { - it.builder.addSuperinterface(symbols.base) + it.builder.superclass(symbols.base) it } @@ -108,34 +116,41 @@ class KotlinGenerator(private val options: KotlinOptions) : Generator.enums(symbols: KotlinSymbols) = filterIsInstance().map { it.generate(symbols) } - - // TODO generate hashCode, equals, componentN so we can have OPEN internal implementations - private fun KotlinNodeSpec.Product.addDataClassMethods(symbols: KotlinSymbols, ref: TypeRef.Path) { - impl.addModifiers(KModifier.INTERNAL, KModifier.OPEN) - addEqualsMethod() - addHashCodeMethod() - addToStringMethod(symbols, ref) - val args = listOf("_id") + props.map { it.name } - val copy = FunSpec.builder("copy").addModifiers(KModifier.ABSTRACT).returns(clazz) - val copyImpl = FunSpec.builder("copy") - .addModifiers(KModifier.OVERRIDE) - .returns(clazz) - .addStatement("return %T(${args.joinToString()})", implClazz) - props.forEach { - val para = ParameterSpec.builder(it.name, it.type).build() - copy.addParameter(para.toBuilder().defaultValue("this.${it.name}").build()) - copyImpl.addParameter(para) - } - builder.addFunction(copy.build()) - impl.addFunction(copyImpl.build()) - } - - /** - * Adds `equals` method to the core abstract class - */ - private fun KotlinNodeSpec.Product.addEqualsMethod() { - val equalsFunctionBodyBuilder = CodeBlock.builder().let { body -> - body.addStatement("if (this === other) return true") - body.addStatement("if (other !is %T) return false", this.clazz) - this.props.forEach { prop -> - body.addStatement("if (%N != other.%N) return false", prop.name, prop.name) - } - body.addStatement("return true") - } - builder.addFunction( - FunSpec.builder("equals").addModifiers(KModifier.OVERRIDE).returns(Boolean::class) - .addParameter(ParameterSpec.builder("other", Any::class.asTypeName().copy(nullable = true)).build()) - .addCode(equalsFunctionBodyBuilder.build()) - .build() - ) - } - - /** - * Adds `hashCode` method to the core abstract class - */ - private fun KotlinNodeSpec.Product.addHashCodeMethod() { - val hashcodeBodyBuilder = CodeBlock.builder().let { body -> - when (this.props.size) { - 0 -> body.addStatement("return 0") - 1 -> body.addStatement("return %N.hashCode()", this.props.first().name) - else -> { - body.addStatement("var result = %N.hashCode()", this.props.first().name) - this.props.subList(1, this.props.size).forEach { prop -> - body.addStatement("result = 31 * result + %N.hashCode()", prop.name) - } - body.addStatement("return result") - } - } - body - } - builder.addFunction( - FunSpec.builder("hashCode") - .addModifiers(KModifier.OVERRIDE) - .returns(Int::class) - .addCode(hashcodeBodyBuilder.build()) - .build() - ) - } - - private fun enumToStringSpec(base: String): FunSpec { - val bodyBuilder = CodeBlock.builder().let { body -> - val str = "$base::\${super.toString()}" - body.addStatement("return %P", str) - body - } - return FunSpec.builder("toString") - .addModifiers(KModifier.OVERRIDE) - .returns(String::class) - .addCode(bodyBuilder.build()) - .build() - } - - /** - * Adds `toString` method to the core abstract class. We write it in Ion syntax, however, it is NOT a contract - * and therefore subject to failure. - * - * Notably, the following don't format to Ion: - * - Maps - * - Imported Types - * - Escape Characters - */ - private fun KotlinNodeSpec.Product.addToStringMethod(symbols: KotlinSymbols, ref: TypeRef.Path) { - val annotation = symbols.pascal(ref) - val thiz = this - val bodyBuilder = CodeBlock.builder().let { body -> - val returnString = buildString { - append("$annotation::{") - thiz.props.forEach { prop -> - if (String::class.asTypeName() == prop.type) { - append("${prop.name}: \"\$${prop.name}\",") - } else { - append("${prop.name}: \$${prop.name},") - } - } - append("}") - } - body.addStatement("return %P", returnString) - body - } - builder.addFunction( - FunSpec.builder("toString") - .addModifiers(KModifier.OVERRIDE) - .returns(String::class) - .addCode(bodyBuilder.build()) - .build() - ) - } } diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinSymbols.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinSymbols.kt index 91235f6e91..b9d91c5971 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinSymbols.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/KotlinSymbols.kt @@ -6,16 +6,13 @@ import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.DOUBLE import com.squareup.kotlinpoet.FLOAT import com.squareup.kotlinpoet.INT -import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.LIST import com.squareup.kotlinpoet.LONG import com.squareup.kotlinpoet.MAP import com.squareup.kotlinpoet.MUTABLE_LIST import com.squareup.kotlinpoet.MUTABLE_MAP import com.squareup.kotlinpoet.MUTABLE_SET -import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy -import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.SET import com.squareup.kotlinpoet.STRING import com.squareup.kotlinpoet.TypeName @@ -54,16 +51,6 @@ class KotlinSymbols private constructor( */ val base: ClassName = ClassName(rootPackage, "${rootId}Node") - /** - * Id Property for interfaces and classes - */ - val idProp = PropertySpec.builder("_id", String::class).addModifiers(KModifier.OVERRIDE).initializer("_id").build() - - /** - * Id Parameter for internal constructors - */ - val idPara = ParameterSpec.builder("_id", String::class).build() - /** * Memoize converting a TypeRef.Path to a camel case identifier to be used as method/function names */ diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinBuilderPoem.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinBuilderPoem.kt index dd60461c8e..d20f79d0fc 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinBuilderPoem.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinBuilderPoem.kt @@ -4,14 +4,12 @@ import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec -import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.LambdaTypeName import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.asTypeName -import com.squareup.kotlinpoet.buildCodeBlock import net.pearx.kasechange.toCamelCase import net.pearx.kasechange.toPascalCase import org.partiql.sprout.generator.target.kotlin.KotlinPoem @@ -31,40 +29,6 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { private val builderPackageName = "${symbols.rootPackage}.builder" - private val idProviderType = LambdaTypeName.get(returnType = String::class.asTypeName()) - private val idProvider = PropertySpec.builder("_id", idProviderType).build() - - // Abstract factory which can be used by DSL blocks - private val factoryName = "${symbols.rootId}Factory" - private val factoryClass = ClassName(builderPackageName, factoryName) - private val factory = TypeSpec.interfaceBuilder(factoryClass) - .addProperty(idProvider) - - private val baseFactoryName = "${symbols.rootId}FactoryImpl" - private val baseFactoryClass = ClassName(builderPackageName, baseFactoryName) - private val baseFactory = TypeSpec.classBuilder(baseFactoryClass) - .addSuperinterface(factoryClass) - .addModifiers(KModifier.OPEN) - .addProperty( - idProvider.toBuilder() - .addModifiers(KModifier.OVERRIDE) - .initializer( - "{ %P }", - buildCodeBlock { - // universe-${"%08x".format(Random.nextInt())} - add("${symbols.rootId}-\${%S.format(%T.nextInt())}", "%08x", ClassName("kotlin.random", "Random")) - } - ) - .build() - ) - - private val factoryParamDefault = ParameterSpec.builder("factory", factoryClass) - .defaultValue("%T.DEFAULT", factoryClass) - .build() - - // Assume there's a .kt file in the package root containing the default builder - private val factoryDefault = ClassName(symbols.rootPackage, symbols.rootId) - // Java style builders, used by the DSL private val buildersName = "${symbols.rootId}Builders" private val buildersFile = FileSpec.builder(builderPackageName, buildersName) @@ -79,13 +43,6 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { private val dslName = "${symbols.rootId}Builder" private val dslClass = ClassName(builderPackageName, dslName) private val dslSpec = TypeSpec.classBuilder(dslClass) - .addProperty( - PropertySpec.builder("factory", factoryClass) - .addModifiers(KModifier.PRIVATE) - .initializer("factory") - .build() - ) - .primaryConstructor(FunSpec.constructorBuilder().addParameter(factoryParamDefault).build()) // T : FooNode private val boundedT = TypeVariableName("T", symbols.base) @@ -93,7 +50,6 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { // Static top-level entry point for DSL private val dslFunc = FunSpec.builder(symbols.rootId.toCamelCase()) .addTypeVariable(boundedT) - .addParameter(factoryParamDefault) .addParameter( ParameterSpec.builder( "block", @@ -103,23 +59,7 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { ) ).build() ) - .addStatement("return %T(factory).block()", dslClass) - .build() - - // Static companion object entry point for factory, similar to PIG "build" - private val factoryFunc = FunSpec.builder("create") - .addAnnotation(Annotations.jvmStatic) - .addTypeVariable(boundedT) - .addParameter( - ParameterSpec.builder( - "block", - LambdaTypeName.get( - receiver = factoryClass, - returnType = boundedT, - ) - ).build() - ) - .addStatement("return %T.DEFAULT.block()", factoryClass) + .addStatement("return %T().block()", dslClass) .build() override fun apply(universe: KotlinUniverseSpec) { @@ -128,17 +68,9 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { KotlinPackageSpec( name = builderPackageName, files = mutableListOf( - // Factory Interface - FileSpec.builder(builderPackageName, factoryName) - .addType(factory.addType(factoryCompanion()).build()) - .build(), - // Factory Base - FileSpec.builder(builderPackageName, baseFactoryName) - .addType(baseFactory.build()) - .build(), // Java Builders buildersFile.build(), - // DSL + // Kotlin DSL FileSpec.builder(builderPackageName, dslName) .addAnnotation(suppressUnused) .addFunction(dslFunc) @@ -150,35 +82,10 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { } override fun apply(node: KotlinNodeSpec.Product) { - val function = FunSpec.builder(symbols.camel(node.product.ref)) - .apply { - node.props.forEach { - addParameter(it.name, it.type) - } - } - .returns(node.clazz) - .build() - // interface - factory.addFunction(function.toBuilder().addModifiers(KModifier.ABSTRACT).build()) - // impl - baseFactory.addFunction( - function.toBuilder() - .addModifiers(KModifier.OVERRIDE) - .returns(node.clazz) - .apply { - val args = listOf("_id()") + node.props.map { - // add as function parameter - it.name - } - // Inject identifier `node(id(), props...)` - addStatement("return %T(${args.joinToString()})", node.implClazz) - } - .build() - ) // DSL Receiver and Function - val (builder, func) = node.builderToFunc() + val (builder, funcDsl) = node.builderToFunc() buildersFile.addType(builder) - dslSpec.addFunction(func) + dslSpec.addFunction(funcDsl) super.apply(node) } @@ -197,10 +104,10 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { // DSL Function val funcDsl = FunSpec.builder(builderName).returns(clazz) - funcDsl.addStatement("val builder = %T(${ props.joinToString { it.name }})", builderType) + funcDsl.addStatement("val builder = %T(${props.joinToString { it.name }})", builderType) - // Java builder `build(factory: Factory = DEFAULT): T` - val funcBuild = FunSpec.builder("build").addParameter(factoryParamDefault).returns(clazz) + // Java builder `build(): T` + val funcBuild = FunSpec.builder("build").returns(clazz) val args = mutableListOf() // Add builder function to node interface @@ -249,6 +156,7 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { ) // Add parameter to `build(factory: Factory =)` me + // This would be a nice place for friendly error messages rather the NPE val assertion = if (!it.ref.nullable && default == "null") "!!" else "" args += "$name = $name$assertion" } @@ -267,31 +175,16 @@ class KotlinBuilderPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { ) // End of factory.foo call - funcBuild.addStatement("return factory.$builderName(${args.joinToString()})") + funcBuild.addStatement("return %T(${args.joinToString()})", clazz) // Finalize Java builder - builder.addFunction( - FunSpec.builder("build") - .returns(clazz) - .addStatement("return build(%T.DEFAULT)", factoryClass) - .build() - ) builder.addFunction(funcBuild.build()) builder.primaryConstructor(builderConstructor.build()) // Finalize DSL function funcDsl.addStatement("builder.block()") - funcDsl.addStatement("return builder.build(factory)") + funcDsl.addStatement("return builder.build()") return Pair(builder.build(), funcDsl.build()) } - - private fun factoryCompanion() = TypeSpec.companionObjectBuilder() - .addProperty( - PropertySpec.builder("DEFAULT", factoryClass) - .initializer("%T()", baseFactoryClass) - .build() - ) - .addFunction(factoryFunc) - .build() } diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinFactoryPoem.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinFactoryPoem.kt new file mode 100644 index 0000000000..9e6bb4ba51 --- /dev/null +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinFactoryPoem.kt @@ -0,0 +1,46 @@ +package org.partiql.sprout.generator.target.kotlin.poems + +import com.squareup.kotlinpoet.AnnotationSpec +import com.squareup.kotlinpoet.ClassName +import com.squareup.kotlinpoet.FileSpec +import com.squareup.kotlinpoet.FunSpec +import org.partiql.sprout.generator.target.kotlin.KotlinPoem +import org.partiql.sprout.generator.target.kotlin.KotlinSymbols +import org.partiql.sprout.generator.target.kotlin.spec.KotlinNodeSpec +import org.partiql.sprout.generator.target.kotlin.spec.KotlinUniverseSpec + +/** + * Poem which creates a DSL for instantiation + */ +class KotlinFactoryPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { + + override val id: String = "factory" + + // file@JvmName("...") + private val jvmName = AnnotationSpec.builder(ClassName("kotlin.jvm", "JvmName")) + + override fun apply(universe: KotlinUniverseSpec) { + super.apply(universe) + val factory = FileSpec.builder(symbols.rootPackage, symbols.rootId) + .addAnnotation(jvmName.addMember("%S", symbols.rootId).build()) + .apply { + universe.forEachNode { + // add all product creation functions + if (it is KotlinNodeSpec.Product) addFunction(it.factoryMethod()) + } + } + .build() + universe.files.add(factory) + } + + private fun KotlinNodeSpec.Product.factoryMethod() = FunSpec.builder(symbols.camel(product.ref)) + .returns(clazz) + .apply { + val args = props.map { + addParameter(it.name, it.type) + it.name + } + addStatement("return %T(${args.joinToString()})", clazz) + } + .build() +} diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinUtilsPoem.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinUtilsPoem.kt index 9a5cf15e54..ec444dd2b7 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinUtilsPoem.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinUtilsPoem.kt @@ -10,10 +10,8 @@ import com.squareup.kotlinpoet.LIST import com.squareup.kotlinpoet.LambdaTypeName import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy -import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.SET import com.squareup.kotlinpoet.TypeSpec -import net.pearx.kasechange.toCamelCase import org.partiql.sprout.generator.target.kotlin.KotlinPoem import org.partiql.sprout.generator.target.kotlin.KotlinSymbols import org.partiql.sprout.generator.target.kotlin.spec.KotlinNodeSpec @@ -43,15 +41,12 @@ class KotlinUtilsPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { .build() // Not taking a dep on builder or visitor poems, as this is temporary - private val factoryClass = ClassName("${symbols.rootPackage}.builder", "${symbols.rootId}Factory") private val visitorBaseClass = ClassName("${symbols.rootPackage}.visitor", "${symbols.rootId}BaseVisitor") .parameterizedBy(symbols.base, Parameters.C) private val rewriterPackageName = "${symbols.rootPackage}.util" private val rewriterName = "${symbols.rootId}Rewriter" - private val factory = symbols.rootId.toCamelCase() - /** * Defines the open `children` property and the abstract`accept` method on the base node */ @@ -61,13 +56,6 @@ class KotlinUtilsPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { .addModifiers(KModifier.ABSTRACT) .addTypeVariable(Parameters.C) .apply { - // open val foo: FooFactory = FooFactory.DEFAULT - addProperty( - PropertySpec.builder(factory, factoryClass) - .addModifiers(KModifier.OPEN) - .initializer("%T.DEFAULT", factoryClass) - .build() - ) // override fun defaultReturn(node: PlanNode, ctx: C) = node addFunction( FunSpec.builder("defaultReturn") @@ -114,7 +102,6 @@ class KotlinUtilsPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { private fun KotlinNodeSpec.Product.rewriter(): FunSpec { val visit = product.ref.visitMethodName() - val constructor = symbols.camel(product.ref) return FunSpec.builder(visit) .addModifiers(KModifier.OVERRIDE) .addParameter(ParameterSpec("node", clazz)) @@ -175,7 +162,7 @@ class KotlinUtilsPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { } val condition = names.joinToString(" || ") { "$it !== node.$it" } beginControlFlow("return if ($condition)") - addStatement("$factory.$constructor(${names.joinToString(", ")})") + addStatement("%T(${names.joinToString(", ")})", clazz) nextControlFlow("else") addStatement("node") endControlFlow() diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt index 88ae4f166c..de8aa6fbb3 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/poems/KotlinVisitorPoem.kt @@ -80,7 +80,7 @@ class KotlinVisitorPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { override fun apply(node: KotlinNodeSpec.Product) { val kids = node.kids() if (kids != null) { - node.impl.addProperty( + node.builder.addProperty( children.toBuilder() .addModifiers(KModifier.OVERRIDE) .delegate( @@ -93,14 +93,14 @@ class KotlinVisitorPoem(symbols: KotlinSymbols) : KotlinPoem(symbols) { .build() ) } else { - node.impl.addProperty( + node.builder.addProperty( children.toBuilder() .addModifiers(KModifier.OVERRIDE) .initializer("emptyList()") .build() ) } - node.impl.addFunction( + node.builder.addFunction( accept.toBuilder() .addModifiers(KModifier.OVERRIDE) .addStatement("return visitor.%L(this, ctx)", node.product.ref.visitMethodName()) diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinNodeSpec.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinNodeSpec.kt index def9c3055b..3abbb5de00 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinNodeSpec.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinNodeSpec.kt @@ -70,26 +70,19 @@ sealed class KotlinNodeSpec( class Product( val product: TypeDef.Product, val props: List, - val implClazz: ClassName, - val impl: TypeSpec.Builder, override val nodes: List, clazz: ClassName, ext: MutableList = mutableListOf(), ) : KotlinNodeSpec( def = product, clazz = clazz, - builder = TypeSpec.classBuilder(clazz).addModifiers(KModifier.ABSTRACT), + builder = TypeSpec.classBuilder(clazz).addModifiers(KModifier.DATA), companion = TypeSpec.companionObjectBuilder(), ext = ext, ) { val constructor = FunSpec.constructorBuilder() override val children: List = nodes - - fun buildImpl(): TypeSpec { - impl.primaryConstructor(constructor.build()) - return impl.build() - } } /** @@ -104,7 +97,7 @@ sealed class KotlinNodeSpec( ) : KotlinNodeSpec( def = sum, clazz = clazz, - builder = TypeSpec.interfaceBuilder(clazz).addModifiers(KModifier.SEALED), + builder = TypeSpec.classBuilder(clazz).addModifiers(KModifier.SEALED), companion = TypeSpec.companionObjectBuilder(), ext = ext, ) { diff --git a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinUniverseSpec.kt b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinUniverseSpec.kt index be5c3d431f..1d99ce2ed1 100644 --- a/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinUniverseSpec.kt +++ b/lib/sprout/src/main/kotlin/org/partiql/sprout/generator/target/kotlin/spec/KotlinUniverseSpec.kt @@ -42,7 +42,8 @@ class KotlinUniverseSpec( * Build the Kotlin files * * - * ├── Types.kt + * ├── .kt + * ├── Nodes.kt * ├── ... * ├── builder * │ └── Builder.kt @@ -55,8 +56,8 @@ class KotlinUniverseSpec( val files = mutableListOf() val specs = nodes.map { it.build() } - // /Types.kt - files += with(FileSpec.builder(root, "Types")) { + // /Nodes.kt + files += with(FileSpec.builder(root, "Nodes")) { addType(base.build()) specs.forEach { addType(it) } types.forEach { addType(it) } @@ -69,16 +70,6 @@ class KotlinUniverseSpec( // //... files += packages.flatMap { it.files } - // /impl/Types.kt - files += with(FileSpec.builder("$root.impl", "Types")) { - forEachNode { - if (it is KotlinNodeSpec.Product) { - addType(it.buildImpl()) - } - } - build() - } - return files } diff --git a/partiql-ast/build.gradle.kts b/partiql-ast/build.gradle.kts index 0814ffaa75..c5e94c9ce1 100644 --- a/partiql-ast/build.gradle.kts +++ b/partiql-ast/build.gradle.kts @@ -60,6 +60,7 @@ val generate = tasks.register("generate") { "-o", "$buildDir/generated-src", "-p", "org.partiql.ast", "-u", "Ast", + "--poems", "factory", "--poems", "visitor", "--poems", "builder", "--poems", "util", diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt deleted file mode 100644 index fa3434d1ab..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/Ast.kt +++ /dev/null @@ -1,28 +0,0 @@ -package org.partiql.ast - -import org.partiql.ast.builder.AstFactoryImpl -import org.partiql.ast.sql.SqlBlock -import org.partiql.ast.sql.SqlDialect -import org.partiql.ast.sql.SqlLayout -import org.partiql.ast.sql.sql - -/** - * Singleton instance of the default factory; also accessible via `AstFactory.DEFAULT`. - */ -object Ast : AstBaseFactory() - -/** - * AstBaseFactory can be used to create a factory which extends from the factory provided by AstFactory.DEFAULT. - */ -public abstract class AstBaseFactory : AstFactoryImpl() { - // internal default overrides here -} - -/** - * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] - */ -@JvmOverloads -public fun AstNode.sql( - layout: SqlLayout = SqlLayout.DEFAULT, - dialect: SqlDialect = SqlDialect.PARTIQL, -): String = accept(dialect, SqlBlock.Nil).sql(layout) diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index 5974d9be2c..da20736eed 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -88,7 +88,7 @@ private class AstTranslator(val metas: Map) : AstBaseVisi node: AstNode, block: PartiqlAst.Builder.(metas: MetaContainer) -> T, ): T { - val metas = metas[node._id] ?: emptyMetaContainer() + val metas = metas[node.tag] ?: emptyMetaContainer() return pig.block(metas) } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/impl/.gitkeep b/partiql-ast/src/main/kotlin/org/partiql/ast/impl/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt index d360f59e33..b9065fe8fe 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/Sql.kt @@ -1,5 +1,16 @@ package org.partiql.ast.sql +import org.partiql.ast.AstNode + +/** + * Pretty-print this [AstNode] as SQL text with the given [SqlLayout] + */ +@JvmOverloads +public fun AstNode.sql( + layout: SqlLayout = SqlLayout.DEFAULT, + dialect: SqlDialect = SqlDialect.PARTIQL, +): String = accept(dialect, SqlBlock.Nil).sql(layout) + // a <> b <-> a concat b internal infix fun SqlBlock.concat(rhs: SqlBlock): SqlBlock = link(this, rhs) diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt index 53186817bb..59f2082860 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt @@ -17,7 +17,6 @@ import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.ExecutionMode import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource -import org.partiql.ast.Ast import org.partiql.ast.AstNode import org.partiql.ast.Expr import org.partiql.ast.From @@ -26,8 +25,9 @@ import org.partiql.ast.Identifier import org.partiql.ast.SetQuantifier import org.partiql.ast.Sort import org.partiql.ast.builder.AstBuilder -import org.partiql.ast.builder.AstFactory import org.partiql.ast.builder.ast +import org.partiql.ast.exprLit +import org.partiql.ast.identifierSymbol import org.partiql.lang.domains.PartiqlAst import org.partiql.value.PartiQLValueExperimental import org.partiql.value.blobValue @@ -113,20 +113,20 @@ class ToLegacyAstTest { companion object { private fun expect(expected: String, block: AstBuilder.() -> AstNode): Case { - val i = ast(AstFactory.DEFAULT, block) + val i = ast(block) val e = PartiqlAst.transform(loadSingleElement(expected)) return Case.Translate(i, e) } private fun fail(message: String, block: AstBuilder.() -> AstNode): Case { - val i = ast(AstFactory.DEFAULT, block) + val i = ast(block) return Case.Fail(i, message) } - private val NULL = Ast.exprLit(nullValue()) + private val NULL = exprLit(nullValue()) // Shortcut to construct a "legacy-compatible" simple identifier - private fun id(name: String) = Ast.identifierSymbol(name, Identifier.CaseSensitivity.INSENSITIVE) + private fun id(name: String) = identifierSymbol(name, Identifier.CaseSensitivity.INSENSITIVE) @JvmStatic fun literals() = listOf( diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt index 7f9b320ac8..15f59df536 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt @@ -13,7 +13,6 @@ import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.ExecutionMode import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource -import org.partiql.ast.Ast import org.partiql.ast.AstNode import org.partiql.ast.DatetimeField import org.partiql.ast.Expr @@ -24,9 +23,8 @@ import org.partiql.ast.SetOp import org.partiql.ast.SetQuantifier import org.partiql.ast.Sort import org.partiql.ast.builder.AstBuilder -import org.partiql.ast.builder.AstFactory import org.partiql.ast.builder.ast -import org.partiql.ast.sql +import org.partiql.ast.exprLit import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue import org.partiql.value.decimalValue @@ -168,7 +166,7 @@ class SqlDialectTest { companion object { - private val NULL = Ast.exprLit(nullValue()) + private val NULL = exprLit(nullValue()) @JvmStatic fun types() = listOf( @@ -1593,12 +1591,12 @@ class SqlDialectTest { ) private fun expect(expected: String, block: AstBuilder.() -> AstNode): Case { - val i = ast(AstFactory.DEFAULT, block) + val i = ast(block) return Case.Success(i, expected) } private fun fail(message: String, block: AstBuilder.() -> AstNode): Case { - val i = ast(AstFactory.DEFAULT, block) + val i = ast(block) return Case.Fail(i, message) } diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt new file mode 100644 index 0000000000..250d6269cd --- /dev/null +++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt @@ -0,0 +1,1342 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.jmh.benchmarks + +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.BenchmarkMode +import org.openjdk.jmh.annotations.Fork +import org.openjdk.jmh.annotations.Measurement +import org.openjdk.jmh.annotations.Mode +import org.openjdk.jmh.annotations.OutputTimeUnit +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import org.openjdk.jmh.annotations.Warmup +import org.openjdk.jmh.infra.Blackhole +import org.partiql.jmh.utils.FORK_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED +import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED +import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.PartiQLParserException +import java.util.concurrent.TimeUnit + +// TODO: If https://github.com/benchmark-action/github-action-benchmark/issues/141 gets fixed, we can move to using +// parameterized tests. This file intentionally uses the same prefix `parse` and `parseFail` for each benchmark. It +// expects that the parameter `name` will be used in the future, so it adds `Name` to prefix each argument. This will +// potentially make it easier to transition to the continuous benchmarking framework if parameterized benchmarks +// are supported. + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +internal open class PartiQLParserBenchmark { + + companion object { + private const val FORK_VALUE: Int = FORK_VALUE_RECOMMENDED + private const val MEASUREMENT_ITERATION_VALUE: Int = MEASUREMENT_ITERATION_VALUE_RECOMMENDED + private const val MEASUREMENT_TIME_VALUE: Int = MEASUREMENT_TIME_VALUE_RECOMMENDED + private const val WARMUP_ITERATION_VALUE: Int = WARMUP_ITERATION_VALUE_RECOMMENDED + private const val WARMUP_TIME_VALUE: Int = WARMUP_TIME_VALUE_RECOMMENDED + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQuerySimple(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::querySimple.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameNestedParen(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::nestedParen.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSomeJoins(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::someJoins.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSeveralJoins(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::severalJoins.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSomeSelect(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::someSelect.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSeveralSelect(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::severalSelect.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSomeProjections(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::someProjections.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSeveralProjections(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::severalProjections.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQueryFunc(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::queryFunc.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQueryFuncInProjection(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::queryFuncInProjection.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQueryList(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::queryList.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQuery15OrsAndLikes(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::query15OrsAndLikes.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQuery30Plus(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::query30Plus.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameQueryNestedSelect(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::queryNestedSelect.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameGraphPattern(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::graphPattern.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameGraphPreFilters(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::graphPreFilters.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameManyJoins(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::manyJoins.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameTimeZone(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::timeZone.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameCaseWhenThen(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::caseWhenThen.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameSimpleInsert(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::simpleInsert.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameExceptUnionIntersectSixty(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::exceptUnionIntersectSixty.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameExec20Expressions(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::exec20Expressions.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameFromLet(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::fromLet.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameGroupLimit(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::groupLimit.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNamePivot(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::pivot.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameLongFromSourceOrderBy(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::longFromSourceOrderBy.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameNestedAggregates(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::nestedAggregates.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameComplexQuery(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::complexQuery.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameComplexQuery01(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::complexQuery01.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameComplexQuery02(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::complexQuery02.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameVeryLongQuery(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::veryLongQuery.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseNameVeryLongQuery01(state: MyState, blackhole: Blackhole) { + val result = state.parser.parse(state.queries[state::veryLongQuery01.name]!!) + blackhole.consume(result) + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQuerySimple(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::querySimple.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameNestedParen(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::nestedParen.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSomeJoins(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::someJoins.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSeveralJoins(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::severalJoins.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSomeSelect(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::someSelect.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSeveralSelect(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::severalSelect.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSomeProjections(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::someProjections.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSeveralProjections(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::severalProjections.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQueryFunc(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::queryFunc.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQueryFuncInProjection(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::queryFuncInProjection.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQueryList(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::queryList.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQuery15OrsAndLikes(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::query15OrsAndLikes.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQuery30Plus(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::query30Plus.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameQueryNestedSelect(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::queryNestedSelect.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameGraphPattern(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::graphPattern.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameGraphPreFilters(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::graphPreFilters.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameManyJoins(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::manyJoins.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameTimeZone(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::timeZone.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameCaseWhenThen(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::caseWhenThen.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameSimpleInsert(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::simpleInsert.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameExceptUnionIntersectSixty(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::exceptUnionIntersectSixty.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameExec20Expressions(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::exec20Expressions.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameFromLet(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::fromLet.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameGroupLimit(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::groupLimit.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNamePivot(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::pivot.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameLongFromSourceOrderBy(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::longFromSourceOrderBy.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameNestedAggregates(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::nestedAggregates.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameComplexQuery(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::complexQuery.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameComplexQuery01(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::complexQuery01.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameComplexQuery02(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::complexQuery02.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameVeryLongQuery(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::veryLongQuery.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @Benchmark + @Fork(value = FORK_VALUE) + @Measurement(iterations = MEASUREMENT_ITERATION_VALUE, time = MEASUREMENT_TIME_VALUE) + @Warmup(iterations = WARMUP_ITERATION_VALUE, time = WARMUP_TIME_VALUE) + @Suppress("UNUSED") + fun parseFailNameVeryLongQuery01(state: MyState, blackhole: Blackhole) { + try { + val result = state.parser.parse(state.queriesFail[state::veryLongQuery01.name]!!) + blackhole.consume(result) + throw RuntimeException() + } catch (ex: PartiQLParserException) { + blackhole.consume(ex) + } + } + + @State(Scope.Thread) + open class MyState { + + val parser = PartiQLParserBuilder.standard().build() + + val query15OrsAndLikes = """ + SELECT * + FROM hr.employees as emp + WHERE lower(emp.name) LIKE '%bob smith%' + OR lower(emp.name) LIKE '%gage swanson%' + OR lower(emp.name) LIKE '%riley perry%' + OR lower(emp.name) LIKE '%sandra woodward%' + OR lower(emp.name) LIKE '%abagail oconnell%' + OR lower(emp.name) LIKE '%amari duke%' + OR lower(emp.name) LIKE '%elisha wyatt%' + OR lower(emp.name) LIKE '%aryanna hess%' + OR lower(emp.name) LIKE '%bryanna jones%' + OR lower(emp.name) LIKE '%trace gilmore%' + OR lower(emp.name) LIKE '%antwan stevenson%' + OR lower(emp.name) LIKE '%julianna callahan%' + OR lower(emp.name) LIKE '%jaelynn trevino%' + OR lower(emp.name) LIKE '%kadence bates%' + OR lower(emp.name) LIKE '%jakobe townsend%' + """ + + val query30Plus = """ + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + """ + + val querySimple = """ + SELECT a FROM t + """ + + val queryNestedSelect = """ + SELECT + ( + SELECT a AS p + FROM ( + SELECT VALUE b + FROM some_table + WHERE 3 = 4 + ) AS some_wrapped_table + WHERE id = 3 + ) AS projectionQuery + FROM ( + SELECT everything + FROM ( + SELECT * + FROM someSourceTable AS t + LET 5 + t.b AS x + WHERE x = 2 + GROUP BY t.a AS k + GROUP AS g + ORDER BY t.d + ) AS someTable + ) + LET (SELECT a FROM smallTable) AS letVariable + WHERE letVariable > 4 + GROUP BY t.a AS groupKey + """ + + val queryList = """ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] + """ + + val queryFunc = """ + f(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29) + """ + + val queryFuncInProjection = """ + SELECT + f(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29) + FROM t + """ + + val someJoins = """ + SELECT a + FROM a, b, c + """ + + val severalJoins = """ + SELECT a + FROM a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p + """ + + val someProjections = """ + SELECT a, b, c + FROM t + """ + + val severalProjections = """ + SELECT a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p + FROM t + """ + + val someSelect = """ + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + """ + + val severalSelect = """ + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + + (SELECT a FROM t) + """ + + val nestedParen = """ + ((((((((((((((((((((((((((((((0)))))))))))))))))))))))))))))) + """ + + val graphPreFilters = """ + SELECT u as banCandidate + FROM g + MATCH (p:Post Where p.isFlagged = true) <-[:createdPost]- (u:Usr WHERE u.isBanned = false AND u.karma < 20) -[:createdComment]->(c:Comment WHERE c.isFlagged = true) + WHERE p.title LIKE '%considered harmful%' + """.trimIndent() + + val graphPattern = """ + SELECT the_a.name AS src, the_b.name AS dest + FROM my_graph MATCH (the_a:a) -[the_y:y]-> (the_b:b) + WHERE the_y.score > 10 + """.trimIndent() + + val manyJoins = """ + SELECT x FROM a INNER CROSS JOIN b CROSS JOIN c LEFT JOIN d ON e RIGHT OUTER CROSS JOIN f OUTER JOIN g ON h + """ + + val timeZone = "TIME WITH TIME ZONE '23:59:59.123456789+18:00'" + + val caseWhenThen = "CASE WHEN name = 'zoe' THEN 1 WHEN name > 'kumo' THEN 2 ELSE 0 END" + + val simpleInsert = """ + INSERT INTO foo VALUE 1 AT bar RETURNING MODIFIED OLD bar, MODIFIED NEW bar, ALL NEW * + """ + + val exceptUnionIntersectSixty = """ + a EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + EXCEPT a INTERSECT a UNION a + """ + + val exec20Expressions = """ + EXEC + a + b, + a, + b, + c, + d, + 123, + "aaaaa", + 'aaaaa', + @ident, + 1 + 1, + 2 + 2, + a, + a, + a, + a, + a, + a, + a, + a + """ + + val fromLet = + "SELECT C.region, MAX(nameLength) AS maxLen FROM C LET char_length(C.name) AS nameLength GROUP BY C.region" + + val groupLimit = + "SELECT g FROM `[{foo: 1, bar: 10}, {foo: 1, bar: 11}]` AS f GROUP BY f.foo GROUP AS g LIMIT 1" + + val pivot = """ + PIVOT foo.a AT foo.b + FROM <<{'a': 1, 'b':'I'}, {'a': 2, 'b':'II'}, {'a': 3, 'b':'III'}>> AS foo + LIMIT 1 OFFSET 1 + """.trimIndent() + + val longFromSourceOrderBy = """ + SELECT * + FROM [{'a': {'a': 5}}, {'a': {'a': 'b'}}, {'a': {'a': true}}, {'a': {'a': []}}, {'a': {'a': {}}}, {'a': {'a': <<>>}}, {'a': {'a': `{{}}`}}, {'a': {'a': null}}] + ORDER BY a DESC + """.trimIndent() + + val nestedAggregates = """ + SELECT + i2 AS outerKey, + g2 AS outerGroupAs, + MIN(innerQuery.innerSum) AS outerMin, + ( + SELECT VALUE SUM(i2) + FROM << 0, 1 >> + ) AS projListSubQuery + FROM ( + SELECT + i, + g, + SUM(col1) AS innerSum + FROM simple_1_col_1_group_2 AS innerFromSource + GROUP BY col1 AS i GROUP AS g + ) AS innerQuery + GROUP BY innerQuery.i AS i2, innerQuery.g AS g2 + """.trimIndent() + + val complexQuery = """ + 1 + ( + SELECT a, b, c + FROM [ + { 'a': 1} + ] AS t + LET x AS y + WHERE y > 2 AND y > 3 AND y > 4 + GROUP BY t.a, t.b AS b, t.c AS c + GROUP AS d + ORDER BY x + LIMIT 1 + 22222222222222222 + OFFSET x + y + z + a + b + c + ) + ( + CAST( + '45678920irufji332r94832fhedjcd2wqbxucri3' + AS INT + ) + ) + [ + 1, 2, 3, 4, 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 + ] - ((((((((((2)))))))))) + ( + SELECT VALUE { 'a': a } FROM t WHERE t.a > 3 + ) + """.trimIndent() + + val complexQuery01 = """ + SELECT + DATE_FORMAT(co.order_date, '%Y-%m') AS order_month, + DATE_FORMAT(co.order_date, '%Y-%m-%d') AS order_day, + COUNT(DISTINCT co.order_id) AS num_orders, + COUNT(ol.book_id) AS num_books, + SUM(ol.price) AS total_price + FROM cust_order co + INNER JOIN order_line ol ON co.order_id = ol.order_id + GROUP BY + DATE_FORMAT(co.order_date, '%Y-%m'), + DATE_FORMAT(co.order_date, '%Y-%m-%d') + ORDER BY co.order_date ASC; + """.trimIndent() + + val complexQuery02 = """ + SELECT + c.calendar_date, + c.calendar_year, + c.calendar_month, + c.calendar_dayName, + COUNT(DISTINCT sub.order_id) AS num_orders, + COUNT(sub.book_id) AS num_books, + SUM(sub.price) AS total_price, + SUM(COUNT(sub.book_id)) AS running_total_num_books, + LAG(COUNT(sub.book_id), 7) AS prev_books + FROM calendar_days c + LEFT JOIN ( + SELECT + co.order_date, + co.order_id, + ol.book_id, + ol.price + FROM cust_order co + INNER JOIN order_line ol ON co.order_id = ol.order_id + ) sub ON c.calendar_date = sub.order_date + GROUP BY c.calendar_date, c.calendar_year, c.calendar_month, c.calendar_dayname + ORDER BY c.calendar_date ASC; + """.trimIndent() + + val veryLongQuery = """ + SELECT + e.employee_id AS "Employee#", e.first_name || '' || e.last_name AS "Name", e.email AS "Email", + e.phone_number AS "Phone", TO_CHAR(e.hire_date, 'MM/DD/YYYY') AS "Hire Date", + TO_CHAR(e.salary, 'L99G999D99', 'NLS_NUMERIC_CHARACTERS=''.,''NLS_CURRENCY=''${'$'}''') AS "Salary", + e.commission_pct AS "Comission%", + 'works as' || j.job_title || 'in' || d.department_name || ' department (manager: ' + || dm.first_name || '' || dm.last_name || ')andimmediatesupervisor:' || m.first_name || '' || m.last_name AS "CurrentJob", + TO_CHAR(j.min_salary, 'L99G999D99', 'NLS_NUMERIC_CHARACTERS=''.,''NLS_CURRENCY=''${'$'}''') || '-' || + TO_CHAR(j.max_salary, 'L99G999D99', 'NLS_NUMERIC_CHARACTERS=''.,''NLS_CURRENCY=''${'$'}''') AS "CurrentSalary", + l.street_address || ',' || l.postal_code || ',' || l.city || ',' || l.state_province || ',' + || c.country_name || '(' || r.region_name || ')' AS "Location", + jh.job_id AS "HistoryJobID", + 'worked from' || TO_CHAR(jh.start_date, 'MM/DD/YYYY') || 'to' || TO_CHAR(jh.end_date, 'MM/DD/YYYY') || + 'as' || jj.job_title || 'in' || dd.department_name || 'department' AS "HistoryJobTitle" + FROM employees e + JOIN jobs j + ON e.job_id = j.job_id + LEFT JOIN employees m + ON e.manager_id = m.employee_id + LEFT JOIN departments d + ON d.department_id = e.department_id + LEFT JOIN employees dm + ON d.manager_id = dm.employee_id + LEFT JOIN locations l + ON d.location_id = l.location_id + LEFT JOIN countries c + ON l.country_id = c.country_id + LEFT JOIN regions r + ON c.region_id = r.region_id + LEFT JOIN job_history jh + ON e.employee_id = jh.employee_id + LEFT JOIN jobs jj + ON jj.job_id = jh.job_id + LEFT JOIN departments dd + ON dd.department_id = jh.department_id + + ORDER BY e.employee_id; + """.trimIndent() + + val veryLongQuery01 = """ + SELECT + id as feedId, + (IF(groupId > 0, groupId, IF(friendId > 0, friendId, userId))) as wallOwnerId, + (IF(groupId > 0 or friendId > 0, userId, NULL)) as guestWriterId, + (IF(groupId > 0 or friendId > 0, userId, NULL)) as guestWriterType, + case + when type = 2 then 1 + when type = 1 then IF(media_count = 1, 2, 4) + when type = 5 then IF(media_count = 1, IF(albumName = 'Audio Feeds', 5, 6), 7) + when type = 6 then IF(media_count = 1, IF(albumName = 'Video Feeds', 8, 9), 10) + end as contentType, + albumId, + albumName, + addTime, + IF(validity > 0,IF((validity - updateTime) / 86400000 > 1,(validity - updateTime) / 86400000, 1),0) as validity, + updateTime, + status, + location, + latitude as locationLat, + longitude as locationLon, + sharedFeedId as parentFeedId, + case + when privacy = 2 or privacy = 10 then 15 + when privacy = 3 then 25 + else 1 + end as privacy, + pagefeedcategoryid, + case + when lastSharedFeedId = 2 then 10 + when lastSharedFeedId = 3 then 15 + when lastSharedFeedId = 4 then 25 + when lastSharedFeedId = 5 then 20 + when lastSharedFeedId = 6 then 99 + else 1 + end as wallOwnerType, + (ISNULL(latitude) or latitude = 9999.0 or ISNULL(longitude) or longitude = 9999.0) as latlongexists, + (SELECT concat('[',GROUP_CONCAT(moodId),']') FROM feedactivities WHERE newsFeedId = newsfeed.id) as feelings, + (SELECT concat('[',GROUP_CONCAT(userId),']') FROM feedtags WHERE newsFeedId = newsfeed.id) as withTag, + (SELECT concat('{',GROUP_CONCAT(pos,':', friendId),'}') FROM statustags WHERE newsFeedId = newsfeed.id) as textTag, + albumType, + defaultCategoryType, + linkType,linkTitle,linkURL,linkDesc,linkImageURL,linkDomain, -- Link Content + title,description,shortDescription,newsUrl,externalUrlOption, -- Additional Content + url, height, width, thumnail_url, thumnail_height, thumbnail_width, duration, artist -- Media + FROM + (newsfeed LEFT JOIN + ( + SELECT + case + when (mediaalbums.media_type = 1 and album_name = 'AudioFeeds') + or (mediaalbums.media_type = 2 and album_name = 'VideoFeeds') + then -1 * mediaalbums.user_id else mediaalbums.id + end as albumId, + album_name as albumName, + newsFeedId, + (NULL) as height, + (NULL) as width, + media_thumbnail_url as thumnail_url, + max(thumb_image_height) as thumnail_height, + max(thumb_image_width) as thumbnail_width, + max(media_duration) as duration, + case + when mediaalbums.media_type = 1 and album_name = 'AudioFeeds' + then 4 + when mediaalbums.media_type = 2 and album_name = 'VideoFeeds' + then 5 else 8 + end as albumType, + count(mediacontents.id) as media_count, + media_artist as artist + FROM + (mediaalbums INNER JOIN mediacontents ON mediaalbums.id = mediacontents.album_id) + INNER JOIN newsfeedmediacontents + ON newsfeedmediacontents.contentId = mediacontents.id group by newsfeedid + UNION + SELECT + -1 * userId as albumId, + newsFeedId,imageUrl as url, + max(imageHeight) as height, + max(imageWidth) as width, + (NULL) as thumnail_url, + (NULL) as thumnail_height, + (NULL) as thumbnail_width, + (NULL) as duration, + case + when albumId = 'default' then 1 + when albumId = 'profileimages' then 2 + when albumId = 'coverimages' then 3 + end as albumType, + count(imageid) as media_count, + (NULL) as artist + FROM userimages + INNER JOIN newsfeedimages on userimages.id = newsfeedimages.imageId + group by newsfeedid + ) album + ON newsfeed.id = album.newsfeedId + ) + LEFT JOIN + ( + select newsPortalFeedId as feedid, + title,description,shortDescription,newsUrl,externalUrlOption, newsPortalCategoryId as pagefeedcategoryid, + (15) as defaultCategoryType from newsportalFeedInfo + UNION + select businessPageFeedId as feedid,title,description,shortDescription,newsUrl,externalUrlOption, + businessPageCategoryId as pagefeedcategoryid,(25) as defaultCategoryType from businessPageFeedInfo + UNION + select newsfeedId as feedid,(NULL) as title,description,(NULL) as shortDescription,(NULL) as newsUrl, + (NULL) as externalUrlOption, categoryMappingId as pagefeedcategoryid, + (20) as defaultCategoryType from mediaPageFeedInfo + ) page + ON newsfeed.id = page.feedId WHERE privacy != 10 + """.trimIndent() + + val queries = mapOf( + ::querySimple.name to querySimple, + ::nestedParen.name to nestedParen, + ::someJoins.name to someJoins, + ::severalJoins.name to severalJoins, + ::someSelect.name to someSelect, + ::severalSelect.name to severalSelect, + ::someProjections.name to someProjections, + ::severalProjections.name to severalProjections, + ::queryFunc.name to queryFunc, + ::queryFuncInProjection.name to queryFuncInProjection, + ::queryList.name to queryList, + ::query15OrsAndLikes.name to query15OrsAndLikes, + ::query30Plus.name to query30Plus, + ::queryNestedSelect.name to queryNestedSelect, + ::graphPattern.name to graphPattern, + ::graphPreFilters.name to graphPreFilters, + ::manyJoins.name to manyJoins, + ::timeZone.name to timeZone, + ::caseWhenThen.name to caseWhenThen, + ::simpleInsert.name to simpleInsert, + ::exceptUnionIntersectSixty.name to exceptUnionIntersectSixty, + ::exec20Expressions.name to exec20Expressions, + ::fromLet.name to fromLet, + ::groupLimit.name to groupLimit, + ::pivot.name to pivot, + ::longFromSourceOrderBy.name to longFromSourceOrderBy, + ::nestedAggregates.name to nestedAggregates, + ::complexQuery.name to complexQuery, + ::complexQuery01.name to complexQuery01, + ::complexQuery02.name to complexQuery02, + ::veryLongQuery.name to veryLongQuery, + ::veryLongQuery01.name to veryLongQuery01, + ) + + val queriesFail = queries.map { (name, query) -> + val splitQuery = query.split("\\s".toRegex()).toMutableList() + val index = (splitQuery.lastIndex * .50).toInt() + splitQuery.add(index, ";") + name to splitQuery.joinToString(separator = " ") + }.toMap() + } +} diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt index 8fd3b18d8b..d79a20f0d4 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToPlan.kt @@ -13,8 +13,8 @@ import org.partiql.lang.eval.visitors.SelectStarVisitorTransform import org.partiql.lang.planner.transforms.plan.RelConverter import org.partiql.lang.planner.transforms.plan.RexConverter import org.partiql.plan.PartiQLPlan -import org.partiql.plan.Plan import org.partiql.plan.Rex +import org.partiql.plan.partiQLPlan /** * Translate the PIG AST to an implementation of the PartiQL Plan Representation. @@ -30,7 +30,7 @@ object AstToPlan { unsupported(ast) } val root = transform(ast.expr) - return Plan.partiQLPlan( + return partiQLPlan( version = PartiQLPlan.Version.PARTIQL_V0, root = root, ) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt index a909ce76f3..44bf9753fc 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/PlanTyper.kt @@ -27,8 +27,6 @@ import org.partiql.lang.eval.builtins.SCALAR_BUILTINS_DEFAULT import org.partiql.lang.planner.PlanningProblemDetails import org.partiql.lang.planner.transforms.PlannerSession import org.partiql.lang.planner.transforms.impl.Metadata -import org.partiql.lang.planner.transforms.plan.PlanTyper.MinimumTolerance.FULL -import org.partiql.lang.planner.transforms.plan.PlanTyper.MinimumTolerance.PARTIAL import org.partiql.lang.planner.transforms.plan.PlanUtils.addType import org.partiql.lang.planner.transforms.plan.PlanUtils.grabType import org.partiql.lang.types.FunctionSignature @@ -42,12 +40,14 @@ import org.partiql.plan.Binding import org.partiql.plan.Case import org.partiql.plan.ExcludeExpr import org.partiql.plan.ExcludeStep -import org.partiql.plan.Plan import org.partiql.plan.PlanNode import org.partiql.plan.Property import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.Step +import org.partiql.plan.attribute +import org.partiql.plan.binding +import org.partiql.plan.rexId import org.partiql.plan.util.PlanRewriter import org.partiql.spi.BindingCase import org.partiql.spi.BindingName @@ -327,20 +327,20 @@ internal object PlanTyper : PlanRewriter() { val fromExprType = value.grabType() ?: handleMissingType(ctx) val valueType = getUnpivotValueType(fromExprType) - val typeEnv = mutableListOf(Plan.attribute(asSymbolicName, valueType)) + val typeEnv = mutableListOf(attribute(asSymbolicName, valueType)) from.at?.let { val valueHasMissing = StaticTypeUtils.getTypeDomain(valueType).contains(ExprValueType.MISSING) val valueOnlyHasMissing = valueHasMissing && StaticTypeUtils.getTypeDomain(valueType).size == 1 when { valueOnlyHasMissing -> { - typeEnv.add(Plan.attribute(it, StaticType.MISSING)) + typeEnv.add(attribute(it, StaticType.MISSING)) } valueHasMissing -> { - typeEnv.add(Plan.attribute(it, StaticType.STRING.asOptional())) + typeEnv.add(attribute(it, StaticType.STRING.asOptional())) } else -> { - typeEnv.add(Plan.attribute(it, StaticType.STRING)) + typeEnv.add(attribute(it, StaticType.STRING)) } } } @@ -357,8 +357,8 @@ internal object PlanTyper : PlanRewriter() { override fun visitRelAggregate(node: Rel.Aggregate, ctx: Context): PlanNode { val input = visitRel(node.input, ctx) - val calls = node.calls.map { Plan.binding(it.name, typeRex(it.value, input, ctx)) } - val groups = node.groups.map { Plan.binding(it.name, typeRex(it.value, input, ctx)) } + val calls = node.calls.map { binding(it.name, typeRex(it.value, input, ctx)) } + val groups = node.groups.map { binding(it.name, typeRex(it.value, input, ctx)) } return node.copy( calls = calls, groups = groups, @@ -377,12 +377,12 @@ internal object PlanTyper : PlanRewriter() { when (val structType = type as? StructType) { null -> { handleIncompatibleDataTypeForExprError(StaticType.STRUCT, type, ctx) - listOf(Plan.attribute(binding.name, type)) + listOf(attribute(binding.name, type)) } - else -> structType.fields.map { entry -> Plan.attribute(entry.key, entry.value) } + else -> structType.fields.map { entry -> attribute(entry.key, entry.value) } } } - false -> listOf(Plan.attribute(binding.name, type)) + false -> listOf(attribute(binding.name, type)) } } return node.copy( @@ -418,7 +418,7 @@ internal object PlanTyper : PlanRewriter() { is Rex.Query.Collection -> when (value.constructor) { null -> value.rel else -> { - val typeEnv = listOf(Plan.attribute(asSymbolicName, sourceType)) + val typeEnv = listOf(attribute(asSymbolicName, sourceType)) node.copy( value = value, common = node.common.copy( @@ -428,7 +428,7 @@ internal object PlanTyper : PlanRewriter() { } } else -> { - val typeEnv = listOf(Plan.attribute(asSymbolicName, sourceType)) + val typeEnv = listOf(attribute(asSymbolicName, sourceType)) node.copy( value = value, common = node.common.copy( @@ -1265,7 +1265,7 @@ internal object PlanTyper : PlanRewriter() { rexCaseToBindingCase(node.case) ) - private fun List.toAttributes(ctx: Context) = this.map { Plan.attribute(it.name, it.grabType() ?: handleMissingType(ctx)) } + private fun List.toAttributes(ctx: Context) = this.map { attribute(it.name, it.grabType() ?: handleMissingType(ctx)) } private fun inferConcatOp(leftType: SingleType, rightType: SingleType): SingleType { fun checkUnconstrainedText(type: SingleType) = type is SymbolType || type is StringType && type.lengthConstraint is StringType.StringLengthConstraint.Unconstrained @@ -1688,7 +1688,7 @@ internal object PlanTyper : PlanRewriter() { ElementType.SYMBOL, ElementType.STRING -> { val stringValue = value.value.asAnyElement().stringValueOrNull stringValue?.let { str -> - Plan.rexId(str, it.case, Rex.Id.Qualifier.UNQUALIFIED, null) + rexId(str, it.case, Rex.Id.Qualifier.UNQUALIFIED, null) } } else -> null diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt index c0d746929f..26ecb6c870 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RelConverter.kt @@ -9,10 +9,33 @@ import org.partiql.plan.Binding import org.partiql.plan.Case import org.partiql.plan.ExcludeExpr import org.partiql.plan.ExcludeStep -import org.partiql.plan.Plan import org.partiql.plan.Rel import org.partiql.plan.Rex import org.partiql.plan.SortSpec +import org.partiql.plan.binding +import org.partiql.plan.common +import org.partiql.plan.excludeExpr +import org.partiql.plan.excludeStepCollectionIndex +import org.partiql.plan.excludeStepCollectionWildcard +import org.partiql.plan.excludeStepTupleAttr +import org.partiql.plan.excludeStepTupleWildcard +import org.partiql.plan.field +import org.partiql.plan.relAggregate +import org.partiql.plan.relExclude +import org.partiql.plan.relFetch +import org.partiql.plan.relFilter +import org.partiql.plan.relJoin +import org.partiql.plan.relProject +import org.partiql.plan.relScan +import org.partiql.plan.relSort +import org.partiql.plan.relUnpivot +import org.partiql.plan.rexAgg +import org.partiql.plan.rexId +import org.partiql.plan.rexLit +import org.partiql.plan.rexQueryCollection +import org.partiql.plan.rexQueryScalarPivot +import org.partiql.plan.rexTuple +import org.partiql.plan.sortSpec import org.partiql.types.StaticType /** @@ -23,7 +46,7 @@ internal class RelConverter { /** * As of now, the COMMON property of relation operators is under development, so just use empty for now */ - private val empty = Plan.common( + private val empty = common( typeEnv = emptyList(), properties = emptySet(), metas = emptyMap() @@ -40,7 +63,7 @@ internal class RelConverter { val rex = when (val projection = select.project) { // PIVOT ... FROM is PartiqlAst.Projection.ProjectPivot -> { - Plan.rexQueryScalarPivot( + rexQueryScalarPivot( rel = rel, value = RexConverter.convert(projection.value), at = RexConverter.convert(projection.key), @@ -49,7 +72,7 @@ internal class RelConverter { } // SELECT VALUE ... FROM is PartiqlAst.Projection.ProjectValue -> { - Plan.rexQueryCollection( + rexQueryCollection( rel = rel, constructor = RexConverter.convert(projection.value), type = null @@ -57,7 +80,7 @@ internal class RelConverter { } // SELECT ... FROM else -> { - Plan.rexQueryCollection( + rexQueryCollection( rel = rel, constructor = null, type = null @@ -108,7 +131,7 @@ internal class RelConverter { null -> input else -> { val exprs = excludeOp.exprs.map { convertExcludeExpr(it) } - Plan.relExclude( + relExclude( common = empty, input = input, exprs = exprs, @@ -120,15 +143,15 @@ internal class RelConverter { val root = excludeExpr.root.name.text val case = convertCase(excludeExpr.root.case) val steps = excludeExpr.steps.map { convertExcludeSteps(it) } - return Plan.excludeExpr(root, case, steps) + return excludeExpr(root, case, steps) } private fun convertExcludeSteps(excludeStep: PartiqlAst.ExcludeStep): ExcludeStep { return when (excludeStep) { - is PartiqlAst.ExcludeStep.ExcludeCollectionWildcard -> Plan.excludeStepCollectionWildcard() - is PartiqlAst.ExcludeStep.ExcludeTupleWildcard -> Plan.excludeStepTupleWildcard() - is PartiqlAst.ExcludeStep.ExcludeTupleAttr -> Plan.excludeStepTupleAttr(excludeStep.attr.name.text, convertCase(excludeStep.attr.case)) - is PartiqlAst.ExcludeStep.ExcludeCollectionIndex -> Plan.excludeStepCollectionIndex(excludeStep.index.value.toInt()) + is PartiqlAst.ExcludeStep.ExcludeCollectionWildcard -> excludeStepCollectionWildcard() + is PartiqlAst.ExcludeStep.ExcludeTupleWildcard -> excludeStepTupleWildcard() + is PartiqlAst.ExcludeStep.ExcludeTupleAttr -> excludeStepTupleAttr(excludeStep.attr.name.text, convertCase(excludeStep.attr.case)) + is PartiqlAst.ExcludeStep.ExcludeCollectionIndex -> excludeStepCollectionIndex(excludeStep.index.value.toInt()) } } @@ -148,7 +171,7 @@ internal class RelConverter { val lhs = convertFrom(join.left) val rhs = convertFrom(join.right) val condition = if (join.predicate != null) RexConverter.convert(join.predicate!!) else null - return Plan.relJoin( + return relJoin( common = empty, lhs = lhs, rhs = rhs, @@ -165,7 +188,7 @@ internal class RelConverter { /** * Appends [Rel.Scan] which takes no input relational expression */ - private fun convertScan(scan: PartiqlAst.FromSource.Scan) = Plan.relScan( + private fun convertScan(scan: PartiqlAst.FromSource.Scan) = relScan( common = empty, value = when (val expr = scan.expr) { is PartiqlAst.Expr.Select -> convert(expr) @@ -179,7 +202,7 @@ internal class RelConverter { /** * Appends [Rel.Unpivot] to range over attribute value pairs */ - private fun convertUnpivot(scan: PartiqlAst.FromSource.Unpivot) = Plan.relUnpivot( + private fun convertUnpivot(scan: PartiqlAst.FromSource.Unpivot) = relUnpivot( common = empty, value = RexConverter.convert(scan.expr), alias = scan.asAlias?.text, @@ -192,7 +215,7 @@ internal class RelConverter { */ private fun convertWhere(input: Rel, expr: PartiqlAst.Expr?): Rel = when (expr) { null -> input - else -> Plan.relFilter( + else -> relFilter( common = empty, input = input, condition = RexConverter.convert(expr) @@ -239,7 +262,7 @@ internal class RelConverter { } } - val rel = Plan.relAggregate( + val rel = relAggregate( common = empty, input = input, calls = calls, @@ -266,7 +289,7 @@ internal class RelConverter { */ private fun convertHaving(input: Rel, expr: PartiqlAst.Expr?): Rel = when (expr) { null -> input - else -> Plan.relFilter( + else -> relFilter( common = empty, input = input, condition = RexConverter.convert(expr) @@ -278,7 +301,7 @@ internal class RelConverter { */ private fun convertOrderBy(input: Rel, orderBy: PartiqlAst.OrderBy?) = when (orderBy) { null -> input - else -> Plan.relSort( + else -> relSort( common = empty, input = input, specs = orderBy.sortSpecs.map { convertSortSpec(it) } @@ -300,7 +323,7 @@ internal class RelConverter { if (offset != null) error("offset without limit") return input } - return Plan.relFetch( + return relFetch( common = empty, input = input, limit = RexConverter.convert(limit), @@ -315,7 +338,7 @@ internal class RelConverter { * @param projection * @return */ - private fun convertProjectList(input: Rel, projection: PartiqlAst.Projection.ProjectList) = Plan.relProject( + private fun convertProjectList(input: Rel, projection: PartiqlAst.Projection.ProjectList) = relProject( common = empty, input = input, bindings = projection.projectItems.bindings() @@ -328,7 +351,7 @@ internal class RelConverter { * - ASC NULLS LAST (default) * - DESC NULLS FIRST (default for DESC) */ - private fun convertSortSpec(sortSpec: PartiqlAst.SortSpec) = Plan.sortSpec( + private fun convertSortSpec(sortSpec: PartiqlAst.SortSpec) = sortSpec( value = RexConverter.convert(sortSpec.expr), dir = when (sortSpec.orderingSpec) { is PartiqlAst.OrderingSpec.Desc -> SortSpec.Dir.DESC @@ -354,16 +377,16 @@ internal class RelConverter { */ private fun convertGroupAs(name: String, from: PartiqlAst.FromSource): Binding { val fields = from.bindings().map { n -> - Plan.field( - name = Plan.rexLit(ionString(n), StaticType.STRING), - value = Plan.rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = StaticType.STRUCT) + field( + name = rexLit(ionString(n), StaticType.STRING), + value = rexId(n, Case.SENSITIVE, Rex.Id.Qualifier.UNQUALIFIED, type = StaticType.STRUCT) ) } - return Plan.binding( + return binding( name = name, - value = Plan.rexAgg( + value = rexAgg( id = "group_as", - args = listOf(Plan.rexTuple(fields, StaticType.STRUCT)), + args = listOf(rexTuple(fields, StaticType.STRUCT)), modifier = Rex.Agg.Modifier.ALL, type = StaticType.STRUCT ) @@ -488,7 +511,7 @@ internal class RelConverter { /** * Binding helper */ - private fun binding(name: String, expr: PartiqlAst.Expr) = Plan.binding( + private fun binding(name: String, expr: PartiqlAst.Expr) = binding( name = name, value = RexConverter.convert(expr) ) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt index 633e19c491..fa177fe2ec 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/plan/RexConverter.kt @@ -11,9 +11,27 @@ import org.partiql.lang.eval.errorContextFrom import org.partiql.lang.planner.transforms.AstToPlan import org.partiql.lang.planner.transforms.plan.PlanTyper.isProjectAll import org.partiql.plan.Case -import org.partiql.plan.Plan import org.partiql.plan.Rel import org.partiql.plan.Rex +import org.partiql.plan.argType +import org.partiql.plan.argValue +import org.partiql.plan.branch +import org.partiql.plan.field +import org.partiql.plan.rexAgg +import org.partiql.plan.rexBinary +import org.partiql.plan.rexCall +import org.partiql.plan.rexCollectionArray +import org.partiql.plan.rexCollectionBag +import org.partiql.plan.rexId +import org.partiql.plan.rexLit +import org.partiql.plan.rexPath +import org.partiql.plan.rexQueryScalarSubquery +import org.partiql.plan.rexSwitch +import org.partiql.plan.rexTuple +import org.partiql.plan.rexUnary +import org.partiql.plan.stepKey +import org.partiql.plan.stepUnpivot +import org.partiql.plan.stepWildcard import org.partiql.types.StaticType import java.util.Locale @@ -57,11 +75,11 @@ internal object RexConverter : PartiqlAst.VisitorFold() { private fun convert(vararg nodes: PartiqlAst.Expr) = nodes.map { convert(it) } private fun arg(name: String, node: PartiqlAst.PartiqlAstNode) = when (node) { - is PartiqlAst.Expr -> Plan.argValue( + is PartiqlAst.Expr -> argValue( name = name, value = convert(node), ) - is PartiqlAst.Type -> Plan.argType( + is PartiqlAst.Type -> argType( name = name, type = TypeConverter.convert(node) ) @@ -100,12 +118,12 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkMetas(node: MetaContainer, ctx: Ctx) = AstToPlan.unsupported(ctx.node) override fun walkExprMissing(node: PartiqlAst.Expr.Missing, ctx: Ctx) = visit(node) { - Plan.rexLit(ionNull(), StaticType.MISSING) + rexLit(ionNull(), StaticType.MISSING) } override fun walkExprLit(node: PartiqlAst.Expr.Lit, ctx: Ctx) = visit(node) { val ionType = node.value.type.toIonType() - Plan.rexLit( + rexLit( value = node.value, type = TypeConverter.convert(ionType) ) @@ -121,7 +139,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { internal = false ) } - Plan.rexCall( + rexCall( id = functionName, args = emptyList(), type = null @@ -129,7 +147,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprId(node: PartiqlAst.Expr.Id, ctx: Ctx) = visit(node) { - Plan.rexId( + rexId( name = node.name.text, case = convertCase(node.case), qualifier = when (node.qualifier) { @@ -141,16 +159,16 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprPath(node: PartiqlAst.Expr.Path, ctx: Ctx) = visit(node) { - Plan.rexPath( + rexPath( root = convert(node.root), steps = node.steps.map { when (it) { - is PartiqlAst.PathStep.PathExpr -> Plan.stepKey( + is PartiqlAst.PathStep.PathExpr -> stepKey( value = convert(it.index), case = convertCase(it.case) ) - is PartiqlAst.PathStep.PathUnpivot -> Plan.stepUnpivot() - is PartiqlAst.PathStep.PathWildcard -> Plan.stepWildcard() + is PartiqlAst.PathStep.PathUnpivot -> stepUnpivot() + is PartiqlAst.PathStep.PathWildcard -> stepWildcard() } }, type = null, @@ -158,7 +176,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprNot(node: PartiqlAst.Expr.Not, ctx: Ctx) = visit(node) { - Plan.rexUnary( + rexUnary( value = convert(node.expr), op = Rex.Unary.Op.NOT, type = StaticType.BOOL, @@ -166,7 +184,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprPos(node: PartiqlAst.Expr.Pos, ctx: Ctx) = visit(node) { - Plan.rexUnary( + rexUnary( value = convert(node.expr), op = Rex.Unary.Op.POS, type = StaticType.NUMERIC, @@ -174,7 +192,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprNeg(node: PartiqlAst.Expr.Neg, ctx: Ctx) = visit(node) { - Plan.rexUnary( + rexUnary( value = convert(node.expr), op = Rex.Unary.Op.NEG, type = StaticType.NUMERIC, @@ -182,7 +200,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprPlus(node: PartiqlAst.Expr.Plus, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.PLUS, @@ -191,7 +209,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprMinus(node: PartiqlAst.Expr.Minus, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.MINUS, @@ -200,7 +218,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprTimes(node: PartiqlAst.Expr.Times, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.TIMES, @@ -209,7 +227,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprDivide(node: PartiqlAst.Expr.Divide, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.DIV, @@ -218,7 +236,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprModulo(node: PartiqlAst.Expr.Modulo, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.MODULO, @@ -227,7 +245,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprBitwiseAnd(node: PartiqlAst.Expr.BitwiseAnd, accumulator: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.BITWISE_AND, @@ -238,7 +256,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprConcat(node: PartiqlAst.Expr.Concat, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.CONCAT, @@ -247,7 +265,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprAnd(node: PartiqlAst.Expr.And, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.AND, @@ -256,7 +274,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprOr(node: PartiqlAst.Expr.Or, ctx: Ctx) = visit(node) { - Plan.rexBinary( + rexBinary( lhs = convert(node.operands[0]), rhs = convert(node.operands[1]), op = Rex.Binary.Op.OR, @@ -266,7 +284,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprEq(node: PartiqlAst.Expr.Eq, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.EQ, @@ -276,7 +294,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprNe(node: PartiqlAst.Expr.Ne, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.NEQ, @@ -286,7 +304,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprGt(node: PartiqlAst.Expr.Gt, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.GT, @@ -296,7 +314,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprGte(node: PartiqlAst.Expr.Gte, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.GTE, @@ -306,7 +324,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprLt(node: PartiqlAst.Expr.Lt, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.LT, @@ -316,7 +334,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprLte(node: PartiqlAst.Expr.Lte, ctx: Ctx) = visit(node) { val (lhs, rhs) = walkComparisonOperands(node.operands) - Plan.rexBinary( + rexBinary( lhs = lhs, rhs = rhs, op = Rex.Binary.Op.LTE, @@ -369,7 +387,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { if (relProject.bindings.any { it.value.isProjectAll() }) { error("Unimplemented feature: coercion of SELECT *.") } - Plan.rexCollectionArray( + rexCollectionArray( relProject.bindings.map { it.value }, type = StaticType.LIST ) @@ -379,7 +397,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprLike(node: PartiqlAst.Expr.Like, ctx: Ctx) = visit(node) { when (val escape = node.escape) { - null -> Plan.rexCall( + null -> rexCall( id = Constants.like, args = args( "value" to node.value, @@ -387,7 +405,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { ), type = StaticType.BOOL, ) - else -> Plan.rexCall( + else -> rexCall( id = Constants.likeEscape, args = args( "value" to node.value, @@ -400,7 +418,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprBetween(node: PartiqlAst.Expr.Between, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.between, args = args("value" to node.value, "from" to node.from, "to" to node.to), type = StaticType.BOOL, @@ -415,20 +433,20 @@ internal object RexConverter : PartiqlAst.VisitorFold() { val potentialSubqueryRex = convert(node.operands[1]) val potentialSubquery = coercePotentialSubquery(potentialSubqueryRex) val rhs = (potentialSubquery as? Rex.Query.Scalar.Subquery)?.query ?: potentialSubquery - Plan.rexCall( + rexCall( id = Constants.inCollection, args = listOf( - Plan.argValue("lhs", lhs), - Plan.argValue("rhs", rhs), + argValue("lhs", lhs), + argValue("rhs", rhs), ), type = StaticType.BOOL, ) } override fun walkExprStruct(node: PartiqlAst.Expr.Struct, ctx: Ctx) = visit(node) { - Plan.rexTuple( + rexTuple( fields = node.fields.map { - Plan.field( + field( name = convert(it.first), value = convert(it.second) ) @@ -438,28 +456,28 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprBag(node: PartiqlAst.Expr.Bag, ctx: Ctx) = visit(node) { - Plan.rexCollectionBag( + rexCollectionBag( values = convert(node.values), type = StaticType.BAG, ) } override fun walkExprList(node: PartiqlAst.Expr.List, ctx: Ctx) = visit(node) { - Plan.rexCollectionArray( + rexCollectionArray( values = convert(node.values), type = StaticType.LIST, ) } override fun walkExprSexp(node: PartiqlAst.Expr.Sexp, accumulator: Ctx) = visit(node) { - Plan.rexCollectionArray( + rexCollectionArray( values = convert(node.values), type = StaticType.LIST, ) } override fun walkExprCall(node: PartiqlAst.Expr.Call, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = node.funcName.text, args = args(*node.args.toTypedArray()), type = null, @@ -467,7 +485,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprCallAgg(node: PartiqlAst.Expr.CallAgg, ctx: Ctx) = visit(node) { - Plan.rexAgg( + rexAgg( id = node.funcName.text, args = listOf(convert(node.arg)), modifier = when (node.setq) { @@ -479,7 +497,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprIsType(node: PartiqlAst.Expr.IsType, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.isType, args = args("value" to node.value, "type" to node.type), type = StaticType.BOOL, @@ -487,10 +505,10 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprSimpleCase(node: PartiqlAst.Expr.SimpleCase, ctx: Ctx) = visit(node) { - Plan.rexSwitch( + rexSwitch( match = convert(node.expr), branches = node.cases.pairs.map { - Plan.branch( + branch( condition = convert(it.first), value = convert(it.second), ) @@ -501,10 +519,10 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprSearchedCase(node: PartiqlAst.Expr.SearchedCase, ctx: Ctx) = visit(node) { - Plan.rexSwitch( + rexSwitch( match = null, branches = node.cases.pairs.map { - Plan.branch( + branch( condition = convert(it.first), value = convert(it.second), ) @@ -542,7 +560,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { is PartiqlAst.BagOpType.OuterExcept -> Constants.outerSetExcept } } - Plan.rexCall( + rexCall( id = op, args = args("lhs" to node.operands[0], "rhs" to node.operands[1]), type = StaticType.BAG, @@ -550,7 +568,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprCast(node: PartiqlAst.Expr.Cast, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.cast, args = args("value" to node.value, "type" to node.asType), type = TypeConverter.convert(node.asType), @@ -558,7 +576,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprCanCast(node: PartiqlAst.Expr.CanCast, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.canCast, args = args("value" to node.value, "type" to node.asType), type = StaticType.BOOL, @@ -566,7 +584,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprCanLosslessCast(node: PartiqlAst.Expr.CanLosslessCast, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.canLosslessCast, args = args("value" to node.value, "type" to node.asType), type = StaticType.BOOL, @@ -574,7 +592,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprNullIf(node: PartiqlAst.Expr.NullIf, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.nullIf, args = args(node.expr1, node.expr2), type = StaticType.BOOL, @@ -582,7 +600,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { } override fun walkExprCoalesce(node: PartiqlAst.Expr.Coalesce, ctx: Ctx) = visit(node) { - Plan.rexCall( + rexCall( id = Constants.coalesce, args = args(node.args), type = null, @@ -591,7 +609,7 @@ internal object RexConverter : PartiqlAst.VisitorFold() { override fun walkExprSelect(node: PartiqlAst.Expr.Select, ctx: Ctx) = visit(node) { when (val query = RelConverter.convert(node)) { - is Rex.Query.Collection -> Plan.rexQueryScalarSubquery(query, null) + is Rex.Query.Collection -> rexQueryScalarSubquery(query, null) is Rex.Query.Scalar -> query } } diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt index bb0c33a28f..5342e19636 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt @@ -32,7 +32,6 @@ import org.antlr.v4.runtime.TokenStream import org.antlr.v4.runtime.atn.PredictionMode import org.antlr.v4.runtime.misc.ParseCancellationException import org.antlr.v4.runtime.tree.TerminalNode -import org.partiql.ast.Ast import org.partiql.ast.AstNode import org.partiql.ast.DatetimeField import org.partiql.ast.Exclude @@ -52,7 +51,155 @@ import org.partiql.ast.Sort import org.partiql.ast.Statement import org.partiql.ast.TableDefinition import org.partiql.ast.Type -import org.partiql.ast.builder.AstFactory +import org.partiql.ast.exclude +import org.partiql.ast.excludeExcludeExpr +import org.partiql.ast.excludeStepExcludeCollectionIndex +import org.partiql.ast.excludeStepExcludeCollectionWildcard +import org.partiql.ast.excludeStepExcludeTupleAttr +import org.partiql.ast.excludeStepExcludeTupleWildcard +import org.partiql.ast.exprAgg +import org.partiql.ast.exprBagOp +import org.partiql.ast.exprBetween +import org.partiql.ast.exprBinary +import org.partiql.ast.exprCall +import org.partiql.ast.exprCanCast +import org.partiql.ast.exprCanLosslessCast +import org.partiql.ast.exprCase +import org.partiql.ast.exprCaseBranch +import org.partiql.ast.exprCast +import org.partiql.ast.exprCoalesce +import org.partiql.ast.exprCollection +import org.partiql.ast.exprDateAdd +import org.partiql.ast.exprDateDiff +import org.partiql.ast.exprExtract +import org.partiql.ast.exprInCollection +import org.partiql.ast.exprIon +import org.partiql.ast.exprIsType +import org.partiql.ast.exprLike +import org.partiql.ast.exprLit +import org.partiql.ast.exprMatch +import org.partiql.ast.exprNullIf +import org.partiql.ast.exprOverlay +import org.partiql.ast.exprParameter +import org.partiql.ast.exprPath +import org.partiql.ast.exprPathStepIndex +import org.partiql.ast.exprPathStepSymbol +import org.partiql.ast.exprPathStepUnpivot +import org.partiql.ast.exprPathStepWildcard +import org.partiql.ast.exprPosition +import org.partiql.ast.exprSFW +import org.partiql.ast.exprSessionAttribute +import org.partiql.ast.exprStruct +import org.partiql.ast.exprStructField +import org.partiql.ast.exprSubstring +import org.partiql.ast.exprTrim +import org.partiql.ast.exprUnary +import org.partiql.ast.exprVar +import org.partiql.ast.exprWindow +import org.partiql.ast.exprWindowOver +import org.partiql.ast.fromJoin +import org.partiql.ast.fromValue +import org.partiql.ast.graphMatch +import org.partiql.ast.graphMatchLabelConj +import org.partiql.ast.graphMatchLabelDisj +import org.partiql.ast.graphMatchLabelName +import org.partiql.ast.graphMatchLabelNegation +import org.partiql.ast.graphMatchLabelWildcard +import org.partiql.ast.graphMatchPattern +import org.partiql.ast.graphMatchPatternPartEdge +import org.partiql.ast.graphMatchPatternPartNode +import org.partiql.ast.graphMatchPatternPartPattern +import org.partiql.ast.graphMatchQuantifier +import org.partiql.ast.graphMatchSelectorAllShortest +import org.partiql.ast.graphMatchSelectorAny +import org.partiql.ast.graphMatchSelectorAnyK +import org.partiql.ast.graphMatchSelectorAnyShortest +import org.partiql.ast.graphMatchSelectorShortestK +import org.partiql.ast.graphMatchSelectorShortestKGroup +import org.partiql.ast.groupBy +import org.partiql.ast.groupByKey +import org.partiql.ast.identifierSymbol +import org.partiql.ast.let +import org.partiql.ast.letBinding +import org.partiql.ast.onConflict +import org.partiql.ast.onConflictActionDoNothing +import org.partiql.ast.onConflictActionDoReplace +import org.partiql.ast.onConflictActionDoUpdate +import org.partiql.ast.onConflictTargetConstraint +import org.partiql.ast.onConflictTargetSymbols +import org.partiql.ast.orderBy +import org.partiql.ast.path +import org.partiql.ast.pathStepIndex +import org.partiql.ast.pathStepSymbol +import org.partiql.ast.returning +import org.partiql.ast.returningColumn +import org.partiql.ast.returningColumnValueExpression +import org.partiql.ast.returningColumnValueWildcard +import org.partiql.ast.selectPivot +import org.partiql.ast.selectProject +import org.partiql.ast.selectProjectItemAll +import org.partiql.ast.selectProjectItemExpression +import org.partiql.ast.selectStar +import org.partiql.ast.selectValue +import org.partiql.ast.setOp +import org.partiql.ast.sort +import org.partiql.ast.statementDDLCreateIndex +import org.partiql.ast.statementDDLCreateTable +import org.partiql.ast.statementDDLDropIndex +import org.partiql.ast.statementDDLDropTable +import org.partiql.ast.statementDMLBatchLegacy +import org.partiql.ast.statementDMLBatchLegacyOpDelete +import org.partiql.ast.statementDMLBatchLegacyOpInsert +import org.partiql.ast.statementDMLBatchLegacyOpInsertLegacy +import org.partiql.ast.statementDMLBatchLegacyOpRemove +import org.partiql.ast.statementDMLBatchLegacyOpSet +import org.partiql.ast.statementDMLDelete +import org.partiql.ast.statementDMLDeleteTarget +import org.partiql.ast.statementDMLInsert +import org.partiql.ast.statementDMLInsertLegacy +import org.partiql.ast.statementDMLRemove +import org.partiql.ast.statementDMLReplace +import org.partiql.ast.statementDMLUpdate +import org.partiql.ast.statementDMLUpdateAssignment +import org.partiql.ast.statementDMLUpsert +import org.partiql.ast.statementExec +import org.partiql.ast.statementExplain +import org.partiql.ast.statementExplainTargetDomain +import org.partiql.ast.statementQuery +import org.partiql.ast.tableDefinition +import org.partiql.ast.tableDefinitionColumn +import org.partiql.ast.tableDefinitionColumnConstraint +import org.partiql.ast.tableDefinitionColumnConstraintBodyNotNull +import org.partiql.ast.tableDefinitionColumnConstraintBodyNullable +import org.partiql.ast.typeAny +import org.partiql.ast.typeBag +import org.partiql.ast.typeBlob +import org.partiql.ast.typeBool +import org.partiql.ast.typeChar +import org.partiql.ast.typeClob +import org.partiql.ast.typeCustom +import org.partiql.ast.typeDate +import org.partiql.ast.typeDecimal +import org.partiql.ast.typeFloat32 +import org.partiql.ast.typeFloat64 +import org.partiql.ast.typeInt +import org.partiql.ast.typeInt2 +import org.partiql.ast.typeInt4 +import org.partiql.ast.typeInt8 +import org.partiql.ast.typeList +import org.partiql.ast.typeMissing +import org.partiql.ast.typeNullType +import org.partiql.ast.typeNumeric +import org.partiql.ast.typeReal +import org.partiql.ast.typeSexp +import org.partiql.ast.typeString +import org.partiql.ast.typeStruct +import org.partiql.ast.typeSymbol +import org.partiql.ast.typeTime +import org.partiql.ast.typeTimeWithTz +import org.partiql.ast.typeTimestamp +import org.partiql.ast.typeTuple +import org.partiql.ast.typeVarchar import org.partiql.parser.PartiQLLexerException import org.partiql.parser.PartiQLParser import org.partiql.parser.PartiQLParserException @@ -171,7 +318,7 @@ internal class PartiQLParserDefault : PartiQLParser { line: Int, charPositionInLine: Int, msg: String, - e: RecognitionException? + e: RecognitionException?, ) { if (offendingSymbol is Token) { val token = offendingSymbol.text @@ -208,7 +355,7 @@ internal class PartiQLParserDefault : PartiQLParser { line: Int, charPositionInLine: Int, msg: String, - e: RecognitionException? + e: RecognitionException?, ) { if (offendingSymbol is Token) { val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" @@ -274,9 +421,6 @@ internal class PartiQLParserDefault : PartiQLParser { private val parameters: Map = mapOf(), ) : PartiQLBaseVisitor() { - // Use default factory - private val factory = Ast - companion object { private val rules = GeneratedParser.ruleNames.asList() @@ -287,7 +431,7 @@ internal class PartiQLParserDefault : PartiQLParser { fun translate( source: String, tokens: CountingTokenStream, - tree: GeneratedParser.RootContext + tree: GeneratedParser.RootContext, ): PartiQLParser.Result { val locations = SourceLocations.Mutable() val visitor = Visitor(locations, tokens.parameterIndexes) @@ -342,10 +486,10 @@ internal class PartiQLParserDefault : PartiQLParser { /** * Each visit attaches source locations from the given parse tree node; constructs nodes via the factory. */ - private inline fun translate(ctx: ParserRuleContext, block: AstFactory.() -> T): T { - val node = factory.block() + private inline fun translate(ctx: ParserRuleContext, block: () -> T): T { + val node = block() if (ctx.start != null) { - locations[node._id] = SourceLocation( + locations[node.tag] = SourceLocation( line = ctx.start.line, offset = ctx.start.charPositionInLine + 1, length = (ctx.stop?.stopIndex ?: ctx.start.stopIndex) - ctx.start.startIndex + 1, @@ -922,27 +1066,31 @@ internal class PartiQLParserDefault : PartiQLParser { excludeStepExcludeTupleAttr(identifier) } - override fun visitExcludeExprCollectionIndex(ctx: GeneratedParser.ExcludeExprCollectionIndexContext) = translate(ctx) { - val index = ctx.index.text.toInt() - excludeStepExcludeCollectionIndex(index) - } + override fun visitExcludeExprCollectionIndex(ctx: GeneratedParser.ExcludeExprCollectionIndexContext) = + translate(ctx) { + val index = ctx.index.text.toInt() + excludeStepExcludeCollectionIndex(index) + } - override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = translate(ctx) { - val attr = ctx.attr.getStringValue() - val identifier = identifierSymbol( - attr, - Identifier.CaseSensitivity.SENSITIVE, - ) - excludeStepExcludeTupleAttr(identifier) - } + override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = + translate(ctx) { + val attr = ctx.attr.getStringValue() + val identifier = identifierSymbol( + attr, + Identifier.CaseSensitivity.SENSITIVE, + ) + excludeStepExcludeTupleAttr(identifier) + } - override fun visitExcludeExprCollectionWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprCollectionWildcardContext) = translate(ctx) { - excludeStepExcludeCollectionWildcard() - } + override fun visitExcludeExprCollectionWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprCollectionWildcardContext) = + translate(ctx) { + excludeStepExcludeCollectionWildcard() + } - override fun visitExcludeExprTupleWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprTupleWildcardContext) = translate(ctx) { - excludeStepExcludeTupleWildcard() - } + override fun visitExcludeExprTupleWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprTupleWildcardContext) = + translate(ctx) { + excludeStepExcludeTupleWildcard() + } /** * @@ -1333,7 +1481,7 @@ internal class PartiQLParserDefault : PartiQLParser { private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr { val l = visit(lhs) as Expr val r = visit(rhs) as Expr - return factory.exprBinary(op, l, r) + return exprBinary(op, l, r) } private fun convertBinaryOp(token: Token) = when (token.type) { @@ -1450,8 +1598,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitParameter(ctx: GeneratedParser.ParameterContext) = translate(ctx) { val index = parameters[ctx.QUESTION_MARK().symbol.tokenIndex] ?: throw error( - ctx, - "Unable to find index of parameter." + ctx, "Unable to find index of parameter." ) exprParameter(index) } @@ -1521,9 +1668,10 @@ internal class PartiQLParserDefault : PartiQLParser { exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) } - override fun visitExprTermCurrentDate(ctx: org.partiql.parser.antlr.PartiQLParser.ExprTermCurrentDateContext) = translate(ctx) { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) - } + override fun visitExprTermCurrentDate(ctx: org.partiql.parser.antlr.PartiQLParser.ExprTermCurrentDateContext) = + translate(ctx) { + exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) + } /** * @@ -1926,7 +2074,8 @@ internal class PartiQLParserDefault : PartiQLParser { else -> ctx.map { visit(it) as T } } - private inline fun visitOrNull(ctx: ParserRuleContext?): T? = ctx?.let { it.accept(this) as T } + private inline fun visitOrNull(ctx: ParserRuleContext?): T? = + ctx?.let { it.accept(this) as T } private inline fun visitAs(ctx: ParserRuleContext): T = visit(ctx) as T @@ -1955,7 +2104,7 @@ internal class PartiQLParserDefault : PartiQLParser { */ private fun getTimeStringAndPrecision( stringNode: TerminalNode, - integerNode: TerminalNode? + integerNode: TerminalNode?, ): Pair { val timeString = stringNode.getStringValue() val precision = when (integerNode) { @@ -2036,7 +2185,7 @@ internal class PartiQLParserDefault : PartiQLParser { selectProjectItemAll(path.root) } path.steps.last() is Expr.Path.Step.Unpivot -> { - selectProjectItemAll(factory.exprPath(path.root, steps)) + selectProjectItemAll(exprPath(path.root, steps)) } else -> { selectProjectItemExpression(path, alias) @@ -2054,7 +2203,7 @@ internal class PartiQLParserDefault : PartiQLParser { else -> throw error(this, "Unsupported token for grabbing string value.") } - private fun String.toIdentifier(): Identifier.Symbol = factory.identifierSymbol( + private fun String.toIdentifier(): Identifier.Symbol = identifierSymbol( symbol = this, caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE, ) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/impl/PartiQLParserSessionAttributeTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/impl/PartiQLParserSessionAttributeTests.kt index 9b07bc47c3..e1d632ac47 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/impl/PartiQLParserSessionAttributeTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/impl/PartiQLParserSessionAttributeTests.kt @@ -3,7 +3,10 @@ package org.partiql.parser.impl import org.junit.jupiter.api.Test import org.partiql.ast.AstNode import org.partiql.ast.Expr -import org.partiql.ast.builder.AstFactory +import org.partiql.ast.exprBinary +import org.partiql.ast.exprLit +import org.partiql.ast.exprSessionAttribute +import org.partiql.ast.statementQuery import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int64Value import kotlin.test.assertEquals @@ -13,9 +16,7 @@ class PartiQLParserSessionAttributeTests { private val parser = PartiQLParserDefault() - private fun query(body: AstFactory.() -> Expr) = AstFactory.create { - statementQuery(this.body()) - } + private inline fun query(body: () -> Expr) = statementQuery(body()) @Test fun currentUserUpperCase() = assertExpression( diff --git a/partiql-plan/build.gradle.kts b/partiql-plan/build.gradle.kts index f5e789f784..9ef2e1fd6a 100644 --- a/partiql-plan/build.gradle.kts +++ b/partiql-plan/build.gradle.kts @@ -47,6 +47,7 @@ val generate = tasks.register("generate") { "-o", "$buildDir/generated-src", "-p", "org.partiql.plan", "-u", "Plan", + "--poems", "factory", "--poems", "visitor", "--poems", "builder", "--poems", "util", diff --git a/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt b/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt deleted file mode 100644 index 9dfcd6a9e1..0000000000 --- a/partiql-plan/src/main/kotlin/org/partiql/plan/Plan.kt +++ /dev/null @@ -1,15 +0,0 @@ -package org.partiql.plan - -import org.partiql.plan.builder.PlanFactoryImpl - -/** - * Singleton instance of the default factory. Also accessible via `PlanFactory.DEFAULT`. - */ -object Plan : PlanBaseFactory() - -/** - * PlanBaseFactory can be used to create a factory which extends from the factory provided by PlanFactory.DEFAULT. - */ -public abstract class PlanBaseFactory : PlanFactoryImpl() { - // internal default overrides here -} diff --git a/test/sprout-tests/build.gradle.kts b/test/sprout-tests/build.gradle.kts index aa748f374a..fd96fb8f69 100644 --- a/test/sprout-tests/build.gradle.kts +++ b/test/sprout-tests/build.gradle.kts @@ -30,6 +30,7 @@ val generate = tasks.register("generate") { "-o", "$buildDir/generated-src", "-p", "org.partiql.sprout.tests.example", "-u", "Example", + "--poems", "factory", "--poems", "visitor", "--poems", "builder", "--poems", "util", diff --git a/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/EqualityTests.kt b/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/EqualityTests.kt index f8624cf5df..5140f26945 100644 --- a/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/EqualityTests.kt +++ b/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/EqualityTests.kt @@ -18,7 +18,6 @@ import com.amazon.ionelement.api.ionInt import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ArgumentsSource import org.partiql.sprout.tests.ArgumentsProviderBase -import org.partiql.sprout.tests.example.builder.ExampleFactoryImpl import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -46,65 +45,65 @@ class EqualityTests { } class EqualArgumentsProvider : ArgumentsProviderBase() { - private val factory = ExampleFactoryImpl() + override fun getParameters(): List = listOf( - TestCase(factory.exprEmpty(), factory.exprEmpty()), - TestCase(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(1))), + TestCase(exprEmpty(), exprEmpty()), + TestCase(exprIon(ionInt(1)), exprIon(ionInt(1))), TestCase( - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), emptyList() ), - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), emptyList() ) ), TestCase( - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), listOf( - factory.identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), - factory.identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), - factory.identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), ) ), - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), listOf( - factory.identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), - factory.identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), - factory.identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), ) ) ), TestCase( - factory.statementQuery(factory.exprEmpty()), - factory.statementQuery(factory.exprEmpty()) + statementQuery(exprEmpty()), + statementQuery(exprEmpty()) ), // Tests deep equality of LISTS TestCase( - factory.exprNested( + exprNested( itemsList = listOf( listOf( - factory.exprEmpty(), - factory.exprIon(ionInt(1)) + exprEmpty(), + exprIon(ionInt(1)) ), listOf( - factory.exprIon(ionInt(3)) + exprIon(ionInt(3)) ) ), itemsSet = emptySet(), itemsMap = emptyMap() ), - factory.exprNested( + exprNested( itemsList = listOf( listOf( - factory.exprEmpty(), - factory.exprIon(ionInt(1)) + exprEmpty(), + exprIon(ionInt(1)) ), listOf( - factory.exprIon(ionInt(3)) + exprIon(ionInt(3)) ) ), itemsSet = emptySet(), @@ -113,50 +112,50 @@ class EqualityTests { ), // Tests deep equality of SETS TestCase( - first = factory.exprNested( + first = exprNested( itemsList = emptyList(), itemsSet = setOf( setOf(), - setOf(factory.exprEmpty()), - setOf(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(2))) + setOf(exprEmpty()), + setOf(exprIon(ionInt(1)), exprIon(ionInt(2))) ), itemsMap = emptyMap() ), - second = factory.exprNested( + second = exprNested( itemsList = emptyList(), itemsSet = setOf( setOf(), - setOf(factory.exprEmpty()), - setOf(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(2))) + setOf(exprEmpty()), + setOf(exprIon(ionInt(1)), exprIon(ionInt(2))) ), itemsMap = emptyMap() ), ), // Tests deep equality of MAPS TestCase( - first = factory.exprNested( + first = exprNested( itemsList = emptyList(), itemsSet = emptySet(), itemsMap = mapOf( "hello" to mapOf( - "world" to factory.exprEmpty(), - "!" to factory.exprIon(ionInt(1)) + "world" to exprEmpty(), + "!" to exprIon(ionInt(1)) ), "goodbye" to mapOf( - "friend" to factory.exprIon(ionInt(2)) + "friend" to exprIon(ionInt(2)) ) ) ), - second = factory.exprNested( + second = exprNested( itemsList = emptyList(), itemsSet = emptySet(), itemsMap = mapOf( "hello" to mapOf( - "world" to factory.exprEmpty(), - "!" to factory.exprIon(ionInt(1)) + "world" to exprEmpty(), + "!" to exprIon(ionInt(1)) ), "goodbye" to mapOf( - "friend" to factory.exprIon(ionInt(2)) + "friend" to exprIon(ionInt(2)) ) ) ), @@ -170,55 +169,55 @@ class EqualityTests { } class NotEqualArgumentsProvider : ArgumentsProviderBase() { - private val factory = ExampleFactoryImpl() + override fun getParameters(): List = listOf( - TestCase(factory.exprEmpty(), factory.exprIon(ionInt(1))), - TestCase(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(2))), + TestCase(exprEmpty(), exprIon(ionInt(1))), + TestCase(exprIon(ionInt(1)), exprIon(ionInt(2))), TestCase( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), - factory.identifierSymbol("hello", Identifier.CaseSensitivity.SENSITIVE) + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("hello", Identifier.CaseSensitivity.SENSITIVE) ), TestCase( - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), listOf( - factory.identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), - factory.identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), - factory.identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("world", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), ) ), - factory.identifierQualified( - factory.identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), + identifierQualified( + identifierSymbol("hello", Identifier.CaseSensitivity.INSENSITIVE), listOf( - factory.identifierSymbol("NOT_WORLD", Identifier.CaseSensitivity.SENSITIVE), - factory.identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), - factory.identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("NOT_WORLD", Identifier.CaseSensitivity.SENSITIVE), + identifierSymbol("yeah", Identifier.CaseSensitivity.INSENSITIVE), + identifierSymbol("foliage", Identifier.CaseSensitivity.INSENSITIVE), ) ) ), // Tests deep equality of LISTS TestCase( - factory.exprNested( + exprNested( itemsList = listOf( listOf( - factory.exprEmpty(), - factory.exprIon(ionInt(1)) + exprEmpty(), + exprIon(ionInt(1)) ), listOf( - factory.exprIon(ionInt(3)) + exprIon(ionInt(3)) ) ), itemsSet = emptySet(), itemsMap = emptyMap() ), - factory.exprNested( + exprNested( itemsList = listOf( listOf( - factory.exprEmpty(), - factory.exprIon(ionInt(2)) + exprEmpty(), + exprIon(ionInt(2)) ), listOf( - factory.exprIon(ionInt(3)) + exprIon(ionInt(3)) ) ), itemsSet = emptySet(), @@ -227,50 +226,50 @@ class EqualityTests { ), // Tests deep equality of SETS TestCase( - first = factory.exprNested( + first = exprNested( itemsList = emptyList(), itemsSet = setOf( setOf(), - setOf(factory.exprEmpty()), - setOf(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(2))) + setOf(exprEmpty()), + setOf(exprIon(ionInt(1)), exprIon(ionInt(2))) ), itemsMap = emptyMap() ), - second = factory.exprNested( + second = exprNested( itemsList = emptyList(), itemsSet = setOf( setOf(), - setOf(factory.exprEmpty()), - setOf(factory.exprIon(ionInt(1)), factory.exprIon(ionInt(3))) + setOf(exprEmpty()), + setOf(exprIon(ionInt(1)), exprIon(ionInt(3))) ), itemsMap = emptyMap() ), ), // Tests deep equality of MAPS TestCase( - first = factory.exprNested( + first = exprNested( itemsList = emptyList(), itemsSet = emptySet(), itemsMap = mapOf( "hello" to mapOf( - "world" to factory.exprEmpty(), - "!" to factory.exprIon(ionInt(1)) + "world" to exprEmpty(), + "!" to exprIon(ionInt(1)) ), "goodbye" to mapOf( - "friend" to factory.exprIon(ionInt(2)) + "friend" to exprIon(ionInt(2)) ) ) ), - second = factory.exprNested( + second = exprNested( itemsList = emptyList(), itemsSet = emptySet(), itemsMap = mapOf( "hello" to mapOf( - "world" to factory.exprEmpty(), - "!" to factory.exprIon(ionInt(1)) + "world" to exprEmpty(), + "!" to exprIon(ionInt(1)) ), "goodbye" to mapOf( - "friend" to factory.exprIon(ionInt(3)) + "friend" to exprIon(ionInt(3)) ) ) ), diff --git a/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/ToStringTests.kt b/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/ToStringTests.kt deleted file mode 100644 index 7d9f93bcf6..0000000000 --- a/test/sprout-tests/src/test/kotlin/org/partiql/sprout/tests/example/ToStringTests.kt +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.sprout.tests.example - -import com.amazon.ionelement.api.createIonElementLoader -import org.junit.jupiter.api.Test -import org.partiql.sprout.tests.example.builder.ExampleFactoryImpl -import kotlin.test.assertEquals - -/** - * While toString isn't a contract, here are some tests for making sure at least some things work. - * - * Notably, the following definitely won't get properly converted to Ion: - * - Maps - * - Imported Types - * - Escape Characters - */ -class ToStringTests { - private val factory = ExampleFactoryImpl() - private val loader = createIonElementLoader() - - @Test - fun simpleProductAndEnum() { - val product = factory.identifierSymbol( - symbol = "helloworld!", - caseSensitivity = Identifier.CaseSensitivity.SENSITIVE - ) - val expected = loader.loadSingleElement("IdentifierSymbol::{ symbol: \"helloworld!\", caseSensitivity: IdentifierCaseSensitivity::SENSITIVE }") - val actual = loader.loadSingleElement(product.toString()) - assertEquals(expected, actual) - } - - @Test - fun emptyProduct() { - val product = factory.exprEmpty() - val expected = loader.loadSingleElement("ExprEmpty::{ }") - val actual = loader.loadSingleElement(product.toString()) - assertEquals(expected, actual) - } - - @Test - fun list() { - val product = factory.identifierQualified( - root = factory.identifierSymbol( - symbol = "hello", - caseSensitivity = Identifier.CaseSensitivity.INSENSITIVE - ), - steps = listOf( - factory.identifierSymbol( - symbol = "world", - caseSensitivity = Identifier.CaseSensitivity.SENSITIVE - ), - ) - ) - val expectedString = """ - IdentifierQualified::{ - root: IdentifierSymbol::{ symbol: "hello", caseSensitivity: IdentifierCaseSensitivity::INSENSITIVE }, - steps: [ - IdentifierSymbol::{ symbol: "world", caseSensitivity: IdentifierCaseSensitivity::SENSITIVE }, - ] - } - """.trimIndent() - val expected = loader.loadSingleElement(expectedString) - val actual = loader.loadSingleElement(product.toString()) - assertEquals(expected, actual) - } -}