Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
roblabat committed Sep 6, 2022
1 parent 6011b67 commit 77b79ef
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 6 deletions.
9 changes: 9 additions & 0 deletions deno/lib/__tests__/discriminatedUnions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z
.object({ type: z.literal("transform"), val: z.literal(10) })
.transform((val) => ({
val,
})),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -57,6 +62,10 @@ test("valid - discriminator value of various primitive types", () => {
type: undefined,
val: 9,
});
// console.log("test");
// expect(schema.parse({ type: "transform", val: 10 })).toEqual({
// val: 10,
// });
});

test("invalid - null", () => {
Expand Down
35 changes: 32 additions & 3 deletions deno/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2137,7 +2137,7 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
/////////////////////////////////////////////////////
/////////////////////////////////////////////////////

export type ZodDiscriminatedUnionOption<
export type ZodDiscriminatedUnionOptionBase<
Discriminator extends string,
DiscriminatorValue extends Primitive
> = ZodObject<
Expand All @@ -2146,6 +2146,15 @@ export type ZodDiscriminatedUnionOption<
any
>;

export type ZodDiscriminatedUnionOption<
Discriminator extends string,
DiscriminatorValue extends Primitive
> =
| ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
| ZodEffects<
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
>;

export interface ZodDiscriminatedUnionDef<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Expand Down Expand Up @@ -2243,8 +2252,22 @@ export class ZodDiscriminatedUnion<

try {
types.forEach((type) => {
const discriminatorValue = type.shape[discriminator].value;
options.set(discriminatorValue, type);
if (type._def.typeName === ZodFirstPartyTypeKind.ZodObject) {
const discriminatorValue = (
type as ZodDiscriminatedUnionOptionBase<
Discriminator,
DiscriminatorValue
>
).shape[discriminator].value;
options.set(discriminatorValue, type);
} else if (type._def.typeName === ZodFirstPartyTypeKind.ZodEffects) {
const discriminatorValue = (
type as ZodEffects<
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
>
).sourceType().shape[discriminator].value;
options.set(discriminatorValue, type);
}
});
} catch (e) {
throw new Error(
Expand Down Expand Up @@ -3417,6 +3440,12 @@ export class ZodEffects<
return this._def.schema;
}

sourceType(): T {
return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects
? (this._def.schema as unknown as ZodEffects<T>).sourceType()
: (this._def.schema as T);
}

_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { status, ctx } = this._processInputParams(input);

Expand Down
9 changes: 9 additions & 0 deletions src/__tests__/discriminatedUnions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z
.object({ type: z.literal("transform"), val: z.literal(10) })
.transform((val) => ({
val,
})),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -56,6 +61,10 @@ test("valid - discriminator value of various primitive types", () => {
type: undefined,
val: 9,
});
// console.log("test");
// expect(schema.parse({ type: "transform", val: 10 })).toEqual({
// val: 10,
// });
});

test("invalid - null", () => {
Expand Down
35 changes: 32 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2137,7 +2137,7 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
/////////////////////////////////////////////////////
/////////////////////////////////////////////////////

export type ZodDiscriminatedUnionOption<
export type ZodDiscriminatedUnionOptionBase<
Discriminator extends string,
DiscriminatorValue extends Primitive
> = ZodObject<
Expand All @@ -2146,6 +2146,15 @@ export type ZodDiscriminatedUnionOption<
any
>;

export type ZodDiscriminatedUnionOption<
Discriminator extends string,
DiscriminatorValue extends Primitive
> =
| ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
| ZodEffects<
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
>;

export interface ZodDiscriminatedUnionDef<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Expand Down Expand Up @@ -2243,8 +2252,22 @@ export class ZodDiscriminatedUnion<

try {
types.forEach((type) => {
const discriminatorValue = type.shape[discriminator].value;
options.set(discriminatorValue, type);
if (type._def.typeName === ZodFirstPartyTypeKind.ZodObject) {
const discriminatorValue = (
type as ZodDiscriminatedUnionOptionBase<
Discriminator,
DiscriminatorValue
>
).shape[discriminator].value;
options.set(discriminatorValue, type);
} else if (type._def.typeName === ZodFirstPartyTypeKind.ZodEffects) {
const discriminatorValue = (
type as ZodEffects<
ZodDiscriminatedUnionOptionBase<Discriminator, DiscriminatorValue>
>
).sourceType().shape[discriminator].value;
options.set(discriminatorValue, type);
}
});
} catch (e) {
throw new Error(
Expand Down Expand Up @@ -3417,6 +3440,12 @@ export class ZodEffects<
return this._def.schema;
}

sourceType(): T {
return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects
? (this._def.schema as unknown as ZodEffects<T>).sourceType()
: (this._def.schema as T);
}

_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { status, ctx } = this._processInputParams(input);

Expand Down

0 comments on commit 77b79ef

Please sign in to comment.