diff --git a/lib/cast.js b/lib/cast.js index cb99ed6a490..22f26292f44 100644 --- a/lib/cast.js +++ b/lib/cast.js @@ -7,6 +7,7 @@ const CastError = require('./error/cast'); const StrictModeError = require('./error/strict'); const Types = require('./schema/index'); +const cast$expr = require('./helpers/query/cast$expr'); const castTextSearch = require('./schema/operators/text'); const get = require('./helpers/get'); const getConstructorName = require('./helpers/getConstructorName'); @@ -87,9 +88,7 @@ module.exports = function cast(schema, obj, options, context) { continue; } else if (path === '$expr') { - if (typeof val !== 'object' || val == null) { - throw new Error('`$expr` must be an object'); - } + val = cast$expr(val, schema); continue; } else if (path === '$elemMatch') { val = cast(schema, val, options, context); diff --git a/lib/helpers/query/cast$expr.js b/lib/helpers/query/cast$expr.js new file mode 100644 index 00000000000..39316cd644a --- /dev/null +++ b/lib/helpers/query/cast$expr.js @@ -0,0 +1,284 @@ +'use strict'; + +const CastError = require('../../error/cast'); +const StrictModeError = require('../../error/strict'); +const castNumber = require('../../cast/number'); + +const booleanComparison = new Set(['$and', '$or', '$not']); +const comparisonOperator = new Set(['$cmp', '$eq', '$lt', '$lte', '$gt', '$gte']); +const arithmeticOperatorArray = new Set([ + // avoid casting '$add' or '$subtract', because expressions can be either number or date, + // and we don't have a good way of inferring which arguments should be numbers and which should + // be dates. + '$multiply', + '$divide', + '$log', + '$mod', + '$trunc', + '$avg', + '$max', + '$min', + '$stdDevPop', + '$stdDevSamp', + '$sum' +]); +const arithmeticOperatorNumber = new Set([ + '$abs', + '$exp', + '$ceil', + '$floor', + '$ln', + '$log10', + '$round', + '$sqrt', + '$sin', + '$cos', + '$tan', + '$asin', + '$acos', + '$atan', + '$atan2', + '$asinh', + '$acosh', + '$atanh', + '$sinh', + '$cosh', + '$tanh', + '$degreesToRadians', + '$radiansToDegrees' +]); +const arrayElementOperators = new Set([ + '$arrayElemAt', + '$first', + '$last' +]); +const dateOperators = new Set([ + '$year', + '$month', + '$week', + '$dayOfMonth', + '$dayOfYear', + '$hour', + '$minute', + '$second', + '$isoDayOfWeek', + '$isoWeekYear', + '$isoWeek', + '$millisecond' +]); + +module.exports = function cast$expr(val, schema, strictQuery) { + if (typeof val !== 'object' || val == null) { + throw new Error('`$expr` must be an object'); + } + + return _castExpression(val, schema, strictQuery); +}; + +function _castExpression(val, schema, strictQuery) { + if (isPath(val)) { + // Assume path + return val; + } + + if (val.$cond != null) { + if (Array.isArray(val.$cond)) { + val.$cond = val.$cond.map(expr => _castExpression(expr, schema, strictQuery)); + } else { + val.$cond.if = _castExpression(val.$cond.if, schema, strictQuery); + val.$cond.then = _castExpression(val.$cond.then, schema, strictQuery); + val.$cond.else = _castExpression(val.$cond.else, schema, strictQuery); + } + } else if (val.$ifNull != null) { + val.$ifNull.map(v => _castExpression(v, schema, strictQuery)); + } else if (val.$switch != null) { + val.branches.map(v => _castExpression(v, schema, strictQuery)); + val.default = _castExpression(val.default, schema, strictQuery); + } + + const keys = Object.keys(val); + for (const key of keys) { + if (booleanComparison.has(key)) { + val[key] = val[key].map(v => _castExpression(v, schema, strictQuery)); + } else if (comparisonOperator.has(key)) { + val[key] = castComparison(val[key], schema, strictQuery); + } else if (arithmeticOperatorArray.has(key)) { + val[key] = castArithmetic(val[key], schema, strictQuery); + } else if (arithmeticOperatorNumber.has(key)) { + val[key] = castNumberOperator(val[key], schema, strictQuery); + } + } + + if (val.$in) { + val.$in = castIn(val.$in, schema, strictQuery); + } + if (val.$size) { + val.$size = castNumberOperator(val.$size, schema, strictQuery); + } + + _omitUndefined(val); + + return val; +} + +function _omitUndefined(val) { + const keys = Object.keys(val); + for (const key of keys) { + if (val[key] === void 0) { + delete val[key]; + } + } +} + +// { $op: } +function castNumberOperator(val) { + if (!isLiteral(val)) { + return val; + } + + try { + return castNumber(val); + } catch (err) { + throw new CastError('Number', val); + } +} + +function castIn(val, schema, strictQuery) { + let search = val[0]; + let path = val[1]; + if (!isPath(path)) { + return val; + } + + path = path.slice(1); + const schematype = schema.path(path); + if (schematype == null) { + if (strictQuery === false) { + return val; + } else if (strictQuery === 'throw') { + throw new StrictModeError('$in'); + } + + return void 0; + } + + if (!schematype.$isMongooseArray) { + throw new Error('Path must be an array for $in'); + } + + if (schematype.$isMongooseDocumentArray) { + search = schematype.$embeddedSchemaType.cast(search); + } else { + search = schematype.caster.cast(search); + } + return [search, val[1]]; +} + +// { $op: [, ] } +function castArithmetic(val) { + if (!Array.isArray(val)) { + if (!isLiteral(val)) { + return val; + } + try { + return castNumber(val); + } catch (err) { + throw new CastError('Number', val); + } + } + + return val.map(v => { + if (!isLiteral(v)) { + return v; + } + try { + return castNumber(v); + } catch (err) { + throw new CastError('Number', v); + } + }); +} + +// { $op: [expression, expression] } +function castComparison(val, schema, strictQuery) { + if (!Array.isArray(val) || val.length !== 2) { + throw new Error('Comparison operator must be an array of length 2'); + } + + val[0] = _castExpression(val[0], schema, strictQuery); + const lhs = val[0]; + + if (isLiteral(val[1])) { + let path = null; + let schematype = null; + let caster = null; + if (isPath(lhs)) { + path = lhs.slice(1); + schematype = schema.path(path); + } else if (typeof lhs === 'object' && lhs != null) { + for (const key of Object.keys(lhs)) { + if (dateOperators.has(key) && isPath(lhs[key])) { + path = lhs[key].slice(1) + '.' + key; + caster = castNumber; + } else if (arrayElementOperators.has(key) && isPath(lhs[key])) { + path = lhs[key].slice(1) + '.' + key; + schematype = schema.path(lhs[key].slice(1)); + if (schematype != null) { + if (schematype.$isMongooseDocumentArray) { + schematype = schematype.$embeddedSchemaType; + } else if (schematype.$isMongooseArray) { + schematype = schematype.caster; + } + } + } + } + } + + const is$literal = typeof val[1] === 'object' && val[1] != null && val[1].$literal != null; + if (schematype != null) { + if (is$literal) { + val[1] = { $literal: schematype.cast(val[1].$literal) }; + } else { + val[1] = schematype.cast(val[1]); + } + } else if (caster != null) { + if (is$literal) { + try { + val[1] = { $literal: caster(val[1].$literal) }; + } catch (err) { + throw new CastError(caster.name.replace(/^cast/, ''), val[1], path + '.$literal'); + } + } else { + try { + val[1] = caster(val[1]); + } catch (err) { + throw new CastError(caster.name.replace(/^cast/, ''), val[1], path); + } + } + } else if (path != null && strictQuery === true) { + return void 0; + } else if (path != null && strictQuery === 'throw') { + throw new StrictModeError(path); + } + } else { + val[1] = _castExpression(val[1]); + } + + return val; +} + +function isPath(val) { + return typeof val === 'string' && val.startsWith('$'); +} + +function isLiteral(val) { + if (typeof val === 'string' && val.startsWith('$')) { + return false; + } + if (typeof val === 'object' && val != null && Object.keys(val).find(key => key.startsWith('$'))) { + // The `$literal` expression can make an object a literal + // https://docs.mongodb.com/manual/reference/operator/aggregation/literal/#mongodb-expression-exp.-literal + return val.$literal != null; + } + return true; +} \ No newline at end of file diff --git a/test/helpers/query.cast$expr.test.js b/test/helpers/query.cast$expr.test.js new file mode 100644 index 00000000000..e7c64406574 --- /dev/null +++ b/test/helpers/query.cast$expr.test.js @@ -0,0 +1,104 @@ +'use strict'; + +const { Schema } = require('../common').mongoose; +const assert = require('assert'); +const cast$expr = require('../../lib/helpers/query/cast$expr'); + +describe('castexpr', function() { + it('casts comparisons', function() { + const testSchema = new Schema({ date: Date, spent: Number, budget: Number, nums: [Number] }); + + let res = cast$expr({ $eq: ['$date', '2021-06-01'] }, testSchema); + assert.deepEqual(res, { $eq: ['$date', new Date('2021-06-01')] }); + + res = cast$expr({ $eq: [{ $year: '$date' }, 2021] }, testSchema); + assert.deepStrictEqual(res, { $eq: [{ $year: '$date' }, 2021] }); + + res = cast$expr({ $eq: [{ $year: '$date' }, '2021'] }, testSchema); + assert.deepStrictEqual(res, { $eq: [{ $year: '$date' }, 2021] }); + + res = cast$expr({ $eq: [{ $year: '$date' }, { $literal: '2021' }] }, testSchema); + assert.deepStrictEqual(res, { $eq: [{ $year: '$date' }, { $literal: 2021 }] }); + + res = cast$expr({ $eq: [{ $year: '$date' }, { $literal: '2021' }] }, testSchema); + assert.deepStrictEqual(res, { $eq: [{ $year: '$date' }, { $literal: 2021 }] }); + + res = cast$expr({ $gt: ['$spent', '$budget'] }, testSchema); + assert.deepStrictEqual(res, { $gt: ['$spent', '$budget'] }); + + res = cast$expr({ $gt: [{ $last: '$nums' }, '42'] }, testSchema); + assert.deepStrictEqual(res, { $gt: [{ $last: '$nums' }, 42] }); + }); + + it('casts conditions', function() { + const testSchema = new Schema({ price: Number, qty: Number }); + + let discountedPrice = { + $cond: { + if: { $gte: ['$qty', { $floor: '100' }] }, + then: { $multiply: ['$price', '0.5'] }, + else: { $multiply: ['$price', '0.75'] } + } + }; + let res = cast$expr({ $lt: [discountedPrice, 5] }, testSchema); + assert.deepStrictEqual(res, { + $lt: [ + { + $cond: { + if: { $gte: ['$qty', { $floor: 100 }] }, + then: { $multiply: ['$price', 0.5] }, + else: { $multiply: ['$price', 0.75] } + } + }, + 5 + ] + }); + + discountedPrice = { + $cond: { + if: { $and: [{ $gte: ['$qty', { $floor: '100' }] }] }, + then: { $multiply: ['$price', '0.5'] }, + else: { $multiply: ['$price', '0.75'] } + } + }; + res = cast$expr({ $lt: [discountedPrice, 5] }, testSchema); + assert.deepStrictEqual(res, { + $lt: [ + { + $cond: { + if: { $and: [{ $gte: ['$qty', { $floor: 100 }] }] }, + then: { $multiply: ['$price', 0.5] }, + else: { $multiply: ['$price', 0.75] } + } + }, + 5 + ] + }); + }); + + it('casts boolean expressions', function() { + const testSchema = new Schema({ date: Date, spent: Number, budget: Number }); + + const res = cast$expr({ $and: [{ $eq: [{ $year: '$date' }, '2021'] }] }, testSchema); + assert.deepStrictEqual(res, { $and: [{ $eq: [{ $year: '$date' }, 2021] }] }); + }); + + it('cast errors', function() { + const testSchema = new Schema({ date: Date, spent: Number, budget: Number }); + + assert.throws(() => { + cast$expr({ $eq: [{ $year: '$date' }, 'not a number'] }, testSchema); + }, /Cast to Number failed/); + }); + + it('casts $in', function() { + const testSchema = new Schema({ nums: [Number], docs: [new Schema({ prop: Number }, { _id: false })] }); + + let res = cast$expr({ $in: ['42', '$nums'] }, testSchema); + assert.deepStrictEqual(res, { $in: [42, '$nums'] }); + + res = cast$expr({ $in: [{ prop: '42' }, '$docs'] }, testSchema); + res.$in[0] = res.$in[0].toBSON(); // So `deepStrictEqual()` doesn't complain about subdoc internals + assert.deepStrictEqual(res, { $in: [{ prop: 42 }, '$docs'] }); + }); +}); \ No newline at end of file