Skip to content

Commit

Permalink
#1171 support for refine, superRefine, transform and lazy in discrimi…
Browse files Browse the repository at this point in the history
…natedUnion (#1290)

* #1171

* fix tests

* add superRefine in tests

* add support for lazy

* fix typings

* fixe typings for asserted lazy

* fix

* clean console.log from debug

* Clean up discriminatedUnion

* Fix deno test

Co-authored-by: Colin McDonnell <[email protected]>
  • Loading branch information
roblabat and Colin McDonnell authored Nov 15, 2022
1 parent 22ac512 commit 2281caa
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 116 deletions.
11 changes: 5 additions & 6 deletions deno/lib/__tests__/discriminatedUnions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z.object({ type: z.literal("transform"), val: z.literal(10) }),
z.object({ type: z.literal("refine"), val: z.literal(11) }),
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -126,9 +129,7 @@ test("wrong schema - missing discriminator", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"The discriminator value could not be extracted from all the provided schemas"
);
expect(e.message.includes("could not be extracted")).toBe(true);
}
});

Expand All @@ -140,9 +141,7 @@ test("wrong schema - duplicate discriminator values", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"Some of the discriminator values are not unique"
);
expect(e.message.includes("has duplicate value")).toEqual(true);
}
});

Expand Down
124 changes: 72 additions & 52 deletions deno/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2297,33 +2297,46 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
/////////////////////////////////////////////////////
/////////////////////////////////////////////////////

export type ZodDiscriminatedUnionOption<
Discriminator extends string,
DiscriminatorValue extends Primitive
> = ZodObject<
{ [key in Discriminator]: ZodLiteral<DiscriminatorValue> } & ZodRawShape,
any,
any
>;
const getDiscriminator = <T extends ZodTypeAny>(
type: T
): Primitive[] | null => {
if (type instanceof ZodLazy) {
return getDiscriminator(type.schema);
} else if (type instanceof ZodEffects) {
return getDiscriminator(type.innerType());
} else if (type instanceof ZodLiteral) {
return [type.value];
} else if (type instanceof ZodEnum) {
return type.options;
} else if (type instanceof ZodUndefined) {
return [undefined];
} else if (type instanceof ZodNull) {
return [null];
} else {
return null;
}
};

export type ZodDiscriminatedUnionOption<Discriminator extends string> =
ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>;

export interface ZodDiscriminatedUnionDef<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
Options extends ZodDiscriminatedUnionOption<any>[]
> extends ZodTypeDef {
discriminator: Discriminator;
options: Map<DiscriminatorValue, Option>;
options: Options;
optionsMap: Map<Primitive, ZodDiscriminatedUnionOption<any>>;
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion;
}

export class ZodDiscriminatedUnion<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
Options extends ZodDiscriminatedUnionOption<Discriminator>[]
> extends ZodType<
Option["_output"],
ZodDiscriminatedUnionDef<Discriminator, DiscriminatorValue, Option>,
Option["_input"]
output<Options[number]>,
ZodDiscriminatedUnionDef<Discriminator, Options>,
input<Options[number]>
> {
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { ctx } = this._processInputParams(input);
Expand All @@ -2338,13 +2351,13 @@ export class ZodDiscriminatedUnion<
}

const discriminator = this.discriminator;
const discriminatorValue: DiscriminatorValue = ctx.data[discriminator];
const option = this.options.get(discriminatorValue);
const discriminatorValue: string = ctx.data[discriminator];
const option = this.optionsMap.get(discriminatorValue);

if (!option) {
addIssueToContext(ctx, {
code: ZodIssueCode.invalid_union_discriminator,
options: this.validDiscriminatorValues,
options: Array.from(this.optionsMap.keys()),
path: [discriminator],
});
return INVALID;
Expand All @@ -2355,28 +2368,28 @@ export class ZodDiscriminatedUnion<
data: ctx.data,
path: ctx.path,
parent: ctx,
});
}) as any;
} else {
return option._parseSync({
data: ctx.data,
path: ctx.path,
parent: ctx,
});
}) as any;
}
}

