Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JS client parity and broken out tests #292

Merged
merged 16 commits into from
Apr 18, 2023
3 changes: 3 additions & 0 deletions .github/workflows/chroma-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ jobs:
run: python -m pytest
- name: Integration Test
run: bin/integration-test
- name: JS Tests
run: bin/js-test

16 changes: 16 additions & 0 deletions bin/js-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env bash

set -e

function cleanup {
cd ../..
docker compose -f docker-compose.test.yml down --rmi local --volumes
}

trap cleanup EXIT

docker compose -f docker-compose.test.yml up --build -d

cd clients/js
yarn
yarn test:run
jeffchuber marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, settings):
self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an explicit url for this to make it easy for our OpenAPI spec to pick it up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the python frontend is still using the "old" route... self.router.add_api_route("/api/v1", self.root, methods=["GET"]), but we could move it over to the new one as well. perhaps it makes sense to do that in this PR? Though having isolation b/t JS and python changes is nice.

self.router.add_api_route("/api/v1/persist", self.persist, methods=["POST"])
self.router.add_api_route("/api/v1/raw_sql", self.raw_sql, methods=["POST"])

Expand Down Expand Up @@ -124,6 +125,9 @@ def app(self):
def root(self):
return {"nanosecond heartbeat": self._api.heartbeat()}

def heartbeat(self):
return self.root()

def persist(self):
self._api.persist()

Expand Down
4 changes: 2 additions & 2 deletions clients/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
"test:run": "jest --runInBand",
"test:runfull": "PORT=8001 jest --runInBand",
"test:update": "run-s db:clean db:run && jest --runInBand --updateSnapshot && run-s db:clean",
"db:clean": "cd ../.. && docker-compose -f docker-compose-js-tests.yml down --volumes",
"db:run": "cd ../.. && docker-compose -f docker-compose-js-tests.yml up --detach && sleep 5",
"db:clean": "cd ../.. && docker-compose -f docker-compose.test.yml down --volumes",
"db:run": "cd ../.. && docker-compose -f docker-compose.test.yml up --detach && sleep 5",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we removed an "extra" docker file, but didnt fix this up.

