From 4266ab65afd7905330d965b8f77ac80307c64fb7 Mon Sep 17 00:00:00 2001 From: Sasha <64744993+r1tsuu@users.noreply.github.com> Date: Mon, 4 Nov 2024 22:43:15 +0200 Subject: [PATCH] feat: add skip population of individual joins (#8992) ### What? Adds new `sanitizeJoinQuery` that does following: - Validates `where` for each join - Executes access for the joined collection - Combines join's `where` with the access result and join's default `where` - Moves default join's `where` handling to `sanitizeJoinQuery` Adds ability to skip population of an individual join via `joins[schemaPath] = false`. This is also used internally if `executeAccess` returned `false`. ```ts payload.find({ collection: categoriesSlug, where: { id: { equals: category.id }, }, joins: { relatedPosts: false, }, }) ``` --- .../src/utilities/buildJoinAggregation.ts | 6 +- packages/drizzle/src/find/traverseFields.ts | 10 +- .../rest/utilities/sanitizeJoinParams.ts | 24 +++-- .../src/collections/operations/find.ts | 13 ++- .../src/collections/operations/findByID.ts | 11 ++- .../payload/src/database/getLocalizedPaths.ts | 16 +--- .../queryValidation/validateQueryPaths.ts | 49 +--------- .../queryValidation/validateSearchParams.ts | 7 +- .../payload/src/database/sanitizeJoinQuery.ts | 92 +++++++++++++++++++ packages/payload/src/types/index.ts | 12 ++- test/joins/int.spec.ts | 47 ++++++++++ 11 files changed, 200 insertions(+), 87 deletions(-) create mode 100644 packages/payload/src/database/sanitizeJoinQuery.ts diff --git a/packages/db-mongodb/src/utilities/buildJoinAggregation.ts b/packages/db-mongodb/src/utilities/buildJoinAggregation.ts index 9851736810f..5cec7a5291e 100644 --- a/packages/db-mongodb/src/utilities/buildJoinAggregation.ts +++ b/packages/db-mongodb/src/utilities/buildJoinAggregation.ts @@ -64,6 +64,10 @@ export const buildJoinAggregation = async ({ continue } + if (joins?.[join.schemaPath] === false) { + continue + } + const { limit: limitJoin = join.field.defaultLimit ?? 10, sort: sortJoin = join.field.defaultSort || collectionConfig.defaultSort, @@ -83,7 +87,7 @@ export const buildJoinAggregation = async ({ const $match = await joinModel.buildQuery({ locale, payload: adapter.payload, - where: combineQueries(whereJoin, join.field?.where ?? {}), + where: whereJoin, }) const pipeline: Exclude[] = [ diff --git a/packages/drizzle/src/find/traverseFields.ts b/packages/drizzle/src/find/traverseFields.ts index 43b4c67e536..c5a35efa384 100644 --- a/packages/drizzle/src/find/traverseFields.ts +++ b/packages/drizzle/src/find/traverseFields.ts @@ -417,11 +417,17 @@ export const traverseFields = ({ break } + const joinSchemaPath = `${path.replaceAll('_', '.')}${field.name}` + + if (joinQuery[joinSchemaPath] === false) { + break + } + const { limit: limitArg = field.defaultLimit ?? 10, sort = field.defaultSort, where, - } = joinQuery[`${path.replaceAll('_', '.')}${field.name}`] || {} + } = joinQuery[joinSchemaPath] || {} let limit = limitArg if (limit !== 0) { @@ -442,7 +448,7 @@ export const traverseFields = ({ locale, sort, tableName: joinCollectionTableName, - where: combineQueries(where, field?.where ?? {}), + where, }) let subQueryWhere = buildQueryResult.where diff --git a/packages/next/src/routes/rest/utilities/sanitizeJoinParams.ts b/packages/next/src/routes/rest/utilities/sanitizeJoinParams.ts index 148404bb315..cdbd843c3d1 100644 --- a/packages/next/src/routes/rest/utilities/sanitizeJoinParams.ts +++ b/packages/next/src/routes/rest/utilities/sanitizeJoinParams.ts @@ -9,21 +9,27 @@ import { isNumber } from 'payload/shared' export const sanitizeJoinParams = ( joins: | { - [schemaPath: string]: { - limit?: unknown - sort?: string - where?: unknown - } + [schemaPath: string]: + | { + limit?: unknown + sort?: string + where?: unknown + } + | false } | false = {}, ): JoinQuery => { const joinQuery = {} Object.keys(joins).forEach((schemaPath) => { - joinQuery[schemaPath] = { - limit: isNumber(joins[schemaPath]?.limit) ? Number(joins[schemaPath].limit) : undefined, - sort: joins[schemaPath]?.sort ? joins[schemaPath].sort : undefined, - where: joins[schemaPath]?.where ? joins[schemaPath].where : undefined, + if (joins[schemaPath] === 'false' || joins[schemaPath] === false) { + joinQuery[schemaPath] = false + } else { + joinQuery[schemaPath] = { + limit: isNumber(joins[schemaPath]?.limit) ? Number(joins[schemaPath].limit) : undefined, + sort: joins[schemaPath]?.sort ? joins[schemaPath].sort : undefined, + where: joins[schemaPath]?.where ? joins[schemaPath].where : undefined, + } } }) diff --git a/packages/payload/src/collections/operations/find.ts b/packages/payload/src/collections/operations/find.ts index 718b6f8066b..2d3521e874b 100644 --- a/packages/payload/src/collections/operations/find.ts +++ b/packages/payload/src/collections/operations/find.ts @@ -17,6 +17,7 @@ import type { import executeAccess from '../../auth/executeAccess.js' import { combineQueries } from '../../database/combineQueries.js' import { validateQueryPaths } from '../../database/queryValidation/validateQueryPaths.js' +import { sanitizeJoinQuery } from '../../database/sanitizeJoinQuery.js' import { afterRead } from '../../fields/hooks/afterRead/index.js' import { killTransaction } from '../../utilities/killTransaction.js' import { buildVersionCollectionFields } from '../../versions/buildCollectionFields.js' @@ -129,6 +130,13 @@ export const findOperation = async < let fullWhere = combineQueries(where, accessResult) + const sanitizedJoins = await sanitizeJoinQuery({ + collectionConfig, + joins, + overrideAccess, + req, + }) + if (collectionConfig.versions?.drafts && draftsEnabled) { fullWhere = appendVersionToQueryKey(fullWhere) @@ -142,7 +150,7 @@ export const findOperation = async < result = await payload.db.queryDrafts>({ collection: collectionConfig.slug, - joins: req.payloadAPI === 'GraphQL' ? false : joins, + joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins, limit: sanitizedLimit, locale, page: sanitizedPage, @@ -155,7 +163,6 @@ export const findOperation = async < } else { await validateQueryPaths({ collectionConfig, - joins, overrideAccess, req, where, @@ -163,7 +170,7 @@ export const findOperation = async < result = await payload.db.find>({ collection: collectionConfig.slug, - joins: req.payloadAPI === 'GraphQL' ? false : joins, + joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins, limit: sanitizedLimit, locale, page: sanitizedPage, diff --git a/packages/payload/src/collections/operations/findByID.ts b/packages/payload/src/collections/operations/findByID.ts index f13b5a18069..28e61ae7103 100644 --- a/packages/payload/src/collections/operations/findByID.ts +++ b/packages/payload/src/collections/operations/findByID.ts @@ -14,6 +14,7 @@ import type { import executeAccess from '../../auth/executeAccess.js' import { combineQueries } from '../../database/combineQueries.js' +import { sanitizeJoinQuery } from '../../database/sanitizeJoinQuery.js' import { NotFound } from '../../errors/index.js' import { afterRead } from '../../fields/hooks/afterRead/index.js' import { validateQueryPaths } from '../../index.js' @@ -94,9 +95,16 @@ export const findByIDOperation = async < const where = combineQueries({ id: { equals: id } }, accessResult) + const sanitizedJoins = await sanitizeJoinQuery({ + collectionConfig, + joins, + overrideAccess, + req, + }) + const findOneArgs: FindOneArgs = { collection: collectionConfig.slug, - joins: req.payloadAPI === 'GraphQL' ? false : joins, + joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins, locale, req: { transactionID: req.transactionID, @@ -107,7 +115,6 @@ export const findByIDOperation = async < await validateQueryPaths({ collectionConfig, - joins, overrideAccess, req, where, diff --git a/packages/payload/src/database/getLocalizedPaths.ts b/packages/payload/src/database/getLocalizedPaths.ts index 73960bf3653..07e731ef791 100644 --- a/packages/payload/src/database/getLocalizedPaths.ts +++ b/packages/payload/src/database/getLocalizedPaths.ts @@ -103,11 +103,10 @@ export async function getLocalizedPaths({ } case 'relationship': - case 'upload': - case 'join': { + case 'upload': { // If this is a polymorphic relation, // We only support querying directly (no nested querying) - if (matchedField.type !== 'join' && Array.isArray(matchedField.relationTo)) { + if (typeof matchedField.relationTo !== 'string') { const lastSegmentIsValid = ['relationTo', 'value'].includes(pathSegments[pathSegments.length - 1]) || pathSegments.length === 1 || @@ -130,16 +129,7 @@ export async function getLocalizedPaths({ .join('.') if (nestedPathToQuery) { - let slug: string - if (matchedField.type === 'join') { - slug = matchedField.collection - } else if ( - // condition is only for type assertion - !Array.isArray(matchedField.relationTo) - ) { - slug = matchedField.relationTo - } - const relatedCollection = payload.collections[slug].config + const relatedCollection = payload.collections[matchedField.relationTo].config const remainingPaths = await getLocalizedPaths({ collectionSlug: relatedCollection.slug, diff --git a/packages/payload/src/database/queryValidation/validateQueryPaths.ts b/packages/payload/src/database/queryValidation/validateQueryPaths.ts index dbeb40cb7a9..1639fb4256c 100644 --- a/packages/payload/src/database/queryValidation/validateQueryPaths.ts +++ b/packages/payload/src/database/queryValidation/validateQueryPaths.ts @@ -1,7 +1,7 @@ import type { SanitizedCollectionConfig } from '../../collections/config/types.js' import type { Field, FieldAffectingData } from '../../fields/config/types.js' import type { SanitizedGlobalConfig } from '../../globals/config/types.js' -import type { JoinQuery, Operator, PayloadRequest, Where, WhereField } from '../../types/index.js' +import type { Operator, PayloadRequest, Where, WhereField } from '../../types/index.js' import type { EntityPolicies } from './types.js' import { QueryError } from '../../errors/QueryError.js' @@ -12,7 +12,6 @@ import { validateSearchParam } from './validateSearchParams.js' type Args = { errors?: { path: string }[] - joins?: JoinQuery overrideAccess: boolean policies?: EntityPolicies req: PayloadRequest @@ -42,14 +41,10 @@ const flattenWhere = (query: Where): WhereField[] => return [...flattenedConstraints, { [key]: val }] }, []) -/** - * Iterates over the `where` object and to validate the field paths are correct and that the user has access to the fields - */ export async function validateQueryPaths({ collectionConfig, errors = [], globalConfig, - joins, overrideAccess, policies = { collections: {}, @@ -57,52 +52,15 @@ export async function validateQueryPaths({ }, req, versionFields, - where: whereArg, + where, }: Args): Promise { - let where = whereArg const fields = flattenFields( versionFields || (globalConfig || collectionConfig).fields, ) as FieldAffectingData[] - const promises = [] - - // Validate the user has access to configured join fields - if (collectionConfig?.joins) { - Object.entries(collectionConfig.joins).forEach(([collectionSlug, collectionJoins]) => { - collectionJoins.forEach((join) => { - if (join.field.where) { - promises.push( - validateQueryPaths({ - collectionConfig: req.payload.config.collections.find( - (config) => config.slug === collectionSlug, - ), - errors, - overrideAccess, - policies, - req, - where: join.field.where, - }), - ) - } - }) - }) - } - - if (joins) { - where = { ...whereArg } - // concat schemaPath of joins to the join.where to be passed for validation - Object.entries(joins).forEach(([schemaPath, { where: whereJoin }]) => { - if (whereJoin) { - Object.entries(whereJoin).forEach(([path, constraint]) => { - // merge the paths together to be handled the same way as relationships - where[`${schemaPath}.${path}`] = constraint - }) - } - }) - } - if (typeof where === 'object') { const whereFields = flattenWhere(where) // We need to determine if the whereKey is an AND, OR, or a schema path + const promises = [] void whereFields.map((constraint) => { void Object.keys(constraint).map((path) => { void Object.entries(constraint[path]).map(([operator, val]) => { @@ -126,7 +84,6 @@ export async function validateQueryPaths({ }) }) }) - await Promise.all(promises) if (errors.length > 0) { throw new QueryError(errors) diff --git a/packages/payload/src/database/queryValidation/validateSearchParams.ts b/packages/payload/src/database/queryValidation/validateSearchParams.ts index 986403b855e..8a781a84c8f 100644 --- a/packages/payload/src/database/queryValidation/validateSearchParams.ts +++ b/packages/payload/src/database/queryValidation/validateSearchParams.ts @@ -150,10 +150,6 @@ export async function validateSearchParam({ // Remove top collection and reverse array // to work backwards from top const pathsToQuery = paths.slice(1).reverse() - let joinCollectionSlug - if (field.type === 'join') { - joinCollectionSlug = field.collection - } pathsToQuery.forEach( ({ collectionSlug: pathCollectionSlug, path: subPath }, pathToQueryIndex) => { @@ -162,8 +158,7 @@ export async function validateSearchParam({ if (pathToQueryIndex === 0) { promises.push( validateQueryPaths({ - collectionConfig: - req.payload.collections[joinCollectionSlug ?? pathCollectionSlug].config, + collectionConfig: req.payload.collections[pathCollectionSlug].config, errors, globalConfig: undefined, overrideAccess, diff --git a/packages/payload/src/database/sanitizeJoinQuery.ts b/packages/payload/src/database/sanitizeJoinQuery.ts new file mode 100644 index 00000000000..e3f8725bf19 --- /dev/null +++ b/packages/payload/src/database/sanitizeJoinQuery.ts @@ -0,0 +1,92 @@ +import type { SanitizedCollectionConfig } from '../collections/config/types.js' +import type { JoinQuery, PayloadRequest } from '../types/index.js' + +import executeAccess from '../auth/executeAccess.js' +import { QueryError } from '../errors/QueryError.js' +import { combineQueries } from './combineQueries.js' +import { validateQueryPaths } from './queryValidation/validateQueryPaths.js' + +type Args = { + collectionConfig: SanitizedCollectionConfig + joins?: JoinQuery + overrideAccess: boolean + req: PayloadRequest +} + +/** + * * Validates `where` for each join + * * Combines the access result for joined collection + * * Combines the default join's `where` + */ +export const sanitizeJoinQuery = async ({ + collectionConfig, + joins: joinsQuery, + overrideAccess, + req, +}: Args) => { + if (joinsQuery === false) { + return false + } + + if (!joinsQuery) { + joinsQuery = {} + } + + const errors: { path: string }[] = [] + const promises: Promise[] = [] + + for (const collectionSlug in collectionConfig.joins) { + for (const { field, schemaPath } of collectionConfig.joins[collectionSlug]) { + if (joinsQuery[schemaPath] === false) { + continue + } + + const joinCollectionConfig = req.payload.collections[collectionSlug].config + + const accessResult = !overrideAccess + ? await executeAccess({ disableErrors: true, req }, joinCollectionConfig.access.read) + : true + + if (accessResult === false) { + joinsQuery[schemaPath] = false + continue + } + + if (!joinsQuery[schemaPath]) { + joinsQuery[schemaPath] = {} + } + + const joinQuery = joinsQuery[schemaPath] + + if (!joinQuery.where) { + joinQuery.where = {} + } + + if (field.where) { + joinQuery.where = combineQueries(joinQuery.where, field.where) + } + + if (typeof accessResult === 'object') { + joinQuery.where = combineQueries(joinQuery.where, accessResult) + } + + promises.push( + validateQueryPaths({ + collectionConfig: joinCollectionConfig, + errors, + overrideAccess, + req, + where: joinQuery.where, + }), + ) + } + } + + await Promise.all(promises) + + if (errors.length > 0) { + throw new QueryError(errors) + } + + return joinsQuery +} diff --git a/packages/payload/src/types/index.ts b/packages/payload/src/types/index.ts index 192aa06eeb9..f472b88fd7b 100644 --- a/packages/payload/src/types/index.ts +++ b/packages/payload/src/types/index.ts @@ -127,11 +127,13 @@ export type Sort = Array | string */ export type JoinQuery = | { - [schemaPath: string]: { - limit?: number - sort?: string - where?: Where - } + [schemaPath: string]: + | { + limit?: number + sort?: string + where?: Where + } + | false } | false diff --git a/test/joins/int.spec.ts b/test/joins/int.spec.ts index e9f1bc41f8a..762b14a4391 100644 --- a/test/joins/int.spec.ts +++ b/test/joins/int.spec.ts @@ -784,6 +784,53 @@ describe('Joins Field', () => { expect((categoryWithJoins.singulars.docs[0] as Singular).id).toBe(singular.id) }) + + it('local API should not populate individual join by providing schemaPath=false', async () => { + const { + docs: [res], + } = await payload.find({ + collection: categoriesSlug, + where: { + id: { equals: category.id }, + }, + joins: { + relatedPosts: false, + }, + }) + + // removed from the result + expect(res.relatedPosts).toBeUndefined() + + expect(res.hasManyPosts.docs).toBeDefined() + expect(res.hasManyPostsLocalized.docs).toBeDefined() + expect(res.group.relatedPosts.docs).toBeDefined() + expect(res.group.camelCasePosts.docs).toBeDefined() + }) + + it('rEST API should not populate individual join by providing schemaPath=false', async () => { + const { + docs: [res], + } = await restClient + .GET(`/${categoriesSlug}`, { + query: { + where: { + id: { equals: category.id }, + }, + joins: { + relatedPosts: false, + }, + }, + }) + .then((res) => res.json()) + + // removed from the result + expect(res.relatedPosts).toBeUndefined() + + expect(res.hasManyPosts.docs).toBeDefined() + expect(res.hasManyPostsLocalized.docs).toBeDefined() + expect(res.group.relatedPosts.docs).toBeDefined() + expect(res.group.camelCasePosts.docs).toBeDefined() + }) }) async function createPost(overrides?: Partial, locale?: Config['locale']) {