From 6ec38f102e738508c42612aa38f8e83ed2806d00 Mon Sep 17 00:00:00 2001 From: Nathan Rajlich Date: Tue, 7 Nov 2023 20:19:47 -0800 Subject: [PATCH] Use mbedtls to support Socket TLS `secureTransport: 'on'` --- .changeset/forty-scissors-jump.md | 5 + Makefile | 2 +- packages/runtime/src/$.ts | 19 +- packages/runtime/src/internal.ts | 2 + packages/runtime/src/switch.ts | 3 +- packages/runtime/src/tcp.ts | 96 ++++---- packages/runtime/src/utils.ts | 12 +- source/main.c | 2 + source/tcp.c | 9 +- source/tls.c | 356 ++++++++++++++++++++++++++++++ source/tls.h | 43 ++++ source/types.h | 14 +- 12 files changed, 505 insertions(+), 58 deletions(-) create mode 100644 .changeset/forty-scissors-jump.md create mode 100644 source/tls.c create mode 100644 source/tls.h diff --git a/.changeset/forty-scissors-jump.md b/.changeset/forty-scissors-jump.md new file mode 100644 index 00000000..eb3d7593 --- /dev/null +++ b/.changeset/forty-scissors-jump.md @@ -0,0 +1,5 @@ +--- +'nxjs-runtime': patch +--- + +Use mbedtls to support Socket TLS `secureTransport: 'on'` diff --git a/Makefile b/Makefile index 7659ea71..08d23f49 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ CXXFLAGS := $(CFLAGS) -fno-rtti -fno-exceptions ASFLAGS := -g $(ARCH) LDFLAGS = -specs=$(DEVKITPRO)/libnx/switch.specs -g $(ARCH) -Wl,-Map,$(notdir $*.map) -LIBS := -pthread `freetype-config --libs` `aarch64-none-elf-pkg-config cairo --libs` -lturbojpeg -lwebp -lquickjs -lm3 -lm +LIBS := -pthread -lmbedtls -lmbedx509 -lmbedcrypto `freetype-config --libs` `aarch64-none-elf-pkg-config cairo --libs` -lturbojpeg -lwebp -lquickjs -lm3 -lm #--------------------------------------------------------------------------------- # list of directories containing libraries, this must be the top level containing diff --git a/packages/runtime/src/$.ts b/packages/runtime/src/$.ts index 0f27a3c8..298cc367 100644 --- a/packages/runtime/src/$.ts +++ b/packages/runtime/src/$.ts @@ -1,6 +1,6 @@ import type { NetworkInfo } from './types'; import type { Callback } from './internal'; -import type { Server } from './tcp'; +import type { Server, TlsContextOpaque } from './tcp'; import type { MemoryDescriptor, Memory } from './wasm'; import type { VirtualKeyboard } from './navigator/virtual-keyboard'; @@ -52,6 +52,23 @@ export interface Init { onAccept: (fd: number) => void ): Server; + // tls.c + tlsHandshake( + cb: Callback, + fd: number, + hostname: string + ): void; + tlsWrite( + cb: Callback, + ctx: TlsContextOpaque, + data: ArrayBuffer + ): void; + tlsRead( + cb: Callback, + ctx: TlsContextOpaque, + buffer: ArrayBuffer + ): void; + // wasm.c wasmCallFunc(f: any, ...args: unknown[]): unknown; wasmMemNew(descriptor: MemoryDescriptor): Memory; diff --git a/packages/runtime/src/internal.ts b/packages/runtime/src/internal.ts index aee967a9..f2ae097b 100644 --- a/packages/runtime/src/internal.ts +++ b/packages/runtime/src/internal.ts @@ -3,6 +3,8 @@ import type { SocketOptions } from './types'; export const INTERNAL_SYMBOL = Symbol('Internal'); +export type Opaque = { __type: T }; + export type Callback = (err: Error | null, result: T) => void; export type CallbackReturnType = T extends ( diff --git a/packages/runtime/src/switch.ts b/packages/runtime/src/switch.ts index daa913b6..e929829a 100644 --- a/packages/runtime/src/switch.ts +++ b/packages/runtime/src/switch.ts @@ -1,7 +1,7 @@ import { $ } from './$'; import { Canvas, CanvasRenderingContext2D, ctxInternal } from './canvas'; import { FontFaceSet } from './polyfills/font'; -import { type Callback, INTERNAL_SYMBOL } from './internal'; +import { type Callback, INTERNAL_SYMBOL, type Opaque } from './internal'; import { inspect } from './inspect'; import { bufferSourceToArrayBuffer, toPromise } from './utils'; import { setTimeout, clearTimeout } from './timers'; @@ -16,7 +16,6 @@ import type { SocketOptions, } from './types'; -export type Opaque = { __type: T }; export type CanvasRenderingContext2DState = Opaque<'CanvasRenderingContext2DState'>; export type FontFaceState = Opaque<'FontFaceState'>; diff --git a/packages/runtime/src/tcp.ts b/packages/runtime/src/tcp.ts index 8a398917..9083e3d1 100644 --- a/packages/runtime/src/tcp.ts +++ b/packages/runtime/src/tcp.ts @@ -9,23 +9,14 @@ import { def, toPromise, } from './utils'; -import type { - BufferSource, - SecureTransportKind, - SocketAddress, - SocketInfo, -} from './types'; -import { INTERNAL_SYMBOL, type SocketOptionsInternal } from './internal'; - -interface SocketInternal { - fd: number; - opened: Deferred; - closed: Deferred; - secureTransport: SecureTransportKind; - allowHalfOpen: boolean; -} +import type { BufferSource, SocketAddress, SocketInfo } from './types'; +import { + INTERNAL_SYMBOL, + Opaque, + type SocketOptionsInternal, +} from './internal'; -const socketInternal = new WeakMap(); +export type TlsContextOpaque = Opaque<'TlsContext'>; export function parseAddress(address: string): SocketAddress { const firstColon = address.indexOf(':'); @@ -35,24 +26,6 @@ export function parseAddress(address: string): SocketAddress { }; } -export function read(fd: number, buffer: BufferSource) { - const ab = bufferSourceToArrayBuffer(buffer); - return toPromise($.read, fd, ab); -} - -export function write(fd: number, data: string | BufferSource) { - const d = typeof data === 'string' ? encoder.encode(data) : data; - const ab = bufferSourceToArrayBuffer(d); - return toPromise($.write, fd, ab); -} - -/** - * Creates a TCP connection specified by the `hostname` - * and `port`. - * - * @param opts Object containing the `port` number and `hostname` (defaults to `127.0.0.1`) to connect to. - * @returns Promise that is fulfilled once the connection has been successfully established. - */ export async function connect(opts: SocketAddress) { const { hostname = '127.0.0.1', port } = opts; const [ip] = await resolve(hostname); @@ -62,6 +35,39 @@ export async function connect(opts: SocketAddress) { return toPromise($.connect, ip, port); } +function read(fd: number, buffer: BufferSource) { + const ab = bufferSourceToArrayBuffer(buffer); + return toPromise($.read, fd, ab); +} + +function write(fd: number, data: BufferSource) { + const ab = bufferSourceToArrayBuffer(data); + return toPromise($.write, fd, ab); +} + +function tlsHandshake(fd: number, hostname: string) { + return toPromise($.tlsHandshake, fd, hostname); +} + +function tlsRead(ctx: TlsContextOpaque, buffer: BufferSource) { + const ab = bufferSourceToArrayBuffer(buffer); + return toPromise($.tlsRead, ctx, ab); +} + +function tlsWrite(ctx: TlsContextOpaque, data: BufferSource) { + const ab = bufferSourceToArrayBuffer(data); + return toPromise($.tlsWrite, ctx, ab); +} + +interface SocketInternal { + fd: number; + tls?: TlsContextOpaque; + opened: Deferred; + closed: Deferred; +} + +const socketInternal = new WeakMap(); + /** * The `Socket` class represents a TCP connection, from which you can * read and write data. A socket begins in a _connected_ state (if the @@ -72,7 +78,6 @@ export async function connect(opts: SocketAddress) { export class Socket { readonly readable: ReadableStream; readonly writable: WritableStream; - readonly opened: Promise; readonly closed: Promise; @@ -92,8 +97,6 @@ export class Socket { fd: -1, opened: new Deferred(), closed: new Deferred(), - secureTransport, - allowHalfOpen, }; socketInternal.set(this, i); this.opened = i.opened.promise; @@ -105,7 +108,10 @@ export class Socket { if (i.opened.pending) { await socket.opened; } - const bytesRead = await read(i.fd, readBuffer); + const bytesRead = await (i.tls + ? tlsRead(i.tls, readBuffer) + : read(i.fd, readBuffer)); + //console.log('read %d bytes', bytesRead); if (bytesRead === 0) { controller.close(); if (!allowHalfOpen) { @@ -113,21 +119,29 @@ export class Socket { } return; } - controller.enqueue(new Uint8Array(readBuffer, 0, bytesRead)); + //controller.enqueue(new Uint8Array(readBuffer, 0, bytesRead)); + controller.enqueue(new Uint8Array(readBuffer.slice(0, bytesRead))); }, }); this.writable = new WritableStream({ - async write(chunk, controller) { + async write(chunk) { if (i.opened.pending) { await socket.opened; } - await write(i.fd, chunk); + const n = await (i.tls ? tlsWrite(i.tls, chunk) : write(i.fd, chunk)); + //console.log('Wrote %d bytes', n); }, }); connect(address) .then((fd) => { i.fd = fd; + if (secureTransport === 'on') { + return tlsHandshake(fd, address.hostname); + } + }) + .then((tls) => { + i.tls = tls; i.opened.resolve({ localAddress: '', remoteAddress: '', diff --git a/packages/runtime/src/utils.ts b/packages/runtime/src/utils.ts index 323f13b6..dce5b87c 100644 --- a/packages/runtime/src/utils.ts +++ b/packages/runtime/src/utils.ts @@ -45,10 +45,14 @@ export function toPromise< Func extends (cb: Callback, ...args: any[]) => any >(fn: Func, ...args: CallbackArguments) { return new Promise>((resolve, reject) => { - fn((err, result) => { - if (err) return reject(err); - resolve(result); - }, ...args); + try { + fn((err, result) => { + if (err) return reject(err); + resolve(result); + }, ...args); + } catch (err) { + reject(err); + } }); } diff --git a/source/main.c b/source/main.c index 3f7b60b3..6cf04187 100644 --- a/source/main.c +++ b/source/main.c @@ -22,6 +22,7 @@ #include "wasm.h" #include "image.h" #include "tcp.h" +#include "tls.h" #include "poll.h" #define LOG_FILENAME "nxjs-debug.log" @@ -478,6 +479,7 @@ int main(int argc, char *argv[]) nx_init_dns(ctx, init_obj); nx_init_nifm(ctx, init_obj); nx_init_tcp(ctx, init_obj); + nx_init_tls(ctx, init_obj); nx_init_swkbd(ctx, init_obj); nx_init_wasm(ctx, init_obj); JS_SetPropertyStr(ctx, global_obj, "$", init_obj); diff --git a/source/tcp.c b/source/tcp.c index 784f97b8..b1125fa3 100644 --- a/source/tcp.c +++ b/source/tcp.c @@ -6,13 +6,6 @@ #include "poll.h" #include "error.h" -typedef struct -{ - JSContext *context; - JSValue callback; - JSValue buffer; -} nx_js_callback_t; - void nx_on_connect(nx_poll_t *p, nx_connect_t *req) { nx_js_callback_t *req_cb = (nx_js_callback_t *)req->opaque; @@ -42,8 +35,8 @@ void nx_on_connect(nx_poll_t *p, nx_connect_t *req) JSValue nx_js_tcp_connect(JSContext *ctx, JSValueConst this_val, int argc, JSValueConst *argv) { - const char *ip = JS_ToCString(ctx, argv[1]); int port; + const char *ip = JS_ToCString(ctx, argv[1]); if (!ip || JS_ToInt32(ctx, &port, argv[2])) { JS_ThrowTypeError(ctx, "invalid input"); diff --git a/source/tls.c b/source/tls.c new file mode 100644 index 00000000..a189c223 --- /dev/null +++ b/source/tls.c @@ -0,0 +1,356 @@ +#include +#include +#include +#include +#include "tls.h" +#include "poll.h" +#include "error.h" + +static JSClassID nx_tls_context_class_id; + +static nx_tls_context_t *nx_tls_context_get(JSContext *ctx, JSValueConst obj) +{ + return JS_GetOpaque2(ctx, obj, nx_tls_context_class_id); +} + +static void finalizer_tls_context(JSRuntime *rt, JSValue val) +{ + fprintf(stderr, "finalizer_tls_context\n"); + nx_tls_context_t *data = JS_GetOpaque(val, nx_tls_context_class_id); + if (data) + { + mbedtls_net_free(&data->server_fd); + mbedtls_ssl_free(&data->ssl); + mbedtls_ssl_config_free(&data->conf); + js_free_rt(rt, data); + } +} + +void nx_tls_on_connect(nx_poll_t *p, nx_tls_connect_t *req) +{ + nx_js_callback_t *req_cb = (nx_js_callback_t *)req->opaque; + JSValue args[] = {JS_UNDEFINED, JS_UNDEFINED}; + + if (req->err) + { + /* Error during TLS handshake */ + char error_buf[100]; + mbedtls_strerror(req->err, error_buf, 100); + args[0] = JS_NewError(req_cb->context); + JS_SetPropertyStr(req_cb->context, args[0], "message", JS_NewString(req_cb->context, error_buf)); + } + else + { + /* Handshake complete */ + args[1] = req_cb->buffer; + } + + JSValue ret_val = JS_Call(req_cb->context, req_cb->callback, JS_NULL, 2, args); + JS_FreeValue(req_cb->context, req_cb->buffer); + JS_FreeValue(req_cb->context, req_cb->callback); + if (JS_IsException(ret_val)) + { + nx_emit_error_event(req_cb->context); + } + JS_FreeValue(req_cb->context, ret_val); + free(req_cb); + free(req); +} + +void nx_tls_do_handshake(nx_poll_t *p, nx_watcher_t *watcher, int revents) +{ + nx_tls_connect_t *req = (nx_tls_connect_t *)watcher; + int err = mbedtls_ssl_handshake(&req->data->ssl); + if (err == MBEDTLS_ERR_SSL_WANT_READ || err == MBEDTLS_ERR_SSL_WANT_WRITE) + { + // Handshake not completed, wait for more events + return; + } + nx_remove_watcher(p, watcher); + req->err = err; + req->callback(p, req); +} + +JSValue nx_tls_handshake(JSContext *ctx, JSValueConst this_val, int argc, JSValueConst *argv) +{ + int ret; + nx_context_t *nx_ctx = JS_GetContextOpaque(ctx); + + if (!nx_ctx->mbedtls_initialized) + { + mbedtls_entropy_init(&nx_ctx->entropy); + mbedtls_ctr_drbg_init(&nx_ctx->ctr_drbg); + + const char *pers = "client"; + + // Seed the RNG + if ((ret = mbedtls_ctr_drbg_seed(&nx_ctx->ctr_drbg, mbedtls_entropy_func, &nx_ctx->entropy, (const unsigned char *)pers, strlen(pers))) != 0) + { + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + JS_ThrowTypeError(ctx, "Failed seeding RNG: %s", error_buf); + return JS_EXCEPTION; + } + + nx_ctx->mbedtls_initialized = true; + } + + int fd; + const char *hostname = JS_ToCString(ctx, argv[2]); + if (!hostname || JS_ToInt32(ctx, &fd, argv[1])) + { + JS_ThrowTypeError(ctx, "invalid input"); + return JS_EXCEPTION; + } + + JSValue obj = JS_NewObjectClass(ctx, nx_tls_context_class_id); + nx_tls_context_t *data = js_mallocz(ctx, sizeof(nx_tls_context_t)); + if (!data) + { + JS_ThrowOutOfMemory(ctx); + return JS_EXCEPTION; + } + JS_SetOpaque(obj, data); + + data->server_fd.fd = fd; + mbedtls_ssl_init(&data->ssl); + mbedtls_ssl_config_init(&data->conf); + + // Setup the SSL/TLS structure and set the hostname for Server Name Indication (SNI) + if ((ret = mbedtls_ssl_config_defaults(&data->conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) + { + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + JS_ThrowTypeError(ctx, "Failed setting SSL config defaults: %s", error_buf); + return JS_EXCEPTION; + } + mbedtls_ssl_conf_authmode(&data->conf, MBEDTLS_SSL_VERIFY_NONE); + mbedtls_ssl_conf_rng(&data->conf, mbedtls_ctr_drbg_random, &nx_ctx->ctr_drbg); + if ((ret = mbedtls_ssl_set_hostname(&data->ssl, hostname)) != 0) + { + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + JS_ThrowTypeError(ctx, "Failed setting hostname: %s", error_buf); + return JS_EXCEPTION; + } + mbedtls_ssl_set_bio(&data->ssl, &data->server_fd, mbedtls_net_send, mbedtls_net_recv, NULL); + if ((ret = mbedtls_ssl_setup(&data->ssl, &data->conf)) != 0) + { + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + JS_ThrowTypeError(ctx, "Failed setting up SSL: %s", error_buf); + return JS_EXCEPTION; + } + + JS_FreeCString(ctx, hostname); + + nx_tls_connect_t *req = malloc(sizeof(nx_tls_connect_t)); + nx_js_callback_t *req_cb = malloc(sizeof(nx_js_callback_t)); + req_cb->context = ctx; + req_cb->callback = JS_DupValue(ctx, argv[0]); + req_cb->buffer = JS_DupValue(ctx, obj); + req->opaque = req_cb; + req->data = data; + req->watcher_callback = nx_tls_do_handshake; + req->callback = nx_tls_on_connect; + req->fd = fd; + req->events = POLLIN | POLLOUT | POLLERR; + + nx_add_watcher(&nx_ctx->poll, (nx_watcher_t *)req); + nx_tls_do_handshake(&nx_ctx->poll, (nx_watcher_t *)req, 0); + + return JS_UNDEFINED; +} + +void nx_tls_do_read(nx_poll_t *p, nx_watcher_t *watcher, int revents) +{ + nx_tls_read_t *req = (nx_tls_read_t *)watcher; + nx_js_callback_t *req_cb = (nx_js_callback_t *)req->opaque; + + int ret = mbedtls_ssl_read(&req->data->ssl, req->buffer, req->buffer_size); + if (ret == MBEDTLS_ERR_SSL_WANT_READ) + { + return; + } + + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + ret = 0; + } + + nx_remove_watcher(p, watcher); + + JSContext *ctx = req_cb->context; + JS_FreeValue(ctx, req_cb->buffer); + + JSValue args[] = {JS_UNDEFINED, JS_UNDEFINED}; + + if (ret < 0) + { + /* Error during read. */ + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + args[0] = JS_NewError(ctx); + JS_SetPropertyStr(ctx, args[0], "message", JS_NewString(ctx, error_buf)); + } + else + { + args[1] = JS_NewInt32(ctx, ret); + } + + JSValue ret_val = JS_Call(ctx, req_cb->callback, JS_NULL, 2, args); + JS_FreeValue(ctx, args[0]); + JS_FreeValue(ctx, args[1]); + JS_FreeValue(ctx, req_cb->callback); + if (JS_IsException(ret_val)) + { + nx_emit_error_event(ctx); + } + JS_FreeValue(ctx, ret_val); + free(req_cb); + free(req); +} + +JSValue nx_tls_read(JSContext *ctx, JSValueConst this_val, int argc, JSValueConst *argv) +{ + nx_context_t *nx_ctx = JS_GetContextOpaque(ctx); + size_t buffer_size; + + nx_tls_context_t *data = nx_tls_context_get(ctx, argv[1]); + if (!data) + return JS_EXCEPTION; + + uint8_t *buffer = JS_GetArrayBuffer(ctx, &buffer_size, argv[2]); + if (!buffer) + return JS_EXCEPTION; + + JSValue buffer_value = JS_DupValue(ctx, argv[2]); + + nx_tls_read_t *req = malloc(sizeof(nx_tls_read_t)); + nx_js_callback_t *req_cb = malloc(sizeof(nx_js_callback_t)); + req_cb->context = ctx; + req_cb->callback = JS_DupValue(ctx, argv[0]); + req_cb->buffer = buffer_value; + req->fd = data->server_fd.fd; + req->events = POLLIN | POLLERR; + req->err = 0; + req->watcher_callback = nx_tls_do_read; + req->opaque = req_cb; + req->data = data; + req->buffer = buffer; + req->buffer_size = buffer_size; + + nx_add_watcher(&nx_ctx->poll, (nx_watcher_t *)req); + nx_tls_do_read(&nx_ctx->poll, (nx_watcher_t *)req, 0); + + return JS_UNDEFINED; +} + +void nx_tls_do_write(nx_poll_t *p, nx_watcher_t *watcher, int revents) +{ + nx_tls_write_t *req = (nx_tls_write_t *)watcher; + nx_js_callback_t *req_cb = (nx_js_callback_t *)req->opaque; + + int ret = mbedtls_ssl_write(&req->data->ssl, req->buffer, req->buffer_size); + + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) + { + return; + } + + JSContext *ctx = req_cb->context; + JSValue args[] = {JS_UNDEFINED, JS_UNDEFINED}; + + if (ret < 0) + { + /* Error during write */ + char error_buf[100]; + mbedtls_strerror(ret, error_buf, 100); + args[0] = JS_NewError(ctx); + JS_SetPropertyStr(ctx, args[0], "message", JS_NewString(ctx, error_buf)); + } + else + { + req->bytes_written += ret; + if (req->bytes_written < req->buffer_size) + { + // Not all data was written, need to wait before trying again + return; + } + + args[1] = JS_NewInt32(ctx, ret); + } + + // If we got to here then all the data was written + nx_remove_watcher(p, watcher); + + JS_FreeValue(ctx, req_cb->buffer); + + JSValue ret_val = JS_Call(ctx, req_cb->callback, JS_NULL, 2, args); + JS_FreeValue(ctx, args[0]); + JS_FreeValue(ctx, args[1]); + JS_FreeValue(ctx, req_cb->callback); + if (JS_IsException(ret_val)) + { + nx_emit_error_event(ctx); + } + JS_FreeValue(ctx, ret_val); + free(req_cb); + free(req); +} + +JSValue nx_tls_write(JSContext *ctx, JSValueConst this_val, int argc, JSValueConst *argv) +{ + nx_context_t *nx_ctx = JS_GetContextOpaque(ctx); + size_t buffer_size; + + nx_tls_context_t *data = nx_tls_context_get(ctx, argv[1]); + if (!data) + return JS_EXCEPTION; + + uint8_t *buffer = JS_GetArrayBuffer(ctx, &buffer_size, argv[2]); + if (!buffer) + return JS_EXCEPTION; + + JSValue buffer_value = JS_DupValue(ctx, argv[2]); + + nx_tls_write_t *req = malloc(sizeof(nx_tls_write_t)); + nx_js_callback_t *req_cb = malloc(sizeof(nx_js_callback_t)); + req_cb->context = ctx; + req_cb->callback = JS_DupValue(ctx, argv[0]); + req_cb->buffer = buffer_value; + req->fd = data->server_fd.fd; + req->events = POLLOUT | POLLERR; + req->err = 0; + req->watcher_callback = nx_tls_do_write; + req->opaque = req_cb; + req->data = data; + req->buffer = buffer; + req->buffer_size = buffer_size; + req->bytes_written = 0; + + nx_add_watcher(&nx_ctx->poll, (nx_watcher_t *)req); + nx_tls_do_write(&nx_ctx->poll, (nx_watcher_t *)req, 0); + + return JS_UNDEFINED; +} + +static const JSCFunctionListEntry function_list[] = { + JS_CFUNC_DEF("tlsHandshake", 0, nx_tls_handshake), + JS_CFUNC_DEF("tlsRead", 0, nx_tls_read), + JS_CFUNC_DEF("tlsWrite", 0, nx_tls_write), +}; + +void nx_init_tls(JSContext *ctx, JSValueConst init_obj) +{ + + JSRuntime *rt = JS_GetRuntime(ctx); + + JS_NewClassID(&nx_tls_context_class_id); + JSClassDef nx_tls_context_class = { + "TlsContext", + .finalizer = finalizer_tls_context, + }; + JS_NewClass(rt, nx_tls_context_class_id, &nx_tls_context_class); + + JS_SetPropertyFunctionList(ctx, init_obj, function_list, countof(function_list)); +} diff --git a/source/tls.h b/source/tls.h new file mode 100644 index 00000000..955e7db4 --- /dev/null +++ b/source/tls.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include "types.h" + +typedef struct +{ + mbedtls_net_context server_fd; + mbedtls_ssl_context ssl; + mbedtls_ssl_config conf; +} nx_tls_context_t; + +typedef struct nx_tls_connect_s nx_tls_connect_t; +typedef struct nx_tls_read_s nx_tls_read_t; +typedef struct nx_tls_write_s nx_tls_write_t; + +typedef void (*nx_tls_connect_cb)(nx_poll_t *p, nx_tls_connect_t *req); + +struct nx_tls_connect_s +{ + NX_WATCHER_FIELDS + nx_tls_context_t *data; + nx_tls_connect_cb callback; +}; + +struct nx_tls_read_s +{ + NX_WATCHER_FIELDS + nx_tls_context_t *data; + unsigned char *buffer; + size_t buffer_size; +}; + +struct nx_tls_write_s +{ + NX_WATCHER_FIELDS + nx_tls_context_t *data; + const uint8_t *buffer; + size_t buffer_size; + size_t bytes_written; +}; + +void nx_init_tls(JSContext *ctx, JSValueConst init_obj); diff --git a/source/types.h b/source/types.h index 458c6ff0..0c9b23a6 100644 --- a/source/types.h +++ b/source/types.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "thpool.h" #include "poll.h" @@ -32,7 +34,12 @@ #define NX_DEF_FUNC(THISARG, NAME, FN, LENGTH) (JS_DefinePropertyValueStr(ctx, THISARG, NAME, JS_NewCFunction(ctx, FN, NAME, LENGTH), JS_PROP_C_W)) -typedef int BOOL; +typedef struct +{ + JSContext *context; + JSValue callback; + JSValue buffer; +} nx_js_callback_t; typedef struct nx_work_s nx_work_t; typedef void (*nx_work_cb)(nx_work_t *req); @@ -61,6 +68,11 @@ typedef struct IM3Environment wasm_env; JSValue onerror_handler; JSValue unhandled_rejection_handler; + + // mbedtls structures shared by all TLS connections + bool mbedtls_initialized; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; } nx_context_t; inline nx_context_t *nx_get_context(JSContext *ctx)