"clean": "rimraf dist",
"build": "run-s clean build:*",
"build:main": "tsc -p tsconfig.json",
Expand Down
120 changes: 120 additions & 0 deletions clients/js/src/generated/api/default-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,36 @@ export const DefaultApiAxiosParamCreator = function (configuration?: Configurati
options: localVarRequestOptions,
};
},
/**
*
* @summary Heartbeat
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
heartbeat: async (options: AxiosRequestConfig = {}): Promise<RequestArgs> => {
const localVarPath = `/api/v1/heartbeat`;
// use dummy base URL string because the URL constructor only accepts absolute URLs.
const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL);
let baseOptions;
if (configuration) {
baseOptions = configuration.baseOptions;
}

const localVarRequestOptions = { method: 'GET', ...baseOptions, ...options};
const localVarHeaderParameter = {} as any;
const localVarQueryParameter = {} as any;



setSearchParams(localVarUrlObj, localVarQueryParameter);
let headersFromBaseOptions = baseOptions && baseOptions.headers ? baseOptions.headers : {};
localVarRequestOptions.headers = {...localVarHeaderParameter, ...headersFromBaseOptions, ...options.headers};

return {
url: toPathString(localVarUrlObj),
options: localVarRequestOptions,
};
},
/**
*
* @summary List Collections
Expand Down Expand Up @@ -607,6 +637,36 @@ export const DefaultApiAxiosParamCreator = function (configuration?: Configurati
localVarRequestOptions.headers = {...localVarHeaderParameter, ...headersFromBaseOptions, ...options.headers};
localVarRequestOptions.data = serializeDataIfNeeded(updateCollection, localVarRequestOptions, configuration)

return {
url: toPathString(localVarUrlObj),
options: localVarRequestOptions,
};
},
/**
*
* @summary Version
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
version: async (options: AxiosRequestConfig = {}): Promise<RequestArgs> => {
const localVarPath = `/api/v1/version`;
// use dummy base URL string because the URL constructor only accepts absolute URLs.
const localVarUrlObj = new URL(localVarPath, DUMMY_BASE_URL);
let baseOptions;
if (configuration) {
baseOptions = configuration.baseOptions;
}

const localVarRequestOptions = { method: 'GET', ...baseOptions, ...options};
const localVarHeaderParameter = {} as any;
const localVarQueryParameter = {} as any;



setSearchParams(localVarUrlObj, localVarQueryParameter);
let headersFromBaseOptions = baseOptions && baseOptions.headers ? baseOptions.headers : {};
localVarRequestOptions.headers = {...localVarHeaderParameter, ...headersFromBaseOptions, ...options.headers};

return {
url: toPathString(localVarUrlObj),
options: localVarRequestOptions,
Expand Down Expand Up @@ -725,6 +785,16 @@ export const DefaultApiFp = function(configuration?: Configuration) {
const localVarAxiosArgs = await localVarAxiosParamCreator.getNearestNeighbors(collectionName, queryEmbedding, options);
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
},
/**
*
* @summary Heartbeat
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
async heartbeat(options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<any>> {
const localVarAxiosArgs = await localVarAxiosParamCreator.heartbeat(options);
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
},
/**
*
* @summary List Collections
Expand Down Expand Up @@ -800,6 +870,16 @@ export const DefaultApiFp = function(configuration?: Configuration) {
const localVarAxiosArgs = await localVarAxiosParamCreator.updateCollection(collectionName, updateCollection, options);
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
},
/**
*
* @summary Version
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
async version(options?: AxiosRequestConfig): Promise<(axios?: AxiosInstance, basePath?: string) => AxiosPromise<any>> {
const localVarAxiosArgs = await localVarAxiosParamCreator.version(options);
return createRequestFunction(localVarAxiosArgs, globalAxios, BASE_PATH, configuration);
},
}
};

Expand Down Expand Up @@ -904,6 +984,15 @@ export const DefaultApiFactory = function (configuration?: Configuration, basePa
getNearestNeighbors(collectionName: any, queryEmbedding: QueryEmbedding, options?: any): AxiosPromise<any> {
return localVarFp.getNearestNeighbors(collectionName, queryEmbedding, options).then((request) => request(axios, basePath));
},
/**
*
* @summary Heartbeat
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
heartbeat(options?: any): AxiosPromise<any> {
return localVarFp.heartbeat(options).then((request) => request(axios, basePath));
},
/**
*
* @summary List Collections
Expand Down Expand Up @@ -972,6 +1061,15 @@ export const DefaultApiFactory = function (configuration?: Configuration, basePa
updateCollection(collectionName: any, updateCollection: UpdateCollection, options?: any): AxiosPromise<any> {
return localVarFp.updateCollection(collectionName, updateCollection, options).then((request) => request(axios, basePath));
},
/**
*
* @summary Version
* @param {*} [options] Override http request option.
* @throws {RequiredError}
*/
version(options?: any): AxiosPromise<any> {
return localVarFp.version(options).then((request) => request(axios, basePath));
},
};
};

Expand Down Expand Up @@ -1300,6 +1398,17 @@ export class DefaultApi extends BaseAPI {
return DefaultApiFp(this.configuration).getNearestNeighbors(requestParameters.collectionName, requestParameters.queryEmbedding, options).then((request) => request(this.axios, this.basePath));
}

/**
*
* @summary Heartbeat
* @param {*} [options] Override http request option.
* @throws {RequiredError}
* @memberof DefaultApi
*/
public heartbeat(options?: AxiosRequestConfig) {
return DefaultApiFp(this.configuration).heartbeat(options).then((request) => request(this.axios, this.basePath));
}

