From e73721241610a85d184d3748e86b00a2d4bd2783 Mon Sep 17 00:00:00 2001 From: Michael Rittmeister Date: Tue, 27 Apr 2021 08:30:01 +0200 Subject: [PATCH] Port to Kotlin 1.5 (#268) * Port dependencies to Kotlin 1.5 - Convert AbstractRateLimiter.AbstractRequestToken to a static rather than an inner class due to a compiler bug - Downgrade kx.ser-json to 1.0.0 to avoid a compiler bug - Bump other Kotlin dependencies to latest fixup! Port dependencies to Kotlin 1.5 - Convert AbstractRateLimiter.AbstractRequestToken to a static rather than an inner class due to a compiler bug - Downgrade kx.ser-json to 1.0.0 to avoid a compiler bug - Bump other Kotlin dependencies to latest * Replace deprecated kotlin.time APIs * Replace more deprecated APIs & inline classes * Replace deprecated usage of time API in tests * Possibly fix test Co-authored-by: Hope <34831095+HopeBaron@users.noreply.github.com> --- build.gradle.kts | 1 + buildSrc/src/main/kotlin/Dependencies.kt | 10 +-- .../kotlin/ratelimit/BucketRateLimiter.kt | 2 +- .../kotlin/ratelimit/BucketRateLimiterTest.kt | 5 +- .../channel/MessageChannelBehavior.kt | 8 +- .../main/kotlin/builder/kord/KordBuilder.kt | 9 ++- .../kotlin/event/guild/InviteCreateEvent.kt | 2 +- core/src/main/kotlin/gateway/MasterGateway.kt | 4 +- .../kotlin/performance/KordEventDropTest.kt | 76 +++++++++---------- .../src/main/kotlin/DefaultGatewayBuilder.kt | 7 +- gateway/src/main/kotlin/retry/LinearRetry.kt | 11 +-- .../test/kotlin/gateway/DefaultGatewayTest.kt | 4 +- gateway/src/test/kotlin/json/CommandTest.kt | 3 +- .../kotlin/ratelimit/AbstractRateLimiter.kt | 46 ++++++----- .../ratelimit/ExclusionRequestRateLimiter.kt | 13 ++-- .../ratelimit/ParallelRequestRateLimiter.kt | 6 +- .../kotlin/ratelimit/RequestRateLimiter.kt | 12 ++- .../AbstractRequestRateLimiterTest.kt | 7 +- 18 files changed, 119 insertions(+), 107 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 95179e1b92d..a57cbff30fa 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -5,6 +5,7 @@ import org.apache.commons.codec.binary.Base64 buildscript { repositories { jcenter() + mavenCentral() maven(url = "https://plugins.gradle.org/m2/") } dependencies { diff --git a/buildSrc/src/main/kotlin/Dependencies.kt b/buildSrc/src/main/kotlin/Dependencies.kt index eea2ceef887..db8f0d25d0a 100644 --- a/buildSrc/src/main/kotlin/Dependencies.kt +++ b/buildSrc/src/main/kotlin/Dependencies.kt @@ -1,10 +1,10 @@ object Versions { - const val kotlin = "1.5.0-RC" - const val kotlinxSerialization = "1.1.0" - const val ktor = "1.5.2" - const val kotlinxCoroutines = "1.4.2" + const val kotlin = "1.5.0" + const val kotlinxSerialization = "1.0.0" + const val ktor = "1.5.3" + const val kotlinxCoroutines = "1.4.3" const val kotlinLogging = "2.0.4" - const val atomicFu = "0.15.1" + const val atomicFu = "0.15.2" const val binaryCompatibilityValidator = "0.4.0" //test deps diff --git a/common/src/main/kotlin/ratelimit/BucketRateLimiter.kt b/common/src/main/kotlin/ratelimit/BucketRateLimiter.kt index 407b071a1b4..e613456b7dc 100644 --- a/common/src/main/kotlin/ratelimit/BucketRateLimiter.kt +++ b/common/src/main/kotlin/ratelimit/BucketRateLimiter.kt @@ -38,7 +38,7 @@ class BucketRateLimiter( private fun resetState() { count = 0 - nextInterval = clock.millis() + refillInterval.inMilliseconds.toLong() + nextInterval = clock.millis() + refillInterval.inWholeMilliseconds } private suspend fun delayUntilNextInterval() { diff --git a/common/src/test/kotlin/ratelimit/BucketRateLimiterTest.kt b/common/src/test/kotlin/ratelimit/BucketRateLimiterTest.kt index da970531a54..ba5407996ae 100644 --- a/common/src/test/kotlin/ratelimit/BucketRateLimiterTest.kt +++ b/common/src/test/kotlin/ratelimit/BucketRateLimiterTest.kt @@ -9,6 +9,7 @@ import java.time.ZoneOffset import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.asserter +import kotlin.time.Duration import kotlin.time.ExperimentalTime import kotlin.time.milliseconds @@ -16,7 +17,7 @@ import kotlin.time.milliseconds @ExperimentalCoroutinesApi class BucketRateLimiterTest { - val interval = 1_000_000.milliseconds + val interval = Duration.milliseconds(1_000_000) val instant = Instant.now() val clock = Clock.fixed(instant, ZoneOffset.UTC) lateinit var rateLimiter: BucketRateLimiter @@ -38,7 +39,7 @@ class BucketRateLimiterTest { rateLimiter.consume() rateLimiter.consume() - asserter.assertTrue("expected timeout of ${interval.inMilliseconds.toLong()} ms but was $currentTime ms", interval.toLongMilliseconds() == currentTime) + asserter.assertTrue("expected timeout of ${interval.inWholeMilliseconds} ms but was $currentTime ms", interval.inWholeMilliseconds == currentTime) } } diff --git a/core/src/main/kotlin/behavior/channel/MessageChannelBehavior.kt b/core/src/main/kotlin/behavior/channel/MessageChannelBehavior.kt index df3e98301f7..21fbb951c22 100644 --- a/core/src/main/kotlin/behavior/channel/MessageChannelBehavior.kt +++ b/core/src/main/kotlin/behavior/channel/MessageChannelBehavior.kt @@ -24,8 +24,8 @@ import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.coroutines.coroutineContext +import kotlin.time.Duration import kotlin.time.TimeMark -import kotlin.time.seconds /** * The behavior of a Discord channel that can use messages. @@ -199,7 +199,7 @@ interface MessageChannelBehavior : ChannelBehavior, Strategizable { suspend fun typeUntil(mark: TimeMark) { while (mark.hasNotPassedNow()) { type() - delay(8.seconds.toLongMilliseconds()) //bracing ourselves for some network delays + delay(Duration.seconds(8).inWholeMilliseconds) //bracing ourselves for some network delays } } @@ -212,7 +212,7 @@ interface MessageChannelBehavior : ChannelBehavior, Strategizable { suspend fun typeUntil(instant: Instant) { while (instant.isBefore(Instant.now())) { type() - delay(8.seconds.toLongMilliseconds()) //bracing ourselves for some network delays + delay(Duration.seconds(8).inWholeMilliseconds) //bracing ourselves for some network delays } } @@ -297,7 +297,7 @@ suspend inline fun T.withTyping(block: T.() -> Unit kord.launch(context = coroutineContext) { while (typing) { type() - delay(8.seconds.toLongMilliseconds()) + delay(Duration.seconds(8).inWholeMilliseconds) } } diff --git a/core/src/main/kotlin/builder/kord/KordBuilder.kt b/core/src/main/kotlin/builder/kord/KordBuilder.kt index 3e3d987cb86..2f6edd0053e 100644 --- a/core/src/main/kotlin/builder/kord/KordBuilder.kt +++ b/core/src/main/kotlin/builder/kord/KordBuilder.kt @@ -40,18 +40,19 @@ import kotlin.concurrent.thread import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.time.Duration import kotlin.time.seconds operator fun DefaultGateway.Companion.invoke( resources: ClientResources, - retry: Retry = LinearRetry(2.seconds, 60.seconds, 10) + retry: Retry = LinearRetry(Duration.seconds(2), Duration.seconds(60), 10) ): DefaultGateway { return DefaultGateway { url = "wss://gateway.discord.gg/" client = resources.httpClient reconnectRetry = retry - sendRateLimiter = BucketRateLimiter(120, 60.seconds) - identifyRateLimiter = BucketRateLimiter(1, 5.seconds) + sendRateLimiter = BucketRateLimiter(120, Duration.seconds(60)) + identifyRateLimiter = BucketRateLimiter(1, Duration.seconds(5)) } } @@ -63,7 +64,7 @@ class KordBuilder(val token: String) { private var shardsBuilder: (recommended: Int) -> Shards = { Shards(it) } private var gatewayBuilder: (resources: ClientResources, shards: List) -> List = { resources, shards -> - val rateLimiter = BucketRateLimiter(1, 5.seconds) + val rateLimiter = BucketRateLimiter(1, Duration.seconds(5)) shards.map { DefaultGateway { client = resources.httpClient diff --git a/core/src/main/kotlin/event/guild/InviteCreateEvent.kt b/core/src/main/kotlin/event/guild/InviteCreateEvent.kt index 9d333f007c1..5023577aaca 100644 --- a/core/src/main/kotlin/event/guild/InviteCreateEvent.kt +++ b/core/src/main/kotlin/event/guild/InviteCreateEvent.kt @@ -87,7 +87,7 @@ class InviteCreateEvent( /** * How long the invite is valid for (in seconds). */ - val maxAge: Duration get() = data.maxAge.seconds + val maxAge: Duration get() = Duration.seconds(data.maxAge) /** * The maximum number of times the invite can be used. diff --git a/core/src/main/kotlin/gateway/MasterGateway.kt b/core/src/main/kotlin/gateway/MasterGateway.kt index 72b2ce1afe5..09c7fa93e11 100644 --- a/core/src/main/kotlin/gateway/MasterGateway.kt +++ b/core/src/main/kotlin/gateway/MasterGateway.kt @@ -25,10 +25,10 @@ class MasterGateway( */ val averagePing get(): Duration? { - val pings = gateways.values.mapNotNull { it.ping.value?.inMicroseconds } + val pings = gateways.values.mapNotNull { it.ping.value?.inWholeMilliseconds } if (pings.isEmpty()) return null - return pings.average().microseconds + return Duration.microseconds(pings.average()) } diff --git a/core/src/test/kotlin/performance/KordEventDropTest.kt b/core/src/test/kotlin/performance/KordEventDropTest.kt index 39542aa3970..7656aa002e6 100644 --- a/core/src/test/kotlin/performance/KordEventDropTest.kt +++ b/core/src/test/kotlin/performance/KordEventDropTest.kt @@ -14,9 +14,9 @@ import dev.kord.rest.request.KtorRequestHandler import dev.kord.rest.service.RestClient import io.ktor.client.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.BroadcastChannel -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow import java.time.Clock import java.util.concurrent.CountDownLatch import java.util.concurrent.atomic.AtomicInteger @@ -25,7 +25,6 @@ import kotlin.coroutines.EmptyCoroutineContext import kotlin.test.Test import kotlin.test.assertEquals import kotlin.time.Duration -import kotlin.time.minutes class KordEventDropTest { @@ -47,13 +46,13 @@ class KordEventDropTest { } val kord = Kord( - resources = ClientResources("token", Shards(1), HttpClient(), EntitySupplyStrategy.cache, Intents.none), - cache = DataCache.none(), - MasterGateway(mapOf(0 to SpammyGateway)), - RestClient(KtorRequestHandler("token", clock = Clock.systemUTC())), - Snowflake("420"), - MutableSharedFlow(extraBufferCapacity = Int.MAX_VALUE), - Dispatchers.Default + resources = ClientResources("token", Shards(1), HttpClient(), EntitySupplyStrategy.cache, Intents.none), + cache = DataCache.none(), + MasterGateway(mapOf(0 to SpammyGateway)), + RestClient(KtorRequestHandler("token", clock = Clock.systemUTC())), + Snowflake("420"), + MutableSharedFlow(extraBufferCapacity = Int.MAX_VALUE), + Dispatchers.Default ) @Test @@ -61,32 +60,33 @@ class KordEventDropTest { val amount = 1_000 val event = GuildCreate( - DiscordGuild( - Snowflake("1337"), - "discord guild", - afkTimeout = 0, - defaultMessageNotifications = DefaultMessageNotificationLevel.AllMessages, - emojis = emptyList(), - explicitContentFilter = ExplicitContentFilter.AllMembers, - features = emptyList(), - mfaLevel = MFALevel.Elevated, - ownerId = Snowflake("123"), - preferredLocale = "en", - description = "A not really real guild", - premiumTier = PremiumTier.None, - region = "idk", - roles = emptyList(), - verificationLevel = VerificationLevel.High, - icon = null, - afkChannelId = null, - applicationId = null, - systemChannelFlags = SystemChannelFlags(0), - systemChannelId = null, - rulesChannelId = null, - vanityUrlCode = null, - banner = null, - publicUpdatesChannelId = null - ), 0) + DiscordGuild( + Snowflake("1337"), + "discord guild", + afkTimeout = 0, + defaultMessageNotifications = DefaultMessageNotificationLevel.AllMessages, + emojis = emptyList(), + explicitContentFilter = ExplicitContentFilter.AllMembers, + features = emptyList(), + mfaLevel = MFALevel.Elevated, + ownerId = Snowflake("123"), + preferredLocale = "en", + description = "A not really real guild", + premiumTier = PremiumTier.None, + region = "idk", + roles = emptyList(), + verificationLevel = VerificationLevel.High, + icon = null, + afkChannelId = null, + applicationId = null, + systemChannelFlags = SystemChannelFlags(0), + systemChannelId = null, + rulesChannelId = null, + vanityUrlCode = null, + banner = null, + publicUpdatesChannelId = null + ), 0 + ) val counter = AtomicInteger(0) val countdown = CountDownLatch(amount) @@ -99,7 +99,7 @@ class KordEventDropTest { SpammyGateway.events.emit(event) } - withTimeout(1.minutes) { + withTimeout(Duration.minutes(1).inWholeMilliseconds) { countdown.await() } assertEquals(amount, counter.get()) diff --git a/gateway/src/main/kotlin/DefaultGatewayBuilder.kt b/gateway/src/main/kotlin/DefaultGatewayBuilder.kt index f4d521a0533..5305a78640e 100644 --- a/gateway/src/main/kotlin/DefaultGatewayBuilder.kt +++ b/gateway/src/main/kotlin/DefaultGatewayBuilder.kt @@ -15,6 +15,7 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ObsoleteCoroutinesApi import kotlinx.coroutines.flow.MutableSharedFlow +import kotlin.time.Duration import kotlin.time.seconds class DefaultGatewayBuilder { @@ -32,9 +33,9 @@ class DefaultGatewayBuilder { install(WebSockets) install(JsonFeature) } - val retry = reconnectRetry ?: LinearRetry(2.seconds, 20.seconds, 10) - val sendRateLimiter = sendRateLimiter ?: BucketRateLimiter(120, 60.seconds) - val identifyRateLimiter = identifyRateLimiter ?: BucketRateLimiter(1, 5.seconds) + val retry = reconnectRetry ?: LinearRetry(Duration.seconds(2), Duration.seconds(20), 10) + val sendRateLimiter = sendRateLimiter ?: BucketRateLimiter(120, Duration.seconds(60)) + val identifyRateLimiter = identifyRateLimiter ?: BucketRateLimiter(1, Duration.seconds(5)) client.requestPipeline.intercept(HttpRequestPipeline.Render) { // CIO adds this header even if no extensions are used, which causes it to be empty diff --git a/gateway/src/main/kotlin/retry/LinearRetry.kt b/gateway/src/main/kotlin/retry/LinearRetry.kt index bcbfd080b50..dcf08181681 100644 --- a/gateway/src/main/kotlin/retry/LinearRetry.kt +++ b/gateway/src/main/kotlin/retry/LinearRetry.kt @@ -22,15 +22,12 @@ class LinearRetry constructor( private val maxTries: Int ) : Retry { - constructor(firstBackoffMillis: Long, maxBackoffMillis: Long, maxTries: Int) : - this(firstBackoffMillis.milliseconds, maxBackoffMillis.milliseconds, maxTries) - init { - require(firstBackoff.isPositive()) { "firstBackoff needs to be positive but was ${firstBackoff.toLongMilliseconds()} ms" } - require(maxBackoff.isPositive()) { "maxBackoff needs to be positive but was ${maxBackoff.toLongMilliseconds()} ms" } + require(firstBackoff.isPositive()) { "firstBackoff needs to be positive but was ${firstBackoff.inWholeMilliseconds} ms" } + require(maxBackoff.isPositive()) { "maxBackoff needs to be positive but was ${maxBackoff.inWholeMilliseconds} ms" } require( maxBackoff.minus(firstBackoff).isPositive() - ) { "maxBackoff ${maxBackoff.toLongMilliseconds()} ms needs to be bigger than firstBackoff ${firstBackoff.toLongMilliseconds()} ms" } + ) { "maxBackoff ${maxBackoff.inWholeMilliseconds} ms needs to be bigger than firstBackoff ${firstBackoff.inWholeMilliseconds} ms" } require(maxTries > 0) { "maxTries needs to be positive but was $maxTries" } } @@ -47,7 +44,7 @@ class LinearRetry constructor( if (!hasNext) error("max retries exceeded") tries.incrementAndGet() - var diff = (maxBackoff - firstBackoff).toLongMilliseconds() / maxTries + var diff = (maxBackoff - firstBackoff).inWholeMilliseconds / maxTries diff *= tries.value linearRetryLogger.trace { "retry attempt ${tries.value}, delaying for $diff ms" } delay(diff) diff --git a/gateway/src/test/kotlin/gateway/DefaultGatewayTest.kt b/gateway/src/test/kotlin/gateway/DefaultGatewayTest.kt index 2f1476500c4..d5b4ed6c261 100644 --- a/gateway/src/test/kotlin/gateway/DefaultGatewayTest.kt +++ b/gateway/src/test/kotlin/gateway/DefaultGatewayTest.kt @@ -22,7 +22,7 @@ import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import java.time.Duration import kotlin.time.ExperimentalTime -import kotlin.time.seconds +import kotlin.time.Duration as KDuration import kotlin.time.toKotlinDuration @FlowPreview @@ -44,7 +44,7 @@ class DefaultGatewayTest { } } - reconnectRetry = LinearRetry(2.seconds, 20.seconds, 10) + reconnectRetry = LinearRetry(KDuration.seconds(2), KDuration.seconds(20), 10) sendRateLimiter = BucketRateLimiter(120, Duration.ofSeconds(60).toKotlinDuration()) } diff --git a/gateway/src/test/kotlin/json/CommandTest.kt b/gateway/src/test/kotlin/json/CommandTest.kt index 69c736f6e57..167bc76d80b 100644 --- a/gateway/src/test/kotlin/json/CommandTest.kt +++ b/gateway/src/test/kotlin/json/CommandTest.kt @@ -13,6 +13,7 @@ import dev.kord.gateway.* import kotlinx.serialization.json.* import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test +import java.util.* private val json = Json { encodeDefaults = false } @@ -112,7 +113,7 @@ class CommandTest { put("d", buildJsonObject { put("since", since) put("activities", null as String?) - put("status", status.value.toLowerCase()) + put("status", status.value.lowercase(Locale.getDefault())) put("afk", afk) }) }) diff --git a/rest/src/main/kotlin/ratelimit/AbstractRateLimiter.kt b/rest/src/main/kotlin/ratelimit/AbstractRateLimiter.kt index 4e612be6b4f..41421dfc642 100644 --- a/rest/src/main/kotlin/ratelimit/AbstractRateLimiter.kt +++ b/rest/src/main/kotlin/ratelimit/AbstractRateLimiter.kt @@ -11,15 +11,15 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.sync.Mutex import mu.KLogger import java.time.Clock +import kotlin.time.Duration as KDuration import java.time.Duration import java.util.concurrent.ConcurrentHashMap -import kotlin.time.minutes abstract class AbstractRateLimiter internal constructor(val clock: Clock) : RequestRateLimiter { internal abstract val logger: KLogger - internal val autoBanRateLimiter = BucketRateLimiter(25000, 10.minutes) + internal val autoBanRateLimiter = BucketRateLimiter(25000, KDuration.minutes(10)) internal val globalSuspensionPoint = atomic(Reset(clock.instant())) internal val buckets = ConcurrentHashMap() internal val routeBuckets = ConcurrentHashMap>() @@ -45,7 +45,8 @@ abstract class AbstractRateLimiter internal constructor(val clock: Clock) : Requ internal abstract fun newToken(request: Request<*, *>, buckets: List): RequestToken - internal abstract inner class AbstractRequestToken( + internal abstract class AbstractRequestToken( + val rateLimiter: AbstractRateLimiter, val identity: RequestIdentifier, val requestBuckets: List ) : RequestToken { @@ -54,29 +55,32 @@ abstract class AbstractRateLimiter internal constructor(val clock: Clock) : Requ override val completed: Boolean get() = completableDeferred.isCompleted - open override suspend fun complete(response: RequestResponse) { - response.bucketKey?.let { key -> - if (identity.addBucket(key)) { - logger.trace { "[DISCOVERED]:[BUCKET]:Bucket ${response.bucketKey?.value} discovered for $identity" } - buckets[key] = key.bucket - } - } + override suspend fun complete(response: RequestResponse) { + with(rateLimiter) { + val key = response.bucketKey + if (key != null) { + if (identity.addBucket(key)) { - when (response) { - is RequestResponse.GlobalRateLimit -> { - logger.trace { "[RATE LIMIT]:[GLOBAL]:exhausted until ${response.reset.value}" } - globalSuspensionPoint.update { response.reset } + logger.trace { "[DISCOVERED]:[BUCKET]:Bucket discovered for" } + buckets[key] = key.bucket + } } - is RequestResponse.BucketRateLimit -> { - logger.trace { "[RATE LIMIT]:[BUCKET]:Bucket ${response.bucketKey.value} was exhausted until ${response.reset.value}" } - response.bucketKey.bucket.updateReset(response.reset) + + when (response) { + is RequestResponse.GlobalRateLimit -> { + logger.trace { "[RATE LIMIT]:[GLOBAL]:exhausted until ${response.reset.value}" } + globalSuspensionPoint.update { response.reset } + } + is RequestResponse.BucketRateLimit -> { + logger.trace { "[RATE LIMIT]:[BUCKET]:Bucket ${response.bucketKey.value} was exhausted until ${response.reset.value}" } + response.bucketKey.bucket.updateReset(response.reset) + } } - } - completableDeferred.complete(Unit) - requestBuckets.forEach { it.unlock() } + completableDeferred.complete(Unit) + requestBuckets.forEach { it.unlock() } + } } - } internal inner class Bucket(val id: BucketKey) { diff --git a/rest/src/main/kotlin/ratelimit/ExclusionRequestRateLimiter.kt b/rest/src/main/kotlin/ratelimit/ExclusionRequestRateLimiter.kt index 1a0d1331622..d8b5dacf8ac 100644 --- a/rest/src/main/kotlin/ratelimit/ExclusionRequestRateLimiter.kt +++ b/rest/src/main/kotlin/ratelimit/ExclusionRequestRateLimiter.kt @@ -3,13 +3,10 @@ package dev.kord.rest.ratelimit import dev.kord.rest.request.Request import dev.kord.rest.request.RequestIdentifier import dev.kord.rest.request.identifier -import kotlinx.coroutines.delay import kotlinx.coroutines.sync.Mutex import mu.KLogger import mu.KotlinLogging import java.time.Clock -import java.time.Duration -import java.time.Instant private val requestLogger = KotlinLogging.logger {} @@ -31,11 +28,15 @@ class ExclusionRequestRateLimiter(clock: Clock = Clock.systemUTC()) : AbstractRa } override fun newToken(request: Request<*, *>, buckets: List): RequestToken { - return ExclusionRequestToken(request.identifier, buckets) + return ExclusionRequestToken(this, request.identifier, buckets) } - private inner class ExclusionRequestToken(identity: RequestIdentifier, requestBuckets: List) : - AbstractRequestToken(identity, requestBuckets) { + private inner class ExclusionRequestToken( + rateLimiter: ExclusionRequestRateLimiter, + identity: RequestIdentifier, + requestBuckets: List + ) : + AbstractRequestToken(rateLimiter, identity, requestBuckets) { override suspend fun complete(response: RequestResponse) { super.complete(response) diff --git a/rest/src/main/kotlin/ratelimit/ParallelRequestRateLimiter.kt b/rest/src/main/kotlin/ratelimit/ParallelRequestRateLimiter.kt index 54e5f61b760..2774f82dfdc 100644 --- a/rest/src/main/kotlin/ratelimit/ParallelRequestRateLimiter.kt +++ b/rest/src/main/kotlin/ratelimit/ParallelRequestRateLimiter.kt @@ -30,9 +30,9 @@ class ParallelRequestRateLimiter(clock: Clock = Clock.systemUTC()) : AbstractRat get() = parallelLogger override fun newToken(request: Request<*, *>, buckets: List): RequestToken = - ParallelRequestToken(request.identifier, buckets) + ParallelRequestToken(this, request.identifier, buckets) - private inner class ParallelRequestToken(identity: RequestIdentifier, requestBuckets: List) : - AbstractRequestToken(identity, requestBuckets) + private inner class ParallelRequestToken(rateLimiter: ParallelRequestRateLimiter, identity: RequestIdentifier, requestBuckets: List) : + AbstractRequestToken(rateLimiter, identity, requestBuckets) } \ No newline at end of file diff --git a/rest/src/main/kotlin/ratelimit/RequestRateLimiter.kt b/rest/src/main/kotlin/ratelimit/RequestRateLimiter.kt index 106f5ee951a..9e7dc8a424f 100644 --- a/rest/src/main/kotlin/ratelimit/RequestRateLimiter.kt +++ b/rest/src/main/kotlin/ratelimit/RequestRateLimiter.kt @@ -54,26 +54,30 @@ data class RateLimit(val total: Total, val remaining: Remaining) { companion object } -inline class Total(val value: Long) { +@JvmInline +value class Total(val value: Long) { companion object } -inline class Remaining(val value: Long) { +@JvmInline +value class Remaining(val value: Long) { companion object } /** * The unique identifier of this bucket. */ -inline class BucketKey(val value: String) { +@JvmInline +value class BucketKey(val value: String) { companion object } /** * The [instant][value] when the current bucket gets reset. */ -inline class Reset(val value: Instant) { +@JvmInline +value class Reset(val value: Instant) { companion object } diff --git a/rest/src/test/kotlin/ratelimit/AbstractRequestRateLimiterTest.kt b/rest/src/test/kotlin/ratelimit/AbstractRequestRateLimiterTest.kt index 8d860a982ba..e6460977827 100644 --- a/rest/src/test/kotlin/ratelimit/AbstractRequestRateLimiterTest.kt +++ b/rest/src/test/kotlin/ratelimit/AbstractRequestRateLimiterTest.kt @@ -14,6 +14,7 @@ import java.time.ZoneOffset import kotlin.IllegalStateException import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.time.Duration import kotlin.time.ExperimentalTime import kotlin.time.seconds import kotlin.time.toJavaDuration @@ -24,7 +25,7 @@ abstract class AbstractRequestRateLimiterTest { abstract fun newRequestRateLimiter(clock: Clock) : RequestRateLimiter - private val timeout = 1000.seconds + private val timeout = Duration.seconds(1000) private val instant = Instant.EPOCH private val RateLimit.Companion.exhausted get() = RateLimit(Total(5), Remaining(0)) @@ -74,7 +75,7 @@ abstract class AbstractRequestRateLimiterTest { rateLimiter.sendRequest(clock, 1, rateLimit = RateLimit.exhausted) rateLimiter.sendRequest(clock, 1, rateLimit = RateLimit(Total(5), Remaining(5))) - assertEquals(timeout.inMilliseconds.toLong(), currentTime) + assertEquals(timeout.inWholeMilliseconds, currentTime) } @Test @@ -86,7 +87,7 @@ abstract class AbstractRequestRateLimiterTest { rateLimiter.sendRequest(clock, 2, 1 , rateLimit = RateLimit(Total(5), Remaining(5))) //discovery rateLimiter.sendRequest(clock, 2, 1, rateLimit = RateLimit(Total(5), Remaining(5))) - assertEquals(timeout.inMilliseconds.toLong(), currentTime) + assertEquals(timeout.inWholeMilliseconds, currentTime) } @Test