Skip to content

Commit

Permalink
Merge pull request #339 from junelife/compression
Browse files Browse the repository at this point in the history
Support websocket compression
  • Loading branch information
daltoniam authored Jun 24, 2017
2 parents 02718e7 + 44fdfa3 commit 7bf478c
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 28 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ It's Objective-C counter part can be found here: [Jetfire](https://github.com/ac
- Conforms to all of the base [Autobahn test suite](http://autobahn.ws/testsuite/).
- Nonblocking. Everything happens in the background, thanks to GCD.
- TLS/WSS support.
- Compression Extensions support ([RFC 7692](https://tools.ietf.org/html/rfc7692))
- Simple concise codebase at just a few hundred LOC.

## Example
Expand Down Expand Up @@ -197,6 +198,17 @@ socket.security = SSLSecurity(certs: [SSLCert(data: data)], usePublicKeys: true)
```
You load either a `Data` blob of your certificate or you can use a `SecKeyRef` if you have a public key you want to use. The `usePublicKeys` bool is whether to use the certificates for validation or the public keys. The public keys will be extracted from the certificates automatically if `usePublicKeys` is choosen.

### Compression Extensions

Compression Extensions ([RFC 7692](https://tools.ietf.org/html/rfc7692)) is supported in Starscream. Compression is enabled by default, however compression will only be used if it is supported by the server as well. You may enable or disable compression via the `.enableCompression` property:

```swift
socket = WebSocket(url: URL(string: "ws://localhost:8080/")!)
socket.enableCompression = false
```

Compression should be disabled if your application is transmitting already-compressed, random, or other uncompressable data.

### Custom Queue

A custom queue can be specified when delegate methods are called. By default `DispatchQueue.main` is used, thus making all delegate methods calls run on the main thread. It is important to note that all WebSocket processing is done on a background thread, only the delegate method calls are changed when modifying the queue. The actual processing is always on a background thread and will not pause your app.
Expand Down
177 changes: 177 additions & 0 deletions Source/Compression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Compression.swift
//
// Created by Joseph Ross on 7/16/14.
// Copyright © 2017 Joseph Ross.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//////////////////////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Compression implementation is implemented in conformance with RFC 7692 Compression Extensions
// for WebSocket: https://tools.ietf.org/html/rfc7692
//
//////////////////////////////////////////////////////////////////////////////////////////////////

import Foundation
import zlib

class Decompressor {
private var strm = z_stream()
private var buffer = [UInt8](repeating: 0, count: 0x2000)
private var inflateInitialized = false
private let windowBits:Int

init?(windowBits:Int) {
self.windowBits = windowBits
guard initInflate() else { return nil }
}

private func initInflate() -> Bool {
if Z_OK == inflateInit2_(&strm, -CInt(windowBits),
ZLIB_VERSION, CInt(MemoryLayout<z_stream>.size))
{
inflateInitialized = true
return true
}
return false
}

func reset() throws {
teardownInflate()
guard initInflate() else { throw NSError() }
}

func decompress(_ data: Data, finish: Bool) throws -> Data {
return try data.withUnsafeBytes { (bytes:UnsafePointer<UInt8>) -> Data in
return try decompress(bytes: bytes, count: data.count, finish: finish)
}
}

func decompress(bytes: UnsafePointer<UInt8>, count: Int, finish: Bool) throws -> Data {
var decompressed = Data()
try decompress(bytes: bytes, count: count, out: &decompressed)

if finish {
let tail:[UInt8] = [0x00, 0x00, 0xFF, 0xFF]
try decompress(bytes: tail, count: tail.count, out: &decompressed)
}

return decompressed

}

private func decompress(bytes: UnsafePointer<UInt8>, count: Int, out:inout Data) throws {
var res:CInt = 0
strm.next_in = UnsafeMutablePointer<UInt8>(mutating: bytes)
strm.avail_in = CUnsignedInt(count)

repeat {
strm.next_out = UnsafeMutablePointer<UInt8>(&buffer)
strm.avail_out = CUnsignedInt(buffer.count)

res = inflate(&strm, 0)

let byteCount = buffer.count - Int(strm.avail_out)
out.append(buffer, count: byteCount)
} while res == Z_OK && strm.avail_out == 0

guard (res == Z_OK && strm.avail_out > 0)
|| (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
else {
throw NSError(domain: WebSocket.ErrorDomain, code: Int(WebSocket.InternalErrorCode.compressionError.rawValue), userInfo: nil)
}
}

private func teardownInflate() {
if inflateInitialized, Z_OK == inflateEnd(&strm) {
inflateInitialized = false
}
}

deinit {
teardownInflate()
}
}

class Compressor {
private var strm = z_stream()
private var buffer = [UInt8](repeating: 0, count: 0x2000)
private var deflateInitialized = false
private let windowBits:Int

init?(windowBits: Int) {
self.windowBits = windowBits
guard initDeflate() else { return nil }
}

private func initDeflate() -> Bool {
if Z_OK == deflateInit2_(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
-CInt(windowBits), 8, Z_DEFAULT_STRATEGY,
ZLIB_VERSION, CInt(MemoryLayout<z_stream>.size))
{
deflateInitialized = true
return true
}
return false
}

func reset() throws {
teardownDeflate()
guard initDeflate() else { throw NSError() }
}

func compress(_ data: Data) throws -> Data {
var compressed = Data()
var res:CInt = 0
data.withUnsafeBytes { (ptr:UnsafePointer<UInt8>) -> Void in
strm.next_in = UnsafeMutablePointer<UInt8>(mutating: ptr)
strm.avail_in = CUnsignedInt(data.count)

repeat {
strm.next_out = UnsafeMutablePointer<UInt8>(&buffer)
strm.avail_out = CUnsignedInt(buffer.count)

res = deflate(&strm, Z_SYNC_FLUSH)

let byteCount = buffer.count - Int(strm.avail_out)
compressed.append(buffer, count: byteCount)
}
while res == Z_OK && strm.avail_out == 0

}

guard res == Z_OK && strm.avail_out > 0
|| (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
else {
throw NSError(domain: WebSocket.ErrorDomain, code: Int(WebSocket.InternalErrorCode.compressionError.rawValue), userInfo: nil)
}

compressed.removeLast(4)
return compressed
}

private func teardownDeflate() {
if deflateInitialized, Z_OK == deflateEnd(&strm) {
deflateInitialized = false
}
}

deinit {
teardownDeflate()
}
}

93 changes: 90 additions & 3 deletions Source/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ open class WebSocket : NSObject, StreamDelegate {
enum InternalErrorCode: UInt16 {
// 0-999 WebSocket status codes not used
case outputStreamWriteError = 1
case compressionError = 2
}

// Where the callback is executed. It defaults to the main UI thread queue.
Expand All @@ -86,13 +87,15 @@ open class WebSocket : NSObject, StreamDelegate {
let headerWSProtocolName = "Sec-WebSocket-Protocol"
let headerWSVersionName = "Sec-WebSocket-Version"
let headerWSVersionValue = "13"
let headerWSExtensionName = "Sec-WebSocket-Extensions"
let headerWSKeyName = "Sec-WebSocket-Key"
let headerOriginName = "Origin"
let headerWSAcceptName = "Sec-WebSocket-Accept"
let BUFFER_MAX = 4096
let FinMask: UInt8 = 0x80
let OpCodeMask: UInt8 = 0x0F
let RSVMask: UInt8 = 0x70
let RSV1Mask: UInt8 = 0x40
let MaskMask: UInt8 = 0x80
let PayloadLenMask: UInt8 = 0x7F
let MaxFrameSize: Int = 32
Expand Down Expand Up @@ -128,6 +131,7 @@ open class WebSocket : NSObject, StreamDelegate {
public var headers = [String: String]()
public var voipEnabled = false
public var disableSSLCertValidation = false
public var enableCompression = true
public var security: SSLTrustValidator?
public var enabledSSLCipherSuites: [SSLCipherSuite]?
public var origin: String?
Expand All @@ -139,12 +143,24 @@ open class WebSocket : NSObject, StreamDelegate {
public var currentURL: URL { return url }

// MARK: - Private

private struct CompressionState {
var supportsCompression = false
var messageNeedsDecompression = false
var serverMaxWindowBits = 15
var clientMaxWindowBits = 15
var clientNoContextTakeover = false
var serverNoContextTakeover = false
var decompressor:Decompressor? = nil
var compressor:Compressor? = nil
}

private var url: URL
private var inputStream: InputStream?
private var outputStream: OutputStream?
private var connected = false
private var isConnecting = false
private var compressionState = CompressionState()
private var writeQueue = OperationQueue()
private var readStack = [WSResponse]()
private var inputQueue = [Data]()
Expand Down Expand Up @@ -279,6 +295,10 @@ open class WebSocket : NSObject, StreamDelegate {
if let origin = origin {
addHeader(urlRequest, key: headerOriginName, val: origin)
}
if enableCompression {
let val = "permessage-deflate; client_max_window_bits; server_max_window_bits=15"
addHeader(urlRequest, key: headerWSExtensionName, val: val)
}
addHeader(urlRequest, key: headerWSHostName, val: "\(url.host!):\(port!)")
for (key, value) in headers {
addHeader(urlRequest, key: key, val: value)
Expand Down Expand Up @@ -577,6 +597,10 @@ open class WebSocket : NSObject, StreamDelegate {
}
if let cfHeaders = CFHTTPMessageCopyAllHeaderFields(response) {
let headers = cfHeaders.takeRetainedValue() as NSDictionary
if let extensionHeader = headers[headerWSExtensionName as NSString] as? String {
processExtensionHeader(extensionHeader)
}

if let acceptKey = headers[headerWSAcceptName as NSString] as? NSString {
if acceptKey.length > 0 {
return 0
Expand All @@ -586,6 +610,37 @@ open class WebSocket : NSObject, StreamDelegate {
return -1
}

/**
Parses the extension header, setting up the compression parameters.
*/
func processExtensionHeader(_ extensionHeader: String) {
let parts = extensionHeader.components(separatedBy: ";")
for p in parts {
let part = p.trimmingCharacters(in: .whitespaces)
if part == "permessage-deflate" {
compressionState.supportsCompression = true
} else if part.hasPrefix("server_max_window_bits="){
let valString = part.components(separatedBy: "=")[1]
if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
compressionState.serverMaxWindowBits = val
}
} else if part.hasPrefix("client_max_window_bits="){
let valString = part.components(separatedBy: "=")[1]
if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
compressionState.clientMaxWindowBits = val
}
} else if part == "client_no_context_takeover"{
compressionState.clientNoContextTakeover = true
} else if part == "server_no_context_takeover"{
compressionState.serverNoContextTakeover = true
}
}
if compressionState.supportsCompression {
compressionState.decompressor = Decompressor(windowBits: compressionState.serverMaxWindowBits)
compressionState.compressor = Compressor(windowBits: compressionState.clientMaxWindowBits)
}
}

/**
Read a 16 bit big endian value from a buffer
*/
Expand Down Expand Up @@ -650,7 +705,10 @@ open class WebSocket : NSObject, StreamDelegate {
let isMasked = (MaskMask & baseAddress[1])
let payloadLen = (PayloadLenMask & baseAddress[1])
var offset = 2
if (isMasked > 0 || (RSVMask & baseAddress[0]) > 0) && receivedOpcode != .pong {
if compressionState.supportsCompression && receivedOpcode != .continueFrame {
compressionState.messageNeedsDecompression = (RSV1Mask & baseAddress[0]) > 0
}
if (isMasked > 0 || (RSVMask & baseAddress[0]) > 0) && receivedOpcode != .pong && !compressionState.messageNeedsDecompression {
let errCode = CloseCode.protocolError.rawValue
doDisconnect(errorWithDetail("masked and rsv data is not currently supported", code: errCode))
writeError(errCode)
Expand Down Expand Up @@ -710,7 +768,23 @@ open class WebSocket : NSObject, StreamDelegate {
offset += size
len -= UInt64(size)
}
let data = Data(bytes: baseAddress+offset, count: Int(len))
let data: Data
if compressionState.messageNeedsDecompression, let decompressor = compressionState.decompressor {
do {
data = try decompressor.decompress(bytes: baseAddress+offset, count: Int(len), finish: isFin > 0)
if isFin > 0 && compressionState.serverNoContextTakeover{
try decompressor.reset()
}
} catch {
let closeReason = "Decompression failed: \(error)"
let closeCode = CloseCode.encoding.rawValue
doDisconnect(errorWithDetail(closeReason, code: closeCode))
writeError(closeCode)
return emptyBuffer
}
} else {
data = Data(bytes: baseAddress+offset, count: Int(len))
}

if receivedOpcode == .connectionClose {
var closeReason = "connection closed by server"
Expand Down Expand Up @@ -864,10 +938,23 @@ open class WebSocket : NSObject, StreamDelegate {
guard let s = self else { return }
guard let sOperation = operation else { return }
var offset = 2
var firstByte:UInt8 = s.FinMask | code.rawValue
var data = data
if [.textFrame, .binaryFrame].contains(code), let compressor = s.compressionState.compressor {
do {
data = try compressor.compress(data)
if s.compressionState.clientNoContextTakeover {
try compressor.reset()
}
firstByte |= s.RSV1Mask
} catch {
// TODO: report error? We can just send the uncompressed frame.
}
}
let dataLength = data.count
let frame = NSMutableData(capacity: dataLength + s.MaxFrameSize)
let buffer = UnsafeMutableRawPointer(frame!.mutableBytes).assumingMemoryBound(to: UInt8.self)
buffer[0] = s.FinMask | code.rawValue
buffer[0] = firstByte
if dataLength < 126 {
buffer[1] = CUnsignedChar(dataLength)
} else if dataLength <= Int(UInt16.max) {
Expand Down
Loading

0 comments on commit 7bf478c

Please sign in to comment.