Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atom.INTERNAL_onInit hook to support sync effects #2801

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/vanilla/atom.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import type { AtomState, PrdOrDevStore as Store } from './store'

type Getter = <Value>(atom: Atom<Value>) => Value

type Setter = <Value, Args extends unknown[], Result>(
Expand Down Expand Up @@ -47,6 +49,11 @@ export interface Atom<Value> {
* @private
*/
debugPrivate?: boolean
/**
* Fires after atom is referenced by the store for the first time
* For internal use only and subject to change without notice.
*/
INTERNAL_onInit?: (store: Store, atomState: AtomState) => void
}

export interface WritableAtom<Value, Args extends unknown[], Result>
Expand Down
98 changes: 62 additions & 36 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,28 @@ const isPromiseLike = (
* The mounted state of an atom is freed once it is no longer mounted.
*/
type Mounted = {
/** Set of listeners to notify when the atom value changes. */
readonly l: Set<() => void>
/** Count of listeners to notify when the atom value changes. */
l: number
/** Set of mounted atoms that the atom depends on. */
readonly d: Set<AnyAtom>
/** Set of mounted atoms that depends on the atom. */
readonly t: Set<AnyAtom>
/** Function to run when the atom is unmounted. */
u?: (batch: Batch) => void
u?: BatchListener
}

/**
* Mutable atom state,
* tracked for both mounted and unmounted atoms in a store.
*/
type AtomState<Value = AnyValue> = {
export type AtomState<Value = AnyValue> = {
/**
* Map of atoms that the atom depends on.
* The map value is the epoch number of the dependency.
*/
readonly d: Map<AnyAtom, number>
/** Set of priority listeners to run when the atom value changes. */
readonly l: Set<readonly [listener: BatchListener, priority: BatchPriority]>
/**
* Set of atoms with pending promise that depend on the atom.
*
Expand Down Expand Up @@ -163,17 +165,19 @@ const addDependency = <Value>(
// Batch
//

type BatchListener = (batch: Batch) => void

type BatchPriority = 'H' | 'M' | 'L'

type Batch = Readonly<{
/** Atom dependents map */
D: Map<AnyAtom, Set<AnyAtom>>
/** High priority functions */
H: Set<() => void>
H: Set<BatchListener>
/** Medium priority functions */
M: Set<() => void>
M: Set<BatchListener>
/** Low priority functions */
L: Set<() => void>
L: Set<BatchListener>
}>

const createBatch = (): Batch => ({
Expand All @@ -185,8 +189,8 @@ const createBatch = (): Batch => ({

const addBatchFunc = (
batch: Batch,
fn: BatchListener,
priority: BatchPriority,
fn: () => void,
) => {
batch[priority].add(fn)
}
Expand All @@ -198,9 +202,12 @@ const registerBatchAtom = (
) => {
if (!batch.D.has(atom)) {
batch.D.set(atom, new Set())
addBatchFunc(batch, 'M', () => {
atomState.m?.l.forEach((listener) => addBatchFunc(batch, 'M', listener))
})
const scheduleListeners = () => {
for (const [listener, priority] of atomState.l) {
addBatchFunc(batch, listener, priority)
}
}
addBatchFunc(batch, scheduleListeners, 'H')
}
}

Expand All @@ -221,9 +228,9 @@ const getBatchAtomDependents = (batch: Batch, atom: AnyAtom) =>
const flushBatch = (batch: Batch) => {
let error: AnyError
let hasError = false
const call = (fn: () => void) => {
const call = (fn: BatchListener) => {
try {
fn()
fn(batch)
} catch (e) {
if (!hasError) {
error = e
Expand All @@ -245,9 +252,17 @@ const flushBatch = (batch: Batch) => {
}
}

type AtomOnInit = <Value>(
atom: Atom<Value>,
atomState: AtomState<Value>,
) => void

// internal & unstable type
type StoreArgs = readonly [
getAtomState: <Value>(atom: Atom<Value>) => AtomState<Value>,
getAtomState: <Value>(
atom: Atom<Value>,
atomOnInit?: AtomOnInit | undefined,
) => AtomState<Value>,
atomRead: <Value>(
atom: Atom<Value>,
...params: Parameters<Atom<Value>['read']>
Expand All @@ -260,6 +275,7 @@ type StoreArgs = readonly [
atom: WritableAtom<Value, Args, Result>,
setAtom: (...args: Args) => Result,
) => OnUnmount | void,
createAtomOnInit: (store: Store) => AtomOnInit,
]

// for debugging purpose only
Expand All @@ -284,9 +300,12 @@ type Store = {
export type INTERNAL_DevStoreRev4 = DevStoreRev4
export type INTERNAL_PrdStore = Store

const buildStore = (
...[getAtomState, atomRead, atomWrite, atomOnMount]: StoreArgs
): Store => {
const buildStore = (...storeArgs: StoreArgs): Store => {
const [_getAtomState, atomRead, atomWrite, atomOnMount, createAtomOnInit] =
storeArgs
const getAtomState = <Value>(atom: Atom<Value>) =>
_getAtomState(atom, createAtomOnInit(store))

const setAtomStateValueOrPromise = (
atom: AnyAtom,
atomState: AtomState,
Expand Down Expand Up @@ -503,7 +522,7 @@ const buildStore = (

// Step 2: use the topSortedReversed atom list to recompute all affected atoms
// Track what's changed, so that we can short circuit when possible
addBatchFunc(batch, 'H', () => {
const finishRecompute = () => {
const changedAtoms = new Set<AnyAtom>([atom])
for (let i = topSortedReversed.length - 1; i >= 0; --i) {
const [a, aState, prevEpochNumber] = topSortedReversed[i]!
Expand All @@ -524,7 +543,8 @@ const buildStore = (
}
delete aState.x
}
})
}
addBatchFunc(batch, finishRecompute, 'H')
}

const writeAtomState = <Value, Args extends unknown[], Result>(
Expand Down Expand Up @@ -621,8 +641,8 @@ const buildStore = (
}
// mount self
atomState.m = {
l: new Set(),
d: new Set(atomState.d.keys()),
l: 0,
t: new Set(),
}
if (isActuallyWritableAtom(atom)) {
Expand All @@ -645,14 +665,15 @@ const buildStore = (
isSync = false
}
}
addBatchFunc(batch, 'L', () => {
const processOnMount = () => {
const onUnmount = createInvocationContext(batch, () =>
atomOnMount(atom, (...args) => setAtom(...args)),
)
if (onUnmount) {
mounted.u = (batch) => createInvocationContext(batch, onUnmount)
}
})
}
addBatchFunc(batch, processOnMount, 'L')
}
}
return atomState.m
Expand All @@ -665,13 +686,13 @@ const buildStore = (
): Mounted | undefined => {
if (
atomState.m &&
!atomState.m.l.size &&
!atomState.m.l &&
!Array.from(atomState.m.t).some((a) => getAtomState(a).m?.d.has(atom))
) {
// unmount self
const onUnmount = atomState.m.u
if (onUnmount) {
addBatchFunc(batch, 'L', () => onUnmount(batch))
addBatchFunc(batch, onUnmount, 'L')
}
delete atomState.m
// unmount dependencies
Expand All @@ -688,19 +709,21 @@ const buildStore = (
const batch = createBatch()
const atomState = getAtomState(atom)
const mounted = mountAtom(batch, atom, atomState)
const listeners = mounted.l
listeners.add(listener)
const priorityListener = [() => listener(), 'M'] as const
++mounted.l
atomState.l.add(priorityListener)
flushBatch(batch)
return () => {
listeners.delete(listener)
const batch = createBatch()
--mounted.l
atomState.l.delete(priorityListener)
unmountAtom(batch, atom, atomState)
flushBatch(batch)
}
}

const unstable_derive = (fn: (...args: StoreArgs) => StoreArgs) =>
buildStore(...fn(getAtomState, atomRead, atomWrite, atomOnMount))
const unstable_derive: Store['unstable_derive'] = (fn) =>
buildStore(...fn(...storeArgs))

const store: Store = {
get: readAtom,
Expand All @@ -717,13 +740,13 @@ const deriveDevStoreRev4 = (store: Store): Store & DevStoreRev4 => {
let savedGetAtomState: StoreArgs[0]
let inRestoreAtom = 0
const derivedStore = store.unstable_derive(
(getAtomState, atomRead, atomWrite, atomOnMount) => {
savedGetAtomState = getAtomState
(getAtomState, atomRead, atomWrite, atomOnMount, createAtomOnInit) => {
savedGetAtomState = (a) => getAtomState(a, createAtomOnInit(derivedStore))
return [
(atom) => {
(atom, atomOnInit) => {
let proxyAtomState = proxyAtomStateMap.get(atom)
if (!proxyAtomState) {
const atomState = getAtomState(atom)
const atomState = getAtomState(atom, atomOnInit)
proxyAtomState = new Proxy(atomState, {
set(target, prop, value) {
if (prop === 'm') {
Expand All @@ -750,6 +773,7 @@ const deriveDevStoreRev4 = (store: Store): Store & DevStoreRev4 => {
return atomWrite(atom, getter, setter, ...args)
},
atomOnMount,
createAtomOnInit,
]
},
)
Expand Down Expand Up @@ -789,18 +813,19 @@ const deriveDevStoreRev4 = (store: Store): Store & DevStoreRev4 => {
return Object.assign(derivedStore, devStore)
}

type PrdOrDevStore = Store | (Store & DevStoreRev4)
export type PrdOrDevStore = Store | (Store & DevStoreRev4)

export const createStore = (): PrdOrDevStore => {
const atomStateMap = new WeakMap()
const getAtomState = <Value>(atom: Atom<Value>) => {
const getAtomState = <Value>(atom: Atom<Value>, atomOnInit?: AtomOnInit) => {
if (import.meta.env?.MODE !== 'production' && !atom) {
throw new Error('Atom is undefined or null')
}
let atomState = atomStateMap.get(atom) as AtomState<Value> | undefined
if (!atomState) {
atomState = { d: new Map(), p: new Set(), n: 0 }
atomState = { d: new Map(), l: new Set(), p: new Set(), n: 0 }
atomStateMap.set(atom, atomState)
atomOnInit?.(atom, atomState)
}
return atomState
}
Expand All @@ -809,6 +834,7 @@ export const createStore = (): PrdOrDevStore => {
(atom, ...params) => atom.read(...params),
(atom, ...params) => atom.write(...params),
(atom, ...params) => atom.onMount?.(...params),
(store) => (atom, atomState) => atom.INTERNAL_onInit?.(store, atomState),
)
if (import.meta.env?.MODE !== 'production') {
return deriveDevStoreRev4(store)
Expand Down
16 changes: 16 additions & 0 deletions tests/setup.ts
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
import '@testing-library/jest-dom/vitest'
import { expect, vi } from 'vitest'

type MockFunction = ReturnType<typeof vi.fn>

expect.extend({
toHaveBeenCalledBefore(received: MockFunction, expected: MockFunction) {
const pass =
received.mock.invocationCallOrder[0]! <
expected.mock.invocationCallOrder[0]!
return {
pass,
message: () =>
`expected ${received} to have been called before ${expected}`,
}
},
})
Loading
Loading