Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Apr 21, 2024
1 parent 5f7ec29 commit 6edfc89
Show file tree
Hide file tree
Showing 11 changed files with 631 additions and 61 deletions.
5 changes: 4 additions & 1 deletion packages/plugins/swr/tests/test-model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ export const modelMeta: ModelMeta = {
ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true },
},
},
uniqueConstraints: {},
uniqueConstraints: {
user: { id: { name: 'id', fields: ['id'] } },
post: { id: { name: 'id', fields: ['id'] } },
},
deleteCascade: {
user: ['Post'],
},
Expand Down
5 changes: 4 additions & 1 deletion packages/plugins/tanstack-query/tests/test-model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ export const modelMeta: ModelMeta = {
ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true },
},
},
uniqueConstraints: {},
uniqueConstraints: {
user: { id: { name: 'id', fields: ['id'] } },
post: { id: { name: 'id', fields: ['id'] } },
},
deleteCascade: {
user: ['Post'],
},
Expand Down
20 changes: 9 additions & 11 deletions packages/runtime/src/cross/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { lowerCaseFirst } from 'lower-case-first';
import { ModelMeta } from '.';
import { ModelMeta, requireField } from '.';

/**
* Gets field names in a data model entity, filtering out internal fields.
Expand Down Expand Up @@ -47,17 +47,15 @@ export function zip<T1, T2>(x: Enumerable<T1>, y: Enumerable<T2>): Array<[T1, T2
}

export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound = false) {
let fields = modelMeta.fields[lowerCaseFirst(model)];
if (!fields) {
const uniqueConstraints = modelMeta.uniqueConstraints[lowerCaseFirst(model)] ?? {};

const entries = Object.values(uniqueConstraints);
if (entries.length === 0) {
if (throwIfNotFound) {
throw new Error(`Unable to load fields for ${model}`);
} else {
fields = {};
throw new Error(`Model ${model} does not have any id field`);
}
return [];
}
const result = Object.values(fields).filter((f) => f.isId);
if (result.length === 0 && throwIfNotFound) {
throw new Error(`model ${model} does not have an id field`);
}
return result;

return entries[0].fields.map((f) => requireField(modelMeta, model, f));
}
99 changes: 81 additions & 18 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
});

// return only the ids of the top-level entity
const ids = this.utils.getEntityIds(this.model, result);
const ids = this.utils.getEntityIds(model, result);
return { result: ids, postWriteChecks: [...postCreateChecks.values()] };
}

Expand Down Expand Up @@ -792,8 +792,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

// proceed with the create and collect post-create checks
const { postWriteChecks: checks } = await this.doCreate(model, { data: createData }, db);
const { postWriteChecks: checks, result } = await this.doCreate(model, { data: createData }, db);
postWriteChecks.push(...checks);

return result;
};

const _createMany = async (
Expand Down Expand Up @@ -881,18 +883,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// check pre-update guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);

// handles the case where id fields are updated
const postUpdateIds = this.utils.clone(existing);
for (const key of Object.keys(existing)) {
const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key];
if (
typeof updateValue === 'string' ||
typeof updateValue === 'number' ||
typeof updateValue === 'bigint'
) {
postUpdateIds[key] = updateValue;
}
}
// handle the case where id fields are updated
const _args: any = args;
const updatePayload = _args.data && typeof _args.data === 'object' ? _args.data : _args;
const postUpdateIds = this.calculatePostUpdateIds(model, existing, updatePayload);

// register post-update check
await _registerPostUpdateCheck(model, existing, postUpdateIds);
Expand Down Expand Up @@ -984,10 +978,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// update case

// check pre-update guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
await this.utils.checkPolicyForUnique(model, existing, 'update', db, args);

// handle the case where id fields are updated
const postUpdateIds = this.calculatePostUpdateIds(model, existing, args.update);

// register post-update check
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
await _registerPostUpdateCheck(model, existing, postUpdateIds);

// convert upsert to update
const convertedUpdate = {
Expand Down Expand Up @@ -1021,9 +1018,22 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
if (existing) {
// connect
await _connectDisconnect(model, args.where, context);
return true;
} else {
// create
await _create(model, args.create, context);
const created = await _create(model, args.create, context);

const upperContext = context.nestingPath[context.nestingPath.length - 2];
if (upperContext?.where && context.field) {
// check if the where clause of the upper context references the id
// of the connected entity, if so, we need to update it
this.overrideForeignKeyFields(upperContext.model, upperContext.where, context.field, created);
}

// remove the payload from the parent
this.removeFromParent(context.parent, 'connectOrCreate', args);

return false;
}
},

Expand Down Expand Up @@ -1093,6 +1103,52 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
return { result, postWriteChecks };
}

// calculate id fields used for post-update check given an update payload
private calculatePostUpdateIds(_model: string, currentIds: any, updatePayload: any) {
const result = this.utils.clone(currentIds);
for (const key of Object.keys(currentIds)) {
const updateValue = updatePayload[key];
if (typeof updateValue === 'string' || typeof updateValue === 'number' || typeof updateValue === 'bigint') {
result[key] = updateValue;
}
}
return result;
}

// updates foreign key fields inside `payload` based on relation id fields in `newIds`
private overrideForeignKeyFields(
model: string,
payload: any,
relation: FieldInfo,
newIds: Record<string, unknown>
) {
if (!relation.foreignKeyMapping || Object.keys(relation.foreignKeyMapping).length === 0) {
return;
}

// override foreign key values
for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) {
if (payload[fk] !== undefined && newIds[id] !== undefined) {
payload[fk] = newIds[id];
}
}

// deal with compound id fields
const uniqueConstraints = this.utils.getUniqueConstraints(model);
for (const [name, constraint] of Object.entries(uniqueConstraints)) {
if (constraint.fields.length > 1) {
const target = payload[name];
if (target) {
for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) {
if (target[fk] !== undefined && newIds[id] !== undefined) {
target[fk] = newIds[id];
}
}
}
}
}
}

// Validates the given update payload against Zod schema if any
private validateUpdateInputSchema(model: string, data: any) {
const schema = this.utils.getZodSchema(model, 'update');
Expand Down Expand Up @@ -1228,7 +1284,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

if (existing) {
// update case
const { result, postWriteChecks } = await this.doUpdate({ where: existing, data: update, ...rest }, tx);
const { result, postWriteChecks } = await this.doUpdate(
{
where: this.utils.composeCompoundUniqueField(this.model, existing),
data: update,
...rest,
},
tx
);
await this.runPostWriteChecks(postWriteChecks, tx);
return this.utils.readBack(tx, this.model, 'update', args, result);
} else {
Expand Down
30 changes: 30 additions & 0 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,27 @@ export class PolicyUtil {
}
}

composeCompoundUniqueField(model: string, fieldData: any) {
const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)];
if (!uniqueConstraints) {
return fieldData;
}

// e.g.: { a: '1', b: '1' } => { a_b: { a: '1', b: '1' } }
const result: any = this.clone(fieldData);
for (const [name, constraint] of Object.entries(uniqueConstraints)) {
if (constraint.fields.length > 1 && constraint.fields.every((f) => fieldData[f] !== undefined)) {
// multi-field unique constraint, compose it
result[name] = constraint.fields.reduce<any>(
(prev, field) => ({ ...prev, [field]: fieldData[field] }),
{}
);
constraint.fields.forEach((f) => delete result[f]);
}
}
return result;
}

/**
* Gets unique constraints for the given model.
*/
Expand Down Expand Up @@ -642,6 +663,15 @@ export class PolicyUtil {
// preserve the original structure
currQuery[currField.backLink] = { ...visitWhere };
}

