From 2281caa1e1c9490800ce2aaf847da145381a2f36 Mon Sep 17 00:00:00 2001 From: robin labat Date: Tue, 15 Nov 2022 09:38:01 +0100 Subject: [PATCH] #1171 support for refine, superRefine, transform and lazy in discriminatedUnion (#1290) * #1171 * fix tests * add superRefine in tests * add support for lazy * fix typings * fixe typings for asserted lazy * fix * clean console.log from debug * Clean up discriminatedUnion * Fix deno test Co-authored-by: Colin McDonnell --- .../lib/__tests__/discriminatedUnions.test.ts | 11 +- deno/lib/types.ts | 124 ++++++++++-------- src/__tests__/discriminatedUnions.test.ts | 11 +- src/types.ts | 124 ++++++++++-------- 4 files changed, 154 insertions(+), 116 deletions(-) diff --git a/deno/lib/__tests__/discriminatedUnions.test.ts b/deno/lib/__tests__/discriminatedUnions.test.ts index a5f8f8996..9b9919b39 100644 --- a/deno/lib/__tests__/discriminatedUnions.test.ts +++ b/deno/lib/__tests__/discriminatedUnions.test.ts @@ -25,6 +25,9 @@ test("valid - discriminator value of various primitive types", () => { z.object({ type: z.literal(null), val: z.literal(7) }), z.object({ type: z.literal("undefined"), val: z.literal(8) }), z.object({ type: z.literal(undefined), val: z.literal(9) }), + z.object({ type: z.literal("transform"), val: z.literal(10) }), + z.object({ type: z.literal("refine"), val: z.literal(11) }), + z.object({ type: z.literal("superRefine"), val: z.literal(12) }), ]); expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 }); @@ -126,9 +129,7 @@ test("wrong schema - missing discriminator", () => { ]); throw new Error(); } catch (e: any) { - expect(e.message).toEqual( - "The discriminator value could not be extracted from all the provided schemas" - ); + expect(e.message.includes("could not be extracted")).toBe(true); } }); @@ -140,9 +141,7 @@ test("wrong schema - duplicate discriminator values", () => { ]); throw new Error(); } catch (e: any) { - expect(e.message).toEqual( - "Some of the discriminator values are not unique" - ); + expect(e.message.includes("has duplicate value")).toEqual(true); } }); diff --git a/deno/lib/types.ts b/deno/lib/types.ts index 753197429..04019e58c 100644 --- a/deno/lib/types.ts +++ b/deno/lib/types.ts @@ -2297,33 +2297,46 @@ export class ZodUnion extends ZodType< ///////////////////////////////////////////////////// ///////////////////////////////////////////////////// -export type ZodDiscriminatedUnionOption< - Discriminator extends string, - DiscriminatorValue extends Primitive -> = ZodObject< - { [key in Discriminator]: ZodLiteral } & ZodRawShape, - any, - any ->; +const getDiscriminator = ( + type: T +): Primitive[] | null => { + if (type instanceof ZodLazy) { + return getDiscriminator(type.schema); + } else if (type instanceof ZodEffects) { + return getDiscriminator(type.innerType()); + } else if (type instanceof ZodLiteral) { + return [type.value]; + } else if (type instanceof ZodEnum) { + return type.options; + } else if (type instanceof ZodUndefined) { + return [undefined]; + } else if (type instanceof ZodNull) { + return [null]; + } else { + return null; + } +}; + +export type ZodDiscriminatedUnionOption = + ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>; export interface ZodDiscriminatedUnionDef< Discriminator extends string, - DiscriminatorValue extends Primitive, - Option extends ZodDiscriminatedUnionOption + Options extends ZodDiscriminatedUnionOption[] > extends ZodTypeDef { discriminator: Discriminator; - options: Map; + options: Options; + optionsMap: Map>; typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion; } export class ZodDiscriminatedUnion< Discriminator extends string, - DiscriminatorValue extends Primitive, - Option extends ZodDiscriminatedUnionOption + Options extends ZodDiscriminatedUnionOption[] > extends ZodType< - Option["_output"], - ZodDiscriminatedUnionDef, - Option["_input"] + output, + ZodDiscriminatedUnionDef, + input > { _parse(input: ParseInput): ParseReturnType { const { ctx } = this._processInputParams(input); @@ -2338,13 +2351,13 @@ export class ZodDiscriminatedUnion< } const discriminator = this.discriminator; - const discriminatorValue: DiscriminatorValue = ctx.data[discriminator]; - const option = this.options.get(discriminatorValue); + const discriminatorValue: string = ctx.data[discriminator]; + const option = this.optionsMap.get(discriminatorValue); if (!option) { addIssueToContext(ctx, { code: ZodIssueCode.invalid_union_discriminator, - options: this.validDiscriminatorValues, + options: Array.from(this.optionsMap.keys()), path: [discriminator], }); return INVALID; @@ -2355,13 +2368,13 @@ export class ZodDiscriminatedUnion< data: ctx.data, path: ctx.path, parent: ctx, - }); + }) as any; } else { return option._parseSync({ data: ctx.data, path: ctx.path, parent: ctx, - }); + }) as any; } } @@ -2369,14 +2382,14 @@ export class ZodDiscriminatedUnion< return this._def.discriminator; } - get validDiscriminatorValues() { - return Array.from(this.options.keys()); - } - get options() { return this._def.options; } + get optionsMap() { + return this._def.optionsMap; + } + /** * The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor. * However, it only allows a union of objects, all of which need to share a discriminator property. This property must @@ -2387,44 +2400,45 @@ export class ZodDiscriminatedUnion< */ static create< Discriminator extends string, - DiscriminatorValue extends Primitive, Types extends [ - ZodDiscriminatedUnionOption, - ZodDiscriminatedUnionOption, - ...ZodDiscriminatedUnionOption[] + ZodDiscriminatedUnionOption, + ...ZodDiscriminatedUnionOption[] ] >( discriminator: Discriminator, - types: Types, + options: Types, params?: RawCreateParams - ): ZodDiscriminatedUnion { + ): ZodDiscriminatedUnion { // Get all the valid discriminator values - const options: Map = new Map(); - - try { - types.forEach((type) => { - const discriminatorValue = type.shape[discriminator].value; - options.set(discriminatorValue, type); - }); - } catch (e) { - throw new Error( - "The discriminator value could not be extracted from all the provided schemas" - ); - } - - // Assert that all the discriminator values are unique - if (options.size !== types.length) { - throw new Error("Some of the discriminator values are not unique"); + const optionsMap: Map = new Map(); + + // try { + for (const type of options) { + const discriminatorValues = getDiscriminator(type.shape[discriminator]); + if (!discriminatorValues) { + throw new Error( + `A discriminator value for key \`${discriminator}\`could not be extracted from all schema options` + ); + } + for (const value of discriminatorValues) { + if (optionsMap.has(value)) { + throw new Error( + `Discriminator property ${discriminator} has duplicate value ${value}` + ); + } + optionsMap.set(value, type); + } } return new ZodDiscriminatedUnion< Discriminator, - DiscriminatorValue, - Types[number] + // DiscriminatorValue, + Types >({ typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion, discriminator, options, + optionsMap, ...processCreateParams(params), }); } @@ -3570,13 +3584,19 @@ export interface ZodEffectsDef export class ZodEffects< T extends ZodTypeAny, - Output = T["_output"], - Input = T["_input"] + Output = output, + Input = input > extends ZodType, Input> { innerType() { return this._def.schema; } + sourceType(): T { + return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects + ? (this._def.schema as unknown as ZodEffects).sourceType() + : (this._def.schema as T); + } + _parse(input: ParseInput): ParseReturnType { const { status, ctx } = this._processInputParams(input); @@ -4161,7 +4181,7 @@ export type ZodFirstPartySchemaTypes = | ZodArray | ZodObject | ZodUnion - | ZodDiscriminatedUnion + | ZodDiscriminatedUnion | ZodIntersection | ZodTuple | ZodRecord diff --git a/src/__tests__/discriminatedUnions.test.ts b/src/__tests__/discriminatedUnions.test.ts index 3d26a14d9..6f4615aa5 100644 --- a/src/__tests__/discriminatedUnions.test.ts +++ b/src/__tests__/discriminatedUnions.test.ts @@ -24,6 +24,9 @@ test("valid - discriminator value of various primitive types", () => { z.object({ type: z.literal(null), val: z.literal(7) }), z.object({ type: z.literal("undefined"), val: z.literal(8) }), z.object({ type: z.literal(undefined), val: z.literal(9) }), + z.object({ type: z.literal("transform"), val: z.literal(10) }), + z.object({ type: z.literal("refine"), val: z.literal(11) }), + z.object({ type: z.literal("superRefine"), val: z.literal(12) }), ]); expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 }); @@ -125,9 +128,7 @@ test("wrong schema - missing discriminator", () => { ]); throw new Error(); } catch (e: any) { - expect(e.message).toEqual( - "The discriminator value could not be extracted from all the provided schemas" - ); + expect(e.message.includes("could not be extracted")).toBe(true); } }); @@ -139,9 +140,7 @@ test("wrong schema - duplicate discriminator values", () => { ]); throw new Error(); } catch (e: any) { - expect(e.message).toEqual( - "Some of the discriminator values are not unique" - ); + expect(e.message.includes("has duplicate value")).toEqual(true); } }); diff --git a/src/types.ts b/src/types.ts index ea0cf9d1c..b97ae450f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2297,33 +2297,46 @@ export class ZodUnion extends ZodType< ///////////////////////////////////////////////////// ///////////////////////////////////////////////////// -export type ZodDiscriminatedUnionOption< - Discriminator extends string, - DiscriminatorValue extends Primitive -> = ZodObject< - { [key in Discriminator]: ZodLiteral } & ZodRawShape, - any, - any ->; +const getDiscriminator = ( + type: T +): Primitive[] | null => { + if (type instanceof ZodLazy) { + return getDiscriminator(type.schema); + } else if (type instanceof ZodEffects) { + return getDiscriminator(type.innerType()); + } else if (type instanceof ZodLiteral) { + return [type.value]; + } else if (type instanceof ZodEnum) { + return type.options; + } else if (type instanceof ZodUndefined) { + return [undefined]; + } else if (type instanceof ZodNull) { + return [null]; + } else { + return null; + } +}; + +export type ZodDiscriminatedUnionOption = + ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>; export interface ZodDiscriminatedUnionDef< Discriminator extends string, - DiscriminatorValue extends Primitive, - Option extends ZodDiscriminatedUnionOption + Options extends ZodDiscriminatedUnionOption[] > extends ZodTypeDef { discriminator: Discriminator; - options: Map; + options: Options; + optionsMap: Map>; typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion; } export class ZodDiscriminatedUnion< Discriminator extends string, - DiscriminatorValue extends Primitive, - Option extends ZodDiscriminatedUnionOption + Options extends ZodDiscriminatedUnionOption[] > extends ZodType< - Option["_output"], - ZodDiscriminatedUnionDef, - Option["_input"] + output, + ZodDiscriminatedUnionDef, + input > { _parse(input: ParseInput): ParseReturnType { const { ctx } = this._processInputParams(input); @@ -2338,13 +2351,13 @@ export class ZodDiscriminatedUnion< } const discriminator = this.discriminator; - const discriminatorValue: DiscriminatorValue = ctx.data[discriminator]; - const option = this.options.get(discriminatorValue); + const discriminatorValue: string = ctx.data[discriminator]; + const option = this.optionsMap.get(discriminatorValue); if (!option) { addIssueToContext(ctx, { code: ZodIssueCode.invalid_union_discriminator, - options: this.validDiscriminatorValues, + options: Array.from(this.optionsMap.keys()), path: [discriminator], }); return INVALID; @@ -2355,13 +2368,13 @@ export class ZodDiscriminatedUnion< data: ctx.data, path: ctx.path, parent: ctx, - }); + }) as any; } else { return option._parseSync({ data: ctx.data, path: ctx.path, parent: ctx, - }); + }) as any; } } @@ -2369,14 +2382,14 @@ export class ZodDiscriminatedUnion< return this._def.discriminator; } - get validDiscriminatorValues() { - return Array.from(this.options.keys()); - } - get options() { return this._def.options; } + get optionsMap() { + return this._def.optionsMap; + } + /** * The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor. * However, it only allows a union of objects, all of which need to share a discriminator property. This property must @@ -2387,44 +2400,45 @@ export class ZodDiscriminatedUnion< */ static create< Discriminator extends string, - DiscriminatorValue extends Primitive, Types extends [ - ZodDiscriminatedUnionOption, - ZodDiscriminatedUnionOption, - ...ZodDiscriminatedUnionOption[] + ZodDiscriminatedUnionOption, + ...ZodDiscriminatedUnionOption[] ] >( discriminator: Discriminator, - types: Types, + options: Types, params?: RawCreateParams - ): ZodDiscriminatedUnion { + ): ZodDiscriminatedUnion { // Get all the valid discriminator values - const options: Map = new Map(); - - try { - types.forEach((type) => { - const discriminatorValue = type.shape[discriminator].value; - options.set(discriminatorValue, type); - }); - } catch (e) { - throw new Error( - "The discriminator value could not be extracted from all the provided schemas" - ); - } - - // Assert that all the discriminator values are unique - if (options.size !== types.length) { - throw new Error("Some of the discriminator values are not unique"); + const optionsMap: Map = new Map(); + + // try { + for (const type of options) { + const discriminatorValues = getDiscriminator(type.shape[discriminator]); + if (!discriminatorValues) { + throw new Error( + `A discriminator value for key \`${discriminator}\`could not be extracted from all schema options` + ); + } + for (const value of discriminatorValues) { + if (optionsMap.has(value)) { + throw new Error( + `Discriminator property ${discriminator} has duplicate value ${value}` + ); + } + optionsMap.set(value, type); + } } return new ZodDiscriminatedUnion< Discriminator, - DiscriminatorValue, - Types[number] + // DiscriminatorValue, + Types >({ typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion, discriminator, options, + optionsMap, ...processCreateParams(params), }); } @@ -3570,13 +3584,19 @@ export interface ZodEffectsDef export class ZodEffects< T extends ZodTypeAny, - Output = T["_output"], - Input = T["_input"] + Output = output, + Input = input > extends ZodType, Input> { innerType() { return this._def.schema; } + sourceType(): T { + return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects + ? (this._def.schema as unknown as ZodEffects).sourceType() + : (this._def.schema as T); + } + _parse(input: ParseInput): ParseReturnType { const { status, ctx } = this._processInputParams(input); @@ -4161,7 +4181,7 @@ export type ZodFirstPartySchemaTypes = | ZodArray | ZodObject | ZodUnion - | ZodDiscriminatedUnion + | ZodDiscriminatedUnion | ZodIntersection | ZodTuple | ZodRecord