Skip to content

Commit

Permalink
Added support for sparsevec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 20, 2024
1 parent a86a6c1 commit da1e3a5
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.1.5 (unreleased)

- Added support for `halfvec` type
- Added support for `halfvec` and `sparsevec` types

## 0.1.4 (2023-12-08)

Expand Down
232 changes: 232 additions & 0 deletions src/main/java/com/pgvector/PGsparsevec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package com.pgvector;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

/**
* PGsparsevec class
*/
public class PGsparsevec extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private int dimensions;
private int[] indices;
private float[] values;

/**
* Constructor
*/
public PGsparsevec() {
type = "sparsevec";
}

/**
* Constructor
*
* @param v float array
*/
public PGsparsevec(float[] v) {
this();

int nnz = 0;
for (int i = 0; i < v.length; i++) {
if (v[i] != 0) {
nnz++;
}
}

dimensions = v.length;
indices = new int[nnz];
values = new float[nnz];

int j = 0;
for (int i = 0; i < v.length; i++) {
if (v[i] != 0) {
indices[j] = i;
values[j] = v[i];
j++;
}
}
}

/**
* Constructor
*
* @param <T> number
* @param v list of numbers
*/
public <T extends Number> PGsparsevec(List<T> v) {
this();
if (Objects.isNull(v)) {
indices = null;
} else {
int nnz = 0;
for (T f : v) {
if (f.floatValue() != 0) {
nnz++;
}
}

dimensions = v.size();
indices = new int[nnz];
values = new float[nnz];

int i = 0;
int j = 0;
for (T f : v) {
float fv = f.floatValue();
if (fv != 0) {
indices[j] = i;
values[j] = fv;
j++;
}
i++;
}

}
}

/**
* Constructor
*
* @param s text representation of a sparse vector
* @throws SQLException exception
*/
public PGsparsevec(String s) throws SQLException {
this();
setValue(s);
}

/**
* Sets the value from a text representation of a sparse vector
*/
public void setValue(String s) throws SQLException {
if (s == null) {
indices = null;
} else {
String[] sp = s.split("/", 2);
String[] elements = sp[0].substring(1, sp[0].length() - 1).split(",");

dimensions = Integer.parseInt(sp[1]);
indices = new int[elements.length];
values = new float[elements.length];

for (int i = 0; i < elements.length; i++)
{
String[] ep = elements[i].split(":", 2);
indices[i] = Integer.parseInt(ep[0]) - 1;
values[i] = Float.parseFloat(ep[1]);
}
}
}

/**
* Returns the text representation of a sparse vector
*/
public String getValue() {
if (indices == null) {
return null;
} else {
StringBuilder sb = new StringBuilder(13 + 27 * indices.length);
sb.append('{');

for (int i = 0; i < indices.length; i++) {
if (i > 0) {
sb.append(',');
}
sb.append(indices[i] + 1);
sb.append(':');
sb.append(values[i]);
}

sb.append('}');
sb.append('/');
sb.append(dimensions);
return sb.toString();
}
}

/**
* Returns the number of bytes for the binary representation
*/
public int lengthInBytes() {
return indices == null ? 0 : 12 + indices.length * 4 + values.length * 4;
}

/**
* Sets the value from a binary representation of a sparse vector
*/
public void setByteValue(byte[] value, int offset) throws SQLException {
dimensions = ByteConverter.int4(value, offset);
int nnz = ByteConverter.int4(value, offset + 4);

int unused = ByteConverter.int4(value, offset + 8);
if (unused != 0) {
throw new SQLException("expected unused to be 0");
}

indices = new int[nnz];
for (int i = 0; i < nnz; i++) {
indices[i] = ByteConverter.int4(value, offset + 12 + i * 4);
}

values = new float[nnz];
for (int i = 0; i < nnz; i++) {
values[i] = ByteConverter.float4(value, offset + 12 + nnz * 4 + i * 4);
}
}

/**
* Writes the binary representation of a sparse vector
*/
public void toBytes(byte[] bytes, int offset) {
if (indices == null) {
return;
}

// server will error on overflow due to unconsumed buffer
// could set to Integer.MAX_VALUE for friendlier error message
ByteConverter.int4(bytes, offset, dimensions);
ByteConverter.int4(bytes, offset + 4, indices.length);
ByteConverter.int4(bytes, offset + 8, 0);
for (int i = 0; i < indices.length; i++) {
ByteConverter.int4(bytes, offset + 12 + i * 4, indices[i]);
}
for (int i = 0; i < values.length; i++) {
ByteConverter.float4(bytes, offset + 12 + indices.length * 4 + i * 4, values[i]);
}
}

/**
* Returns an array
*
* @return an array
*/
public float[] toArray() {
if (indices == null) {
return null;
}

float[] vec = new float[dimensions];
for (int i = 0; i < indices.length; i++) {
vec[indices[i]] = values[i];
}
return vec;
}

/**
* Registers the sparsevec type
*
* @param conn connection
* @throws SQLException exception
*/
public static void addSparsevecType(Connection conn) throws SQLException {
conn.unwrap(PGConnection.class).addDataType("sparsevec", PGsparsevec.class);
}
}
48 changes: 48 additions & 0 deletions src/test/java/com/pgvector/JDBCJavaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,52 @@ void halfvecExample(boolean readBinary) throws SQLException {
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
assertNull(embeddings.get(3));
}

@Test
void testSparsevecReadText() throws SQLException {
sparsevecExample(false);
}

@Test
void testSparsevecReadBinary() throws SQLException {
sparsevecExample(true);
}

void sparsevecExample(boolean readBinary) throws SQLException {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");
if (readBinary) {
conn.unwrap(PGConnection.class).setPrepareThreshold(-1);
}

Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");

PGsparsevec.addSparsevecType(conn);

Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding sparsevec(3))");

PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
insertStmt.setObject(1, new PGsparsevec(new float[] {1, 1, 1}));
insertStmt.setObject(2, new PGsparsevec(new float[] {2, 2, 2}));
insertStmt.setObject(3, new PGsparsevec(new float[] {1, 1, 2}));
insertStmt.setObject(4, null);
insertStmt.executeUpdate();

PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5");
neighborStmt.setObject(1, new PGsparsevec(new float[] {1, 1, 1}));
ResultSet rs = neighborStmt.executeQuery();
List<Long> ids = new ArrayList<>();
List<PGsparsevec> embeddings = new ArrayList<>();
while (rs.next()) {
ids.add(rs.getLong("id"));
embeddings.add((PGsparsevec) rs.getObject("embedding"));
}
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
assertNull(embeddings.get(3));
}
}
43 changes: 43 additions & 0 deletions src/test/java/com/pgvector/PGsparsevecTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.pgvector;

import java.sql.SQLException;
import java.util.Arrays;
import com.pgvector.PGsparsevec;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class PGsparsevecTest {
@Test
void testArrayConstructor() {
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0});
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray());
}

@Test
void testStringConstructor() throws SQLException {
PGsparsevec vec = new PGsparsevec("{1:1,3:2,5:3}/6");
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray());
}

@Test
void testFloatListConstructor() {
Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)};
PGsparsevec vec = new PGsparsevec(Arrays.asList(a));
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testDoubleListConstructor() {
Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)};
PGsparsevec vec = new PGsparsevec(Arrays.asList(a));
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
}

@Test
void testGetValue() {
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0});
assertEquals("{1:1.0,3:2.0,5:3.0}/6", vec.getValue());
}
}

0 comments on commit da1e3a5

Please sign in to comment.