if (forMutationPayload && currQuery[currField.backLink]) {
// reconstruct compound unique field
currQuery[currField.backLink] = this.composeCompoundUniqueField(
backLinkField.type,
currQuery[currField.backLink]
);
}

currQuery = currQuery[currField.backLink];
}
currField = field;
Expand Down
45 changes: 37 additions & 8 deletions packages/sdk/src/model-meta-generator.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
ArrayExpr,
DataModel,
DataModelAttribute,
DataModelField,
isArrayExpr,
isBooleanLiteral,
Expand Down Expand Up @@ -239,10 +240,7 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] {
function getUniqueConstraints(model: DataModel) {
const constraints: Array<{ name: string; fields: string[] }> = [];

// model-level constraints
for (const attr of model.attributes.filter(
(attr) => attr.decl.ref?.name === '@@unique' || attr.decl.ref?.name === '@@id'
)) {
const extractConstraint = (attr: DataModelAttribute) => {
const argsMap = getAttributeArgs(attr);
if (argsMap.fields) {
const fieldNames = (argsMap.fields as ArrayExpr).items.map(
Expand All @@ -253,14 +251,45 @@ function getUniqueConstraints(model: DataModel) {
// default constraint name is fields concatenated with underscores
constraintName = fieldNames.join('_');
}
constraints.push({ name: constraintName, fields: fieldNames });
return { name: constraintName, fields: fieldNames };
} else {
return undefined;
}
};

const addConstraint = (constraint: { name: string; fields: string[] }) => {
if (!constraints.some((c) => c.name === constraint.name)) {
constraints.push(constraint);
}
};

// field-level @id first
for (const field of model.fields) {
if (hasAttribute(field, '@id')) {
addConstraint({ name: field.name, fields: [field.name] });
}
}

// field-level constraints
// then model-level @@id
for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@id')) {
const constraint = extractConstraint(attr);
if (constraint) {
addConstraint(constraint);
}
}

// then field-level @unique
for (const field of model.fields) {
if (hasAttribute(field, '@id') || hasAttribute(field, '@unique')) {
constraints.push({ name: field.name, fields: [field.name] });
if (hasAttribute(field, '@unique')) {
addConstraint({ name: field.name, fields: [field.name] });
}
}

// then model-level @@unique
for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@unique')) {
const constraint = extractConstraint(attr);
if (constraint) {
addConstraint(constraint);
}
}

Expand Down
Loading

0 comments on commit 6edfc89

Please sign in to comment.