Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): CLIP search integration #1939

Merged
merged 22 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions machine-learning/src/main.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,58 @@
import os
from flask import Flask, request
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
from PIL import Image


server = Flask(__name__)


classifier = pipeline(
task="image-classification",
model="microsoft/resnet-50"
)

detector = pipeline(
task="object-detection",
model="hustvl/yolos-tiny"
)


# Environment resolver
is_dev = os.getenv('NODE_ENV') == 'development'
server_port = os.getenv('MACHINE_LEARNING_PORT') or 3003
server_port = os.getenv('MACHINE_LEARNING_PORT', 3003)
server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0')

classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50')
object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny')
clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32')
clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32')

_model_cache = {}
def _get_model(model, task=None):
global _model_cache
key = '|'.join([model, str(task)])
if key not in _model_cache:
if task:
_model_cache[key] = pipeline(model=model, task=task)
else:
_model_cache[key] = SentenceTransformer(model)
return _model_cache[key]

server = Flask(__name__)

@server.route("/ping")
def ping():
return "pong"


@server.route("/object-detection/detect-object", methods=['POST'])
def object_detection():
model = _get_model(object_model, 'object-detection')
assetPath = request.json['thumbnailPath']
return run_engine(detector, assetPath), 201

return run_engine(model, assetPath), 200

@server.route("/image-classifier/tag-image", methods=['POST'])
def image_classification():
model = _get_model(classification_model, 'image-classification')
assetPath = request.json['thumbnailPath']
return run_engine(model, assetPath), 200

@server.route("/sentence-transformer/encode-image", methods=['POST'])
def clip_encode_image():
model = _get_model(clip_image_model)
assetPath = request.json['thumbnailPath']
return run_engine(classifier, assetPath), 201
return model.encode(Image.open(assetPath)).tolist(), 200

@server.route("/sentence-transformer/encode-text", methods=['POST'])
def clip_encode_text():
model = _get_model(clip_text_model)
text = request.json['text']
return model.encode(text).tolist(), 200

def run_engine(engine, path):
result = []
Expand All @@ -55,4 +70,4 @@ def run_engine(engine, path):


if __name__ == "__main__":
server.run(debug=is_dev, host='0.0.0.0', port=server_port)
server.run(debug=is_dev, host=server_host, port=server_port)
32 changes: 3 additions & 29 deletions mobile/openapi/doc/SearchApi.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

95 changes: 3 additions & 92 deletions mobile/openapi/lib/api/search_api.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion mobile/openapi/test/search_api_test.dart

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions server/apps/immich/src/api-v1/album/album.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ describe('Album service', () => {

expect(result.id).toEqual(albumEntity.id);
expect(result.albumName).toEqual(albumEntity.albumName);
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } });
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } });
});

it('gets list of albums for auth user', async () => {
Expand Down Expand Up @@ -316,7 +316,7 @@ describe('Album service', () => {
albumName: updatedAlbumName,
albumThumbnailAssetId: updatedAlbumThumbnailAssetId,
});
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } });
expect(jobMock.queue).toHaveBeenCalledWith({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } });
});

it('prevents updating a not owned album (shared with auth user)', async () => {
Expand Down
6 changes: 3 additions & 3 deletions server/apps/immich/src/api-v1/album/album.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export class AlbumService {

async create(authUser: AuthUserDto, createAlbumDto: CreateAlbumDto): Promise<AlbumResponseDto> {
const albumEntity = await this.albumRepository.create(authUser.id, createAlbumDto);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: albumEntity } });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [albumEntity.id] } });
return mapAlbum(albumEntity);
}

Expand Down Expand Up @@ -107,7 +107,7 @@ export class AlbumService {
}

await this.albumRepository.delete(album);
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { id: albumId } });
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ALBUM, data: { ids: [albumId] } });
}

async removeUserFromAlbum(authUser: AuthUserDto, albumId: string, userId: string | 'me'): Promise<void> {
Expand Down Expand Up @@ -171,7 +171,7 @@ export class AlbumService {

const updatedAlbum = await this.albumRepository.updateAlbum(album, updateAlbumDto);

await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { album: updatedAlbum } });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ALBUM, data: { ids: [updatedAlbum.id] } });

return mapAlbum(updatedAlbum);
}
Expand Down
4 changes: 2 additions & 2 deletions server/apps/immich/src/api-v1/asset/asset.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,8 @@ describe('AssetService', () => {
]);

expect(jobMock.queue.mock.calls).toEqual([
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset1' } }],
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { id: 'asset2' } }],
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset1'] } }],
[{ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: ['asset2'] } }],
[
{
name: JobName.DELETE_FILES,
Expand Down
6 changes: 3 additions & 3 deletions server/apps/immich/src/api-v1/asset/asset.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export class AssetService {

const updatedAsset = await this._assetRepository.update(authUser.id, asset, dto);

await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: updatedAsset } });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [assetId] } });

return mapAsset(updatedAsset);
}
Expand Down Expand Up @@ -251,8 +251,8 @@ export class AssetService {
res.header('Cache-Control', 'none');
Logger.error(`Cannot create read stream for asset ${asset.id}`, 'getAssetThumbnail');
throw new InternalServerErrorException(
e,
`Cannot read thumbnail file for asset ${asset.id} - contact your administrator`,
{ cause: e as Error },
);
}
}
Expand Down Expand Up @@ -427,7 +427,7 @@ export class AssetService {

try {
await this._assetRepository.remove(asset);
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { id } });
await this.jobRepository.queue({ name: JobName.SEARCH_REMOVE_ASSET, data: { ids: [id] } });

result.push({ id, status: DeleteAssetStatusEnum.SUCCESS });
deleteQueue.push(asset.originalPath, asset.webpPath, asset.resizePath, asset.encodedVideoPath);
Expand Down
1 change: 1 addition & 0 deletions server/apps/immich/src/api-v1/job/job.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export class JobService {
for (const asset of assets) {
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
}
return assets.length;
}
Expand Down
2 changes: 1 addition & 1 deletion server/apps/immich/src/controllers/search.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export class SearchController {
@Get()
async search(
@GetAuthUser() authUser: AuthUserDto,
@Query(new ValidationPipe({ transform: true })) dto: SearchDto,
@Query(new ValidationPipe({ transform: true })) dto: SearchDto | any,
bo0tzz marked this conversation as resolved.
Show resolved Hide resolved
): Promise<SearchResponseDto> {
return this.searchService.search(authUser, dto);
}
Expand Down
Loading