From 0355193399dfe40c0f94fbc450f18c88b5415d94 Mon Sep 17 00:00:00 2001
From: Yiming <yiming@whimslab.io>
Date: Sat, 13 Apr 2024 16:03:37 +0800
Subject: [PATCH] fix(zmodel): clean up logic for detecting what attributes
 should be inherited from base models (#1249)

---
 packages/schema/src/utils/ast-utils.ts        | 33 ++++++++---
 packages/schema/tests/schema/abstract.test.ts | 20 +++++++
 packages/sdk/src/utils.ts                     | 56 +++++++++++--------
 .../with-delegate/issue-1243.test.ts          | 55 ++++++++++++++++++
 4 files changed, 134 insertions(+), 30 deletions(-)
 create mode 100644 tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts

diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts
index fbb9e4ae2..d33f27e71 100644
--- a/packages/schema/src/utils/ast-utils.ts
+++ b/packages/schema/src/utils/ast-utils.ts
@@ -1,6 +1,7 @@
 import {
     BinaryExpr,
     DataModel,
+    DataModelAttribute,
     DataModelField,
     Expression,
     InheritableNode,
@@ -63,14 +64,7 @@ export function mergeBaseModel(model: Model, linker: Linker) {
                 .concat(dataModel.fields);
 
             dataModel.attributes = bases
-                // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
-                .flatMap((base) => base.attributes)
-                // don't inherit skip-level attributes
-                .filter((attr) => !attr.$inheritedFrom)
-                // don't inherit `@@delegate` attribute
-                .filter((attr) => attr.decl.$refText !== '@@delegate')
-                // don't inherit `@@map` attribute
-                .filter((attr) => attr.decl.$refText !== '@@map')
+                .flatMap((base) => base.attributes.filter((attr) => filterBaseAttribute(base, attr)))
                 .map((attr) => cloneAst(attr, dataModel, buildReference))
                 .concat(dataModel.attributes);
 
@@ -85,6 +79,29 @@ export function mergeBaseModel(model: Model, linker: Linker) {
     model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract));
 }
 
+function filterBaseAttribute(base: DataModel, attr: DataModelAttribute) {
+    if (attr.$inheritedFrom) {
+        // don't inherit from skip-level base
+        return false;
+    }
+
+    // uninheritable attributes for all inheritance
+    const uninheritableAttributes = ['@@delegate', '@@map'];
+
+    // uninheritable attributes for delegate inheritance (they reference fields from the base)
+    const uninheritableFromDelegateAttributes = ['@@unique', '@@index', '@@fulltext'];
+
+    if (uninheritableAttributes.includes(attr.decl.$refText)) {
+        return false;
+    }
+
+    if (isDelegateModel(base) && uninheritableFromDelegateAttributes.includes(attr.decl.$refText)) {
+        return false;
+    }
+
+    return true;
+}
+
 // deep clone an AST, relink references, and set its container
 function cloneAst<T extends InheritableNode>(
     node: T,
diff --git a/packages/schema/tests/schema/abstract.test.ts b/packages/schema/tests/schema/abstract.test.ts
index 47d607962..6a4b69e49 100644
--- a/packages/schema/tests/schema/abstract.test.ts
+++ b/packages/schema/tests/schema/abstract.test.ts
@@ -61,4 +61,24 @@ describe('Abstract Schema Tests', () => {
           
         `);
     });
+
+    it('multiple id fields from base', async () => {
+        await loadModel(`
+        abstract model Base {
+            id1 String
+            id2 String
+            value String
+            
+            @@id([id1, id2])
+        }
+
+        model Item1 extends Base {
+            x String
+        }
+
+        model Item2 extends Base {
+            y String
+        }        
+        `);
+    });
 });
diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts
index e3ed44b99..6617983aa 100644
--- a/packages/sdk/src/utils.ts
+++ b/packages/sdk/src/utils.ts
@@ -176,39 +176,51 @@ export function isDataModelFieldReference(node: AstNode): node is ReferenceExpr
 }
 
 /**
- * Gets `@@id` fields declared at the data model level
+ * Gets `@@id` fields declared at the data model level (including search in base models)
  */
 export function getModelIdFields(model: DataModel) {
-    const idAttr = model.attributes.find((attr) => attr.decl.$refText === '@@id');
-    if (!idAttr) {
-        return [];
-    }
-    const fieldsArg = idAttr.args.find((a) => a.$resolvedParam?.name === 'fields');
-    if (!fieldsArg || !isArrayExpr(fieldsArg.value)) {
-        return [];
+    const modelsToCheck = model.$baseMerged ? [model] : [model, ...getRecursiveBases(model)];
+
+    for (const modelToCheck of modelsToCheck) {
+        const idAttr = modelToCheck.attributes.find((attr) => attr.decl.$refText === '@@id');
+        if (!idAttr) {
+            continue;
+        }
+        const fieldsArg = idAttr.args.find((a) => a.$resolvedParam?.name === 'fields');
+        if (!fieldsArg || !isArrayExpr(fieldsArg.value)) {
+            continue;
+        }
+
+        return fieldsArg.value.items
+            .filter((item): item is ReferenceExpr => isReferenceExpr(item))
+            .map((item) => resolved(item.target) as DataModelField);
     }
 
-    return fieldsArg.value.items
-        .filter((item): item is ReferenceExpr => isReferenceExpr(item))
-        .map((item) => resolved(item.target) as DataModelField);
+    return [];
 }
 
 /**
- * Gets `@@unique` fields declared at the data model level
+ * Gets `@@unique` fields declared at the data model level (including search in base models)
  */
 export function getModelUniqueFields(model: DataModel) {
-    const uniqueAttr = model.attributes.find((attr) => attr.decl.$refText === '@@unique');
-    if (!uniqueAttr) {
-        return [];
-    }
-    const fieldsArg = uniqueAttr.args.find((a) => a.$resolvedParam?.name === 'fields');
-    if (!fieldsArg || !isArrayExpr(fieldsArg.value)) {
-        return [];
+    const modelsToCheck = model.$baseMerged ? [model] : [model, ...getRecursiveBases(model)];
+
+    for (const modelToCheck of modelsToCheck) {
+        const uniqueAttr = modelToCheck.attributes.find((attr) => attr.decl.$refText === '@@unique');
+        if (!uniqueAttr) {
+            continue;
+        }
+        const fieldsArg = uniqueAttr.args.find((a) => a.$resolvedParam?.name === 'fields');
+        if (!fieldsArg || !isArrayExpr(fieldsArg.value)) {
+            continue;
+        }
+
+        return fieldsArg.value.items
+            .filter((item): item is ReferenceExpr => isReferenceExpr(item))
+            .map((item) => resolved(item.target) as DataModelField);
     }
 
-    return fieldsArg.value.items
-        .filter((item): item is ReferenceExpr => isReferenceExpr(item))
-        .map((item) => resolved(item.target) as DataModelField);
+    return [];
 }
 
 /**
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts b/tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts
new file mode 100644
index 000000000..941bc9b61
--- /dev/null
+++ b/tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts
@@ -0,0 +1,55 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('Regression for issue 1243', () => {
+    it('uninheritable fields', async () => {
+        const schema = `
+        model Base {
+            id String @id @default(cuid())
+            type String
+            foo String
+            
+            @@delegate(type)
+            @@index([foo])
+            @@map('base')
+            @@unique([foo])
+        }
+
+        model Item1 extends Base {
+            x String
+        }
+
+        model Item2 extends Base {
+            y String
+        }
+        `;
+
+        await loadSchema(schema, {
+            enhancements: ['delegate'],
+        });
+    });
+
+    it('multiple id fields', async () => {
+        const schema = `
+        model Base {
+            id1 String
+            id2 String
+            type String
+            
+            @@delegate(type)
+            @@id([id1, id2])
+        }
+
+        model Item1 extends Base {
+            x String
+        }
+
+        model Item2 extends Base {
+            y String
+        }
+        `;
+
+        await loadSchema(schema, {
+            enhancements: ['delegate'],
+        });
+    });
+});