From b1d2e252622651b82a4ddcbd6934e2ea20579cdf Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Fri, 3 May 2024 09:52:15 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20user=20store?= =?UTF-8?q?=20add=20an=20auth=20slice=20(#2214)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: refactor the user store with auth slice * ♻️ refactor: separate common and sync slice * 🧑‍💻 chore: add an isMobile selector * ♻️ refactor: refactor the auth action and common action * 🎨 chore: clean code --- .../chat/(mobile)/features/SessionHeader.tsx | 4 +- .../chat/features/ShareButton/ShareModal.tsx | 4 +- src/features/AvatarWithUpload/index.tsx | 4 +- src/layout/GlobalProvider/index.tsx | 8 +- src/services/chat.ts | 4 +- src/store/chat/slices/message/selectors.ts | 4 +- src/store/serverConfig/Provider.tsx | 5 +- src/store/serverConfig/selectors.ts | 1 + src/store/serverConfig/store.ts | 1 + src/store/user/initialState.ts | 8 +- src/store/user/selectors.ts | 2 +- src/store/user/slices/auth/action.test.ts | 118 +++++++++ src/store/user/slices/auth/action.ts | 81 +++++++ src/store/user/slices/auth/initialState.ts | 20 ++ src/store/user/slices/auth/selectors.ts | 6 + src/store/user/slices/common/action.test.ts | 223 ------------------ src/store/user/slices/common/action.ts | 115 +-------- src/store/user/slices/common/selectors.ts | 6 - .../user/slices/settings/initialState.ts | 2 - src/store/user/slices/sync/action.test.ts | 150 ++++++++++++ src/store/user/slices/sync/action.ts | 94 ++++++++ .../slices/{common => sync}/initialState.ts | 9 +- src/store/user/store.ts | 13 +- 23 files changed, 513 insertions(+), 369 deletions(-) create mode 100644 src/store/user/slices/auth/action.test.ts create mode 100644 src/store/user/slices/auth/action.ts create mode 100644 src/store/user/slices/auth/initialState.ts create mode 100644 src/store/user/slices/auth/selectors.ts delete mode 100644 src/store/user/slices/common/selectors.ts create mode 100644 src/store/user/slices/sync/action.test.ts create mode 100644 src/store/user/slices/sync/action.ts rename src/store/user/slices/{common => sync}/initialState.ts (61%) diff --git a/src/app/(main)/chat/(mobile)/features/SessionHeader.tsx b/src/app/(main)/chat/(mobile)/features/SessionHeader.tsx index 50348b5075ec1..42fd68338cb62 100644 --- a/src/app/(main)/chat/(mobile)/features/SessionHeader.tsx +++ b/src/app/(main)/chat/(mobile)/features/SessionHeader.tsx @@ -10,7 +10,7 @@ import SyncStatusInspector from '@/features/SyncStatusInspector'; import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig'; import { useSessionStore } from '@/store/session'; import { useUserStore } from '@/store/user'; -import { commonSelectors } from '@/store/user/selectors'; +import { userProfileSelectors } from '@/store/user/selectors'; import { mobileHeaderSticky } from '@/styles/mobileHeader'; export const useStyles = createStyles(({ css, token }) => ({ @@ -26,7 +26,7 @@ export const useStyles = createStyles(({ css, token }) => ({ const Header = memo(() => { const [createSession] = useSessionStore((s) => [s.createSession]); const router = useRouter(); - const avatar = useUserStore(commonSelectors.userAvatar); + const avatar = useUserStore(userProfileSelectors.userAvatar); const { showCreateSession } = useServerConfigStore(featureFlagsSelectors); return ( diff --git a/src/app/(main)/chat/features/ShareButton/ShareModal.tsx b/src/app/(main)/chat/features/ShareButton/ShareModal.tsx index 2b692eaf55d54..ff3a59d754983 100644 --- a/src/app/(main)/chat/features/ShareButton/ShareModal.tsx +++ b/src/app/(main)/chat/features/ShareButton/ShareModal.tsx @@ -7,7 +7,7 @@ import { Flexbox } from 'react-layout-kit'; import { FORM_STYLE } from '@/const/layoutTokens'; import { useChatStore } from '@/store/chat'; import { useUserStore } from '@/store/user'; -import { commonSelectors } from '@/store/user/selectors'; +import { userProfileSelectors } from '@/store/user/selectors'; import Preview from './Preview'; import { FieldType, ImageType } from './type'; @@ -49,7 +49,7 @@ const ShareModal = memo(({ onCancel, open }) => { const [fieldValue, setFieldValue] = useState(DEFAULT_FIELD_VALUE); const [tab, setTab] = useState(Tab.Screenshot); const { t } = useTranslation('chat'); - const avatar = useUserStore(commonSelectors.userAvatar); + const avatar = useUserStore(userProfileSelectors.userAvatar); const [shareLoading, shareToShareGPT] = useChatStore((s) => [s.shareLoading, s.shareToShareGPT]); const { loading, onDownload, title } = useScreenshot(fieldValue.imageType); diff --git a/src/features/AvatarWithUpload/index.tsx b/src/features/AvatarWithUpload/index.tsx index 3e1a740339ca9..2e0275e0c167e 100644 --- a/src/features/AvatarWithUpload/index.tsx +++ b/src/features/AvatarWithUpload/index.tsx @@ -7,7 +7,7 @@ import { CSSProperties, memo, useCallback } from 'react'; import { DEFAULT_USER_AVATAR_URL } from '@/const/meta'; import { useUserStore } from '@/store/user'; -import { commonSelectors } from '@/store/user/selectors'; +import { userProfileSelectors } from '@/store/user/selectors'; import { imageToBase64 } from '@/utils/imageToBase64'; import { createUploadImageHandler } from '@/utils/uploadFIle'; @@ -41,7 +41,7 @@ const AvatarWithUpload = memo( ({ size = 40, compressSize = 256, style, id }) => { const { styles } = useStyle(); const [avatar, updateAvatar] = useUserStore((s) => [ - commonSelectors.userAvatar(s), + userProfileSelectors.userAvatar(s), s.updateAvatar, ]); diff --git a/src/layout/GlobalProvider/index.tsx b/src/layout/GlobalProvider/index.tsx index aa49feeba2383..6aeff65ec6957 100644 --- a/src/layout/GlobalProvider/index.tsx +++ b/src/layout/GlobalProvider/index.tsx @@ -13,6 +13,7 @@ import { import { getServerGlobalConfig } from '@/server/globalConfig'; import { ServerConfigStoreProvider } from '@/store/serverConfig'; import { getAntdLocale } from '@/utils/locale'; +import { isMobileDevice } from '@/utils/responsive'; import AppTheme from './AppTheme'; import Locale from './Locale'; @@ -48,6 +49,7 @@ const GlobalLayout = async ({ children }: GlobalLayoutProps) => { // get default feature flags to use with ssr const serverFeatureFlags = getServerFeatureFlagsValue(); const serverConfig = getServerGlobalConfig(); + const isMobile = isMobileDevice(); return ( @@ -57,7 +59,11 @@ const GlobalLayout = async ({ children }: GlobalLayoutProps) => { defaultPrimaryColor={primaryColor?.value as any} > - + {children} diff --git a/src/services/chat.ts b/src/services/chat.ts index 8a5905c3135fb..2198a063155dc 100644 --- a/src/services/chat.ts +++ b/src/services/chat.ts @@ -15,10 +15,10 @@ import { useToolStore } from '@/store/tool'; import { pluginSelectors, toolSelectors } from '@/store/tool/selectors'; import { useUserStore } from '@/store/user'; import { - commonSelectors, modelConfigSelectors, modelProviderSelectors, preferenceSelectors, + userProfileSelectors, } from '@/store/user/selectors'; import { ChatErrorType } from '@/types/fetch'; import { ChatMessage } from '@/types/message'; @@ -482,7 +482,7 @@ class ChatService { ...trace, enabled: true, tags: [tag, ...(trace?.tags || []), ...tags].filter(Boolean) as string[], - userId: commonSelectors.userId(useUserStore.getState()), + userId: userProfileSelectors.userId(useUserStore.getState()), }; } diff --git a/src/store/chat/slices/message/selectors.ts b/src/store/chat/slices/message/selectors.ts index 494ad67a25981..1e2e0b07093da 100644 --- a/src/store/chat/slices/message/selectors.ts +++ b/src/store/chat/slices/message/selectors.ts @@ -8,7 +8,7 @@ import { agentSelectors } from '@/store/agent/selectors'; import { useSessionStore } from '@/store/session'; import { sessionMetaSelectors } from '@/store/session/selectors'; import { useUserStore } from '@/store/user'; -import { commonSelectors } from '@/store/user/selectors'; +import { userProfileSelectors } from '@/store/user/selectors'; import { ChatMessage } from '@/types/message'; import { MetaData } from '@/types/meta'; import { merge } from '@/utils/merge'; @@ -20,7 +20,7 @@ const getMeta = (message: ChatMessage) => { switch (message.role) { case 'user': { return { - avatar: commonSelectors.userAvatar(useUserStore.getState()) || DEFAULT_USER_AVATAR, + avatar: userProfileSelectors.userAvatar(useUserStore.getState()) || DEFAULT_USER_AVATAR, }; } diff --git a/src/store/serverConfig/Provider.tsx b/src/store/serverConfig/Provider.tsx index ed009b1936016..00e30ea399f2c 100644 --- a/src/store/serverConfig/Provider.tsx +++ b/src/store/serverConfig/Provider.tsx @@ -10,12 +10,13 @@ import { Provider, createServerConfigStore } from './store'; interface GlobalStoreProviderProps { children: ReactNode; featureFlags?: Partial; + isMobile?: boolean; serverConfig?: GlobalServerConfig; } export const ServerConfigStoreProvider = memo( - ({ children, featureFlags, serverConfig }) => ( - createServerConfigStore({ featureFlags, serverConfig })}> + ({ children, featureFlags, serverConfig, isMobile }) => ( + createServerConfigStore({ featureFlags, isMobile, serverConfig })}> {children} ), diff --git a/src/store/serverConfig/selectors.ts b/src/store/serverConfig/selectors.ts index 15a9fff692a07..8d2e3130ae3b4 100644 --- a/src/store/serverConfig/selectors.ts +++ b/src/store/serverConfig/selectors.ts @@ -8,4 +8,5 @@ export const featureFlagsSelectors = (s: ServerConfigStore) => export const serverConfigSelectors = { enabledOAuthSSO: (s: ServerConfigStore) => s.serverConfig.enabledOAuthSSO, enabledTelemetryChat: (s: ServerConfigStore) => s.serverConfig.telemetry.langfuse || false, + isMobile: (s: ServerConfigStore) => s.isMobile || false, }; diff --git a/src/store/serverConfig/store.ts b/src/store/serverConfig/store.ts index 7f467a5ba88e6..52eaffc623214 100644 --- a/src/store/serverConfig/store.ts +++ b/src/store/serverConfig/store.ts @@ -20,6 +20,7 @@ const initialState: ServerConfigStore = { export interface ServerConfigStore { featureFlags: IFeatureFlags; + isMobile?: boolean; serverConfig: GlobalServerConfig; } diff --git a/src/store/user/initialState.ts b/src/store/user/initialState.ts index 5f8c9f3be99d7..7ed7a6ff9a835 100644 --- a/src/store/user/initialState.ts +++ b/src/store/user/initialState.ts @@ -1,11 +1,13 @@ -import { UserCommonState, initialCommonState } from './slices/common/initialState'; +import { UserAuthState, initialAuthState } from './slices/auth/initialState'; import { UserPreferenceState, initialPreferenceState } from './slices/preference/initialState'; import { UserSettingsState, initialSettingsState } from './slices/settings/initialState'; +import { UserSyncState, initialSyncState } from './slices/sync/initialState'; -export type UserState = UserCommonState & UserSettingsState & UserPreferenceState; +export type UserState = UserSyncState & UserSettingsState & UserPreferenceState & UserAuthState; export const initialState: UserState = { - ...initialCommonState, + ...initialSyncState, ...initialSettingsState, ...initialPreferenceState, + ...initialAuthState, }; diff --git a/src/store/user/selectors.ts b/src/store/user/selectors.ts index eac9eb7b7cefd..7be6758e5c92f 100644 --- a/src/store/user/selectors.ts +++ b/src/store/user/selectors.ts @@ -1,4 +1,4 @@ -export { commonSelectors } from './slices/common/selectors'; +export { userProfileSelectors } from './slices/auth/selectors'; export { preferenceSelectors } from './slices/preference/selectors'; export { modelConfigSelectors, diff --git a/src/store/user/slices/auth/action.test.ts b/src/store/user/slices/auth/action.test.ts new file mode 100644 index 0000000000000..6e0add0cb3d5a --- /dev/null +++ b/src/store/user/slices/auth/action.test.ts @@ -0,0 +1,118 @@ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { mutate } from 'swr'; +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { withSWR } from '~test-utils'; + +import { userService } from '@/services/user'; +import { useUserStore } from '@/store/user'; +import { switchLang } from '@/utils/client/switchLang'; + +vi.mock('zustand/traditional'); + +vi.mock('@/utils/client/switchLang', () => ({ + switchLang: vi.fn(), +})); + +vi.mock('swr', async (importOriginal) => { + const modules = await importOriginal(); + return { + ...(modules as any), + mutate: vi.fn(), + }; +}); + +afterEach(() => { + vi.restoreAllMocks(); +}); + +describe('createAuthSlice', () => { + describe('refreshUserConfig', () => { + it('should refresh user config', async () => { + const { result } = renderHook(() => useUserStore()); + + await act(async () => { + await result.current.refreshUserConfig(); + }); + + expect(mutate).toHaveBeenCalledWith(['fetchUserConfig', true]); + }); + }); + + describe('useFetchUserConfig', () => { + it('should not fetch user config if initServer is false', async () => { + const mockUserConfig: any = undefined; // 模拟未初始化服务器的情况 + vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); + + const { result } = renderHook(() => useUserStore().useFetchUserConfig(false), { + wrapper: withSWR, + }); + + // 因为 initServer 为 false,所以不会触发 getUserConfig 的调用 + expect(userService.getUserConfig).not.toHaveBeenCalled(); + // 确保状态未改变 + expect(result.current.data).toBeUndefined(); + }); + + it('should fetch user config correctly when initServer is true', async () => { + const mockUserConfig: any = { + avatar: 'new-avatar-url', + settings: { + language: 'en', + }, + }; + vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); + + const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { + wrapper: withSWR, + }); + + // 等待 SWR 完成数据获取 + await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); + + // 验证状态是否正确更新 + expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); + expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); + + // 验证是否正确处理了语言设置 + expect(switchLang).not.toHaveBeenCalledWith('auto'); + }); + it('should call switch language when language is auto', async () => { + const mockUserConfig: any = { + avatar: 'new-avatar-url', + settings: { + language: 'auto', + }, + }; + vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); + + const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { + wrapper: withSWR, + }); + + // 等待 SWR 完成数据获取 + await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); + + // 验证状态是否正确更新 + expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); + expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); + + // 验证是否正确处理了语言设置 + expect(switchLang).toHaveBeenCalledWith('auto'); + }); + + it('should handle the case when user config is null', async () => { + vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(null as any); + + const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { + wrapper: withSWR, + }); + + // 等待 SWR 完成数据获取 + await waitFor(() => expect(result.current.data).toBeNull()); + + // 验证状态未被错误更新 + expect(useUserStore.getState().avatar).toBeUndefined(); + expect(useUserStore.getState().settings).toEqual({}); + }); + }); +}); diff --git a/src/store/user/slices/auth/action.ts b/src/store/user/slices/auth/action.ts new file mode 100644 index 0000000000000..ebd458d442637 --- /dev/null +++ b/src/store/user/slices/auth/action.ts @@ -0,0 +1,81 @@ +import useSWR, { SWRResponse, mutate } from 'swr'; +import { StateCreator } from 'zustand/vanilla'; + +import { UserConfig, userService } from '@/services/user'; +import { switchLang } from '@/utils/client/switchLang'; +import { setNamespace } from '@/utils/storeDebug'; + +import { UserStore } from '../../store'; +import { settingsSelectors } from '../settings/selectors'; + +const n = setNamespace('auth'); +const USER_CONFIG_FETCH_KEY = 'fetchUserConfig'; + +export interface UserAuthAction { + getUserConfig: () => void; + /** + * universal login method + */ + login: () => Promise; + /** + * universal logout method + */ + logout: () => Promise; + refreshUserConfig: () => Promise; + + useFetchUserConfig: (initServer: boolean) => SWRResponse; +} + +export const createAuthSlice: StateCreator< + UserStore, + [['zustand/devtools', never]], + [], + UserAuthAction +> = (set, get) => ({ + getUserConfig: () => { + console.log(n('userconfig')); + }, + login: async () => { + // TODO: 针对开启 next-auth 的场景,需要在这里调用登录方法 + console.log(n('login')); + }, + logout: async () => { + // TODO: 针对开启 next-auth 的场景,需要在这里调用登录方法 + console.log(n('logout')); + }, + refreshUserConfig: async () => { + await mutate([USER_CONFIG_FETCH_KEY, true]); + + // when get the user config ,refresh the model provider list to the latest + get().refreshModelProviderList(); + }, + + useFetchUserConfig: (initServer) => + useSWR( + [USER_CONFIG_FETCH_KEY, initServer], + async () => { + if (!initServer) return; + return userService.getUserConfig(); + }, + { + onSuccess: (data) => { + if (!data) return; + + set( + { avatar: data.avatar, settings: data.settings, userId: data.uuid }, + false, + n('fetchUserConfig', data), + ); + + // when get the user config ,refresh the model provider list to the latest + get().refreshDefaultModelProviderList({ trigger: 'fetchUserConfig' }); + + const { language } = settingsSelectors.currentSettings(get()); + if (language === 'auto') { + switchLang('auto'); + } + }, + revalidateOnFocus: false, + }, + ), +}); diff --git a/src/store/user/slices/auth/initialState.ts b/src/store/user/slices/auth/initialState.ts new file mode 100644 index 0000000000000..a9baa570decfa --- /dev/null +++ b/src/store/user/slices/auth/initialState.ts @@ -0,0 +1,20 @@ +export interface LobeUser { + avatar?: string; + firstName?: string | null; + fullName?: string | null; + id: string; + latestName?: string | null; + username?: string | null; +} + +export interface UserAuthState { + /** + * @deprecated + */ + avatar?: string; + isSignedIn?: boolean; + user?: LobeUser; + userId?: string; +} + +export const initialAuthState: UserAuthState = {}; diff --git a/src/store/user/slices/auth/selectors.ts b/src/store/user/slices/auth/selectors.ts new file mode 100644 index 0000000000000..b0c37ab8fa2f1 --- /dev/null +++ b/src/store/user/slices/auth/selectors.ts @@ -0,0 +1,6 @@ +import { UserStore } from '@/store/user'; + +export const userProfileSelectors = { + userAvatar: (s: UserStore): string => s.avatar || '', + userId: (s: UserStore) => s.userId, +}; diff --git a/src/store/user/slices/common/action.test.ts b/src/store/user/slices/common/action.test.ts index 454a257dfba0b..d8f3463f4a1f8 100644 --- a/src/store/user/slices/common/action.test.ts +++ b/src/store/user/slices/common/action.test.ts @@ -8,17 +8,10 @@ import { messageService } from '@/services/message'; import { userService } from '@/services/user'; import { useUserStore } from '@/store/user'; import { preferenceSelectors } from '@/store/user/selectors'; -import { commonSelectors } from '@/store/user/slices/common/selectors'; -import { syncSettingsSelectors } from '@/store/user/slices/settings/selectors'; import { GlobalServerConfig } from '@/types/serverConfig'; -import { switchLang } from '@/utils/client/switchLang'; vi.mock('zustand/traditional'); -vi.mock('@/utils/client/switchLang', () => ({ - switchLang: vi.fn(), -})); - vi.mock('swr', async (importOriginal) => { const modules = await importOriginal(); return { @@ -32,18 +25,6 @@ afterEach(() => { }); describe('createCommonSlice', () => { - describe('refreshUserConfig', () => { - it('should refresh user config', async () => { - const { result } = renderHook(() => useUserStore()); - - await act(async () => { - await result.current.refreshUserConfig(); - }); - - expect(mutate).toHaveBeenCalledWith(['fetchUserConfig', true]); - }); - }); - describe('updateAvatar', () => { it('should update avatar', async () => { const { result } = renderHook(() => useUserStore()); @@ -76,167 +57,6 @@ describe('createCommonSlice', () => { }); }); - describe('useFetchUserConfig', () => { - it('should not fetch user config if initServer is false', async () => { - const mockUserConfig: any = undefined; // 模拟未初始化服务器的情况 - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(false), { - wrapper: withSWR, - }); - - // 因为 initServer 为 false,所以不会触发 getUserConfig 的调用 - expect(userService.getUserConfig).not.toHaveBeenCalled(); - // 确保状态未改变 - expect(result.current.data).toBeUndefined(); - }); - - it('should fetch user config correctly when initServer is true', async () => { - const mockUserConfig: any = { - avatar: 'new-avatar-url', - settings: { - language: 'en', - }, - }; - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); - - // 验证状态是否正确更新 - expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); - expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); - - // 验证是否正确处理了语言设置 - expect(switchLang).not.toHaveBeenCalledWith('auto'); - }); - it('should call switch language when language is auto', async () => { - const mockUserConfig: any = { - avatar: 'new-avatar-url', - settings: { - language: 'auto', - }, - }; - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toEqual(mockUserConfig)); - - // 验证状态是否正确更新 - expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar); - expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings); - - // 验证是否正确处理了语言设置 - expect(switchLang).toHaveBeenCalledWith('auto'); - }); - - it('should handle the case when user config is null', async () => { - vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(null as any); - - const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), { - wrapper: withSWR, - }); - - // 等待 SWR 完成数据获取 - await waitFor(() => expect(result.current.data).toBeNull()); - - // 验证状态未被错误更新 - expect(useUserStore.getState().avatar).toBeUndefined(); - expect(useUserStore.getState().settings).toEqual({}); - }); - }); - - describe('refreshConnection', () => { - it('should not call triggerEnableSync when userId is empty', async () => { - const { result } = renderHook(() => useUserStore()); - const onEvent = vi.fn(); - - vi.spyOn(commonSelectors, 'userId').mockReturnValueOnce(undefined); - const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); - - await act(async () => { - await result.current.refreshConnection(onEvent); - }); - - expect(triggerEnableSyncSpy).not.toHaveBeenCalled(); - }); - - it('should call triggerEnableSync when userId exists', async () => { - const { result } = renderHook(() => useUserStore()); - const onEvent = vi.fn(); - const userId = 'user-id'; - - vi.spyOn(commonSelectors, 'userId').mockReturnValueOnce(userId); - const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); - - await act(async () => { - await result.current.refreshConnection(onEvent); - }); - - expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); - }); - }); - - describe('triggerEnableSync', () => { - it('should return false when sync.channelName is empty', async () => { - const { result } = renderHook(() => useUserStore()); - const userId = 'user-id'; - const onEvent = vi.fn(); - - vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ - channelName: '', - enabled: true, - }); - - const data = await act(async () => { - return result.current.triggerEnableSync(userId, onEvent); - }); - - expect(data).toBe(false); - }); - - it('should call globalService.enabledSync when sync.channelName exists', async () => { - const userId = 'user-id'; - const onEvent = vi.fn(); - const channelName = 'channel-name'; - const channelPassword = 'channel-password'; - const deviceName = 'device-name'; - const signaling = 'signaling'; - - vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ - channelName, - channelPassword, - signaling, - enabled: true, - }); - vi.spyOn(syncSettingsSelectors, 'deviceName').mockReturnValueOnce(deviceName); - const enabledSyncSpy = vi.spyOn(globalService, 'enabledSync').mockResolvedValueOnce(true); - const { result } = renderHook(() => useUserStore()); - - const data = await act(async () => { - return result.current.triggerEnableSync(userId, onEvent); - }); - - expect(enabledSyncSpy).toHaveBeenCalledWith({ - channel: { name: channelName, password: channelPassword }, - onAwarenessChange: expect.any(Function), - onSyncEvent: onEvent, - onSyncStatusChange: expect.any(Function), - signaling, - user: expect.objectContaining({ id: userId, name: deviceName }), - }); - expect(data).toBe(true); - }); - }); - describe('useCheckTrace', () => { it('should return false when shouldFetch is false', async () => { const { result } = renderHook(() => useUserStore().useCheckTrace(false), { @@ -270,47 +90,4 @@ describe('createCommonSlice', () => { expect(messageCountToCheckTraceSpy).toHaveBeenCalled(); }); }); - - describe('useEnabledSync', () => { - it('should return false when userId is empty', async () => { - const { result } = renderHook(() => useUserStore().useEnabledSync(true, undefined, vi.fn()), { - wrapper: withSWR, - }); - - await waitFor(() => expect(result.current.data).toBe(false)); - }); - - it('should call globalService.disableSync when userEnableSync is false', async () => { - const disableSyncSpy = vi.spyOn(globalService, 'disableSync').mockResolvedValueOnce(false); - - const { result } = renderHook( - () => useUserStore().useEnabledSync(false, 'user-id', vi.fn()), - { wrapper: withSWR }, - ); - - await waitFor(() => expect(result.current.data).toBeUndefined()); - expect(disableSyncSpy).toHaveBeenCalled(); - }); - - it('should call triggerEnableSync when userEnableSync and userId exist', async () => { - const userId = 'user-id'; - const onEvent = vi.fn(); - const triggerEnableSyncSpy = vi.fn().mockResolvedValueOnce(true); - - const { result } = renderHook(() => useUserStore()); - - // replace triggerEnableSync as a mock - result.current.triggerEnableSync = triggerEnableSyncSpy; - - const { result: swrResult } = renderHook( - () => result.current.useEnabledSync(true, userId, onEvent), - { - wrapper: withSWR, - }, - ); - - await waitFor(() => expect(swrResult.current.data).toBe(true)); - expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); - }); - }); }); diff --git a/src/store/user/slices/common/action.ts b/src/store/user/slices/common/action.ts index 2e493246a919f..d281d6fc5c93a 100644 --- a/src/store/user/slices/common/action.ts +++ b/src/store/user/slices/common/action.ts @@ -1,22 +1,17 @@ -import useSWR, { SWRResponse, mutate } from 'swr'; +import useSWR, { SWRResponse } from 'swr'; import { DeepPartial } from 'utility-types'; import type { StateCreator } from 'zustand/vanilla'; import { globalService } from '@/services/global'; import { messageService } from '@/services/message'; -import { UserConfig, userService } from '@/services/user'; +import { userService } from '@/services/user'; import type { UserStore } from '@/store/user'; import type { GlobalServerConfig } from '@/types/serverConfig'; import type { GlobalSettings } from '@/types/settings'; -import { OnSyncEvent, PeerSyncStatus } from '@/types/sync'; -import { switchLang } from '@/utils/client/switchLang'; import { merge } from '@/utils/merge'; -import { browserInfo } from '@/utils/platform'; import { setNamespace } from '@/utils/storeDebug'; import { preferenceSelectors } from '../preference/selectors'; -import { settingsSelectors, syncSettingsSelectors } from '../settings/selectors'; -import { commonSelectors } from './selectors'; const n = setNamespace('common'); @@ -24,79 +19,22 @@ const n = setNamespace('common'); * 设置操作 */ export interface CommonAction { - refreshConnection: (onEvent: OnSyncEvent) => Promise; - refreshUserConfig: () => Promise; - triggerEnableSync: (userId: string, onEvent: OnSyncEvent) => Promise; updateAvatar: (avatar: string) => Promise; useCheckTrace: (shouldFetch: boolean) => SWRResponse; - useEnabledSync: ( - userEnableSync: boolean, - userId: string | undefined, - onEvent: OnSyncEvent, - ) => SWRResponse; useFetchServerConfig: () => SWRResponse; - useFetchUserConfig: (initServer: boolean) => SWRResponse; } -const USER_CONFIG_FETCH_KEY = 'fetchUserConfig'; - export const createCommonSlice: StateCreator< UserStore, [['zustand/devtools', never]], [], CommonAction > = (set, get) => ({ - refreshConnection: async (onEvent) => { - const userId = commonSelectors.userId(get()); - - if (!userId) return; - - await get().triggerEnableSync(userId, onEvent); - }, - - refreshUserConfig: async () => { - await mutate([USER_CONFIG_FETCH_KEY, true]); - - // when get the user config ,refresh the model provider list to the latest - get().refreshModelProviderList(); - }, - - triggerEnableSync: async (userId: string, onEvent: OnSyncEvent) => { - // double-check the sync ability - // if there is no channelName, don't start sync - const sync = syncSettingsSelectors.webrtcConfig(get()); - if (!sync.channelName) return false; - - const name = syncSettingsSelectors.deviceName(get()); - - const defaultUserName = `My ${browserInfo.browser} (${browserInfo.os})`; - - set({ syncStatus: PeerSyncStatus.Connecting }); - return globalService.enabledSync({ - channel: { - name: sync.channelName, - password: sync.channelPassword, - }, - onAwarenessChange(state) { - set({ syncAwareness: state }); - }, - onSyncEvent: onEvent, - onSyncStatusChange: (status) => { - set({ syncStatus: status }); - }, - signaling: sync.signaling, - user: { - id: userId, - // if user don't set the name, use default name - name: name || defaultUserName, - ...browserInfo, - }, - }); - }, updateAvatar: async (avatar) => { await userService.updateAvatar(avatar); await get().refreshUserConfig(); }, + useCheckTrace: (shouldFetch) => useSWR( ['checkTrace', shouldFetch], @@ -115,25 +53,6 @@ export const createCommonSlice: StateCreator< }, ), - useEnabledSync: (userEnableSync, userId, onEvent) => - useSWR( - ['enableSync', userEnableSync, userId], - async () => { - // if user don't enable sync or no userId ,don't start sync - if (!userId) return false; - - // if user don't enable sync, stop sync - if (!userEnableSync) return globalService.disableSync(); - - return get().triggerEnableSync(userId, onEvent); - }, - { - onSuccess: (syncEnabled) => { - set({ syncEnabled }, false, n('useEnabledSync')); - }, - revalidateOnFocus: false, - }, - ), useFetchServerConfig: () => useSWR('fetchGlobalConfig', globalService.getGlobalConfig, { onSuccess: (data) => { @@ -152,32 +71,4 @@ export const createCommonSlice: StateCreator< }, revalidateOnFocus: false, }), - useFetchUserConfig: (initServer) => - useSWR( - [USER_CONFIG_FETCH_KEY, initServer], - async () => { - if (!initServer) return; - return userService.getUserConfig(); - }, - { - onSuccess: (data) => { - if (!data) return; - - set( - { avatar: data.avatar, settings: data.settings, userId: data.uuid }, - false, - n('fetchUserConfig', data), - ); - - // when get the user config ,refresh the model provider list to the latest - get().refreshDefaultModelProviderList({ trigger: 'fetchUserConfig' }); - - const { language } = settingsSelectors.currentSettings(get()); - if (language === 'auto') { - switchLang('auto'); - } - }, - revalidateOnFocus: false, - }, - ), }); diff --git a/src/store/user/slices/common/selectors.ts b/src/store/user/slices/common/selectors.ts deleted file mode 100644 index 0947e2719b42d..0000000000000 --- a/src/store/user/slices/common/selectors.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { UserStore } from '@/store/user'; - -export const commonSelectors = { - userAvatar: (s: UserStore) => s.avatar || '', - userId: (s: UserStore) => s.userId, -}; diff --git a/src/store/user/slices/settings/initialState.ts b/src/store/user/slices/settings/initialState.ts index acfe8a0ef3d48..7e47d1735e52f 100644 --- a/src/store/user/slices/settings/initialState.ts +++ b/src/store/user/slices/settings/initialState.ts @@ -7,14 +7,12 @@ import { GlobalServerConfig } from '@/types/serverConfig'; import { GlobalSettings } from '@/types/settings'; export interface UserSettingsState { - avatar?: string; defaultModelProviderList: ModelProviderCard[]; defaultSettings: GlobalSettings; editingCustomCardModel?: { id: string; provider: string } | undefined; modelProviderList: ModelProviderCard[]; serverConfig: GlobalServerConfig; settings: DeepPartial; - userId?: string; } export const initialSettingsState: UserSettingsState = { diff --git a/src/store/user/slices/sync/action.test.ts b/src/store/user/slices/sync/action.test.ts new file mode 100644 index 0000000000000..2ee13918e356c --- /dev/null +++ b/src/store/user/slices/sync/action.test.ts @@ -0,0 +1,150 @@ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { withSWR } from '~test-utils'; + +import { globalService } from '@/services/global'; +import { useUserStore } from '@/store/user'; +import { userProfileSelectors } from '@/store/user/slices/auth/selectors'; +import { syncSettingsSelectors } from '@/store/user/slices/settings/selectors'; + +vi.mock('zustand/traditional'); + +vi.mock('swr', async (importOriginal) => { + const modules = await importOriginal(); + return { + ...(modules as any), + mutate: vi.fn(), + }; +}); + +afterEach(() => { + vi.restoreAllMocks(); +}); + +describe('createSyncSlice', () => { + describe('refreshConnection', () => { + it('should not call triggerEnableSync when userId is empty', async () => { + const { result } = renderHook(() => useUserStore()); + const onEvent = vi.fn(); + + vi.spyOn(userProfileSelectors, 'userId').mockReturnValueOnce(undefined as any); + const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); + + await act(async () => { + await result.current.refreshConnection(onEvent); + }); + + expect(triggerEnableSyncSpy).not.toHaveBeenCalled(); + }); + + it('should call triggerEnableSync when userId exists', async () => { + const { result } = renderHook(() => useUserStore()); + const onEvent = vi.fn(); + const userId = 'user-id'; + + vi.spyOn(userProfileSelectors, 'userId').mockReturnValueOnce(userId); + const triggerEnableSyncSpy = vi.spyOn(result.current, 'triggerEnableSync'); + + await act(async () => { + await result.current.refreshConnection(onEvent); + }); + + expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); + }); + }); + + describe('triggerEnableSync', () => { + it('should return false when sync.channelName is empty', async () => { + const { result } = renderHook(() => useUserStore()); + const userId = 'user-id'; + const onEvent = vi.fn(); + + vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ + channelName: '', + enabled: true, + }); + + const data = await act(async () => { + return result.current.triggerEnableSync(userId, onEvent); + }); + + expect(data).toBe(false); + }); + + it('should call globalService.enabledSync when sync.channelName exists', async () => { + const userId = 'user-id'; + const onEvent = vi.fn(); + const channelName = 'channel-name'; + const channelPassword = 'channel-password'; + const deviceName = 'device-name'; + const signaling = 'signaling'; + + vi.spyOn(syncSettingsSelectors, 'webrtcConfig').mockReturnValueOnce({ + channelName, + channelPassword, + signaling, + enabled: true, + }); + vi.spyOn(syncSettingsSelectors, 'deviceName').mockReturnValueOnce(deviceName); + const enabledSyncSpy = vi.spyOn(globalService, 'enabledSync').mockResolvedValueOnce(true); + const { result } = renderHook(() => useUserStore()); + + const data = await act(async () => { + return result.current.triggerEnableSync(userId, onEvent); + }); + + expect(enabledSyncSpy).toHaveBeenCalledWith({ + channel: { name: channelName, password: channelPassword }, + onAwarenessChange: expect.any(Function), + onSyncEvent: onEvent, + onSyncStatusChange: expect.any(Function), + signaling, + user: expect.objectContaining({ id: userId, name: deviceName }), + }); + expect(data).toBe(true); + }); + }); + + describe('useEnabledSync', () => { + it('should return false when userId is empty', async () => { + const { result } = renderHook(() => useUserStore().useEnabledSync(true, undefined, vi.fn()), { + wrapper: withSWR, + }); + + await waitFor(() => expect(result.current.data).toBe(false)); + }); + + it('should call globalService.disableSync when userEnableSync is false', async () => { + const disableSyncSpy = vi.spyOn(globalService, 'disableSync').mockResolvedValueOnce(false); + + const { result } = renderHook( + () => useUserStore().useEnabledSync(false, 'user-id', vi.fn()), + { wrapper: withSWR }, + ); + + await waitFor(() => expect(result.current.data).toBeUndefined()); + expect(disableSyncSpy).toHaveBeenCalled(); + }); + + it('should call triggerEnableSync when userEnableSync and userId exist', async () => { + const userId = 'user-id'; + const onEvent = vi.fn(); + const triggerEnableSyncSpy = vi.fn().mockResolvedValueOnce(true); + + const { result } = renderHook(() => useUserStore()); + + // replace triggerEnableSync as a mock + result.current.triggerEnableSync = triggerEnableSyncSpy; + + const { result: swrResult } = renderHook( + () => result.current.useEnabledSync(true, userId, onEvent), + { + wrapper: withSWR, + }, + ); + + await waitFor(() => expect(swrResult.current.data).toBe(true)); + expect(triggerEnableSyncSpy).toHaveBeenCalledWith(userId, onEvent); + }); + }); +}); diff --git a/src/store/user/slices/sync/action.ts b/src/store/user/slices/sync/action.ts new file mode 100644 index 0000000000000..e3fa0896642a0 --- /dev/null +++ b/src/store/user/slices/sync/action.ts @@ -0,0 +1,94 @@ +import useSWR, { SWRResponse } from 'swr'; +import type { StateCreator } from 'zustand/vanilla'; + +import { globalService } from '@/services/global'; +import type { UserStore } from '@/store/user'; +import { OnSyncEvent, PeerSyncStatus } from '@/types/sync'; +import { browserInfo } from '@/utils/platform'; +import { setNamespace } from '@/utils/storeDebug'; + +import { userProfileSelectors } from '../auth/selectors'; +import { syncSettingsSelectors } from '../settings/selectors'; + +const n = setNamespace('sync'); + +/** + * 设置操作 + */ +export interface SyncAction { + refreshConnection: (onEvent: OnSyncEvent) => Promise; + triggerEnableSync: (userId: string, onEvent: OnSyncEvent) => Promise; + useEnabledSync: ( + userEnableSync: boolean, + userId: string | undefined, + onEvent: OnSyncEvent, + ) => SWRResponse; +} + +export const createSyncSlice: StateCreator< + UserStore, + [['zustand/devtools', never]], + [], + SyncAction +> = (set, get) => ({ + refreshConnection: async (onEvent) => { + const userId = userProfileSelectors.userId(get()); + + if (!userId) return; + + await get().triggerEnableSync(userId, onEvent); + }, + + triggerEnableSync: async (userId: string, onEvent: OnSyncEvent) => { + // double-check the sync ability + // if there is no channelName, don't start sync + const sync = syncSettingsSelectors.webrtcConfig(get()); + if (!sync.channelName) return false; + + const name = syncSettingsSelectors.deviceName(get()); + + const defaultUserName = `My ${browserInfo.browser} (${browserInfo.os})`; + + set({ syncStatus: PeerSyncStatus.Connecting }); + return globalService.enabledSync({ + channel: { + name: sync.channelName, + password: sync.channelPassword, + }, + onAwarenessChange(state) { + set({ syncAwareness: state }); + }, + onSyncEvent: onEvent, + onSyncStatusChange: (status) => { + set({ syncStatus: status }); + }, + signaling: sync.signaling, + user: { + id: userId, + // if user don't set the name, use default name + name: name || defaultUserName, + ...browserInfo, + }, + }); + }, + + useEnabledSync: (userEnableSync, userId, onEvent) => + useSWR( + ['enableSync', userEnableSync, userId], + async () => { + // if user don't enable sync or no userId ,don't start sync + if (!userId) return false; + + // if user don't enable sync, stop sync + if (!userEnableSync) return globalService.disableSync(); + + return get().triggerEnableSync(userId, onEvent); + }, + { + onSuccess: (syncEnabled) => { + set({ syncEnabled }, false, n('useEnabledSync')); + }, + revalidateOnFocus: false, + }, + ), +}); diff --git a/src/store/user/slices/common/initialState.ts b/src/store/user/slices/sync/initialState.ts similarity index 61% rename from src/store/user/slices/common/initialState.ts rename to src/store/user/slices/sync/initialState.ts index a88f6c7649ec9..9345992102a65 100644 --- a/src/store/user/slices/common/initialState.ts +++ b/src/store/user/slices/sync/initialState.ts @@ -1,17 +1,12 @@ import { PeerSyncStatus, SyncAwarenessState } from '@/types/sync'; -export interface Guide { - // Topic 引导 - topic?: boolean; -} - -export interface UserCommonState { +export interface UserSyncState { syncAwareness: SyncAwarenessState[]; syncEnabled: boolean; syncStatus: PeerSyncStatus; } -export const initialCommonState: UserCommonState = { +export const initialSyncState: UserSyncState = { syncAwareness: [], syncEnabled: false, syncStatus: PeerSyncStatus.Disabled, diff --git a/src/store/user/store.ts b/src/store/user/store.ts index 0e705c6865678..f1a1e5f895b56 100644 --- a/src/store/user/store.ts +++ b/src/store/user/store.ts @@ -6,19 +6,28 @@ import { StateCreator } from 'zustand/vanilla'; import { isDev } from '@/utils/env'; import { type UserState, initialState } from './initialState'; +import { type UserAuthAction, createAuthSlice } from './slices/auth/action'; import { type CommonAction, createCommonSlice } from './slices/common/action'; import { type PreferenceAction, createPreferenceSlice } from './slices/preference/action'; import { type SettingsAction, createSettingsSlice } from './slices/settings/actions'; +import { type SyncAction, createSyncSlice } from './slices/sync/action'; // =============== 聚合 createStoreFn ============ // -export type UserStore = CommonAction & UserState & SettingsAction & PreferenceAction; +export type UserStore = SyncAction & + UserState & + SettingsAction & + PreferenceAction & + UserAuthAction & + CommonAction; const createStore: StateCreator = (...parameters) => ({ ...initialState, - ...createCommonSlice(...parameters), + ...createSyncSlice(...parameters), ...createSettingsSlice(...parameters), ...createPreferenceSlice(...parameters), + ...createAuthSlice(...parameters), + ...createCommonSlice(...parameters), }); // =============== 实装 useStore ============ //