Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove a race condition from the WebSocket upgrade #217

Merged
merged 1 commit into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Remove a race condition from the WebSocket upgrade
An upgrade to WebSocket invokes three handlers in an order. First, the
`shouldUpgrade` handler supplied by the user is run to decide if an upgrade
must go through. This handler also supplies additional headers to be sent in
the response. Next, after the WebSocket upgrader is done upgrading the
pipeline, the `completionHandler` is called. We remove Kitura-NIO's
`HTTPRequestHandler` here. Next, the `upgradePipelineHandler` is invoked. This
handler allows us to add all the WebSocket related `ChannelHandler`s.

For an undocumented reason, we saved the `ChannelHandlerContext` received by
the `completionHandler` in the `HTTPServer` and later used it upgrade the
pipeline in `upgradePipelineHandler`. This can easily lead to a race condition
where we saved the `ChannelHandlerContext` for a connection, into the
`HTTPServer` but before it could be used in `upgradePipelineHandler`, it was
overwritten by the upgrade happening on another connection. Consequently, we
never upgraded the pipeline of the former connection. This could lead to
different kinds of failures.

The `upgradePipelineHandler` has `Channel` as one of its parameters. Hence
there is no need to store the `ChannelHandlerContext` for use in this closure.
Consequently, we have to use a `Channel` to initialize an `HTTPServerRequest`,
which, in turn, is modified to accept a `Channel` instead of a
`ChannelHandlerContext`.
  • Loading branch information
Pushkar Kulkarni committed Jul 22, 2019
commit 09b88f886f250360408f65bad108789bd910802d
2 changes: 1 addition & 1 deletion Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle

switch request {
case .head(let header):
serverRequest = HTTPServerRequest(ctx: context, requestHead: header, enableSSL: enableSSLVerification)
serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification)
self.clientRequestedKeepAlive = header.isKeepAlive
case .body(var buffer):
guard let serverRequest = serverRequest else {
Expand Down
20 changes: 8 additions & 12 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ public class HTTPServer: Server {
/// The event loop group on which the HTTP handler runs
private let eventLoopGroup: MultiThreadedEventLoopGroup

private var ctx: ChannelHandlerContext?

/**
Creates an HTTP server object.

Expand Down Expand Up @@ -193,19 +191,18 @@ public class HTTPServer: Server {
}

/// Creates upgrade request and adds WebSocket handler to pipeline
private func upgradeHandler(webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
guard let ctx = self.ctx else { fatalError("The channel was probably closed during a protocol upgrade.") }
return ctx.eventLoop.submit {
let request = HTTPServerRequest(ctx: ctx, requestHead: request, enableSSL: false)
private func upgradeHandler(channel: Channel, webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
return channel.eventLoop.submit {
let request = HTTPServerRequest(channel: channel, requestHead: request, enableSSL: false)
return webSocketHandlerFactory.handler(for: request)
}.flatMap { (handler: ChannelHandler) -> EventLoopFuture<Void> in
return ctx.channel.pipeline.addHandler(handler).flatMap {
return channel.pipeline.addHandler(handler).flatMap {
if let _extensions = request.headers["Sec-WebSocket-Extensions"].first {
let handlers = webSocketHandlerFactory.extensionHandlers(header: _extensions)
return ctx.channel.pipeline.addHandlers(handlers, position: .before(handler))
return channel.pipeline.addHandlers(handlers, position: .before(handler))
} else {
// No extensions. We must return success.
return ctx.channel.eventLoop.makeSucceededFuture(())
return channel.eventLoop.makeSucceededFuture(())
}
}
}
Expand All @@ -222,7 +219,7 @@ public class HTTPServer: Server {

private func generateUpgradePipelineHandler(_ webSocketHandlerFactory: ProtocolHandlerFactory) -> UpgradePipelineHandlerFunction {
return { (channel: Channel, request: HTTPRequestHead) in
return self.upgradeHandler(webSocketHandlerFactory: webSocketHandlerFactory, request: request)
return self.upgradeHandler(channel: channel, webSocketHandlerFactory: webSocketHandlerFactory, request: request)
}
}

Expand Down Expand Up @@ -304,8 +301,7 @@ public class HTTPServer: Server {
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: allowPortReuse ? 1 : 0)
.childChannelInitializer { channel in
let httpHandler = HTTPRequestHandler(for: self)
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in
self.ctx = ctx
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in
_ = channel.pipeline.removeHandler(httpHandler)
})
return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config, withErrorHandling: true).flatMap {
Expand Down
14 changes: 7 additions & 7 deletions Sources/KituraNet/HTTP/HTTPServerRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public class HTTPServerRequest: ServerRequest {
*/
public var method: String

private let ctx: ChannelHandlerContext
private let channel: Channel

private var enableSSL: Bool = false

Expand Down Expand Up @@ -208,20 +208,20 @@ public class HTTPServerRequest: ServerRequest {
}
}

init(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead, enableSSL: Bool) {
init(channel: Channel, requestHead: HTTPRequestHead, enableSSL: Bool) {
// An HTTPServerRequest may be created only on the EventLoop assigned to handle
// the connection on which the HTTP request arrived.
assert(ctx.eventLoop.inEventLoop)
self.ctx = ctx
assert(channel.eventLoop.inEventLoop)
self.channel = channel
self.headers = HeadersContainer(with: requestHead.headers)
self.method = requestHead.method.rawValue
self.httpVersionMajor = UInt16(requestHead.version.major)
self.httpVersionMinor = UInt16(requestHead.version.minor)
self.rawURLString = requestHead.uri
self.enableSSL = enableSSL
self.localAddressHost = HTTPServerRequest.host(socketAddress: ctx.localAddress)
self.localAddressPort = ctx.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: ctx.remoteAddress)
self.localAddressHost = HTTPServerRequest.host(socketAddress: channel.localAddress)
self.localAddressPort = channel.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: channel.remoteAddress)
}

var buffer: BufferList?
Expand Down