Skip to content

Commit

Permalink
♻️ refactor: refactor upload method
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Dec 21, 2024
1 parent c59c1e2 commit b24b914
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 99 deletions.
16 changes: 15 additions & 1 deletion src/database/_deprecated/models/file.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { DBModel } from '@/database/_deprecated/core/types/db';
import { DB_File, DB_FileSchema } from '@/database/_deprecated/schemas/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { nanoid } from '@/utils/uuid';

import { BaseModel } from '../core';
Expand All @@ -20,7 +21,13 @@ class _FileModel extends BaseModel<'files'> {
if (!item) return;

// arrayBuffer to url
const base64 = Buffer.from(item.data!).toString('base64');
let base64;
if (!item.data) {
const hash = (item.url as string).replace('client-s3://', '');
base64 = await this.getBase64ByFileHash(hash);
} else {
base64 = Buffer.from(item.data).toString('base64');
}

return { ...item, url: `data:${item.fileType};base64,${base64}` };
}
Expand All @@ -32,6 +39,13 @@ class _FileModel extends BaseModel<'files'> {
async clear() {
return this.table.clear();
}

private async getBase64ByFileHash(hash: string) {
const fileItem = await clientS3Storage.getObject(hash);
if (!fileItem) throw new Error('file not found');

return Buffer.from(await fileItem.arrayBuffer()).toString('base64');
}
}

export const FileModel = new _FileModel();
4 changes: 1 addition & 3 deletions src/server/routers/lambda/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ export const fileRouter = router({
}),

