Skip to content

Commit

Permalink
Updated toSql and fromSql to support sparse vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 27, 2024
1 parent f8610a2 commit f44f749
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 54 deletions.
10 changes: 5 additions & 5 deletions src/knex/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const knex = require('knex');
const { fromSql, toSql, anyToSql, vectorType, halfvecType, sparsevecType } = require('../utils');
const { fromSql, toSql, vectorType, halfvecType, sparsevecType } = require('../utils');

knex.SchemaBuilder.extend('enableExtension', function (name) {
return this.raw('CREATE EXTENSION IF NOT EXISTS ??', [name]);
Expand All @@ -21,19 +21,19 @@ knex.TableBuilder.extend('sparsevec', function (name, options) {
});

knex.QueryBuilder.extend('l2Distance', function (column, value) {
return this.client.raw('?? <-> ?', [column, anyToSql(value)]);
return this.client.raw('?? <-> ?', [column, toSql(value)]);
});

knex.QueryBuilder.extend('maxInnerProduct', function (column, value) {
return this.client.raw('?? <#> ?', [column, anyToSql(value)]);
return this.client.raw('?? <#> ?', [column, toSql(value)]);
});

knex.QueryBuilder.extend('cosineDistance', function (column, value) {
return this.client.raw('?? <=> ?', [column, anyToSql(value)]);
return this.client.raw('?? <=> ?', [column, toSql(value)]);
});

knex.QueryBuilder.extend('l1Distance', function (column, value) {
return this.client.raw('?? <+> ?', [column, anyToSql(value)]);
return this.client.raw('?? <+> ?', [column, toSql(value)]);
});

knex.QueryBuilder.extend('hammingDistance', function (column, value) {
Expand Down
9 changes: 4 additions & 5 deletions src/kysely/index.js
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
const { sql } = require('kysely');
const { fromSql, toSql } = require('..');
const { anyToSql } = require('../utils');

function l2Distance(column, value) {
return sql`${sql.ref(column)} <-> ${anyToSql(value)}`;
return sql`${sql.ref(column)} <-> ${toSql(value)}`;
}

function maxInnerProduct(column, value) {
return sql`${sql.ref(column)} <#> ${anyToSql(value)}`;
return sql`${sql.ref(column)} <#> ${toSql(value)}`;
}

function cosineDistance(column, value) {
return sql`${sql.ref(column)} <=> ${anyToSql(value)}`;
return sql`${sql.ref(column)} <=> ${toSql(value)}`;
}

function l1Distance(column, value) {
return sql`${sql.ref(column)} <+> ${anyToSql(value)}`;
return sql`${sql.ref(column)} <+> ${toSql(value)}`;
}

function hammingDistance(column, value) {
Expand Down
6 changes: 3 additions & 3 deletions src/mikro-orm/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ const { BitType } = require('./bit');
const { HalfvecType } = require('./halfvec');
const { SparsevecType } = require('./sparsevec');
const { VectorType } = require('./vector');
const { anyToSql } = require('../utils');
const { toSql } = require('../utils');

function distance(op, column, value, em, binary) {
if (raw) {
return raw(`?? ${op} ?`, [column, binary ? value : anyToSql(value)]);
return raw(`?? ${op} ?`, [column, binary ? value : toSql(value)]);
} else {
return em.raw(`?? ${op} ?`, [column, binary ? value : anyToSql(value)]);
return em.raw(`?? ${op} ?`, [column, binary ? value : toSql(value)]);
}
}

Expand Down
9 changes: 4 additions & 5 deletions src/objection/index.js
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
const { fromSql, toSql } = require('../knex');
const { raw } = require('objection');
const { anyToSql } = require('../utils');

function l2Distance(column, value) {
return raw('?? <-> ?', [column, anyToSql(value)]);
return raw('?? <-> ?', [column, toSql(value)]);
}

function maxInnerProduct(column, value) {
return raw('?? <#> ?', [column, anyToSql(value)]);
return raw('?? <#> ?', [column, toSql(value)]);
}

function cosineDistance(column, value) {
return raw('?? <=> ?', [column, anyToSql(value)]);
return raw('?? <=> ?', [column, toSql(value)]);
}

function l1Distance(column, value) {
return raw('?? <+> ?', [column, anyToSql(value)]);
return raw('?? <+> ?', [column, toSql(value)]);
}

function hammingDistance(column, value) {
Expand Down
4 changes: 2 additions & 2 deletions src/sequelize/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const { anyToSql } = require('../utils');
const { toSql } = require('../utils');
const { Utils } = require('sequelize');
const { registerHalfvec } = require('./halfvec');
const { registerSparsevec } = require('./sparsevec');
Expand All @@ -12,7 +12,7 @@ function registerType(Sequelize) {

function distance(op, column, value, sequelize, binary) {
const quotedColumn = column instanceof Utils.Literal ? column.val : sequelize.dialect.queryGenerator.quoteIdentifier(column);
const escapedValue = sequelize.escape(binary ? value : anyToSql(value));
const escapedValue = sequelize.escape(binary ? value : toSql(value));
return sequelize.literal(`${quotedColumn} ${op} ${escapedValue}`);
}

Expand Down
32 changes: 20 additions & 12 deletions src/utils/index.js
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
const util = require('node:util');
const { SparseVector } = require('./sparse-vector');

function fromSql(value) {
function vectorFromSql(value) {
if (value === null) {
return null;
}
return value.substring(1, value.length - 1).split(',').map((v) => parseFloat(v));
}

function toSql(value) {
function vectorToSql(value) {
if (Array.isArray(value)) {
return JSON.stringify(value);
}
return value;
}

const vectorFromSql = fromSql;
const vectorToSql = toSql;

const halfvecFromSql = fromSql;
const halfvecToSql = toSql;
const halfvecFromSql = vectorFromSql;
const halfvecToSql = vectorToSql;

function sparsevecFromSql(value) {
if (value === null) {
Expand All @@ -30,17 +27,29 @@ function sparsevecFromSql(value) {

function sparsevecToSql(value) {
if (value instanceof SparseVector) {
return value.toSql();
return value.toPostgres();
}
return value;
}

function anyToSql(value) {
function fromSql(value) {
if (value === null) {
return null;
} else if (value[0] == '{') {
return sparsevecFromSql(value);
} else if (value[0] == '[') {
return vectorFromSql(value);
} else {
throw new Error('invalid text representation');
}
}

function toSql(value) {
if (Array.isArray(value)) {
return toSql(value);
return vectorToSql(value);
}
if (value instanceof SparseVector) {
return value.toSql();
return sparsevecToSql(value);
}
return value;
}
Expand Down Expand Up @@ -85,7 +94,6 @@ module.exports = {
halfvecToSql,
sparsevecFromSql,
sparsevecToSql,
anyToSql,
sqlType,
vectorType,
halfvecType,
Expand Down
12 changes: 4 additions & 8 deletions src/utils/sparse-vector.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@ class SparseVector {
}

toString() {
const values = this.values;
const elements = this.indices.map((index, i) => util.format('%i:%f', index + 1, values[i])).join(',');
return util.format('{%s}/%d', elements, this.dimensions);
}

toSql() {
return this.toString();
return this.toPostgres();
}

toPostgres() {
return this.toSql();
const values = this.values;
const elements = this.indices.map((index, i) => util.format('%i:%f', index + 1, values[i])).join(',');
return util.format('{%s}/%d', elements, this.dimensions);
}

toArray() {
Expand Down
14 changes: 7 additions & 7 deletions tests/prisma/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ test('sparsevec', async () => {

// TODO use create when possible (field is not available in the generated client)
// https://www.prisma.io/docs/concepts/components/prisma-schema/features-without-psl-equivalent#unsupported-field-types
const embedding1 = (new SparseVector([1, 1, 1])).toSql();
const embedding2 = (new SparseVector([2, 2, 2])).toSql();
const embedding3 = (new SparseVector([1, 1, 2])).toSql();
const embedding1 = pgvector.toSql(new SparseVector([1, 1, 1]));
const embedding2 = pgvector.toSql(new SparseVector([2, 2, 2]));
const embedding3 = pgvector.toSql(new SparseVector([1, 1, 2]));
await prisma.$executeRaw`INSERT INTO prisma_items (sparse_embedding) VALUES (${embedding1}::sparsevec), (${embedding2}::sparsevec), (${embedding3}::sparsevec)`;

// TODO use raw orderBy when available
// https://github.com/prisma/prisma/issues/5848
const embedding = (new SparseVector([1, 1, 1])).toSql();
const embedding = pgvector.toSql(new SparseVector([1, 1, 1]));
const items = await prisma.$queryRaw`SELECT id, sparse_embedding::text FROM prisma_items ORDER BY sparse_embedding <-> ${embedding}::sparsevec LIMIT 5`;
expect((new SparseVector(items[0].sparse_embedding)).toArray()).toStrictEqual([1, 1, 1]);
expect((new SparseVector(items[1].sparse_embedding)).toArray()).toStrictEqual([1, 1, 2]);
expect((new SparseVector(items[2].sparse_embedding)).toArray()).toStrictEqual([2, 2, 2]);
expect(pgvector.fromSql(items[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items[1].sparse_embedding).toArray()).toStrictEqual([1, 1, 2]);
expect(pgvector.fromSql(items[2].sparse_embedding).toArray()).toStrictEqual([2, 2, 2]);
});

beforeEach(async () => {
Expand Down
8 changes: 4 additions & 4 deletions tests/slonik/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ test('example', async () => {
const binaryEmbedding1 = '000';
const binaryEmbedding2 = '101';
const binaryEmbedding3 = '111';
const sparseEmbedding1 = (new SparseVector([1, 1, 1])).toSql();
const sparseEmbedding2 = (new SparseVector([2, 2, 2])).toSql();
const sparseEmbedding3 = (new SparseVector([1, 1, 2])).toSql();
const sparseEmbedding1 = pgvector.toSql(new SparseVector([1, 1, 1]));
const sparseEmbedding2 = pgvector.toSql(new SparseVector([2, 2, 2]));
const sparseEmbedding3 = pgvector.toSql(new SparseVector([1, 1, 2]));
await pool.query(sql.unsafe`INSERT INTO slonik_items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES (${embedding1}, ${halfEmbedding1}, ${binaryEmbedding1}, ${sparseEmbedding1}), (${embedding2}, ${halfEmbedding2}, ${binaryEmbedding2}, ${sparseEmbedding2}), (${embedding3}, ${halfEmbedding3}, ${binaryEmbedding3}, ${sparseEmbedding3})`);

const embedding = pgvector.toSql([1, 1, 1]);
Expand All @@ -29,7 +29,7 @@ test('example', async () => {
expect(pgvector.fromSql(items.rows[0].embedding)).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items.rows[0].half_embedding)).toStrictEqual([1, 1, 1]);
expect(items.rows[0].binary_embedding).toStrictEqual('000');
expect((new SparseVector(items.rows[0].sparse_embedding)).toArray()).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items.rows[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);

await pool.query(sql.unsafe`CREATE INDEX ON slonik_items USING hnsw (embedding vector_l2_ops)`);

Expand Down
3 changes: 3 additions & 0 deletions tests/utils/index.test.mjs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pgvector from 'pgvector/utils';
import { SparseVector } from 'pgvector/utils';

test('fromSql', () => {
expect(pgvector.fromSql('[1,2,3]')).toStrictEqual([1, 2, 3]);
expect(pgvector.fromSql('{1:1,2:2,3:3}/3').toArray()).toStrictEqual([1, 2, 3]);
expect(pgvector.fromSql(null)).toBeNull();
});

test('toSql', () => {
expect(pgvector.toSql([1, 2, 3])).toEqual('[1,2,3]');
expect(pgvector.toSql(new SparseVector([1, 2, 3]))).toEqual('{1:1,2:2,3:3}/3');
expect(pgvector.toSql(null)).toBeNull();
});

Expand Down
6 changes: 3 additions & 3 deletions tests/utils/sparse-vector.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test('fromSql', () => {

test('fromDense', () => {
const vec = new SparseVector([1, 0, 2, 0, 3, 0]);
expect(vec.toSql()).toStrictEqual('{1:1,3:2,5:3}/6');
expect(vec.toPostgres()).toStrictEqual('{1:1,3:2,5:3}/6');
expect(vec.dimensions).toStrictEqual(6);
expect(vec.indices).toStrictEqual([0, 2, 4]);
expect(vec.values).toStrictEqual([1, 2, 3]);
Expand All @@ -28,7 +28,7 @@ test('fromMap', () => {
expect(vec.values).toStrictEqual([2, 3, 1]);
});

test('toSql', () => {
test('toPostgres', () => {
const vec = new SparseVector([1.23456789]);
expect(vec.toSql()).toStrictEqual('{1:1.23456789}/1');
expect(vec.toPostgres()).toStrictEqual('{1:1.23456789}/1');
});

0 comments on commit f44f749

Please sign in to comment.