Skip to content

Commit

Permalink
Switched to float[] for halfvec
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent a3d2b0d commit ebf503b
Showing 1 changed file with 9 additions and 67 deletions.
76 changes: 9 additions & 67 deletions src/main/java/com/pgvector/PGhalfvec.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
import java.util.List;
import java.util.Objects;
import org.postgresql.PGConnection;
import org.postgresql.util.ByteConverter;
import org.postgresql.util.PGBinaryObject;
import org.postgresql.util.PGobject;

/**
* PGhalfvec class
*/
public class PGhalfvec extends PGobject implements PGBinaryObject, Serializable, Cloneable {
private short[] vec;
public class PGhalfvec extends PGobject implements Serializable, Cloneable {
private float[] vec;

/**
* Constructor
Expand All @@ -31,14 +29,7 @@ public PGhalfvec() {
*/
public PGhalfvec(float[] v) {
this();
if (v == null) {
vec = null;
} else {
vec = new short[v.length];
for (int i = 0; i < v.length; i++) {
vec[i] = Float.floatToFloat16(v[i]);
}
}
vec = v;
}

/**
Expand All @@ -52,10 +43,10 @@ public <T extends Number> PGhalfvec(List<T> v) {
if (Objects.isNull(v)) {
vec = null;
} else {
vec = new short[v.size()];
vec = new float[v.size()];
int i = 0;
for (T f : v) {
vec[i++] = Float.floatToFloat16(f.floatValue());
vec[i++] = f.floatValue();
}
}
}
Expand All @@ -79,9 +70,9 @@ public void setValue(String s) throws SQLException {
vec = null;
} else {
String[] sp = s.substring(1, s.length() - 1).split(",");
vec = new short[sp.length];
vec = new float[sp.length];
for (int i = 0; i < sp.length; i++) {
vec[i] = Float.floatToFloat16(Float.parseFloat(sp[i]));
vec[i] = Float.parseFloat(sp[i]);
}
}
}
Expand All @@ -93,52 +84,7 @@ public String getValue() {
if (vec == null) {
return null;
} else {
float[] fvec = new float[vec.length];
for (int i = 0; i < vec.length; i++) {
fvec[i] = Float.float16ToFloat(vec[i]);
}
return Arrays.toString(fvec).replace(" ", "");
}
}

/**
* Returns the number of bytes for the binary representation
*/
public int lengthInBytes() {
return vec == null ? 0 : 4 + vec.length * 2;
}

/**
* Sets the value from a binary representation of a half vector
*/
public void setByteValue(byte[] value, int offset) throws SQLException {
int dim = ByteConverter.int2(value, offset);

int unused = ByteConverter.int2(value, offset + 2);
if (unused != 0) {
throw new SQLException("expected unused to be 0");
}

vec = new short[dim];
for (int i = 0; i < dim; i++) {
vec[i] = ByteConverter.int2(value, offset + 4 + i * 2);
}
}

/**
* Writes the binary representation of a half vector
*/
public void toBytes(byte[] bytes, int offset) {
if (vec == null) {
return;
}

// server will error on overflow due to unconsumed buffer
// could set to Short.MAX_VALUE for friendlier error message
ByteConverter.int2(bytes, offset, vec.length);
ByteConverter.int2(bytes, offset + 2, 0);
for (int i = 0; i < vec.length; i++) {
ByteConverter.int2(bytes, offset + 4 + i * 2, vec[i]);
return Arrays.toString(vec).replace(" ", "");
}
}

Expand All @@ -148,11 +94,7 @@ public void toBytes(byte[] bytes, int offset) {
* @return an array
*/
public float[] toArray() {
float[] v = new float[vec.length];
for (int i = 0; i < vec.length; i++) {
v[i] = Float.float16ToFloat(vec[i]);
}
return v;
return vec;
}

/**
Expand Down

0 comments on commit ebf503b

Please sign in to comment.