From 3bdf99afb98849156062ad642b7c96849126adf3 Mon Sep 17 00:00:00 2001 From: Ruslan <11838981+feedmeapples@users.noreply.github.com> Date: Wed, 18 Aug 2021 07:49:41 +0300 Subject: [PATCH] Support refreshing TLS certs in background (#369) --- README.md | 3 ++ server/config/index.js | 26 ++++------ server/routes.js | 44 +++++++--------- server/temporal-client-provider.js | 63 +++++++++++++++++++++++ server/temporal-client/temporal-client.js | 6 +-- server/tls/index.js | 4 +- server/tls/tls.js | 53 ++++++++++++------- 7 files changed, 133 insertions(+), 66 deletions(-) create mode 100644 server/temporal-client-provider.js diff --git a/README.md b/README.md index 17c30a63..abad1f34 100644 --- a/README.md +++ b/README.md @@ -40,12 +40,15 @@ Optional TLS configuration variables: | TEMPORAL_TLS_CA_PATH | Certificate authority (CA) certificate for the validation of server | | | TEMPORAL_TLS_ENABLE_HOST_VERIFICATION | Enables verification of the server certificate | true | | TEMPORAL_TLS_SERVER_NAME | Target server that is used for TLS host verification | | +| TEMPORAL_TLS_REFRESH_INTERVAL | How often to refresh TLS Certs, seconds | 0 | * To enable mutual TLS, you need to specify `TEMPORAL_TLS_KEY_PATH` and `TEMPORAL_TLS_CERT_PATH`. * For server-side TLS you need to specify only `TEMPORAL_TLS_CA_PATH`. By default we will also verify your server `hostname`, matching it to `TEMPORAL_TLS_SERVER_NAME`. You can turn this off by setting `TEMPORAL_TLS_ENABLE_HOST_VERIFICATION` to `false`. +Setting `TEMPORAL_TLS_REFRESH_INTERVAL` will make the TLS certs reload every N seconds. + ### Configuring Authentication (optional) diff --git a/server/config/index.js b/server/config/index.js index 125695c9..16e96a2d 100644 --- a/server/config/index.js +++ b/server/config/index.js @@ -2,27 +2,20 @@ const { promisify } = require('util'); const { readFile, readFileSync } = require('fs'); const yaml = require('js-yaml'); -let config = undefined; const configPath = process.env.TEMPORAL_CONFIG_PATH || './server/config.yml'; const readConfigSync = () => { - if (!config) { - const cfgContents = readFileSync(configPath, { - encoding: 'utf8', - }); - config = yaml.safeLoad(cfgContents); - } - return config; + const cfgContents = readFileSync(configPath, { + encoding: 'utf8', + }); + return yaml.safeLoad(cfgContents); }; const readConfig = async () => { - if (!config) { - const cfgContents = await promisify(readFile)(configPath, { - encoding: 'utf8', - }); - config = yaml.safeLoad(cfgContents); - } - return config; + const cfgContents = await promisify(readFile)(configPath, { + encoding: 'utf8', + }); + return yaml.safeLoad(cfgContents); }; const getAuthConfig = async () => { @@ -59,7 +52,7 @@ const getTlsConfig = () => { tls = {}; } - const { ca, key, cert, server_name, verify_host } = tls; + const { ca, key, cert, server_name, verify_host, refresh_interval } = tls; return { ca, @@ -67,6 +60,7 @@ const getTlsConfig = () => { cert, serverName: server_name, verifyHost: verify_host, + refreshInterval: refresh_interval, }; }; diff --git a/server/routes.js b/server/routes.js index 8ca49c78..ab94372d 100644 --- a/server/routes.js +++ b/server/routes.js @@ -2,21 +2,15 @@ const Router = require('koa-router'), router = new Router(), moment = require('moment'), losslessJSON = require('lossless-json'), - { - TemporalClient, - WithAuthMetadata, - WithErrorConverter, - } = require('./temporal-client'), { isWriteApiPermitted } = require('./utils'), - { getAuthConfig, getRoutingConfig } = require('./config'); -authRoutes = require('./routes-auth'); - -const tClient = WithErrorConverter(WithAuthMetadata(new TemporalClient())); + { getAuthConfig, getRoutingConfig } = require('./config'), + authRoutes = require('./routes-auth'), + { getTemporalClient: tClient } = require('./temporal-client-provider'); router.use('/auth', authRoutes); router.get('/api/namespaces', async function(ctx) { - ctx.body = await tClient.listNamespaces(ctx, { + ctx.body = await tClient().listNamespaces(ctx, { pageSize: 50, nextPageToken: ctx.query.nextPageToken ? Buffer.from(ctx.query.nextPageToken, 'base64') @@ -25,7 +19,7 @@ router.get('/api/namespaces', async function(ctx) { }); router.get('/api/namespaces/:namespace', async function(ctx) { - ctx.body = await tClient.describeNamespace(ctx, { + ctx.body = await tClient().describeNamespace(ctx, { namespace: ctx.params.namespace, }); }); @@ -66,7 +60,7 @@ router.get('/api/namespaces/:namespace/workflows/list', async function(ctx) { const { namespace } = ctx.params; - ctx.body = await tClient.listWorkflows(ctx, { + ctx.body = await tClient().listWorkflows(ctx, { namespace, query: q.queryString || undefined, nextPageToken: q.nextPageToken @@ -82,7 +76,7 @@ router.get( const { namespace, workflowId, runId } = ctx.params; - ctx.body = await tClient.getHistory(ctx, { + ctx.body = await tClient().getHistory(ctx, { namespace, execution: { workflowId, runId }, nextPageToken: q.nextPageToken @@ -127,7 +121,7 @@ router.get('/api/namespaces/:namespace/workflows/archived', async function( queryString = buildQueryString(startTime, endTime, query); } - ctx.body = await tClient.archivedWorkflows(ctx, { + ctx.body = await tClient().archivedWorkflows(ctx, { namespace, nextPageToken: nextPageToken ? Buffer.from(nextPageToken, 'base64') @@ -144,7 +138,7 @@ router.get( const { namespace, workflowId, runId } = ctx.params; do { - const page = await tClient.exportHistory(ctx, { + const page = await tClient().exportHistory(ctx, { namespace, nextPageToken, execution: { workflowId, runId }, @@ -174,7 +168,7 @@ router.get( try { const { namespace, workflowId, runId } = ctx.params; - await tClient.queryWorkflow(ctx, { + await tClient().queryWorkflow(ctx, { namespace, execution: { workflowId, runId }, query: { @@ -199,7 +193,7 @@ router.post( async function(ctx) { const { namespace, workflowId, runId } = ctx.params; - ctx.body = await tClient.queryWorkflow(ctx, { + ctx.body = await tClient().queryWorkflow(ctx, { namespace, execution: { workflowId, runId }, query: { @@ -214,7 +208,7 @@ router.post( async function(ctx) { const { namespace, workflowId, runId } = ctx.params; - ctx.body = await tClient.terminateWorkflow(ctx, { + ctx.body = await tClient().terminateWorkflow(ctx, { namespace, execution: { workflowId, runId }, reason: ctx.request.body && ctx.request.body.reason, @@ -227,7 +221,7 @@ router.post( async function(ctx) { const { namespace, workflowId, runId, signal } = ctx.params; - ctx.body = await tClient.signalWorkflow(ctx, { + ctx.body = await tClient().signalWorkflow(ctx, { namespace, execution: { workflowId, runId }, signalName: signal, @@ -241,7 +235,7 @@ router.get( const { namespace, workflowId, runId } = ctx.params; try { - ctx.body = await tClient.describeWorkflow(ctx, { + ctx.body = await tClient().describeWorkflow(ctx, { namespace, execution: { workflowId, runId }, }); @@ -250,7 +244,7 @@ router.get( throw error; } - const archivedHistoryResponse = await tClient.getHistory(); + const archivedHistoryResponse = await tClient().getHistory(); const archivedHistoryEvents = mapHistoryResponse( archivedHistoryResponse.history ); @@ -299,7 +293,7 @@ router.get( const { namespace, taskQueue } = ctx.params; const descTaskQueue = async (taskQueueType) => ( - await tClient.describeTaskQueue(ctx, { + await tClient().describeTaskQueue(ctx, { namespace, taskQueue: { name: taskQueue }, taskQueueType, @@ -337,7 +331,7 @@ router.get('/api/namespaces/:namespace/task-queues/:taskQueue/', async function( ) { const { namespace, taskQueue } = ctx.params; const descTaskQueue = async (taskQueueType) => - await tClient.describeTaskQueue(ctx, { + await tClient().describeTaskQueue(ctx, { namespace, taskQueue: { name: taskQueue }, taskQueueType, @@ -351,7 +345,7 @@ router.get('/api/namespaces/:namespace/task-queues/:taskQueue/', async function( ctx.body = tq; }); -router.post('/api/web-settings/data-converter/:port', async(ctx) => { +router.post('/api/web-settings/data-converter/:port', async (ctx) => { ctx.session.dataConverter = { port: ctx.params.port }; ctx.status = 200; }); @@ -385,7 +379,7 @@ router.get('/api/me', async (ctx) => { }); router.get('/api/cluster/version-info', async (ctx) => { - const res = await tClient.getVersionInfo(ctx); + const res = await tClient().getVersionInfo(ctx); ctx.body = res; }); diff --git a/server/temporal-client-provider.js b/server/temporal-client-provider.js new file mode 100644 index 00000000..2de4abae --- /dev/null +++ b/server/temporal-client-provider.js @@ -0,0 +1,63 @@ +const logger = require('./logger'); +const { + TemporalClient, + WithAuthMetadata, + WithErrorConverter, +} = require('./temporal-client'); +const { getTlsCredentials } = require('./tls'); +const { getTlsConfig } = require('./config'); + +let refreshInterval = Number(process.env.TEMPORAL_TLS_REFRESH_INTERVAL) || 0; + +if (refreshInterval === 0) { + const tls = getTlsConfig(); + if (tls.refreshInterval) { + refreshInterval = Number(tls.refreshInterval); + } +} + +let tlsCache; +let tClient; + +loadClient(); + +if (refreshInterval !== 0) { + setInterval(() => { + try { + const tls = getTlsCredentials(); + if ( + !equal(tls.pk, tlsCache.pk) || + !equal(tls.cert, tlsCache.cert) || + !equal(tls.ca, tlsCache.ca) || + tls.serverName !== tlsCache.serverName || + tls.verifyHost !== tlsCache.verifyHost + ) { + loadClient(); + } + } catch (err) { + logger.error(err); + } + }, refreshInterval * 1000); +} + +getTemporalClient = () => tClient; + +function loadClient() { + tlsCache = getTlsCredentials(); + tClient = WithErrorConverter(WithAuthMetadata(new TemporalClient(tlsCache))); +} + +function equal(v1, v2) { + if (Buffer.isBuffer(v1)) { + if (Buffer.isBuffer(v2)) { + return Buffer.compare(v1, v2) === 0; + } + return false; + } else if (Buffer.isBuffer(v2)) { + return false; + } else { + return v1 === v2; + } +} + +module.exports = { getTemporalClient }; diff --git a/server/temporal-client/temporal-client.js b/server/temporal-client/temporal-client.js index 58af92c0..3c0b1577 100644 --- a/server/temporal-client/temporal-client.js +++ b/server/temporal-client/temporal-client.js @@ -2,7 +2,6 @@ const grpc = require('grpc'); const protoLoader = require('@grpc/proto-loader'); const bluebird = require('bluebird'); const utils = require('../utils'); -const { getCredentials } = require('../tls'); const { buildHistory, buildWorkflowExecutionRequest, @@ -10,8 +9,9 @@ const { uiTransform, cliTransform, } = require('./helpers'); +const { getGrpcCredentials } = require('../tls'); -function TemporalClient() { +function TemporalClient(tlsConfig) { const dir = process.cwd(); const protoFileName = 'service.proto'; const options = { @@ -42,7 +42,7 @@ function TemporalClient() { const packageDefinition = protoLoader.loadSync(protoFileName, options); const service = grpc.loadPackageDefinition(packageDefinition); - const { credentials: tlsCreds, options: tlsOpts } = getCredentials(); + const { credentials: tlsCreds, options: tlsOpts } = getGrpcCredentials(tlsConfig); tlsOpts['grpc.max_receive_message_length'] = Number(process.env.TEMPORAL_GRPC_MAX_MESSAGE_LENGTH) || 4 * 1024 * 1024; diff --git a/server/tls/index.js b/server/tls/index.js index b41aae7e..b6579b49 100644 --- a/server/tls/index.js +++ b/server/tls/index.js @@ -1,2 +1,2 @@ -const { getCredentials } = require('./tls'); -module.exports = { getCredentials }; +const { getTlsCredentials, getGrpcCredentials } = require('./tls'); +module.exports = { getTlsCredentials, getGrpcCredentials }; diff --git a/server/tls/tls.js b/server/tls/tls.js index e8f22f6a..d15d8a63 100644 --- a/server/tls/tls.js +++ b/server/tls/tls.js @@ -2,8 +2,8 @@ const grpc = require('grpc'); const { readCredsFromCertFiles } = require('./read-creds-from-cert-files'); const { readCredsFromConfig } = require('./read-creds-from-config'); const { compareCaseInsensitive } = require('../utils'); -const { getTlsConfig } = require('../config'); -const logger = require('../logger') +const { getTlsConfig: getTlsCredsFromConfig } = require('../config'); +const logger = require('../logger'); const keyPath = process.env.TEMPORAL_TLS_KEY_PATH; const certPath = process.env.TEMPORAL_TLS_CERT_PATH; @@ -12,33 +12,46 @@ const serverName = process.env.TEMPORAL_TLS_SERVER_NAME; const verifyHost = [true, 'true', undefined].includes( process.env.TEMPORAL_TLS_ENABLE_HOST_VERIFICATION ); -const tlsConfigFile = getTlsConfig() -function getCredentials() { + +function getGrpcCredentials(tlsCreds) { + if (!tlsCreds || (tlsCreds.pk && !tlsCreds.ca)) { + logger.log('will use insecure connection with Temporal server...'); + return { credentials: grpc.credentials.createInsecure(), options: {} }; + } else if (tlsCreds.pk) { + logger.log('will use mTLS connection with Temporal server...'); + } else if (tlsCreds.ca) { + logger.log('will use server-side TLS connection with Temporal server...'); + } + + return createSecure(tlsCreds); +} + +function getTlsCredentials() { + const tlsConfigFile = getTlsCredsFromConfig(); + + let tls = {}; if (keyPath !== undefined && certPath !== undefined) { - logger.log('establishing secure connection using TLS cert files...'); - const { pk, cert, ca } = readCredsFromCertFiles({ + tls = readCredsFromCertFiles({ keyPath, certPath, caPath, }); - return createSecure(pk, cert, ca, serverName, verifyHost); } else if (caPath !== undefined) { - logger.log('establishing server-side TLS connection using only TLS CA file...'); - const { ca } = readCredsFromCertFiles({ caPath }); - return createSecure(undefined, undefined, ca, serverName, verifyHost); + tls = readCredsFromCertFiles({ caPath }); } else if (tlsConfigFile.key) { - logger.log( - 'establishing secure connection using TLS yml configuration...' - ); - const { pk, cert, ca, serverName, verifyHost } = readCredsFromConfig(); - return createSecure(pk, cert, ca, serverName, verifyHost); - } else { - logger.log('establishing insecure connection...'); - return { credentials: grpc.credentials.createInsecure(), options: {} }; + tls = readCredsFromConfig(); } + + return { + pk: tls.pk, + cert: tls.cert, + ca: tls.ca, + serverName: tls.serverName || serverName, + verifyHost: tls.verifyHost || verifyHost, + }; } -function createSecure(pk, cert, ca, serverName, verifyHost) { +function createSecure({ pk, cert, ca, serverName, verifyHost }) { let checkServerIdentity; if (verifyHost) { checkServerIdentity = (receivedName, cert) => { @@ -62,4 +75,4 @@ function createSecure(pk, cert, ca, serverName, verifyHost) { return { credentials, options }; } -module.exports = { getCredentials }; +module.exports = { getTlsCredentials, getGrpcCredentials };