-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
324 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
package com.pgvector; | ||
|
||
import java.io.Serializable; | ||
import java.sql.Connection; | ||
import java.sql.SQLException; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import org.postgresql.PGConnection; | ||
import org.postgresql.util.ByteConverter; | ||
import org.postgresql.util.PGBinaryObject; | ||
import org.postgresql.util.PGobject; | ||
|
||
/** | ||
* PGsparsevec class | ||
*/ | ||
public class PGsparsevec extends PGobject implements PGBinaryObject, Serializable, Cloneable { | ||
private int dimensions; | ||
private int[] indices; | ||
private float[] values; | ||
|
||
/** | ||
* Constructor | ||
*/ | ||
public PGsparsevec() { | ||
type = "sparsevec"; | ||
} | ||
|
||
/** | ||
* Constructor | ||
* | ||
* @param v float array | ||
*/ | ||
public PGsparsevec(float[] v) { | ||
this(); | ||
|
||
int nnz = 0; | ||
for (int i = 0; i < v.length; i++) { | ||
if (v[i] != 0) { | ||
nnz++; | ||
} | ||
} | ||
|
||
dimensions = v.length; | ||
indices = new int[nnz]; | ||
values = new float[nnz]; | ||
|
||
int j = 0; | ||
for (int i = 0; i < v.length; i++) { | ||
if (v[i] != 0) { | ||
indices[j] = i; | ||
values[j] = v[i]; | ||
j++; | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Constructor | ||
* | ||
* @param <T> number | ||
* @param v list of numbers | ||
*/ | ||
public <T extends Number> PGsparsevec(List<T> v) { | ||
this(); | ||
if (Objects.isNull(v)) { | ||
indices = null; | ||
} else { | ||
int nnz = 0; | ||
for (T f : v) { | ||
if (f.floatValue() != 0) { | ||
nnz++; | ||
} | ||
} | ||
|
||
dimensions = v.size(); | ||
indices = new int[nnz]; | ||
values = new float[nnz]; | ||
|
||
int i = 0; | ||
int j = 0; | ||
for (T f : v) { | ||
float fv = f.floatValue(); | ||
if (fv != 0) { | ||
indices[j] = i; | ||
values[j] = fv; | ||
j++; | ||
} | ||
i++; | ||
} | ||
|
||
} | ||
} | ||
|
||
/** | ||
* Constructor | ||
* | ||
* @param s text representation of a sparse vector | ||
* @throws SQLException exception | ||
*/ | ||
public PGsparsevec(String s) throws SQLException { | ||
this(); | ||
setValue(s); | ||
} | ||
|
||
/** | ||
* Sets the value from a text representation of a sparse vector | ||
*/ | ||
public void setValue(String s) throws SQLException { | ||
if (s == null) { | ||
indices = null; | ||
} else { | ||
String[] sp = s.split("/", 2); | ||
String[] elements = sp[0].substring(1, sp[0].length() - 1).split(","); | ||
|
||
dimensions = Integer.parseInt(sp[1]); | ||
indices = new int[elements.length]; | ||
values = new float[elements.length]; | ||
|
||
for (int i = 0; i < elements.length; i++) | ||
{ | ||
String[] ep = elements[i].split(":", 2); | ||
indices[i] = Integer.parseInt(ep[0]) - 1; | ||
values[i] = Float.parseFloat(ep[1]); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns the text representation of a sparse vector | ||
*/ | ||
public String getValue() { | ||
if (indices == null) { | ||
return null; | ||
} else { | ||
StringBuilder sb = new StringBuilder(13 + 27 * indices.length); | ||
sb.append('{'); | ||
|
||
for (int i = 0; i < indices.length; i++) { | ||
if (i > 0) { | ||
sb.append(','); | ||
} | ||
sb.append(indices[i] + 1); | ||
sb.append(':'); | ||
sb.append(values[i]); | ||
} | ||
|
||
sb.append('}'); | ||
sb.append('/'); | ||
sb.append(dimensions); | ||
return sb.toString(); | ||
} | ||
} | ||
|
||
/** | ||
* Returns the number of bytes for the binary representation | ||
*/ | ||
public int lengthInBytes() { | ||
return indices == null ? 0 : 12 + indices.length * 4 + values.length * 4; | ||
} | ||
|
||
/** | ||
* Sets the value from a binary representation of a sparse vector | ||
*/ | ||
public void setByteValue(byte[] value, int offset) throws SQLException { | ||
dimensions = ByteConverter.int4(value, offset); | ||
int nnz = ByteConverter.int4(value, offset + 4); | ||
|
||
int unused = ByteConverter.int4(value, offset + 8); | ||
if (unused != 0) { | ||
throw new SQLException("expected unused to be 0"); | ||
} | ||
|
||
indices = new int[nnz]; | ||
for (int i = 0; i < nnz; i++) { | ||
indices[i] = ByteConverter.int4(value, offset + 12 + i * 4); | ||
} | ||
|
||
values = new float[nnz]; | ||
for (int i = 0; i < nnz; i++) { | ||
values[i] = ByteConverter.float4(value, offset + 12 + nnz * 4 + i * 4); | ||
} | ||
} | ||
|
||
/** | ||
* Writes the binary representation of a sparse vector | ||
*/ | ||
public void toBytes(byte[] bytes, int offset) { | ||
if (indices == null) { | ||
return; | ||
} | ||
|
||
// server will error on overflow due to unconsumed buffer | ||
// could set to Integer.MAX_VALUE for friendlier error message | ||
ByteConverter.int4(bytes, offset, dimensions); | ||
ByteConverter.int4(bytes, offset + 4, indices.length); | ||
ByteConverter.int4(bytes, offset + 8, 0); | ||
for (int i = 0; i < indices.length; i++) { | ||
ByteConverter.int4(bytes, offset + 12 + i * 4, indices[i]); | ||
} | ||
for (int i = 0; i < values.length; i++) { | ||
ByteConverter.float4(bytes, offset + 12 + indices.length * 4 + i * 4, values[i]); | ||
} | ||
} | ||
|
||
/** | ||
* Returns an array | ||
* | ||
* @return an array | ||
*/ | ||
public float[] toArray() { | ||
if (indices == null) { | ||
return null; | ||
} | ||
|
||
float[] vec = new float[dimensions]; | ||
for (int i = 0; i < indices.length; i++) { | ||
vec[indices[i]] = values[i]; | ||
} | ||
return vec; | ||
} | ||
|
||
/** | ||
* Registers the sparsevec type | ||
* | ||
* @param conn connection | ||
* @throws SQLException exception | ||
*/ | ||
public static void addSparsevecType(Connection conn) throws SQLException { | ||
conn.unwrap(PGConnection.class).addDataType("sparsevec", PGsparsevec.class); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package com.pgvector; | ||
|
||
import java.sql.SQLException; | ||
import java.util.Arrays; | ||
import com.pgvector.PGsparsevec; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertArrayEquals; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
public class PGsparsevecTest { | ||
@Test | ||
void testArrayConstructor() { | ||
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0}); | ||
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray()); | ||
} | ||
|
||
@Test | ||
void testStringConstructor() throws SQLException { | ||
PGsparsevec vec = new PGsparsevec("{1:1,3:2,5:3}/6"); | ||
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray()); | ||
} | ||
|
||
@Test | ||
void testFloatListConstructor() { | ||
Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)}; | ||
PGsparsevec vec = new PGsparsevec(Arrays.asList(a)); | ||
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); | ||
} | ||
|
||
@Test | ||
void testDoubleListConstructor() { | ||
Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)}; | ||
PGsparsevec vec = new PGsparsevec(Arrays.asList(a)); | ||
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray()); | ||
} | ||
|
||
@Test | ||
void testGetValue() { | ||
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0}); | ||
assertEquals("{1:1.0,3:2.0,5:3.0}/6", vec.getValue()); | ||
} | ||
} |