Skip to content

Commit

Permalink
Merge pull request #4735 from reduxjs/upsert-new
Browse files Browse the repository at this point in the history
Update to new version of upsert proposal, and fix listener equality checks
  • Loading branch information
markerikson authored Nov 28, 2024
2 parents ee22bef + 05a8322 commit fdfc3b7
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 150 deletions.
10 changes: 6 additions & 4 deletions packages/toolkit/src/combineSlices.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import type {
UnionToIntersection,
WithOptionalProp,
} from './tsHelpers'
import { emplace } from './utils'
import { getOrInsertComputed } from './utils'

type SliceLike<ReducerPath extends string, State> = {
reducerPath: ReducerPath
Expand Down Expand Up @@ -324,8 +324,10 @@ const createStateProxy = <State extends object>(
state: State,
reducerMap: Partial<Record<string, Reducer>>,
) =>
emplace(stateProxyMap, state, {
insert: () =>
getOrInsertComputed(
stateProxyMap,
state,
() =>
new Proxy(state, {
get: (target, prop, receiver) => {
if (prop === ORIGINAL_STATE) return target
Expand All @@ -350,7 +352,7 @@ const createStateProxy = <State extends object>(
return result
},
}),
}) as State
) as State

const original = (state: any) => {
if (!isStateProxy(state)) {
Expand Down
40 changes: 20 additions & 20 deletions packages/toolkit/src/createSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { createReducer } from './createReducer'
import type { ActionReducerMapBuilder, TypedActionCreator } from './mapBuilders'
import { executeReducerBuilderCallback } from './mapBuilders'
import type { Id, TypeGuard } from './tsHelpers'
import { emplace } from './utils'
import { getOrInsertComputed } from './utils'

const asyncThunkSymbol = /* @__PURE__ */ Symbol.for(
'rtk-slice-createasyncthunk',
Expand Down Expand Up @@ -769,25 +769,25 @@ export function buildCreateSlice({ creators }: BuildCreateSliceConfig = {}) {
function getSelectors(
selectState: (rootState: any) => State = selectSelf,
) {
const selectorCache = emplace(injectedSelectorCache, injected, {
insert: () => new WeakMap(),
})

return emplace(selectorCache, selectState, {
insert: () => {
const map: Record<string, Selector<any, any>> = {}
for (const [name, selector] of Object.entries(
options.selectors ?? {},
)) {
map[name] = wrapSelector(
selector,
selectState,
getInitialState,
injected,
)
}
return map
},
const selectorCache = getOrInsertComputed(
injectedSelectorCache,
injected,
() => new WeakMap(),
)

return getOrInsertComputed(selectorCache, selectState, () => {
const map: Record<string, Selector<any, any>> = {}
for (const [name, selector] of Object.entries(
options.selectors ?? {},
)) {
map[name] = wrapSelector(
selector,
selectState,
getInitialState,
injected,
)
}
return map
}) as any
}
return {
Expand Down
19 changes: 7 additions & 12 deletions packages/toolkit/src/dynamicMiddleware/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { compose } from 'redux'
import { createAction } from '../createAction'
import { isAllOf } from '../matchers'
import { nanoid } from '../nanoid'
import { emplace, find } from '../utils'
import { getOrInsertComputed } from '../utils'
import type {
AddMiddleware,
DynamicMiddleware,
Expand All @@ -23,7 +23,6 @@ const createMiddlewareEntry = <
>(
middleware: Middleware<any, State, DispatchType>,
): MiddlewareEntry<State, DispatchType> => ({
id: nanoid(),
middleware,
applied: new Map(),
})
Expand All @@ -38,7 +37,10 @@ export const createDynamicMiddleware = <
DispatchType extends Dispatch<UnknownAction> = Dispatch<UnknownAction>,
>(): DynamicMiddlewareInstance<State, DispatchType> => {
const instanceId = nanoid()
const middlewareMap = new Map<string, MiddlewareEntry<State, DispatchType>>()
const middlewareMap = new Map<
Middleware<any, State, DispatchType>,
MiddlewareEntry<State, DispatchType>
>()

const withMiddleware = Object.assign(
createAction(
Expand All @@ -58,22 +60,15 @@ export const createDynamicMiddleware = <
...middlewares: Middleware<any, State, DispatchType>[]
) {
middlewares.forEach((middleware) => {
let entry = find(
Array.from(middlewareMap.values()),
(entry) => entry.middleware === middleware,
)
if (!entry) {
entry = createMiddlewareEntry(middleware)
}
middlewareMap.set(entry.id, entry)
getOrInsertComputed(middlewareMap, middleware, createMiddlewareEntry)
})
},
{ withTypes: () => addMiddleware },
) as AddMiddleware<State, DispatchType>

const getFinalMiddleware: Middleware<{}, State, DispatchType> = (api) => {
const appliedMiddleware = Array.from(middlewareMap.values()).map((entry) =>
emplace(entry.applied, api, { insert: () => entry.middleware(api) }),
getOrInsertComputed(entry.applied, api, entry.middleware),
)
return compose(...appliedMiddleware)
}
Expand Down
1 change: 0 additions & 1 deletion packages/toolkit/src/dynamicMiddleware/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ export type MiddlewareEntry<
State = unknown,
DispatchType extends Dispatch<UnknownAction> = Dispatch<UnknownAction>,
> = {
id: string
middleware: Middleware<any, State, DispatchType>
applied: Map<
MiddlewareAPI<DispatchType, State>,
Expand Down
44 changes: 22 additions & 22 deletions packages/toolkit/src/listenerMiddleware/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import type { ThunkDispatch } from 'redux-thunk'
import { createAction } from '../createAction'
import { nanoid } from '../nanoid'

import { find } from '../utils'
import {
TaskAbortError,
listenerCancelled,
Expand Down Expand Up @@ -221,9 +220,8 @@ export const createListenerEntry: TypedCreateListenerEntry<unknown> =
(options: FallbackAddListenerOptions) => {
const { type, predicate, effect } = getListenerEntryPropsFrom(options)

const id = nanoid()
const entry: ListenerEntry<unknown> = {
id,
id: nanoid(),
effect,
type,
predicate,
Expand All @@ -238,6 +236,22 @@ export const createListenerEntry: TypedCreateListenerEntry<unknown> =
{ withTypes: () => createListenerEntry },
) as unknown as TypedCreateListenerEntry<unknown>

const findListenerEntry = (
listenerMap: Map<string, ListenerEntry>,
options: FallbackAddListenerOptions,
) => {
const { type, effect, predicate } = getListenerEntryPropsFrom(options)

return Array.from(listenerMap.values()).find((entry) => {
const matchPredicateOrType =
typeof type === 'string'
? entry.type === type
: entry.predicate === predicate

return matchPredicateOrType && entry.effect === effect
})
}

const cancelActiveListeners = (
entry: ListenerEntry<unknown, Dispatch<UnknownAction>>,
) => {
Expand Down Expand Up @@ -330,7 +344,7 @@ export const createListenerMiddleware = <
assertFunction(onError, 'onError')

const insertEntry = (entry: ListenerEntry) => {
entry.unsubscribe = () => listenerMap.delete(entry!.id)
entry.unsubscribe = () => listenerMap.delete(entry.id)

listenerMap.set(entry.id, entry)
return (cancelOptions?: UnsubscribeListenerOptions) => {
Expand All @@ -342,14 +356,9 @@ export const createListenerMiddleware = <
}

const startListening = ((options: FallbackAddListenerOptions) => {
let entry = find(
Array.from(listenerMap.values()),
(existingEntry) => existingEntry.effect === options.effect,
)

if (!entry) {
entry = createListenerEntry(options as any)
}
const entry =
findListenerEntry(listenerMap, options) ??
createListenerEntry(options as any)

return insertEntry(entry)
}) as AddListenerOverloads<any>
Expand All @@ -361,16 +370,7 @@ export const createListenerMiddleware = <
const stopListening = (
options: FallbackAddListenerOptions & UnsubscribeListenerOptions,
): boolean => {
const { type, effect, predicate } = getListenerEntryPropsFrom(options)

const entry = find(Array.from(listenerMap.values()), (entry) => {
const matchPredicateOrType =
typeof type === 'string'
? entry.type === type
: entry.predicate === predicate

return matchPredicateOrType && entry.effect === effect
})
const entry = findListenerEntry(listenerMap, options)

if (entry) {
entry.unsubscribe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ describe('createListenerMiddleware', () => {
const testAction1 = createAction<string>('testAction1')
type TestAction1 = ReturnType<typeof testAction1>
const testAction2 = createAction<string>('testAction2')
type TestAction2 = ReturnType<typeof testAction2>
const testAction3 = createAction<string>('testAction3')

beforeAll(() => {
Expand Down Expand Up @@ -339,6 +340,27 @@ describe('createListenerMiddleware', () => {
])
})

test('subscribing with the same effect but different predicate is allowed', () => {
const effect = vi.fn((_: TestAction1 | TestAction2) => {})

startListening({
actionCreator: testAction1,
effect,
})
startListening({
actionCreator: testAction2,
effect,
})

store.dispatch(testAction1('a'))
store.dispatch(testAction2('b'))

expect(effect.mock.calls).toEqual([
[testAction1('a'), middlewareApi],
[testAction2('b'), middlewareApi],
])
})

test('unsubscribing via callback', () => {
const effect = vi.fn((_: TestAction1) => {})

Expand Down
34 changes: 27 additions & 7 deletions packages/toolkit/src/listenerMiddleware/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,13 @@ export type TypedAddListener<
OverrideStateType,
unknown,
UnknownAction
>,
OverrideExtraArgument = unknown,
>() => TypedAddListener<OverrideStateType, OverrideDispatchType, OverrideExtraArgument>
>,
OverrideExtraArgument = unknown,
>() => TypedAddListener<
OverrideStateType,
OverrideDispatchType,
OverrideExtraArgument
>
}

/**
Expand Down Expand Up @@ -641,7 +645,11 @@ export type TypedRemoveListener<
UnknownAction
>,
OverrideExtraArgument = unknown,
>() => TypedRemoveListener<OverrideStateType, OverrideDispatchType, OverrideExtraArgument>
>() => TypedRemoveListener<
OverrideStateType,
OverrideDispatchType,
OverrideExtraArgument
>
}

/**
Expand Down Expand Up @@ -701,7 +709,11 @@ export type TypedStartListening<
UnknownAction
>,
OverrideExtraArgument = unknown,
>() => TypedStartListening<OverrideStateType, OverrideDispatchType, OverrideExtraArgument>
>() => TypedStartListening<
OverrideStateType,
OverrideDispatchType,
OverrideExtraArgument
>
}

/**
Expand Down Expand Up @@ -756,7 +768,11 @@ export type TypedStopListening<
UnknownAction
>,
OverrideExtraArgument = unknown,
>() => TypedStopListening<OverrideStateType, OverrideDispatchType, OverrideExtraArgument>
>() => TypedStopListening<
OverrideStateType,
OverrideDispatchType,
OverrideExtraArgument
>
}

/**
Expand Down Expand Up @@ -813,7 +829,11 @@ export type TypedCreateListenerEntry<
UnknownAction
>,
OverrideExtraArgument = unknown,
>() => TypedStopListening<OverrideStateType, OverrideDispatchType, OverrideExtraArgument>
>() => TypedStopListening<
OverrideStateType,
OverrideDispatchType,
OverrideExtraArgument
>
}

/**
Expand Down
5 changes: 2 additions & 3 deletions packages/toolkit/src/query/core/buildInitiate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import type {
QueryDefinition,
ResultTypeFrom,
} from '../endpointDefinitions'
import { countObjectKeys, isNotNullish } from '../utils'
import { countObjectKeys, getOrInsert, isNotNullish } from '../utils'
import type { SubscriptionOptions } from './apiState'
import type { QueryResultSelectorResult } from './buildSelectors'
import type { MutationThunk, QueryThunk, QueryThunkArg } from './buildThunks'
Expand Down Expand Up @@ -391,9 +391,8 @@ You must add the middleware for RTK-Query to function correctly!`,
)

if (!runningQuery && !skippedSynchronously && !forceQueryFn) {
const running = runningQueries.get(dispatch) || {}
const running = getOrInsert(runningQueries, dispatch, {})
running[queryCacheKey] = statePromise
runningQueries.set(dispatch, running)

statePromise.then(() => {
delete running[queryCacheKey]
Expand Down
15 changes: 15 additions & 0 deletions packages/toolkit/src/query/utils/getOrInsert.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export function getOrInsert<K extends object, V>(
map: WeakMap<K, V>,
key: K,
value: V,
): V
export function getOrInsert<K, V>(map: Map<K, V>, key: K, value: V): V
export function getOrInsert<K extends object, V>(
map: Map<K, V> | WeakMap<K, V>,
key: K,
value: V,
): V {
if (map.has(key)) return map.get(key) as V

return map.set(key, value).get(key) as V
}
1 change: 1 addition & 0 deletions packages/toolkit/src/query/utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ export * from './isNotNullish'
export * from './isOnline'
export * from './isValidUrl'
export * from './joinUrls'
export * from './getOrInsert'
Loading

0 comments on commit fdfc3b7

Please sign in to comment.