diff --git a/src/__tests__/repository-components.test.ts b/src/__tests__/repository-components.test.ts new file mode 100644 index 0000000..e8e013b --- /dev/null +++ b/src/__tests__/repository-components.test.ts @@ -0,0 +1,122 @@ +import * as semanticLayer from "../index.js"; + +import { assert, it } from "vitest"; + +const customersModel = semanticLayer + .model() + .withName("customers") + .fromTable("Customer") + .withDimension("customer_id", { + type: "number", + primaryKey: true, + sql: ({ model, sql }) => sql`${model.column("CustomerId")}`, + }) + .withDimension("first_name", { + type: "string", + sql: ({ model }) => model.column("FirstName"), + }); + +const invoicesModel = semanticLayer + .model() + .withName("invoices") + .fromTable("Invoice") + .withDimension("invoice_id", { + type: "number", + primaryKey: true, + sql: ({ model }) => model.column("InvoiceId"), + }) + .withDimension("customer_id", { + type: "number", + sql: ({ model }) => model.column("CustomerId"), + }); + +const invoiceLinesModel = semanticLayer + .model() + .withName("invoice_lines") + .fromTable("InvoiceLine") + .withDimension("invoice_line_id", { + type: "number", + primaryKey: true, + sql: ({ model }) => model.column("InvoiceLineId"), + }) + .withDimension("invoice_id", { + type: "number", + sql: ({ model }) => model.column("InvoiceId"), + }) + .withDimension("track_id", { + type: "number", + sql: ({ model }) => model.column("TrackId"), + }); + +const tracksModel = semanticLayer + .model() + .withName("tracks") + .fromTable("Track") + .withDimension("track_id", { + type: "number", + primaryKey: true, + sql: ({ model }) => model.column("TrackId"), + }); + +it("will correctly check if all models are connected when no joins exists", () => { + const repository = semanticLayer + .repository() + .withModel(customersModel) + .withModel(invoicesModel) + .withModel(invoiceLinesModel) + .withModel(tracksModel); + + assert.throws(() => { + repository.build("postgresql"); + }, "All models in a repository must be connected."); +}); + +it("will correctly check if all models are connected when only some models are connected (1)", () => { + const repository = semanticLayer + .repository() + .withModel(customersModel) + .withModel(invoicesModel) + .withModel(invoiceLinesModel) + .withModel(tracksModel) + .joinOneToMany( + "customers", + "invoices", + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, + ); + + assert.throws(() => { + repository.build("postgresql"); + }, "All models in a repository must be connected."); +}); + +it("will correctly check if all models are connected when only some models are connected (2)", () => { + const repository = semanticLayer + .repository() + .withModel(customersModel) + .withModel(invoicesModel) + .withModel(invoiceLinesModel) + .withModel(tracksModel) + .joinOneToMany( + "customers", + "invoices", + ({ sql, models }) => + sql`${models.customers.dimension( + "customer_id", + )} = ${models.invoices.dimension("customer_id")}`, + ) + .joinManyToOne( + "invoice_lines", + "tracks", + ({ sql, models }) => + sql`${models.invoice_lines.dimension( + "track_id", + )} = ${models.tracks.dimension("track_id")}`, + ); + + assert.throws(() => { + repository.build("postgresql"); + }, "All models in a repository must be connected."); +}); diff --git a/src/lib/repository.ts b/src/lib/repository.ts index 0fb470a..920049e 100644 --- a/src/lib/repository.ts +++ b/src/lib/repository.ts @@ -55,6 +55,22 @@ export type ModelWithMatchingContext = [C] extends [ export type AnyRepository = Repository; +function isRepositoryProperlyConnected( + modelNames: string[], + components: string[][], +) { + if (modelNames.length === 1) { + return true; + } + + if (components.length === 1) { + const firstComponent = components[0]!; + return firstComponent.length === modelNames.length; + } + + return false; +} + export class Repository< TContext, TModelNames extends string = never, @@ -456,6 +472,16 @@ export class Repository< build>( dialectName: N, ) { + const repositoryGraphComponents = graphlib.alg.components(this.graph); + + invariant( + isRepositoryProperlyConnected( + Object.keys(this.models), + repositoryGraphComponents, + ), + "All models in a repository must be connected.", + ); + const dialect = AvailableDialects[dialectName]; return new QueryBuilder< TContext,