diff --git a/src/Condition.js b/src/Condition.js index 9d03437cf..0d9e38f46 100644 --- a/src/Condition.js +++ b/src/Condition.js @@ -7,6 +7,10 @@ function callOrConcat(schema) { return base => base.concat(schema); } +function makeIsFn(refs, predicate) { + return refs.length < 2 ? predicate : (...values) => values.every(predicate); +} + class Conditional { constructor(refs, options) { let { is, then, otherwise } = options; @@ -26,10 +30,15 @@ class Conditional { 'either `then:` or `otherwise:` is required for `when()` conditions', ); - let isFn = - typeof is === 'function' - ? is - : (...values) => values.every(value => value === is); + let isFn; + + if (typeof is === 'function') { + isFn = is; + } else if (isSchema(is)) { + isFn = makeIsFn(this.refs, value => is.isValidSync(value)); + } else { + isFn = makeIsFn(this.refs, value => value === is); + } this.fn = function(...values) { let currentSchema = values.pop(); diff --git a/test/mixed.js b/test/mixed.js index c01285727..7fa9eaaef 100644 --- a/test/mixed.js +++ b/test/mixed.js @@ -750,6 +750,22 @@ describe('Mixed Types ', () => { await inst.validate(-1).should.be.fulfilled(); }); + it('should handle conditionals with schema as condition', async function() { + let inst = object({ + flag: mixed(), + prop: number().when('flag', { + is: bool(), + then: number().min(5), + }), + }); + + await inst.validate({ flag: 'hello', prop: 4 }).should.be.fulfilled(); + + await inst + .validate({ flag: true, prop: 4 }) + .should.be.rejectedWith(ValidationError, /must be greater than/); + }); + it('should use label in error message', async function() { let label = 'Label'; let inst = object({