diff --git a/src/knex/index.js b/src/knex/index.js index c48369c..c60f6bd 100644 --- a/src/knex/index.js +++ b/src/knex/index.js @@ -1,5 +1,6 @@ const knex = require('knex'); const { fromSql, toSql, sqlType } = require('../utils'); +const { vector } = require('..'); knex.SchemaBuilder.extend('enableExtension', function(name) { return this.raw('CREATE EXTENSION IF NOT EXISTS ??', [name]); @@ -22,4 +23,4 @@ knex.QueryBuilder.extend('cosineDistance', function(column, value) { return this.client.raw('?? <=> ?', [column, toSql(value)]); }); -module.exports = {fromSql, toSql}; +module.exports = {vector, fromSql, toSql}; diff --git a/src/objection/index.js b/src/objection/index.js index b75144e..1488648 100644 --- a/src/objection/index.js +++ b/src/objection/index.js @@ -1,5 +1,4 @@ -require('../knex'); -const { vector } = require('..'); +const { vector } = require('../knex'); const { raw } = require('objection'); function l2Distance(column, value) { diff --git a/tests/knex/index.test.mjs b/tests/knex/index.test.mjs index c93e963..ef7f6f2 100644 --- a/tests/knex/index.test.mjs +++ b/tests/knex/index.test.mjs @@ -1,5 +1,5 @@ import Knex from 'knex'; -import pgvector from 'pgvector/knex'; +import { vector } from 'pgvector/knex'; test('example', async () => { const knex = Knex({ @@ -15,9 +15,9 @@ test('example', async () => { }); const newItems = [ - {embedding: pgvector.toSql([1, 1, 1])}, - {embedding: pgvector.toSql([2, 2, 2])}, - {embedding: pgvector.toSql([1, 1, 2])} + {embedding: vector([1, 1, 1])}, + {embedding: vector([2, 2, 2])}, + {embedding: vector([1, 1, 2])} ]; await knex('knex_items').insert(newItems); @@ -26,9 +26,9 @@ test('example', async () => { .orderBy(knex.l2Distance('embedding', [1, 1, 1])) .limit(5); expect(items.map(v => v.id)).toStrictEqual([1, 3, 2]); - expect(pgvector.fromSql(items[0].embedding)).toStrictEqual([1, 1, 1]); - expect(pgvector.fromSql(items[1].embedding)).toStrictEqual([1, 1, 2]); - expect(pgvector.fromSql(items[2].embedding)).toStrictEqual([2, 2, 2]); + expect(vector(items[0].embedding).toArray()).toStrictEqual([1, 1, 1]); + expect(vector(items[1].embedding).toArray()).toStrictEqual([1, 1, 2]); + expect(vector(items[2].embedding).toArray()).toStrictEqual([2, 2, 2]); // max inner product items = await knex('knex_items')