Skip to content

Commit

Permalink
Addressing multiple review comments:
Browse files Browse the repository at this point in the history
- rework factory method getSerializerByStreamContent
- added test case for stream of unsupported content
- removed exceptions from Serializer interface method's signatures, changed it to unchecked runtime exception
- simplify license header in new classes

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jan 13, 2022
1 parent 13181a1 commit 8e946f5
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ public float[] getValue() {
return vector;
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException((e));
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand Down Expand Up @@ -43,8 +37,6 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
vectorList.add(vector);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
docIdList.add(doc);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand All @@ -22,20 +16,28 @@
*/
public class KNNVectorAsArraySerializer implements KNNVectorSerializer {
@Override
public byte[] floatToByteArray(float[] input) throws Exception {
public byte[] floatToByteArray(float[] input) {
byte[] bytes;
try (ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);) {
objectStream.writeObject(input);
bytes = byteStream.toByteArray();
} catch (IOException e) {
throw new RuntimeException(e);
}
return bytes;
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) throws IOException, ClassNotFoundException {
final ObjectInputStream objectStream = new ObjectInputStream(byteStream);
final float[] vector = (float[]) objectStream.readObject();
return vector;
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
try {
final ObjectInputStream objectStream = new ObjectInputStream(byteStream);
final float[] vector = (float[]) objectStream.readObject();
return vector;
} catch (IOException e) {
throw new RuntimeException(e);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand All @@ -33,9 +27,13 @@ public byte[] floatToByteArray(float[] input) {

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) {
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats");
}
final byte[] vectorAsByteArray = new byte[byteStream.available()];
byteStream.read(vectorAsByteArray, 0, byteStream.available());
final float[] vector = new float[vectorAsByteArray.length / BYTES_IN_FLOAT];
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT;
final float[] vector = new float[sizeOfFloatArray];
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector);
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand All @@ -22,16 +16,13 @@ public interface KNNVectorSerializer {
* Serializes array of floats to array of bytes
* @param input array that will be converted
* @return array of bytes that contains serialized input array
* @throws Exception
*/
byte[] floatToByteArray(float[] input) throws Exception;
byte[] floatToByteArray(float[] input);

/**
* Deserializes all bytes from the stream to array of floats
* @param byteStream stream of bytes that will be used for deserialization to array of floats
* @return array of floats deserialized from the stream
* @throws IOException
* @throws ClassNotFoundException
*/
float[] byteToFloatArray(ByteArrayInputStream byteStream) throws IOException, ClassNotFoundException;
float[] byteToFloatArray(ByteArrayInputStream byteStream);
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand All @@ -33,6 +27,7 @@ COLLECTION_OF_FLOATS, new KNNVectorAsCollectionOfFloatsSerializer()

private static final int ARRAY_HEADER_OFFSET = 27;
private static final int BYTES_IN_FLOAT = 4;
private static final int BITS_IN_ONE_BYTE = 8;

/**
* Array represents first 6 bytes of the byte stream header as per Java serialization protocol described in details
Expand Down Expand Up @@ -61,23 +56,30 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn
}

private static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
//check size, if the length is long enough for header and length is header + some number of floats
if (byteStream.available() < ARRAY_HEADER_OFFSET ||
(byteStream.available() - ARRAY_HEADER_OFFSET) % BYTES_IN_FLOAT != 0) {
return COLLECTION_OF_FLOATS;
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}
final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length];
byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length);
byteStream.reset();
//checking if stream protocol grammar in header is valid for serialized array
if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) {
return ARRAY;
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET;
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY);
}
return COLLECTION_OF_FLOATS;
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
}

private static SerializationMode getSerializerOrThrowError(int numberOfRemainingBytes, final SerializationMode serializationMode) {
if (numberOfRemainingBytes % BYTES_IN_FLOAT == 0) {
return serializationMode;
}
throw new IllegalArgumentException(String.format("Byte stream cannot be deserialized to array of floats due to invalid length %d", numberOfRemainingBytes));
}

private static byte highByte(short shortValue) {
return (byte) (shortValue>>8);
return (byte) (shortValue>> BITS_IN_ONE_BYTE);
}

private static byte lowByte(short shortValue) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.util;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec;
Expand All @@ -27,10 +21,25 @@ public class KNNVectorSerializerTests extends KNNTestCase {

Random random = new Random();

public void testVectorSerializerFactory() {
public void testVectorSerializerFactory() throws Exception {
//check that default serializer can work with array of floats
//setup
final float[] vector = getArrayOfRandomFloats(20);
final ByteArrayOutputStream bas = new ByteArrayOutputStream();
final DataOutputStream ds = new DataOutputStream(bas);
for (float f : vector)
ds.writeFloat(f);
final byte[] vectorAsCollectionOfFloats = bas.toByteArray();
final ByteArrayInputStream bais = new ByteArrayInputStream(vectorAsCollectionOfFloats);
bais.reset();

final KNNVectorSerializer defaultSerializer = KNNVectorSerializerFactory.getDefaultSerializer();
assertNotNull(defaultSerializer);

final float[] actualDeserializedVector = defaultSerializer.byteToFloatArray(bais);
assertNotNull(actualDeserializedVector);
assertArrayEquals(vector, actualDeserializedVector, 0.1f);

final KNNVectorSerializer arraySerializer =
KNNVectorSerializerFactory.getSerializerBySerializationMode(SerializationMode.ARRAY);
assertNotNull(arraySerializer);
Expand All @@ -40,10 +49,23 @@ public void testVectorSerializerFactory() {
assertNotNull(collectionOfFloatsSerializer);
}


public void testVectorSerializerFactory_throwExceptionForStreamWithUnsupportedDataType() throws Exception {
//prepare array of chars that is not supported by serializer factory. expected behavior is to fail
final char[] arrayOfChars = new char[] {'a', 'b', 'c'};
final ByteArrayOutputStream bas = new ByteArrayOutputStream();
final DataOutputStream ds = new DataOutputStream(bas);
for (char ch : arrayOfChars)
ds.writeChar(ch);
final byte[] vectorAsCollectionOfChars = bas.toByteArray();
final ByteArrayInputStream bais = new ByteArrayInputStream(vectorAsCollectionOfChars);
bais.reset();

expectThrows(RuntimeException.class, () -> KNNVectorSerializerFactory.getSerializerByStreamContent(bais));
}

public void testVectorAsArraySerializer() throws Exception {
int arrayLength = 20;
float[] vector = new float[arrayLength];
IntStream.range(0, arrayLength).forEach(index -> vector[index] = random.nextFloat());
final float[] vector = getArrayOfRandomFloats(20);

final ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
final ObjectOutputStream objectStream = new ObjectOutputStream(byteStream);
Expand All @@ -70,9 +92,7 @@ public void testVectorAsArraySerializer() throws Exception {

public void testVectorAsCollectionOfFloatsSerializer() throws Exception {
//setup
int arrayLength = 20;
float[] vector = new float[arrayLength];
IntStream.range(0, arrayLength).forEach(index -> vector[index] = random.nextFloat());
final float[] vector = getArrayOfRandomFloats(20);

final ByteArrayOutputStream bas = new ByteArrayOutputStream();
final DataOutputStream ds = new DataOutputStream(bas);
Expand All @@ -97,4 +117,10 @@ public void testVectorAsCollectionOfFloatsSerializer() throws Exception {
assertNotNull(actualDeserializedVector);
assertArrayEquals(vector, actualDeserializedVector, 0.1f);
}

private float[] getArrayOfRandomFloats(int arrayLength) {
float[] vector = new float[arrayLength];
IntStream.range(0, arrayLength).forEach(index -> vector[index] = random.nextFloat());
return vector;
}
}

0 comments on commit 8e946f5

Please sign in to comment.