Skip to content

Commit

Permalink
Added support for halfvec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 45d8b38 commit 7c78945
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.1.5 (unreleased)

- Added support for `halfvec` type

## 0.1.4 (2023-12-08)

- Added `List` constructor
Expand Down
164 changes: 164 additions & 0 deletions src/main/java/com/pgvector/PGhalfvec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package com.pgvector;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

/**
* PGvector class
*/
public class PGhalfvec extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private short[] vec;

/**
* Constructor
*/
public PGhalfvec() {
type = "halfvec";
}

/**
* Constructor
*
* @param v short array
*/
public PGhalfvec(float[] v) {
this();
if (v == null) {
vec = null;
} else {
vec = new short[v.length];
for (int i = 0; i < v.length; i++) {
vec[i] = Float.floatToFloat16(v[i]);
}
}
}

/**
* Constructor
*
* @param <T> number
* @param v list of numbers
*/
public <T extends Number> PGhalfvec(List<T> v) {
this();
if (Objects.isNull(v)) {
vec = null;
} else {
vec = new short[v.size()];
int i = 0;
for (T f : v) {
vec[i++] = Float.floatToFloat16(f.floatValue());
}
}
}

/**
* Constructor
*
* @param s text representation of a vector
* @throws SQLException exception
*/
public PGhalfvec(String s) throws SQLException {
this();
setValue(s);
}

/**
* Sets the value from a text representation of a vector
*/
public void setValue(String s) throws SQLException {
if (s == null) {
vec = null;
} else {
String[] sp = s.substring(1, s.length() - 1).split(",");
vec = new short[sp.length];
for (int i = 0; i < sp.length; i++) {
vec[i] = Float.floatToFloat16(Float.parseFloat(sp[i]));
}
}
}

/**
* Returns the text representation of a vector
*/
public String getValue() {
if (vec == null) {
return null;
} else {
// TODO convert
return Arrays.toString(vec).replace(" ", "");
}
}

/**
* Returns the number of bytes for the binary representation
*/
public int lengthInBytes() {
return vec == null ? 0 : 4 + vec.length * 2;
}

/**
* Sets the value from a binary representation of a vector
*/
public void setByteValue(byte[] value, int offset) throws SQLException {
int dim = ByteConverter.int2(value, offset);

int unused = ByteConverter.int2(value, offset + 2);
if (unused != 0) {
throw new SQLException("expected unused to be 0");
}

vec = new short[dim];
for (int i = 0; i < dim; i++) {
vec[i] = ByteConverter.int2(value, offset + 4 + i * 2);
}
}

/**
* Writes the binary representation of a vector
*/
public void toBytes(byte[] bytes, int offset) {
if (vec == null) {
return;
}

// server will error on overflow due to unconsumed buffer
// could set to Short.MAX_VALUE for friendlier error message
ByteConverter.int2(bytes, offset, vec.length);
ByteConverter.int2(bytes, offset + 2, 0);
for (int i = 0; i < vec.length; i++) {
ByteConverter.int2(bytes, offset + 4 + i * 2, vec[i]);
}
}

/**
* Returns an array
*
* @return an array
*/
public float[] toArray() {
float[] v = new float[vec.length];
for (int i = 0; i < vec.length; i++) {
v[i] = Float.float16ToFloat(vec[i]);
}
return v;
}

/**
* Registers the halfvec type
*
* @param conn connection
* @throws SQLException exception
*/
public static void addHalfvecType(Connection conn) throws SQLException {
conn.unwrap(PGConnection.class).addDataType("halfvec", PGhalfvec.class);
}
}
36 changes: 36 additions & 0 deletions src/test/java/com/pgvector/JDBCJavaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,40 @@ void example(boolean readBinary) throws SQLException {

conn.close();
}

@Test
void testHalfvec() throws SQLException {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");

Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");

PGhalfvec.addHalfvecType(conn);

Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding halfvec(3))");

PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
insertStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
insertStmt.setObject(2, new PGhalfvec(new float[] {2, 2, 2}));
insertStmt.setObject(3, new PGhalfvec(new float[] {1, 1, 2}));
insertStmt.setObject(4, null);
insertStmt.executeUpdate();

PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5");
neighborStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
ResultSet rs = neighborStmt.executeQuery();
List<Long> ids = new ArrayList<>();
List<PGhalfvec> embeddings = new ArrayList<>();
while (rs.next()) {
ids.add(rs.getLong("id"));
embeddings.add((PGhalfvec) rs.getObject("embedding"));
}
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
assertNull(embeddings.get(3));
}
}

0 comments on commit 7c78945

Please sign in to comment.