Skip to content

Commit

Permalink
Added support for binary format for halfvec
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 54b8f98 commit b9a6cf7
Showing 1 changed file with 67 additions and 9 deletions.
76 changes: 67 additions & 9 deletions src/main/java/com/pgvector/PGhalfvec.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

/**
* PGhalfvec class
*/
public class PGhalfvec extends PGobject implements Serializable, Cloneable {
private float[] vec;
public class PGhalfvec extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private short[] vec;

/**
* Constructor
Expand All @@ -29,7 +31,14 @@ public PGhalfvec() {
*/
public PGhalfvec(float[] v) {
this();
vec = v;
if (v == null) {
vec = null;
} else {
vec = new short[v.length];
for (int i = 0; i < v.length; i++) {
vec[i] = Float.floatToFloat16(v[i]);
}
}
}

/**
Expand All @@ -43,10 +52,10 @@ public <T extends Number> PGhalfvec(List<T> v) {
if (Objects.isNull(v)) {
vec = null;
} else {
vec = new float[v.size()];
vec = new short[v.size()];
int i = 0;
for (T f : v) {
vec[i++] = f.floatValue();
vec[i++] = Float.floatToFloat16(f.floatValue());
}
}
}
Expand All @@ -70,9 +79,9 @@ public void setValue(String s) throws SQLException {
vec = null;
} else {
String[] sp = s.substring(1, s.length() - 1).split(",");
vec = new float[sp.length];
vec = new short[sp.length];
for (int i = 0; i < sp.length; i++) {
vec[i] = Float.parseFloat(sp[i]);
vec[i] = Float.floatToFloat16(Float.parseFloat(sp[i]));
}
}
}
Expand All @@ -84,7 +93,52 @@ public String getValue() {
if (vec == null) {
return null;
} else {
return Arrays.toString(vec).replace(" ", "");
float[] fvec = new float[vec.length];
for (int i = 0; i < vec.length; i++) {
fvec[i] = Float.float16ToFloat(vec[i]);
}
return Arrays.toString(fvec).replace(" ", "");
}
}

/**
* Returns the number of bytes for the binary representation
*/
public int lengthInBytes() {
return vec == null ? 0 : 4 + vec.length * 2;
}

/**
* Sets the value from a binary representation of a half vector
*/
public void setByteValue(byte[] value, int offset) throws SQLException {
int dim = ByteConverter.int2(value, offset);

int unused = ByteConverter.int2(value, offset + 2);
if (unused != 0) {
throw new SQLException("expected unused to be 0");
}

vec = new short[dim];
for (int i = 0; i < dim; i++) {
vec[i] = ByteConverter.int2(value, offset + 4 + i * 2);
}
}

/**
* Writes the binary representation of a half vector
*/
public void toBytes(byte[] bytes, int offset) {
if (vec == null) {
return;
}

// server will error on overflow due to unconsumed buffer
// could set to Short.MAX_VALUE for friendlier error message
ByteConverter.int2(bytes, offset, vec.length);
ByteConverter.int2(bytes, offset + 2, 0);
for (int i = 0; i < vec.length; i++) {
ByteConverter.int2(bytes, offset + 4 + i * 2, vec[i]);
}
}

Expand All @@ -94,7 +148,11 @@ public String getValue() {
* @return an array
*/
public float[] toArray() {
return vec;
float[] v = new float[vec.length];
for (int i = 0; i < vec.length; i++) {
v[i] = Float.float16ToFloat(vec[i]);
}
return v;
}

/**
Expand Down

0 comments on commit b9a6cf7

Please sign in to comment.