Skip to content

Commit

Permalink
♻️ refactor: user store add an auth slice (#2214)
Browse files Browse the repository at this point in the history
* ♻️ refactor: refactor the user store with auth slice

* ♻️ refactor: separate common and sync slice

* 🧑‍💻 chore: add an isMobile selector

* ♻️ refactor: refactor the auth action and common action

* 🎨 chore: clean code
  • Loading branch information
arvinxx authored May 3, 2024
1 parent 4cb5adb commit 948b257
Show file tree
Hide file tree
Showing 23 changed files with 513 additions and 369 deletions.
4 changes: 2 additions & 2 deletions src/app/(main)/chat/(mobile)/features/SessionHeader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import SyncStatusInspector from '@/features/SyncStatusInspector';
import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig';
import { useSessionStore } from '@/store/session';
import { useUserStore } from '@/store/user';
import { commonSelectors } from '@/store/user/selectors';
import { userProfileSelectors } from '@/store/user/selectors';
import { mobileHeaderSticky } from '@/styles/mobileHeader';

export const useStyles = createStyles(({ css, token }) => ({
Expand All @@ -26,7 +26,7 @@ export const useStyles = createStyles(({ css, token }) => ({
const Header = memo(() => {
const [createSession] = useSessionStore((s) => [s.createSession]);
const router = useRouter();
const avatar = useUserStore(commonSelectors.userAvatar);
const avatar = useUserStore(userProfileSelectors.userAvatar);
const { showCreateSession } = useServerConfigStore(featureFlagsSelectors);

return (
Expand Down
4 changes: 2 additions & 2 deletions src/app/(main)/chat/features/ShareButton/ShareModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { Flexbox } from 'react-layout-kit';
import { FORM_STYLE } from '@/const/layoutTokens';
import { useChatStore } from '@/store/chat';
import { useUserStore } from '@/store/user';
import { commonSelectors } from '@/store/user/selectors';
import { userProfileSelectors } from '@/store/user/selectors';

import Preview from './Preview';
import { FieldType, ImageType } from './type';
Expand Down Expand Up @@ -49,7 +49,7 @@ const ShareModal = memo<ModalProps>(({ onCancel, open }) => {
const [fieldValue, setFieldValue] = useState<FieldType>(DEFAULT_FIELD_VALUE);
const [tab, setTab] = useState<Tab>(Tab.Screenshot);
const { t } = useTranslation('chat');
const avatar = useUserStore(commonSelectors.userAvatar);
const avatar = useUserStore(userProfileSelectors.userAvatar);
const [shareLoading, shareToShareGPT] = useChatStore((s) => [s.shareLoading, s.shareToShareGPT]);
const { loading, onDownload, title } = useScreenshot(fieldValue.imageType);

Expand Down
4 changes: 2 additions & 2 deletions src/features/AvatarWithUpload/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { CSSProperties, memo, useCallback } from 'react';

import { DEFAULT_USER_AVATAR_URL } from '@/const/meta';
import { useUserStore } from '@/store/user';
import { commonSelectors } from '@/store/user/selectors';
import { userProfileSelectors } from '@/store/user/selectors';
import { imageToBase64 } from '@/utils/imageToBase64';
import { createUploadImageHandler } from '@/utils/uploadFIle';

Expand Down Expand Up @@ -41,7 +41,7 @@ const AvatarWithUpload = memo<AvatarWithUploadProps>(
({ size = 40, compressSize = 256, style, id }) => {
const { styles } = useStyle();
const [avatar, updateAvatar] = useUserStore((s) => [
commonSelectors.userAvatar(s),
userProfileSelectors.userAvatar(s),
s.updateAvatar,
]);

Expand Down
8 changes: 7 additions & 1 deletion src/layout/GlobalProvider/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
import { getServerGlobalConfig } from '@/server/globalConfig';
import { ServerConfigStoreProvider } from '@/store/serverConfig';
import { getAntdLocale } from '@/utils/locale';
import { isMobileDevice } from '@/utils/responsive';

import AppTheme from './AppTheme';
import Locale from './Locale';
Expand Down Expand Up @@ -48,6 +49,7 @@ const GlobalLayout = async ({ children }: GlobalLayoutProps) => {
// get default feature flags to use with ssr
const serverFeatureFlags = getServerFeatureFlagsValue();
const serverConfig = getServerGlobalConfig();
const isMobile = isMobileDevice();
return (
<StyleRegistry>
<Locale antdLocale={antdLocale} defaultLang={defaultLang?.value}>
Expand All @@ -57,7 +59,11 @@ const GlobalLayout = async ({ children }: GlobalLayoutProps) => {
defaultPrimaryColor={primaryColor?.value as any}
>
<StoreInitialization />
<ServerConfigStoreProvider featureFlags={serverFeatureFlags} serverConfig={serverConfig}>
<ServerConfigStoreProvider
featureFlags={serverFeatureFlags}
isMobile={isMobile}
serverConfig={serverConfig}
>
{children}
</ServerConfigStoreProvider>
<DebugUI />
Expand Down
4 changes: 2 additions & 2 deletions src/services/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import { useToolStore } from '@/store/tool';
import { pluginSelectors, toolSelectors } from '@/store/tool/selectors';
import { useUserStore } from '@/store/user';
import {
commonSelectors,
modelConfigSelectors,
modelProviderSelectors,
preferenceSelectors,
userProfileSelectors,
} from '@/store/user/selectors';
import { ChatErrorType } from '@/types/fetch';
import { ChatMessage } from '@/types/message';
Expand Down Expand Up @@ -482,7 +482,7 @@ class ChatService {
...trace,
enabled: true,
tags: [tag, ...(trace?.tags || []), ...tags].filter(Boolean) as string[],
userId: commonSelectors.userId(useUserStore.getState()),
userId: userProfileSelectors.userId(useUserStore.getState()),
};
}

Expand Down
4 changes: 2 additions & 2 deletions src/store/chat/slices/message/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { agentSelectors } from '@/store/agent/selectors';
import { useSessionStore } from '@/store/session';
import { sessionMetaSelectors } from '@/store/session/selectors';
import { useUserStore } from '@/store/user';
import { commonSelectors } from '@/store/user/selectors';
import { userProfileSelectors } from '@/store/user/selectors';
import { ChatMessage } from '@/types/message';
import { MetaData } from '@/types/meta';
import { merge } from '@/utils/merge';
Expand All @@ -20,7 +20,7 @@ const getMeta = (message: ChatMessage) => {
switch (message.role) {
case 'user': {
return {
avatar: commonSelectors.userAvatar(useUserStore.getState()) || DEFAULT_USER_AVATAR,
avatar: userProfileSelectors.userAvatar(useUserStore.getState()) || DEFAULT_USER_AVATAR,
};
}

Expand Down
5 changes: 3 additions & 2 deletions src/store/serverConfig/Provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import { Provider, createServerConfigStore } from './store';
interface GlobalStoreProviderProps {
children: ReactNode;
featureFlags?: Partial<IFeatureFlags>;
isMobile?: boolean;
serverConfig?: GlobalServerConfig;
}

export const ServerConfigStoreProvider = memo<GlobalStoreProviderProps>(
({ children, featureFlags, serverConfig }) => (
<Provider createStore={() => createServerConfigStore({ featureFlags, serverConfig })}>
({ children, featureFlags, serverConfig, isMobile }) => (
<Provider createStore={() => createServerConfigStore({ featureFlags, isMobile, serverConfig })}>
{children}
</Provider>
),
Expand Down
1 change: 1 addition & 0 deletions src/store/serverConfig/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ export const featureFlagsSelectors = (s: ServerConfigStore) =>
export const serverConfigSelectors = {
enabledOAuthSSO: (s: ServerConfigStore) => s.serverConfig.enabledOAuthSSO,
enabledTelemetryChat: (s: ServerConfigStore) => s.serverConfig.telemetry.langfuse || false,
isMobile: (s: ServerConfigStore) => s.isMobile || false,
};
1 change: 1 addition & 0 deletions src/store/serverConfig/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const initialState: ServerConfigStore = {

export interface ServerConfigStore {
featureFlags: IFeatureFlags;
isMobile?: boolean;
serverConfig: GlobalServerConfig;
}

Expand Down
8 changes: 5 additions & 3 deletions src/store/user/initialState.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { UserCommonState, initialCommonState } from './slices/common/initialState';
import { UserAuthState, initialAuthState } from './slices/auth/initialState';
import { UserPreferenceState, initialPreferenceState } from './slices/preference/initialState';
import { UserSettingsState, initialSettingsState } from './slices/settings/initialState';
import { UserSyncState, initialSyncState } from './slices/sync/initialState';

export type UserState = UserCommonState & UserSettingsState & UserPreferenceState;
export type UserState = UserSyncState & UserSettingsState & UserPreferenceState & UserAuthState;

export const initialState: UserState = {
...initialCommonState,
...initialSyncState,
...initialSettingsState,
...initialPreferenceState,
...initialAuthState,
};
2 changes: 1 addition & 1 deletion src/store/user/selectors.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export { commonSelectors } from './slices/common/selectors';
export { userProfileSelectors } from './slices/auth/selectors';
export { preferenceSelectors } from './slices/preference/selectors';
export {
modelConfigSelectors,
Expand Down
118 changes: 118 additions & 0 deletions src/store/user/slices/auth/action.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import { act, renderHook, waitFor } from '@testing-library/react';
import { mutate } from 'swr';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { withSWR } from '~test-utils';

import { userService } from '@/services/user';
import { useUserStore } from '@/store/user';
import { switchLang } from '@/utils/client/switchLang';

vi.mock('zustand/traditional');

vi.mock('@/utils/client/switchLang', () => ({
switchLang: vi.fn(),
}));

vi.mock('swr', async (importOriginal) => {
const modules = await importOriginal();
return {
...(modules as any),
mutate: vi.fn(),
};
});

afterEach(() => {
vi.restoreAllMocks();
});

describe('createAuthSlice', () => {
describe('refreshUserConfig', () => {
it('should refresh user config', async () => {
const { result } = renderHook(() => useUserStore());

await act(async () => {
await result.current.refreshUserConfig();
});

expect(mutate).toHaveBeenCalledWith(['fetchUserConfig', true]);
});
});

describe('useFetchUserConfig', () => {
it('should not fetch user config if initServer is false', async () => {
const mockUserConfig: any = undefined; // 模拟未初始化服务器的情况
vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig);

const { result } = renderHook(() => useUserStore().useFetchUserConfig(false), {
wrapper: withSWR,
});

// 因为 initServer 为 false,所以不会触发 getUserConfig 的调用
expect(userService.getUserConfig).not.toHaveBeenCalled();
// 确保状态未改变
expect(result.current.data).toBeUndefined();
});

it('should fetch user config correctly when initServer is true', async () => {
const mockUserConfig: any = {
avatar: 'new-avatar-url',
settings: {
language: 'en',
},
};
vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig);

const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), {
wrapper: withSWR,
});

// 等待 SWR 完成数据获取
await waitFor(() => expect(result.current.data).toEqual(mockUserConfig));

// 验证状态是否正确更新
expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar);
expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings);

// 验证是否正确处理了语言设置
expect(switchLang).not.toHaveBeenCalledWith('auto');
});
it('should call switch language when language is auto', async () => {
const mockUserConfig: any = {
avatar: 'new-avatar-url',
settings: {
language: 'auto',
},
};
vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(mockUserConfig);

const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), {
wrapper: withSWR,
});

// 等待 SWR 完成数据获取
await waitFor(() => expect(result.current.data).toEqual(mockUserConfig));

// 验证状态是否正确更新
expect(useUserStore.getState().avatar).toBe(mockUserConfig.avatar);
expect(useUserStore.getState().settings).toEqual(mockUserConfig.settings);

// 验证是否正确处理了语言设置
expect(switchLang).toHaveBeenCalledWith('auto');
});

it('should handle the case when user config is null', async () => {
vi.spyOn(userService, 'getUserConfig').mockResolvedValueOnce(null as any);

const { result } = renderHook(() => useUserStore().useFetchUserConfig(true), {
wrapper: withSWR,
});

// 等待 SWR 完成数据获取
await waitFor(() => expect(result.current.data).toBeNull());

// 验证状态未被错误更新
expect(useUserStore.getState().avatar).toBeUndefined();
expect(useUserStore.getState().settings).toEqual({});
});
});
});
81 changes: 81 additions & 0 deletions src/store/user/slices/auth/action.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import useSWR, { SWRResponse, mutate } from 'swr';
import { StateCreator } from 'zustand/vanilla';

import { UserConfig, userService } from '@/services/user';
import { switchLang } from '@/utils/client/switchLang';
import { setNamespace } from '@/utils/storeDebug';

import { UserStore } from '../../store';
import { settingsSelectors } from '../settings/selectors';

const n = setNamespace('auth');
const USER_CONFIG_FETCH_KEY = 'fetchUserConfig';

export interface UserAuthAction {
getUserConfig: () => void;
/**
* universal login method
*/
login: () => Promise<void>;
/**
* universal logout method
*/
logout: () => Promise<void>;
refreshUserConfig: () => Promise<void>;

useFetchUserConfig: (initServer: boolean) => SWRResponse<UserConfig | undefined>;
}

export const createAuthSlice: StateCreator<
UserStore,
[['zustand/devtools', never]],
[],
UserAuthAction
> = (set, get) => ({
getUserConfig: () => {
console.log(n('userconfig'));
},
login: async () => {
// TODO: 针对开启 next-auth 的场景,需要在这里调用登录方法
console.log(n('login'));
},
logout: async () => {
// TODO: 针对开启 next-auth 的场景,需要在这里调用登录方法
console.log(n('logout'));
},
refreshUserConfig: async () => {
await mutate([USER_CONFIG_FETCH_KEY, true]);

// when get the user config ,refresh the model provider list to the latest
get().refreshModelProviderList();
},

useFetchUserConfig: (initServer) =>
useSWR<UserConfig | undefined>(
[USER_CONFIG_FETCH_KEY, initServer],
async () => {
if (!initServer) return;
return userService.getUserConfig();
},
{
onSuccess: (data) => {
if (!data) return;

set(
{ avatar: data.avatar, settings: data.settings, userId: data.uuid },
false,
n('fetchUserConfig', data),
);

// when get the user config ,refresh the model provider list to the latest
get().refreshDefaultModelProviderList({ trigger: 'fetchUserConfig' });

const { language } = settingsSelectors.currentSettings(get());
if (language === 'auto') {
switchLang('auto');
}
},
revalidateOnFocus: false,
},
),
});
20 changes: 20 additions & 0 deletions src/store/user/slices/auth/initialState.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
export interface LobeUser {
avatar?: string;
firstName?: string | null;
fullName?: string | null;
id: string;
latestName?: string | null;
username?: string | null;
}

export interface UserAuthState {
/**
* @deprecated
*/
avatar?: string;
isSignedIn?: boolean;
user?: LobeUser;
userId?: string;
}

export const initialAuthState: UserAuthState = {};
Loading

0 comments on commit 948b257

Please sign in to comment.