diff --git a/packages/toolkit/src/createDraftSafeSelector.ts b/packages/toolkit/src/createDraftSafeSelector.ts index b1858af2e7..bd58ee406a 100644 --- a/packages/toolkit/src/createDraftSafeSelector.ts +++ b/packages/toolkit/src/createDraftSafeSelector.ts @@ -1,5 +1,23 @@ -import { current, isDraft } from 'immer' import { createSelector } from 'reselect' +import type { ImmutableHelpers } from './tsHelpers' +import { immutableHelpers } from './immer' + +export type BuildCreateDraftSafeSelectorConfiguration = Pick< + ImmutableHelpers, + 'isDraft' | 'current' +> + +export function buildCreateDraftSafeSelector({ + isDraft, + current, +}: BuildCreateDraftSafeSelectorConfiguration): typeof createSelector { + return function createDraftSafeSelector(...args: unknown[]) { + const selector = (createSelector as any)(...args) + const wrappedSelector = (value: unknown, ...rest: unknown[]) => + selector(isDraft(value) ? current(value) : value, ...rest) + return wrappedSelector as any + } +} /** * "Draft-Safe" version of `reselect`'s `createSelector`: @@ -8,11 +26,5 @@ import { createSelector } from 'reselect' * that might be possibly outdated if the draft has been modified since. * @public */ -export const createDraftSafeSelector: typeof createSelector = ( - ...args: unknown[] -) => { - const selector = (createSelector as any)(...args) - const wrappedSelector = (value: unknown, ...rest: unknown[]) => - selector(isDraft(value) ? current(value) : value, ...rest) - return wrappedSelector as any -} +export const createDraftSafeSelector: typeof createSelector = + buildCreateDraftSafeSelector(immutableHelpers) diff --git a/packages/toolkit/src/createReducer.ts b/packages/toolkit/src/createReducer.ts index 3690f274f7..5feb0c01d4 100644 --- a/packages/toolkit/src/createReducer.ts +++ b/packages/toolkit/src/createReducer.ts @@ -1,10 +1,9 @@ import type { Draft } from 'immer' -import { produce as createNextState, isDraft, isDraftable } from 'immer' import type { AnyAction, Action, Reducer } from 'redux' import type { ActionReducerMapBuilder } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' -import type { NoInfer } from './tsHelpers' -import { makeFreezeDraftable } from './utils' +import type { ImmutableHelpers, NoInfer } from './tsHelpers' +import { immutableHelpers } from './immer' /** * Defines a mapping from action types to corresponding action object shapes. @@ -153,25 +152,17 @@ const reducer = createReducer( ): ReducerWithInitialState } -export interface BuildCreateReducerConfiguration { - createNextState: ( - base: Base, - recipe: (draft: Draft) => void | Base | Draft - ) => Base - isDraft(value: any): boolean - isDraftable(value: any): boolean -} +export type BuildCreateReducerConfiguration = Pick< + ImmutableHelpers, + 'createNextState' | 'isDraft' | 'isDraftable' | 'freeze' +> export function buildCreateReducer({ createNextState, isDraft, isDraftable, + freeze, }: BuildCreateReducerConfiguration): CreateReducer { - const freezeDraftable = makeFreezeDraftable({ - createNextState, - isDraft, - isDraftable, - }) return function createReducer>( initialState: S | (() => S), mapOrBuilderCallback: (builder: ActionReducerMapBuilder) => void @@ -190,9 +181,9 @@ export function buildCreateReducer({ // Ensure the initial state gets frozen either way (if draftable) let getInitialState: () => S if (isStateFunction(initialState)) { - getInitialState = () => freezeDraftable(initialState()) + getInitialState = () => freeze(initialState(), true) } else { - const frozenInitialState = freezeDraftable(initialState) + const frozenInitialState = freeze(initialState, true) getInitialState = () => frozenInitialState } @@ -256,8 +247,4 @@ export function buildCreateReducer({ } } -export const createReducer = buildCreateReducer({ - createNextState, - isDraft, - isDraftable, -}) +export const createReducer = buildCreateReducer(immutableHelpers) diff --git a/packages/toolkit/src/createSlice.ts b/packages/toolkit/src/createSlice.ts index d2ad2b8368..3e2f9902c8 100644 --- a/packages/toolkit/src/createSlice.ts +++ b/packages/toolkit/src/createSlice.ts @@ -1,5 +1,4 @@ import type { Reducer } from 'redux' -import { produce as createNextState, isDraft, isDraftable } from 'immer' import type { ActionCreatorWithoutPayload, PayloadAction, @@ -18,7 +17,7 @@ import { buildCreateReducer } from './createReducer' import type { ActionReducerMapBuilder } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' import type { NoInfer } from './tsHelpers' -import { makeFreezeDraftable } from './utils' +import { immutableHelpers } from './immer' let hasWarnedAboutObjectNotation = false @@ -283,7 +282,7 @@ export function buildCreateSlice( configuration: BuildCreateSliceConfiguration ): CreateSlice { const createReducer = buildCreateReducer(configuration) - const freezeDraftable = makeFreezeDraftable(configuration) + const { freeze } = configuration return function createSlice< State, @@ -311,7 +310,7 @@ export function buildCreateSlice( const initialState = typeof options.initialState == 'function' ? options.initialState - : freezeDraftable(options.initialState) + : freeze(options.initialState) const reducers = options.reducers || {} @@ -394,8 +393,4 @@ export function buildCreateSlice( } } -export const createSlice = buildCreateSlice({ - createNextState, - isDraft, - isDraftable, -}) +export const createSlice = buildCreateSlice(immutableHelpers) diff --git a/packages/toolkit/src/entities/create_adapter.ts b/packages/toolkit/src/entities/create_adapter.ts index 0d0c77e9d1..31c263362b 100644 --- a/packages/toolkit/src/entities/create_adapter.ts +++ b/packages/toolkit/src/entities/create_adapter.ts @@ -5,39 +5,56 @@ import type { EntityAdapter, } from './models' import { createInitialStateFactory } from './entity_state' -import { createSelectorsFactory } from './state_selectors' -import { createSortedStateAdapter } from './sorted_state_adapter' -import { createUnsortedStateAdapter } from './unsorted_state_adapter' +import { buildCreateSelectorsFactory } from './state_selectors' +import { buildCreateSortedStateAdapter } from './sorted_state_adapter' +import { buildCreateUnsortedStateAdapter } from './unsorted_state_adapter' +import type { BuildCreateDraftSafeSelectorConfiguration } from '..' +import type { BuildStateOperatorConfiguration } from './state_adapter' +import { immutableHelpers } from '../immer' -/** - * - * @param options - * - * @public - */ -export function createEntityAdapter( - options: { +export interface BuildCreateEntityAdapterConfiguration + extends BuildCreateDraftSafeSelectorConfiguration, + BuildStateOperatorConfiguration {} + +export type CreateEntityAdapter = { + (options?: { selectId?: IdSelector sortComparer?: false | Comparer - } = {} -): EntityAdapter { - const { selectId, sortComparer }: EntityDefinition = { - sortComparer: false, - selectId: (instance: any) => instance.id, - ...options, - } + }): EntityAdapter +} - const stateFactory = createInitialStateFactory() - const selectorsFactory = createSelectorsFactory() - const stateAdapter = sortComparer - ? createSortedStateAdapter(selectId, sortComparer) - : createUnsortedStateAdapter(selectId) +export function buildCreateEntityAdapter( + config: BuildCreateEntityAdapterConfiguration +): CreateEntityAdapter { + const createSelectorsFactory = buildCreateSelectorsFactory(config) + const createUnsortedStateAdapter = buildCreateUnsortedStateAdapter(config) + const createSortedStateAdapter = buildCreateSortedStateAdapter(config) + return function createEntityAdapter( + options: { + selectId?: IdSelector + sortComparer?: false | Comparer + } = {} + ): EntityAdapter { + const { selectId, sortComparer }: EntityDefinition = { + sortComparer: false, + selectId: (instance: any) => instance.id, + ...options, + } - return { - selectId, - sortComparer, - ...stateFactory, - ...selectorsFactory, - ...stateAdapter, + const stateFactory = createInitialStateFactory() + const selectorsFactory = createSelectorsFactory() + const stateAdapter = sortComparer + ? createSortedStateAdapter(selectId, sortComparer) + : createUnsortedStateAdapter(selectId) + + return { + selectId, + sortComparer, + ...stateFactory, + ...selectorsFactory, + ...stateAdapter, + } } } + +export const createEntityAdapter = buildCreateEntityAdapter(immutableHelpers) diff --git a/packages/toolkit/src/entities/sorted_state_adapter.ts b/packages/toolkit/src/entities/sorted_state_adapter.ts index 4f1ed7a372..bc389ab69c 100644 --- a/packages/toolkit/src/entities/sorted_state_adapter.ts +++ b/packages/toolkit/src/entities/sorted_state_adapter.ts @@ -6,163 +6,170 @@ import type { Update, EntityId, } from './models' -import { createStateOperator } from './state_adapter' -import { createUnsortedStateAdapter } from './unsorted_state_adapter' +import type { BuildStateOperatorConfiguration } from './state_adapter' +import { buildCreateStateOperator } from './state_adapter' +import { buildCreateUnsortedStateAdapter } from './unsorted_state_adapter' import { selectIdValue, ensureEntitiesArray, splitAddedUpdatedEntities, } from './utils' -export function createSortedStateAdapter( - selectId: IdSelector, - sort: Comparer -): EntityStateAdapter { - type R = EntityState - - const { removeOne, removeMany, removeAll } = - createUnsortedStateAdapter(selectId) - - function addOneMutably(entity: T, state: R): void { - return addManyMutably([entity], state) - } +export function buildCreateSortedStateAdapter( + config: BuildStateOperatorConfiguration +) { + const createUnsortedStateAdapter = buildCreateUnsortedStateAdapter(config) + const createStateOperator = buildCreateStateOperator(config) + return function createSortedStateAdapter( + selectId: IdSelector, + sort: Comparer + ): EntityStateAdapter { + type R = EntityState + + const { removeOne, removeMany, removeAll } = + createUnsortedStateAdapter(selectId) + + function addOneMutably(entity: T, state: R): void { + return addManyMutably([entity], state) + } - function addManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function addManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - const models = newEntities.filter( - (model) => !(selectIdValue(model, selectId) in state.entities) - ) + const models = newEntities.filter( + (model) => !(selectIdValue(model, selectId) in state.entities) + ) - if (models.length !== 0) { - merge(models, state) + if (models.length !== 0) { + merge(models, state) + } } - } - - function setOneMutably(entity: T, state: R): void { - return setManyMutably([entity], state) - } - function setManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - if (newEntities.length !== 0) { - merge(newEntities, state) + function setOneMutably(entity: T, state: R): void { + return setManyMutably([entity], state) } - } - function setAllMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - state.entities = {} - state.ids = [] + function setManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + if (newEntities.length !== 0) { + merge(newEntities, state) + } + } - addManyMutably(newEntities, state) - } + function setAllMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + state.entities = {} + state.ids = [] - function updateOneMutably(update: Update, state: R): void { - return updateManyMutably([update], state) - } + addManyMutably(newEntities, state) + } - function updateManyMutably( - updates: ReadonlyArray>, - state: R - ): void { - let appliedUpdates = false + function updateOneMutably(update: Update, state: R): void { + return updateManyMutably([update], state) + } - for (let update of updates) { - const entity = state.entities[update.id] - if (!entity) { - continue + function updateManyMutably( + updates: ReadonlyArray>, + state: R + ): void { + let appliedUpdates = false + + for (let update of updates) { + const entity = state.entities[update.id] + if (!entity) { + continue + } + + appliedUpdates = true + + Object.assign(entity, update.changes) + const newId = selectId(entity) + if (update.id !== newId) { + delete state.entities[update.id] + state.entities[newId] = entity + } } - appliedUpdates = true - - Object.assign(entity, update.changes) - const newId = selectId(entity) - if (update.id !== newId) { - delete state.entities[update.id] - state.entities[newId] = entity + if (appliedUpdates) { + resortEntities(state) } } - if (appliedUpdates) { - resortEntities(state) + function upsertOneMutably(entity: T, state: R): void { + return upsertManyMutably([entity], state) } - } - - function upsertOneMutably(entity: T, state: R): void { - return upsertManyMutably([entity], state) - } - function upsertManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - const [added, updated] = splitAddedUpdatedEntities( - newEntities, - selectId, - state - ) - - updateManyMutably(updated, state) - addManyMutably(added, state) - } - - function areArraysEqual(a: readonly unknown[], b: readonly unknown[]) { - if (a.length !== b.length) { - return false + function upsertManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + const [added, updated] = splitAddedUpdatedEntities( + newEntities, + selectId, + state + ) + + updateManyMutably(updated, state) + addManyMutably(added, state) } - for (let i = 0; i < a.length && i < b.length; i++) { - if (a[i] === b[i]) { - continue + function areArraysEqual(a: readonly unknown[], b: readonly unknown[]) { + if (a.length !== b.length) { + return false + } + + for (let i = 0; i < a.length && i < b.length; i++) { + if (a[i] === b[i]) { + continue + } + return false } - return false + return true } - return true - } - function merge(models: readonly T[], state: R): void { - // Insert/overwrite all new/updated - models.forEach((model) => { - state.entities[selectId(model)] = model - }) + function merge(models: readonly T[], state: R): void { + // Insert/overwrite all new/updated + models.forEach((model) => { + state.entities[selectId(model)] = model + }) - resortEntities(state) - } + resortEntities(state) + } - function resortEntities(state: R) { - const allEntities = Object.values(state.entities) as T[] - allEntities.sort(sort) + function resortEntities(state: R) { + const allEntities = Object.values(state.entities) as T[] + allEntities.sort(sort) - const newSortedIds = allEntities.map(selectId) - const { ids } = state + const newSortedIds = allEntities.map(selectId) + const { ids } = state - if (!areArraysEqual(ids, newSortedIds)) { - state.ids = newSortedIds + if (!areArraysEqual(ids, newSortedIds)) { + state.ids = newSortedIds + } } - } - return { - removeOne, - removeMany, - removeAll, - addOne: createStateOperator(addOneMutably), - updateOne: createStateOperator(updateOneMutably), - upsertOne: createStateOperator(upsertOneMutably), - setOne: createStateOperator(setOneMutably), - setMany: createStateOperator(setManyMutably), - setAll: createStateOperator(setAllMutably), - addMany: createStateOperator(addManyMutably), - updateMany: createStateOperator(updateManyMutably), - upsertMany: createStateOperator(upsertManyMutably), + return { + removeOne, + removeMany, + removeAll, + addOne: createStateOperator(addOneMutably), + updateOne: createStateOperator(updateOneMutably), + upsertOne: createStateOperator(upsertOneMutably), + setOne: createStateOperator(setOneMutably), + setMany: createStateOperator(setManyMutably), + setAll: createStateOperator(setAllMutably), + addMany: createStateOperator(addManyMutably), + updateMany: createStateOperator(updateManyMutably), + upsertMany: createStateOperator(upsertManyMutably), + } } } diff --git a/packages/toolkit/src/entities/state_adapter.ts b/packages/toolkit/src/entities/state_adapter.ts index 220abae40a..d65f35b68a 100644 --- a/packages/toolkit/src/entities/state_adapter.ts +++ b/packages/toolkit/src/entities/state_adapter.ts @@ -1,57 +1,72 @@ -import { produce as createNextState, isDraft } from 'immer' import type { EntityState, PreventAny } from './models' import type { PayloadAction } from '../createAction' import { isFSA } from '../createAction' +import type { ImmutableHelpers } from '../tsHelpers' import { IsAny } from '../tsHelpers' -export function createSingleArgumentStateOperator( - mutator: (state: EntityState) => void +export type BuildStateOperatorConfiguration = Pick< + ImmutableHelpers, + 'isDraft' | 'createNextState' +> + +export function buildCreateSingleArgumentStateOperator( + config: BuildStateOperatorConfiguration ) { - const operator = createStateOperator((_: undefined, state: EntityState) => - mutator(state) - ) + const createStateOperator = buildCreateStateOperator(config) + return function createSingleArgumentStateOperator( + mutator: (state: EntityState) => void + ) { + const operator = createStateOperator( + (_: undefined, state: EntityState) => mutator(state) + ) - return function operation>( - state: PreventAny - ): S { - return operator(state as S, undefined) + return function operation>( + state: PreventAny + ): S { + return operator(state as S, undefined) + } } } -export function createStateOperator( - mutator: (arg: R, state: EntityState) => void -) { - return function operation>( - state: S, - arg: R | PayloadAction - ): S { - function isPayloadActionArgument( +export function buildCreateStateOperator({ + isDraft, + createNextState, +}: BuildStateOperatorConfiguration) { + return function createStateOperator( + mutator: (arg: R, state: EntityState) => void + ) { + return function operation>( + state: S, arg: R | PayloadAction - ): arg is PayloadAction { - return isFSA(arg) - } + ): S { + function isPayloadActionArgument( + arg: R | PayloadAction + ): arg is PayloadAction { + return isFSA(arg) + } - const runMutator = (draft: EntityState) => { - if (isPayloadActionArgument(arg)) { - mutator(arg.payload, draft) - } else { - mutator(arg, draft) + const runMutator = (draft: EntityState) => { + if (isPayloadActionArgument(arg)) { + mutator(arg.payload, draft) + } else { + mutator(arg, draft) + } } - } - if (isDraft(state)) { - // we must already be inside a `createNextState` call, likely because - // this is being wrapped in `createReducer` or `createSlice`. - // It's safe to just pass the draft to the mutator. - runMutator(state) + if (isDraft(state)) { + // we must already be inside a `createNextState` call, likely because + // this is being wrapped in `createReducer` or `createSlice`. + // It's safe to just pass the draft to the mutator. + runMutator(state) - // since it's a draft, we'll just return it - return state - } else { - // @ts-ignore createNextState() produces an Immutable> rather - // than an Immutable, and TypeScript cannot find out how to reconcile - // these two types. - return createNextState(state, runMutator) + // since it's a draft, we'll just return it + return state + } else { + // @ts-ignore createNextState() produces an Immutable> rather + // than an Immutable, and TypeScript cannot find out how to reconcile + // these two types. + return createNextState(state, runMutator) + } } } } diff --git a/packages/toolkit/src/entities/state_selectors.ts b/packages/toolkit/src/entities/state_selectors.ts index 46f59d3d9e..c7018e4a47 100644 --- a/packages/toolkit/src/entities/state_selectors.ts +++ b/packages/toolkit/src/entities/state_selectors.ts @@ -1,5 +1,6 @@ import type { Selector } from 'reselect' -import { createDraftSafeSelector } from '../createDraftSafeSelector' +import type { BuildCreateDraftSafeSelectorConfiguration } from '../createDraftSafeSelector' +import { buildCreateDraftSafeSelector } from '../createDraftSafeSelector' import type { EntityState, EntitySelectors, @@ -7,61 +8,70 @@ import type { EntityId, } from './models' -export function createSelectorsFactory() { - function getSelectors(): EntitySelectors> - function getSelectors( - selectState: (state: V) => EntityState - ): EntitySelectors - function getSelectors( - selectState?: (state: V) => EntityState - ): EntitySelectors { - const selectIds = (state: EntityState) => state.ids +export function buildCreateSelectorsFactory( + config: BuildCreateDraftSafeSelectorConfiguration +) { + const createDraftSafeSelector = buildCreateDraftSafeSelector(config) - const selectEntities = (state: EntityState) => state.entities + return function createSelectorsFactory() { + function getSelectors(): EntitySelectors> + function getSelectors( + selectState: (state: V) => EntityState + ): EntitySelectors + function getSelectors( + selectState?: (state: V) => EntityState + ): EntitySelectors { + const selectIds = (state: EntityState) => state.ids - const selectAll = createDraftSafeSelector( - selectIds, - selectEntities, - (ids, entities): T[] => ids.map((id) => entities[id]!) - ) + const selectEntities = (state: EntityState) => state.entities - const selectId = (_: unknown, id: EntityId) => id + const selectAll = createDraftSafeSelector( + selectIds, + selectEntities, + (ids, entities): T[] => ids.map((id) => entities[id]!) + ) - const selectById = (entities: Dictionary, id: EntityId) => entities[id] + const selectId = (_: unknown, id: EntityId) => id - const selectTotal = createDraftSafeSelector(selectIds, (ids) => ids.length) + const selectById = (entities: Dictionary, id: EntityId) => entities[id] - if (!selectState) { - return { + const selectTotal = createDraftSafeSelector( selectIds, - selectEntities, - selectAll, - selectTotal, - selectById: createDraftSafeSelector( + (ids) => ids.length + ) + + if (!selectState) { + return { + selectIds, selectEntities, + selectAll, + selectTotal, + selectById: createDraftSafeSelector( + selectEntities, + selectId, + selectById + ), + } + } + + const selectGlobalizedEntities = createDraftSafeSelector( + selectState as Selector>, + selectEntities + ) + + return { + selectIds: createDraftSafeSelector(selectState, selectIds), + selectEntities: selectGlobalizedEntities, + selectAll: createDraftSafeSelector(selectState, selectAll), + selectTotal: createDraftSafeSelector(selectState, selectTotal), + selectById: createDraftSafeSelector( + selectGlobalizedEntities, selectId, selectById ), } } - const selectGlobalizedEntities = createDraftSafeSelector( - selectState as Selector>, - selectEntities - ) - - return { - selectIds: createDraftSafeSelector(selectState, selectIds), - selectEntities: selectGlobalizedEntities, - selectAll: createDraftSafeSelector(selectState, selectAll), - selectTotal: createDraftSafeSelector(selectState, selectTotal), - selectById: createDraftSafeSelector( - selectGlobalizedEntities, - selectId, - selectById - ), - } + return { getSelectors } } - - return { getSelectors } } diff --git a/packages/toolkit/src/entities/unsorted_state_adapter.ts b/packages/toolkit/src/entities/unsorted_state_adapter.ts index 9113580ba8..b0c6cad2ef 100644 --- a/packages/toolkit/src/entities/unsorted_state_adapter.ts +++ b/packages/toolkit/src/entities/unsorted_state_adapter.ts @@ -5,9 +5,10 @@ import type { Update, EntityId, } from './models' +import type { BuildStateOperatorConfiguration } from './state_adapter' import { - createStateOperator, - createSingleArgumentStateOperator, + buildCreateSingleArgumentStateOperator, + buildCreateStateOperator, } from './state_adapter' import { selectIdValue, @@ -15,184 +16,191 @@ import { splitAddedUpdatedEntities, } from './utils' -export function createUnsortedStateAdapter( - selectId: IdSelector -): EntityStateAdapter { - type R = EntityState +export function buildCreateUnsortedStateAdapter( + config: BuildStateOperatorConfiguration +) { + const createSingleArgumentStateOperator = + buildCreateSingleArgumentStateOperator(config) + const createStateOperator = buildCreateStateOperator(config) + return function createUnsortedStateAdapter( + selectId: IdSelector + ): EntityStateAdapter { + type R = EntityState - function addOneMutably(entity: T, state: R): void { - const key = selectIdValue(entity, selectId) + function addOneMutably(entity: T, state: R): void { + const key = selectIdValue(entity, selectId) - if (key in state.entities) { - return - } + if (key in state.entities) { + return + } - state.ids.push(key) - state.entities[key] = entity - } + state.ids.push(key) + state.entities[key] = entity + } - function addManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function addManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - for (const entity of newEntities) { - addOneMutably(entity, state) + for (const entity of newEntities) { + addOneMutably(entity, state) + } } - } - function setOneMutably(entity: T, state: R): void { - const key = selectIdValue(entity, selectId) - if (!(key in state.entities)) { - state.ids.push(key) + function setOneMutably(entity: T, state: R): void { + const key = selectIdValue(entity, selectId) + if (!(key in state.entities)) { + state.ids.push(key) + } + state.entities[key] = entity } - state.entities[key] = entity - } - function setManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) - for (const entity of newEntities) { - setOneMutably(entity, state) + function setManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) + for (const entity of newEntities) { + setOneMutably(entity, state) + } } - } - function setAllMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - newEntities = ensureEntitiesArray(newEntities) + function setAllMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + newEntities = ensureEntitiesArray(newEntities) - state.ids = [] - state.entities = {} + state.ids = [] + state.entities = {} - addManyMutably(newEntities, state) - } + addManyMutably(newEntities, state) + } - function removeOneMutably(key: EntityId, state: R): void { - return removeManyMutably([key], state) - } + function removeOneMutably(key: EntityId, state: R): void { + return removeManyMutably([key], state) + } - function removeManyMutably(keys: readonly EntityId[], state: R): void { - let didMutate = false + function removeManyMutably(keys: readonly EntityId[], state: R): void { + let didMutate = false - keys.forEach((key) => { - if (key in state.entities) { - delete state.entities[key] - didMutate = true - } - }) + keys.forEach((key) => { + if (key in state.entities) { + delete state.entities[key] + didMutate = true + } + }) - if (didMutate) { - state.ids = state.ids.filter((id) => id in state.entities) + if (didMutate) { + state.ids = state.ids.filter((id) => id in state.entities) + } } - } - - function removeAllMutably(state: R): void { - Object.assign(state, { - ids: [], - entities: {}, - }) - } - function takeNewKey( - keys: { [id: string]: EntityId }, - update: Update, - state: R - ): boolean { - const original = state.entities[update.id] - const updated: T = Object.assign({}, original, update.changes) - const newKey = selectIdValue(updated, selectId) - const hasNewKey = newKey !== update.id - - if (hasNewKey) { - keys[update.id] = newKey - delete state.entities[update.id] + function removeAllMutably(state: R): void { + Object.assign(state, { + ids: [], + entities: {}, + }) } - state.entities[newKey] = updated + function takeNewKey( + keys: { [id: string]: EntityId }, + update: Update, + state: R + ): boolean { + const original = state.entities[update.id] + const updated: T = Object.assign({}, original, update.changes) + const newKey = selectIdValue(updated, selectId) + const hasNewKey = newKey !== update.id + + if (hasNewKey) { + keys[update.id] = newKey + delete state.entities[update.id] + } - return hasNewKey - } + state.entities[newKey] = updated - function updateOneMutably(update: Update, state: R): void { - return updateManyMutably([update], state) - } + return hasNewKey + } - function updateManyMutably( - updates: ReadonlyArray>, - state: R - ): void { - const newKeys: { [id: string]: EntityId } = {} - - const updatesPerEntity: { [id: string]: Update } = {} - - updates.forEach((update) => { - // Only apply updates to entities that currently exist - if (update.id in state.entities) { - // If there are multiple updates to one entity, merge them together - updatesPerEntity[update.id] = { - id: update.id, - // Spreads ignore falsy values, so this works even if there isn't - // an existing update already at this key - changes: { - ...(updatesPerEntity[update.id] - ? updatesPerEntity[update.id].changes - : null), - ...update.changes, - }, + function updateOneMutably(update: Update, state: R): void { + return updateManyMutably([update], state) + } + + function updateManyMutably( + updates: ReadonlyArray>, + state: R + ): void { + const newKeys: { [id: string]: EntityId } = {} + + const updatesPerEntity: { [id: string]: Update } = {} + + updates.forEach((update) => { + // Only apply updates to entities that currently exist + if (update.id in state.entities) { + // If there are multiple updates to one entity, merge them together + updatesPerEntity[update.id] = { + id: update.id, + // Spreads ignore falsy values, so this works even if there isn't + // an existing update already at this key + changes: { + ...(updatesPerEntity[update.id] + ? updatesPerEntity[update.id].changes + : null), + ...update.changes, + }, + } } - } - }) + }) - updates = Object.values(updatesPerEntity) + updates = Object.values(updatesPerEntity) - const didMutateEntities = updates.length > 0 + const didMutateEntities = updates.length > 0 - if (didMutateEntities) { - const didMutateIds = - updates.filter((update) => takeNewKey(newKeys, update, state)).length > - 0 + if (didMutateEntities) { + const didMutateIds = + updates.filter((update) => takeNewKey(newKeys, update, state)) + .length > 0 - if (didMutateIds) { - state.ids = Object.keys(state.entities) + if (didMutateIds) { + state.ids = Object.keys(state.entities) + } } } - } - function upsertOneMutably(entity: T, state: R): void { - return upsertManyMutably([entity], state) - } + function upsertOneMutably(entity: T, state: R): void { + return upsertManyMutably([entity], state) + } - function upsertManyMutably( - newEntities: readonly T[] | Record, - state: R - ): void { - const [added, updated] = splitAddedUpdatedEntities( - newEntities, - selectId, - state - ) - - updateManyMutably(updated, state) - addManyMutably(added, state) - } + function upsertManyMutably( + newEntities: readonly T[] | Record, + state: R + ): void { + const [added, updated] = splitAddedUpdatedEntities( + newEntities, + selectId, + state + ) + + updateManyMutably(updated, state) + addManyMutably(added, state) + } - return { - removeAll: createSingleArgumentStateOperator(removeAllMutably), - addOne: createStateOperator(addOneMutably), - addMany: createStateOperator(addManyMutably), - setOne: createStateOperator(setOneMutably), - setMany: createStateOperator(setManyMutably), - setAll: createStateOperator(setAllMutably), - updateOne: createStateOperator(updateOneMutably), - updateMany: createStateOperator(updateManyMutably), - upsertOne: createStateOperator(upsertOneMutably), - upsertMany: createStateOperator(upsertManyMutably), - removeOne: createStateOperator(removeOneMutably), - removeMany: createStateOperator(removeManyMutably), + return { + removeAll: createSingleArgumentStateOperator(removeAllMutably), + addOne: createStateOperator(addOneMutably), + addMany: createStateOperator(addManyMutably), + setOne: createStateOperator(setOneMutably), + setMany: createStateOperator(setManyMutably), + setAll: createStateOperator(setAllMutably), + updateOne: createStateOperator(updateOneMutably), + updateMany: createStateOperator(updateManyMutably), + upsertOne: createStateOperator(upsertOneMutably), + upsertMany: createStateOperator(upsertManyMutably), + removeOne: createStateOperator(removeOneMutably), + removeMany: createStateOperator(removeManyMutably), + } } } diff --git a/packages/toolkit/src/immer.ts b/packages/toolkit/src/immer.ts new file mode 100644 index 0000000000..81a87da99a --- /dev/null +++ b/packages/toolkit/src/immer.ts @@ -0,0 +1,22 @@ +import { + applyPatches, + current, + freeze, + isDraft, + isDraftable, + original, + produce, + produceWithPatches, +} from 'immer' +import { defineImmutableHelpers } from './tsHelpers' + +export const immutableHelpers = defineImmutableHelpers({ + createNextState: produce, + createWithPatches: produceWithPatches, + applyPatches, + isDraft, + isDraftable, + original, + current, + freeze, +}) diff --git a/packages/toolkit/src/index.ts b/packages/toolkit/src/index.ts index ad505f3d54..04978139c8 100644 --- a/packages/toolkit/src/index.ts +++ b/packages/toolkit/src/index.ts @@ -14,7 +14,11 @@ export type { OutputSelector, ParametricSelector, } from 'reselect' -export { createDraftSafeSelector } from './createDraftSafeSelector' +export type { BuildCreateDraftSafeSelectorConfiguration } from './createDraftSafeSelector' +export { + buildCreateDraftSafeSelector, + createDraftSafeSelector, +} from './createDraftSafeSelector' export type { ThunkAction, ThunkDispatch, ThunkMiddleware } from 'redux-thunk' export { @@ -104,7 +108,10 @@ export type { } from './mapBuilders' export { MiddlewareArray } from './utils' -export { createEntityAdapter } from './entities/create_adapter' +export { + buildCreateEntityAdapter, + createEntityAdapter, +} from './entities/create_adapter' export type { Dictionary, EntityState, @@ -190,3 +197,7 @@ export { autoBatchEnhancer, } from './autoBatchEnhancer' export type { AutoBatchOptions } from './autoBatchEnhancer' + +export type { ImmutableHelpers } from './tsHelpers' +export { defineImmutableHelpers } from './tsHelpers' +export { immutableHelpers as immerImmutableHelpers } from './immer' diff --git a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts index 549ade7084..fb76288cef 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts @@ -5,8 +5,7 @@ import type { QuerySubstateIdentifier, Subscribers, } from '../apiState' -import { produceWithPatches } from 'immer' -import type { AnyAction } from '@reduxjs/toolkit'; +import type { AnyAction } from '@reduxjs/toolkit' import { createSlice, PayloadAction } from '@reduxjs/toolkit' // Copied from https://github.com/feross/queue-microtask @@ -30,7 +29,12 @@ const queueMicrotaskShim = export const buildBatchedActionsHandler: InternalHandlerBuilder< [actionShouldContinue: boolean, subscriptionExists: boolean] -> = ({ api, queryThunk, internalState }) => { +> = ({ + api, + queryThunk, + internalState, + immutableHelpers: { createWithPatches }, +}) => { const subscriptionsPrefix = `${api.reducerPath}/subscriptions` let previousSubscriptions: SubscriptionState = @@ -125,7 +129,7 @@ export const buildBatchedActionsHandler: InternalHandlerBuilder< JSON.stringify(internalState.currentSubscriptions) ) // Figure out a smaller diff between original and current - const [, patches] = produceWithPatches( + const [, patches] = createWithPatches( previousSubscriptions, () => newSubscriptions ) diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index 810839e333..f27304c9cb 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -12,6 +12,7 @@ import { buildInvalidationByTagsHandler } from './invalidationByTags' import { buildPollingHandler } from './polling' import type { BuildMiddlewareInput, + BuildSubMiddlewareInput, InternalHandlerBuilder, InternalMiddlewareState, } from './types' @@ -63,7 +64,7 @@ export function buildMiddleware< currentSubscriptions: {}, } - const builderArgs = { + const builderArgs: BuildSubMiddlewareInput = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, string, diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index 20e23a4ac8..1b8561979f 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -5,6 +5,7 @@ import type { Middleware, MiddlewareAPI, ThunkDispatch, + ImmutableHelpers, } from '@reduxjs/toolkit' import type { Api, ApiContext } from '../../apiTypes' @@ -43,6 +44,7 @@ export interface BuildMiddlewareInput< mutationThunk: MutationThunk api: Api assertTagType: AssertTagTypes + immutableHelpers: Pick } export type SubMiddlewareApi = MiddlewareAPI< diff --git a/packages/toolkit/src/query/core/buildSelectors.ts b/packages/toolkit/src/query/core/buildSelectors.ts index a49da366c5..9067a1e2b1 100644 --- a/packages/toolkit/src/query/core/buildSelectors.ts +++ b/packages/toolkit/src/query/core/buildSelectors.ts @@ -1,4 +1,5 @@ -import { createNextState, createSelector } from '@reduxjs/toolkit' +import type { ImmutableHelpers } from '@reduxjs/toolkit' +import { createSelector } from '@reduxjs/toolkit' import type { MutationSubState, QuerySubState, @@ -103,30 +104,28 @@ export type MutationResultSelectorResult< Definition extends MutationDefinition > = MutationSubState & RequestStatusFlags -const initialSubState: QuerySubState = { - status: QueryStatus.uninitialized as const, -} - -// abuse immer to freeze default states -const defaultQuerySubState = /* @__PURE__ */ createNextState( - initialSubState, - () => {} -) -const defaultMutationSubState = /* @__PURE__ */ createNextState( - initialSubState as MutationSubState, - () => {} -) - export function buildSelectors< Definitions extends EndpointDefinitions, ReducerPath extends string >({ serializeQueryArgs, reducerPath, + immutableHelpers: { freeze }, }: { serializeQueryArgs: InternalSerializeQueryArgs reducerPath: ReducerPath + immutableHelpers: Pick }) { + const initialSubState: QuerySubState = { + status: QueryStatus.uninitialized as const, + } + + const defaultQuerySubState = freeze(initialSubState, true) + const defaultMutationSubState = freeze( + initialSubState as MutationSubState, + true + ) + type RootState = _RootState const selectSkippedQuery = (state: RootState) => defaultQuerySubState diff --git a/packages/toolkit/src/query/core/buildSlice.ts b/packages/toolkit/src/query/core/buildSlice.ts index 3343b6dc3d..578de663ad 100644 --- a/packages/toolkit/src/query/core/buildSlice.ts +++ b/packages/toolkit/src/query/core/buildSlice.ts @@ -1,13 +1,16 @@ -import type { AnyAction, PayloadAction } from '@reduxjs/toolkit' +import type { + PayloadAction, + BuildCreateSliceConfiguration, + ImmutableHelpers, +} from '@reduxjs/toolkit' import { combineReducers, createAction, - createSlice, isAnyOf, isFulfilled, isRejectedWithValue, - createNextState, prepareAutoBatched, + buildCreateSlice, } from '@reduxjs/toolkit' import type { CombinedState as CombinedQueryState, @@ -32,7 +35,6 @@ import type { QueryDefinition, } from '../endpointDefinitions' import type { Patch } from 'immer' -import { applyPatches } from 'immer' import { onFocus, onFocusLost, onOffline, onOnline } from './setupListeners' import { isDocumentVisible, @@ -99,6 +101,8 @@ export function buildSlice({ }, assertTagType, config, + immutableHelpers, + immutableHelpers: { applyPatches, createNextState }, }: { reducerPath: string queryThunk: QueryThunk @@ -109,7 +113,11 @@ export function buildSlice({ ConfigState, 'online' | 'focused' | 'middlewareRegistered' > + immutableHelpers: Pick & + BuildCreateSliceConfiguration }) { + const createSlice = buildCreateSlice(immutableHelpers) + const resetApiState = createAction(`${reducerPath}/resetApiState`) const querySlice = createSlice({ name: `${reducerPath}/queries`, diff --git a/packages/toolkit/src/query/core/buildThunks.ts b/packages/toolkit/src/query/core/buildThunks.ts index 458b9edd44..42f52aa7e6 100644 --- a/packages/toolkit/src/query/core/buildThunks.ts +++ b/packages/toolkit/src/query/core/buildThunks.ts @@ -23,7 +23,11 @@ import type { } from '../endpointDefinitions' import { isQueryDefinition } from '../endpointDefinitions' import { calculateProvidedBy } from '../endpointDefinitions' -import type { AsyncThunkPayloadCreator, Draft } from '@reduxjs/toolkit' +import type { + AsyncThunkPayloadCreator, + Draft, + ImmutableHelpers, +} from '@reduxjs/toolkit' import { isAllOf, isFulfilled, @@ -32,7 +36,6 @@ import { isRejectedWithValue, } from '@reduxjs/toolkit' import type { Patch } from 'immer' -import { isDraftable, produceWithPatches } from 'immer' import type { AnyAction, ThunkAction, @@ -222,12 +225,14 @@ export function buildThunks< context: { endpointDefinitions }, serializeQueryArgs, api, + immutableHelpers: { isDraftable, createWithPatches }, }: { baseQuery: BaseQuery reducerPath: ReducerPath context: ApiContext serializeQueryArgs: InternalSerializeQueryArgs api: Api + immutableHelpers: Pick }) { type State = RootState @@ -264,7 +269,7 @@ export function buildThunks< } if ('data' in currentState) { if (isDraftable(currentState.data)) { - const [, patches, inversePatches] = produceWithPatches( + const [, patches, inversePatches] = createWithPatches( currentState.data, updateRecipe ) diff --git a/packages/toolkit/src/query/core/module.ts b/packages/toolkit/src/query/core/module.ts index 2cb9ac76e3..2078af0720 100644 --- a/packages/toolkit/src/query/core/module.ts +++ b/packages/toolkit/src/query/core/module.ts @@ -10,11 +10,14 @@ import { buildThunks } from './buildThunks' import type { ActionCreatorWithPayload, AnyAction, + BuildCreateSliceConfiguration, Middleware, Reducer, ThunkAction, ThunkDispatch, + ImmutableHelpers, } from '@reduxjs/toolkit' +import { immerImmutableHelpers } from '@reduxjs/toolkit' import type { EndpointDefinitions, QueryArgFrom, @@ -71,7 +74,8 @@ export type CoreModule = | ReferenceQueryLifecycle | ReferenceCacheCollection -export interface ThunkWithReturnValue extends ThunkAction {} +export interface ThunkWithReturnValue + extends ThunkAction {} declare module '../apiTypes' { export interface ApiModules< @@ -450,6 +454,14 @@ export type ListenerActions = { export type InternalActions = SliceActions & ListenerActions +interface CoreModuleOptions { + immutableHelpers?: BuildCreateSliceConfiguration & + Pick< + ImmutableHelpers, + 'createWithPatches' | 'applyPatches' | 'isDraftable' | 'freeze' + > +} + /** * Creates a module containing the basic redux logic for use with `buildCreateApi`. * @@ -458,7 +470,9 @@ export type InternalActions = SliceActions & ListenerActions * const createBaseApi = buildCreateApi(coreModule()); * ``` */ -export const coreModule = (): Module => ({ +export const coreModule = ({ + immutableHelpers = immerImmutableHelpers, +}: CoreModuleOptions = {}): Module => ({ name: coreModuleName, init( api, @@ -518,6 +532,7 @@ export const coreModule = (): Module => ({ context, api, serializeQueryArgs, + immutableHelpers, }) const { reducer, actions: sliceActions } = buildSlice({ @@ -533,6 +548,7 @@ export const coreModule = (): Module => ({ keepUnusedDataFor, reducerPath, }, + immutableHelpers, }) safeAssign(api.util, { @@ -551,6 +567,7 @@ export const coreModule = (): Module => ({ mutationThunk, api, assertTagType, + immutableHelpers, }) safeAssign(api.util, middlewareActions) @@ -560,6 +577,7 @@ export const coreModule = (): Module => ({ buildSelectors({ serializeQueryArgs: serializeQueryArgs as any, reducerPath, + immutableHelpers, }) safeAssign(api.util, { selectInvalidatedBy }) diff --git a/packages/toolkit/src/tsHelpers.ts b/packages/toolkit/src/tsHelpers.ts index 3077037731..40c44904bd 100644 --- a/packages/toolkit/src/tsHelpers.ts +++ b/packages/toolkit/src/tsHelpers.ts @@ -1,6 +1,58 @@ import type { Middleware, StoreEnhancer } from 'redux' +import type { Draft, Patch, applyPatches } from 'immer' import type { MiddlewareArray } from './utils' +export interface ImmutableHelpers { + /** + * Function that receives a base object, and a recipe which is called with a draft that the recipe is allowed to mutate. + * The recipe can return a new state which will replace the existing state, or it can not return (in which case the existing draft is used) + * Returns an immutably modified version of the input object. + */ + createNextState: ( + base: Base, + recipe: (draft: Draft) => void | Base | Draft + ) => Base + /** + * Function that receives a base object, and a recipe which is called with a draft that the recipe is allowed to mutate. + * The recipe can return a new state which will replace the existing state, or it can not return (in which case the existing draft is used) + * Returns a tuple of an immutably modified version of the input object, an array of patches describing the changes made, and an array of inverse patches. + */ + createWithPatches: ( + base: Base, + recipe: (draft: Draft) => void | Base | Draft + ) => readonly [Base, Patch[], Patch[]] + /** + * Receives a base object and an array of patches describing changes to apply. + * Returns an immutably modified version of the base object with changes applied. + */ + applyPatches: typeof applyPatches + /** + * Indicates whether the value passed is a draft, meaning it's safe to mutate. + */ + isDraft(value: any): boolean + /** + * Indicates whether the value passed is possible to turn into a mutable draft. + */ + isDraftable(value: any): boolean + /** + * Receives a draft and returns its base object. + */ + original(value: T): T | undefined + /** + * Receives a draft and returns an object with any changes to date immutably applied. + */ + current(value: T): T + /** + * Receives an object and freezes it, causing runtime errors if mutation is attempted after. + */ + freeze(obj: T, deep?: boolean): T +} + +/** + * Define a config object indicating utilities for RTK packages to use for immutable operations. + */ +export const defineImmutableHelpers = (helpers: ImmutableHelpers) => helpers + /** * return True if T is `any`, otherwise return False * taken from https://github.com/joonhocho/tsdef