diff --git a/CHANGELOG.md b/CHANGELOG.md index a81b999..9a81498 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ ## 0.1.5 (unreleased) -- Added support for `halfvec` and `sparsevec` types +- Added support for `halfvec`, `bit`, and `sparsevec` types ## 0.1.4 (2023-12-08) diff --git a/src/main/java/com/pgvector/PGbit.java b/src/main/java/com/pgvector/PGbit.java new file mode 100644 index 0000000..b61bc49 --- /dev/null +++ b/src/main/java/com/pgvector/PGbit.java @@ -0,0 +1,124 @@ +package com.pgvector; + +import java.io.Serializable; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import org.postgresql.PGConnection; +import org.postgresql.util.ByteConverter; +import org.postgresql.util.PGBinaryObject; +import org.postgresql.util.PGobject; + +/** + * PGbit class + */ +public class PGbit extends PGobject implements PGBinaryObject, Serializable, Cloneable { + private int length; + private byte[] data; + + /** + * Constructor + */ + public PGbit() { + type = "bit"; + } + + /** + * Constructor + * + * @param v boolean array + */ + public PGbit(boolean[] v) { + this(); + length = v.length; + data = new byte[(length + 7) / 8]; + for (int i = 0; i < length; i++) { + data[i / 8] |= (v[i] ? 1 : 0) << (7 - (i % 8)); + } + } + + /** + * Constructor + * + * @param s text representation of a bit string + * @throws SQLException exception + */ + public PGbit(String s) throws SQLException { + this(); + setValue(s); + } + + /** + * Sets the value from a text representation of a bit string + */ + public void setValue(String s) throws SQLException { + if (s == null) { + data = null; + } else { + length = s.length(); + data = new byte[(length + 7) / 8]; + for (int i = 0; i < length; i++) { + data[i / 8] |= (s.charAt(i) != '0' ? 1 : 0) << (7 - (i % 8)); + } + } + } + + /** + * Returns the text representation of a bit string + */ + public String getValue() { + if (data == null) { + return null; + } else { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(((data[i / 8] >> (7 - (i % 8))) & 1) == 1 ? '1' : '0'); + } + return sb.toString(); + } + } + + /** + * Returns the number of bytes for the binary representation + */ + public int lengthInBytes() { + return data == null ? 0 : 4 + data.length; + } + + /** + * Sets the value from a binary representation of a bit string + */ + public void setByteValue(byte[] value, int offset) throws SQLException { + length = ByteConverter.int4(value, offset); + data = new byte[(length + 7) / 8]; + for (int i = 0; i < data.length; i++) { + data[i] = value[offset + 4 + i]; + } + } + + /** + * Writes the binary representation of a bit string + */ + public void toBytes(byte[] bytes, int offset) { + if (data == null) { + return; + } + + ByteConverter.int4(bytes, offset, length); + for (int i = 0; i < data.length; i++) { + bytes[offset + 4 + i] = data[i]; + } + } + + /** + * Registers the bit type + * + * @param conn connection + * @throws SQLException exception + */ + public static void addBitType(Connection conn) throws SQLException { + conn.unwrap(PGConnection.class).addDataType("bit", PGbit.class); + } +} diff --git a/src/test/java/com/pgvector/JDBCJavaTest.java b/src/test/java/com/pgvector/JDBCJavaTest.java index 70a8f20..9ae2fda 100644 --- a/src/test/java/com/pgvector/JDBCJavaTest.java +++ b/src/test/java/com/pgvector/JDBCJavaTest.java @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; public class JDBCJavaTest { @@ -117,6 +118,59 @@ void halfvecExample(boolean readBinary) throws SQLException { assertNull(embeddings.get(3)); } + @Test + void testBitReadText() throws SQLException { + bitExample(false); + } + + @Test + void testBitReadBinary() throws SQLException { + bitExample(true); + } + + void bitExample(boolean readBinary) throws SQLException { + Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test"); + if (readBinary) { + conn.unwrap(PGConnection.class).setPrepareThreshold(-1); + } + + Statement setupStmt = conn.createStatement(); + setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector"); + setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items"); + + PGbit.addBitType(conn); + + Statement createStmt = conn.createStatement(); + createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding bit(9))"); + + PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)"); + insertStmt.setObject(1, new PGbit(new boolean[] {false, false, false, false, false, false, false, false, false})); + insertStmt.setObject(2, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true})); + insertStmt.setObject(3, new PGbit(new boolean[] {false, true, true, true, false, false, false, false, true})); + insertStmt.setObject(4, null); + insertStmt.executeUpdate(); + + PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <~> ? LIMIT 5"); + neighborStmt.setObject(1, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true})); + ResultSet rs = neighborStmt.executeQuery(); + List ids = new ArrayList<>(); + List embeddings = new ArrayList<>(); + while (rs.next()) { + ids.add(rs.getLong("id")); + embeddings.add((PGbit) rs.getObject("embedding")); + } + assertArrayEquals(new Long[] {2L, 3L, 1L, 4L}, ids.toArray()); + assertEquals("010100001", embeddings.get(0).getValue()); + assertEquals("011100001", embeddings.get(1).getValue()); + assertEquals("000000000", embeddings.get(2).getValue()); + assertNull(embeddings.get(3)); + + Statement indexStmt = conn.createStatement(); + indexStmt.executeUpdate("CREATE INDEX ON jdbc_items USING ivfflat (embedding bit_hamming_ops) WITH (lists = 100)"); + + conn.close(); + } + @Test void testSparsevecReadText() throws SQLException { sparsevecExample(false); diff --git a/src/test/java/com/pgvector/PGbitTest.java b/src/test/java/com/pgvector/PGbitTest.java new file mode 100644 index 0000000..8056273 --- /dev/null +++ b/src/test/java/com/pgvector/PGbitTest.java @@ -0,0 +1,23 @@ +package com.pgvector; + +import java.sql.SQLException; +import java.util.Arrays; +import com.pgvector.PGbit; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PGbitTest { + @Test + void testArrayConstructor() { + PGbit vec = new PGbit(new boolean[] {true, false, true}); + assertEquals("101", vec.getValue()); + } + + @Test + void testStringConstructor() throws SQLException { + PGbit vec = new PGbit("101"); + assertEquals("101", vec.getValue()); + } +}