createFile: fileProcedure
.input(
UploadFileSchema.omit({ data: true, saveMode: true, url: true }).extend({ url: z.string() }),
)
.input(UploadFileSchema.omit({ url: true }).extend({ url: z.string() }))
.mutation(async ({ ctx, input }) => {
const { isExist } = await ctx.fileModel.checkHash(input.hash!);

Expand Down
115 changes: 115 additions & 0 deletions src/services/file/ClientS3/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { createStore, del, get, set } from 'idb-keyval';
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { BrowserS3Storage } from './index';

// Mock idb-keyval
vi.mock('idb-keyval', () => ({
createStore: vi.fn(),
set: vi.fn(),
get: vi.fn(),
del: vi.fn(),
}));

let storage: BrowserS3Storage;
let mockStore = {};

beforeEach(() => {
// Reset all mocks before each test
vi.clearAllMocks();
mockStore = {};
(createStore as any).mockReturnValue(mockStore);
storage = new BrowserS3Storage();
});

describe('BrowserS3Storage', () => {
describe('constructor', () => {
it('should create store when in browser environment', () => {
expect(createStore).toHaveBeenCalledWith('lobechat-local-s3', 'objects');
});
});

describe('putObject', () => {
it('should successfully put a file object', async () => {
const mockFile = new File(['test content'], 'test.txt', { type: 'text/plain' });
const mockArrayBuffer = new ArrayBuffer(8);
vi.spyOn(mockFile, 'arrayBuffer').mockResolvedValue(mockArrayBuffer);
(set as any).mockResolvedValue(undefined);

await storage.putObject('1-test-key', mockFile);

expect(set).toHaveBeenCalledWith(
'1-test-key',
{
data: mockArrayBuffer,
name: 'test.txt',
type: 'text/plain',
},
mockStore,
);
});

it('should throw error when put operation fails', async () => {
const mockFile = new File(['test content'], 'test.txt', { type: 'text/plain' });
const mockError = new Error('Storage error');
(set as any).mockRejectedValue(mockError);

await expect(storage.putObject('test-key', mockFile)).rejects.toThrow(
'Failed to put file test.txt: Storage error',
);
});
});

describe('getObject', () => {
it('should successfully get a file object', async () => {
const mockData = {
data: new ArrayBuffer(8),
name: 'test.txt',
type: 'text/plain',
};
(get as any).mockResolvedValue(mockData);

const result = await storage.getObject('test-key');

expect(result).toBeInstanceOf(File);
expect(result?.name).toBe('test.txt');
expect(result?.type).toBe('text/plain');
});

it('should return undefined when file not found', async () => {
(get as any).mockResolvedValue(undefined);

const result = await storage.getObject('test-key');

expect(result).toBeUndefined();
});

it('should throw error when get operation fails', async () => {
const mockError = new Error('Storage error');
(get as any).mockRejectedValue(mockError);

await expect(storage.getObject('test-key')).rejects.toThrow(
'Failed to get object (key=test-key): Storage error',
);
});
});

describe('deleteObject', () => {
it('should successfully delete a file object', async () => {
(del as any).mockResolvedValue(undefined);

await storage.deleteObject('test-key2');

expect(del).toHaveBeenCalledWith('test-key2', {});
});

it('should throw error when delete operation fails', async () => {
const mockError = new Error('Storage error');
(del as any).mockRejectedValue(mockError);

await expect(storage.deleteObject('test-key')).rejects.toThrow(
'Failed to delete object (key=test-key): Storage error',
);
});
});
});
58 changes: 58 additions & 0 deletions src/services/file/ClientS3/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { createStore, del, get, set } from 'idb-keyval';

const BROWSER_S3_DB_NAME = 'lobechat-local-s3';

export class BrowserS3Storage {
private store;

constructor() {
// skip server-side rendering
if (typeof window === 'undefined') return;

this.store = createStore(BROWSER_S3_DB_NAME, 'objects');
}

/**
* 上传文件
* @param key 文件 hash
* @param file File 对象
*/
async putObject(key: string, file: File): Promise<void> {
try {
const data = await file.arrayBuffer();
await set(key, { data, name: file.name, type: file.type }, this.store);
} catch (e) {
throw new Error(`Failed to put file ${file.name}: ${(e as Error).message}`);
}
}

/**
* 获取文件
* @param key 文件 hash
* @returns File 对象
*/
async getObject(key: string): Promise<File | undefined> {
try {
const res = await get<{ data: ArrayBuffer; name: string; type: string }>(key, this.store);
if (!res) return;

return new File([res.data], res!.name, { type: res?.type });
} catch (e) {
throw new Error(`Failed to get object (key=${key}): ${(e as Error).message}`);
}
}

/**
* 删除文件
* @param key 文件 hash
*/
async deleteObject(key: string): Promise<void> {
try {
await del(key, this.store);
} catch (e) {
throw new Error(`Failed to delete object (key=${key}): ${(e as Error).message}`);
}
}
}

export const clientS3Storage = new BrowserS3Storage();
35 changes: 29 additions & 6 deletions src/services/file/client.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import { FileModel } from '@/database/_deprecated/models/file';
import { DB_File } from '@/database/_deprecated/schemas/files';
import { FileItem } from '@/types/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { FileItem, UploadFileParams } from '@/types/files';

import { IFileService } from './type';

export class ClientService implements IFileService {
async createFile(file: DB_File) {
async createFile(file: UploadFileParams) {
// save to local storage
// we may want to save to a remote server later
const res = await FileModel.create(file);
// arrayBuffer to url
const base64 = Buffer.from(file.data!).toString('base64');
const res = await FileModel.create({
createdAt: Date.now(),
data: undefined,
fileHash: file.hash,
fileType: file.fileType,
metadata: file.metadata,
name: file.name,
saveMode: 'url',
size: file.size,
url: file.url,
} as any);

// get file to base64 url
const base64 = await this.getBase64ByFileHash(file.hash!);

return {
id: res.id,
url: `data:${file.fileType};base64,${base64}`,
};
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async checkFileHash(_hash: string) {
return { isExist: false, metadata: {} };
}

async getFile(id: string): Promise<FileItem> {
const item = await FileModel.findById(id);
if (!item) {
Expand Down Expand Up @@ -49,4 +65,11 @@ export class ClientService implements IFileService {
async removeAllFiles() {
return FileModel.clear();
}

private async getBase64ByFileHash(hash: string) {
const fileItem = await clientS3Storage.getObject(hash);
if (!fileItem) throw new Error('file not found');

return Buffer.from(await fileItem.arrayBuffer()).toString('base64');
}
}
24 changes: 8 additions & 16 deletions src/services/upload.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { fileEnv } from '@/config/file';
import { edgeClient } from '@/libs/trpc/client';
import { API_ENDPOINTS } from '@/services/_url';
import { FileMetadata, UploadFileParams } from '@/types/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { FileMetadata } from '@/types/files';
import { FileUploadState, FileUploadStatus } from '@/types/files/upload';
import { uuid } from '@/utils/uuid';

Expand Down Expand Up @@ -66,23 +67,14 @@ class UploadService {
return result;
};

uploadToClientDB = async (params: UploadFileParams, file: File) => {
const { FileModel } = await import('@/database/_deprecated/models/file');
const fileArrayBuffer = await file.arrayBuffer();

// save to local storage
// we may want to save to a remote server later
const res = await FileModel.create({
createdAt: Date.now(),
...params,
data: fileArrayBuffer,
});
// arrayBuffer to url
const base64 = Buffer.from(fileArrayBuffer).toString('base64');
uploadToClientS3 = async (hash: string, file: File): Promise<FileMetadata> => {
await clientS3Storage.putObject(hash, file);

return {
id: res.id,
url: `data:${params.fileType};base64,${base64}`,
date: (Date.now() / 1000 / 60 / 60).toFixed(0),
dirname: '',
filename: file.name,
path: `client-s3://${hash}`,
};
};

Expand Down
17 changes: 13 additions & 4 deletions src/store/chat/slices/builtinTool/action.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { act, renderHook } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';

import { fileService } from '@/services/file';
import { ClientService } from '@/services/file/client';
import { messageService } from '@/services/message';
import { imageGenerationService } from '@/services/textToImage';
import { uploadService } from '@/services/upload';
import { chatSelectors } from '@/store/chat/selectors';
Expand Down Expand Up @@ -39,17 +41,23 @@ describe('chatToolSlice', () => {
vi.spyOn(uploadService, 'getImageFileByUrlWithCORS').mockResolvedValue(
new File(['1'], 'file.png', { type: 'image/png' }),
);
vi.spyOn(uploadService, 'uploadToClientDB').mockResolvedValue({} as any);
vi.spyOn(fileService, 'createFile').mockResolvedValue({ id: mockId, url: '' });
vi.spyOn(uploadService, 'uploadToClientS3').mockResolvedValue({} as any);
vi.spyOn(ClientService.prototype, 'createFile').mockResolvedValue({
id: mockId,
url: '',
});
vi.spyOn(result.current, 'toggleDallEImageLoading');
vi.spyOn(ClientService.prototype, 'checkFileHash').mockImplementation(async () => ({
isExist: false,
metadata: {},
}));

await act(async () => {
await result.current.generateImageFromPrompts(prompts, messageId);
});
// For each prompt, loading is toggled on and then off
expect(imageGenerationService.generateImage).toHaveBeenCalledTimes(prompts.length);
expect(uploadService.uploadToClientDB).toHaveBeenCalledTimes(prompts.length);

expect(uploadService.uploadToClientS3).toHaveBeenCalledTimes(prompts.length);
expect(result.current.toggleDallEImageLoading).toHaveBeenCalledTimes(prompts.length * 2);
});
});
Expand All @@ -75,6 +83,7 @@ describe('chatToolSlice', () => {
content: initialMessageContent,
}) as ChatMessage,
);
vi.spyOn(messageService, 'updateMessage').mockResolvedValueOnce(undefined);

await act(async () => {
await result.current.updateImageItem(messageId, updateFunction);
Expand Down
Loading

0 comments on commit b24b914

Please sign in to comment.