Skip to content

Commit

Permalink
♻️ refactor: refactor the drizzle code style (#5058)
Browse files Browse the repository at this point in the history
* ♻️ refactor: Update drizzle code style

* ♻️ refactor: Fix some drizzle-orm/expressions import

* 💄 style: 替换为箭头函数

* Update topic.ts

---------

Co-authored-by: Arvin Xu <[email protected]>
  • Loading branch information
canisminor1990 and arvinxx authored Dec 17, 2024
1 parent 679211d commit 4057ad3
Show file tree
Hide file tree
Showing 35 changed files with 309 additions and 400 deletions.
29 changes: 11 additions & 18 deletions src/database/repositories/dataImporter/__tests__/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// @vitest-environment node
import { eq, inArray } from 'drizzle-orm';
import { eq, inArray } from 'drizzle-orm/expressions';
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
import {
agents,
agentsToSessions,
Expand All @@ -12,6 +11,7 @@ import {
topics,
users,
} from '@/database/schemas';
import { getTestDBInstance } from '@/database/server/core/dbForTest';
import { CURRENT_CONFIG_VERSION } from '@/migrations';
import { ImporterEntryData } from '@/types/importer';

Expand Down Expand Up @@ -60,8 +60,7 @@ describe('DataImporter', () => {
it('should skip existing session groups and return correct result', async () => {
await serverDB
.insert(sessionGroups)
.values({ clientId: 'group1', name: 'Existing Group', userId })
.execute();
.values({ clientId: 'group1', name: 'Existing Group', userId });

const data: ImporterEntryData = {
version: CURRENT_CONFIG_VERSION,
Expand Down Expand Up @@ -141,7 +140,7 @@ describe('DataImporter', () => {
});

it('should skip existing sessions and return correct result', async () => {
await serverDB.insert(sessions).values({ clientId: 'session1', userId }).execute();
await serverDB.insert(sessions).values({ clientId: 'session1', userId });

const data: ImporterEntryData = {
version: CURRENT_CONFIG_VERSION,
Expand Down Expand Up @@ -477,10 +476,7 @@ describe('DataImporter', () => {
});

it('should skip existing topics and return correct result', async () => {
await serverDB
.insert(topics)
.values({ clientId: 'topic1', title: 'Existing Topic', userId })
.execute();
await serverDB.insert(topics).values({ clientId: 'topic1', title: 'Existing Topic', userId });

const data: ImporterEntryData = {
version: CURRENT_CONFIG_VERSION,
Expand Down Expand Up @@ -616,15 +612,12 @@ describe('DataImporter', () => {
});

it('should skip existing messages and return correct result', async () => {
await serverDB
.insert(messages)
.values({
clientId: 'msg1',
content: 'Existing Message',
role: 'user',
userId,
})
.execute();
await serverDB.insert(messages).values({
clientId: 'msg1',
content: 'Existing Message',
role: 'user',
userId,
});

const data: ImporterEntryData = {
version: CURRENT_CONFIG_VERSION,
Expand Down
77 changes: 31 additions & 46 deletions src/database/repositories/dataImporter/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { eq, inArray, sql } from 'drizzle-orm';
import { and } from 'drizzle-orm/expressions';
import { sql } from 'drizzle-orm';
import { and, eq, inArray } from 'drizzle-orm/expressions';

import {
agents,
Expand Down Expand Up @@ -71,8 +71,7 @@ export class DataImporterRepos {
set: { updatedAt: new Date() },
target: [sessionGroups.clientId, sessionGroups.userId],
})
.returning({ clientId: sessionGroups.clientId, id: sessionGroups.id })
.execute();
.returning({ clientId: sessionGroups.clientId, id: sessionGroups.id });

sessionGroupResult.added = mapArray.length - query.length;

Expand Down Expand Up @@ -109,8 +108,7 @@ export class DataImporterRepos {
set: { updatedAt: new Date() },
target: [sessions.clientId, sessions.userId],
})
.returning({ clientId: sessions.clientId, id: sessions.id })
.execute();
.returning({ clientId: sessions.clientId, id: sessions.id });

// get the session client-server id map
sessionIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));
Expand All @@ -133,18 +131,14 @@ export class DataImporterRepos {
userId: this.userId,
})),
)
.returning({ id: agents.id })
.execute();
.returning({ id: agents.id });

await trx
.insert(agentsToSessions)
.values(
shouldInsertSessionAgents.map(({ id }, index) => ({
agentId: agentMapArray[index].id,
sessionId: sessionIdMap[id],
})),
)
.execute();
await trx.insert(agentsToSessions).values(
shouldInsertSessionAgents.map(({ id }, index) => ({
agentId: agentMapArray[index].id,
sessionId: sessionIdMap[id],
})),
);
}
}

Expand Down Expand Up @@ -178,8 +172,7 @@ export class DataImporterRepos {
set: { updatedAt: new Date() },
target: [topics.clientId, topics.userId],
})
.returning({ clientId: topics.clientId, id: topics.id })
.execute();
.returning({ clientId: topics.clientId, id: topics.id });

topicIdMap = Object.fromEntries(mapArray.map(({ clientId, id }) => [clientId, id]));

Expand Down Expand Up @@ -230,7 +223,7 @@ export class DataImporterRepos {

for (let i = 0; i < inertValues.length; i += BATCH_SIZE) {
const batch = inertValues.slice(i, i + BATCH_SIZE);
await trx.insert(messages).values(batch).execute();
await trx.insert(messages).values(batch);
}

console.timeEnd('insert messages');
Expand Down Expand Up @@ -265,7 +258,7 @@ export class DataImporterRepos {
.filter(Boolean);

if (parentIdUpdates.length > 0) {
const updateQuery = trx
await trx
.update(messages)
.set({
parentId: sql`CASE ${sql.join(parentIdUpdates)} END`,
Expand All @@ -281,42 +274,34 @@ export class DataImporterRepos {
// const SQL = updateQuery.toSQL();
// console.log('sql:', SQL.sql);
// console.log('params:', SQL.params);

await updateQuery.execute();
}
console.timeEnd('execute updates parentId');

// 4. insert message plugins
const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin);
if (pluginInserts.length > 0) {
await trx
.insert(messagePlugins)
.values(
pluginInserts.map((msg) => ({
apiName: msg.plugin?.apiName,
arguments: msg.plugin?.arguments,
id: messageIdMap[msg.id],
identifier: msg.plugin?.identifier,
state: msg.pluginState,
toolCallId: msg.tool_call_id,
type: msg.plugin?.type,
})),
)
.execute();
await trx.insert(messagePlugins).values(
pluginInserts.map((msg) => ({
apiName: msg.plugin?.apiName,
arguments: msg.plugin?.arguments,
id: messageIdMap[msg.id],
identifier: msg.plugin?.identifier,
state: msg.pluginState,
toolCallId: msg.tool_call_id,
type: msg.plugin?.type,
})),
);
}

// 5. insert message translate
const translateInserts = shouldInsertMessages.filter((msg) => msg.extra?.translate);
if (translateInserts.length > 0) {
await trx
.insert(messageTranslates)
.values(
translateInserts.map((msg) => ({
id: messageIdMap[msg.id],
...msg.extra?.translate,
})),
)
.execute();
await trx.insert(messageTranslates).values(
translateInserts.map((msg) => ({
id: messageIdMap[msg.id],
...msg.extra?.translate,
})),
);
}

// TODO: 未来需要处理 TTS 和图片的插入 (目前存在 file 的部分,不方便处理)
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/_test_template.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { eq } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { eq } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/asyncTask.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { eq } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/chunk.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { eq } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
2 changes: 1 addition & 1 deletion src/database/server/models/__tests__/file.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// @vitest-environment node
import { eq, inArray } from 'drizzle-orm';
import { eq, inArray } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
3 changes: 1 addition & 2 deletions src/database/server/models/__tests__/knowledgeBase.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { and, desc } from 'drizzle-orm/expressions';
import { and, eq } from 'drizzle-orm/expressions';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
Expand Down
Loading

0 comments on commit 4057ad3

Please sign in to comment.