diff --git a/Sources/KituraNet/HTTP/HTTPServerResponse.swift b/Sources/KituraNet/HTTP/HTTPServerResponse.swift index a77543f7..a2ec4e3b 100644 --- a/Sources/KituraNet/HTTP/HTTPServerResponse.swift +++ b/Sources/KituraNet/HTTP/HTTPServerResponse.swift @@ -52,11 +52,15 @@ public class HTTPServerResponse: ServerResponse { private var httpVersion: HTTPVersion /// The data to be written as a part of the response. - private var buffer: ByteBuffer? + private var buffer: ByteBuffer + + /// Initial size of the response buffer (inherited from KituraNet) + private static let bufferSize = 2000 init(channel: Channel, handler: HTTPRequestHandler) { self.channel = channel self.handler = handler + self.buffer = channel.allocator.buffer(capacity: HTTPServerResponse.bufferSize) let httpVersionMajor = handler.serverRequest?.httpVersionMajor ?? 0 let httpVersionMinor = handler.serverRequest?.httpVersionMinor ?? 0 self.httpVersion = HTTPVersion(major: httpVersionMajor, minor: httpVersionMinor) @@ -71,11 +75,8 @@ public class HTTPServerResponse: ServerResponse { fatalError("No channel available to write.") } - if buffer == nil { - execute(on: channel.eventLoop) { - self.buffer = channel.allocator.buffer(capacity: string.utf8.count) - self.buffer!.write(string: string) - } + execute(on: channel.eventLoop) { + self.buffer.write(string: string) } } @@ -87,11 +88,8 @@ public class HTTPServerResponse: ServerResponse { fatalError("No channel available to write.") } - if buffer == nil { - execute(on: channel.eventLoop) { - self.buffer = channel.allocator.buffer(capacity: data.count) - self.buffer!.write(bytes: data) - } + execute(on: channel.eventLoop) { + self.buffer.write(bytes: data) } } @@ -146,7 +144,7 @@ public class HTTPServerResponse: ServerResponse { } let response = HTTPResponseHead(version: httpVersion, status: status, headers: headers.httpHeaders()) channel.write(handler.wrapOutboundOut(.head(response)), promise: nil) - if let buffer = buffer { + if buffer.readableBytes > 0 { channel.write(handler.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) } channel.writeAndFlush(handler.wrapOutboundOut(.end(nil)), promise: nil) @@ -184,8 +182,8 @@ public class HTTPServerResponse: ServerResponse { headers["Connection"] = ["Close"] let response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: status, headers: headers.httpHeaders()) channel.write(handler.wrapOutboundOut(.head(response)), promise: nil) - if withBody && buffer != nil { - channel.write(handler.wrapOutboundOut(.body(.byteBuffer(buffer!))), promise: nil) + if withBody && buffer.readableBytes > 0 { + channel.write(handler.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) } channel.writeAndFlush(handler.wrapOutboundOut(.end(nil)), promise: nil) handler.updateKeepAliveState() @@ -205,9 +203,7 @@ public class HTTPServerResponse: ServerResponse { /// Reset this response object back to its initial state public func reset() { status = HTTPStatusCode.OK.rawValue - if buffer != nil { - buffer!.clear() - } + buffer.clear() headers.removeAll() headers["Date"] = [SPIUtils.httpDate()] } diff --git a/Tests/KituraNetTests/HTTPResponseTests.swift b/Tests/KituraNetTests/HTTPResponseTests.swift index e8d9da0d..adc09b1c 100644 --- a/Tests/KituraNetTests/HTTPResponseTests.swift +++ b/Tests/KituraNetTests/HTTPResponseTests.swift @@ -22,7 +22,9 @@ import XCTest class HTTPResponseTests: KituraNetTest { static var allTests : [(String, (HTTPResponseTests) -> () throws -> Void)] { return [ - ("testContentTypeHeaders", testContentTypeHeaders) + ("testContentTypeHeaders", testContentTypeHeaders), + ("testHeadersContainerHTTPHeaders", testHeadersContainerHTTPHeaders), + ("testMultipleWritesToResponse", testMultipleWritesToResponse), ] } @@ -81,4 +83,38 @@ class HTTPResponseTests: KituraNetTest { headers.removeAll() XCTAssertFalse(headers.httpHeaders().contains(name: "foo")) } + + func testMultipleWritesToResponse() { + performServerTest(WriteTwiceServerDelegate(), useSSL: false) { expectation in + self.performRequest("get", path: "/writetwice", callback: { response in + XCTAssertEqual(response?.statusCode, HTTPStatusCode.OK, "Status code wasn't .Ok was \(String(describing: response?.statusCode))") + do { + var data = Data() + _ = try response?.readAllData(into: &data) + let receivedString = String(data: data as Data, encoding: .utf8) ?? "" + XCTAssertEqual("Hello, World!", receivedString, "The string received \(receivedString) is not Hello, World!") + } catch { + XCTFail("Error: \(error)") + } + expectation.fulfill() + }) + } + } +} + +class WriteTwiceServerDelegate: ServerDelegate { + func handle(request: ServerRequest, response: ServerResponse) { + do { + response.statusCode = .OK + response.headers["Content-Type"] = ["text/plain"] + let helloData = "Hello, ".data(using: .utf8)! + let worldData = "World!".data(using: .utf8)! + response.headers["Content-Length"] = ["13"] + try response.write(from: helloData) + try response.write(from: worldData) + try response.end() + } catch { + print("Could not send a response") + } + } }