Skip to content

Commit

Permalink
feat: add skip population of individual joins (#8992)
Browse files Browse the repository at this point in the history
### What?
Adds new `sanitizeJoinQuery` that does following:
- Validates `where` for each join
- Executes access for the joined collection
- Combines join's `where` with the access result and join's default
`where`
- Moves default join's `where` handling to `sanitizeJoinQuery`

Adds ability to skip population of an individual join via
`joins[schemaPath] = false`. This is also used internally if
`executeAccess` returned `false`.
```ts
payload.find({
  collection: categoriesSlug,
  where: {
    id: { equals: category.id },
  },
  joins: {
    relatedPosts: false,
  },
})
```
  • Loading branch information
r1tsuu authored Nov 4, 2024
1 parent a7c22a3 commit 4266ab6
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 87 deletions.
6 changes: 5 additions & 1 deletion packages/db-mongodb/src/utilities/buildJoinAggregation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ export const buildJoinAggregation = async ({
continue
}

if (joins?.[join.schemaPath] === false) {
continue
}

const {
limit: limitJoin = join.field.defaultLimit ?? 10,
sort: sortJoin = join.field.defaultSort || collectionConfig.defaultSort,
Expand All @@ -83,7 +87,7 @@ export const buildJoinAggregation = async ({
const $match = await joinModel.buildQuery({
locale,
payload: adapter.payload,
where: combineQueries(whereJoin, join.field?.where ?? {}),
where: whereJoin,
})

const pipeline: Exclude<PipelineStage, PipelineStage.Merge | PipelineStage.Out>[] = [
Expand Down
10 changes: 8 additions & 2 deletions packages/drizzle/src/find/traverseFields.ts
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,17 @@ export const traverseFields = ({
break
}

const joinSchemaPath = `${path.replaceAll('_', '.')}${field.name}`

if (joinQuery[joinSchemaPath] === false) {
break
}

const {
limit: limitArg = field.defaultLimit ?? 10,
sort = field.defaultSort,
where,
} = joinQuery[`${path.replaceAll('_', '.')}${field.name}`] || {}
} = joinQuery[joinSchemaPath] || {}
let limit = limitArg

if (limit !== 0) {
Expand All @@ -442,7 +448,7 @@ export const traverseFields = ({
locale,
sort,
tableName: joinCollectionTableName,
where: combineQueries(where, field?.where ?? {}),
where,
})

let subQueryWhere = buildQueryResult.where
Expand Down
24 changes: 15 additions & 9 deletions packages/next/src/routes/rest/utilities/sanitizeJoinParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,27 @@ import { isNumber } from 'payload/shared'
export const sanitizeJoinParams = (
joins:
| {
[schemaPath: string]: {
limit?: unknown
sort?: string
where?: unknown
}
[schemaPath: string]:
| {
limit?: unknown
sort?: string
where?: unknown
}
| false
}
| false = {},
): JoinQuery => {
const joinQuery = {}

Object.keys(joins).forEach((schemaPath) => {
joinQuery[schemaPath] = {
limit: isNumber(joins[schemaPath]?.limit) ? Number(joins[schemaPath].limit) : undefined,
sort: joins[schemaPath]?.sort ? joins[schemaPath].sort : undefined,
where: joins[schemaPath]?.where ? joins[schemaPath].where : undefined,
if (joins[schemaPath] === 'false' || joins[schemaPath] === false) {
joinQuery[schemaPath] = false
} else {
joinQuery[schemaPath] = {
limit: isNumber(joins[schemaPath]?.limit) ? Number(joins[schemaPath].limit) : undefined,
sort: joins[schemaPath]?.sort ? joins[schemaPath].sort : undefined,
where: joins[schemaPath]?.where ? joins[schemaPath].where : undefined,
}
}
})

Expand Down
13 changes: 10 additions & 3 deletions packages/payload/src/collections/operations/find.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import type {
import executeAccess from '../../auth/executeAccess.js'
import { combineQueries } from '../../database/combineQueries.js'
import { validateQueryPaths } from '../../database/queryValidation/validateQueryPaths.js'
import { sanitizeJoinQuery } from '../../database/sanitizeJoinQuery.js'
import { afterRead } from '../../fields/hooks/afterRead/index.js'
import { killTransaction } from '../../utilities/killTransaction.js'
import { buildVersionCollectionFields } from '../../versions/buildCollectionFields.js'
Expand Down Expand Up @@ -129,6 +130,13 @@ export const findOperation = async <

let fullWhere = combineQueries(where, accessResult)

const sanitizedJoins = await sanitizeJoinQuery({
collectionConfig,
joins,
overrideAccess,
req,
})

if (collectionConfig.versions?.drafts && draftsEnabled) {
fullWhere = appendVersionToQueryKey(fullWhere)

Expand All @@ -142,7 +150,7 @@ export const findOperation = async <

result = await payload.db.queryDrafts<DataFromCollectionSlug<TSlug>>({
collection: collectionConfig.slug,
joins: req.payloadAPI === 'GraphQL' ? false : joins,
joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins,
limit: sanitizedLimit,
locale,
page: sanitizedPage,
Expand All @@ -155,15 +163,14 @@ export const findOperation = async <
} else {
await validateQueryPaths({
collectionConfig,
joins,
overrideAccess,
req,
where,
})

result = await payload.db.find<DataFromCollectionSlug<TSlug>>({
collection: collectionConfig.slug,
joins: req.payloadAPI === 'GraphQL' ? false : joins,
joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins,
limit: sanitizedLimit,
locale,
page: sanitizedPage,
Expand Down
11 changes: 9 additions & 2 deletions packages/payload/src/collections/operations/findByID.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type {

import executeAccess from '../../auth/executeAccess.js'
import { combineQueries } from '../../database/combineQueries.js'
import { sanitizeJoinQuery } from '../../database/sanitizeJoinQuery.js'
import { NotFound } from '../../errors/index.js'
import { afterRead } from '../../fields/hooks/afterRead/index.js'
import { validateQueryPaths } from '../../index.js'
Expand Down Expand Up @@ -94,9 +95,16 @@ export const findByIDOperation = async <

const where = combineQueries({ id: { equals: id } }, accessResult)

const sanitizedJoins = await sanitizeJoinQuery({
collectionConfig,
joins,
overrideAccess,
req,
})

const findOneArgs: FindOneArgs = {
collection: collectionConfig.slug,
joins: req.payloadAPI === 'GraphQL' ? false : joins,
joins: req.payloadAPI === 'GraphQL' ? false : sanitizedJoins,
locale,
req: {
transactionID: req.transactionID,
Expand All @@ -107,7 +115,6 @@ export const findByIDOperation = async <

await validateQueryPaths({
collectionConfig,
joins,
overrideAccess,
req,
where,
Expand Down
16 changes: 3 additions & 13 deletions packages/payload/src/database/getLocalizedPaths.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ export async function getLocalizedPaths({
}

case 'relationship':
case 'upload':
case 'join': {
case 'upload': {
// If this is a polymorphic relation,
// We only support querying directly (no nested querying)
if (matchedField.type !== 'join' && Array.isArray(matchedField.relationTo)) {
if (typeof matchedField.relationTo !== 'string') {
const lastSegmentIsValid =
['relationTo', 'value'].includes(pathSegments[pathSegments.length - 1]) ||
pathSegments.length === 1 ||
Expand All @@ -130,16 +129,7 @@ export async function getLocalizedPaths({
.join('.')

if (nestedPathToQuery) {
let slug: string
if (matchedField.type === 'join') {
slug = matchedField.collection
} else if (
// condition is only for type assertion
!Array.isArray(matchedField.relationTo)
) {
slug = matchedField.relationTo
}
const relatedCollection = payload.collections[slug].config
const relatedCollection = payload.collections[matchedField.relationTo].config

const remainingPaths = await getLocalizedPaths({
collectionSlug: relatedCollection.slug,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { SanitizedCollectionConfig } from '../../collections/config/types.js'
import type { Field, FieldAffectingData } from '../../fields/config/types.js'
import type { SanitizedGlobalConfig } from '../../globals/config/types.js'
import type { JoinQuery, Operator, PayloadRequest, Where, WhereField } from '../../types/index.js'
import type { Operator, PayloadRequest, Where, WhereField } from '../../types/index.js'
import type { EntityPolicies } from './types.js'

import { QueryError } from '../../errors/QueryError.js'
Expand All @@ -12,7 +12,6 @@ import { validateSearchParam } from './validateSearchParams.js'

type Args = {
errors?: { path: string }[]
joins?: JoinQuery
overrideAccess: boolean
policies?: EntityPolicies
req: PayloadRequest
Expand Down Expand Up @@ -42,67 +41,26 @@ const flattenWhere = (query: Where): WhereField[] =>
return [...flattenedConstraints, { [key]: val }]
}, [])

/**
* Iterates over the `where` object and to validate the field paths are correct and that the user has access to the fields
*/
export async function validateQueryPaths({
collectionConfig,
errors = [],
globalConfig,
joins,
overrideAccess,
policies = {
collections: {},
globals: {},
},
req,
versionFields,
where: whereArg,
where,
}: Args): Promise<void> {
let where = whereArg
const fields = flattenFields(
versionFields || (globalConfig || collectionConfig).fields,
) as FieldAffectingData[]
const promises = []

// Validate the user has access to configured join fields
if (collectionConfig?.joins) {
Object.entries(collectionConfig.joins).forEach(([collectionSlug, collectionJoins]) => {
collectionJoins.forEach((join) => {
if (join.field.where) {
promises.push(
validateQueryPaths({
collectionConfig: req.payload.config.collections.find(
(config) => config.slug === collectionSlug,
),
errors,
overrideAccess,
policies,
req,
where: join.field.where,
}),
)
}
})
})
}

if (joins) {
where = { ...whereArg }
// concat schemaPath of joins to the join.where to be passed for validation
Object.entries(joins).forEach(([schemaPath, { where: whereJoin }]) => {
if (whereJoin) {
Object.entries(whereJoin).forEach(([path, constraint]) => {
// merge the paths together to be handled the same way as relationships
where[`${schemaPath}.${path}`] = constraint
})
}
})
}

if (typeof where === 'object') {
const whereFields = flattenWhere(where)
// We need to determine if the whereKey is an AND, OR, or a schema path
const promises = []
void whereFields.map((constraint) => {
void Object.keys(constraint).map((path) => {
void Object.entries(constraint[path]).map(([operator, val]) => {
Expand All @@ -126,7 +84,6 @@ export async function validateQueryPaths({
})
})
})

await Promise.all(promises)
if (errors.length > 0) {
throw new QueryError(errors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ export async function validateSearchParam({
// Remove top collection and reverse array
// to work backwards from top
const pathsToQuery = paths.slice(1).reverse()
let joinCollectionSlug
if (field.type === 'join') {
joinCollectionSlug = field.collection
}

pathsToQuery.forEach(
({ collectionSlug: pathCollectionSlug, path: subPath }, pathToQueryIndex) => {
Expand All @@ -162,8 +158,7 @@ export async function validateSearchParam({
if (pathToQueryIndex === 0) {
promises.push(
validateQueryPaths({
collectionConfig:
req.payload.collections[joinCollectionSlug ?? pathCollectionSlug].config,
collectionConfig: req.payload.collections[pathCollectionSlug].config,
errors,
globalConfig: undefined,
overrideAccess,
Expand Down
Loading

0 comments on commit 4266ab6

Please sign in to comment.