Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use gateway url provided in Ready event for resuming #666

Merged
merged 9 commits into from
Oct 4, 2022
5 changes: 3 additions & 2 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -11568,14 +11568,15 @@ public abstract class dev/kord/core/event/gateway/GatewayEvent : dev/kord/core/e
}

public final class dev/kord/core/event/gateway/ReadyEvent : dev/kord/core/event/gateway/GatewayEvent, dev/kord/core/entity/Strategizable {
public fun <init> (ILjava/util/Set;Ldev/kord/core/entity/User;Ljava/lang/String;Ldev/kord/core/Kord;ILjava/lang/Object;Ldev/kord/core/supplier/EntitySupplier;)V
public synthetic fun <init> (ILjava/util/Set;Ldev/kord/core/entity/User;Ljava/lang/String;Ldev/kord/core/Kord;ILjava/lang/Object;Ldev/kord/core/supplier/EntitySupplier;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (ILjava/util/Set;Ldev/kord/core/entity/User;Ljava/lang/String;Ljava/lang/String;Ldev/kord/core/Kord;ILjava/lang/Object;Ldev/kord/core/supplier/EntitySupplier;)V
public synthetic fun <init> (ILjava/util/Set;Ldev/kord/core/entity/User;Ljava/lang/String;Ljava/lang/String;Ldev/kord/core/Kord;ILjava/lang/Object;Ldev/kord/core/supplier/EntitySupplier;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun getCustomContext ()Ljava/lang/Object;
public final fun getGatewayVersion ()I
public final fun getGuildIds ()Ljava/util/Set;
public final fun getGuilds ()Ljava/util/Set;
public final fun getGuilds (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun getKord ()Ldev/kord/core/Kord;
public final fun getResumeGatewayUrl ()Ljava/lang/String;
public final fun getSelf ()Ldev/kord/core/entity/User;
public final fun getSessionId ()Ljava/lang/String;
public fun getShard ()I
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/kotlin/event/gateway/Events.kt
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public class ReadyEvent(
public val guildIds: Set<Snowflake>,
public val self: User,
public val sessionId: String,
public val resumeGatewayUrl: String,
override val kord: Kord,
override val shard: Int,
override val customContext: Any?,
Expand All @@ -153,11 +154,10 @@ public class ReadyEvent(
public suspend fun getGuilds(): Flow<Guild> = supplier.guilds.filter { it.id in guildIds }

override fun withStrategy(strategy: EntitySupplyStrategy<*>): ReadyEvent =
ReadyEvent(gatewayVersion, guildIds, self, sessionId, kord, shard, customContext, strategy.supply(kord))
ReadyEvent(gatewayVersion, guildIds, self, sessionId, resumeGatewayUrl, kord, shard, customContext, strategy.supply(kord))

override fun toString(): String {
return "ReadyEvent(gatewayVersion=$gatewayVersion, guildIds=$guildIds, self=$self, sessionId='$sessionId', kord=$kord, shard=$shard, supplier=$supplier)"
}
override fun toString(): String = "ReadyEvent(gatewayVersion=$gatewayVersion, guildIds=$guildIds, self=$self, " +
"sessionId='$sessionId', resumeGatewayUrl=$resumeGatewayUrl, kord=$kord, shard=$shard, supplier=$supplier)"
}

public class ResumedEvent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ internal class LifeCycleEventHandler : BaseGatewayEventHandler() {
guilds,
User(self, kord),
sessionId,
resumeGatewayUrl,
kord,
shard,
context?.get(),
Expand Down
18 changes: 10 additions & 8 deletions gateway/api/gateway.api
Original file line number Diff line number Diff line change
Expand Up @@ -1506,27 +1506,29 @@ public final class dev/kord/gateway/Ready : dev/kord/gateway/DispatchEvent {

public final class dev/kord/gateway/ReadyData {
public static final field Companion Ldev/kord/gateway/ReadyData$Companion;
public synthetic fun <init> (IILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;Lkotlinx/serialization/internal/SerializationConstructorMarker;)V
public fun <init> (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;)V
public synthetic fun <init> (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (IILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;Lkotlinx/serialization/internal/SerializationConstructorMarker;)V
public fun <init> (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;)V
public synthetic fun <init> (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()I
public final fun component10 ()Ldev/kord/common/entity/optional/Optional;
public final fun component10 ()Ljava/util/List;
public final fun component11 ()Ldev/kord/common/entity/optional/Optional;
public final fun component2 ()Ldev/kord/common/entity/DiscordUser;
public final fun component3 ()Ljava/util/List;
public final fun component4 ()Ljava/util/List;
public final fun component5 ()Ljava/lang/String;
public final fun component6 ()Ldev/kord/common/entity/optional/Optional;
public final fun component6 ()Ljava/lang/String;
public final fun component7 ()Ldev/kord/common/entity/optional/Optional;
public final fun component8 ()Ldev/kord/common/entity/optional/Optional;
public final fun component9 ()Ljava/util/List;
public final fun copy (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;)Ldev/kord/gateway/ReadyData;
public static synthetic fun copy$default (Ldev/kord/gateway/ReadyData;ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;ILjava/lang/Object;)Ldev/kord/gateway/ReadyData;
public final fun component9 ()Ldev/kord/common/entity/optional/Optional;
public final fun copy (ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;)Ldev/kord/gateway/ReadyData;
public static synthetic fun copy$default (Ldev/kord/gateway/ReadyData;ILdev/kord/common/entity/DiscordUser;Ljava/util/List;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ldev/kord/common/entity/optional/Optional;Ljava/util/List;Ldev/kord/common/entity/optional/Optional;ILjava/lang/Object;)Ldev/kord/gateway/ReadyData;
public fun equals (Ljava/lang/Object;)Z
public final fun getApplication ()Ldev/kord/common/entity/optional/Optional;
public final fun getGeoOrderedRtcRegions ()Ldev/kord/common/entity/optional/Optional;
public final fun getGuildHashes ()Ldev/kord/common/entity/optional/Optional;
public final fun getGuilds ()Ljava/util/List;
public final fun getPrivateChannels ()Ljava/util/List;
public final fun getResumeGatewayUrl ()Ljava/lang/String;
public final fun getSessionId ()Ljava/lang/String;
public final fun getShard ()Ldev/kord/common/entity/optional/Optional;
public final fun getTraces ()Ljava/util/List;
Expand Down
22 changes: 11 additions & 11 deletions gateway/src/main/kotlin/DefaultGateway.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import dev.kord.gateway.GatewayCloseCode.*
import dev.kord.gateway.handler.*
import dev.kord.gateway.retry.Retry
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.http.*
Expand Down Expand Up @@ -65,7 +64,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

override val coroutineContext: CoroutineContext = SupervisorJob() + data.dispatcher

private val compression: Boolean = URLBuilder(data.url).parameters.contains("compress", "zlib-stream")
private val compression: Boolean

private val _ping = MutableStateFlow<Duration?>(null)
override val ping: StateFlow<Duration?> get() = _ping
Expand All @@ -88,9 +87,13 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
private val stateMutex = Mutex()

init {
val initialUrl = Url(data.url)
compression = initialUrl.parameters.contains("compress", "zlib-stream")

val sequence = Sequence()
SequenceHandler(events, sequence)
handshakeHandler = HandshakeHandler(events, ::trySend, sequence, data.identifyRateLimiter, data.reconnectRetry)
handshakeHandler =
HandshakeHandler(events, initialUrl, ::trySend, sequence, data.identifyRateLimiter, data.reconnectRetry)
HeartbeatHandler(events, ::trySend, { restart(Close.ZombieConnection) }, { _ping.value = it }, sequence)
ReconnectHandler(events) { restart(Close.Reconnecting) }
InvalidSessionHandler(events) { restart(it) }
Expand All @@ -102,7 +105,9 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {

while (data.reconnectRetry.hasNext && state.value is State.Running) {
try {
socket = webSocket(data.url)
val url = handshakeHandler.gatewayUrl
defaultGatewayLogger.trace { "opening gateway connection to $url" }
socket = data.client.webSocketSession { url(url) }
/**
* https://discord.com/developers/docs/topics/gateway#transport-compression
*
Expand Down Expand Up @@ -146,8 +151,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}

private suspend fun resetState(configuration: GatewayConfiguration) = stateMutex.withLock {
@Suppress("UNUSED_VARIABLE")
val exhaustive = when (state.value) { //exhaustive state checking
Comment on lines -149 to -150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-exhaustive when statements over sealed types are an error now, so this is no longer needed

when (state.value) {
is State.Running -> throw IllegalStateException(gatewayRunningError)
State.Detached -> throw IllegalStateException(gatewayDetachedError)
State.Stopped -> Unit
Expand Down Expand Up @@ -209,7 +213,7 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
when {
!discordReason.retry -> {
state.update { State.Stopped }
throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}")
throw IllegalStateException("Gateway closed: ${reason.code} ${reason.message}")
}
discordReason.resetSession -> {
setStopped()
Expand All @@ -230,10 +234,6 @@ public class DefaultGateway(private val data: DefaultGatewayData) : Gateway {
}
}

private suspend fun webSocket(url: String) = data.client.webSocketSession {
url(url)
}

override suspend fun stop() {
check(state.value !is State.Detached) { "The resources of this gateway are detached, create another one" }
data.eventFlow.emit(Close.UserClose)
Expand Down
2 changes: 2 additions & 0 deletions gateway/src/main/kotlin/Event.kt
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ public data class ReadyData(
val guilds: List<DiscordUnavailableGuild>,
@SerialName("session_id")
val sessionId: String,
@SerialName("resume_gateway_url")
val resumeGatewayUrl: String,
@SerialName("geo_ordered_rtc_regions")
val geoOrderedRtcRegions: Optional<JsonElement?> = Optional.Missing(),
@SerialName("guild_hashes")
Expand Down
40 changes: 23 additions & 17 deletions gateway/src/main/kotlin/handler/HandshakeHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import dev.kord.common.ratelimit.RateLimiter
import dev.kord.common.ratelimit.consume
import dev.kord.gateway.*
import dev.kord.gateway.retry.Retry
import kotlinx.atomicfu.AtomicRef
import io.ktor.http.*
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.update
import kotlinx.coroutines.flow.Flow

internal class HandshakeHandler(
flow: Flow<Event>,
private val initialUrl: Url,
private val send: suspend (Command) -> Unit,
private val sequence: Sequence,
private val identifyRateLimiter: RateLimiter,
Expand All @@ -19,31 +19,37 @@ internal class HandshakeHandler(

lateinit var configuration: GatewayConfiguration

private val session: AtomicRef<String?> = atomic(null)
// see https://discord.com/developers/docs/topics/gateway#resuming
private class ResumeContext(val sessionId: String, val resumeUrl: Url)

private val identify
get() = configuration.identify
private val resumeContext = atomic<ResumeContext?>(initial = null)
val gatewayUrl get() = resumeContext.value?.resumeUrl ?: initialUrl

private val resume
get() = Resume(configuration.token, session.value!!, sequence.value ?: 0)

private val sessionStart get() = session.value == null

override fun start() {
on<Ready> { event ->
session.update { event.data.sessionId }
private val resumeOrIdentify
get() = when (val sessionId = resumeContext.value?.sessionId) {
null -> configuration.identify
else -> Resume(configuration.token, sessionId, sequence.value ?: 0)
Comment on lines +28 to +31
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resumeOrIdentify combines the previous properties sessionStart, resume and identify so that the session id is only queried once

}

override fun start() {
on<Hello> {
reconnectRetry.reset() //connected and read without problems, resetting retry counter
reconnectRetry.reset() // connected and read without problems, resetting retry counter
identifyRateLimiter.consume {
if (sessionStart) send(identify)
else send(resume)
send(resumeOrIdentify)
}
lukellmann marked this conversation as resolved.
Show resolved Hide resolved
}

on<Ready> { event ->
// keep custom query params
val resumeUrl = URLBuilder(event.data.resumeGatewayUrl)
.apply { parameters.appendMissing(initialUrl.parameters) }
.build()

resumeContext.value = ResumeContext(event.data.sessionId, resumeUrl)
}

on<Close.SessionReset> {
session.update { null }
resumeContext.value = null
}
lukellmann marked this conversation as resolved.
Show resolved Hide resolved
}
}
1 change: 1 addition & 0 deletions gateway/src/test/kotlin/json/SerializationTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class SerializationTest {
}
privateChannels shouldBe listOf()
sessionId shouldBe "12345"
resumeGatewayUrl shouldBe "wss://us-east1-b.gateway.discord.gg"
with(shard.value!!) {
index.shouldBe(0)
count.shouldBe(5)
Expand Down
1 change: 1 addition & 0 deletions gateway/src/test/resources/json/event/ready.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
}
],
"session_id": "12345",
"resume_gateway_url": "wss://us-east1-b.gateway.discord.gg",
"_trace": [
"test"
],
Expand Down