Skip to content

Commit

Permalink
Added Vector class
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Dec 13, 2023
1 parent 4c41fbf commit 2e9366d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 16 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ await db.schema.createTable('items')
Insert vectors

```javascript
import pgvector from 'pgvector/kysely';
import { vector } from 'pgvector/kysely';

const newItems = [
{embedding: pgvector.toSql([1, 2, 3])},
{embedding: pgvector.toSql([4, 5, 6])}
{embedding: vector([1, 2, 3])},
{embedding: vector([4, 5, 6])}
];
await db.insertInto('items').values(newItems).execute();
```
Expand Down
7 changes: 7 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
"url": "https://github.com/pgvector/pgvector-node"
},
"exports": {
".": {
"types": "./types/index.d.ts",
"default": "./src/index.js"
},
"./drizzle-orm": {
"types": "./types/drizzle-orm/index.d.ts",
"default": "./src/drizzle-orm/index.js"
Expand Down Expand Up @@ -48,6 +52,9 @@
},
"typesVersions": {
"*": {
"*": [
"types/index.d.ts"
],
"drizzle-orm": [
"types/drizzle-orm/index.d.ts"
],
Expand Down
29 changes: 29 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
const utils = require('./utils');

class Vector {
constructor(vec) {
if (Array.isArray(vec)) {
this.vec = vec;
} else {
this.vec = utils.fromSql(vec);
}
}

toArray() {
return this.vec;
}

toPostgres() {
return this.toString();
}

toString() {
return utils.toSql(this.vec);
}
}

function vector(vec) {
return new Vector(vec);
}

module.exports = {vector};
10 changes: 5 additions & 5 deletions src/kysely/index.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
const { sql } = require('kysely');
const { fromSql, toSql } = require('../utils');
const { vector } = require('..');

function l2Distance(column, value) {
return sql`${sql.ref(column)} <-> ${toSql(value)}`;
return sql`${sql.ref(column)} <-> ${vector(value)}`;
}

function maxInnerProduct(column, value) {
return sql`${sql.ref(column)} <#> ${toSql(value)}`;
return sql`${sql.ref(column)} <#> ${vector(value)}`;
}

function cosineDistance(column, value) {
return sql`${sql.ref(column)} <=> ${toSql(value)}`;
return sql`${sql.ref(column)} <=> ${vector(value)}`;
}

module.exports = {fromSql, toSql, l2Distance, maxInnerProduct, cosineDistance};
module.exports = {vector, l2Distance, maxInnerProduct, cosineDistance};
15 changes: 7 additions & 8 deletions tests/kysely/index.test.mjs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pg from 'pg';
import { Kysely, PostgresDialect, sql } from 'kysely';
import pgvector from 'pgvector/kysely';
import { l2Distance, maxInnerProduct, cosineDistance } from 'pgvector/kysely';
import { vector, l2Distance, maxInnerProduct, cosineDistance } from 'pgvector/kysely';

test('example', async () => {
const dialect = new PostgresDialect({
Expand All @@ -26,9 +25,9 @@ test('example', async () => {
.execute();

const newItems = [
{embedding: pgvector.toSql([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2])}
{embedding: vector([1, 1, 1])},
{embedding: vector([2, 2, 2])},
{embedding: vector([1, 1, 2])}
];
await db.insertInto('kysely_items')
.values(newItems)
Expand All @@ -40,9 +39,9 @@ test('example', async () => {
.limit(5)
.execute();
expect(items.map(v => v.id)).toStrictEqual([1, 3, 2]);
expect(pgvector.fromSql(items[0].embedding)).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items[1].embedding)).toStrictEqual([1, 1, 2]);
expect(pgvector.fromSql(items[2].embedding)).toStrictEqual([2, 2, 2]);
expect(vector(items[0].embedding).toArray()).toStrictEqual([1, 1, 1]);
expect(vector(items[1].embedding).toArray()).toStrictEqual([1, 1, 2]);
expect(vector(items[2].embedding).toArray()).toStrictEqual([2, 2, 2]);

items = await db.selectFrom('kysely_items')
.selectAll()
Expand Down

0 comments on commit 2e9366d

Please sign in to comment.