get discriminator() {
return this._def.discriminator;
}

get validDiscriminatorValues() {
return Array.from(this.options.keys());
}

get options() {
return this._def.options;
}

get optionsMap() {
return this._def.optionsMap;
}

/**
* The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor.
* However, it only allows a union of objects, all of which need to share a discriminator property. This property must
Expand All @@ -2387,44 +2400,45 @@ export class ZodDiscriminatedUnion<
*/
static create<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Types extends [
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
...ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>[]
ZodDiscriminatedUnionOption<Discriminator>,
...ZodDiscriminatedUnionOption<Discriminator>[]
]
>(
discriminator: Discriminator,
types: Types,
options: Types,
params?: RawCreateParams
): ZodDiscriminatedUnion<Discriminator, DiscriminatorValue, Types[number]> {
): ZodDiscriminatedUnion<Discriminator, Types> {
// Get all the valid discriminator values
const options: Map<DiscriminatorValue, Types[number]> = new Map();

try {
types.forEach((type) => {
const discriminatorValue = type.shape[discriminator].value;
options.set(discriminatorValue, type);
});
} catch (e) {
throw new Error(
"The discriminator value could not be extracted from all the provided schemas"
);
}

// Assert that all the discriminator values are unique
if (options.size !== types.length) {
throw new Error("Some of the discriminator values are not unique");
const optionsMap: Map<Primitive, Types[number]> = new Map();

// try {
for (const type of options) {
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
if (!discriminatorValues) {
throw new Error(
`A discriminator value for key \`${discriminator}\`could not be extracted from all schema options`
);
}
for (const value of discriminatorValues) {
if (optionsMap.has(value)) {
throw new Error(
`Discriminator property ${discriminator} has duplicate value ${value}`
);
}
optionsMap.set(value, type);
}
}

return new ZodDiscriminatedUnion<
Discriminator,
DiscriminatorValue,
Types[number]
// DiscriminatorValue,
Types
>({
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion,
discriminator,
options,
optionsMap,
...processCreateParams(params),
});
}
Expand Down Expand Up @@ -3570,13 +3584,19 @@ export interface ZodEffectsDef<T extends ZodTypeAny = ZodTypeAny>

export class ZodEffects<
T extends ZodTypeAny,
Output = T["_output"],
Input = T["_input"]
Output = output<T>,
Input = input<T>
> extends ZodType<Output, ZodEffectsDef<T>, Input> {
innerType() {
return this._def.schema;
}

sourceType(): T {
return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects
? (this._def.schema as unknown as ZodEffects<T>).sourceType()
: (this._def.schema as T);
}

_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { status, ctx } = this._processInputParams(input);

Expand Down Expand Up @@ -4161,7 +4181,7 @@ export type ZodFirstPartySchemaTypes =
| ZodArray<any, any>
| ZodObject<any, any, any, any, any>
| ZodUnion<any>
| ZodDiscriminatedUnion<any, any, any>
| ZodDiscriminatedUnion<any, any>
| ZodIntersection<any, any>
| ZodTuple<any, any>
| ZodRecord<any, any>
Expand Down
11 changes: 5 additions & 6 deletions src/__tests__/discriminatedUnions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z.object({ type: z.literal("transform"), val: z.literal(10) }),
z.object({ type: z.literal("refine"), val: z.literal(11) }),
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -125,9 +128,7 @@ test("wrong schema - missing discriminator", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"The discriminator value could not be extracted from all the provided schemas"
);
expect(e.message.includes("could not be extracted")).toBe(true);
}
});

Expand All @@ -139,9 +140,7 @@ test("wrong schema - duplicate discriminator values", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"Some of the discriminator values are not unique"
);
expect(e.message.includes("has duplicate value")).toEqual(true);
}
});

Expand Down
Loading

0 comments on commit 2281caa

Please sign in to comment.