Skip to content

Commit

Permalink
tls: asynchronous SNICallback
Browse files Browse the repository at this point in the history
Make ClientHelloParser handle SNI extension, and extend `_tls_wrap.js`
to support loading SNI Context from both hello, and resumed session.

fix nodejs#5967
  • Loading branch information
indutny committed Aug 6, 2013
1 parent 8e28193 commit 048e0e7
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 36 deletions.
7 changes: 4 additions & 3 deletions doc/api/tls.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ automatically set as a listener for the [secureConnection][] event. The
- `NPNProtocols`: An array or `Buffer` of possible NPN protocols. (Protocols
should be ordered by their priority).

- `SNICallback`: A function that will be called if client supports SNI TLS
extension. Only one argument will be passed to it: `servername`. And
`SNICallback` should return SecureContext instance.
- `SNICallback(servername, cb)`: A function that will be called if client
supports SNI TLS extension. Two argument will be passed to it: `servername`,
and `cb`. `SNICallback` should invoke `cb(null, ctx)`, where `ctx` is a
SecureContext instance.
(You can use `crypto.createCredentials(...).context` to get proper
SecureContext). If `SNICallback` wasn't provided - default callback with
high-level API will be used (see below).
Expand Down
63 changes: 54 additions & 9 deletions lib/_tls_wrap.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,67 @@ function onhandshakedone() {

function onclienthello(hello) {
var self = this,
once = false;
onceSession = false,
onceSNI = false;

function callback(err, session) {
if (once)
return self.destroy(new Error('TLS session callback was called twice'));
once = true;
if (onceSession)
return self.destroy(new Error('TLS session callback was called 2 times'));
onceSession = true;

if (err)
return self.destroy(err);

self.ssl.loadSession(session);
// NOTE: That we have disabled OpenSSL's internal session storage in
// `node_crypto.cc` and hence its safe to rely on getting servername only
// from clienthello or this place.
var ret = self.ssl.loadSession(session);

// Servername came from SSL session
// NOTE: TLS Session ticket doesn't include servername information
//
// Another note, From RFC3546:
//
// If, on the other hand, the older
// session is resumed, then the server MUST ignore extensions appearing
// in the client hello, and send a server hello containing no
// extensions; in this case the extension functionality negotiated
// during the original session initiation is applied to the resumed
// session.
//
// Therefore we should account session loading when dealing with servername
if (ret && ret.servername) {
self._SNICallback(ret.servername, onSNIResult);
} else if (hello.servername && self._SNICallback) {
self._SNICallback(hello.servername, onSNIResult);
} else {
self.ssl.endParser();
}
}

function onSNIResult(err, context) {
if (onceSNI)
return self.destroy(new Error('TLS SNI callback was called 2 times'));
onceSNI = true;

if (err)
return self.destroy(err);

if (context)
self.ssl.sni_context = context;

self.ssl.endParser();
}

if (hello.sessionId.length <= 0 ||
hello.tlsTicket ||
this.server &&
!this.server.emit('resumeSession', hello.sessionId, callback)) {
callback(null, null);
// Invoke SNI callback, since we've no session to resume
if (hello.servername && this._SNICallback)
this._SNICallback(hello.servername, onSNIResult);
else
this.ssl.endParser();
}
}

Expand Down Expand Up @@ -94,6 +137,7 @@ function TLSSocket(socket, options) {
this._tlsOptions = options;
this._secureEstablished = false;
this._controlReleased = false;
this._SNICallback = null;
this.ssl = null;
this.servername = null;
this.npnProtocol = null;
Expand Down Expand Up @@ -176,7 +220,8 @@ TLSSocket.prototype._init = function() {
(options.SNICallback !== SNICallback ||
options.server._contexts.length)) {
assert(typeof options.SNICallback === 'function');
this.ssl.onsniselect = options.SNICallback;
this._SNICallback = options.SNICallback;
this.ssl.enableHelloParser();
}

if (process.features.tls_npn && options.NPNProtocols)
Expand Down Expand Up @@ -499,7 +544,7 @@ Server.prototype.addContext = function(servername, credentials) {
this._contexts.push([re, crypto.createCredentials(credentials).context]);
};

function SNICallback(servername) {
function SNICallback(servername, callback) {
var ctx;

this._contexts.some(function(elem) {
Expand All @@ -509,7 +554,7 @@ function SNICallback(servername) {
}
});

return ctx;
callback(null, ctx);
}

Server.prototype.SNICallback = SNICallback;
Expand Down
2 changes: 2 additions & 0 deletions src/node_crypto_clienthello-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ inline void ClientHelloParser::Reset() {
session_id_ = NULL;
tls_ticket_size_ = -1;
tls_ticket_ = NULL;
servername_size_ = 0;
servername_ = NULL;
}

inline void ClientHelloParser::Start(ClientHelloParser::OnHelloCb onhello_cb,
Expand Down
25 changes: 25 additions & 0 deletions src/node_crypto_clienthello.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ void ClientHelloParser::ParseHeader(const uint8_t* data, size_t avail) {
hello.session_id_ = session_id_;
hello.session_size_ = session_size_;
hello.has_ticket_ = tls_ticket_ != NULL && tls_ticket_size_ != 0;
hello.servername_ = servername_;
hello.servername_size_ = servername_size_;
onhello_cb_(cb_arg_, hello);
}

Expand All @@ -134,6 +136,29 @@ void ClientHelloParser::ParseExtension(ClientHelloParser::ExtensionType type,
// That's because we're heavily relying on OpenSSL to solve any problem with
// incoming data.
switch (type) {
case kServerName:
{
if (len < 2)
return;
uint16_t server_names_len = (data[0] << 8) + data[1];
if (server_names_len + 2 > len)
return;
for (size_t offset = 2; offset < 2 + server_names_len; ) {
if (offset + 3 > len)
return;
uint8_t name_type = data[offset];
if (name_type != kServernameHostname)
return;
uint16_t name_len = (data[offset + 1] << 8) + data[offset + 2];
offset += 3;
if (offset + name_len > len)
return;
servername_ = data + offset;
servername_size_ = name_len;
offset += name_len;
}
}
break;
case kTLSSessionTicket:
tls_ticket_size_ = len;
tls_ticket_ = data + len;
Expand Down
8 changes: 8 additions & 0 deletions src/node_crypto_clienthello.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,15 @@ class ClientHelloParser {
inline uint8_t session_size() const { return session_size_; }
inline const uint8_t* session_id() const { return session_id_; }
inline bool has_ticket() const { return has_ticket_; }
inline uint8_t servername_size() const { return servername_size_; }
inline const uint8_t* servername() const { return servername_; }

private:
uint8_t session_size_;
const uint8_t* session_id_;
bool has_ticket_;
uint8_t servername_size_;
const uint8_t* servername_;

friend class ClientHelloParser;
};
Expand All @@ -71,6 +75,7 @@ class ClientHelloParser {
static const uint8_t kSSL2HeaderMask = 0x3f;
static const size_t kMaxTLSFrameLen = 16 * 1024 + 5;
static const size_t kMaxSSLExFrameLen = 32 * 1024;
static const uint8_t kServernameHostname = 0;

enum ParseState {
kWaiting,
Expand All @@ -93,6 +98,7 @@ class ClientHelloParser {
};

enum ExtensionType {
kServerName = 0,
kTLSSessionTicket = 35
};

Expand All @@ -115,6 +121,8 @@ class ClientHelloParser {
size_t extension_offset_;
uint8_t session_size_;
const uint8_t* session_id_;
uint16_t servername_size_;
const uint8_t* servername_;
uint16_t tls_ticket_size_;
const uint8_t* tls_ticket_;
};
Expand Down
77 changes: 57 additions & 20 deletions src/tls_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ using v8::Value;

static Cached<String> onread_sym;
static Cached<String> onerror_sym;
static Cached<String> onsniselect_sym;
static Cached<String> onhandshakestart_sym;
static Cached<String> onhandshakedone_sym;
static Cached<String> onclienthello_sym;
Expand All @@ -67,6 +66,8 @@ static Cached<String> version_sym;
static Cached<String> ext_key_usage_sym;
static Cached<String> sessionid_sym;
static Cached<String> tls_ticket_sym;
static Cached<String> servername_sym;
static Cached<String> sni_context_sym;

static Persistent<Function> tlsWrap;

Expand Down Expand Up @@ -174,7 +175,6 @@ TLSCallbacks::~TLSCallbacks() {
#endif // OPENSSL_NPN_NEGOTIATED

#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
servername_.Dispose();
sni_context_.Dispose();
#endif // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
}
Expand Down Expand Up @@ -640,7 +640,6 @@ void TLSCallbacks::DoRead(uv_stream_t* handle,

// Parse ClientHello first
if (!hello_.IsEnded()) {
assert(session_callbacks_);
size_t avail = 0;
uint8_t* data = reinterpret_cast<uint8_t*>(enc_in->Peek(&avail));
assert(avail == 0 || data != NULL);
Expand Down Expand Up @@ -770,6 +769,16 @@ void TLSCallbacks::EnableSessionCallbacks(
UNWRAP(TLSCallbacks);

wrap->session_callbacks_ = true;
EnableHelloParser(args);
}


void TLSCallbacks::EnableHelloParser(
const FunctionCallbackInfo<Value>& args) {
HandleScope scope(node_isolate);

UNWRAP(TLSCallbacks);

wrap->hello_.Start(OnClientHello, OnClientHelloParseEnd, wrap);
}

Expand All @@ -785,6 +794,14 @@ void TLSCallbacks::OnClientHello(void* arg,
reinterpret_cast<const char*>(hello.session_id()),
hello.session_size());
hello_obj->Set(sessionid_sym, buff);
if (hello.servername() == NULL) {
hello_obj->Set(servername_sym, String::Empty(node_isolate));
} else {
Local<String> servername = String::New(
reinterpret_cast<const char*>(hello.servername()),
hello.servername_size());
hello_obj->Set(servername_sym, servername);
}
hello_obj->Set(tls_ticket_sym, Boolean::New(hello.has_ticket()));

Handle<Value> argv[1] = { hello_obj };
Expand Down Expand Up @@ -999,7 +1016,23 @@ void TLSCallbacks::LoadSession(const FunctionCallbackInfo<Value>& args) {
if (wrap->next_sess_ != NULL)
SSL_SESSION_free(wrap->next_sess_);
wrap->next_sess_ = sess;

Local<Object> info = Object::New();
#ifndef OPENSSL_NO_TLSEXT
if (sess->tlsext_hostname == NULL) {
info->Set(servername_sym, False(node_isolate));
} else {
info->Set(servername_sym, String::New(sess->tlsext_hostname));
}
#endif
args.GetReturnValue().Set(info);
}
}

void TLSCallbacks::EndParser(const FunctionCallbackInfo<Value>& args) {
HandleScope scope(node_isolate);

UNWRAP(TLSCallbacks);

wrap->hello_.End();
}
Expand Down Expand Up @@ -1143,8 +1176,10 @@ void TLSCallbacks::GetServername(const FunctionCallbackInfo<Value>& args) {

UNWRAP(TLSCallbacks);

if (wrap->kind_ == kTLSServer && !wrap->servername_.IsEmpty()) {
args.GetReturnValue().Set(wrap->servername_);
const char* servername = SSL_get_servername(wrap->ssl_,
TLSEXT_NAMETYPE_host_name);
if (servername != NULL) {
args.GetReturnValue().Set(String::New(servername));
} else {
args.GetReturnValue().Set(false);
}
Expand Down Expand Up @@ -1179,25 +1214,22 @@ int TLSCallbacks::SelectSNIContextCallback(SSL* s, int* ad, void* arg) {

const char* servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);

if (servername) {
p->servername_.Reset(node_isolate, String::New(servername));

if (servername != NULL) {
// Call the SNI callback and use its return value as context
Local<Object> object = p->object();
if (object->Has(onsniselect_sym)) {
p->sni_context_.Dispose();
Local<Value> ctx;
if (object->Has(sni_context_sym)) {
ctx = object->Get(sni_context_sym);
}

Local<Value> arg = PersistentToLocal(node_isolate, p->servername_);
Handle<Value> ret = MakeCallback(object, onsniselect_sym, 1, &arg);
if (ctx.IsEmpty() || ctx->IsUndefined())
return SSL_TLSEXT_ERR_NOACK;

// If ret is SecureContext
if (ret->IsUndefined())
return SSL_TLSEXT_ERR_NOACK;
p->sni_context_.Dispose();
p->sni_context_.Reset(node_isolate, ctx);

p->sni_context_.Reset(node_isolate, ret);
SecureContext* sc = ObjectWrap::Unwrap<SecureContext>(ret.As<Object>());
SSL_set_SSL_CTX(s, sc->ctx_);
}
SecureContext* sc = ObjectWrap::Unwrap<SecureContext>(ctx.As<Object>());
SSL_set_SSL_CTX(s, sc->ctx_);
}

return SSL_TLSEXT_ERR_OK;
Expand All @@ -1219,13 +1251,17 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
NODE_SET_PROTOTYPE_METHOD(t, "getSession", GetSession);
NODE_SET_PROTOTYPE_METHOD(t, "setSession", SetSession);
NODE_SET_PROTOTYPE_METHOD(t, "loadSession", LoadSession);
NODE_SET_PROTOTYPE_METHOD(t, "endParser", EndParser);
NODE_SET_PROTOTYPE_METHOD(t, "getCurrentCipher", GetCurrentCipher);
NODE_SET_PROTOTYPE_METHOD(t, "verifyError", VerifyError);
NODE_SET_PROTOTYPE_METHOD(t, "setVerifyMode", SetVerifyMode);
NODE_SET_PROTOTYPE_METHOD(t, "isSessionReused", IsSessionReused);
NODE_SET_PROTOTYPE_METHOD(t,
"enableSessionCallbacks",
EnableSessionCallbacks);
NODE_SET_PROTOTYPE_METHOD(t,
"enableHelloParser",
EnableHelloParser);

#ifdef OPENSSL_NPN_NEGOTIATED
NODE_SET_PROTOTYPE_METHOD(t, "getNegotiatedProtocol", GetNegotiatedProto);
Expand All @@ -1240,7 +1276,6 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
tlsWrap.Reset(node_isolate, t->GetFunction());

onread_sym = String::New("onread");
onsniselect_sym = String::New("onsniselect");
onerror_sym = String::New("onerror");
onhandshakestart_sym = String::New("onhandshakestart");
onhandshakedone_sym = String::New("onhandshakedone");
Expand All @@ -1260,6 +1295,8 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
ext_key_usage_sym = String::New("ext_key_usage");
sessionid_sym = String::New("sessionId");
tls_ticket_sym = String::New("tlsTicket");
servername_sym = String::New("servername");
sni_context_sym = String::New("sni_context");
}

} // namespace node
Expand Down
Loading

0 comments on commit 048e0e7

Please sign in to comment.