Skip to content

Commit

Permalink
Merge pull request #686 from janhq/feat/add-download-state
Browse files Browse the repository at this point in the history
feat: add download state
  • Loading branch information
namchuai authored Jun 13, 2024
2 parents 2b47c8e + 5b44c0a commit 65da850
Show file tree
Hide file tree
Showing 13 changed files with 426 additions and 11,080 deletions.
11,047 changes: 0 additions & 11,047 deletions cortex-js/package-lock.json

This file was deleted.

1 change: 1 addition & 0 deletions cortex-js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"@nestjs/config": "^3.2.2",
"@nestjs/core": "^10.0.0",
"@nestjs/devtools-integration": "^0.1.6",
"@nestjs/event-emitter": "^2.0.4",
"@nestjs/mapped-types": "*",
"@nestjs/platform-express": "^10.0.0",
"@nestjs/swagger": "^7.3.1",
Expand Down
6 changes: 6 additions & 0 deletions cortex-js/src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import { env } from 'node:process';
import { SeedService } from './usecases/seed/seed.service';
import { FileManagerModule } from './infrastructure/services/file-manager/file-manager.module';
import { AppLoggerMiddleware } from './infrastructure/middlewares/app.logger.middleware';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { AppController } from './infrastructure/controllers/app.controller';
import { DownloadManagerModule } from './download-manager/download-manager.module';

@Module({
imports: [
Expand All @@ -24,6 +27,7 @@ import { AppLoggerMiddleware } from './infrastructure/middlewares/app.logger.mid
isGlobal: true,
envFilePath: env.NODE_ENV !== 'production' ? '.env.development' : '.env',
}),
EventEmitterModule.forRoot(),
DatabaseModule,
MessagesModule,
ThreadsModule,
Expand All @@ -34,7 +38,9 @@ import { AppLoggerMiddleware } from './infrastructure/middlewares/app.logger.mid
ExtensionModule,
FileManagerModule,
ModelRepositoryModule,
DownloadManagerModule,
],
controllers: [AppController],
providers: [SeedService],
})
export class AppModule implements NestModule {
Expand Down
4 changes: 4 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import { KillCommand } from './infrastructure/commanders/kill.command';
import { PresetCommand } from './infrastructure/commanders/presets.command';
import { EmbeddingCommand } from './infrastructure/commanders/embeddings.command';
import { BenchmarkCommand } from './infrastructure/commanders/benchmark.command';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { DownloadManagerModule } from './download-manager/download-manager.module';

@Module({
imports: [
Expand All @@ -37,6 +39,7 @@ import { BenchmarkCommand } from './infrastructure/commanders/benchmark.command'
envFilePath:
process.env.NODE_ENV !== 'production' ? '.env.development' : '.env',
}),
EventEmitterModule.forRoot(),
DatabaseModule,
ModelsModule,
CortexModule,
Expand All @@ -46,6 +49,7 @@ import { BenchmarkCommand } from './infrastructure/commanders/benchmark.command'
AssistantsModule,
MessagesModule,
FileManagerModule,
DownloadManagerModule,
],
providers: [
CortexCommand,
Expand Down
72 changes: 72 additions & 0 deletions cortex-js/src/domain/models/download.interface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
export class DownloadState {
/**
* The id of a particular download. Being used to prevent duplication of downloads.
*/
id: string;

/**
* For displaying purposes.
*/
title: string;

/**
* The type of download.
*/
type: DownloadType;

/**
* The status of the download.
*/
status: DownloadStatus;

/**
* Explanation of the error if the download failed.
*/
error?: string;

/**
* The actual downloads. [DownloadState] is just a group to supporting for download multiple files.
*/
children: DownloadItem[];
}

export enum DownloadStatus {
Pending = 'pending',
Downloading = 'downloading',
Error = 'error',
Downloaded = 'downloaded',
}

export class DownloadItem {
/**
* Filename of the download.
*/
id: string;

time: {
elapsed: number;
remaining: number;
};

size: {
total: number;
transferred: number;
};

checksum?: string;

status: DownloadStatus;

error?: string;

metadata?: Record<string, unknown>;
}

export interface DownloadStateEvent {
data: DownloadState[];
}

export enum DownloadType {
Model = 'model',
Miscelanous = 'miscelanous',
}
10 changes: 10 additions & 0 deletions cortex-js/src/download-manager/download-manager.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { Module } from '@nestjs/common';
import { DownloadManagerService } from './download-manager.service';
import { HttpModule } from '@nestjs/axios';

@Module({
imports: [HttpModule],
providers: [DownloadManagerService],
exports: [DownloadManagerService],
})
export class DownloadManagerModule {}
18 changes: 18 additions & 0 deletions cortex-js/src/download-manager/download-manager.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Test, TestingModule } from '@nestjs/testing';
import { DownloadManagerService } from './download-manager.service';

describe('DownloadManagerService', () => {
let service: DownloadManagerService;

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [DownloadManagerService],
}).compile();

service = module.get<DownloadManagerService>(DownloadManagerService);
});

it('should be defined', () => {
expect(service).toBeDefined();
});
});
209 changes: 209 additions & 0 deletions cortex-js/src/download-manager/download-manager.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import {
DownloadItem,
DownloadState,
DownloadStatus,
DownloadType,
} from '@/domain/models/download.interface';
import { HttpService } from '@nestjs/axios';
import { Injectable } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { createWriteStream } from 'node:fs';
import { firstValueFrom } from 'rxjs';

