Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds performance optimizations for ExprCallDynamic #1388

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,11 @@ internal class Compiler(
}
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
@OptIn(FnExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val candidates = Array(node.candidates.size) {
val candidate = node.candidates[it]
val fn = symbols.getFn(candidate.fn)
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, coercions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ import org.partiql.value.PartiQLValueType
*/
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
internal class ExprCallDynamic(
private val candidates: List<Candidate>,
candidates: Array<Candidate>,
private val args: Array<Operator.Expr>
) : Operator.Expr {

private val candidateIndex = CandidateIndex.All(candidates)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I say create the index in the Compiler and pass to the ExprCallDynamic. In this scenario, we're passing an arg into the constructor just to be passed to another constructor.


override fun eval(env: Environment): PartiQLValue {
val actualArgs = args.map { it.eval(env) }.toTypedArray()
candidates.forEach { candidate ->
if (candidate.matches(actualArgs)) {
return candidate.eval(actualArgs, env)
}
val actualTypes = actualArgs.map { it.type }
candidateIndex.get(actualTypes)?.let {
return it.eval(actualArgs, env)
}
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function for arguments $argString in $candidates.")
append("Could not dynamically find function (${candidateIndex.name}) for arguments $argString.")
}
throw TypeCheckException(errorString)
}
Expand All @@ -47,13 +48,11 @@ internal class ExprCallDynamic(
*
* @see ExprCallDynamic
*/
internal class Candidate(
data class Candidate(
val fn: Fn,
val coercions: Array<Ref.Cast?>
) {

private val signatureParameters = fn.signature.parameters.map { it.type }.toTypedArray()

fun eval(originalArgs: Array<PartiQLValue>, env: Environment): PartiQLValue {
val args = originalArgs.mapIndexed { i, arg ->
when (val c = coercions[i]) {
Expand All @@ -63,32 +62,156 @@ internal class ExprCallDynamic(
}.toTypedArray()
return fn.invoke(args)
}
}

private sealed interface CandidateIndex {

public fun get(args: List<PartiQLValueType>): Candidate?
johnedquinn marked this conversation as resolved.
Show resolved Hide resolved

/**
* Preserves the original ordering of the passed-in candidates while making it faster to lookup matching
* functions. Utilizes both [Direct] and [Indirect].
*
* Say a user passes in the following ordered candidates:
* [
* foo(int16, int16) -> int16,
* foo(int32, int32) -> int32,
* foo(int64, int64) -> int64,
* foo(string, string) -> string,
* foo(struct, struct) -> struct,
* foo(numeric, numeric) -> numeric,
* foo(int64, dynamic) -> dynamic,
* foo(struct, dynamic) -> dynamic,
* foo(bool, bool) -> bool
* ]
*
* With the above candidates, the [CandidateIndex.All] will maintain the original ordering by utilizing:
* - [CandidateIndex.Direct] to match hashable runtime types
* - [CandidateIndex.Indirect] to match the dynamic type
*
* For the above example, the internal representation of [CandidateIndex.All] is a list of
* [CandidateIndex.Direct] and [CandidateIndex.Indirect] that looks like:
* ALL listOf(
* DIRECT hashMap(
* [int16, int16] --> foo(int16, int16) -> int16,
* [int32, int32] --> foo(int32, int32) -> int32,
* [int64, int64] --> foo(int64, int64) -> int64
* [string, string] --> foo(string, string) -> string,
* [struct, struct] --> foo(struct, struct) -> struct,
* [numeric, numeric] --> foo(numeric, numeric) -> numeric
* ),
* INDIRECT listOf(
* foo(int64, dynamic) -> dynamic,
* foo(struct, dynamic) -> dynamic
* ),
* DIRECT hashMap(
* [bool, bool] --> foo(bool, bool) -> bool
* )
* )
*
* @param candidates
*/
class All(
candidates: Array<Candidate>,
) : CandidateIndex {

private val lookups: List<CandidateIndex>
internal val name: String = candidates.first().fn.signature.name

internal fun matches(inputs: Array<PartiQLValue>): Boolean {
for (i in inputs.indices) {
val inputType = inputs[i].type
val parameterType = signatureParameters[i]
val c = coercions[i]
when (c) {
// coercion might be null if one of the following is true
// Function parameter is ANY,
// Input type is null
// input type is the same as function parameter
null -> {
if (!(inputType == parameterType || inputType == PartiQLValueType.NULL || parameterType == PartiQLValueType.ANY)) {
return false
init {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if we see any notable CPU utilization jump from this init block since there is more work when an ExprCallDynamic is initialized. Wondering if there are any situations in which this construction of the CandidateIndex results in worse evaluation performance/utilization (e.g. if candidates is just a list of two elements)? I assume this change is better in the general case in which there are more than just a few candidate elements though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call out -- in the future, I'd say it'd be a good idea for the Compiler to dictate what version of ExprCallDynamic to create (the indexing O(1) vs iteration O(n)) depending on the number of candidates (and our own testing). As for init by itself, it is only invoked at compilation.

For now, we know dynamic is quite expensive for our customers.

val lookupsMutable = mutableListOf<CandidateIndex>()
val accumulator = mutableListOf<Pair<List<PartiQLValueType>, Candidate>>()

// Indicates that we are currently processing dynamic candidates that accept ANY.
var activelyProcessingAny = true

candidates.forEach { candidate ->
johnedquinn marked this conversation as resolved.
Show resolved Hide resolved
// Gather the input types to the dynamic invocation
val lookupTypes = candidate.coercions.mapIndexed { index, cast ->
when (cast) {
null -> candidate.fn.signature.parameters[index].type
else -> cast.input
}
}
else -> {
// checking the input type is expected by the coercion
if (inputType != c.input) return false
// checking the result is expected by the function signature
// this should branch should never be reached, but leave it here for clarity
if (c.target != parameterType) error("Internal Error: Cast Target does not match Function Parameter")
val parametersIncludeAny = lookupTypes.any { it == PartiQLValueType.ANY }
// A way to simplify logic further below. If it's empty, add something and set the processing type.
if (accumulator.isEmpty()) {
activelyProcessingAny = parametersIncludeAny
accumulator.add(lookupTypes to candidate)
return@forEach
}
when (parametersIncludeAny) {
true -> when (activelyProcessingAny) {
true -> accumulator.add(lookupTypes to candidate)
false -> {
activelyProcessingAny = true
lookupsMutable.add(Direct.of(accumulator.toList()))
accumulator.clear()
accumulator.add(lookupTypes to candidate)
}
}
false -> when (activelyProcessingAny) {
false -> accumulator.add(lookupTypes to candidate)
true -> {
activelyProcessingAny = false
lookupsMutable.add(Indirect(accumulator.toList()))
accumulator.clear()
accumulator.add(lookupTypes to candidate)
}
}
}
}
// Add any remaining candidates (that we didn't submit due to not ending while switching)
when (accumulator.isEmpty()) {
true -> { /* Do nothing! */ }
false -> when (activelyProcessingAny) {
true -> lookupsMutable.add(Indirect(accumulator.toList()))
false -> lookupsMutable.add(Direct.of(accumulator.toList()))
}
}
this.lookups = lookupsMutable
}

override fun get(args: List<PartiQLValueType>): Candidate? {
return this.lookups.firstNotNullOfOrNull { it.get(args) }
}
}

/**
* An O(1) structure to quickly find directly matching dynamic candidates. This is specifically used for runtime
* types that can be matched directly. AKA int32, int64, etc. This does NOT include [PartiQLValueType.ANY].
*/
data class Direct private constructor(val directCandidates: HashMap<List<PartiQLValueType>, Candidate>) : CandidateIndex {

companion object {
internal fun of(candidates: List<Pair<List<PartiQLValueType>, Candidate>>): Direct {
val candidateMap = java.util.HashMap<List<PartiQLValueType>, Candidate>()
candidateMap.putAll(candidates)
return Direct(candidateMap)
}
}

override fun get(args: List<PartiQLValueType>): Candidate? {
return directCandidates[args]
}
johnedquinn marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Holds all candidates that expect a [PartiQLValueType.ANY] on input. This maintains the original
* precedence order.
*/
data class Indirect(private val candidates: List<Pair<List<PartiQLValueType>, Candidate>>) : CandidateIndex {
override fun get(args: List<PartiQLValueType>): Candidate? {
candidates.forEach { (types, candidate) ->
for (i in args.indices) {
if (args[i] != types[i] && types[i] != PartiQLValueType.ANY) {
return@forEach
}
}
return candidate
}
return null
}
return true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.value.ListValue
import org.partiql.value.NullValue
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
Expand All @@ -30,7 +31,13 @@ import org.partiql.value.StringValue
import org.partiql.value.SymbolValue
import org.partiql.value.TextValue
import org.partiql.value.bagValue
import org.partiql.value.binaryValue
import org.partiql.value.blobValue
import org.partiql.value.boolValue
import org.partiql.value.byteValue
import org.partiql.value.charValue
import org.partiql.value.clobValue
import org.partiql.value.dateValue
import org.partiql.value.decimalValue
import org.partiql.value.float32Value
import org.partiql.value.float64Value
Expand All @@ -40,9 +47,13 @@ import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import org.partiql.value.listValue
import org.partiql.value.missingValue
import org.partiql.value.sexpValue
import org.partiql.value.stringValue
import org.partiql.value.structValue
import org.partiql.value.symbolValue
import org.partiql.value.timeValue
import org.partiql.value.timestampValue
import java.math.BigDecimal
import java.math.BigInteger

Expand Down Expand Up @@ -79,14 +90,48 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
PartiQLValueType.LIST -> castFromCollection(arg as ListValue<*>, cast.target)
PartiQLValueType.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target)
PartiQLValueType.STRUCT -> TODO("CAST FROM STRUCT not yet implemented")
PartiQLValueType.NULL -> error("cast from NULL should be handled by Typer")
PartiQLValueType.NULL -> castFromNull(arg as NullValue, cast.target)
PartiQLValueType.MISSING -> error("cast from MISSING should be handled by Typer")
}
} catch (e: DataException) {
throw TypeCheckException()
}
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromNull(value: NullValue, t: PartiQLValueType): PartiQLValue {
return when (t) {
PartiQLValueType.ANY -> value
PartiQLValueType.BOOL -> boolValue(null)
PartiQLValueType.CHAR -> charValue(null)
PartiQLValueType.STRING -> stringValue(null)
PartiQLValueType.SYMBOL -> symbolValue(null)
PartiQLValueType.BINARY -> binaryValue(null)
PartiQLValueType.BYTE -> byteValue(null)
PartiQLValueType.BLOB -> blobValue(null)
PartiQLValueType.CLOB -> clobValue(null)
PartiQLValueType.DATE -> dateValue(null)
PartiQLValueType.TIME -> timeValue(null)
PartiQLValueType.TIMESTAMP -> timestampValue(null)
PartiQLValueType.INTERVAL -> TODO("Not yet supported")
PartiQLValueType.BAG -> bagValue<PartiQLValue>(null)
PartiQLValueType.LIST -> listValue<PartiQLValue>(null)
PartiQLValueType.SEXP -> sexpValue<PartiQLValue>(null)
PartiQLValueType.STRUCT -> structValue<PartiQLValue>(null)
PartiQLValueType.NULL -> value
PartiQLValueType.MISSING -> missingValue() // TODO: Os this allowed
PartiQLValueType.INT8 -> int8Value(null)
PartiQLValueType.INT16 -> int16Value(null)
PartiQLValueType.INT32 -> int32Value(null)
PartiQLValueType.INT64 -> int64Value(null)
PartiQLValueType.INT -> intValue(null)
PartiQLValueType.DECIMAL -> decimalValue(null)
PartiQLValueType.DECIMAL_ARBITRARY -> decimalValue(null)
PartiQLValueType.FLOAT32 -> float32Value(null)
PartiQLValueType.FLOAT64 -> float64Value(null)
}
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromBool(value: BoolValue, t: PartiQLValueType): PartiQLValue {
val v = value.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.partiql.value.structValue
import java.io.ByteArrayOutputStream
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

/**
Expand Down Expand Up @@ -1253,10 +1252,12 @@ class PartiQLEngineDefaultTest {

internal fun assert() {
val permissiveResult = run(mode = PartiQLEngine.Mode.PERMISSIVE)
assertEquals(expectedPermissive, permissiveResult, comparisonString(expectedPermissive, permissiveResult))
assert(expectedPermissive == permissiveResult.first) {
comparisonString(expectedPermissive, permissiveResult.first, permissiveResult.second)
}
var error: Throwable? = null
try {
when (val result = run(mode = PartiQLEngine.Mode.STRICT)) {
when (val result = run(mode = PartiQLEngine.Mode.STRICT).first) {
is CollectionValue<*> -> result.toList()
else -> result
}
Expand All @@ -1266,7 +1267,7 @@ class PartiQLEngineDefaultTest {
assertNotNull(error)
}

private fun run(mode: PartiQLEngine.Mode): PartiQLValue {
private fun run(mode: PartiQLEngine.Mode): Pair<PartiQLValue, PartiQLPlan> {
val statement = parser.parse(input).root
val catalog = MemoryCatalog.PartiQL().name("memory").build()
val connector = MemoryConnector(catalog)
Expand All @@ -1283,17 +1284,18 @@ class PartiQLEngineDefaultTest {
val plan = planner.plan(statement, session)
val prepared = engine.prepare(plan.plan, PartiQLEngine.Session(mapOf("memory" to connector), mode = mode))
when (val result = engine.execute(prepared)) {
is PartiQLResult.Value -> return result.value
is PartiQLResult.Value -> return result.value to plan.plan
is PartiQLResult.Error -> throw result.cause
}
}

@OptIn(PartiQLValueExperimental::class)
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue): String {
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: PartiQLPlan): String {
val expectedBuffer = ByteArrayOutputStream()
val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer)
expectedWriter.append(expected)
return buildString {
PlanPrinter.append(this, plan)
appendLine("Expected : $expectedBuffer")
expectedBuffer.reset()
expectedWriter.append(actual)
Expand Down Expand Up @@ -1444,6 +1446,7 @@ class PartiQLEngineDefaultTest {
).assert()

@Test
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
// TODO: Add to conformance tests
fun wildCard() =
SuccessTestCase(
Expand Down Expand Up @@ -1487,6 +1490,7 @@ class PartiQLEngineDefaultTest {
).assert()

@Test
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
// TODO: add to conformance tests
// Note that the existing pipeline produced identical result when supplying with
// SELECT VALUE v2.name FROM e as v0, v0.books as v1, unpivot v1.authors as v2;
Expand Down
Loading
Loading