Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java] Fix double close #19133

Merged
merged 1 commit into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

/**
* A container for a map returned by {@link OrtSession#run(Map)}.
Expand All @@ -16,6 +17,7 @@
* values: String, Long, Float, Double.
*/
public class OnnxMap implements OnnxValue {
private static final Logger logger = Logger.getLogger(OnnxMap.class.getName());

static {
try {
Expand Down Expand Up @@ -107,6 +109,8 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {

private final OnnxMapValueType valueType;

private boolean closed;

/**
* Constructs an OnnxMap containing a reference to the native map along with the type information.
*
Expand All @@ -122,6 +126,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {
this.info = info;
this.stringKeys = info.keyType == OnnxJavaType.STRING;
this.valueType = OnnxMapValueType.mapFromOnnxJavaType(info.valueType);
this.closed = false;
}

/**
Expand All @@ -146,6 +151,7 @@ public OnnxValueType getType() {
*/
@Override
public Map<? extends Object, ? extends Object> getValue() throws OrtException {
checkClosed();
Object[] keys = getMapKeys();
Object[] values = getMapValues();
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));
Expand Down Expand Up @@ -222,10 +228,27 @@ public String toString() {
return "ONNXMap(size=" + size() + ",info=" + info.toString() + ")";
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Closes this map, releasing the native memory backing it and it's elements. */
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed map.");
}
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}

private native String[] getStringKeys(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down
27 changes: 25 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxSequence.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

/**
* A sequence of {@link OnnxValue}s all of the same type.
Expand All @@ -24,6 +25,7 @@
* </ul>
*/
public class OnnxSequence implements OnnxValue {
private static final Logger logger = Logger.getLogger(OnnxSequence.class.getName());

static {
try {
Expand All @@ -40,6 +42,8 @@ public class OnnxSequence implements OnnxValue {

private final SequenceInfo info;

private boolean closed;

/**
* Creates the wrapper object for a native sequence.
*
Expand All @@ -53,6 +57,7 @@ public class OnnxSequence implements OnnxValue {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
this.closed = false;
}

@Override
Expand All @@ -76,6 +81,7 @@ public OnnxValueType getType() {
*/
@Override
public List<? extends OnnxValue> getValue() throws OrtException {
checkClosed();
if (info.sequenceOfMaps) {
OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return Collections.unmodifiableList(Arrays.asList(maps));
Expand Down Expand Up @@ -110,10 +116,27 @@ public String toString() {
return "OnnxSequence(info=" + info.toString() + ")";
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Closes this sequence, releasing the native memory backing it and it's elements. */
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed sequence.");
}
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}

private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down
18 changes: 16 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;
import java.util.logging.Logger;

/**
* A Java object wrapping an OnnxSparseTensor.
Expand All @@ -22,6 +23,7 @@
* different static inner class representing each type.
*/
public final class OnnxSparseTensor extends OnnxTensorLike {
private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName());
private final SparseTensorType sparseTensorType;

// Held to prevent deallocation while used in native code.
Expand Down Expand Up @@ -198,6 +200,7 @@ public OnnxValueType getType() {

@Override
public SparseTensor<? extends Buffer> getValue() throws OrtException {
checkClosed();
Buffer buffer = getValuesBuffer();
long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
switch (sparseTensorType) {
Expand Down Expand Up @@ -234,8 +237,13 @@ public SparseTensor<? extends Buffer> getValue() throws OrtException {
}

@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed OnnxSparseTensor.");
}
}

/**
Expand All @@ -257,6 +265,7 @@ public SparseTensorType getSparseTensorType() {
* @return The indices.
*/
public Buffer getIndicesBuffer() {
checkClosed();
switch (sparseTensorType) {
case COO:
case CSRC:
Expand Down Expand Up @@ -295,6 +304,7 @@ public Buffer getIndicesBuffer() {
* @return The inner indices.
*/
public LongBuffer getInnerIndicesBuffer() {
checkClosed();
if (sparseTensorType == SparseTensorType.CSRC) {
LongBuffer buf =
getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
Expand All @@ -320,6 +330,7 @@ public LongBuffer getInnerIndicesBuffer() {
* @return The data buffer.
*/
public Buffer getValuesBuffer() {
checkClosed();
ByteBuffer buffer =
getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
switch (info.type) {
Expand Down Expand Up @@ -396,6 +407,7 @@ public Buffer getValuesBuffer() {
* @return The indices shape.
*/
public long[] getIndicesShape() {
checkClosed();
return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}

Expand All @@ -405,6 +417,7 @@ public long[] getIndicesShape() {
* @return The indices shape.
*/
public long[] getInnerIndicesShape() {
checkClosed();
if (sparseTensorType == SparseTensorType.CSRC) {
return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
} else {
Expand All @@ -420,6 +433,7 @@ public long[] getInnerIndicesShape() {
* @return The values shape.
*/
public long[] getValuesShape() {
checkClosed();
return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}

Expand Down
24 changes: 19 additions & 5 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Optional;
import java.util.logging.Logger;

/**
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
* returned as outputs.
*/
public class OnnxTensor extends OnnxTensorLike {
private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName());

/**
* This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
Expand Down Expand Up @@ -97,6 +99,7 @@ public OnnxValueType getType() {
*/
@Override
public Object getValue() throws OrtException {
checkClosed();
if (info.isScalar()) {
switch (info.type) {
case FLOAT:
Expand Down Expand Up @@ -144,16 +147,21 @@ public Object getValue() throws OrtException {

@Override
public String toString() {
return "OnnxTensor(info=" + info.toString() + ")";
return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")";
}

/**
* Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If
* it is backed by a buffer then the memory is released when the buffer is GC'd.
* Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it
* is backed by a buffer then the memory is released when the buffer is GC'd.
*/
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed tensor.");
}
}

/**
Expand All @@ -165,6 +173,7 @@ public void close() {
* @return A ByteBuffer copy of the OnnxTensor.
*/
public ByteBuffer getByteBuffer() {
checkClosed();
if (info.type != OnnxJavaType.STRING) {
ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
Expand All @@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() {
* @return A FloatBuffer copy of the OnnxTensor.
*/
public FloatBuffer getFloatBuffer() {
checkClosed();
if (info.type == OnnxJavaType.FLOAT) {
// if it's fp32 use the efficient copy.
FloatBuffer buffer = getBuffer().asFloatBuffer();
Expand Down Expand Up @@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() {
* @return A DoubleBuffer copy of the OnnxTensor.
*/
public DoubleBuffer getDoubleBuffer() {
checkClosed();
if (info.type == OnnxJavaType.DOUBLE) {
DoubleBuffer buffer = getBuffer().asDoubleBuffer();
DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity());
Expand All @@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() {
* @return A ShortBuffer copy of the OnnxTensor.
*/
public ShortBuffer getShortBuffer() {
checkClosed();
if ((info.type == OnnxJavaType.INT16)
|| (info.type == OnnxJavaType.FLOAT16)
|| (info.type == OnnxJavaType.BFLOAT16)) {
Expand All @@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() {
* @return An IntBuffer copy of the OnnxTensor.
*/
public IntBuffer getIntBuffer() {
checkClosed();
if (info.type == OnnxJavaType.INT32) {
IntBuffer buffer = getBuffer().asIntBuffer();
IntBuffer output = IntBuffer.allocate(buffer.capacity());
Expand All @@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() {
* @return A LongBuffer copy of the OnnxTensor.
*/
public LongBuffer getLongBuffer() {
checkClosed();
if (info.type == OnnxJavaType.INT64) {
LongBuffer buffer = getBuffer().asLongBuffer();
LongBuffer output = LongBuffer.allocate(buffer.capacity());
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue {
/** The size and shape information for this tensor. */
protected final TensorInfo info;

/** Is this value closed? */
protected boolean closed;

/**
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
*
Expand All @@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
this.closed = false;
}

/**
Expand All @@ -59,4 +63,16 @@ long getNativeHandle() {
public TensorInfo getInfo() {
return info;
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}
}
9 changes: 8 additions & 1 deletion java/src/main/java/ai/onnxruntime/OnnxValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ public enum OnnxValueType {
*/
public ValueInfo getInfo();

/** Closes the OnnxValue, freeing it's native memory. */
/**
* Checks if this value is closed (i.e., the native object has been released).
*
* @return True if the value is closed and the native object has been released.
*/
public boolean isClosed();

/** Closes the OnnxValue, freeing its native memory. */
@Override
public void close();

Expand Down
Loading
Loading