Skip to content

Commit

Permalink
refactor: on repository build validate that repository models are pro…
Browse files Browse the repository at this point in the history
…perly connected
  • Loading branch information
retro committed Sep 4, 2024
1 parent 75f2cf2 commit 4592b9b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
122 changes: 122 additions & 0 deletions src/__tests__/repository-components.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import * as semanticLayer from "../index.js";

import { assert, it } from "vitest";

const customersModel = semanticLayer
.model()
.withName("customers")
.fromTable("Customer")
.withDimension("customer_id", {
type: "number",
primaryKey: true,
sql: ({ model, sql }) => sql`${model.column("CustomerId")}`,
})
.withDimension("first_name", {
type: "string",
sql: ({ model }) => model.column("FirstName"),
});

const invoicesModel = semanticLayer
.model()
.withName("invoices")
.fromTable("Invoice")
.withDimension("invoice_id", {
type: "number",
primaryKey: true,
sql: ({ model }) => model.column("InvoiceId"),
})
.withDimension("customer_id", {
type: "number",
sql: ({ model }) => model.column("CustomerId"),
});

const invoiceLinesModel = semanticLayer
.model()
.withName("invoice_lines")
.fromTable("InvoiceLine")
.withDimension("invoice_line_id", {
type: "number",
primaryKey: true,
sql: ({ model }) => model.column("InvoiceLineId"),
})
.withDimension("invoice_id", {
type: "number",
sql: ({ model }) => model.column("InvoiceId"),
})
.withDimension("track_id", {
type: "number",
sql: ({ model }) => model.column("TrackId"),
});

const tracksModel = semanticLayer
.model()
.withName("tracks")
.fromTable("Track")
.withDimension("track_id", {
type: "number",
primaryKey: true,
sql: ({ model }) => model.column("TrackId"),
});

it("will correctly check if all models are connected when no joins exists", () => {
const repository = semanticLayer
.repository()
.withModel(customersModel)
.withModel(invoicesModel)
.withModel(invoiceLinesModel)
.withModel(tracksModel);

assert.throws(() => {
repository.build("postgresql");
}, "All models in a repository must be connected.");
});

it("will correctly check if all models are connected when only some models are connected (1)", () => {
const repository = semanticLayer
.repository()
.withModel(customersModel)
.withModel(invoicesModel)
.withModel(invoiceLinesModel)
.withModel(tracksModel)
.joinOneToMany(
"customers",
"invoices",
({ sql, models }) =>
sql`${models.customers.dimension(
"customer_id",
)} = ${models.invoices.dimension("customer_id")}`,
);

assert.throws(() => {
repository.build("postgresql");
}, "All models in a repository must be connected.");
});

it("will correctly check if all models are connected when only some models are connected (2)", () => {
const repository = semanticLayer
.repository()
.withModel(customersModel)
.withModel(invoicesModel)
.withModel(invoiceLinesModel)
.withModel(tracksModel)
.joinOneToMany(
"customers",
"invoices",
({ sql, models }) =>
sql`${models.customers.dimension(
"customer_id",
)} = ${models.invoices.dimension("customer_id")}`,
)
.joinManyToOne(
"invoice_lines",
"tracks",
({ sql, models }) =>
sql`${models.invoice_lines.dimension(
"track_id",
)} = ${models.tracks.dimension("track_id")}`,
);

assert.throws(() => {
repository.build("postgresql");
}, "All models in a repository must be connected.");
});
26 changes: 26 additions & 0 deletions src/lib/repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ export type ModelWithMatchingContext<C, T extends AnyModel> = [C] extends [

export type AnyRepository = Repository<any, any, any, any, any, any>;

function isRepositoryProperlyConnected(
modelNames: string[],
components: string[][],
) {
if (modelNames.length === 1) {
return true;
}

if (components.length === 1) {
const firstComponent = components[0]!;
return firstComponent.length === modelNames.length;
}

return false;
}

export class Repository<
TContext,
TModelNames extends string = never,
Expand Down Expand Up @@ -456,6 +472,16 @@ export class Repository<
build<N extends AvailableDialectsNames, P = DialectParamsReturnType<N>>(
dialectName: N,
) {
const repositoryGraphComponents = graphlib.alg.components(this.graph);

invariant(
isRepositoryProperlyConnected(
Object.keys(this.models),
repositoryGraphComponents,
),
"All models in a repository must be connected.",
);

const dialect = AvailableDialects[dialectName];
return new QueryBuilder<
TContext,
Expand Down

0 comments on commit 4592b9b

Please sign in to comment.