@Injectable()
export class DownloadManagerService {
private allDownloadStates: DownloadState[] = [];
private abortControllers: Record<string, Record<string, AbortController>> =
{};

constructor(
private readonly httpService: HttpService,
private readonly eventEmitter: EventEmitter2,
) {
// start emitting download state each 500ms
setInterval(() => {
this.eventEmitter.emit('download.event', this.allDownloadStates);
}, 500);
}

async abortDownload(downloadId: string) {
if (!this.abortControllers[downloadId]) {
return;
}
Object.keys(this.abortControllers[downloadId]).forEach((destination) => {
this.abortControllers[downloadId][destination].abort();
});
delete this.abortControllers[downloadId];
this.allDownloadStates = this.allDownloadStates.filter(
(downloadState) => downloadState.id !== downloadId,
);
}

async submitDownloadRequest(
downloadId: string,
title: string,
downloadType: DownloadType,
urlToDestination: Record<string, string>,
) {
if (
this.allDownloadStates.find(
(downloadState) => downloadState.id === downloadId,
)
) {
return;
}

const downloadItems: DownloadItem[] = Object.keys(urlToDestination).map(
(url) => {
const destination = urlToDestination[url];
const downloadItem: DownloadItem = {
id: destination,
time: {
elapsed: 0,
remaining: 0,
},
size: {
total: 0,
transferred: 0,
},
status: DownloadStatus.Downloading,
};

return downloadItem;
},
);

const downloadState: DownloadState = {
id: downloadId,
title: title,
type: downloadType,
status: DownloadStatus.Downloading,
children: downloadItems,
};

this.allDownloadStates.push(downloadState);
this.abortControllers[downloadId] = {};

Object.keys(urlToDestination).forEach((url) => {
const destination = urlToDestination[url];
this.downloadFile(downloadId, url, destination);
});
}

private async downloadFile(
downloadId: string,
url: string,
destination: string,
) {
const controller = new AbortController();
// adding to abort controllers
this.abortControllers[downloadId][destination] = controller;

const response = await firstValueFrom(
this.httpService.get(url, {
responseType: 'stream',
signal: controller.signal,
}),
);

// check if response is success
if (!response) {
throw new Error('Failed to download model');
}

const writer = createWriteStream(destination);
const totalBytes = response.headers['content-length'];

// update download state
const currentDownloadState = this.allDownloadStates.find(
(downloadState) => downloadState.id === downloadId,
);
if (!currentDownloadState) {
return;
}
const downloadItem = currentDownloadState?.children.find(
(downloadItem) => downloadItem.id === destination,
);
if (downloadItem) {
downloadItem.size.total = totalBytes;
}

let transferredBytes = 0;

writer.on('finish', () => {
// delete the abort controller
delete this.abortControllers[downloadId][destination];
const currentDownloadState = this.allDownloadStates.find(
(downloadState) => downloadState.id === downloadId,
);
if (!currentDownloadState) {
return;
}

// update current child status to downloaded, find by destination as id
const downloadItem = currentDownloadState?.children.find(
(downloadItem) => downloadItem.id === destination,
);
if (downloadItem) {
downloadItem.status = DownloadStatus.Downloaded;
}

const allChildrenDownloaded = currentDownloadState?.children.every(
(downloadItem) => downloadItem.status === DownloadStatus.Downloaded,
);

if (allChildrenDownloaded) {
delete this.abortControllers[downloadId];
currentDownloadState.status = DownloadStatus.Downloaded;
// remove download state if all children is downloaded
this.allDownloadStates = this.allDownloadStates.filter(
(downloadState) => downloadState.id !== downloadId,
);
}
});

writer.on('error', (error) => {
delete this.abortControllers[downloadId][destination];
const currentDownloadState = this.allDownloadStates.find(
(downloadState) => downloadState.id === downloadId,
);
if (!currentDownloadState) {
return;
}

const downloadItem = currentDownloadState?.children.find(
(downloadItem) => downloadItem.id === destination,
);
if (downloadItem) {
downloadItem.status = DownloadStatus.Error;
downloadItem.error = error.message;
}

currentDownloadState.status = DownloadStatus.Error;
currentDownloadState.error = error.message;

// remove download state if all children is downloaded
this.allDownloadStates = this.allDownloadStates.filter(
(downloadState) => downloadState.id !== downloadId,
);
});

response.data.on('data', (chunk: any) => {
transferredBytes += chunk.length;

const currentDownloadState = this.allDownloadStates.find(
(downloadState) => downloadState.id === downloadId,
);
if (!currentDownloadState) return;

const downloadItem = currentDownloadState?.children.find(
(downloadItem) => downloadItem.id === destination,
);
if (downloadItem) {
downloadItem.size.transferred = transferredBytes;
}
});

response.data.pipe(writer);
}
}
19 changes: 19 additions & 0 deletions cortex-js/src/infrastructure/controllers/app.controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import {
DownloadState,
DownloadStateEvent,
} from '@/domain/models/download.interface';
import { Controller, Sse } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { Observable, fromEvent, map } from 'rxjs';

@Controller('app')
export class AppController {
constructor(private readonly eventEmitter: EventEmitter2) {}

@Sse('download')
downloadEvent(): Observable<DownloadStateEvent> {
return fromEvent(this.eventEmitter, 'download.event').pipe(
map((downloadState: DownloadState[]) => ({ data: downloadState })),
);
}
}
Loading

0 comments on commit 65da850

Please sign in to comment.