Skip to content

Commit

Permalink
Implements NDScope to automatically close NDArray in the scope (#2321)
Browse files Browse the repository at this point in the history
* Implements NDScope based on JavaCPP PointerScope

Co-authored-by: enpasos <[email protected]>
Co-authored-by: Frank Liu <[email protected]>
  • Loading branch information
3 people authored Jan 26, 2023
1 parent acfc0a6 commit b533011
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 0 deletions.
82 changes: 82 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDScope.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.IdentityHashMap;

/**
* A class that tracks {@link NDResource} objects created in the try-with-resource block and close
* them automatically when out of the block scope.
*
* <p>This class has been derived from {@code org.bytedeco.javacpp.PointerScope} by Samuel Audet
*/
public class NDScope implements AutoCloseable {

private static final ThreadLocal<Deque<NDScope>> SCOPE_STACK =
ThreadLocal.withInitial(ArrayDeque::new);

private IdentityHashMap<NDArray, NDArray> resources;

/** Constructs a new {@code NDScope} instance. */
public NDScope() {
resources = new IdentityHashMap<>();
SCOPE_STACK.get().addLast(this);
}

/**
* Registers {@link NDArray} object to this scope.
*
* @param array the {@link NDArray} object
*/
public static void register(NDArray array) {
Deque<NDScope> queue = SCOPE_STACK.get();
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.put(array, array);
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param array the {@link NDArray} object
*/
public static void unregister(NDArray array) {
Deque<NDScope> queue = SCOPE_STACK.get();
if (queue.isEmpty()) {
return;
}
queue.getLast().resources.remove(array);
}

/** {@inheritDoc} */
@Override
public void close() {
for (NDArray array : resources.keySet()) {
array.close();
}
SCOPE_STACK.get().remove(this);
}

/**
* A method that does nothing.
*
* <p>You may use it if you do not have a better way to suppress the warning of a created but
* not explicitly used scope.
*/
public void suppressNotUsedWarning() {
// do nothing
}
}
48 changes: 48 additions & 0 deletions api/src/test/java/ai/djl/ndarray/NDScopeTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray;

import org.testng.Assert;
import org.testng.annotations.Test;

public class NDScopeTest {

@Test
@SuppressWarnings("try")
public void testNDScope() {
NDArray detached;
NDArray inside;
NDArray uninvolved;
try (NDManager manager = NDManager.newBaseManager()) {
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
try (NDScope ignore = new NDScope()) {
uninvolved = manager.create(new int[] {1});
uninvolved.close();
inside = manager.create(new int[] {1});
// not tracked by any NDScope, but still managed by NDManager
NDScope.unregister(inside);
}

detached = manager.create(new int[] {1});
detached.detach(); // detached from NDManager
NDScope.unregister(detached); // and unregistered from NDScope
}

Assert.assertFalse(inside.isReleased());
}
Assert.assertTrue(inside.isReleased());
Assert.assertFalse(detached.isReleased());
detached.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -89,6 +90,7 @@ public class MxNDArray extends NativeResource<Pointer> implements LazyNDArray {
this.manager = manager;
mxNDArrayEx = new MxNDArrayEx(this);
manager.attachInternal(getUid(), this);
NDScope.register(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
Expand Down Expand Up @@ -64,6 +65,7 @@ public PtNDArray(PtNDManager manager, long handle) {
this.manager = manager;
this.ptNDArrayEx = new PtNDArrayEx(this);
manager.attachInternal(getUid(), this);
NDScope.register(this);
}

/**
Expand All @@ -80,6 +82,7 @@ public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) {
this.ptNDArrayEx = new PtNDArrayEx(this);
manager.attachInternal(getUid(), this);
dataRef = data;
NDScope.register(this);
}

/**
Expand All @@ -96,6 +99,7 @@ public PtNDArray(PtNDManager manager, String[] strs, Shape shape) {
this.strs = strs;
this.shape = shape;
this.dataType = DataType.STRING;
NDScope.register(this);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -55,6 +56,7 @@ public class TfNDArray extends NativeResource<TFE_TensorHandle> implements NDArr
this.manager = manager;
manager.attachInternal(getUid(), this);
tfNDArrayEx = new TfNDArrayEx(this);
NDScope.register(this);
}

TfNDArray(TfNDManager manager, TFE_TensorHandle handle, TF_Tensor tensor) {
Expand Down

0 comments on commit b533011

Please sign in to comment.