Skip to content

Commit

Permalink
Merge pull request #7 from flash-oss/preserve-hooks-context
Browse files Browse the repository at this point in the history
Preserve hooks context
  • Loading branch information
koresar authored May 18, 2024
2 parents 169dda0 + d8cb868 commit dd2c479
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 23 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"@grpc/grpc-js": "^1.1.7",
"@grpc/proto-loader": "^0.7.5",
"bullmq": "^3.10.1",
"cls-hooked": "^4.2.2",
"eslint": "^8.55.0",
"express": "^4.18.2",
"lambda-local": "^1.7.3",
Expand Down
65 changes: 45 additions & 20 deletions src/server/Allserver.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ module.exports = require("stampit")({
await this.transport.prepareIntrospectionReply(ctx);
},

/**
* This method does not throw.
* The `ctx.procedure` is the function to call.
* @param ctx
* @return {Promise<void>}
* @private
*/
async _callProcedure(ctx) {
if (!isFunction(ctx.procedure)) {
ctx.result = {
Expand Down Expand Up @@ -91,16 +98,29 @@ module.exports = require("stampit")({
}
},

async _callMiddlewares(ctx, middlewareType) {
if (!this[middlewareType]) return;

const middlewares = [].concat(this[middlewareType]).filter(isFunction);
for (const middleware of middlewares) {
try {
const result = await middleware.call(this, ctx);
async _callMiddlewares(ctx, middlewareType, next) {
const runMiddlewares = async (middlewares) => {
if (!middlewares?.length) {
// no middlewares to run
if (next) return await next();
return;
}
const middleware = middlewares[0];
async function handleMiddlewareResult(result) {
if (result !== undefined) {
ctx.result = result;
break;
// Do not call any more middlewares
} else {
await runMiddlewares(middlewares.slice(1));
}
}
try {
if (middleware.length > 1) {
// This middleware accepts more than one argument
await middleware.call(this, ctx, handleMiddlewareResult);
} else {
const result = await middleware.call(this, ctx);
await handleMiddlewareResult(result);
}
} catch (err) {
const code = err.code || "ALLSERVER_MIDDLEWARE_ERROR";
Expand All @@ -111,9 +131,14 @@ module.exports = require("stampit")({
code,
message: `'${err.message}' error in '${middlewareType}' middleware`,
};
return;
// Do not call any more middlewares
if (next) return await next();
}
}
};

const middlewares = [].concat(this[middlewareType]).filter(isFunction);

return await runMiddlewares(middlewares);
},

async handleCall(ctx) {
Expand All @@ -127,18 +152,18 @@ module.exports = require("stampit")({
if (!ctx.arg._) ctx.arg._ = {};
if (!ctx.arg._.procedureName) ctx.arg._.procedureName = ctx.procedureName;

await this._callMiddlewares(ctx, "before");

if (!ctx.result) {
if (ctx.isIntrospection) {
await this._introspect(ctx);
} else {
await this._callProcedure(ctx);
await this._callMiddlewares(ctx, "before", async () => {
if (!ctx.result) {
if (ctx.isIntrospection) {
await this._introspect(ctx);
} else {
await this._callProcedure(ctx);
}
}
}

// Warning! This call might overwrite an existing result.
await this._callMiddlewares(ctx, "after");
// Warning! This call might overwrite an existing result.
await this._callMiddlewares(ctx, "after");
});

return this.transport.reply(ctx);
},
Expand Down
173 changes: 170 additions & 3 deletions test/server/Allserver.test.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const assert = require("assert");
const cls = require("cls-hooked");

const VoidTransport = require("stampit")({
methods: {
Expand Down Expand Up @@ -293,6 +294,7 @@ describe("Allserver", () => {

it("should handle exceptions from 'before'", async () => {
let logged = false;
let lastMiddlewareCalled = false;
const server = Allserver({
logger: {
error(err, code) {
Expand All @@ -301,15 +303,21 @@ describe("Allserver", () => {
logged = true;
},
},
before() {
throw new Error("Handle me please");
},
before: [
() => {
throw new Error("Handle me please");
},
() => {
lastMiddlewareCalled = true;
},
],
});

let ctx = { void: { proc: "testMethod" } };
await server.handleCall(ctx);

assert(logged);
assert.strictEqual(lastMiddlewareCalled, false);
assert.deepStrictEqual(ctx.result, {
success: false,
code: "ALLSERVER_MIDDLEWARE_ERROR",
Expand Down Expand Up @@ -369,6 +377,99 @@ describe("Allserver", () => {

assert(replied);
});

it("should not call procedure if 'before' throws", async () => {
let replied = false;
let procCalled = false;
const MockedTransport = VoidTransport.methods({
async reply() {
replied = true;
},
});
const server = Allserver({
logger: { error() {} },
transport: MockedTransport(),
before() {
throw new Error("Handle me please");
},
procedures: {
testMethod() {
procCalled = true;
assert.fail("should not be called");
},
},
});

let ctx = { void: { proc: "testMethod" } };
await server.handleCall(ctx);

assert(!procCalled);
assert(replied);
});

it("should preserve async_hooks context in 'before'", async () => {
const cls = require("cls-hooked");
const spaceName = "allserver";
const session = cls.getNamespace(spaceName) || cls.createNamespace(spaceName);
function getTraceId() {
if (session?.active) {
return session.get("traceId") || "";
}

return "";
}
function setTraceIdAndRunFunction(traceId, func, ...args) {
return new Promise((resolve, reject) => {
session.run(async () => {
session.set("traceId", traceId);

try {
const result = await func(...args);
resolve(result);
} catch (err) {
reject(err);
}
});
});
}
let called = [];
const server = Allserver({
before: [
(ctx, next) => {
const traceId = ctx.arg._.traceId;
if (traceId) {
setTraceIdAndRunFunction(traceId, next);
} else {
next();
}
},
() => {
assert.strictEqual(getTraceId(), "my-random-trace-id");
called.push(1);
return undefined;
},
() => {
assert.strictEqual(getTraceId(), "my-random-trace-id");
called.push(2);
return { success: false, code: "BAD_AUTH_OR_SOMETHING", message: "Bad auth or something" };
},
() => {
called.push(3);
assert.fail("should not be called");
},
],
});

let ctx = { void: { proc: "testMethod" }, arg: { _: { traceId: "my-random-trace-id" } } };
await server.handleCall(ctx);

assert.deepStrictEqual(called, [1, 2]);
assert.deepStrictEqual(ctx.result, {
success: false,
code: "BAD_AUTH_OR_SOMETHING",
message: "Bad auth or something",
});
});
});

describe("'after'", () => {
Expand Down Expand Up @@ -480,6 +581,72 @@ describe("Allserver", () => {

assert(replied);
});

it("should preserve async_hooks context in 'after'", async () => {
const cls = require("cls-hooked");
const spaceName = "allserver";
const session = cls.getNamespace(spaceName) || cls.createNamespace(spaceName);
function getTraceId() {
if (session?.active) {
return session.get("traceId") || "";
}

return "";
}
function setTraceIdAndRunFunction(traceId, func, ...args) {
return new Promise((resolve, reject) => {
session.run(async () => {
session.set("traceId", traceId);

try {
const result = await func(...args);
resolve(result);
} catch (err) {
reject(err);
}
});
});
}
let called = [];
const server = Allserver({
before: [
(ctx, next) => {
const traceId = ctx.arg._.traceId;
if (traceId) {
setTraceIdAndRunFunction(traceId, next);
} else {
next();
}
},
],
procedures: {
testMethod() {
assert.strictEqual(getTraceId(), "my-random-trace-id");
called.push("testMethod");
},
},
after: [
() => {
assert.strictEqual(getTraceId(), "my-random-trace-id");
called.push(1);
},
() => {
assert.strictEqual(getTraceId(), "my-random-trace-id");
called.push(2);
},
],
});

let ctx = { void: { proc: "testMethod" }, arg: { _: { traceId: "my-random-trace-id" } } };
await server.handleCall(ctx);

assert.deepStrictEqual(called, ["testMethod", 1, 2]);
assert.deepStrictEqual(ctx.result, {
success: true,
code: "SUCCESS",
message: "Success",
});
});
});

describe("'before'+'after'", () => {
Expand Down

0 comments on commit dd2c479

Please sign in to comment.