/**
*
* @summary List Collections
Expand Down Expand Up @@ -1379,4 +1488,15 @@ export class DefaultApi extends BaseAPI {
public updateCollection(requestParameters: DefaultApiUpdateCollectionRequest, options?: AxiosRequestConfig) {
return DefaultApiFp(this.configuration).updateCollection(requestParameters.collectionName, requestParameters.updateCollection, options).then((request) => request(this.axios, this.basePath));
}

/**
*
* @summary Version
* @param {*} [options] Override http request option.
* @throws {RequiredError}
* @memberof DefaultApi
*/
public version(options?: AxiosRequestConfig) {
return DefaultApiFp(this.configuration).version(options).then((request) => request(this.axios, this.basePath));
}
}
35 changes: 35 additions & 0 deletions clients/js/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { QueryEmbeddingIncludeEnum } from "./generated";
import { DefaultApi } from "./generated/api";
import { Configuration } from "./generated/configuration";

Expand Down Expand Up @@ -209,6 +210,8 @@ export class Collection {
n_results: number = 10,
where?: object,
query_text?: string | string[],
where_document?: object, // {"$contains":"search_string"}
include?: QueryEmbeddingIncludeEnum[], // ["metadata", "document"]
) {
if ((query_embeddings === undefined) && (query_text === undefined)) {
throw new Error(
Expand All @@ -234,6 +237,8 @@ export class Collection {
query_embeddings: query_embeddingsArray,
where,
n_results,
where_document: where_document,
include: include
},
}).then(function (response) {
return response.data;
Expand Down Expand Up @@ -286,6 +291,18 @@ export class ChromaClient {
return await this.api.reset();
}

// version
public async version() {
const response = await this.api.version();
return response.data;
}

// heartbeat
public async heartbeat() {
const response = await this.api.heartbeat();
return response.data["nanosecond heartbeat"];
}

public async createCollection(name: string, metadata?: object, embeddingFunction?: CallableFunction) {
const newCollection = await this.api.createCollection({
createCollection: { name, metadata },
Expand All @@ -302,6 +319,24 @@ export class ChromaClient {
return new Collection(name, this.api, embeddingFunction);
}

// get or create collection
public async getOrCreateCollection(name: string, metadata?: object, embeddingFunction?: CallableFunction) {
const newCollection = await this.api.createCollection({
createCollection: { name, metadata, get_or_create: true },

}).then(function (response) {
return response.data;
}).catch(function ({ response }) {
return response.data;
});

if (newCollection.error) {
throw new Error(newCollection.error);
}

return new Collection(name, this.api, embeddingFunction);
}

public async listCollections() {
const response = await this.api.listCollections();
return response.data;
Expand Down
47 changes: 47 additions & 0 deletions clients/js/test/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,53 @@ test('it should create the client connection', async () => {
expect(chroma).toBeInstanceOf(ChromaClient)
})

test('it should get the version', async () => {
const version = await chroma.version()
expect(version).toBeDefined()
expect(version).toMatch(/^[0-9]+\.[0-9]+\.[0-9]+$/)
})

test('it should get the heartbeat', async () => {
const heartbeat = await chroma.heartbeat()
expect(heartbeat).toBeDefined()
expect(heartbeat).toBeGreaterThan(0)
})

test('it should get or create a collection', async () => {
await chroma.reset()
const collection = await chroma.createCollection('test')

const collection2 = await chroma.getOrCreateCollection('test')
expect(collection2).toBeDefined()
expect(collection2).toHaveProperty('name')
expect(collection2.name).toBe('test')

const collection3 = await chroma.getOrCreateCollection('test3')
expect(collection3).toBeDefined()
expect(collection3).toHaveProperty('name')
expect(collection3.name).toBe('test3')
})

// test includes on query
test('it should query a collection', async () => {
await chroma.reset()
const collection = await chroma.createCollection('test')
const ids = ['test1', 'test2', 'test3']
const embeddings = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
]
const metadata = [
{ test: 'test1' },
{ test: 'test2' },
{ test: 'test3' }
]
// probably add documents here as well so i can try where_document here too
await collection.add(ids, embeddings, metadata)
// then query asking for different includes
})

test('it should reset the database', async () => {
await chroma.reset()
let collections = await chroma.listCollections()
Expand Down