Skip to content

Commit

Permalink
Merge pull request #11254 from Automattic/gh-10663
Browse files Browse the repository at this point in the history
Casting for `$expr` in queries
  • Loading branch information
vkarpov15 authored Jan 27, 2022
2 parents beea8c4 + 94671bc commit fc3c1ce
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 3 deletions.
5 changes: 2 additions & 3 deletions lib/cast.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
const CastError = require('./error/cast');
const StrictModeError = require('./error/strict');
const Types = require('./schema/index');
const cast$expr = require('./helpers/query/cast$expr');
const castTextSearch = require('./schema/operators/text');
const get = require('./helpers/get');
const getConstructorName = require('./helpers/getConstructorName');
Expand Down Expand Up @@ -87,9 +88,7 @@ module.exports = function cast(schema, obj, options, context) {

continue;
} else if (path === '$expr') {
if (typeof val !== 'object' || val == null) {
throw new Error('`$expr` must be an object');
}
val = cast$expr(val, schema);
continue;
} else if (path === '$elemMatch') {
val = cast(schema, val, options, context);
Expand Down
284 changes: 284 additions & 0 deletions lib/helpers/query/cast$expr.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
'use strict';

const CastError = require('../../error/cast');
const StrictModeError = require('../../error/strict');
const castNumber = require('../../cast/number');

const booleanComparison = new Set(['$and', '$or', '$not']);
const comparisonOperator = new Set(['$cmp', '$eq', '$lt', '$lte', '$gt', '$gte']);
const arithmeticOperatorArray = new Set([
// avoid casting '$add' or '$subtract', because expressions can be either number or date,
// and we don't have a good way of inferring which arguments should be numbers and which should
// be dates.
'$multiply',
'$divide',
'$log',
'$mod',
'$trunc',
'$avg',
'$max',
'$min',
'$stdDevPop',
'$stdDevSamp',
'$sum'
]);
const arithmeticOperatorNumber = new Set([
'$abs',
'$exp',
'$ceil',
'$floor',
'$ln',
'$log10',
'$round',
'$sqrt',
'$sin',
'$cos',
'$tan',
'$asin',
'$acos',
'$atan',
'$atan2',
'$asinh',
'$acosh',
'$atanh',
'$sinh',
'$cosh',
'$tanh',
'$degreesToRadians',
'$radiansToDegrees'
]);
const arrayElementOperators = new Set([
'$arrayElemAt',
'$first',
'$last'
]);
const dateOperators = new Set([
'$year',
'$month',
'$week',
'$dayOfMonth',
'$dayOfYear',
'$hour',
'$minute',
'$second',
'$isoDayOfWeek',
'$isoWeekYear',
'$isoWeek',
'$millisecond'
]);

module.exports = function cast$expr(val, schema, strictQuery) {
if (typeof val !== 'object' || val == null) {
throw new Error('`$expr` must be an object');
}

return _castExpression(val, schema, strictQuery);
};

function _castExpression(val, schema, strictQuery) {
if (isPath(val)) {
// Assume path
return val;
}

if (val.$cond != null) {
if (Array.isArray(val.$cond)) {
val.$cond = val.$cond.map(expr => _castExpression(expr, schema, strictQuery));
} else {
val.$cond.if = _castExpression(val.$cond.if, schema, strictQuery);
val.$cond.then = _castExpression(val.$cond.then, schema, strictQuery);
val.$cond.else = _castExpression(val.$cond.else, schema, strictQuery);
}
} else if (val.$ifNull != null) {
val.$ifNull.map(v => _castExpression(v, schema, strictQuery));
} else if (val.$switch != null) {
val.branches.map(v => _castExpression(v, schema, strictQuery));
val.default = _castExpression(val.default, schema, strictQuery);
}

const keys = Object.keys(val);
for (const key of keys) {
if (booleanComparison.has(key)) {
val[key] = val[key].map(v => _castExpression(v, schema, strictQuery));
} else if (comparisonOperator.has(key)) {
val[key] = castComparison(val[key], schema, strictQuery);
} else if (arithmeticOperatorArray.has(key)) {
val[key] = castArithmetic(val[key], schema, strictQuery);
} else if (arithmeticOperatorNumber.has(key)) {
val[key] = castNumberOperator(val[key], schema, strictQuery);
}
}

if (val.$in) {
val.$in = castIn(val.$in, schema, strictQuery);
}
if (val.$size) {
val.$size = castNumberOperator(val.$size, schema, strictQuery);
}

_omitUndefined(val);

return val;
}

function _omitUndefined(val) {
const keys = Object.keys(val);
for (const key of keys) {
if (val[key] === void 0) {
delete val[key];
}
}
}

// { $op: <number> }
function castNumberOperator(val) {
if (!isLiteral(val)) {
return val;
}

try {
return castNumber(val);
} catch (err) {
throw new CastError('Number', val);
}
}

function castIn(val, schema, strictQuery) {
let search = val[0];
let path = val[1];
if (!isPath(path)) {
return val;
}

path = path.slice(1);
const schematype = schema.path(path);
if (schematype == null) {
if (strictQuery === false) {
return val;
} else if (strictQuery === 'throw') {
throw new StrictModeError('$in');
}

return void 0;
}

if (!schematype.$isMongooseArray) {
throw new Error('Path must be an array for $in');
}

if (schematype.$isMongooseDocumentArray) {
search = schematype.$embeddedSchemaType.cast(search);
} else {
search = schematype.caster.cast(search);
}
return [search, val[1]];
}

// { $op: [<number>, <number>] }
function castArithmetic(val) {
if (!Array.isArray(val)) {
if (!isLiteral(val)) {
return val;
}
try {
return castNumber(val);
} catch (err) {
throw new CastError('Number', val);
}
}

return val.map(v => {
if (!isLiteral(v)) {
return v;
}
try {
return castNumber(v);
} catch (err) {
throw new CastError('Number', v);
}
});
}

// { $op: [expression, expression] }
function castComparison(val, schema, strictQuery) {
if (!Array.isArray(val) || val.length !== 2) {
throw new Error('Comparison operator must be an array of length 2');
}

val[0] = _castExpression(val[0], schema, strictQuery);
const lhs = val[0];

if (isLiteral(val[1])) {
let path = null;
let schematype = null;
let caster = null;
if (isPath(lhs)) {
path = lhs.slice(1);
schematype = schema.path(path);
} else if (typeof lhs === 'object' && lhs != null) {
for (const key of Object.keys(lhs)) {
if (dateOperators.has(key) && isPath(lhs[key])) {
path = lhs[key].slice(1) + '.' + key;
caster = castNumber;
} else if (arrayElementOperators.has(key) && isPath(lhs[key])) {
path = lhs[key].slice(1) + '.' + key;
schematype = schema.path(lhs[key].slice(1));
if (schematype != null) {
if (schematype.$isMongooseDocumentArray) {
schematype = schematype.$embeddedSchemaType;
} else if (schematype.$isMongooseArray) {
schematype = schematype.caster;
}
}
}
}
}

const is$literal = typeof val[1] === 'object' && val[1] != null && val[1].$literal != null;
if (schematype != null) {
if (is$literal) {
val[1] = { $literal: schematype.cast(val[1].$literal) };
} else {
val[1] = schematype.cast(val[1]);
}
} else if (caster != null) {
if (is$literal) {
try {
val[1] = { $literal: caster(val[1].$literal) };
} catch (err) {
throw new CastError(caster.name.replace(/^cast/, ''), val[1], path + '.$literal');
}
} else {
try {
val[1] = caster(val[1]);
} catch (err) {
throw new CastError(caster.name.replace(/^cast/, ''), val[1], path);
}
}
} else if (path != null && strictQuery === true) {
return void 0;
} else if (path != null && strictQuery === 'throw') {
throw new StrictModeError(path);
}
} else {
val[1] = _castExpression(val[1]);
}

return val;
}

function isPath(val) {
return typeof val === 'string' && val.startsWith('$');
}

function isLiteral(val) {
if (typeof val === 'string' && val.startsWith('$')) {
return false;
}
if (typeof val === 'object' && val != null && Object.keys(val).find(key => key.startsWith('$'))) {
// The `$literal` expression can make an object a literal
// https://docs.mongodb.com/manual/reference/operator/aggregation/literal/#mongodb-expression-exp.-literal
return val.$literal != null;
}
return true;
}
Loading

0 comments on commit fc3c1ce

Please sign in to comment.