Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ability to limit request size and connection count #221

Merged
merged 18 commits into from
Oct 8, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ let package = Package(
dependencies: []),
.target(
name: "KituraNet",
dependencies: ["NIO", "NIOFoundationCompat", "NIOHTTP1", "NIOSSL", "SSLService", "LoggerAPI", "NIOWebSocket", "CLinuxHelpers", "NIOExtras"]),
dependencies: ["NIO", "NIOFoundationCompat", "NIOHTTP1", "NIOSSL", "SSLService", "LoggerAPI", "NIOWebSocket", "CLinuxHelpers", "NIOConcurrencyHelpers", "NIOExtras"]),
.testTarget(
name: "KituraNetTests",
dependencies: ["KituraNet"])
Expand Down
5 changes: 3 additions & 2 deletions Sources/KituraNet/HTTP/HTTP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ public class HTTP {
let server = HTTP.createServer()
````
*/
public static func createServer() -> HTTPServer {
return HTTPServer()
public static func createServer(serverConfig: HTTPServerConfiguration = .default) -> HTTPServer {
let serverConfig = serverConfig
return HTTPServer(serverConfig: serverConfig)
}

/**
Expand Down
68 changes: 66 additions & 2 deletions Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle
/// The HTTPServer instance on which this handler is installed
var server: HTTPServer

var requestSize: Int = 0

/// The serverRequest related to this handler instance
var serverRequest: HTTPServerRequest?

Expand Down Expand Up @@ -66,7 +68,6 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle
self.enableSSLVerification = true
}
}

public typealias InboundIn = HTTPServerRequestPart
public typealias OutboundOut = HTTPServerResponsePart

Expand All @@ -76,12 +77,32 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle
// If an upgrade to WebSocket fails, both `errorCaught` and `channelRead` are triggered.
// We'd want to return the error via `errorCaught`.
if errorResponseSent { return }

switch request {
case .head(let header):
serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification)
if let requestSizeLimit = server.serverConfig.requestSizeLimit,
let contentLength = header.headers["Content-Length"].first,
let contentLengthValue = Int(contentLength) {
if contentLengthValue > requestSizeLimit {
sendStatus(context: context)
context.close()
}
}
let headerSize = getHeaderSize(of: header)
if let requestSizeLimit = server.serverConfig.requestSizeLimit {
if headerSize > requestSizeLimit {
sendStatus(context: context)
}
}
serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification)
self.clientRequestedKeepAlive = header.isKeepAlive
case .body(var buffer):
requestSize += buffer.readableBytes
if let requestSizeLimit = server.serverConfig.requestSizeLimit {
if requestSize > requestSizeLimit {
sendStatus(context: context)
}
}
guard let serverRequest = serverRequest else {
Log.error("No ServerRequest available")
return
Expand All @@ -91,7 +112,23 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle
} else {
serverRequest.buffer!.byteBuffer.writeBuffer(&buffer)
}

case .end:
requestSize = 0
server.connectionCount.add(1)
if let connectionLimit = server.serverConfig.connectionLimit {
if server.connectionCount.load() > connectionLimit {
let statusCode = HTTPStatusCode.serviceUnavailable.rawValue
let statusDescription = HTTP.statusCodes[statusCode] ?? ""
do {
serverResponse = HTTPServerResponse(channel: context.channel, handler: self)
errorResponseSent = true
try serverResponse?.end(with: .serviceUnavailable, message: statusDescription)
} catch {
Log.error("Failed to send error response")
}
}
}
serverResponse = HTTPServerResponse(channel: context.channel, handler: self)
//Make sure we use the latest delegate registered with the server
DispatchQueue.global().async {
Expand Down Expand Up @@ -152,4 +189,31 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle
func updateKeepAliveState() {
keepAliveState.decrement()
}

func channelInactive(context: ChannelHandlerContext, httpServer: HTTPServer) {
httpServer.connectionCount.sub(1)
}

func getHeaderSize(of header: HTTPRequestHead) -> Int {
djones6 marked this conversation as resolved.
Show resolved Hide resolved
var headerSize = 0
headerSize += header.uri.cString(using: .utf8)?.count ?? 0
headerSize += header.version.description.cString(using: .utf8)?.count ?? 0
headerSize += header.method.rawValue.cString(using: .utf8)?.count ?? 0
for headers in header.headers {
headerSize += headers.name.cString(using: .utf8)?.count ?? 0
headerSize += headers.value.cString(using: .utf8)?.count ?? 0
}
return headerSize
}

func sendStatus(context: ChannelHandlerContext) {
let statusDescription = HTTP.statusCodes[HTTPStatusCode.requestTooLong.rawValue] ?? ""
do {
serverResponse = HTTPServerResponse(channel: context.channel, handler: self)
errorResponseSent = true
try serverResponse?.end(with: .requestTooLong, message: statusDescription)
} catch {
Log.error("Failed to send error response")
}
}
}
24 changes: 20 additions & 4 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import SSLService
import LoggerAPI
import NIOWebSocket
import CLinuxHelpers
import Foundation
import NIOExtras
import NIOConcurrencyHelpers

#if os(Linux)
import Glibc
Expand Down Expand Up @@ -127,22 +129,35 @@ public class HTTPServer: Server {

var quiescingHelper: ServerQuiescingHelper?

private var ctx: ChannelHandlerContext?
djones6 marked this conversation as resolved.
Show resolved Hide resolved

/// server configuration
public var serverConfig: HTTPServerConfiguration
djones6 marked this conversation as resolved.
Show resolved Hide resolved

//counter for no of connections
var connectionCount = Atomic(value: 0)

// The data to be written as a part of the response.
//private var buffer: ByteBuffer
djones6 marked this conversation as resolved.
Show resolved Hide resolved

/**
Creates an HTTP server object.

### Usage Example: ###
````swift
let server = HTTPServer()
let config =HTTPServerConfiguration(requestSize: 1000, coonectionLimit: 100)
let server = HTTPServer(serverconfig: config)
server.listen(on: 8080)
````
*/
public init() {
public init(serverConfig: HTTPServerConfiguration = .default) {
djones6 marked this conversation as resolved.
Show resolved Hide resolved
#if os(Linux)
let numberOfCores = Int(linux_sched_getaffinity())
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: numberOfCores > 0 ? numberOfCores : System.coreCount)
#else
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount)
#endif
self.serverConfig = serverConfig
}

/**
Expand Down Expand Up @@ -309,15 +324,16 @@ public class HTTPServer: Server {
}
.childChannelInitializer { channel in
let httpHandler = HTTPRequestHandler(for: self)
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in
let config: HTTPUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in
self.ctx = ctx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this new property ctx ever get read? I can't see reference to it elsewhere in this PR

_ = channel.pipeline.removeHandler(httpHandler)
})
return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config, withErrorHandling: true).flatMap {
if let nioSSLServerHandler = self.createNIOSSLServerHandler() {
_ = channel.pipeline.addHandler(nioSSLServerHandler, position: .first)
}
return channel.pipeline.addHandler(httpHandler)
}
}
djones6 marked this conversation as resolved.
Show resolved Hide resolved
}

let listenerDescription: String
Expand Down
43 changes: 43 additions & 0 deletions Sources/KituraNet/HTTP/HTTPServerConfiguration.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright IBM Corporation 2019
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import Foundation

public struct HTTPServerConfiguration {
djones6 marked this conversation as resolved.
Show resolved Hide resolved
/// Defines the maximum size of an incoming request, in bytes. If requests are received that are larger
/// than this limit, they will be rejected and the connection will be closed.

public let requestSizeLimit: Int?

/// Defines the maximum number of concurrent connections that a server should accept. Clients attempting
/// to connect when this limit has been reached will be rejected.
public let connectionLimit: Int?

public static var `default` = HTTPServerConfiguration(requestSizeLimit: 1024 * 1024, connectionLimit: 1024)


/// Create an `HTTPServerConfiguration` to determine the behaviour of a `Server`.
///
/// - parameter requestSizeLimit: The maximum size of an incoming request. Defaults to `IncomingSocketOptions.defaultRequestSizeLimit`.
/// - parameter connectionLimit: The maximum number of concurrent connections. Defaults to `IncomingSocketOptions.defaultConnectionLimit`.

public init(requestSizeLimit: Int?,connectionLimit: Int?)
{
self.requestSizeLimit = requestSizeLimit
self.connectionLimit = connectionLimit
}

}
1 change: 1 addition & 0 deletions Sources/KituraNet/HTTP/RequestsizeHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

djones6 marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 28 additions & 0 deletions Tests/KituraNetTests/ClientE2ETests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ClientE2ETests: KituraNetTest {
("testQueryParameters", testQueryParameters),
("testRedirect", testRedirect),
("testPercentEncodedQuery", testPercentEncodedQuery),
("testRequestSize",testRequestSize),
]
}

Expand All @@ -52,6 +53,33 @@ class ClientE2ETests: KituraNetTest {

let delegate = TestServerDelegate()

func testRequestSize() {
performServerTest(serverConfig: HTTPServerConfiguration(requestSizeLimit: 10000, connectionLimit: 100),delegate, useSSL: false, asyncTasks: { expectation in
let payload = "[" + contentTypesString + "," + contentTypesString + contentTypesString + "," + contentTypesString + "]"
self.performRequest("post", path: "/largepost", callback: {response in
XCTAssertEqual(response?.statusCode, HTTPStatusCode.requestTooLong)
do {
let expectedResult = "Request Entity Too Large"
var data = Data()
let count = try response?.readAllData(into: &data)
XCTAssertEqual(count, expectedResult.count, "Result should have been \(expectedResult.count) bytes, was \(String(describing: count)) bytes")
let postValue = String(data: data, encoding: .utf8)
if let postValue = postValue {
print("postvalue:", postValue)
XCTAssertEqual(postValue, expectedResult)
} else {
XCTFail("postValue's value wasn't an UTF8 string")
}
} catch {
XCTFail("Failed reading the body of the response")
}
expectation.fulfill()
}) {request in
request.write(from: payload)
}
})
}

func testHeadRequests() {
performServerTest(delegate) { expectation in
self.performRequest("head", path: "/headtest", callback: {response in
Expand Down
93 changes: 93 additions & 0 deletions Tests/KituraNetTests/ConnectionLimitTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import Foundation
import Dispatch
import NIO
import XCTest
import KituraNet
import NIOHTTP1
import NIOWebSocket
import LoggerAPI

class ConnectionLimitTests: KituraNetTest {
static var allTests: [(String, (ConnectionLimitTests) -> () throws -> Void)] {
return [
("testConnectionLimit", testConnectionLimit),
]
}

override func setUp() {
doSetUp()
}

override func tearDown() {
doTearDown()
}
private func sendRequest(request: HTTPRequestHead, on channel: Channel) {
channel.write(NIOAny(HTTPClientRequestPart.head(request)), promise: nil)
try! channel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil))).wait()
}

func establishConnection(expectation: XCTestExpectation, responseHandler: HTTPResponseHandler) {
var channel: Channel
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
let bootstrap = ClientBootstrap(group: group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelInitializer { channel in
channel.pipeline.addHTTPClientHandlers().flatMap {_ in
channel.pipeline.addHandler(responseHandler)
}
}
do {
try channel = bootstrap.connect(host: "localhost", port: self.port).wait()
let request = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: .GET, uri: "/")
self.sendRequest(request: request, on: channel)
} catch let e {
XCTFail("Connection is not established.")
}
}

func testConnectionLimit() {
djones6 marked this conversation as resolved.
Show resolved Hide resolved
let delegate = TestConnectionLimitDelegate()
performServerTest(serverConfig: HTTPServerConfiguration(requestSizeLimit: 10000, connectionLimit: 1), delegate, socketType: .tcp, useSSL: false, asyncTasks: { expectation in
let payload = "Hello, World!"
var payloadBuffer = ByteBufferAllocator().buffer(capacity: 1024)
payloadBuffer.writeString(payload)
_ = self.establishConnection(expectation: expectation, responseHandler: HTTPResponseHandler(expectedStatus:HTTPResponseStatus.ok, expectation: expectation))
}, { expectation in
let payload = "Hello, World!"
var payloadBuffer = ByteBufferAllocator().buffer(capacity: 1024)
payloadBuffer.writeString(payload)
_ = self.establishConnection(expectation: expectation, responseHandler: HTTPResponseHandler(expectedStatus:HTTPResponseStatus.serviceUnavailable, expectation: expectation))
})
}
}

class TestConnectionLimitDelegate: ServerDelegate {
func handle(request: ServerRequest, response: ServerResponse) {
do {
try response.end()
} catch {
XCTFail("Error while writing response")
}
}
}

class HTTPResponseHandler: ChannelInboundHandler {
let expectedStatus: HTTPResponseStatus
let expectation: XCTestExpectation
init(expectedStatus: HTTPResponseStatus, expectation: XCTestExpectation) {
self.expectedStatus = expectedStatus
self.expectation = expectation
}
typealias InboundIn = HTTPClientResponsePart
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = self.unwrapInboundIn(data)
switch response {
case .head(let header):
let status = header.status
XCTAssertEqual(status, expectedStatus)
expectation.fulfill()
default: do {
}
}
}
}
Loading