From 1552f9685a5a51f7d6a17ebb59fa9e49d490441d Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 11:00:32 +0100 Subject: [PATCH 01/30] just the poc, compiles without testing --- .../java/ai/djl/ndarray/BaseNDManager.java | 44 +++++ .../ndarray/gc/DynamicInvocationHandler.java | 67 +++++++ .../ai/djl/ndarray/gc/NDArrayWrapFactory.java | 51 ++++++ .../ai/djl/ndarray/gc/WeakHashMapWrapper.java | 163 ++++++++++++++++++ .../djl/ndarray/gc/WeakReferenceWrapper.java | 38 ++++ .../java/ai/djl/ndarray/gc/package-info.java | 15 ++ .../ai/djl/pytorch/engine/PtNDManager.java | 5 + .../ai/djl/pytorch/integration/gc/Main.java | 57 ++++++ .../pytorch/integration/gc/package-info.java | 14 ++ 9 files changed, 454 insertions(+) create mode 100644 api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java create mode 100644 api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java create mode 100644 api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java create mode 100644 api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java create mode 100644 api/src/main/java/ai/djl/ndarray/gc/package-info.java create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 59b69fbe729..1f925b048e4 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -469,6 +469,50 @@ public void debugDump(int level) { } } + /** + * Prints information about this {@link NDManager} and all sub-managers to the console. + * + * @param level the level of this {@link NDManager} in the hierarchy + */ + public void debugDumpDetailed(int level) { + StringBuilder sb = new StringBuilder(100); + for (int i = 0; i < level; ++i) { + sb.append(" "); + } + sb.append("\\--- NDManager(") + .append(uid.substring(24)) + .append(", ") + .append(device) + .append(") resource count: ") + .append(resources.size()); + + System.out.println(sb); // NOPMD + for (AutoCloseable c : resources.values()) { + if (c instanceof NDManager) { + ((BaseNDManager) c).debugDumpDetailed(level + 1); + } else if (c instanceof NDArray) { + StringBuilder sb2 = new StringBuilder(100); + for (int i = 0; i < level + 1; ++i) { + sb2.append(" "); + } + sb2.append( + "\\--- NDArray(" + + ((NDArray) c).getUid() + + ", Shape" + + ((NDArray) c).getShape() + + ")"); + System.out.println(sb2); // NOPMD + } else if (c instanceof NDResource) { + StringBuilder sb2 = new StringBuilder(100); + for (int i = 0; i < level + 1; ++i) { + sb2.append(" "); + } + sb2.append("\\--- other NDResource"); + System.out.println(sb2); // NOPMD + } + } + } + NDManager getAlternativeManager() { return alternativeManager; } diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java new file mode 100644 index 00000000000..4c268f070aa --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -0,0 +1,67 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.UUID; + +/** {@code DynamicInvocationHandler} implements the {@link InvocationHandler}. */ +public class DynamicInvocationHandler implements InvocationHandler { + + private static final Logger logger = LoggerFactory.getLogger(DynamicInvocationHandler.class); + + WeakHashMapWrapper map; + UUID uuid; + + NDArrayWrapFactory gcAttacher; + + /** + * Creates a new instance of {@code DynamicInvocationHandler}. + * + * @param uuid the uuid + * @param map the map + * @param gcAttacher the gcAttacher + */ + public DynamicInvocationHandler( + UUID uuid, WeakHashMapWrapper map, NDArrayWrapFactory gcAttacher) { + this.map = map; + this.uuid = uuid; + this.gcAttacher = gcAttacher; + } + + /** {@inheritDoc} */ + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + + Object result; + try { + result = method.invoke(map.get(uuid), args); + } catch (IllegalAccessException | InvocationTargetException e) { + logger.error("Error invoking method", e); + throw new RuntimeException(e); // NOPMD + } + + if (result instanceof NDArray) { + return gcAttacher.wrap((NDArray) result); + } + + return result; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java b/api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java new file mode 100644 index 00000000000..df8131a5d3d --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java @@ -0,0 +1,51 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import java.lang.reflect.Proxy; +import java.util.UUID; + +/** {@code NDArrayWrapFactory} creates a proxy facade. */ +public class NDArrayWrapFactory { + + WeakHashMapWrapper map = new WeakHashMapWrapper<>(); + + /** + * Returns the size of the map. + * + * @return the size of the map + */ + public int mapSize() { + return map.size(); + } + + /** + * Wraps the {@link NDArray} in a proxy facade. + * + * @param array the array to wrap + * @return the wrapped array + */ + public NDArray wrap(NDArray array) { + UUID uuid = UUID.randomUUID(); + map.put(uuid, array); + + DynamicInvocationHandler handler = new DynamicInvocationHandler(uuid, map, this); + return (NDArray) + Proxy.newProxyInstance( + Thread.currentThread().getContextClassLoader(), + new Class[] {NDArray.class}, + handler); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java new file mode 100644 index 00000000000..1420759aa97 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -0,0 +1,163 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.WeakHashMap; + +/** + * {@code WeakHashMapWrapper} wraps a {@link WeakHashMap}. It has a {@link ReferenceQueue} to + * get informed by the garbage collector. On method invocations it looks for messages from the + * garbage collector and removes the corresponding entries. + */ +public class WeakHashMapWrapper implements Map { + + private static final Logger logger = LoggerFactory.getLogger(WeakHashMapWrapper.class); + + private final WeakHashMap map = new WeakHashMap<>(); + private final ReferenceQueue queue = new ReferenceQueue<>(); + + private final List> weakReferenceWrapperList = new ArrayList<>(); + + private void checkQueue() { + for (Reference ref; (ref = queue.poll()) != null; ) { + synchronized (queue) { + @SuppressWarnings("unchecked") + WeakReferenceWrapper ref2 = (WeakReferenceWrapper) ref; + V value = ref2.getValue(); + if (value instanceof NDArray) { // just as one example + logger.info( + "NDArray is closed triggered by a message from the garbage collector"); + ((NDArray) value).close(); + } + } + } + } + + // implement all methods of Map interface by calling corresponding methods of + // WeakHashMap instance map + + /** {@inheritDoc} */ + @Override + public int size() { + checkQueue(); + return map.size(); + } + + /** {@inheritDoc} */ + @Override + public boolean isEmpty() { + checkQueue(); + return map.isEmpty(); + } + + /** {@inheritDoc} */ + @Override + public boolean containsKey(Object key) { + checkQueue(); + return map.containsKey(key); + } + + /** {@inheritDoc} */ + @Override + public boolean containsValue(Object value) { + checkQueue(); + return map.containsValue(value); + } + + /** {@inheritDoc} */ + @Override + public V get(Object key) { + checkQueue(); + return map.get(key); + } + + /** {@inheritDoc} */ + @Override + public V put(K key, V value) { + weakReferenceWrapperList.add(new WeakReferenceWrapper(key, value, queue)); + return map.put(key, value); + } + + /** {@inheritDoc} */ + @Override + public V remove(Object key) { + checkQueue(); + return map.remove(key); + } + + /** {@inheritDoc} */ + @Override + public void putAll(Map m) { + checkQueue(); + map.putAll(m); + } + + /** {@inheritDoc} */ + @Override + public void clear() { + checkQueue(); + map.clear(); + } + + /** {@inheritDoc} */ + @Override + public Set keySet() { + checkQueue(); + return map.keySet(); + } + + /** {@inheritDoc} */ + @Override + public Collection values() { + checkQueue(); + return map.values(); + } + + /** {@inheritDoc} */ + @Override + public Set> entrySet() { + checkQueue(); + return map.entrySet(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + checkQueue(); + return map.equals(o); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return map.hashCode(); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return map.toString(); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java new file mode 100644 index 00000000000..e6426b0eaac --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; + +/** + * {@code WeakReferenceWrapper} extends a {@link WeakReference}. It uses an object of type K + * as the referent has an object property of type V. + */ +public class WeakReferenceWrapper extends WeakReference { + V value; + + WeakReferenceWrapper(K key, V value, ReferenceQueue queue) { + super(key, queue); + this.value = value; + } + + /** + * Returns the value. + * + * @return the value + */ + public V getValue() { + return value; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/package-info.java b/api/src/main/java/ai/djl/ndarray/gc/package-info.java new file mode 100644 index 00000000000..1bbae260202 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes that help to garbage collect orphaned resources. */ +package ai.djl.ndarray.gc; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 71f0cb57fdb..e282c27b7f1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -186,6 +186,11 @@ public final Engine getEngine() { return Engine.getEngine(PtEngine.ENGINE_NAME); } + /** Dumps debug information about the current {@link PtNDManager} and all its children. */ + public static void debugDumpFromSystemManager() { + ((BaseNDManager) PtNDManager.getSystemManager()).debugDumpDetailed(0); + } + /** The SystemManager is the root {@link PtNDManager} of which all others are children. */ private static final class SystemManager extends PtNDManager implements SystemNDManager { diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java new file mode 100644 index 00000000000..623b42cb1d3 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration.gc; + +import static ai.djl.pytorch.engine.PtNDManager.debugDumpFromSystemManager; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.NDArrayWrapFactory; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** Some poc testing code. */ +public final class Main { + + private static final Logger logger = LoggerFactory.getLogger(Main.class); + + private Main() {} + + public static void main(String[] args) + throws IOException, TranslateException, InterruptedException { + NDArrayWrapFactory enh = new NDArrayWrapFactory(); + try (NDManager manager = NDManager.newBaseManager(); ) { + + NDArray a = enh.wrap(manager.create(new float[] {1f})); + debugDumpFromSystemManager(); + + logger.info("reference exists ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); + a = null; + logger.info("no reference exists, but likely not yet garbage collected ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); + + System.gc(); + TimeUnit.SECONDS.sleep(1); + + logger.info("no reference exists, and likely garbage collected ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); + debugDumpFromSystemManager(); + } + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java new file mode 100644 index 00000000000..10391a021d0 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** The integration test for testing PyTorch specific features. */ +package ai.djl.pytorch.integration.gc; From 52cbaf652e3fa74f2f92d8799acf5ba6f6e510b4 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 13:15:28 +0100 Subject: [PATCH 02/30] moved PtNDArray interface out --- .../ndarray/gc/DynamicInvocationHandler.java | 10 +- .../ai/djl/ndarray/gc/NDArrayProxyMaker.java | 38 + .../main/java/ai/djl/util/NativeResource.java | 69 +- .../java/ai/djl/util/NativeResourceImpl.java | 74 + .../java/ai/djl/mxnet/engine/CachedOp.java | 4 +- .../java/ai/djl/mxnet/engine/MxNDArray.java | 4 +- .../djl/mxnet/engine/MxParameterServer.java | 4 +- .../main/java/ai/djl/mxnet/engine/Symbol.java | 4 +- .../paddlepaddle/engine/PaddlePredictor.java | 4 +- .../pytorch/engine/PtGradientCollector.java | 2 +- .../java/ai/djl/pytorch/engine/PtNDArray.java | 1616 ++++------------- .../ai/djl/pytorch/engine/PtNDArrayEx.java | 4 +- .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 1561 ++++++++++++++++ .../djl/pytorch/engine/PtNDArrayIndexer.java | 10 +- .../pytorch/engine/PtNDArrayProxyMaker.java | 18 +- .../ai/djl/pytorch/engine/PtNDManager.java | 2 +- .../main/java/ai/djl/pytorch/jni/IValue.java | 11 +- .../java/ai/djl/pytorch/jni/JniUtils.java | 311 ++-- .../djl/pytorch/integration/IValueTest.java | 9 +- .../ai/djl/pytorch/integration/gc/Main.java | 36 +- .../ai/djl/tensorflow/engine/TfNDArray.java | 4 +- .../java/ai/djl/fasttext/jni/FtWrapper.java | 4 +- .../ai/djl/sentencepiece/SpProcessor.java | 4 +- .../tokenizers/HuggingFaceTokenizer.java | 4 +- 24 files changed, 2221 insertions(+), 1586 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java create mode 100644 api/src/main/java/ai/djl/util/NativeResourceImpl.java create mode 100644 engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java rename api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java => engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java (73%) diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java index 4c268f070aa..5947ec102c7 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -30,20 +30,20 @@ public class DynamicInvocationHandler implements InvocationHandler { WeakHashMapWrapper map; UUID uuid; - NDArrayWrapFactory gcAttacher; + NDArrayProxyMaker ndArrayProxyMaker; /** * Creates a new instance of {@code DynamicInvocationHandler}. * * @param uuid the uuid * @param map the map - * @param gcAttacher the gcAttacher + * @param gcAttacher the ndArrayProxyMaker */ public DynamicInvocationHandler( - UUID uuid, WeakHashMapWrapper map, NDArrayWrapFactory gcAttacher) { + UUID uuid, WeakHashMapWrapper map, NDArrayProxyMaker ndArrayProxyMaker) { this.map = map; this.uuid = uuid; - this.gcAttacher = gcAttacher; + this.ndArrayProxyMaker = ndArrayProxyMaker; } /** {@inheritDoc} */ @@ -59,7 +59,7 @@ public Object invoke(Object proxy, Method method, Object[] args) { } if (result instanceof NDArray) { - return gcAttacher.wrap((NDArray) result); + return ndArrayProxyMaker.wrap((NDArray) result); } return result; diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java new file mode 100644 index 00000000000..e83a2a27f64 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import java.lang.reflect.Proxy; +import java.util.UUID; + +/** {@code PtNDArrayProxyMaker} creates a proxy facade. */ +public interface NDArrayProxyMaker { + + + /** + * Returns the size of the map. + * + * @return the size of the map + */ + int mapSize(); + + /** + * Wraps the {@link NDArray} in a proxy facade. + * + * @param array the array to wrap + * @return the wrapped array + */ + NDArray wrap(NDArray array) ; +} diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index 65aa8f0085a..c72ba39e9ee 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -1,71 +1,12 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ package ai.djl.util; -import com.sun.jna.Pointer; +public interface NativeResource extends AutoCloseable { + boolean isReleased(); -import java.util.concurrent.atomic.AtomicReference; + T getHandle(); -/** - * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in - * the different engines. - * - * @param the resource that could map to a native pointer or java object - */ -public abstract class NativeResource implements AutoCloseable { + String getUid(); - protected final AtomicReference handle; - private String uid; - - protected NativeResource(T handle) { - this.handle = new AtomicReference<>(handle); - uid = handle.toString(); - } - - /** - * Gets the boolean that indicates whether this resource has been released. - * - * @return whether this resource has been released - */ - public boolean isReleased() { - return handle.get() == null; - } - - /** - * Gets the {@link Pointer} to this resource. - * - * @return the {@link Pointer} to this resource - */ - public T getHandle() { - T reference = handle.get(); - if (reference == null) { - throw new IllegalStateException("Native resource has been release already."); - } - return reference; - } - - /** - * Gets the unique ID of this resource. - * - * @return the unique ID of this resource - */ - public final String getUid() { - return uid; - } - - /** {@inheritDoc} */ @Override - public void close() { - throw new UnsupportedOperationException("Not implemented."); - } + void close(); } diff --git a/api/src/main/java/ai/djl/util/NativeResourceImpl.java b/api/src/main/java/ai/djl/util/NativeResourceImpl.java new file mode 100644 index 00000000000..0ccfa351336 --- /dev/null +++ b/api/src/main/java/ai/djl/util/NativeResourceImpl.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +import com.sun.jna.Pointer; + +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in + * the different engines. + * + * @param the resource that could map to a native pointer or java object + */ +public abstract class NativeResourceImpl implements NativeResource { + + protected final AtomicReference handle; + private String uid; + + protected NativeResourceImpl(T handle) { + this.handle = new AtomicReference<>(handle); + uid = handle.toString(); + } + + /** + * Gets the boolean that indicates whether this resource has been released. + * + * @return whether this resource has been released + */ + @Override + public boolean isReleased() { + return handle.get() == null; + } + + /** + * Gets the {@link Pointer} to this resource. + * + * @return the {@link Pointer} to this resource + */ + @Override + public T getHandle() { + T reference = handle.get(); + if (reference == null) { + throw new IllegalStateException("Native resource has been release already."); + } + return reference; + } + + /** + * Gets the unique ID of this resource. + * + * @return the unique ID of this resource + */ + @Override + public final String getUid() { + return uid; + } + + /** {@inheritDoc} */ + @Override + public void close() { + throw new UnsupportedOperationException("Not implemented."); + } +} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 62398b1868e..403f284f6ae 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -19,7 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.Parameter; import ai.djl.training.ParameterStore; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Pair; import ai.djl.util.PairList; @@ -40,7 +40,7 @@ * analyzing the input shape. It requires minimum input to do inference because most of the * information can be obtained from the model itself. */ -public class CachedOp extends NativeResource { +public class CachedOp extends NativeResourceImpl { private static final Logger logger = LoggerFactory.getLogger(CachedOp.class); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 8f9dcbc0bdc..d5e5a7ba5a8 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -23,8 +23,8 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import com.sun.jna.Native; import com.sun.jna.Pointer; @@ -36,7 +36,7 @@ import java.util.stream.IntStream; /** {@code MxNDArray} is the MXNet implementation of {@link NDArray}. */ -public class MxNDArray extends NativeResource implements LazyNDArray { +public class MxNDArray extends NativeResourceImpl implements LazyNDArray { private String name; private Device device; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java index 36bead164e4..bfb51646fc8 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java @@ -20,14 +20,14 @@ import ai.djl.ndarray.NDManager; import ai.djl.training.ParameterServer; import ai.djl.training.optimizer.Optimizer; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import com.sun.jna.Pointer; import java.util.Arrays; /** {@code MxParameterServer} is the MXNet implementation of {@link ParameterServer}. */ -public class MxParameterServer extends NativeResource implements ParameterServer { +public class MxParameterServer extends NativeResourceImpl implements ParameterServer { @SuppressWarnings("PMD.SingularField") // use class field to hold the OptimizerCallback which prevent it from being gc. diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java index f724aad4bde..7ba921ec00f 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java @@ -15,7 +15,7 @@ import ai.djl.Device; import ai.djl.mxnet.jna.JnaUtils; import ai.djl.ndarray.types.Shape; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.PairList; import ai.djl.util.Utils; @@ -37,7 +37,7 @@ * @see MXNet * Symbol */ -public class Symbol extends NativeResource { +public class Symbol extends NativeResourceImpl { // private String[] argParams; // private String[] auxParams; diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java index d9f588a2bff..cde0e533e9c 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java @@ -13,10 +13,10 @@ package ai.djl.paddlepaddle.engine; import ai.djl.paddlepaddle.jni.JniUtils; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; /** PaddlePaddle C++ Predictor. */ -public class PaddlePredictor extends NativeResource { +public class PaddlePredictor extends NativeResourceImpl { PaddlePredictor(long handle) { super(handle); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index 5ef573cee68..b253e73f306 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -64,7 +64,7 @@ public void backward(NDArray target) { * higher order derivative products. Defaults to false. */ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean createGraph) { - JniUtils.backward((PtNDArray) target, (PtNDArray) grad, keepGraph, createGraph); + JniUtils.backward((PtNDArrayImpl) target, (PtNDArrayImpl) grad, keepGraph, createGraph); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index d51cc0acdd4..a67aabe5950 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1,1561 +1,569 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ package ai.djl.pytorch.engine; import ai.djl.Device; -import ai.djl.ndarray.BaseNDManager; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; -import ai.djl.pytorch.jni.JniUtils; import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** {@code PtNDArray} is the PyTorch implementation of {@link NDArray}. */ -public class PtNDArray extends NativeResource implements NDArray { - - private String name; - private Device device; - private DataType dataType; - private Shape shape; - private SparseFormat sparseFormat; - // use Boolean object to maintain three status: null, false, true - private Boolean hasGradient; - private PtNDManager manager; - private PtNDArrayEx ptNDArrayEx; - private String[] strs; - - // keep a reference to direct buffer to avoid GC release the memory - @SuppressWarnings("PMD.UnusedPrivateField") - private ByteBuffer dataRef; - - /** - * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead). - * - * @param manager the manager to attach the new array to - * @param handle the pointer to the native PyTorch memory - */ - public PtNDArray(PtNDManager manager, long handle) { - super(handle); - this.manager = manager; - this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attachInternal(getUid(), this); - } - - /** - * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead) with the data that is hold on Java side. - * - * @param manager the manager to attach the new array to - * @param handle the pointer to the native PyTorch memory - * @param data the direct buffer of the data - */ - public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { - super(handle); - this.manager = manager; - this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attachInternal(getUid(), this); - dataRef = data; - } - - /** - * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle - * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. - * - * @param manager the manager to attach the new array to - * @param strs the string array - * @param shape the {@link Shape} of the {@link NDArray} - */ - public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { - super(-1L); - this.manager = manager; - this.strs = strs; - this.shape = shape; - this.dataType = DataType.STRING; - } - - /** {@inheritDoc} */ - @Override - public PtNDManager getManager() { - return manager; - } - - /** {@inheritDoc} */ - @Override - public String getName() { - return name; - } - - /** {@inheritDoc} */ - @Override - public void setName(String name) { - this.name = name; - } - - /** {@inheritDoc} */ - @Override - public DataType getDataType() { - if (dataType == null) { - dataType = JniUtils.getDataType(this); - } - return dataType; - } - - /** {@inheritDoc} */ - @Override - public Device getDevice() { - if (device == null) { - device = JniUtils.getDevice(this); - } - return device; - } - - /** {@inheritDoc} */ - @Override - public Shape getShape() { - if (shape == null) { - shape = JniUtils.getShape(this); - } - return shape; - } - - /** {@inheritDoc} */ - @Override - public SparseFormat getSparseFormat() { - if (sparseFormat == null) { - sparseFormat = JniUtils.getSparseFormat(this); - } - return sparseFormat; - } - - /** {@inheritDoc} */ - @Override - public PtNDArray toDevice(Device device, boolean copy) { - if (device.equals(getDevice()) && !copy) { - return this; - } - return JniUtils.to(this, getDataType(), device); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray toType(DataType dataType, boolean copy) { - if (dataType.equals(getDataType()) && !copy) { - return this; - } - return JniUtils.to(this, dataType, getDevice()); - } - - /** {@inheritDoc} */ - @Override - public void setRequiresGradient(boolean requiresGrad) { - JniUtils.attachGradient(this, requiresGrad); - hasGradient = requiresGrad; - } - - /** {@inheritDoc} */ - @Override - public PtNDArray getGradient() { - if (!hasGradient()) { - throw new IllegalStateException( - "No gradient attached to this NDArray, please call array.setRequiresGradient()" - + " on your NDArray or block.setInitializer() on your Block"); - } - PtNDArray res = JniUtils.getGradient(this); - // If you call getGradient() before you run the backward, - // you will get nothing in PyTorch engine. - // To align with MXNet's behavior, we will create a zeros NDArray. - // TODO should we access the grad NDArray after we close the parameter NDArray? - if (res == null) { - res = (PtNDArray) manager.zeros(getShape()); - } - return res; - } - - /** {@inheritDoc} */ - @Override - public boolean hasGradient() { - if (hasGradient == null) { - hasGradient = JniUtils.requiresGrad(this); - } - return hasGradient; - } - - /** {@inheritDoc} */ - @Override - public NDArray stopGradient() { - return JniUtils.detachGradient(this); - } - - /** {@inheritDoc} */ - @Override - public ByteBuffer toByteBuffer() { - return JniUtils.getByteBuffer(this); - } - - /** {@inheritDoc} */ - @Override - public String[] toStringArray(Charset charset) { - return strs; - } - - /** {@inheritDoc} */ - @Override - public void set(Buffer buffer) { - int size = Math.toIntExact(size()); - DataType type = getDataType(); - BaseNDManager.validateBuffer(buffer, type, size); - // TODO how do we handle the exception happened in the middle - dataRef = null; - if (buffer.isDirect() && buffer instanceof ByteBuffer) { - // If NDArray is on the GPU, it is native code responsibility to control the data life - // cycle - if (!getDevice().isGpu()) { - dataRef = (ByteBuffer) buffer; - } - JniUtils.set(this, (ByteBuffer) buffer); - return; - } - // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType - ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes()); - BaseNDManager.copyBuffer(buffer, buf); - - // If NDArray is on the GPU, it is native code responsibility to control the data life cycle - if (!getDevice().isGpu()) { - dataRef = buf; - } - JniUtils.set(this, buf); - } - - /** {@inheritDoc} */ - @Override - public NDArray get(NDManager manager, long... indices) { - return JniUtils.getItem(this, indices, (PtNDManager) manager); - } - - /** {@inheritDoc} */ - @Override - public NDArray gather(NDArray index, int axis) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray index is supported."); - } - return JniUtils.gather(this, (PtNDArray) index, axis); - } - - /** {@inheritDoc} */ - @Override - public NDArray gatherNd(NDArray index) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray index is supported."); - } - Shape indexShape = index.getShape(); - Shape dataShape = getShape(); - int indexingDepth = (int) indexShape.get(0); - if (indexingDepth > dataShape.dimension()) { - throw new IllegalArgumentException( - "Indexing rank " - + indexShape.get(0) - + " exceeds the data rank " - + dataShape.dimension()); - } - // Row-first order, the linear index is accumulated from z->y->x. - // For example, dataShape = (3, 2, 3), indexShape = (2, 3, 3) - // The method is: indexLinear = index[1] + index[0] * dataShape[1], row-first order - // indexLinear has shape (3, 3), is from combining the index along 0 axis. - // Each number in indexLinear is an indexing to an element in data (3, 2, ...). - // data is flattened to be (3*2, ...) which can be indexed by indexLinear. - // Finally, reshape the output to (3, 3, ...). Thus - // totalShape = indexShape.slice(1).addAll(dataShape.slice(indexingDepth)); - NDArray indexLinear = index.get("{}, ...", indexingDepth - 1); - long dim = 1; - for (int i = indexingDepth - 2; i > -1; i--) { - dim = dim * dataShape.get(i + 1); - indexLinear = indexLinear.addi(index.get("{}, ...", i).muli(dim)); - } - NDArray dataFlatten = this.flatten(0, indexingDepth - 1); - return dataFlatten.get(indexLinear); - } - - /** {@inheritDoc} */ - @Override - public NDArray take(NDManager manager, NDArray index) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray is supported."); - } - return JniUtils.take(this, (PtNDArray) index, (PtNDManager) manager); - } - - /** {@inheritDoc} */ - @Override - public NDArray put(NDArray index, NDArray data) { - if (!(index instanceof PtNDArray) || !(data instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray is supported."); - } - return JniUtils.put(this, (PtNDArray) index, (PtNDArray) data); - } - - /** {@inheritDoc} */ - @Override - public void copyTo(NDArray array) { - throw new UnsupportedOperationException("Not implemented"); - } - - /** {@inheritDoc} */ - @Override - public void attach(NDManager manager) { - detach(); - this.manager = (PtNDManager) manager; - manager.attachInternal(getUid(), this); - } - - /** {@inheritDoc} */ - @Override - public void returnResource(NDManager manager) { - detach(); - this.manager = (PtNDManager) manager; - manager.attachUncappedInternal(getUid(), this); - } - - /** {@inheritDoc} */ - @Override - public void tempAttach(NDManager manager) { - NDManager original = this.manager; - detach(); - this.manager = (PtNDManager) manager; - manager.tempAttachInternal(original, getUid(), this); - } - - /** {@inheritDoc} */ - @Override - public void detach() { - manager.detachInternal(getUid()); - manager = PtNDManager.getSystemManager(); - } - - /** {@inheritDoc} */ - @Override - public NDArray duplicate() { - return JniUtils.clone(this); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray booleanMask(NDArray index, int axis) { - Shape indexShape = index.getShape(); - if (indexShape.equals(getShape())) { - // Result is flattened since shape is undetermined - return JniUtils.booleanMask(this, manager.from(index)); - } else if (indexShape.equals(getShape().slice(axis))) { - // index will be broadcast by default - try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) { - // Shape recovery - Shape remainder = getShape().slice(0, axis); - long selectedSize = flattedResult.getShape().size() / remainder.size(); - return flattedResult.reshape(remainder.addAll(new Shape(selectedSize))); - } - } else { - throw new UnsupportedOperationException( - "Not supported for shape not broadcastable " - + indexShape - + " vs " - + getShape()); - } - } - - /** {@inheritDoc} */ - @Override - public NDArray sequenceMask(NDArray sequenceLength, float value) { - throw new UnsupportedOperationException("Not implemented yet"); - } - - /** {@inheritDoc} */ - @Override - public NDArray sequenceMask(NDArray sequenceLength) { - throw new UnsupportedOperationException("Not implemented yet"); - } - /** {@inheritDoc} */ +public interface PtNDArray extends NativeResource, NDArray { @Override - public boolean contentEquals(Number number) { - return JniUtils.contentEqual(this, (PtNDArray) manager.create(number)); - } + PtNDManager getManager(); - /** {@inheritDoc} */ @Override - public boolean contentEquals(NDArray other) { - if (other == null || (!shapeEquals(other))) { - return false; - } - if (getDataType() != other.getDataType()) { - return false; - } - return JniUtils.contentEqual(this, manager.from(other)); - } + String getName(); - /** {@inheritDoc} */ @Override - public PtNDArray eq(Number n) { - try (NDArray number = manager.create(n)) { - return eq(number); - } - } + void setName(String name); - /** {@inheritDoc} */ @Override - public PtNDArray eq(NDArray other) { - return JniUtils.eq(this, manager.from(other)); - } + DataType getDataType(); - /** {@inheritDoc} */ @Override - public PtNDArray neq(Number n) { - try (NDArray number = manager.create(n)) { - return neq(number); - } - } + Device getDevice(); - /** {@inheritDoc} */ @Override - public PtNDArray neq(NDArray other) { - return JniUtils.neq(this, manager.from(other)); - } + Shape getShape(); - /** {@inheritDoc} */ @Override - public PtNDArray gt(Number n) { - try (NDArray number = manager.create(n)) { - return gt(number); - } - } + SparseFormat getSparseFormat(); - /** {@inheritDoc} */ @Override - public PtNDArray gt(NDArray other) { - return JniUtils.gt(this, manager.from(other)); - } + PtNDArray toDevice(Device device, boolean copy); - /** {@inheritDoc} */ @Override - public PtNDArray gte(Number n) { - try (NDArray number = manager.create(n)) { - return gte(number); - } - } + PtNDArray toType(DataType dataType, boolean copy); - /** {@inheritDoc} */ @Override - public PtNDArray gte(NDArray other) { - return JniUtils.gte(this, manager.from(other)); - } + void setRequiresGradient(boolean requiresGrad); - /** {@inheritDoc} */ @Override - public PtNDArray lt(Number n) { - try (NDArray number = manager.create(n)) { - return lt(number); - } - } + PtNDArray getGradient(); - /** {@inheritDoc} */ @Override - public PtNDArray lt(NDArray other) { - return JniUtils.lt(this, manager.from(other)); - } + boolean hasGradient(); - /** {@inheritDoc} */ @Override - public PtNDArray lte(Number n) { - try (NDArray number = manager.create(n)) { - return lte(number); - } - } + NDArray stopGradient(); - /** {@inheritDoc} */ @Override - public PtNDArray lte(NDArray other) { - return JniUtils.lte(this, manager.from(other)); - } + ByteBuffer toByteBuffer(); - /** {@inheritDoc} */ @Override - public PtNDArray add(Number n) { - try (NDArray number = manager.create(n)) { - return add(number); - } - } + String[] toStringArray(Charset charset); - /** {@inheritDoc} */ @Override - public PtNDArray add(NDArray other) { - return JniUtils.add(this, manager.from(other)); - } + void set(Buffer buffer); - /** {@inheritDoc} */ @Override - public PtNDArray sub(Number n) { - try (NDArray number = manager.create(n)) { - return sub(number); - } - } + NDArray get(NDManager manager, long... indices); - /** {@inheritDoc} */ @Override - public PtNDArray sub(NDArray other) { - return JniUtils.sub(this, manager.from(other)); - } + NDArray gather(NDArray index, int axis); - /** {@inheritDoc} */ @Override - public PtNDArray mul(Number n) { - try (NDArray number = manager.create(n)) { - return mul(number); - } - } + NDArray gatherNd(NDArray index); - /** {@inheritDoc} */ @Override - public PtNDArray mul(NDArray other) { - return JniUtils.mul(this, manager.from(other)); - } + NDArray take(NDManager manager, NDArray index); - /** {@inheritDoc} */ @Override - public PtNDArray div(Number n) { - try (NDArray number = manager.create(n)) { - return div(number); - } - } + NDArray put(NDArray index, NDArray data); - /** {@inheritDoc} */ @Override - public PtNDArray div(NDArray other) { - return JniUtils.div(this, manager.from(other)); - } + void copyTo(NDArray array); - /** {@inheritDoc} */ @Override - public PtNDArray mod(Number n) { - try (NDArray number = manager.create(n)) { - return mod(number); - } - } + void attach(NDManager manager); - /** {@inheritDoc} */ @Override - public PtNDArray mod(NDArray other) { - return JniUtils.remainder(this, manager.from(other)); - } + void returnResource(NDManager manager); - /** {@inheritDoc} */ @Override - public PtNDArray pow(Number n) { - try (NDArray number = manager.create(n)) { - return pow(number); - } - } + void tempAttach(NDManager manager); - /** {@inheritDoc} */ @Override - public PtNDArray pow(NDArray other) { - return JniUtils.pow(this, manager.from(other)); - } + void detach(); - /** {@inheritDoc} */ @Override - public PtNDArray addi(Number n) { - try (NDArray number = manager.create(n)) { - return addi(number); - } - } + NDArray duplicate(); - /** {@inheritDoc} */ @Override - public PtNDArray addi(NDArray other) { - JniUtils.addi(this, manager.from(other)); - return this; - } + PtNDArray booleanMask(NDArray index, int axis); - /** {@inheritDoc} */ @Override - public PtNDArray subi(Number n) { - try (NDArray number = manager.create(n)) { - return subi(number); - } - } + NDArray sequenceMask(NDArray sequenceLength, float value); - /** {@inheritDoc} */ @Override - public PtNDArray subi(NDArray other) { - JniUtils.subi(this, manager.from(other)); - return this; - } + NDArray sequenceMask(NDArray sequenceLength); - /** {@inheritDoc} */ @Override - public PtNDArray muli(Number n) { - try (NDArray number = manager.create(n)) { - return muli(number); - } - } + boolean contentEquals(Number number); - /** {@inheritDoc} */ @Override - public PtNDArray muli(NDArray other) { - JniUtils.muli(this, manager.from(other)); - return this; - } + boolean contentEquals(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray divi(Number n) { - try (NDArray number = manager.create(n)) { - return divi(number); - } - } + PtNDArray eq(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray divi(NDArray other) { - JniUtils.divi(this, manager.from(other)); - return this; - } + PtNDArray eq(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray modi(Number n) { - try (NDArray number = manager.create(n)) { - return modi(number); - } - } + PtNDArray neq(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray modi(NDArray other) { - JniUtils.remainderi(this, manager.from(other)); - return this; - } + PtNDArray neq(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray powi(Number n) { - try (NDArray number = manager.create(n)) { - return powi(number); - } - } + PtNDArray gt(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray powi(NDArray other) { - JniUtils.powi(this, manager.from(other)); - return this; - } + PtNDArray gt(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray sign() { - return JniUtils.sign(this); - } + PtNDArray gte(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray signi() { - JniUtils.signi(this); - return this; - } + PtNDArray gte(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray maximum(Number n) { - try (NDArray number = manager.create(n)) { - return maximum(number); - } - } + PtNDArray lt(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray maximum(NDArray other) { - return JniUtils.max(this, manager.from(other)); - } + PtNDArray lt(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray minimum(Number n) { - try (NDArray number = manager.create(n)) { - return minimum(number); - } - } + PtNDArray lte(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray minimum(NDArray other) { - return JniUtils.min(this, manager.from(other)); - } + PtNDArray lte(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray all() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.all(bool); - } - } + PtNDArray add(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray any() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.any(bool); - } - } + PtNDArray add(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray none() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.none(bool); - } - } + PtNDArray sub(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray neg() { - return JniUtils.neg(this); - } + PtNDArray sub(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray negi() { - JniUtils.negi(this); - return this; - } + PtNDArray mul(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray abs() { - return JniUtils.abs(this); - } + PtNDArray mul(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray square() { - return JniUtils.square(this); - } + PtNDArray div(Number n); - /** {@inheritDoc} */ @Override - public NDArray sqrt() { - return JniUtils.sqrt(this); - } + PtNDArray div(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray cbrt() { - return JniUtils.pow(this, (PtNDArray) manager.create(1.0 / 3)); - } + PtNDArray mod(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray floor() { - return JniUtils.floor(this); - } + PtNDArray mod(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray ceil() { - return JniUtils.ceil(this); - } + PtNDArray pow(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray round() { - return JniUtils.round(this); - } + PtNDArray pow(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray trunc() { - return JniUtils.trunc(this); - } + PtNDArray addi(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray exp() { - return JniUtils.exp(this); - } + PtNDArray addi(NDArray other); - /** {@inheritDoc} */ @Override - public NDArray gammaln() { - throw new UnsupportedOperationException("Not implemented yet."); - } + PtNDArray subi(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray log() { - return JniUtils.log(this); - } + PtNDArray subi(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray log10() { - return JniUtils.log10(this); - } + PtNDArray muli(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray log2() { - return JniUtils.log2(this); - } + PtNDArray muli(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray sin() { - return JniUtils.sin(this); - } + PtNDArray divi(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray cos() { - return JniUtils.cos(this); - } + PtNDArray divi(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray tan() { - return JniUtils.tan(this); - } + PtNDArray modi(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray asin() { - return JniUtils.asin(this); - } + PtNDArray modi(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray acos() { - return JniUtils.acos(this); - } + PtNDArray powi(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray atan() { - return JniUtils.atan(this); - } + PtNDArray powi(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray sinh() { - return JniUtils.sinh(this); - } + PtNDArray sign(); - /** {@inheritDoc} */ @Override - public PtNDArray cosh() { - return JniUtils.cosh(this); - } + PtNDArray signi(); - /** {@inheritDoc} */ @Override - public PtNDArray tanh() { - return JniUtils.tanh(this); - } + PtNDArray maximum(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray asinh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray maximum(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray acosh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray minimum(Number n); - /** {@inheritDoc} */ @Override - public PtNDArray atanh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray minimum(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray toDegrees() { - return mul(180.0).div(Math.PI); - } + PtNDArray all(); - /** {@inheritDoc} */ @Override - public PtNDArray toRadians() { - return mul(Math.PI).div(180.0); - } + PtNDArray any(); - /** {@inheritDoc} */ @Override - public PtNDArray max() { - return JniUtils.max(this); - } + PtNDArray none(); - /** {@inheritDoc} */ @Override - public PtNDArray max(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.max(this, axes[0], keepDims); - } + PtNDArray neg(); - /** {@inheritDoc} */ @Override - public PtNDArray min() { - return JniUtils.min(this); - } + PtNDArray negi(); - /** {@inheritDoc} */ @Override - public PtNDArray min(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.min(this, axes[0], keepDims); - } + PtNDArray abs(); - /** {@inheritDoc} */ @Override - public PtNDArray sum() { - return JniUtils.sum(this); - } + PtNDArray square(); - /** {@inheritDoc} */ @Override - public PtNDArray sum(int[] axes, boolean keepDims) { - return JniUtils.sum(this, Arrays.stream(axes).mapToLong(i -> i).toArray(), keepDims); - } + NDArray sqrt(); - /** {@inheritDoc} */ @Override - public NDArray cumProd(int axis) { - return JniUtils.cumProd(this, axis, null); - } + PtNDArray cbrt(); - /** {@inheritDoc} */ @Override - public NDArray cumProd(int axis, DataType dataType) { - return JniUtils.cumProd(this, axis, dataType); - } + PtNDArray floor(); - /** {@inheritDoc} */ @Override - public PtNDArray prod() { - return JniUtils.prod(this); - } + PtNDArray ceil(); - /** {@inheritDoc} */ @Override - public PtNDArray prod(int[] axes, boolean keepDims) { - if (axes.length > 1) { - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.prod(this, axes[0], keepDims); - } + PtNDArray round(); - /** {@inheritDoc} */ @Override - public PtNDArray mean() { - return JniUtils.mean(this); - } + PtNDArray trunc(); - /** {@inheritDoc} */ @Override - public PtNDArray mean(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.mean(this, axes[0], keepDims); - } + PtNDArray exp(); - /** {@inheritDoc} */ @Override - public PtNDArray normalize(double p, long dim, double eps) { - return JniUtils.normalize(this, p, dim, eps); - } + NDArray gammaln(); - /** {@inheritDoc} */ @Override - public PtNDArray rotate90(int times, int[] axes) { - if (axes.length != 2) { - throw new IllegalArgumentException("Axes must be 2"); - } - return JniUtils.rot90(this, times, axes); - } + PtNDArray log(); - /** {@inheritDoc} */ @Override - public PtNDArray trace(int offset, int axis1, int axis2) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray log10(); - /** {@inheritDoc} */ @Override - public NDList split(long sections, int axis) { - long size = getShape().get(axis) / sections; - return JniUtils.split(this, size, axis); - } + PtNDArray log2(); - /** {@inheritDoc} */ @Override - public NDList split(long[] indices, int axis) { - if (indices.length == 0) { - return new NDList(this); - } - List ptIndex = new ArrayList<>(); - ptIndex.add(indices[0]); - for (int i = 1; i < indices.length; i++) { - ptIndex.add(indices[i] - indices[i - 1]); - } - ptIndex.add(size(axis) - indices[indices.length - 1]); - return JniUtils.split(this, ptIndex.stream().mapToLong(i -> i).toArray(), axis); - } + PtNDArray sin(); - /** {@inheritDoc} */ @Override - public PtNDArray flatten() { - return JniUtils.flatten(this, 0, -1); - } + PtNDArray cos(); - /** {@inheritDoc} */ @Override - public NDArray flatten(int startDim, int endDim) { - return JniUtils.flatten(this, startDim, endDim); - } + PtNDArray tan(); - /** {@inheritDoc} */ @Override - public PtNDArray reshape(Shape shape) { - return JniUtils.reshape(this, shape.getShape()); - } + PtNDArray asin(); - /** {@inheritDoc} */ @Override - public PtNDArray expandDims(int axis) { - return JniUtils.unsqueeze(this, axis); - } + PtNDArray acos(); - /** {@inheritDoc} */ @Override - public PtNDArray squeeze() { - return JniUtils.squeeze(this); - } + PtNDArray atan(); - /** {@inheritDoc} */ @Override - public PtNDArray squeeze(int axis) { - return JniUtils.squeeze(this, axis); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray squeeze(int[] axes) { - if (isScalar()) { - if (axes.length > 1 || axes[0] != 0) { - throw new IllegalArgumentException( - "axis " + axes[0] + "is out of bounds for array of dimension 0"); - } - return (PtNDArray) duplicate(); - } - long[] shapeArr = getShape().getShape(); - List newShape = new ArrayList<>(); - Set set = - IntStream.of(axes).boxed().collect(Collectors.toCollection(HashSet::new)); - // check input - for (int axis : axes) { - if (shapeArr[axis] != 1) { - throw new IllegalArgumentException( - "cannot select an axis to squeeze out which has size not equal to one"); - } - } - for (int i = 0; i < shapeArr.length; i++) { - if (!set.contains(i)) { - newShape.add(shapeArr[i]); - } - } - return (PtNDArray) reshape(newShape.stream().mapToLong(i -> i).toArray()); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray logicalAnd(NDArray other) { - return JniUtils.logicalAnd(this, manager.from(other)); - } + PtNDArray sinh(); - /** {@inheritDoc} */ - @Override - public PtNDArray logicalOr(NDArray other) { - return JniUtils.logicalOr(this, manager.from(other)); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray logicalXor(NDArray other) { - return JniUtils.logicalXor(this, manager.from(other)); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray logicalNot() { - return JniUtils.logicalNot(this); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray argSort(int axis, boolean ascending) { - PtNDArray arr = JniUtils.argSort(this, axis, false); - if (ascending) { - return arr; - } - PtNDArray flip = JniUtils.flip(arr, new long[] {axis}); - arr.close(); - return flip; - } + @Override + PtNDArray cosh(); + + @Override + PtNDArray tanh(); + + @Override + PtNDArray asinh(); + + @Override + PtNDArray acosh(); + + @Override + PtNDArray atanh(); + + @Override + PtNDArray toDegrees(); + + @Override + PtNDArray toRadians(); + + @Override + PtNDArray max(); + + @Override + PtNDArray max(int[] axes, boolean keepDims); + + @Override + PtNDArray min(); + + @Override + PtNDArray min(int[] axes, boolean keepDims); + + @Override + PtNDArray sum(); + + @Override + PtNDArray sum(int[] axes, boolean keepDims); + + @Override + NDArray cumProd(int axis); + + @Override + NDArray cumProd(int axis, DataType dataType); + + @Override + PtNDArray prod(); + + @Override + PtNDArray prod(int[] axes, boolean keepDims); + + @Override + PtNDArray mean(); + + @Override + PtNDArray mean(int[] axes, boolean keepDims); + + @Override + PtNDArray normalize(double p, long dim, double eps); + + @Override + PtNDArray rotate90(int times, int[] axes); + + @Override + PtNDArray trace(int offset, int axis1, int axis2); + + @Override + NDList split(long sections, int axis); + + @Override + NDList split(long[] indices, int axis); + + @Override + PtNDArray flatten(); + + @Override + NDArray flatten(int startDim, int endDim); + + @Override + PtNDArray reshape(Shape shape); + + @Override + PtNDArray expandDims(int axis); + + @Override + PtNDArray squeeze(); + + @Override + PtNDArray squeeze(int axis); + + @Override + PtNDArray squeeze(int[] axes); + + @Override + PtNDArray logicalAnd(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray sort() { - return sort(-1); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray sort(int axis) { - return JniUtils.sort(this, axis, false); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray softmax(int axis) { - return JniUtils.softmax(this, axis, getDataType()); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray logSoftmax(int axis) { - return JniUtils.logSoftmax(this, axis, getDataType()); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray cumSum() { - // TODO: change default behavior on cumSum - if (isScalar()) { - return (PtNDArray) reshape(1); - } - if (isEmpty()) { - return (PtNDArray) reshape(0); - } - return cumSum(0); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray cumSum(int axis) { - return JniUtils.cumSum(this, axis); - } - - /** {@inheritDoc} */ - @Override - public void intern(NDArray replaced) { - PtNDArray arr = (PtNDArray) replaced; - Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); - JniUtils.deleteNDArray(oldHandle); - // dereference old ndarray - arr.close(); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray isInfinite() { - return JniUtils.isInf(this); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray isNaN() { - return JniUtils.isNaN(this); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray tile(long repeats) { - // zero-dim - if (isEmpty()) { - return (PtNDArray) duplicate(); - } - // scalar - int dim = (isScalar()) ? 1 : getShape().dimension(); - long[] repeatsArray = new long[dim]; - Arrays.fill(repeatsArray, repeats); - return tile(repeatsArray); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray tile(int axis, long repeats) { - throw new UnsupportedOperationException("Not implemented"); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray tile(long[] repeats) { - return JniUtils.tile(this, repeats); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray tile(Shape desiredShape) { - throw new UnsupportedOperationException("Not implemented"); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray repeat(long repeats) { - // zero-dim - if (isEmpty()) { - return (PtNDArray) duplicate(); - } - // scalar - int dim = (isScalar()) ? 1 : getShape().dimension(); - long[] repeatsArray = new long[dim]; - Arrays.fill(repeatsArray, repeats); - return repeat(repeatsArray); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray repeat(int axis, long repeats) { - return JniUtils.repeat(this, repeats, axis); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray repeat(long[] repeats) { - PtNDArray result = this; - for (int dim = 0; dim < repeats.length; dim++) { - PtNDArray temp = result; - result = JniUtils.repeat(result, repeats[dim], dim); - if (temp != this) { - temp.close(); - } - } - return result; - } - - /** {@inheritDoc} */ - @Override - public PtNDArray repeat(Shape desiredShape) { - return repeat(repeatsToMatchShape(desiredShape)); - } - - private long[] repeatsToMatchShape(Shape desiredShape) { - Shape curShape = getShape(); - int dimension = curShape.dimension(); - if (desiredShape.dimension() > dimension) { - throw new IllegalArgumentException("The desired shape has too many dimensions"); - } - if (desiredShape.dimension() < dimension) { - int additionalDimensions = dimension - desiredShape.dimension(); - desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape); - } - long[] repeats = new long[dimension]; - for (int i = 0; i < dimension; i++) { - if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) { - throw new IllegalArgumentException( - "The desired shape is not a multiple of the original shape"); - } - repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i))); - } - return repeats; - } - - /** {@inheritDoc} */ - @Override - public PtNDArray dot(NDArray other) { - int selfDim = this.getShape().dimension(); - int otherDim = other.getShape().dimension(); - if (selfDim != otherDim || selfDim > 2) { - throw new UnsupportedOperationException( - "Dimension mismatch or high dimensional dot operation is not supported. Please" - + " use .matMul instead."); - } - return JniUtils.dot(this, manager.from(other)); - } - - /** {@inheritDoc} */ - @Override - public NDArray matMul(NDArray other) { - if (isScalar() || other.isScalar()) { - throw new IllegalArgumentException("scalar is not allowed for matMul()"); - } - return JniUtils.matmul(this, manager.from(other)); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray clip(Number min, Number max) { - return JniUtils.clip(this, min, max); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray swapAxes(int axis1, int axis2) { - return JniUtils.transpose(this, axis1, axis2); - } - - /** {@inheritDoc} */ - @Override - public NDArray flip(int... axes) { - return JniUtils.flip(this, Arrays.stream(axes).mapToLong(ele -> (long) ele).toArray()); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray transpose() { - int dim = getShape().dimension(); - int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray(); - return transpose(reversedShape); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray transpose(int... axes) { - if (isScalar() && axes.length > 0) { - throw new IllegalArgumentException("axes don't match NDArray"); - } - return JniUtils.permute(this, Arrays.stream(axes).mapToLong(i -> i).toArray()); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray broadcast(Shape shape) { - return JniUtils.broadcast(this, shape); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray argMax() { - if (isEmpty()) { - throw new IllegalArgumentException("attempt to get argMax of an empty NDArray"); - } - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMax(this); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray argMax(int axis) { - // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMax(this, axis, false); - } - - /** {@inheritDoc} */ - @Override - public PtNDArray argMin() { - if (isEmpty()) { - throw new IllegalArgumentException("attempt to get argMin of an empty NDArray"); - } - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMin(this); - } + PtNDArray logicalOr(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray argMin(int axis) { - // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMin(this, axis, false); - } + PtNDArray logicalXor(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArray percentile(Number percentile) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray logicalNot(); - /** {@inheritDoc} */ @Override - public PtNDArray percentile(Number percentile, int[] axes) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray argSort(int axis, boolean ascending); - /** {@inheritDoc} */ @Override - public PtNDArray median() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray sort(); - /** {@inheritDoc} */ @Override - public PtNDArray median(int[] axes) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray sort(int axis); - /** {@inheritDoc} */ @Override - public PtNDArray toDense() { - if (!isSparse() && JniUtils.getLayout(this) != 2) { - return (PtNDArray) duplicate(); - } - return JniUtils.toDense(this); - } + PtNDArray softmax(int axis); - /** {@inheritDoc} */ @Override - public PtNDArray toSparse(SparseFormat fmt) { - if (fmt == SparseFormat.DENSE) { - throw new IllegalArgumentException("Default type is not allowed"); - } - if (fmt != SparseFormat.COO) { - throw new UnsupportedOperationException("Only COO sparse type supported for PyTorch"); - } - if (fmt == getSparseFormat()) { - return (PtNDArray) duplicate(); - } - return JniUtils.toSparse(this); - } + PtNDArray logSoftmax(int axis); - /** {@inheritDoc} */ @Override - public PtNDArray nonzero() { - return JniUtils.nonZeros(this); - } + PtNDArray cumSum(); - /** {@inheritDoc} */ @Override - public PtNDArray erfinv() { - return JniUtils.erfinv(this); - } - - /** {@inheritDoc} */ + PtNDArray cumSum(int axis); + + @Override + void intern(NDArray replaced); + + @Override + PtNDArray isInfinite(); + + @Override + PtNDArray isNaN(); + + @Override + PtNDArray tile(long repeats); + + @Override + PtNDArray tile(int axis, long repeats); + + @Override + PtNDArray tile(long[] repeats); + + @Override + PtNDArray tile(Shape desiredShape); + + @Override + PtNDArray repeat(long repeats); + + @Override + PtNDArray repeat(int axis, long repeats); + + @Override + PtNDArray repeat(long[] repeats); + + @Override + PtNDArray repeat(Shape desiredShape); + + @Override + PtNDArray dot(NDArray other); + + @Override + NDArray matMul(NDArray other); + + @Override + PtNDArray clip(Number min, Number max); + + @Override + PtNDArray swapAxes(int axis1, int axis2); + + @Override + NDArray flip(int... axes); + + @Override + PtNDArray transpose(); + + @Override + PtNDArray transpose(int... axes); + + @Override + PtNDArray broadcast(Shape shape); + + @Override + PtNDArray argMax(); + + @Override + PtNDArray argMax(int axis); + + @Override + PtNDArray argMin(); + + @Override + PtNDArray argMin(int axis); + + @Override + PtNDArray percentile(Number percentile); + + @Override + PtNDArray percentile(Number percentile, int[] axes); + + @Override + PtNDArray median(); + + @Override + PtNDArray median(int[] axes); + + @Override + PtNDArray toDense(); + + @Override + PtNDArray toSparse(SparseFormat fmt); + + @Override + PtNDArray nonzero(); + + @Override + PtNDArray erfinv(); + @Override - public PtNDArray inverse() { - return JniUtils.inverse(this); - } + PtNDArray inverse(); - /** {@inheritDoc} */ @Override - public NDArray norm(boolean keepDims) { - return JniUtils.norm(this, 2, new int[] {}, keepDims); - } + NDArray norm(boolean keepDims); - /** {@inheritDoc} */ @Override - public NDArray norm(int order, int[] axes, boolean keepDims) { - return JniUtils.norm(this, order, axes, keepDims); - } + NDArray norm(int order, int[] axes, boolean keepDims); - /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth) { - return JniUtils.oneHot(this, depth, DataType.FLOAT32); - } + NDArray oneHot(int depth); - /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth, DataType dataType) { - return JniUtils.oneHot(this, depth, dataType); - } + NDArray oneHot(int depth, DataType dataType); - /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { - throw new UnsupportedOperationException("Not implemented"); - } + NDArray oneHot(int depth, float onValue, float offValue, DataType dataType); - /** {@inheritDoc} */ @Override - public NDArray batchDot(NDArray other) { - throw new UnsupportedOperationException("Not implemented"); - } + NDArray batchDot(NDArray other); - /** {@inheritDoc} */ @Override - public PtNDArrayEx getNDArrayInternal() { - return ptNDArrayEx; - } + PtNDArrayEx getNDArrayInternal(); - /** {@inheritDoc} */ @Override - public String toString() { - if (isReleased()) { - return "This array is already closed"; - } - // index operator in toDebugString is not supported for MKLDNN & Sparse layout - if (JniUtils.getLayout(this) != 0) { - try (NDArray tmp = toDense()) { - return tmp.toDebugString(); - } - } - return toDebugString(); - } + String toString(); - /** {@inheritDoc} */ @Override - public boolean equals(Object obj) { - if (obj instanceof PtNDArray) { - return contentEquals((PtNDArray) obj); - } - return false; - } + boolean equals(Object obj); - /** {@inheritDoc} */ @Override - public int hashCode() { - return 0; - } + int hashCode(); - /** {@inheritDoc} */ @Override - public void close() { - Long pointer = handle.getAndSet(null); - if (pointer != null && pointer != -1) { - JniUtils.deleteNDArray(pointer); - } - manager.detachInternal(getUid()); - dataRef = null; - } + void close(); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index f5ae2cdbdd3..0fe4a80eb98 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -29,14 +29,14 @@ /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ public class PtNDArrayEx implements NDArrayEx { - private PtNDArray array; + private PtNDArrayImpl array; /** * Constructs an {@code PtNDArrayEx} given a {@link NDArray}. * * @param parent the {@link NDArray} to extend */ - PtNDArrayEx(PtNDArray parent) { + PtNDArrayEx(PtNDArrayImpl parent) { this.array = parent; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java new file mode 100644 index 00000000000..0aca7f8fa55 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -0,0 +1,1561 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.engine; + +import ai.djl.Device; +import ai.djl.ndarray.BaseNDManager; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.ndarray.types.SparseFormat; +import ai.djl.pytorch.jni.JniUtils; +import ai.djl.util.NativeResourceImpl; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** {@code PtNDArray} is the PyTorch implementation of {@link NDArray}. */ +public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray { + + private String name; + private Device device; + private DataType dataType; + private Shape shape; + private SparseFormat sparseFormat; + // use Boolean object to maintain three status: null, false, true + private Boolean hasGradient; + private PtNDManager manager; + private PtNDArrayEx ptNDArrayEx; + private String[] strs; + + // keep a reference to direct buffer to avoid GC release the memory + @SuppressWarnings("PMD.UnusedPrivateField") + private ByteBuffer dataRef; + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead). + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + */ + public PtNDArrayImpl(PtNDManager manager, long handle) { + super(handle); + this.manager = manager; + this.ptNDArrayEx = new PtNDArrayEx(this); + manager.attachInternal(getUid(), this); + } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead) with the data that is hold on Java side. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @param data the direct buffer of the data + */ + public PtNDArrayImpl(PtNDManager manager, long handle, ByteBuffer data) { + super(handle); + this.manager = manager; + this.ptNDArrayEx = new PtNDArrayEx(this); + manager.attachInternal(getUid(), this); + dataRef = data; + } + + /** + * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle + * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. + * + * @param manager the manager to attach the new array to + * @param strs the string array + * @param shape the {@link Shape} of the {@link NDArray} + */ + public PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { + super(-1L); + this.manager = manager; + this.strs = strs; + this.shape = shape; + this.dataType = DataType.STRING; + } + + /** {@inheritDoc} */ + @Override + public PtNDManager getManager() { + return manager; + } + + /** {@inheritDoc} */ + @Override + public String getName() { + return name; + } + + /** {@inheritDoc} */ + @Override + public void setName(String name) { + this.name = name; + } + + /** {@inheritDoc} */ + @Override + public DataType getDataType() { + if (dataType == null) { + dataType = JniUtils.getDataType(this); + } + return dataType; + } + + /** {@inheritDoc} */ + @Override + public Device getDevice() { + if (device == null) { + device = JniUtils.getDevice(this); + } + return device; + } + + /** {@inheritDoc} */ + @Override + public Shape getShape() { + if (shape == null) { + shape = JniUtils.getShape(this); + } + return shape; + } + + /** {@inheritDoc} */ + @Override + public SparseFormat getSparseFormat() { + if (sparseFormat == null) { + sparseFormat = JniUtils.getSparseFormat(this); + } + return sparseFormat; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDevice(Device device, boolean copy) { + if (device.equals(getDevice()) && !copy) { + return this; + } + return JniUtils.to(this, getDataType(), device); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toType(DataType dataType, boolean copy) { + if (dataType.equals(getDataType()) && !copy) { + return this; + } + return JniUtils.to(this, dataType, getDevice()); + } + + /** {@inheritDoc} */ + @Override + public void setRequiresGradient(boolean requiresGrad) { + JniUtils.attachGradient(this, requiresGrad); + hasGradient = requiresGrad; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray getGradient() { + if (!hasGradient()) { + throw new IllegalStateException( + "No gradient attached to this NDArray, please call array.setRequiresGradient()" + + " on your NDArray or block.setInitializer() on your Block"); + } + PtNDArray res = JniUtils.getGradient(this); + // If you call getGradient() before you run the backward, + // you will get nothing in PyTorch engine. + // To align with MXNet's behavior, we will create a zeros NDArray. + // TODO should we access the grad NDArray after we close the parameter NDArray? + if (res == null) { + res = (PtNDArray) manager.zeros(getShape()); + } + return res; + } + + /** {@inheritDoc} */ + @Override + public boolean hasGradient() { + if (hasGradient == null) { + hasGradient = JniUtils.requiresGrad(this); + } + return hasGradient; + } + + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + return JniUtils.detachGradient(this); + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return JniUtils.getByteBuffer(this); + } + + /** {@inheritDoc} */ + @Override + public String[] toStringArray(Charset charset) { + return strs; + } + + /** {@inheritDoc} */ + @Override + public void set(Buffer buffer) { + int size = Math.toIntExact(size()); + DataType type = getDataType(); + BaseNDManager.validateBuffer(buffer, type, size); + // TODO how do we handle the exception happened in the middle + dataRef = null; + if (buffer.isDirect() && buffer instanceof ByteBuffer) { + // If NDArray is on the GPU, it is native code responsibility to control the data life + // cycle + if (!getDevice().isGpu()) { + dataRef = (ByteBuffer) buffer; + } + JniUtils.set(this, (ByteBuffer) buffer); + return; + } + // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType + ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes()); + BaseNDManager.copyBuffer(buffer, buf); + + // If NDArray is on the GPU, it is native code responsibility to control the data life cycle + if (!getDevice().isGpu()) { + dataRef = buf; + } + JniUtils.set(this, buf); + } + + /** {@inheritDoc} */ + @Override + public NDArray get(NDManager manager, long... indices) { + return JniUtils.getItem(this, indices, (PtNDManager) manager); + } + + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray index is supported."); + } + return JniUtils.gather(this, (PtNDArray) index, axis); + } + + /** {@inheritDoc} */ + @Override + public NDArray gatherNd(NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray index is supported."); + } + Shape indexShape = index.getShape(); + Shape dataShape = getShape(); + int indexingDepth = (int) indexShape.get(0); + if (indexingDepth > dataShape.dimension()) { + throw new IllegalArgumentException( + "Indexing rank " + + indexShape.get(0) + + " exceeds the data rank " + + dataShape.dimension()); + } + // Row-first order, the linear index is accumulated from z->y->x. + // For example, dataShape = (3, 2, 3), indexShape = (2, 3, 3) + // The method is: indexLinear = index[1] + index[0] * dataShape[1], row-first order + // indexLinear has shape (3, 3), is from combining the index along 0 axis. + // Each number in indexLinear is an indexing to an element in data (3, 2, ...). + // data is flattened to be (3*2, ...) which can be indexed by indexLinear. + // Finally, reshape the output to (3, 3, ...). Thus + // totalShape = indexShape.slice(1).addAll(dataShape.slice(indexingDepth)); + NDArray indexLinear = index.get("{}, ...", indexingDepth - 1); + long dim = 1; + for (int i = indexingDepth - 2; i > -1; i--) { + dim = dim * dataShape.get(i + 1); + indexLinear = indexLinear.addi(index.get("{}, ...", i).muli(dim)); + } + NDArray dataFlatten = this.flatten(0, indexingDepth - 1); + return dataFlatten.get(indexLinear); + } + + /** {@inheritDoc} */ + @Override + public NDArray take(NDManager manager, NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.take(this, (PtNDArray) index, (PtNDManager) manager); + } + + /** {@inheritDoc} */ + @Override + public NDArray put(NDArray index, NDArray data) { + if (!(index instanceof PtNDArray) || !(data instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.put(this, (PtNDArray) index, (PtNDArray) data); + } + + /** {@inheritDoc} */ + @Override + public void copyTo(NDArray array) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public void attach(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void returnResource(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachUncappedInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { + NDManager original = this.manager; + detach(); + this.manager = (PtNDManager) manager; + manager.tempAttachInternal(original, getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void detach() { + manager.detachInternal(getUid()); + manager = PtNDManager.getSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public NDArray duplicate() { + return JniUtils.clone(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray booleanMask(NDArray index, int axis) { + Shape indexShape = index.getShape(); + if (indexShape.equals(getShape())) { + // Result is flattened since shape is undetermined + return JniUtils.booleanMask(this, manager.from(index)); + } else if (indexShape.equals(getShape().slice(axis))) { + // index will be broadcast by default + try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) { + // Shape recovery + Shape remainder = getShape().slice(0, axis); + long selectedSize = flattedResult.getShape().size() / remainder.size(); + return flattedResult.reshape(remainder.addAll(new Shape(selectedSize))); + } + } else { + throw new UnsupportedOperationException( + "Not supported for shape not broadcastable " + + indexShape + + " vs " + + getShape()); + } + } + + /** {@inheritDoc} */ + @Override + public NDArray sequenceMask(NDArray sequenceLength, float value) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** {@inheritDoc} */ + @Override + public NDArray sequenceMask(NDArray sequenceLength) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** {@inheritDoc} */ + @Override + public boolean contentEquals(Number number) { + return JniUtils.contentEqual(this, (PtNDArray) manager.create(number)); + } + + /** {@inheritDoc} */ + @Override + public boolean contentEquals(NDArray other) { + if (other == null || (!shapeEquals(other))) { + return false; + } + if (getDataType() != other.getDataType()) { + return false; + } + return JniUtils.contentEqual(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray eq(Number n) { + try (NDArray number = manager.create(n)) { + return eq(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray eq(NDArray other) { + return JniUtils.eq(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neq(Number n) { + try (NDArray number = manager.create(n)) { + return neq(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neq(NDArray other) { + return JniUtils.neq(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gt(Number n) { + try (NDArray number = manager.create(n)) { + return gt(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gt(NDArray other) { + return JniUtils.gt(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gte(Number n) { + try (NDArray number = manager.create(n)) { + return gte(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gte(NDArray other) { + return JniUtils.gte(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lt(Number n) { + try (NDArray number = manager.create(n)) { + return lt(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lt(NDArray other) { + return JniUtils.lt(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lte(Number n) { + try (NDArray number = manager.create(n)) { + return lte(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lte(NDArray other) { + return JniUtils.lte(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray add(Number n) { + try (NDArray number = manager.create(n)) { + return add(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray add(NDArray other) { + return JniUtils.add(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sub(Number n) { + try (NDArray number = manager.create(n)) { + return sub(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sub(NDArray other) { + return JniUtils.sub(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mul(Number n) { + try (NDArray number = manager.create(n)) { + return mul(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mul(NDArray other) { + return JniUtils.mul(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray div(Number n) { + try (NDArray number = manager.create(n)) { + return div(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray div(NDArray other) { + return JniUtils.div(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mod(Number n) { + try (NDArray number = manager.create(n)) { + return mod(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mod(NDArray other) { + return JniUtils.remainder(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray pow(Number n) { + try (NDArray number = manager.create(n)) { + return pow(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray pow(NDArray other) { + return JniUtils.pow(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray addi(Number n) { + try (NDArray number = manager.create(n)) { + return addi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray addi(NDArray other) { + JniUtils.addi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray subi(Number n) { + try (NDArray number = manager.create(n)) { + return subi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray subi(NDArray other) { + JniUtils.subi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray muli(Number n) { + try (NDArray number = manager.create(n)) { + return muli(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray muli(NDArray other) { + JniUtils.muli(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray divi(Number n) { + try (NDArray number = manager.create(n)) { + return divi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray divi(NDArray other) { + JniUtils.divi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray modi(Number n) { + try (NDArray number = manager.create(n)) { + return modi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray modi(NDArray other) { + JniUtils.remainderi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray powi(Number n) { + try (NDArray number = manager.create(n)) { + return powi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray powi(NDArray other) { + JniUtils.powi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sign() { + return JniUtils.sign(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray signi() { + JniUtils.signi(this); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray maximum(Number n) { + try (NDArray number = manager.create(n)) { + return maximum(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray maximum(NDArray other) { + return JniUtils.max(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray minimum(Number n) { + try (NDArray number = manager.create(n)) { + return minimum(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray minimum(NDArray other) { + return JniUtils.min(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray all() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.all(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray any() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.any(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray none() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.none(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neg() { + return JniUtils.neg(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray negi() { + JniUtils.negi(this); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray abs() { + return JniUtils.abs(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray square() { + return JniUtils.square(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray sqrt() { + return JniUtils.sqrt(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cbrt() { + return JniUtils.pow(this, (PtNDArray) manager.create(1.0 / 3)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray floor() { + return JniUtils.floor(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray ceil() { + return JniUtils.ceil(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray round() { + return JniUtils.round(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray trunc() { + return JniUtils.trunc(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray exp() { + return JniUtils.exp(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray gammaln() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log() { + return JniUtils.log(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log10() { + return JniUtils.log10(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log2() { + return JniUtils.log2(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sin() { + return JniUtils.sin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cos() { + return JniUtils.cos(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tan() { + return JniUtils.tan(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray asin() { + return JniUtils.asin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray acos() { + return JniUtils.acos(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray atan() { + return JniUtils.atan(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sinh() { + return JniUtils.sinh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cosh() { + return JniUtils.cosh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tanh() { + return JniUtils.tanh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray asinh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray acosh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray atanh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDegrees() { + return mul(180.0).div(Math.PI); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toRadians() { + return mul(Math.PI).div(180.0); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray max() { + return JniUtils.max(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray max(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.max(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray min() { + return JniUtils.min(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray min(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.min(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sum() { + return JniUtils.sum(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sum(int[] axes, boolean keepDims) { + return JniUtils.sum(this, Arrays.stream(axes).mapToLong(i -> i).toArray(), keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray cumProd(int axis) { + return JniUtils.cumProd(this, axis, null); + } + + /** {@inheritDoc} */ + @Override + public NDArray cumProd(int axis, DataType dataType) { + return JniUtils.cumProd(this, axis, dataType); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray prod() { + return JniUtils.prod(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray prod(int[] axes, boolean keepDims) { + if (axes.length > 1) { + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.prod(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mean() { + return JniUtils.mean(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mean(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.mean(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray normalize(double p, long dim, double eps) { + return JniUtils.normalize(this, p, dim, eps); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray rotate90(int times, int[] axes) { + if (axes.length != 2) { + throw new IllegalArgumentException("Axes must be 2"); + } + return JniUtils.rot90(this, times, axes); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray trace(int offset, int axis1, int axis2) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDList split(long sections, int axis) { + long size = getShape().get(axis) / sections; + return JniUtils.split(this, size, axis); + } + + /** {@inheritDoc} */ + @Override + public NDList split(long[] indices, int axis) { + if (indices.length == 0) { + return new NDList(this); + } + List ptIndex = new ArrayList<>(); + ptIndex.add(indices[0]); + for (int i = 1; i < indices.length; i++) { + ptIndex.add(indices[i] - indices[i - 1]); + } + ptIndex.add(size(axis) - indices[indices.length - 1]); + return JniUtils.split(this, ptIndex.stream().mapToLong(i -> i).toArray(), axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray flatten() { + return JniUtils.flatten(this, 0, -1); + } + + /** {@inheritDoc} */ + @Override + public NDArray flatten(int startDim, int endDim) { + return JniUtils.flatten(this, startDim, endDim); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray reshape(Shape shape) { + return JniUtils.reshape(this, shape.getShape()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray expandDims(int axis) { + return JniUtils.unsqueeze(this, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze() { + return JniUtils.squeeze(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze(int axis) { + return JniUtils.squeeze(this, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze(int[] axes) { + if (isScalar()) { + if (axes.length > 1 || axes[0] != 0) { + throw new IllegalArgumentException( + "axis " + axes[0] + "is out of bounds for array of dimension 0"); + } + return (PtNDArray) duplicate(); + } + long[] shapeArr = getShape().getShape(); + List newShape = new ArrayList<>(); + Set set = + IntStream.of(axes).boxed().collect(Collectors.toCollection(HashSet::new)); + // check input + for (int axis : axes) { + if (shapeArr[axis] != 1) { + throw new IllegalArgumentException( + "cannot select an axis to squeeze out which has size not equal to one"); + } + } + for (int i = 0; i < shapeArr.length; i++) { + if (!set.contains(i)) { + newShape.add(shapeArr[i]); + } + } + return (PtNDArray) reshape(newShape.stream().mapToLong(i -> i).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalAnd(NDArray other) { + return JniUtils.logicalAnd(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalOr(NDArray other) { + return JniUtils.logicalOr(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalXor(NDArray other) { + return JniUtils.logicalXor(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalNot() { + return JniUtils.logicalNot(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argSort(int axis, boolean ascending) { + PtNDArray arr = JniUtils.argSort(this, axis, false); + if (ascending) { + return arr; + } + PtNDArray flip = JniUtils.flip(arr, new long[] {axis}); + arr.close(); + return flip; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sort() { + return sort(-1); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sort(int axis) { + return JniUtils.sort(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray softmax(int axis) { + return JniUtils.softmax(this, axis, getDataType()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logSoftmax(int axis) { + return JniUtils.logSoftmax(this, axis, getDataType()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cumSum() { + // TODO: change default behavior on cumSum + if (isScalar()) { + return (PtNDArray) reshape(1); + } + if (isEmpty()) { + return (PtNDArray) reshape(0); + } + return cumSum(0); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cumSum(int axis) { + return JniUtils.cumSum(this, axis); + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + PtNDArrayImpl arr = (PtNDArrayImpl) replaced; + Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); + JniUtils.deleteNDArray(oldHandle); + // dereference old ndarray + arr.close(); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray isInfinite() { + return JniUtils.isInf(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray isNaN() { + return JniUtils.isNaN(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(long repeats) { + // zero-dim + if (isEmpty()) { + return (PtNDArray) duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return tile(repeatsArray); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(int axis, long repeats) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(long[] repeats) { + return JniUtils.tile(this, repeats); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(Shape desiredShape) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(long repeats) { + // zero-dim + if (isEmpty()) { + return (PtNDArray) duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return repeat(repeatsArray); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(int axis, long repeats) { + return JniUtils.repeat(this, repeats, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(long[] repeats) { + PtNDArray result = this; + for (int dim = 0; dim < repeats.length; dim++) { + PtNDArray temp = result; + result = JniUtils.repeat(result, repeats[dim], dim); + if (temp != this) { + temp.close(); + } + } + return result; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(Shape desiredShape) { + return repeat(repeatsToMatchShape(desiredShape)); + } + + private long[] repeatsToMatchShape(Shape desiredShape) { + Shape curShape = getShape(); + int dimension = curShape.dimension(); + if (desiredShape.dimension() > dimension) { + throw new IllegalArgumentException("The desired shape has too many dimensions"); + } + if (desiredShape.dimension() < dimension) { + int additionalDimensions = dimension - desiredShape.dimension(); + desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape); + } + long[] repeats = new long[dimension]; + for (int i = 0; i < dimension; i++) { + if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) { + throw new IllegalArgumentException( + "The desired shape is not a multiple of the original shape"); + } + repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i))); + } + return repeats; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray dot(NDArray other) { + int selfDim = this.getShape().dimension(); + int otherDim = other.getShape().dimension(); + if (selfDim != otherDim || selfDim > 2) { + throw new UnsupportedOperationException( + "Dimension mismatch or high dimensional dot operation is not supported. Please" + + " use .matMul instead."); + } + return JniUtils.dot(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public NDArray matMul(NDArray other) { + if (isScalar() || other.isScalar()) { + throw new IllegalArgumentException("scalar is not allowed for matMul()"); + } + return JniUtils.matmul(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray clip(Number min, Number max) { + return JniUtils.clip(this, min, max); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray swapAxes(int axis1, int axis2) { + return JniUtils.transpose(this, axis1, axis2); + } + + /** {@inheritDoc} */ + @Override + public NDArray flip(int... axes) { + return JniUtils.flip(this, Arrays.stream(axes).mapToLong(ele -> (long) ele).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray transpose() { + int dim = getShape().dimension(); + int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray(); + return transpose(reversedShape); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray transpose(int... axes) { + if (isScalar() && axes.length > 0) { + throw new IllegalArgumentException("axes don't match NDArray"); + } + return JniUtils.permute(this, Arrays.stream(axes).mapToLong(i -> i).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray broadcast(Shape shape) { + return JniUtils.broadcast(this, shape); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMax() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMax of an empty NDArray"); + } + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMax(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMax(int axis) { + // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMax(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMin() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMin of an empty NDArray"); + } + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMin(int axis) { + // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMin(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray percentile(Number percentile) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray percentile(Number percentile, int[] axes) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray median() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray median(int[] axes) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDense() { + if (!isSparse() && JniUtils.getLayout(this) != 2) { + return (PtNDArray) duplicate(); + } + return JniUtils.toDense(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toSparse(SparseFormat fmt) { + if (fmt == SparseFormat.DENSE) { + throw new IllegalArgumentException("Default type is not allowed"); + } + if (fmt != SparseFormat.COO) { + throw new UnsupportedOperationException("Only COO sparse type supported for PyTorch"); + } + if (fmt == getSparseFormat()) { + return (PtNDArray) duplicate(); + } + return JniUtils.toSparse(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray nonzero() { + return JniUtils.nonZeros(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray erfinv() { + return JniUtils.erfinv(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray inverse() { + return JniUtils.inverse(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray norm(boolean keepDims) { + return JniUtils.norm(this, 2, new int[] {}, keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray norm(int order, int[] axes, boolean keepDims) { + return JniUtils.norm(this, order, axes, keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth) { + return JniUtils.oneHot(this, depth, DataType.FLOAT32); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, DataType dataType) { + return JniUtils.oneHot(this, depth, dataType); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDArray batchDot(NDArray other) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArrayEx getNDArrayInternal() { + return ptNDArrayEx; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (isReleased()) { + return "This array is already closed"; + } + // index operator in toDebugString is not supported for MKLDNN & Sparse layout + if (JniUtils.getLayout(this) != 0) { + try (NDArray tmp = toDense()) { + return tmp.toDebugString(); + } + } + return toDebugString(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (obj instanceof PtNDArray) { + return contentEquals((PtNDArray) obj); + } + return false; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return 0; + } + + /** {@inheritDoc} */ + @Override + public void close() { + Long pointer = handle.getAndSet(null); + if (pointer != null && pointer != -1) { + JniUtils.deleteNDArray(pointer); + } + manager.detachInternal(getUid()); + dataRef = null; + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index b38a038d142..c64c9dda170 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -24,7 +24,7 @@ import java.util.Stack; -/** The {@link NDArrayIndexer} used by the {@link PtNDArray}. */ +/** The {@link NDArrayIndexer} used by the {@link PtNDArrayImpl}. */ public class PtNDArrayIndexer extends NDArrayIndexer { private PtNDManager manager; @@ -70,8 +70,8 @@ public NDArray get(NDArray array, NDIndex index) { index.addAllDim(); } - if (array == null || array instanceof PtNDArray) { - return JniUtils.indexAdv((PtNDArray) array, index, manager); + if (array == null || array instanceof PtNDArrayImpl) { + return JniUtils.indexAdv((PtNDArrayImpl) array, index, manager); } else { PtNDArray arrayNew = manager.create(array.toByteBuffer(), array.getShape(), array.getDataType()); @@ -89,9 +89,9 @@ public void set(NDArray array, NDIndex index, Object data) { array.toByteBuffer(), array.getShape(), array.getDataType()); if (data instanceof Number) { - JniUtils.indexAdvPut(ptArray, index, (PtNDArray) manager.create((Number) data)); + JniUtils.indexAdvPut(ptArray, index, (PtNDArrayImpl) manager.create((Number) data)); } else if (data instanceof NDArray) { - JniUtils.indexAdvPut(ptArray, index, (PtNDArray) data); + JniUtils.indexAdvPut(ptArray, index, (PtNDArrayImpl) data); } else { throw new IllegalArgumentException( "The type of value to assign cannot be other than NDArray and Number."); diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java similarity index 73% rename from api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java rename to engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index df8131a5d3d..24426ffa6f0 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/NDArrayWrapFactory.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -10,15 +10,18 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.ndarray.gc; +package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.gc.DynamicInvocationHandler; +import ai.djl.ndarray.gc.NDArrayProxyMaker; +import ai.djl.ndarray.gc.WeakHashMapWrapper; import java.lang.reflect.Proxy; import java.util.UUID; -/** {@code NDArrayWrapFactory} creates a proxy facade. */ -public class NDArrayWrapFactory { +/** {@code PtNDArrayProxyMaker} creates a proxy facade. */ +public class PtNDArrayProxyMaker implements NDArrayProxyMaker { WeakHashMapWrapper map = new WeakHashMapWrapper<>(); @@ -32,20 +35,19 @@ public int mapSize() { } /** - * Wraps the {@link NDArray} in a proxy facade. + * Wraps the {@link PtNDArray} in a proxy facade. * * @param array the array to wrap * @return the wrapped array */ - public NDArray wrap(NDArray array) { + public PtNDArray wrap(NDArray array) { UUID uuid = UUID.randomUUID(); map.put(uuid, array); - DynamicInvocationHandler handler = new DynamicInvocationHandler(uuid, map, this); - return (NDArray) + return (PtNDArray) Proxy.newProxyInstance( Thread.currentThread().getContextClassLoader(), - new Class[] {NDArray.class}, + new Class[] {PtNDArray.class}, handler); } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index e282c27b7f1..24bb7b972be 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -81,7 +81,7 @@ public PtNDArray create(Buffer data, Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { - return new PtNDArray(this, data, shape); + return new PtNDArrayImpl(this, data, shape); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java index b1c34688cdb..c1804735fdc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -16,8 +16,9 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import java.util.Arrays; import java.util.Map; @@ -28,7 +29,7 @@ * *

DJL doesn't support creating nested IValue. */ -public class IValue extends NativeResource { +public class IValue extends NativeResourceImpl { IValue(long handle) { super(handle); @@ -392,7 +393,7 @@ public double[] toDoubleArray() { * @return the NDArray value of this IValue */ public PtNDArray toTensor(PtNDManager manager) { - return new PtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); + return new PtNDArrayImpl(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); } /** @@ -403,9 +404,9 @@ public PtNDArray toTensor(PtNDManager manager) { */ public PtNDArray[] toTensorArray(PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); - PtNDArray[] ret = new PtNDArray[handles.length]; + PtNDArray[] ret = new PtNDArrayImpl[handles.length]; for (int i = 0; i < ret.length; ++i) { - ret[i] = new PtNDArray(manager, handles[i]); + ret[i] = new PtNDArrayImpl(manager, handles[i]); } return ret; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 47fb0b724df..275eb830848 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -30,6 +30,7 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.engine.PtDeviceType; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; @@ -170,15 +171,15 @@ public static PtNDArray createNdFromByteBuffer( if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side - return new PtNDArray(manager, handle); + return new PtNDArrayImpl(manager, handle); } - return new PtNDArray(manager, handle, data); + return new PtNDArrayImpl(manager, handle, data); } public static PtNDArray createEmptyNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchEmpty( shape.getShape(), @@ -191,7 +192,7 @@ public static PtNDArray createEmptyNdArray( public static PtNDArray createZerosNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchZeros( shape.getShape(), @@ -204,7 +205,7 @@ public static PtNDArray createZerosNdArray( public static PtNDArray createOnesNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchOnes( shape.getShape(), @@ -222,7 +223,7 @@ public static PtNDArray full( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchFull( shape.getShape(), @@ -234,9 +235,9 @@ public static PtNDArray full( } public static PtNDArray zerosLike( - PtNDArray array, DataType dType, Device device, SparseFormat fmt) { + PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( array.getManager(), PyTorchLibrary.LIB.torchZerosLike( array.getHandle(), @@ -247,9 +248,9 @@ public static PtNDArray zerosLike( } public static PtNDArray onesLike( - PtNDArray array, DataType dType, Device device, SparseFormat fmt) { + PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( array.getManager(), PyTorchLibrary.LIB.torchOnesLike( array.getHandle(), @@ -268,7 +269,7 @@ public static PtNDArray arange( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchArange( start, @@ -289,7 +290,7 @@ public static PtNDArray linspace( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchLinspace( start, @@ -302,7 +303,7 @@ public static PtNDArray linspace( } public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) { - return new PtNDArray( + return new PtNDArrayImpl( values.getManager(), PyTorchLibrary.LIB.torchSparseCoo( shape.getShape(), indices.getHandle(), values.getHandle(), false)); @@ -315,7 +316,7 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) if (!device.equals(manager.getDevice())) { manager = manager.newSubManager(device); } - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchTo( ndArray.getHandle(), @@ -324,23 +325,23 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) } public static PtNDArray toSparse(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse(ndArray.getHandle())); } public static PtNDArray toDense(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchToDense(ndArray.getHandle())); } public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchExpand(ndArray.getHandle(), shape.getShape())); } public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSlice(ndArray.getHandle(), dim, start, stop, step)); } @@ -351,7 +352,7 @@ public static PtNDArray index( long[] maxIndices, long[] stepIndices, PtNDManager manager) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchIndex( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); @@ -412,7 +413,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); } @@ -498,7 +499,7 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } @@ -507,7 +508,7 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager man if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); } @@ -515,7 +516,7 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data) if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchPut( ndArray.getHandle(), index.getHandle(), data.getHandle())); @@ -547,20 +548,20 @@ public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchWhere( condition.getHandle(), self.getHandle(), other.getHandle())); } public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect(ndArray.getHandle(), indicesNd.getHandle())); } @@ -575,102 +576,102 @@ public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager m // due to significant performance gain // for commonly used data loading call if (indices.length == 1) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle())); } public static PtNDArray reshape(PtNDArray ndArray, long[] shape) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchReshape(ndArray.getHandle(), shape)); } public static PtNDArray stack(PtNDArray[] arrays, int dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); + return new PtNDArrayImpl(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); } public static PtNDArray cat(PtNDArray[] arrays, long dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); + return new PtNDArrayImpl(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); } public static PtNDArray tile(PtNDArray ndArray, long[] repeats) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat(ndArray.getHandle(), repeats)); } public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(ndArray.getHandle(), repeat, dim)); } public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray argMax(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle())); } public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argMin(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle())); } public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSort(ndArray.getHandle(), dim, descending)); } public static PtNDArray permute(PtNDArray ndArray, long[] dims) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchPermute(ndArray.getHandle(), dims)); } public static PtNDArray flip(PtNDArray ndArray, long[] dims) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchFlip(ndArray.getHandle(), dims)); } public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose(ndArray.getHandle(), dim1, dim2)); } @@ -680,7 +681,7 @@ public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -690,7 +691,7 @@ public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchSub(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -700,7 +701,7 @@ public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchMul(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -710,7 +711,7 @@ public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -720,7 +721,7 @@ public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -730,7 +731,7 @@ public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchPow(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -740,7 +741,7 @@ public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sign(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSign(ndArray.getHandle())); } @@ -749,104 +750,104 @@ public static void signi(PtNDArray ndArray) { } public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalNot(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot(ndArray.getHandle())); } public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) { if (ndArray1.getShape().dimension() == 1) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchDot(ndArray1.getHandle(), ndArray2.getHandle())); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray min(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray mean(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); } public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchRot90(ndArray.getHandle(), times, longaxes)); } public static PtNDArray sum(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle(), dims, keepDim)); } @@ -856,29 +857,29 @@ public static PtNDArray cumProd(PtNDArray ndArray, long dim, DataType dataType) if (dataType != null) { dtPosition = dataType.ordinal(); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchCumProd(ndArray.getHandle(), dim, dtPosition)); } public static PtNDArray prod(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle())); } public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray cumSum(PtNDArray ndArray, long dim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNOneHot( ndArray.toType(DataType.INT64, false).getHandle(), depth)) @@ -889,7 +890,7 @@ public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArray(ndArray.getManager(), ptr)); + list.add(new PtNDArrayImpl(ndArray.getManager(), ptr)); } return list; } @@ -898,196 +899,196 @@ public static NDList split(PtNDArray ndArray, long[] indices, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArray(ndArray.getManager(), ptr)); + list.add(new PtNDArrayImpl(ndArray.getManager(), ptr)); } return list; } public static PtNDArray squeeze(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle())); } public static PtNDArray squeeze(PtNDArray ndArray, long dim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle(), dim)); } public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze(ndArray.getHandle(), dim)); } public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten(ndArray.getHandle(), startDim, endDim)); } public static PtNDArray abs(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSquare(ndArray.getHandle())); } public static PtNDArray floor(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchFloor(ndArray.getHandle())); } public static PtNDArray ceil(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchCeil(ndArray.getHandle())); } public static PtNDArray round(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchRound(ndArray.getHandle())); } public static PtNDArray trunc(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc(ndArray.getHandle())); } public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { PtNDArray minNd = (PtNDArray) ndArray.getManager().create(min); PtNDArray maxNd = (PtNDArray) ndArray.getManager().create(max); - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchClamp( ndArray.getHandle(), minNd.getHandle(), maxNd.getHandle())); } public static PtNDArray exp(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchLog10(ndArray.getHandle())); } public static PtNDArray log2(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchLog2(ndArray.getHandle())); } public static PtNDArray sin(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchASin(ndArray.getHandle())); } public static PtNDArray acos(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchAcos(ndArray.getHandle())); } public static PtNDArray atan(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); } public static PtNDArray sinh(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSinh(ndArray.getHandle())); } public static PtNDArray cosh(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchCosh(ndArray.getHandle())); } public static PtNDArray tanh(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchTanh(ndArray.getHandle())); } public static PtNDArray sigmoid(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid(ndArray.getHandle())); } public static PtNDArray all(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNone(ndArray.getHandle())); } public static PtNDArray eq(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchEq(self.getHandle(), other.getHandle())); } public static PtNDArray neq(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchNeq(self.getHandle(), other.getHandle())); } public static PtNDArray gt(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchGt(self.getHandle(), other.getHandle())); } public static PtNDArray gte(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchGte(self.getHandle(), other.getHandle())); } public static PtNDArray lt(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchLt(self.getHandle(), other.getHandle())); } public static PtNDArray lte(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return new PtNDArrayImpl( self.getManager(), PyTorchLibrary.LIB.torchLte(self.getHandle(), other.getHandle())); } public static PtNDArray neg(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } @@ -1096,12 +1097,12 @@ public static void negi(PtNDArray ndArray) { } public static PtNDArray isNaN(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN(ndArray.getHandle())); } public static PtNDArray isInf(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf(ndArray.getHandle())); } @@ -1112,7 +1113,7 @@ public static PtNDArray randint( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchRandint( low, @@ -1126,7 +1127,7 @@ public static PtNDArray randint( public static PtNDArray randperm( PtNDManager manager, long n, DataType dataType, Device device) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchRandPerm( n, @@ -1143,7 +1144,7 @@ public static PtNDArray normal( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchNormal( mean, @@ -1162,7 +1163,7 @@ public static PtNDArray uniform( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.tensorUniform( low, @@ -1176,7 +1177,7 @@ public static PtNDArray uniform( public static PtNDArray eye( PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) { - return new PtNDArray( + return new PtNDArrayImpl( manager, PyTorchLibrary.LIB.torchEye( n, @@ -1188,25 +1189,25 @@ public static PtNDArray eye( } public static PtNDArray erfinv(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } public static PtNDArray inverse(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); } public static PtNDArray interpolate( - PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { - return new PtNDArray( + PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( ndArray.getHandle(), size, mode, alignCorners)); } public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) { - return new PtNDArray( + return new PtNDArrayImpl( input.getManager(), PyTorchLibrary.LIB.torchNNLinear( input.getHandle(), @@ -1215,44 +1216,44 @@ public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias } public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) { - return new PtNDArray( + return new PtNDArrayImpl( input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding(input.getHandle(), weight.getHandle(), sparse)); } public static PtNDArray relu(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu(ndArray.getHandle())); } public static PtNDArray softPlus(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(ndArray.getHandle())); } public static PtNDArray softSign(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign(ndArray.getHandle())); } public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(ndArray.getHandle(), negativeSlope)); } public static PtNDArray elu(PtNDArray ndArray, double alpha) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu(ndArray.getHandle(), alpha)); } public static PtNDArray selu(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu(ndArray.getHandle())); } public static PtNDArray gelu(PtNDArray ndArray) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu(ndArray.getHandle())); } @@ -1264,7 +1265,7 @@ public static PtNDArray convolution( Shape padding, Shape dilation, int groups) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd( ndArray.getHandle(), @@ -1285,7 +1286,7 @@ public static PtNDArray batchNorm( boolean isTraining, double momentum, double eps) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm( ndArray.getHandle(), @@ -1299,8 +1300,8 @@ public static PtNDArray batchNorm( } public static PtNDArray layerNorm( - PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { - return new PtNDArray( + PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( ndArray.getHandle(), @@ -1311,13 +1312,13 @@ public static PtNDArray layerNorm( } public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); } public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } @@ -1350,7 +1351,7 @@ public static NDList rnn( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(new PtNDArrayImpl(manager, output)); } return res; } @@ -1381,7 +1382,7 @@ public static NDList gru( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(new PtNDArrayImpl(manager, output)); } return res; } @@ -1414,7 +1415,7 @@ public static NDList lstm( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(new PtNDArrayImpl(manager, output)); } return res; } @@ -1426,7 +1427,7 @@ public static PtNDArray avgPool( Shape padding, boolean ceilMode, boolean countIncludePad) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool( ndArray.getHandle(), @@ -1438,8 +1439,8 @@ public static PtNDArray avgPool( } public static PtNDArray maxPool( - PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { - return new PtNDArray( + PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( ndArray.getHandle(), @@ -1450,25 +1451,25 @@ public static PtNDArray maxPool( } public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray lpPool( - PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) { + PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) { if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool( ndArray.getHandle(), @@ -1533,7 +1534,7 @@ public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) { public static PtNDArray detachGradient(PtNDArray ndArray) { // TODO: detached ndarray may not use the same manager for the attached one - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad(ndArray.getHandle())); } @@ -1542,11 +1543,11 @@ public static PtNDArray getGradient(PtNDArray ndArray) { if (pointer == NULL_PTR) { return null; } - return new PtNDArray(ndArray.getManager(), pointer); + return new PtNDArrayImpl(ndArray.getManager(), pointer); } public static void backward( - PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) { + PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) { PyTorchLibrary.LIB.torchBackward( ndArray.getHandle(), gradNd.getHandle(), keepGraph, createGraph); } @@ -1622,7 +1623,7 @@ public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) { String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle()); NDList list = new NDList(handles.length); for (int i = 0; i < handles.length; i++) { - PtNDArray array = new PtNDArray(manager, handles[i]); + PtNDArray array = new PtNDArrayImpl(manager, handles[i]); array.setName(names[i]); list.add(array); } @@ -1698,7 +1699,7 @@ public static int getLayout(PtNDArray array) { public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) { long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNorm(ndArray.getHandle(), ord, longAxes, keepDims)); } @@ -1707,7 +1708,7 @@ public static PtNDArray nonZeros(PtNDArray ndArray) { if (ndArray.isScalar()) { ndArray = (PtNDArray) ndArray.reshape(-1); } - return new PtNDArray( + return new PtNDArrayImpl( ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros(ndArray.getHandle())); } } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java index 286ac02001c..efad9d17e95 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; import ai.djl.pytorch.jni.IValue; @@ -38,10 +39,10 @@ public class IValueTest { @Test public void testIValue() { try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { - PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1)); - PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1)); - PtNDArray array3 = (PtNDArray) manager.create("test"); - PtNDArray array4 = (PtNDArray) manager.create(new String[] {"test1", "test2"}); + PtNDArrayImpl array1 = (PtNDArrayImpl) manager.zeros(new Shape(1)); + PtNDArrayImpl array2 = (PtNDArrayImpl) manager.ones(new Shape(1)); + PtNDArrayImpl array3 = (PtNDArrayImpl) manager.create("test"); + PtNDArrayImpl array4 = (PtNDArrayImpl) manager.create(new String[] {"test1", "test2"}); try (IValue ivalue = IValue.from(array1)) { Assert.assertTrue(ivalue.isTensor()); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index 623b42cb1d3..7b1bfc14aa5 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -16,7 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.gc.NDArrayWrapFactory; +import ai.djl.pytorch.engine.PtNDArrayProxyMaker; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -34,24 +34,32 @@ private Main() {} public static void main(String[] args) throws IOException, TranslateException, InterruptedException { - NDArrayWrapFactory enh = new NDArrayWrapFactory(); - try (NDManager manager = NDManager.newBaseManager(); ) { + PtNDArrayProxyMaker enh = new PtNDArrayProxyMaker(); + try (NDManager baseManager = NDManager.newBaseManager(); ) { + try (NDManager subManager = baseManager.newSubManager()) { - NDArray a = enh.wrap(manager.create(new float[] {1f})); - debugDumpFromSystemManager(); + NDArray a = enh.wrap(subManager.create(new float[]{1f})); + NDArray b = enh.wrap(subManager.create(new float[]{2f})); + NDArray c = a.add(b); + debugDumpFromSystemManager(); - logger.info("reference exists ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); - a = null; - logger.info("no reference exists, but likely not yet garbage collected ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); + logger.info("reference exists ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); + a = null; + b = null; + c = null; + logger.info("no reference exists, but likely not yet garbage collected ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); - System.gc(); - TimeUnit.SECONDS.sleep(1); + System.gc(); // just for testing - do not use in production + TimeUnit.SECONDS.sleep(1); - logger.info("no reference exists, and likely garbage collected ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); + logger.info("no reference exists, and likely garbage collected ..."); + logger.info("weakHashMap size: {}", enh.mapSize()); + debugDumpFromSystemManager(); + } debugDumpFromSystemManager(); } + debugDumpFromSystemManager(); } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 0cf130c1ef9..7bac89ae583 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -23,7 +23,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.tensorflow.engine.javacpp.JavacppUtils; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Preconditions; import org.tensorflow.internal.c_api.TFE_TensorHandle; @@ -40,7 +40,7 @@ /** {@code TfNDArray} is the TensorFlow implementation of {@link NDArray}. */ @SuppressWarnings("PMD.UseTryWithResources") -public class TfNDArray extends NativeResource implements NDArray { +public class TfNDArray extends NativeResourceImpl implements NDArray { private Shape shape; private Device device; diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java index bad891e72da..32a02684f30 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java @@ -13,14 +13,14 @@ package ai.djl.fasttext.jni; import ai.djl.modality.Classifications; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import java.util.ArrayList; import java.util.List; /** A class containing utilities to interact with the fastText JNI layer. */ @SuppressWarnings("MissingJavadocMethod") -public final class FtWrapper extends NativeResource { +public final class FtWrapper extends NativeResourceImpl { private static RuntimeException libraryStatus; diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java index 44d05936e58..4cecc80de9e 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java @@ -14,10 +14,10 @@ import ai.djl.sentencepiece.jni.LibUtils; import ai.djl.sentencepiece.jni.SentencePieceLibrary; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; /** The processor holder for SentencePiece. */ -final class SpProcessor extends NativeResource { +final class SpProcessor extends NativeResourceImpl { private static RuntimeException libraryStatus; diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 8881292b820..6da9cbaca6d 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -18,7 +18,7 @@ import ai.djl.modality.nlp.preprocess.Tokenizer; import ai.djl.ndarray.NDManager; import ai.djl.translate.ArgumentsUtil; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Utils; import org.slf4j.Logger; @@ -37,7 +37,7 @@ * {@code HuggingFaceTokenizer} is a Huggingface tokenizer implementation of the {@link Tokenizer} * interface that converts sentences into token. */ -public final class HuggingFaceTokenizer extends NativeResource implements Tokenizer { +public final class HuggingFaceTokenizer extends NativeResourceImpl implements Tokenizer { private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class); From 5d67a61a1864f91184434c680298ce65664326f0 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 14:59:55 +0100 Subject: [PATCH 03/30] creating NDArrays with or without proxy --- api/src/main/java/ai/djl/engine/Engine.java | 6 + .../java/ai/djl/ndarray/BaseNDManager.java | 18 +- .../main/java/ai/djl/ndarray/NDManager.java | 21 ++ .../java/ai/djl/pytorch/engine/PtEngine.java | 5 + .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 35 +- .../ai/djl/pytorch/engine/PtNDManager.java | 42 ++- .../main/java/ai/djl/pytorch/jni/IValue.java | 6 +- .../java/ai/djl/pytorch/jni/JniUtils.java | 298 +++++++++--------- .../ai/djl/pytorch/integration/gc/Main.java | 14 +- 9 files changed, 278 insertions(+), 167 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 73d19779bd5..34bd06973f6 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -289,6 +289,12 @@ public int getGpuCount() { * * @return a new top-level {@code NDManager} */ + + + + + public abstract NDManager newBaseManager( boolean useProxies); + public abstract NDManager newBaseManager(); /** diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 1f925b048e4..fbecf084c11 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -43,6 +43,7 @@ public abstract class BaseNDManager implements NDManager { private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class); + private final boolean useProxies; protected NDManager parent; protected NDManager alternativeManager; @@ -55,8 +56,12 @@ public abstract class BaseNDManager implements NDManager { protected AtomicBoolean capped = new AtomicBoolean(false); protected BaseNDManager(NDManager parent, Device device) { + this(parent, device, false); + } + protected BaseNDManager(NDManager parent, Device device, boolean useProxies) { this.parent = parent; this.device = device == null ? defaultDevice() : device; + this.useProxies = useProxies; resources = new ConcurrentHashMap<>(); tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); @@ -78,6 +83,10 @@ public NDArray create(String[] data, Charset charset, Shape shape) { throw new UnsupportedOperationException("Not supported!"); } + public boolean isUseProxies() { + return useProxies; + } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { @@ -298,7 +307,12 @@ public NDManager getParentManager() { /** {@inheritDoc} */ @Override public NDManager newSubManager() { - return newSubManager(device); + return newSubManager(device, useProxies); + } + + @Override + public NDManager newSubManager(boolean useProxies) { + return newSubManager(device, useProxies); } /** {@inheritDoc} */ @@ -588,6 +602,8 @@ public static void copyBuffer(Buffer src, ByteBuffer target) { target.rewind(); } + public abstract NDManager newSubManager(Device device, boolean useProxies); + protected static final class TempResource { private NDResource resource; diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 26ffcd7aa4a..58c61963f2a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.ndarray.gc.NDArrayProxyMaker; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.translate.Translator; @@ -105,6 +106,18 @@ */ public interface NDManager extends AutoCloseable { + + /** + * Creates a new top-level {@code NDManager}. + * + *

{@code NDManager} will inherit default {@link Device}. + * + * @return a new top-level {@code NDManager} + */ + static NDManager newBaseManager(boolean useProxies) { + return Engine.getInstance().newBaseManager(useProxies); + } + /** * Creates a new top-level {@code NDManager}. * @@ -717,6 +730,10 @@ default NDArray decode(InputStream is) throws IOException { */ NDList load(Path path); + default NDArrayProxyMaker getProxyMaker() { + throw new UnsupportedOperationException("Not supported"); + } + /** * Loads the NDArrays saved to a file. * @@ -1499,6 +1516,10 @@ default NDArray truncatedNormal( */ NDManager newSubManager(Device device); + default NDManager newSubManager(boolean useProxies) { + throw new UnsupportedOperationException("useProxies not supported here"); + } + /** * Returns the default {@link Device} of this {@code NDManager}. * diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index d0f7a101394..1a0152702eb 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -128,6 +128,11 @@ public Model newModel(String name, Device device) { return new PtModel(name, device); } + + @Override + public NDManager newBaseManager(boolean useProxies) { + return PtNDManager.getSystemManager().newSubManager(useProxies); + } /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 0aca7f8fa55..b5489144da6 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -52,6 +52,32 @@ public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray @SuppressWarnings("PMD.UnusedPrivateField") private ByteBuffer dataRef; + + public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { + PtNDArray instance = new PtNDArrayImpl(manager, handle); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { + PtNDArray instance = new PtNDArrayImpl(manager, handle, data); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { + PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} * instead). @@ -59,7 +85,7 @@ public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray * @param manager the manager to attach the new array to * @param handle the pointer to the native PyTorch memory */ - public PtNDArrayImpl(PtNDManager manager, long handle) { + private PtNDArrayImpl(PtNDManager manager, long handle) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); @@ -74,7 +100,7 @@ public PtNDArrayImpl(PtNDManager manager, long handle) { * @param handle the pointer to the native PyTorch memory * @param data the direct buffer of the data */ - public PtNDArrayImpl(PtNDManager manager, long handle, ByteBuffer data) { + private PtNDArrayImpl(PtNDManager manager, long handle, ByteBuffer data) { super(handle); this.manager = manager; this.ptNDArrayEx = new PtNDArrayEx(this); @@ -90,7 +116,7 @@ public PtNDArrayImpl(PtNDManager manager, long handle, ByteBuffer data) { * @param strs the string array * @param shape the {@link Shape} of the {@link NDArray} */ - public PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { + private PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { super(-1L); this.manager = manager; this.strs = strs; @@ -98,6 +124,9 @@ public PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { this.dataType = DataType.STRING; } + + + /** {@inheritDoc} */ @Override public PtNDManager getManager() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 24bb7b972be..fabc4e2ab5e 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -27,19 +27,32 @@ import java.nio.ByteOrder; import java.nio.charset.Charset; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + /** {@code PtNDManager} is the PyTorch implementation of {@link NDManager}. */ -public class PtNDManager extends BaseNDManager { +public class PtNDManager extends BaseNDManager { private static final PtNDManager SYSTEM_MANAGER = new SystemManager(); - private PtNDManager(NDManager parent, Device device) { - super(parent, device); + protected PtNDArrayProxyMaker proxyMaker; + + private PtNDManager(NDManager parent, Device device, boolean useProxies) { + super(parent, device, useProxies); } + static PtNDManager getSystemManager() { return SYSTEM_MANAGER; } + + + public PtNDArrayProxyMaker getProxyMaker() { + return getSystemManager().getProxyMaker(); + } + + + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -81,7 +94,7 @@ public PtNDArray create(Buffer data, Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { - return new PtNDArrayImpl(this, data, shape); + return newPtNDArray(this, data, shape); } /** {@inheritDoc} */ @@ -172,10 +185,19 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy return JniUtils.normal(this, loc, scale, shape, dataType, device); } + /** {@inheritDoc} */ @Override public PtNDManager newSubManager(Device device) { - PtNDManager manager = new PtNDManager(this, device); + PtNDManager manager = new PtNDManager(this, device, isUseProxies()); + attachUncappedInternal(manager.uid, manager); + return manager; + } + + + @Override + public NDManager newSubManager(Device device, boolean useProxies) { + PtNDManager manager = new PtNDManager(this, device, useProxies); attachUncappedInternal(manager.uid, manager); return manager; } @@ -195,7 +217,15 @@ public static void debugDumpFromSystemManager() { private static final class SystemManager extends PtNDManager implements SystemNDManager { SystemManager() { - super(null, null); + super(null, null, false); + this.proxyMaker = new PtNDArrayProxyMaker(); + } + + + + @Override + public PtNDArrayProxyMaker getProxyMaker() { + return this.proxyMaker; } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java index c1804735fdc..3094c6e7006 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -24,6 +24,8 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + /** * A class represent a PyTorch {@code IValue} data. * @@ -393,7 +395,7 @@ public double[] toDoubleArray() { * @return the NDArray value of this IValue */ public PtNDArray toTensor(PtNDManager manager) { - return new PtNDArrayImpl(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); + return newPtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); } /** @@ -406,7 +408,7 @@ public PtNDArray[] toTensorArray(PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); PtNDArray[] ret = new PtNDArrayImpl[handles.length]; for (int i = 0; i < ret.length; ++i) { - ret[i] = new PtNDArrayImpl(manager, handles[i]); + ret[i] = newPtNDArray(manager, handles[i]); } return ret; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 275eb830848..1ff735d00fc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -50,6 +50,8 @@ import java.util.ListIterator; import java.util.Set; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + /** * A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI) * layer. @@ -171,15 +173,15 @@ public static PtNDArray createNdFromByteBuffer( if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side - return new PtNDArrayImpl(manager, handle); + return newPtNDArray(manager, handle); } - return new PtNDArrayImpl(manager, handle, data); + return newPtNDArray(manager, handle, data); } public static PtNDArray createEmptyNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchEmpty( shape.getShape(), @@ -192,7 +194,7 @@ public static PtNDArray createEmptyNdArray( public static PtNDArray createZerosNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchZeros( shape.getShape(), @@ -205,7 +207,7 @@ public static PtNDArray createZerosNdArray( public static PtNDArray createOnesNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchOnes( shape.getShape(), @@ -223,7 +225,7 @@ public static PtNDArray full( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchFull( shape.getShape(), @@ -237,7 +239,7 @@ public static PtNDArray full( public static PtNDArray zerosLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchZerosLike( array.getHandle(), @@ -250,7 +252,7 @@ public static PtNDArray zerosLike( public static PtNDArray onesLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchOnesLike( array.getHandle(), @@ -269,7 +271,7 @@ public static PtNDArray arange( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchArange( start, @@ -290,7 +292,7 @@ public static PtNDArray linspace( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchLinspace( start, @@ -303,7 +305,7 @@ public static PtNDArray linspace( } public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) { - return new PtNDArrayImpl( + return newPtNDArray( values.getManager(), PyTorchLibrary.LIB.torchSparseCoo( shape.getShape(), indices.getHandle(), values.getHandle(), false)); @@ -316,7 +318,7 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) if (!device.equals(manager.getDevice())) { manager = manager.newSubManager(device); } - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchTo( ndArray.getHandle(), @@ -325,23 +327,23 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) } public static PtNDArray toSparse(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse(ndArray.getHandle())); } public static PtNDArray toDense(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToDense(ndArray.getHandle())); } public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExpand(ndArray.getHandle(), shape.getShape())); } public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSlice(ndArray.getHandle(), dim, start, stop, step)); } @@ -352,7 +354,7 @@ public static PtNDArray index( long[] maxIndices, long[] stepIndices, PtNDManager manager) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchIndex( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); @@ -413,7 +415,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); } @@ -499,7 +501,7 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } @@ -508,7 +510,7 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager man if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); } @@ -516,7 +518,7 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data) if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPut( ndArray.getHandle(), index.getHandle(), data.getHandle())); @@ -548,20 +550,20 @@ public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchWhere( condition.getHandle(), self.getHandle(), other.getHandle())); } public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect(ndArray.getHandle(), indicesNd.getHandle())); } @@ -576,102 +578,102 @@ public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager m // due to significant performance gain // for commonly used data loading call if (indices.length == 1) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle())); } public static PtNDArray reshape(PtNDArray ndArray, long[] shape) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchReshape(ndArray.getHandle(), shape)); } public static PtNDArray stack(PtNDArray[] arrays, int dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArrayImpl(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); + return newPtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); } public static PtNDArray cat(PtNDArray[] arrays, long dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArrayImpl(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); + return newPtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); } public static PtNDArray tile(PtNDArray ndArray, long[] repeats) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat(ndArray.getHandle(), repeats)); } public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(ndArray.getHandle(), repeat, dim)); } public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray argMax(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle())); } public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argMin(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle())); } public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSort(ndArray.getHandle(), dim, descending)); } public static PtNDArray permute(PtNDArray ndArray, long[] dims) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPermute(ndArray.getHandle(), dims)); } public static PtNDArray flip(PtNDArray ndArray, long[] dims) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlip(ndArray.getHandle(), dims)); } public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose(ndArray.getHandle(), dim1, dim2)); } @@ -681,7 +683,7 @@ public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -691,7 +693,7 @@ public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchSub(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -701,7 +703,7 @@ public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMul(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -711,7 +713,7 @@ public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -721,7 +723,7 @@ public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -731,7 +733,7 @@ public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchPow(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -741,7 +743,7 @@ public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sign(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSign(ndArray.getHandle())); } @@ -750,104 +752,104 @@ public static void signi(PtNDArray ndArray) { } public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalNot(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot(ndArray.getHandle())); } public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) { if (ndArray1.getShape().dimension() == 1) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchDot(ndArray1.getHandle(), ndArray2.getHandle())); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray min(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray mean(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); } public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRot90(ndArray.getHandle(), times, longaxes)); } public static PtNDArray sum(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle(), dims, keepDim)); } @@ -857,29 +859,29 @@ public static PtNDArray cumProd(PtNDArray ndArray, long dim, DataType dataType) if (dataType != null) { dtPosition = dataType.ordinal(); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumProd(ndArray.getHandle(), dim, dtPosition)); } public static PtNDArray prod(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle())); } public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray cumSum(PtNDArray ndArray, long dim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNOneHot( ndArray.toType(DataType.INT64, false).getHandle(), depth)) @@ -890,7 +892,7 @@ public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArrayImpl(ndArray.getManager(), ptr)); + list.add(newPtNDArray(ndArray.getManager(), ptr)); } return list; } @@ -899,196 +901,196 @@ public static NDList split(PtNDArray ndArray, long[] indices, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArrayImpl(ndArray.getManager(), ptr)); + list.add(newPtNDArray(ndArray.getManager(), ptr)); } return list; } public static PtNDArray squeeze(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle())); } public static PtNDArray squeeze(PtNDArray ndArray, long dim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle(), dim)); } public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze(ndArray.getHandle(), dim)); } public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten(ndArray.getHandle(), startDim, endDim)); } public static PtNDArray abs(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSquare(ndArray.getHandle())); } public static PtNDArray floor(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFloor(ndArray.getHandle())); } public static PtNDArray ceil(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCeil(ndArray.getHandle())); } public static PtNDArray round(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRound(ndArray.getHandle())); } public static PtNDArray trunc(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc(ndArray.getHandle())); } public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { PtNDArray minNd = (PtNDArray) ndArray.getManager().create(min); PtNDArray maxNd = (PtNDArray) ndArray.getManager().create(max); - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchClamp( ndArray.getHandle(), minNd.getHandle(), maxNd.getHandle())); } public static PtNDArray exp(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog10(ndArray.getHandle())); } public static PtNDArray log2(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog2(ndArray.getHandle())); } public static PtNDArray sin(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchASin(ndArray.getHandle())); } public static PtNDArray acos(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAcos(ndArray.getHandle())); } public static PtNDArray atan(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); } public static PtNDArray sinh(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSinh(ndArray.getHandle())); } public static PtNDArray cosh(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCosh(ndArray.getHandle())); } public static PtNDArray tanh(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTanh(ndArray.getHandle())); } public static PtNDArray sigmoid(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid(ndArray.getHandle())); } public static PtNDArray all(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNone(ndArray.getHandle())); } public static PtNDArray eq(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchEq(self.getHandle(), other.getHandle())); } public static PtNDArray neq(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchNeq(self.getHandle(), other.getHandle())); } public static PtNDArray gt(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGt(self.getHandle(), other.getHandle())); } public static PtNDArray gte(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGte(self.getHandle(), other.getHandle())); } public static PtNDArray lt(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLt(self.getHandle(), other.getHandle())); } public static PtNDArray lte(PtNDArray self, PtNDArray other) { - return new PtNDArrayImpl( + return newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLte(self.getHandle(), other.getHandle())); } public static PtNDArray neg(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } @@ -1097,12 +1099,12 @@ public static void negi(PtNDArray ndArray) { } public static PtNDArray isNaN(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN(ndArray.getHandle())); } public static PtNDArray isInf(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf(ndArray.getHandle())); } @@ -1113,7 +1115,7 @@ public static PtNDArray randint( Shape size, DataType dataType, Device device) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchRandint( low, @@ -1127,7 +1129,7 @@ public static PtNDArray randint( public static PtNDArray randperm( PtNDManager manager, long n, DataType dataType, Device device) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchRandPerm( n, @@ -1144,7 +1146,7 @@ public static PtNDArray normal( Shape size, DataType dataType, Device device) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchNormal( mean, @@ -1163,7 +1165,7 @@ public static PtNDArray uniform( Shape size, DataType dataType, Device device) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.tensorUniform( low, @@ -1177,7 +1179,7 @@ public static PtNDArray uniform( public static PtNDArray eye( PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) { - return new PtNDArrayImpl( + return newPtNDArray( manager, PyTorchLibrary.LIB.torchEye( n, @@ -1189,25 +1191,25 @@ public static PtNDArray eye( } public static PtNDArray erfinv(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } public static PtNDArray inverse(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); } public static PtNDArray interpolate( PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( ndArray.getHandle(), size, mode, alignCorners)); } public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) { - return new PtNDArrayImpl( + return newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNLinear( input.getHandle(), @@ -1216,44 +1218,44 @@ public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias } public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) { - return new PtNDArrayImpl( + return newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding(input.getHandle(), weight.getHandle(), sparse)); } public static PtNDArray relu(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu(ndArray.getHandle())); } public static PtNDArray softPlus(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(ndArray.getHandle())); } public static PtNDArray softSign(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign(ndArray.getHandle())); } public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(ndArray.getHandle(), negativeSlope)); } public static PtNDArray elu(PtNDArray ndArray, double alpha) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu(ndArray.getHandle(), alpha)); } public static PtNDArray selu(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu(ndArray.getHandle())); } public static PtNDArray gelu(PtNDArray ndArray) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu(ndArray.getHandle())); } @@ -1265,7 +1267,7 @@ public static PtNDArray convolution( Shape padding, Shape dilation, int groups) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd( ndArray.getHandle(), @@ -1286,7 +1288,7 @@ public static PtNDArray batchNorm( boolean isTraining, double momentum, double eps) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm( ndArray.getHandle(), @@ -1301,7 +1303,7 @@ public static PtNDArray batchNorm( public static PtNDArray layerNorm( PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( ndArray.getHandle(), @@ -1312,13 +1314,13 @@ public static PtNDArray layerNorm( } public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); } public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } @@ -1351,7 +1353,7 @@ public static NDList rnn( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArrayImpl(manager, output)); + res.add(newPtNDArray(manager, output)); } return res; } @@ -1382,7 +1384,7 @@ public static NDList gru( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArrayImpl(manager, output)); + res.add(newPtNDArray(manager, output)); } return res; } @@ -1415,7 +1417,7 @@ public static NDList lstm( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArrayImpl(manager, output)); + res.add(newPtNDArray(manager, output)); } return res; } @@ -1427,7 +1429,7 @@ public static PtNDArray avgPool( Shape padding, boolean ceilMode, boolean countIncludePad) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool( ndArray.getHandle(), @@ -1440,7 +1442,7 @@ public static PtNDArray avgPool( public static PtNDArray maxPool( PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( ndArray.getHandle(), @@ -1451,14 +1453,14 @@ public static PtNDArray maxPool( } public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool( ndArray.getHandle(), outputSize.getShape())); @@ -1469,7 +1471,7 @@ public static PtNDArray lpPool( if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool( ndArray.getHandle(), @@ -1534,7 +1536,7 @@ public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) { public static PtNDArray detachGradient(PtNDArray ndArray) { // TODO: detached ndarray may not use the same manager for the attached one - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad(ndArray.getHandle())); } @@ -1543,7 +1545,7 @@ public static PtNDArray getGradient(PtNDArray ndArray) { if (pointer == NULL_PTR) { return null; } - return new PtNDArrayImpl(ndArray.getManager(), pointer); + return newPtNDArray(ndArray.getManager(), pointer); } public static void backward( @@ -1623,7 +1625,7 @@ public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) { String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle()); NDList list = new NDList(handles.length); for (int i = 0; i < handles.length; i++) { - PtNDArray array = new PtNDArrayImpl(manager, handles[i]); + PtNDArray array = newPtNDArray(manager, handles[i]); array.setName(names[i]); list.add(array); } @@ -1699,7 +1701,7 @@ public static int getLayout(PtNDArray array) { public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) { long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNorm(ndArray.getHandle(), ord, longAxes, keepDims)); } @@ -1708,7 +1710,7 @@ public static PtNDArray nonZeros(PtNDArray ndArray) { if (ndArray.isScalar()) { ndArray = (PtNDArray) ndArray.reshape(-1); } - return new PtNDArrayImpl( + return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros(ndArray.getHandle())); } } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index 7b1bfc14aa5..fea3711b94b 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -34,28 +34,28 @@ private Main() {} public static void main(String[] args) throws IOException, TranslateException, InterruptedException { - PtNDArrayProxyMaker enh = new PtNDArrayProxyMaker(); - try (NDManager baseManager = NDManager.newBaseManager(); ) { + + try (NDManager baseManager = NDManager.newBaseManager(true); ) { try (NDManager subManager = baseManager.newSubManager()) { - NDArray a = enh.wrap(subManager.create(new float[]{1f})); - NDArray b = enh.wrap(subManager.create(new float[]{2f})); + NDArray a = subManager.create(new float[]{1f}); + NDArray b = subManager.create(new float[]{2f}); NDArray c = a.add(b); debugDumpFromSystemManager(); logger.info("reference exists ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); a = null; b = null; c = null; logger.info("no reference exists, but likely not yet garbage collected ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); System.gc(); // just for testing - do not use in production TimeUnit.SECONDS.sleep(1); logger.info("no reference exists, and likely garbage collected ..."); - logger.info("weakHashMap size: {}", enh.mapSize()); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); debugDumpFromSystemManager(); } debugDumpFromSystemManager(); From 396da5764cbfc700a5a4510228330d509aabe07f Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 16:49:22 +0100 Subject: [PATCH 04/30] build without test, poc uses switch successfully --- api/src/main/java/ai/djl/engine/Engine.java | 14 +++-- .../java/ai/djl/ndarray/BaseNDManager.java | 27 ++++++--- .../main/java/ai/djl/ndarray/NDManager.java | 43 +++++++++++--- .../ndarray/gc/DynamicInvocationHandler.java | 2 +- .../ai/djl/ndarray/gc/NDArrayProxyMaker.java | 8 +-- .../main/java/ai/djl/util/NativeResource.java | 36 ++++++++++++ .../java/ai/djl/util/NativeResourceImpl.java | 24 ++------ .../passthrough/PassthroughNDManager.java | 18 ++++++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 2 +- .../djl/mxnet/engine/MxParameterServer.java | 2 +- .../java/ai/djl/pytorch/engine/PtEngine.java | 1 - .../java/ai/djl/pytorch/engine/PtNDArray.java | 1 - .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 5 -- .../pytorch/engine/PtNDArrayProxyMaker.java | 2 +- .../ai/djl/pytorch/engine/PtNDManager.java | 15 +---- .../main/java/ai/djl/pytorch/jni/IValue.java | 4 +- .../java/ai/djl/pytorch/jni/JniUtils.java | 58 +++++++------------ .../ai/djl/pytorch/integration/gc/Main.java | 5 +- 18 files changed, 158 insertions(+), 109 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 34bd06973f6..314b7d3359c 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -287,14 +287,18 @@ public int getGpuCount() { * *

{@code NDManager} will inherit default {@link Device}. * + * @param useProxies whether to facade resources with a proxy * @return a new top-level {@code NDManager} */ + public abstract NDManager newBaseManager(boolean useProxies); - - - - public abstract NDManager newBaseManager( boolean useProxies); - + /** + * Creates a new top-level {@link NDManager}. + * + *

{@code NDManager} will inherit default {@link Device}. + * + * @return a new top-level {@code NDManager} + */ public abstract NDManager newBaseManager(); /** diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index fbecf084c11..d56aae3bf5f 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -58,6 +58,7 @@ public abstract class BaseNDManager implements NDManager { protected BaseNDManager(NDManager parent, Device device) { this(parent, device, false); } + protected BaseNDManager(NDManager parent, Device device, boolean useProxies) { this.parent = parent; this.device = device == null ? defaultDevice() : device; @@ -82,17 +83,18 @@ public final Device defaultDevice() { public NDArray create(String[] data, Charset charset, Shape shape) { throw new UnsupportedOperationException("Not supported!"); } - - public boolean isUseProxies() { - return useProxies; - } - /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ + @Override + public boolean isUseProxies() { + return useProxies; + } + /** {@inheritDoc} */ @Override public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { @@ -310,11 +312,24 @@ public NDManager newSubManager() { return newSubManager(device, useProxies); } + /** {@inheritDoc} */ @Override public NDManager newSubManager(boolean useProxies) { return newSubManager(device, useProxies); } + /** {@inheritDoc} */ + @Override + public NDManager newSubManager(Device device) { + return newSubManager(device, useProxies); + } + + /** {@inheritDoc} */ + @Override + public NDManager newSubManager(Device device, boolean useProxies) { + throw new UnsupportedOperationException("Not supported!"); + } + /** {@inheritDoc} */ @Override public Device getDevice() { @@ -602,8 +617,6 @@ public static void copyBuffer(Buffer src, ByteBuffer target) { target.rewind(); } - public abstract NDManager newSubManager(Device device, boolean useProxies); - protected static final class TempResource { private NDResource resource; diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 58c61963f2a..6a20c70edf9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -106,12 +106,12 @@ */ public interface NDManager extends AutoCloseable { - /** * Creates a new top-level {@code NDManager}. * *

{@code NDManager} will inherit default {@link Device}. * + * @param useProxies whether to facade {@link NDArray} behind a proxy * @return a new top-level {@code NDManager} */ static NDManager newBaseManager(boolean useProxies) { @@ -730,10 +730,6 @@ default NDArray decode(InputStream is) throws IOException { */ NDList load(Path path); - default NDArrayProxyMaker getProxyMaker() { - throw new UnsupportedOperationException("Not supported"); - } - /** * Loads the NDArrays saved to a file. * @@ -748,6 +744,15 @@ default NDList load(Path path, Device device) { return newSubManager(device).load(path); } + /** + * Returns the {@link NDArrayProxyMaker}. + * + * @return the {@link NDArrayProxyMaker} + */ + default NDArrayProxyMaker getProxyMaker() { + throw new UnsupportedOperationException("Not supported"); + } + /** * Sets the name for the NDManager. * @@ -762,6 +767,13 @@ default NDList load(Path path, Device device) { */ String getName(); + /** + * Returns useProxies. + * + * @return useProxies + */ + boolean isUseProxies(); + /** * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros. * @@ -1516,9 +1528,24 @@ default NDArray truncatedNormal( */ NDManager newSubManager(Device device); - default NDManager newSubManager(boolean useProxies) { - throw new UnsupportedOperationException("useProxies not supported here"); - } + /** + * Creates a child {@code NDManager} with specified boolean switch useProxies and will inherit + * default {@link Device} from this {@code NDManager}. + * + * @param useProxies the boolean switch to use proxies + * @return a child {@code NDManager} + */ + NDManager newSubManager(boolean useProxies); + + /** + * Creates a child {@code NDManager} with specified default {@link Device} and the boolean + * switch useProxies. + * + * @param device the default {@link Device} + * @param useProxies the boolean switch to use proxies + * @return a child {@code NDManager} + */ + NDManager newSubManager(Device device, boolean useProxies); /** * Returns the default {@link Device} of this {@code NDManager}. diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java index 5947ec102c7..d89f9298892 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -37,7 +37,7 @@ public class DynamicInvocationHandler implements InvocationHandler { * * @param uuid the uuid * @param map the map - * @param gcAttacher the ndArrayProxyMaker + * @param ndArrayProxyMaker the ndArrayProxyMaker */ public DynamicInvocationHandler( UUID uuid, WeakHashMapWrapper map, NDArrayProxyMaker ndArrayProxyMaker) { diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java index e83a2a27f64..b6f6786a444 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java +++ b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java @@ -14,19 +14,15 @@ import ai.djl.ndarray.NDArray; -import java.lang.reflect.Proxy; -import java.util.UUID; - /** {@code PtNDArrayProxyMaker} creates a proxy facade. */ public interface NDArrayProxyMaker { - /** * Returns the size of the map. * * @return the size of the map */ - int mapSize(); + int mapSize(); /** * Wraps the {@link NDArray} in a proxy facade. @@ -34,5 +30,5 @@ public interface NDArrayProxyMaker { * @param array the array to wrap * @return the wrapped array */ - NDArray wrap(NDArray array) ; + NDArray wrap(NDArray array); } diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index c72ba39e9ee..978554c7b4f 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -1,12 +1,48 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ package ai.djl.util; +import com.sun.jna.Pointer; + +/** + * {@code NativeResource} is an interface for {@link AutoCloseable} blocks of memory created in the + * different engines. + * + * @param the resource that could map to a native pointer or java object + */ public interface NativeResource extends AutoCloseable { + /** + * Gets the boolean that indicates whether this resource has been released. + * + * @return whether this resource has been released + */ boolean isReleased(); + /** + * Gets the {@link Pointer} to this resource. + * + * @return the {@link Pointer} to this resource + */ T getHandle(); + /** + * Gets the unique ID of this resource. + * + * @return the unique ID of this resource + */ String getUid(); + /** {@inheritDoc} */ @Override void close(); } diff --git a/api/src/main/java/ai/djl/util/NativeResourceImpl.java b/api/src/main/java/ai/djl/util/NativeResourceImpl.java index 0ccfa351336..4133c01c772 100644 --- a/api/src/main/java/ai/djl/util/NativeResourceImpl.java +++ b/api/src/main/java/ai/djl/util/NativeResourceImpl.java @@ -12,13 +12,11 @@ */ package ai.djl.util; -import com.sun.jna.Pointer; - import java.util.concurrent.atomic.AtomicReference; /** - * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in - * the different engines. + * {@code NativeResourceImpl} is an internal class for {@link AutoCloseable} blocks of memory + * created in the different engines. * * @param the resource that could map to a native pointer or java object */ @@ -32,21 +30,13 @@ protected NativeResourceImpl(T handle) { uid = handle.toString(); } - /** - * Gets the boolean that indicates whether this resource has been released. - * - * @return whether this resource has been released - */ + /** {@inheritDoc} */ @Override public boolean isReleased() { return handle.get() == null; } - /** - * Gets the {@link Pointer} to this resource. - * - * @return the {@link Pointer} to this resource - */ + /** {@inheritDoc} */ @Override public T getHandle() { T reference = handle.get(); @@ -56,11 +46,7 @@ public T getHandle() { return reference; } - /** - * Gets the unique ID of this resource. - * - * @return the unique ID of this resource - */ + /** {@inheritDoc} */ @Override public final String getUid() { return uid; diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index bde8a89137a..eaaa710afee 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -85,6 +85,12 @@ public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { throw new UnsupportedOperationException(UNSUPPORTED); } + /** {@inheritDoc} */ + @Override + public boolean isUseProxies() { + return false; + } + /** {@inheritDoc} */ @Override public NDList load(Path path) { @@ -243,6 +249,18 @@ public NDManager newSubManager(Device device) { return this; } + /** {@inheritDoc} */ + @Override + public NDManager newSubManager(boolean useProxies) { + return this; + } + + /** {@inheritDoc} */ + @Override + public NDManager newSubManager(Device device, boolean useProxies) { + return this; + } + /** {@inheritDoc} */ @Override public Device getDevice() { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index d5e5a7ba5a8..cc85079afa8 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -23,8 +23,8 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; - import ai.djl.util.NativeResourceImpl; + import com.sun.jna.Native; import com.sun.jna.Pointer; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java index bfb51646fc8..6da9c104e53 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java @@ -20,8 +20,8 @@ import ai.djl.ndarray.NDManager; import ai.djl.training.ParameterServer; import ai.djl.training.optimizer.Optimizer; - import ai.djl.util.NativeResourceImpl; + import com.sun.jna.Pointer; import java.util.Arrays; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 1a0152702eb..cafdbcc1d4a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -128,7 +128,6 @@ public Model newModel(String name, Device device) { return new PtModel(name, device); } - @Override public NDManager newBaseManager(boolean useProxies) { return PtNDManager.getSystemManager().newSubManager(useProxies); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a67aabe5950..bb7ccfae4e1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -8,7 +8,6 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.util.NativeResource; -import ai.djl.util.NativeResourceImpl; import java.nio.Buffer; import java.nio.ByteBuffer; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index b5489144da6..57ae4637d25 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -52,7 +52,6 @@ public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray @SuppressWarnings("PMD.UnusedPrivateField") private ByteBuffer dataRef; - public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { PtNDArray instance = new PtNDArrayImpl(manager, handle); if (manager.isUseProxies()) { @@ -77,7 +76,6 @@ public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape s return instance; } - /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} * instead). @@ -124,9 +122,6 @@ private PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { this.dataType = DataType.STRING; } - - - /** {@inheritDoc} */ @Override public PtNDManager getManager() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index 24426ffa6f0..b4e91ea0024 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -47,7 +47,7 @@ public PtNDArray wrap(NDArray array) { return (PtNDArray) Proxy.newProxyInstance( Thread.currentThread().getContextClassLoader(), - new Class[] {PtNDArray.class}, + new Class[] {PtNDArray.class}, handler); } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index fabc4e2ab5e..24856a9d49b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -12,6 +12,8 @@ */ package ai.djl.pytorch.engine; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.ndarray.BaseNDManager; @@ -27,10 +29,8 @@ import java.nio.ByteOrder; import java.nio.charset.Charset; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - /** {@code PtNDManager} is the PyTorch implementation of {@link NDManager}. */ -public class PtNDManager extends BaseNDManager { +public class PtNDManager extends BaseNDManager { private static final PtNDManager SYSTEM_MANAGER = new SystemManager(); @@ -40,19 +40,14 @@ private PtNDManager(NDManager parent, Device device, boolean useProxies) { super(parent, device, useProxies); } - static PtNDManager getSystemManager() { return SYSTEM_MANAGER; } - - public PtNDArrayProxyMaker getProxyMaker() { return getSystemManager().getProxyMaker(); } - - /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -185,7 +180,6 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy return JniUtils.normal(this, loc, scale, shape, dataType, device); } - /** {@inheritDoc} */ @Override public PtNDManager newSubManager(Device device) { @@ -194,7 +188,6 @@ public PtNDManager newSubManager(Device device) { return manager; } - @Override public NDManager newSubManager(Device device, boolean useProxies) { PtNDManager manager = new PtNDManager(this, device, useProxies); @@ -221,8 +214,6 @@ private static final class SystemManager extends PtNDManager implements SystemND this.proxyMaker = new PtNDArrayProxyMaker(); } - - @Override public PtNDArrayProxyMaker getProxyMaker() { return this.proxyMaker; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java index 3094c6e7006..a68c1be9359 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -12,6 +12,8 @@ */ package ai.djl.pytorch.jni; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -24,8 +26,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - /** * A class represent a PyTorch {@code IValue} data. * diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 1ff735d00fc..7eb39fbc5dd 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -12,6 +12,8 @@ */ package ai.djl.pytorch.jni; +import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; + import ai.djl.Device; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDIndex; @@ -30,7 +32,6 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.engine.PtDeviceType; import ai.djl.pytorch.engine.PtNDArray; -import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; @@ -50,8 +51,6 @@ import java.util.ListIterator; import java.util.Set; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - /** * A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI) * layer. @@ -237,7 +236,7 @@ public static PtNDArray full( } public static PtNDArray zerosLike( - PtNDArray array, DataType dType, Device device, SparseFormat fmt) { + PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return newPtNDArray( array.getManager(), @@ -250,7 +249,7 @@ public static PtNDArray zerosLike( } public static PtNDArray onesLike( - PtNDArray array, DataType dType, Device device, SparseFormat fmt) { + PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); return newPtNDArray( array.getManager(), @@ -581,8 +580,7 @@ public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager m return newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } - return newPtNDArray( - manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); + return newPtNDArray(manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { @@ -798,8 +796,7 @@ public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray max(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { @@ -815,8 +812,7 @@ public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray min(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { @@ -844,8 +840,7 @@ public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { } public static PtNDArray sum(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { @@ -928,8 +923,7 @@ public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { } public static PtNDArray abs(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { @@ -967,13 +961,11 @@ public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { } public static PtNDArray exp(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { @@ -987,18 +979,15 @@ public static PtNDArray log2(PtNDArray ndArray) { } public static PtNDArray sin(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { @@ -1042,13 +1031,11 @@ public static PtNDArray sigmoid(PtNDArray ndArray) { } public static PtNDArray all(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { @@ -1090,8 +1077,7 @@ public static PtNDArray lte(PtNDArray self, PtNDArray other) { } public static PtNDArray neg(PtNDArray ndArray) { - return newPtNDArray( - ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); + return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } public static void negi(PtNDArray ndArray) { @@ -1201,7 +1187,7 @@ public static PtNDArray inverse(PtNDArray ndArray) { } public static PtNDArray interpolate( - PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { + PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( @@ -1302,7 +1288,7 @@ public static PtNDArray batchNorm( } public static PtNDArray layerNorm( - PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { + PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( @@ -1441,7 +1427,7 @@ public static PtNDArray avgPool( } public static PtNDArray maxPool( - PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { + PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { return newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( @@ -1467,7 +1453,7 @@ public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { } public static PtNDArray lpPool( - PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) { + PtNDArray ndArray, double normType, Shape kernelSize, Shape stride, boolean ceilMode) { if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } @@ -1549,7 +1535,7 @@ public static PtNDArray getGradient(PtNDArray ndArray) { } public static void backward( - PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) { + PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph) { PyTorchLibrary.LIB.torchBackward( ndArray.getHandle(), gradNd.getHandle(), keepGraph, createGraph); } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index fea3711b94b..2b1715495b0 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -16,7 +16,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; -import ai.djl.pytorch.engine.PtNDArrayProxyMaker; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -38,8 +37,8 @@ public static void main(String[] args) try (NDManager baseManager = NDManager.newBaseManager(true); ) { try (NDManager subManager = baseManager.newSubManager()) { - NDArray a = subManager.create(new float[]{1f}); - NDArray b = subManager.create(new float[]{2f}); + NDArray a = subManager.create(new float[] {1f}); + NDArray b = subManager.create(new float[] {2f}); NDArray c = a.add(b); debugDumpFromSystemManager(); From f49c4b470dd3f72909f4ada3ca1fe5014a95b07f Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 18:01:20 +0100 Subject: [PATCH 05/30] build without test and publishesToMavenLocal --- api/src/main/java/ai/djl/engine/Engine.java | 4 +- .../java/ai/djl/pytorch/engine/PtEngine.java | 1 + .../java/ai/djl/pytorch/engine/PtNDArray.java | 197 +++++++++++ .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 82 +++-- .../pytorch/engine/PtNDArrayProxyMaker.java | 8 +- .../ai/djl/pytorch/engine/PtNDManager.java | 7 +- .../main/java/ai/djl/pytorch/jni/IValue.java | 6 +- .../java/ai/djl/pytorch/jni/JniUtils.java | 314 +++++++++--------- tools/conf/findbugs-exclude.xml | 2 +- 9 files changed, 431 insertions(+), 190 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 314b7d3359c..099f3b5493f 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -290,7 +290,9 @@ public int getGpuCount() { * @param useProxies whether to facade resources with a proxy * @return a new top-level {@code NDManager} */ - public abstract NDManager newBaseManager(boolean useProxies); + public NDManager newBaseManager(boolean useProxies) { + throw new UnsupportedOperationException("Not implemented"); + } /** * Creates a new top-level {@link NDManager}. diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index cafdbcc1d4a..7fa2636d60b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -128,6 +128,7 @@ public Model newModel(String name, Device device) { return new PtModel(name, device); } + /** {@inheritDoc} */ @Override public NDManager newBaseManager(boolean useProxies) { return PtNDManager.getSystemManager().newSubManager(useProxies); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index bb7ccfae4e1..146d88a5fad 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1,3 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ package ai.djl.pytorch.engine; import ai.djl.Device; @@ -13,556 +25,741 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; +/** {@code PtNDArray} is the interface for the PyTorch implementation of {@link NDArray}. */ public interface PtNDArray extends NativeResource, NDArray { + /** {@inheritDoc} */ @Override PtNDManager getManager(); + /** {@inheritDoc} */ @Override String getName(); + /** {@inheritDoc} */ @Override void setName(String name); + /** {@inheritDoc} */ @Override DataType getDataType(); + /** {@inheritDoc} */ @Override Device getDevice(); + /** {@inheritDoc} */ @Override Shape getShape(); + /** {@inheritDoc} */ @Override SparseFormat getSparseFormat(); + /** {@inheritDoc} */ @Override PtNDArray toDevice(Device device, boolean copy); + /** {@inheritDoc} */ @Override PtNDArray toType(DataType dataType, boolean copy); + /** {@inheritDoc} */ @Override void setRequiresGradient(boolean requiresGrad); + /** {@inheritDoc} */ @Override PtNDArray getGradient(); + /** {@inheritDoc} */ @Override boolean hasGradient(); + /** {@inheritDoc} */ @Override NDArray stopGradient(); + /** {@inheritDoc} */ @Override ByteBuffer toByteBuffer(); + /** {@inheritDoc} */ @Override String[] toStringArray(Charset charset); + /** {@inheritDoc} */ @Override void set(Buffer buffer); + /** {@inheritDoc} */ @Override NDArray get(NDManager manager, long... indices); + /** {@inheritDoc} */ @Override NDArray gather(NDArray index, int axis); + /** {@inheritDoc} */ @Override NDArray gatherNd(NDArray index); + /** {@inheritDoc} */ @Override NDArray take(NDManager manager, NDArray index); + /** {@inheritDoc} */ @Override NDArray put(NDArray index, NDArray data); + /** {@inheritDoc} */ @Override void copyTo(NDArray array); + /** {@inheritDoc} */ @Override void attach(NDManager manager); + /** {@inheritDoc} */ @Override void returnResource(NDManager manager); + /** {@inheritDoc} */ @Override void tempAttach(NDManager manager); + /** {@inheritDoc} */ @Override void detach(); + /** {@inheritDoc} */ @Override NDArray duplicate(); + /** {@inheritDoc} */ @Override PtNDArray booleanMask(NDArray index, int axis); + /** {@inheritDoc} */ @Override NDArray sequenceMask(NDArray sequenceLength, float value); + /** {@inheritDoc} */ @Override NDArray sequenceMask(NDArray sequenceLength); + /** {@inheritDoc} */ @Override boolean contentEquals(Number number); + /** {@inheritDoc} */ @Override boolean contentEquals(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray eq(Number n); + /** {@inheritDoc} */ @Override PtNDArray eq(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray neq(Number n); + /** {@inheritDoc} */ @Override PtNDArray neq(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray gt(Number n); + /** {@inheritDoc} */ @Override PtNDArray gt(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray gte(Number n); + /** {@inheritDoc} */ @Override PtNDArray gte(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray lt(Number n); + /** {@inheritDoc} */ @Override PtNDArray lt(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray lte(Number n); + /** {@inheritDoc} */ @Override PtNDArray lte(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray add(Number n); + /** {@inheritDoc} */ @Override PtNDArray add(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray sub(Number n); + /** {@inheritDoc} */ @Override PtNDArray sub(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray mul(Number n); + /** {@inheritDoc} */ @Override PtNDArray mul(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray div(Number n); + /** {@inheritDoc} */ @Override PtNDArray div(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray mod(Number n); + /** {@inheritDoc} */ @Override PtNDArray mod(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray pow(Number n); + /** {@inheritDoc} */ @Override PtNDArray pow(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray addi(Number n); + /** {@inheritDoc} */ @Override PtNDArray addi(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray subi(Number n); + /** {@inheritDoc} */ @Override PtNDArray subi(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray muli(Number n); + /** {@inheritDoc} */ @Override PtNDArray muli(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray divi(Number n); + /** {@inheritDoc} */ @Override PtNDArray divi(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray modi(Number n); + /** {@inheritDoc} */ @Override PtNDArray modi(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray powi(Number n); + /** {@inheritDoc} */ @Override PtNDArray powi(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray sign(); + /** {@inheritDoc} */ @Override PtNDArray signi(); + /** {@inheritDoc} */ @Override PtNDArray maximum(Number n); + /** {@inheritDoc} */ @Override PtNDArray maximum(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray minimum(Number n); + /** {@inheritDoc} */ @Override PtNDArray minimum(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray all(); + /** {@inheritDoc} */ @Override PtNDArray any(); + /** {@inheritDoc} */ @Override PtNDArray none(); + /** {@inheritDoc} */ @Override PtNDArray neg(); + /** {@inheritDoc} */ @Override PtNDArray negi(); + /** {@inheritDoc} */ @Override PtNDArray abs(); + /** {@inheritDoc} */ @Override PtNDArray square(); + /** {@inheritDoc} */ @Override NDArray sqrt(); + /** {@inheritDoc} */ @Override PtNDArray cbrt(); + /** {@inheritDoc} */ @Override PtNDArray floor(); + /** {@inheritDoc} */ @Override PtNDArray ceil(); + /** {@inheritDoc} */ @Override PtNDArray round(); + /** {@inheritDoc} */ @Override PtNDArray trunc(); + /** {@inheritDoc} */ @Override PtNDArray exp(); + /** {@inheritDoc} */ @Override NDArray gammaln(); + /** {@inheritDoc} */ @Override PtNDArray log(); + /** {@inheritDoc} */ @Override PtNDArray log10(); + /** {@inheritDoc} */ @Override PtNDArray log2(); + /** {@inheritDoc} */ @Override PtNDArray sin(); + /** {@inheritDoc} */ @Override PtNDArray cos(); + /** {@inheritDoc} */ @Override PtNDArray tan(); + /** {@inheritDoc} */ @Override PtNDArray asin(); + /** {@inheritDoc} */ @Override PtNDArray acos(); + /** {@inheritDoc} */ @Override PtNDArray atan(); + /** {@inheritDoc} */ @Override PtNDArray sinh(); + /** {@inheritDoc} */ @Override PtNDArray cosh(); + /** {@inheritDoc} */ @Override PtNDArray tanh(); + /** {@inheritDoc} */ @Override PtNDArray asinh(); + /** {@inheritDoc} */ @Override PtNDArray acosh(); + /** {@inheritDoc} */ @Override PtNDArray atanh(); + /** {@inheritDoc} */ @Override PtNDArray toDegrees(); + /** {@inheritDoc} */ @Override PtNDArray toRadians(); + /** {@inheritDoc} */ @Override PtNDArray max(); + /** {@inheritDoc} */ @Override PtNDArray max(int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override PtNDArray min(); + /** {@inheritDoc} */ @Override PtNDArray min(int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override PtNDArray sum(); + /** {@inheritDoc} */ @Override PtNDArray sum(int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override NDArray cumProd(int axis); + /** {@inheritDoc} */ @Override NDArray cumProd(int axis, DataType dataType); + /** {@inheritDoc} */ @Override PtNDArray prod(); + /** {@inheritDoc} */ @Override PtNDArray prod(int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override PtNDArray mean(); + /** {@inheritDoc} */ @Override PtNDArray mean(int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override PtNDArray normalize(double p, long dim, double eps); + /** {@inheritDoc} */ @Override PtNDArray rotate90(int times, int[] axes); + /** {@inheritDoc} */ @Override PtNDArray trace(int offset, int axis1, int axis2); + /** {@inheritDoc} */ @Override NDList split(long sections, int axis); + /** {@inheritDoc} */ @Override NDList split(long[] indices, int axis); + /** {@inheritDoc} */ @Override PtNDArray flatten(); + /** {@inheritDoc} */ @Override NDArray flatten(int startDim, int endDim); + /** {@inheritDoc} */ @Override PtNDArray reshape(Shape shape); + /** {@inheritDoc} */ @Override PtNDArray expandDims(int axis); + /** {@inheritDoc} */ @Override PtNDArray squeeze(); + /** {@inheritDoc} */ @Override PtNDArray squeeze(int axis); + /** {@inheritDoc} */ @Override PtNDArray squeeze(int[] axes); + /** {@inheritDoc} */ @Override PtNDArray logicalAnd(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray logicalOr(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray logicalXor(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray logicalNot(); + /** {@inheritDoc} */ @Override PtNDArray argSort(int axis, boolean ascending); + /** {@inheritDoc} */ @Override PtNDArray sort(); + /** {@inheritDoc} */ @Override PtNDArray sort(int axis); + /** {@inheritDoc} */ @Override PtNDArray softmax(int axis); + /** {@inheritDoc} */ @Override PtNDArray logSoftmax(int axis); + /** {@inheritDoc} */ @Override PtNDArray cumSum(); + /** {@inheritDoc} */ @Override PtNDArray cumSum(int axis); + /** {@inheritDoc} */ @Override void intern(NDArray replaced); + /** {@inheritDoc} */ @Override PtNDArray isInfinite(); + /** {@inheritDoc} */ @Override PtNDArray isNaN(); + /** {@inheritDoc} */ @Override PtNDArray tile(long repeats); + /** {@inheritDoc} */ @Override PtNDArray tile(int axis, long repeats); + /** {@inheritDoc} */ @Override PtNDArray tile(long[] repeats); + /** {@inheritDoc} */ @Override PtNDArray tile(Shape desiredShape); + /** {@inheritDoc} */ @Override PtNDArray repeat(long repeats); + /** {@inheritDoc} */ @Override PtNDArray repeat(int axis, long repeats); + /** {@inheritDoc} */ @Override PtNDArray repeat(long[] repeats); + /** {@inheritDoc} */ @Override PtNDArray repeat(Shape desiredShape); + /** {@inheritDoc} */ @Override PtNDArray dot(NDArray other); + /** {@inheritDoc} */ @Override NDArray matMul(NDArray other); + /** {@inheritDoc} */ @Override PtNDArray clip(Number min, Number max); + /** {@inheritDoc} */ @Override PtNDArray swapAxes(int axis1, int axis2); + /** {@inheritDoc} */ @Override NDArray flip(int... axes); + /** {@inheritDoc} */ @Override PtNDArray transpose(); + /** {@inheritDoc} */ @Override PtNDArray transpose(int... axes); + /** {@inheritDoc} */ @Override PtNDArray broadcast(Shape shape); + /** {@inheritDoc} */ @Override PtNDArray argMax(); + /** {@inheritDoc} */ @Override PtNDArray argMax(int axis); + /** {@inheritDoc} */ @Override PtNDArray argMin(); + /** {@inheritDoc} */ @Override PtNDArray argMin(int axis); + /** {@inheritDoc} */ @Override PtNDArray percentile(Number percentile); + /** {@inheritDoc} */ @Override PtNDArray percentile(Number percentile, int[] axes); + /** {@inheritDoc} */ @Override PtNDArray median(); + /** {@inheritDoc} */ @Override PtNDArray median(int[] axes); + /** {@inheritDoc} */ @Override PtNDArray toDense(); + /** {@inheritDoc} */ @Override PtNDArray toSparse(SparseFormat fmt); + /** {@inheritDoc} */ @Override PtNDArray nonzero(); + /** {@inheritDoc} */ @Override PtNDArray erfinv(); + /** {@inheritDoc} */ @Override PtNDArray inverse(); + /** {@inheritDoc} */ @Override NDArray norm(boolean keepDims); + /** {@inheritDoc} */ @Override NDArray norm(int order, int[] axes, boolean keepDims); + /** {@inheritDoc} */ @Override NDArray oneHot(int depth); + /** {@inheritDoc} */ @Override NDArray oneHot(int depth, DataType dataType); + /** {@inheritDoc} */ @Override NDArray oneHot(int depth, float onValue, float offValue, DataType dataType); + /** {@inheritDoc} */ @Override NDArray batchDot(NDArray other); + /** {@inheritDoc} */ @Override PtNDArrayEx getNDArrayInternal(); + /** {@inheritDoc} */ @Override String toString(); + /** {@inheritDoc} */ @Override boolean equals(Object obj); + /** {@inheritDoc} */ @Override int hashCode(); + /** {@inheritDoc} */ @Override void close(); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 57ae4637d25..1e4d065b85a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -34,8 +34,8 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -/** {@code PtNDArray} is the PyTorch implementation of {@link NDArray}. */ -public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray { +/** {@code PtNDArrayImpl} is the PyTorch implementation of {@link NDArray}. */ +public final class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray { private String name; private Device device; @@ -52,30 +52,6 @@ public class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray @SuppressWarnings("PMD.UnusedPrivateField") private ByteBuffer dataRef; - public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { - PtNDArray instance = new PtNDArrayImpl(manager, handle); - if (manager.isUseProxies()) { - instance = manager.getProxyMaker().wrap(instance); - } - return instance; - } - - public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { - PtNDArray instance = new PtNDArrayImpl(manager, handle, data); - if (manager.isUseProxies()) { - instance = manager.getProxyMaker().wrap(instance); - } - return instance; - } - - public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { - PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); - if (manager.isUseProxies()) { - instance = manager.getProxyMaker().wrap(instance); - } - return instance; - } - /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} * instead). @@ -1582,4 +1558,58 @@ public void close() { manager.detachInternal(getUid()); dataRef = null; } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead). Depending on the switch {@code useProxies}, the returned {@code NDArray} will be + * returned as a proxy or a direct instance. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { + PtNDArray instance = new PtNDArrayImpl(manager, handle); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead) with the data that is hold on Java side. Depending on the switch {@code useProxies}, + * the returned {@code NDArray} will be returned as a proxy or a direct instance. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @param data the direct buffer of the data + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { + PtNDArray instance = new PtNDArrayImpl(manager, handle, data); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + /** + * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle + * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. Depending + * on the switch {@code useProxies}, the returned {@code NDArray} will be returned as a proxy or + * a direct instance. + * + * @param manager the manager to attach the new array to + * @param strs the string array + * @param shape the {@link Shape} of the {@link NDArray} + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { + PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); + if (manager.isUseProxies()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index b4e91ea0024..3e5253964a5 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -25,11 +25,8 @@ public class PtNDArrayProxyMaker implements NDArrayProxyMaker { WeakHashMapWrapper map = new WeakHashMapWrapper<>(); - /** - * Returns the size of the map. - * - * @return the size of the map - */ + /** {@inheritDoc} */ + @Override public int mapSize() { return map.size(); } @@ -40,6 +37,7 @@ public int mapSize() { * @param array the array to wrap * @return the wrapped array */ + @Override public PtNDArray wrap(NDArray array) { UUID uuid = UUID.randomUUID(); map.put(uuid, array); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 24856a9d49b..1f1335c490a 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -12,8 +12,6 @@ */ package ai.djl.pytorch.engine; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.ndarray.BaseNDManager; @@ -44,6 +42,8 @@ static PtNDManager getSystemManager() { return SYSTEM_MANAGER; } + /** {@inheritDoc} */ + @Override public PtNDArrayProxyMaker getProxyMaker() { return getSystemManager().getProxyMaker(); } @@ -89,7 +89,7 @@ public PtNDArray create(Buffer data, Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { - return newPtNDArray(this, data, shape); + return PtNDArrayImpl.newPtNDArray(this, data, shape); } /** {@inheritDoc} */ @@ -188,6 +188,7 @@ public PtNDManager newSubManager(Device device) { return manager; } + /** {@inheritDoc} */ @Override public NDManager newSubManager(Device device, boolean useProxies) { PtNDManager manager = new PtNDManager(this, device, useProxies); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java index a68c1be9359..48b5192e24b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -12,8 +12,6 @@ */ package ai.djl.pytorch.jni; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; @@ -395,7 +393,7 @@ public double[] toDoubleArray() { * @return the NDArray value of this IValue */ public PtNDArray toTensor(PtNDManager manager) { - return newPtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); + return PtNDArrayImpl.newPtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); } /** @@ -408,7 +406,7 @@ public PtNDArray[] toTensorArray(PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); PtNDArray[] ret = new PtNDArrayImpl[handles.length]; for (int i = 0; i < ret.length; ++i) { - ret[i] = newPtNDArray(manager, handles[i]); + ret[i] = PtNDArrayImpl.newPtNDArray(manager, handles[i]); } return ret; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 7eb39fbc5dd..e068774f8df 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -12,8 +12,6 @@ */ package ai.djl.pytorch.jni; -import static ai.djl.pytorch.engine.PtNDArrayImpl.newPtNDArray; - import ai.djl.Device; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDIndex; @@ -32,6 +30,7 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.engine.PtDeviceType; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; @@ -172,15 +171,15 @@ public static PtNDArray createNdFromByteBuffer( if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side - return newPtNDArray(manager, handle); + return PtNDArrayImpl.newPtNDArray(manager, handle); } - return newPtNDArray(manager, handle, data); + return PtNDArrayImpl.newPtNDArray(manager, handle, data); } public static PtNDArray createEmptyNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchEmpty( shape.getShape(), @@ -193,7 +192,7 @@ public static PtNDArray createEmptyNdArray( public static PtNDArray createZerosNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchZeros( shape.getShape(), @@ -206,7 +205,7 @@ public static PtNDArray createZerosNdArray( public static PtNDArray createOnesNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchOnes( shape.getShape(), @@ -224,7 +223,7 @@ public static PtNDArray full( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchFull( shape.getShape(), @@ -238,7 +237,7 @@ public static PtNDArray full( public static PtNDArray zerosLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchZerosLike( array.getHandle(), @@ -251,7 +250,7 @@ public static PtNDArray zerosLike( public static PtNDArray onesLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchOnesLike( array.getHandle(), @@ -270,7 +269,7 @@ public static PtNDArray arange( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchArange( start, @@ -291,7 +290,7 @@ public static PtNDArray linspace( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchLinspace( start, @@ -304,7 +303,7 @@ public static PtNDArray linspace( } public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( values.getManager(), PyTorchLibrary.LIB.torchSparseCoo( shape.getShape(), indices.getHandle(), values.getHandle(), false)); @@ -317,7 +316,7 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) if (!device.equals(manager.getDevice())) { manager = manager.newSubManager(device); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchTo( ndArray.getHandle(), @@ -326,23 +325,23 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) } public static PtNDArray toSparse(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse(ndArray.getHandle())); } public static PtNDArray toDense(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToDense(ndArray.getHandle())); } public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExpand(ndArray.getHandle(), shape.getShape())); } public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSlice(ndArray.getHandle(), dim, start, stop, step)); } @@ -353,7 +352,7 @@ public static PtNDArray index( long[] maxIndices, long[] stepIndices, PtNDManager manager) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchIndex( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); @@ -414,7 +413,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); } @@ -500,7 +499,7 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } @@ -509,7 +508,7 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager man if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); } @@ -517,7 +516,7 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data) if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPut( ndArray.getHandle(), index.getHandle(), data.getHandle())); @@ -549,20 +548,20 @@ public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchWhere( condition.getHandle(), self.getHandle(), other.getHandle())); } public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect(ndArray.getHandle(), indicesNd.getHandle())); } @@ -577,101 +576,104 @@ public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager m // due to significant performance gain // for commonly used data loading call if (indices.length == 1) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } - return newPtNDArray(manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); + return PtNDArrayImpl.newPtNDArray( + manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle())); } public static PtNDArray reshape(PtNDArray ndArray, long[] shape) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchReshape(ndArray.getHandle(), shape)); } public static PtNDArray stack(PtNDArray[] arrays, int dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return newPtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); + return PtNDArrayImpl.newPtNDArray( + arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); } public static PtNDArray cat(PtNDArray[] arrays, long dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return newPtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); + return PtNDArrayImpl.newPtNDArray( + arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); } public static PtNDArray tile(PtNDArray ndArray, long[] repeats) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat(ndArray.getHandle(), repeats)); } public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(ndArray.getHandle(), repeat, dim)); } public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray argMax(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle())); } public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argMin(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle())); } public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSort(ndArray.getHandle(), dim, descending)); } public static PtNDArray permute(PtNDArray ndArray, long[] dims) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPermute(ndArray.getHandle(), dims)); } public static PtNDArray flip(PtNDArray ndArray, long[] dims) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlip(ndArray.getHandle(), dims)); } public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose(ndArray.getHandle(), dim1, dim2)); } @@ -681,7 +683,7 @@ public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -691,7 +693,7 @@ public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchSub(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -701,7 +703,7 @@ public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMul(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -711,7 +713,7 @@ public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -721,7 +723,7 @@ public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -731,7 +733,7 @@ public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchPow(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -741,7 +743,7 @@ public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sign(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSign(ndArray.getHandle())); } @@ -750,101 +752,104 @@ public static void signi(PtNDArray ndArray) { } public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalNot(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot(ndArray.getHandle())); } public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) { if (ndArray1.getShape().dimension() == 1) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchDot(ndArray1.getHandle(), ndArray2.getHandle())); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray min(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray mean(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); } public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRot90(ndArray.getHandle(), times, longaxes)); } public static PtNDArray sum(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle(), dims, keepDim)); } @@ -854,29 +859,29 @@ public static PtNDArray cumProd(PtNDArray ndArray, long dim, DataType dataType) if (dataType != null) { dtPosition = dataType.ordinal(); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumProd(ndArray.getHandle(), dim, dtPosition)); } public static PtNDArray prod(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle())); } public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray cumSum(PtNDArray ndArray, long dim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNOneHot( ndArray.toType(DataType.INT64, false).getHandle(), depth)) @@ -887,7 +892,7 @@ public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(newPtNDArray(ndArray.getManager(), ptr)); + list.add(PtNDArrayImpl.newPtNDArray(ndArray.getManager(), ptr)); } return list; } @@ -896,188 +901,197 @@ public static NDList split(PtNDArray ndArray, long[] indices, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(newPtNDArray(ndArray.getManager(), ptr)); + list.add(PtNDArrayImpl.newPtNDArray(ndArray.getManager(), ptr)); } return list; } public static PtNDArray squeeze(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle())); } public static PtNDArray squeeze(PtNDArray ndArray, long dim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle(), dim)); } public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze(ndArray.getHandle(), dim)); } public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten(ndArray.getHandle(), startDim, endDim)); } public static PtNDArray abs(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSquare(ndArray.getHandle())); } public static PtNDArray floor(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFloor(ndArray.getHandle())); } public static PtNDArray ceil(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCeil(ndArray.getHandle())); } public static PtNDArray round(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRound(ndArray.getHandle())); } public static PtNDArray trunc(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc(ndArray.getHandle())); } public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { PtNDArray minNd = (PtNDArray) ndArray.getManager().create(min); PtNDArray maxNd = (PtNDArray) ndArray.getManager().create(max); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchClamp( ndArray.getHandle(), minNd.getHandle(), maxNd.getHandle())); } public static PtNDArray exp(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog10(ndArray.getHandle())); } public static PtNDArray log2(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog2(ndArray.getHandle())); } public static PtNDArray sin(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchASin(ndArray.getHandle())); } public static PtNDArray acos(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAcos(ndArray.getHandle())); } public static PtNDArray atan(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); } public static PtNDArray sinh(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSinh(ndArray.getHandle())); } public static PtNDArray cosh(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCosh(ndArray.getHandle())); } public static PtNDArray tanh(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTanh(ndArray.getHandle())); } public static PtNDArray sigmoid(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid(ndArray.getHandle())); } public static PtNDArray all(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNone(ndArray.getHandle())); } public static PtNDArray eq(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchEq(self.getHandle(), other.getHandle())); } public static PtNDArray neq(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchNeq(self.getHandle(), other.getHandle())); } public static PtNDArray gt(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGt(self.getHandle(), other.getHandle())); } public static PtNDArray gte(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGte(self.getHandle(), other.getHandle())); } public static PtNDArray lt(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLt(self.getHandle(), other.getHandle())); } public static PtNDArray lte(PtNDArray self, PtNDArray other) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLte(self.getHandle(), other.getHandle())); } public static PtNDArray neg(PtNDArray ndArray) { - return newPtNDArray(ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); + return PtNDArrayImpl.newPtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } public static void negi(PtNDArray ndArray) { @@ -1085,12 +1099,12 @@ public static void negi(PtNDArray ndArray) { } public static PtNDArray isNaN(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN(ndArray.getHandle())); } public static PtNDArray isInf(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf(ndArray.getHandle())); } @@ -1101,7 +1115,7 @@ public static PtNDArray randint( Shape size, DataType dataType, Device device) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchRandint( low, @@ -1115,7 +1129,7 @@ public static PtNDArray randint( public static PtNDArray randperm( PtNDManager manager, long n, DataType dataType, Device device) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchRandPerm( n, @@ -1132,7 +1146,7 @@ public static PtNDArray normal( Shape size, DataType dataType, Device device) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchNormal( mean, @@ -1151,7 +1165,7 @@ public static PtNDArray uniform( Shape size, DataType dataType, Device device) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.tensorUniform( low, @@ -1165,7 +1179,7 @@ public static PtNDArray uniform( public static PtNDArray eye( PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchEye( n, @@ -1177,25 +1191,25 @@ public static PtNDArray eye( } public static PtNDArray erfinv(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } public static PtNDArray inverse(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); } public static PtNDArray interpolate( PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( ndArray.getHandle(), size, mode, alignCorners)); } public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNLinear( input.getHandle(), @@ -1204,44 +1218,44 @@ public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias } public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding(input.getHandle(), weight.getHandle(), sparse)); } public static PtNDArray relu(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu(ndArray.getHandle())); } public static PtNDArray softPlus(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(ndArray.getHandle())); } public static PtNDArray softSign(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign(ndArray.getHandle())); } public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(ndArray.getHandle(), negativeSlope)); } public static PtNDArray elu(PtNDArray ndArray, double alpha) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu(ndArray.getHandle(), alpha)); } public static PtNDArray selu(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu(ndArray.getHandle())); } public static PtNDArray gelu(PtNDArray ndArray) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu(ndArray.getHandle())); } @@ -1253,7 +1267,7 @@ public static PtNDArray convolution( Shape padding, Shape dilation, int groups) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd( ndArray.getHandle(), @@ -1274,7 +1288,7 @@ public static PtNDArray batchNorm( boolean isTraining, double momentum, double eps) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm( ndArray.getHandle(), @@ -1289,7 +1303,7 @@ public static PtNDArray batchNorm( public static PtNDArray layerNorm( PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( ndArray.getHandle(), @@ -1300,13 +1314,13 @@ public static PtNDArray layerNorm( } public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); } public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } @@ -1339,7 +1353,7 @@ public static NDList rnn( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(newPtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1370,7 +1384,7 @@ public static NDList gru( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(newPtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1403,7 +1417,7 @@ public static NDList lstm( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(newPtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1415,7 +1429,7 @@ public static PtNDArray avgPool( Shape padding, boolean ceilMode, boolean countIncludePad) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool( ndArray.getHandle(), @@ -1428,7 +1442,7 @@ public static PtNDArray avgPool( public static PtNDArray maxPool( PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( ndArray.getHandle(), @@ -1439,14 +1453,14 @@ public static PtNDArray maxPool( } public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool( ndArray.getHandle(), outputSize.getShape())); @@ -1457,7 +1471,7 @@ public static PtNDArray lpPool( if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool( ndArray.getHandle(), @@ -1522,7 +1536,7 @@ public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) { public static PtNDArray detachGradient(PtNDArray ndArray) { // TODO: detached ndarray may not use the same manager for the attached one - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad(ndArray.getHandle())); } @@ -1531,7 +1545,7 @@ public static PtNDArray getGradient(PtNDArray ndArray) { if (pointer == NULL_PTR) { return null; } - return newPtNDArray(ndArray.getManager(), pointer); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), pointer); } public static void backward( @@ -1611,7 +1625,7 @@ public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) { String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle()); NDList list = new NDList(handles.length); for (int i = 0; i < handles.length; i++) { - PtNDArray array = newPtNDArray(manager, handles[i]); + PtNDArray array = PtNDArrayImpl.newPtNDArray(manager, handles[i]); array.setName(names[i]); list.add(array); } @@ -1687,7 +1701,7 @@ public static int getLayout(PtNDArray array) { public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) { long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNorm(ndArray.getHandle(), ord, longAxes, keepDims)); } @@ -1696,7 +1710,7 @@ public static PtNDArray nonZeros(PtNDArray ndArray) { if (ndArray.isScalar()) { ndArray = (PtNDArray) ndArray.reshape(-1); } - return newPtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros(ndArray.getHandle())); } } diff --git a/tools/conf/findbugs-exclude.xml b/tools/conf/findbugs-exclude.xml index b36584a1714..2e59f19a5ad 100644 --- a/tools/conf/findbugs-exclude.xml +++ b/tools/conf/findbugs-exclude.xml @@ -31,6 +31,6 @@ - + From 8b067359f7a2a4f76214c9fc161374feb63a9e56 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 18:47:56 +0100 Subject: [PATCH 06/30] some proxy handling fixes --- api/src/main/java/ai/djl/util/NativeResource.java | 7 +++++++ api/src/main/java/ai/djl/util/NativeResourceImpl.java | 6 ++++++ .../ai/djl/pytorch/engine/PtGradientCollector.java | 2 +- .../src/main/java/ai/djl/pytorch/engine/PtNDArray.java | 1 + .../main/java/ai/djl/pytorch/engine/PtNDArrayEx.java | 4 ++-- .../main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java | 4 ++-- .../java/ai/djl/pytorch/engine/PtNDArrayIndexer.java | 10 +++++----- .../java/ai/djl/pytorch/integration/IValueTest.java | 9 ++++----- 8 files changed, 28 insertions(+), 15 deletions(-) diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index 978554c7b4f..c70483234d1 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -42,6 +42,13 @@ public interface NativeResource extends AutoCloseable { */ String getUid(); + /** + * Gets and sets the atomic handle to null. + * + * @return the previous handle value + */ + T getAndSetHandleNull(); + /** {@inheritDoc} */ @Override void close(); diff --git a/api/src/main/java/ai/djl/util/NativeResourceImpl.java b/api/src/main/java/ai/djl/util/NativeResourceImpl.java index 4133c01c772..f907db7dd78 100644 --- a/api/src/main/java/ai/djl/util/NativeResourceImpl.java +++ b/api/src/main/java/ai/djl/util/NativeResourceImpl.java @@ -30,6 +30,12 @@ protected NativeResourceImpl(T handle) { uid = handle.toString(); } + /** {@inheritDoc} */ + @Override + public T getAndSetHandleNull() { + return handle.getAndSet(null); + } + /** {@inheritDoc} */ @Override public boolean isReleased() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index b253e73f306..5ef573cee68 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -64,7 +64,7 @@ public void backward(NDArray target) { * higher order derivative products. Defaults to false. */ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean createGraph) { - JniUtils.backward((PtNDArrayImpl) target, (PtNDArrayImpl) grad, keepGraph, createGraph); + JniUtils.backward((PtNDArray) target, (PtNDArray) grad, keepGraph, createGraph); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 146d88a5fad..34189e6f4b3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -27,6 +27,7 @@ /** {@code PtNDArray} is the interface for the PyTorch implementation of {@link NDArray}. */ public interface PtNDArray extends NativeResource, NDArray { + /** {@inheritDoc} */ @Override PtNDManager getManager(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index 0fe4a80eb98..f5ae2cdbdd3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -29,14 +29,14 @@ /** {@code PtNDArrayEx} is the PyTorch implementation of the {@link NDArrayEx}. */ public class PtNDArrayEx implements NDArrayEx { - private PtNDArrayImpl array; + private PtNDArray array; /** * Constructs an {@code PtNDArrayEx} given a {@link NDArray}. * * @param parent the {@link NDArray} to extend */ - PtNDArrayEx(PtNDArrayImpl parent) { + PtNDArrayEx(PtNDArray parent) { this.array = parent; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 1e4d065b85a..6cc0c48aecc 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -1191,8 +1191,8 @@ public PtNDArray cumSum(int axis) { /** {@inheritDoc} */ @Override public void intern(NDArray replaced) { - PtNDArrayImpl arr = (PtNDArrayImpl) replaced; - Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); + PtNDArray arr = (PtNDArray) replaced; + Long oldHandle = handle.getAndSet(arr.getAndSetHandleNull()); JniUtils.deleteNDArray(oldHandle); // dereference old ndarray arr.close(); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index c64c9dda170..b38a038d142 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -24,7 +24,7 @@ import java.util.Stack; -/** The {@link NDArrayIndexer} used by the {@link PtNDArrayImpl}. */ +/** The {@link NDArrayIndexer} used by the {@link PtNDArray}. */ public class PtNDArrayIndexer extends NDArrayIndexer { private PtNDManager manager; @@ -70,8 +70,8 @@ public NDArray get(NDArray array, NDIndex index) { index.addAllDim(); } - if (array == null || array instanceof PtNDArrayImpl) { - return JniUtils.indexAdv((PtNDArrayImpl) array, index, manager); + if (array == null || array instanceof PtNDArray) { + return JniUtils.indexAdv((PtNDArray) array, index, manager); } else { PtNDArray arrayNew = manager.create(array.toByteBuffer(), array.getShape(), array.getDataType()); @@ -89,9 +89,9 @@ public void set(NDArray array, NDIndex index, Object data) { array.toByteBuffer(), array.getShape(), array.getDataType()); if (data instanceof Number) { - JniUtils.indexAdvPut(ptArray, index, (PtNDArrayImpl) manager.create((Number) data)); + JniUtils.indexAdvPut(ptArray, index, (PtNDArray) manager.create((Number) data)); } else if (data instanceof NDArray) { - JniUtils.indexAdvPut(ptArray, index, (PtNDArrayImpl) data); + JniUtils.indexAdvPut(ptArray, index, (PtNDArray) data); } else { throw new IllegalArgumentException( "The type of value to assign cannot be other than NDArray and Number."); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java index efad9d17e95..286ac02001c 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java @@ -18,7 +18,6 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtNDArray; -import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; import ai.djl.pytorch.jni.IValue; @@ -39,10 +38,10 @@ public class IValueTest { @Test public void testIValue() { try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { - PtNDArrayImpl array1 = (PtNDArrayImpl) manager.zeros(new Shape(1)); - PtNDArrayImpl array2 = (PtNDArrayImpl) manager.ones(new Shape(1)); - PtNDArrayImpl array3 = (PtNDArrayImpl) manager.create("test"); - PtNDArrayImpl array4 = (PtNDArrayImpl) manager.create(new String[] {"test1", "test2"}); + PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1)); + PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1)); + PtNDArray array3 = (PtNDArray) manager.create("test"); + PtNDArray array4 = (PtNDArray) manager.create(new String[] {"test1", "test2"}); try (IValue ivalue = IValue.from(array1)) { Assert.assertTrue(ivalue.isTensor()); From 6fc187919679f0247906882f8dd9e51c01b87795 Mon Sep 17 00:00:00 2001 From: enpasos Date: Thu, 15 Dec 2022 23:19:04 +0100 Subject: [PATCH 07/30] removed a logging --- .../main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java index 1420759aa97..e6c8b80a1c4 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -14,9 +14,6 @@ import ai.djl.ndarray.NDArray; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.util.ArrayList; @@ -33,8 +30,6 @@ */ public class WeakHashMapWrapper implements Map { - private static final Logger logger = LoggerFactory.getLogger(WeakHashMapWrapper.class); - private final WeakHashMap map = new WeakHashMap<>(); private final ReferenceQueue queue = new ReferenceQueue<>(); @@ -47,8 +42,6 @@ private void checkQueue() { WeakReferenceWrapper ref2 = (WeakReferenceWrapper) ref; V value = ref2.getValue(); if (value instanceof NDArray) { // just as one example - logger.info( - "NDArray is closed triggered by a message from the garbage collector"); ((NDArray) value).close(); } } From 720458633d0b713165c6196864ca4904bf56b602 Mon Sep 17 00:00:00 2001 From: enpasos Date: Fri, 16 Dec 2022 07:26:26 +0100 Subject: [PATCH 08/30] fixed double wrapping bug --- .../java/ai/djl/ndarray/gc/DynamicInvocationHandler.java | 4 ---- api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java | 5 +++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java index d89f9298892..804d2f0f557 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -58,10 +58,6 @@ public Object invoke(Object proxy, Method method, Object[] args) { throw new RuntimeException(e); // NOPMD } - if (result instanceof NDArray) { - return ndArrayProxyMaker.wrap((NDArray) result); - } - return result; } } diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java index e6c8b80a1c4..aab36876c65 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -16,6 +16,7 @@ import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; +import java.lang.reflect.Proxy; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -89,6 +90,10 @@ public V get(Object key) { /** {@inheritDoc} */ @Override public V put(K key, V value) { + if (value instanceof Proxy) { + throw new IllegalArgumentException( + "Proxy is not supported to be stored as value here."); + } weakReferenceWrapperList.add(new WeakReferenceWrapper(key, value, queue)); return map.put(key, value); } From 8b52b0136007b691ef2380f2b7bcc18af8273f11 Mon Sep 17 00:00:00 2001 From: enpasos Date: Fri, 16 Dec 2022 07:38:01 +0100 Subject: [PATCH 09/30] catch exception silently if resource already closed --- .../java/ai/djl/pytorch/engine/PtGradientCollector.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index 5ef573cee68..eb05dc82b41 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -72,8 +72,12 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { - if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + try { + if (array.hasGradient()) { + array.getGradient().subi(array.getGradient()); + } + } catch (IllegalStateException e) { + // ignore if the array is already closed } } } From 335586cadc8e3269a891e707a5bbefa70dcbf24d Mon Sep 17 00:00:00 2001 From: enpasos Date: Fri, 16 Dec 2022 08:19:13 +0100 Subject: [PATCH 10/30] methods to remove NDManager without resources --- .../java/ai/djl/ndarray/BaseNDManager.java | 19 +++++++++++++++++++ .../main/java/ai/djl/ndarray/NDManager.java | 6 ++++++ .../passthrough/PassthroughNDManager.java | 12 ++++++++++++ 3 files changed, 37 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index d56aae3bf5f..42b54d73421 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -78,6 +78,25 @@ public final Device defaultDevice() { return getEngine().defaultDevice(); } + /** {@inheritDoc} */ + @Override + public void closeSubManagersWithoutResources() { + for (AutoCloseable resource : resources.values()) { + if (resource instanceof NDManager) { + NDManager subManager = (NDManager) resource; + subManager.closeIfWithoutResources(); + } + } + } + + /** {@inheritDoc} */ + @Override + public void closeIfWithoutResources() { + if (resources.isEmpty()) { + close(); + } + } + /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 6a20c70edf9..3bfef125982 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -1699,6 +1699,12 @@ default void tempAttachAll(NDResource... resources) { @Override void close(); + /** Closes all subManagers that do not have any resources attached to them. */ + void closeSubManagersWithoutResources(); + + /** Closes if it does not have any resources attached. */ + void closeIfWithoutResources(); + /** * A {@link SystemNDManager} is a marker class for a base NDManager. * diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index eaaa710afee..f66c9ca5822 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -320,4 +320,16 @@ public Engine getEngine() { /** {@inheritDoc} */ @Override public void close() {} + + /** {@inheritDoc} */ + @Override + public void closeSubManagersWithoutResources() { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + /** {@inheritDoc} */ + @Override + public void closeIfWithoutResources() { + throw new UnsupportedOperationException(UNSUPPORTED); + } } From 320ff1dbbd541a9df583745b036f0e7ad8851dc2 Mon Sep 17 00:00:00 2001 From: enpasos Date: Fri, 16 Dec 2022 18:57:52 +0100 Subject: [PATCH 11/30] renamed the switch to garbageCollectionOn --- .../main/java/ai/djl/ndarray/BaseNDManager.java | 14 +++++++------- api/src/main/java/ai/djl/ndarray/NDManager.java | 6 +++--- .../djl/util/passthrough/PassthroughNDManager.java | 2 +- .../java/ai/djl/pytorch/engine/PtNDArrayImpl.java | 6 +++--- .../java/ai/djl/pytorch/engine/PtNDManager.java | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 42b54d73421..dc2325716d1 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -43,7 +43,7 @@ public abstract class BaseNDManager implements NDManager { private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class); - private final boolean useProxies; + private final boolean garbageCollectionOn; protected NDManager parent; protected NDManager alternativeManager; @@ -59,10 +59,10 @@ protected BaseNDManager(NDManager parent, Device device) { this(parent, device, false); } - protected BaseNDManager(NDManager parent, Device device, boolean useProxies) { + protected BaseNDManager(NDManager parent, Device device, boolean garbageCollectionOn) { this.parent = parent; this.device = device == null ? defaultDevice() : device; - this.useProxies = useProxies; + this.garbageCollectionOn = garbageCollectionOn; resources = new ConcurrentHashMap<>(); tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); @@ -110,8 +110,8 @@ public NDArray create(Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override - public boolean isUseProxies() { - return useProxies; + public boolean isGarbageCollectionOn() { + return garbageCollectionOn; } /** {@inheritDoc} */ @@ -328,7 +328,7 @@ public NDManager getParentManager() { /** {@inheritDoc} */ @Override public NDManager newSubManager() { - return newSubManager(device, useProxies); + return newSubManager(device, garbageCollectionOn); } /** {@inheritDoc} */ @@ -340,7 +340,7 @@ public NDManager newSubManager(boolean useProxies) { /** {@inheritDoc} */ @Override public NDManager newSubManager(Device device) { - return newSubManager(device, useProxies); + return newSubManager(device, garbageCollectionOn); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 3bfef125982..a76250859bc 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -768,11 +768,11 @@ default NDArrayProxyMaker getProxyMaker() { String getName(); /** - * Returns useProxies. + * Returns garbageCollectionOn. * - * @return useProxies + * @return garbageCollectionOn */ - boolean isUseProxies(); + boolean isGarbageCollectionOn(); /** * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros. diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index f66c9ca5822..ad5445e6ba9 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -87,7 +87,7 @@ public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { /** {@inheritDoc} */ @Override - public boolean isUseProxies() { + public boolean isGarbageCollectionOn() { return false; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 6cc0c48aecc..de6db8d1c4b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -1570,7 +1570,7 @@ public void close() { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { PtNDArray instance = new PtNDArrayImpl(manager, handle); - if (manager.isUseProxies()) { + if (manager.isGarbageCollectionOn()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1588,7 +1588,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { PtNDArray instance = new PtNDArrayImpl(manager, handle, data); - if (manager.isUseProxies()) { + if (manager.isGarbageCollectionOn()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1607,7 +1607,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffe */ public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); - if (manager.isUseProxies()) { + if (manager.isGarbageCollectionOn()) { instance = manager.getProxyMaker().wrap(instance); } return instance; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 1f1335c490a..f4c584d0887 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -183,7 +183,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy /** {@inheritDoc} */ @Override public PtNDManager newSubManager(Device device) { - PtNDManager manager = new PtNDManager(this, device, isUseProxies()); + PtNDManager manager = new PtNDManager(this, device, isGarbageCollectionOn()); attachUncappedInternal(manager.uid, manager); return manager; } From e0da7f3cdc67fa3f518ad00d658e08fe4236c43a Mon Sep 17 00:00:00 2001 From: enpasos Date: Sat, 17 Dec 2022 08:37:21 +0100 Subject: [PATCH 12/30] add switch for gc in model --- api/src/main/java/ai/djl/Model.java | 53 +++++++++++++++++++ api/src/main/java/ai/djl/engine/Engine.java | 9 ++++ .../java/ai/djl/dlr/engine/DlrEngine.java | 6 +++ .../java/ai/djl/ml/lightgbm/LgbmEngine.java | 5 ++ .../java/ai/djl/ml/xgboost/XgbEngine.java | 6 +++ .../java/ai/djl/mxnet/engine/MxEngine.java | 6 +++ .../ai/djl/onnxruntime/engine/OrtEngine.java | 6 +++ .../ai/djl/paddlepaddle/engine/PpEngine.java | 6 +++ .../java/ai/djl/pytorch/engine/PtEngine.java | 6 +++ .../java/ai/djl/pytorch/engine/PtModel.java | 14 +++++ .../ai/djl/tensorflow/engine/TfEngine.java | 6 +++ .../ai/djl/tensorrt/engine/TrtEngine.java | 6 +++ .../ai/djl/tflite/engine/TfLiteEngine.java | 6 +++ 13 files changed, 135 insertions(+) diff --git a/api/src/main/java/ai/djl/Model.java b/api/src/main/java/ai/djl/Model.java index 0fec77426c0..be385bced37 100644 --- a/api/src/main/java/ai/djl/Model.java +++ b/api/src/main/java/ai/djl/Model.java @@ -99,6 +99,59 @@ static Model newInstance(String name, Device device, String engineName) { return Engine.getEngine(engineName).newModel(name, device); } + /** + * Creates an empty model instance. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @return a new Model instance + */ + static Model newInstance(String name, boolean useGarbageCollection) { + return newInstance(name, (Device) null, useGarbageCollection); + } + + /** + * Creates an empty model instance on the specified {@link Device}. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @param device the device to load the model onto + * @return a new model instance + */ + static Model newInstance(String name, Device device, boolean useGarbageCollection) { + return Engine.getInstance().newModel(name, device, useGarbageCollection); + } + + /** + * Creates an empty model instance on the specified {@link Device} and engine. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @param engineName the name of the engine + * @return a new model instance + */ + static Model newInstance(String name, String engineName, boolean useGarbageCollection) { + Engine engine = Engine.getEngine(engineName); + return engine.newModel(name, null, useGarbageCollection); + } + + /** + * Creates an empty model instance on the specified {@link Device} and engine. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @param device the device to load the model onto + * @param engineName the name of the engine + * @return a new model instance + */ + static Model newInstance( + String name, Device device, String engineName, boolean useGarbageCollection) { + if (engineName == null || engineName.isEmpty()) { + return newInstance(name, device); + } + return Engine.getEngine(engineName).newModel(name, device, useGarbageCollection); + } + /** * Loads the model from the {@code modelPath}. * diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 099f3b5493f..d56038f904c 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -282,6 +282,15 @@ public int getGpuCount() { */ public abstract Model newModel(String name, Device device); + /** + * Constructs a new model. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @param device the device that the model will be loaded onto + * @return a new Model instance using the network defined in block + */ + public abstract Model newModel(String name, Device device, boolean useGarbageCollection); /** * Creates a new top-level {@link NDManager}. * diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index ee8bf0d4f76..0b41b218f30 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -102,6 +102,12 @@ public Model newModel(String name, Device device) { return new DlrModel(name, newBaseManager(Device.cpu())); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java index 23d0c1adbcc..b0dec588cb3 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java @@ -101,6 +101,11 @@ public Model newModel(String name, Device device) { return new LgbmModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java index 6449ee5f321..59e423f1688 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java @@ -113,6 +113,12 @@ public Model newModel(String name, Device device) { return new XgbModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java index c47d8bd1e49..66eeeddefee 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java @@ -126,6 +126,12 @@ public Model newModel(String name, Device device) { return new MxModel(name, device); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 6bf773e15ec..4aa4e581931 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -124,6 +124,12 @@ public Model newModel(String name, Device device) { return new OrtModel(name, newBaseManager(device), env); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index e482491598d..86a11047205 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -98,6 +98,12 @@ public Model newModel(String name, Device device) { return new PpModel(name, device, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 7fa2636d60b..88b5754e7fe 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -128,6 +128,12 @@ public Model newModel(String name, Device device) { return new PtModel(name, device); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + return new PtModel(name, device, useGarbageCollection); + } + /** {@inheritDoc} */ @Override public NDManager newBaseManager(boolean useProxies) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 0de495bc161..d643c53ae2e 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -57,6 +57,20 @@ public class PtModel extends BaseModel { dataType = DataType.FLOAT32; } + /** + * Constructs a new Model on a given device. + * + * @param name the model name + * @param useGarbageCollection whether to use garbage collection + * @param device the device the model should be located on + */ + PtModel(String name, Device device, boolean useGarbageCollection) { + super(name); + manager = PtNDManager.getSystemManager().newSubManager(device, useGarbageCollection); + manager.setName("ptModel"); + dataType = DataType.FLOAT32; + } + /** {@inheritDoc} */ @Override public void load(Path modelPath, String prefix, Map options) diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java index bd8d10f3640..89d4587a9ea 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java @@ -84,6 +84,12 @@ public Model newModel(String name, Device device) { return new TfModel(name, device); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java index e5663355228..80c38fe5ee6 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java @@ -118,6 +118,12 @@ public Model newModel(String name, Device device) { return new TrtModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public TrtNDManager newBaseManager() { diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index 02be61faa66..05b94c36d50 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -89,6 +89,12 @@ public Model newModel(String name, Device device) { return new TfLiteModel(name, newBaseManager(device)); } + /** {@inheritDoc} */ + @Override + public Model newModel(String name, Device device, boolean useGarbageCollection) { + throw new UnsupportedOperationException("Garbage collection not supported"); + } + /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { From 811bce8538f17376eb7a7ac75af5c9abf33b9b16 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 19 Dec 2022 16:33:40 +0100 Subject: [PATCH 13/30] common name for the switch --- api/src/main/java/ai/djl/engine/Engine.java | 4 +-- .../java/ai/djl/ndarray/BaseNDManager.java | 20 +++++++------- .../main/java/ai/djl/ndarray/NDManager.java | 26 +++++++++---------- .../passthrough/PassthroughNDManager.java | 6 ++--- .../java/ai/djl/pytorch/engine/PtEngine.java | 4 +-- .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 19 +++++++------- .../ai/djl/pytorch/engine/PtNDManager.java | 10 +++---- 7 files changed, 45 insertions(+), 44 deletions(-) diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index d56038f904c..8997eff5050 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -296,10 +296,10 @@ public int getGpuCount() { * *

{@code NDManager} will inherit default {@link Device}. * - * @param useProxies whether to facade resources with a proxy + * @param useGarbageCollection whether to facade resources with a proxy * @return a new top-level {@code NDManager} */ - public NDManager newBaseManager(boolean useProxies) { + public NDManager newBaseManager(boolean useGarbageCollection) { throw new UnsupportedOperationException("Not implemented"); } diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index dc2325716d1..38c5e6ae629 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -43,7 +43,7 @@ public abstract class BaseNDManager implements NDManager { private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class); - private final boolean garbageCollectionOn; + private final boolean useGarbageCollection; protected NDManager parent; protected NDManager alternativeManager; @@ -59,10 +59,10 @@ protected BaseNDManager(NDManager parent, Device device) { this(parent, device, false); } - protected BaseNDManager(NDManager parent, Device device, boolean garbageCollectionOn) { + protected BaseNDManager(NDManager parent, Device device, boolean useGarbageCollection) { this.parent = parent; this.device = device == null ? defaultDevice() : device; - this.garbageCollectionOn = garbageCollectionOn; + this.useGarbageCollection = useGarbageCollection; resources = new ConcurrentHashMap<>(); tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); @@ -110,8 +110,8 @@ public NDArray create(Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override - public boolean isGarbageCollectionOn() { - return garbageCollectionOn; + public boolean isUseGarbageCollection() { + return useGarbageCollection; } /** {@inheritDoc} */ @@ -328,24 +328,24 @@ public NDManager getParentManager() { /** {@inheritDoc} */ @Override public NDManager newSubManager() { - return newSubManager(device, garbageCollectionOn); + return newSubManager(device, useGarbageCollection); } /** {@inheritDoc} */ @Override - public NDManager newSubManager(boolean useProxies) { - return newSubManager(device, useProxies); + public NDManager newSubManager(boolean useGarbageCollection) { + return newSubManager(device, useGarbageCollection); } /** {@inheritDoc} */ @Override public NDManager newSubManager(Device device) { - return newSubManager(device, garbageCollectionOn); + return newSubManager(device, useGarbageCollection); } /** {@inheritDoc} */ @Override - public NDManager newSubManager(Device device, boolean useProxies) { + public NDManager newSubManager(Device device, boolean useGarbageCollection) { throw new UnsupportedOperationException("Not supported!"); } diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index a76250859bc..32cc42ffa92 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -111,11 +111,11 @@ public interface NDManager extends AutoCloseable { * *

{@code NDManager} will inherit default {@link Device}. * - * @param useProxies whether to facade {@link NDArray} behind a proxy + * @param useGarbageCollection whether to facade {@link NDArray} behind a proxy * @return a new top-level {@code NDManager} */ - static NDManager newBaseManager(boolean useProxies) { - return Engine.getInstance().newBaseManager(useProxies); + static NDManager newBaseManager(boolean useGarbageCollection) { + return Engine.getInstance().newBaseManager(useGarbageCollection); } /** @@ -768,11 +768,11 @@ default NDArrayProxyMaker getProxyMaker() { String getName(); /** - * Returns garbageCollectionOn. + * Returns useGarbageCollection. * - * @return garbageCollectionOn + * @return useGarbageCollection */ - boolean isGarbageCollectionOn(); + boolean isUseGarbageCollection(); /** * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros. @@ -1529,23 +1529,23 @@ default NDArray truncatedNormal( NDManager newSubManager(Device device); /** - * Creates a child {@code NDManager} with specified boolean switch useProxies and will inherit - * default {@link Device} from this {@code NDManager}. + * Creates a child {@code NDManager} with specified boolean switch useGarbageCollection and will + * inherit default {@link Device} from this {@code NDManager}. * - * @param useProxies the boolean switch to use proxies + * @param useGarbageCollection the boolean switch to use proxies * @return a child {@code NDManager} */ - NDManager newSubManager(boolean useProxies); + NDManager newSubManager(boolean useGarbageCollection); /** * Creates a child {@code NDManager} with specified default {@link Device} and the boolean - * switch useProxies. + * switch useGarbageCollection. * * @param device the default {@link Device} - * @param useProxies the boolean switch to use proxies + * @param useGarbageCollection the boolean switch to use proxies * @return a child {@code NDManager} */ - NDManager newSubManager(Device device, boolean useProxies); + NDManager newSubManager(Device device, boolean useGarbageCollection); /** * Returns the default {@link Device} of this {@code NDManager}. diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index ad5445e6ba9..0bacb8e288e 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -87,7 +87,7 @@ public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { /** {@inheritDoc} */ @Override - public boolean isGarbageCollectionOn() { + public boolean isUseGarbageCollection() { return false; } @@ -251,13 +251,13 @@ public NDManager newSubManager(Device device) { /** {@inheritDoc} */ @Override - public NDManager newSubManager(boolean useProxies) { + public NDManager newSubManager(boolean useGarbageCollection) { return this; } /** {@inheritDoc} */ @Override - public NDManager newSubManager(Device device, boolean useProxies) { + public NDManager newSubManager(Device device, boolean useGarbageCollection) { return this; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 88b5754e7fe..1b76d8a94a6 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -136,8 +136,8 @@ public Model newModel(String name, Device device, boolean useGarbageCollection) /** {@inheritDoc} */ @Override - public NDManager newBaseManager(boolean useProxies) { - return PtNDManager.getSystemManager().newSubManager(useProxies); + public NDManager newBaseManager(boolean useGarbageCollection) { + return PtNDManager.getSystemManager().newSubManager(useGarbageCollection); } /** {@inheritDoc} */ @Override diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index de6db8d1c4b..ef5573caa5f 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -1561,8 +1561,8 @@ public void close() { /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead). Depending on the switch {@code useProxies}, the returned {@code NDArray} will be - * returned as a proxy or a direct instance. + * instead). Depending on the switch {@code useGarbageCollection}, the returned {@code NDArray} + * will be returned as a proxy or a direct instance. * * @param manager the manager to attach the new array to * @param handle the pointer to the native PyTorch memory @@ -1570,7 +1570,7 @@ public void close() { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { PtNDArray instance = new PtNDArrayImpl(manager, handle); - if (manager.isGarbageCollectionOn()) { + if (manager.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1578,8 +1578,9 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead) with the data that is hold on Java side. Depending on the switch {@code useProxies}, - * the returned {@code NDArray} will be returned as a proxy or a direct instance. + * instead) with the data that is hold on Java side. Depending on the switch {@code + * useGarbageCollection}, the returned {@code NDArray} will be returned as a proxy or a direct + * instance. * * @param manager the manager to attach the new array to * @param handle the pointer to the native PyTorch memory @@ -1588,7 +1589,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { PtNDArray instance = new PtNDArrayImpl(manager, handle, data); - if (manager.isGarbageCollectionOn()) { + if (manager.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1597,8 +1598,8 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffe /** * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. Depending - * on the switch {@code useProxies}, the returned {@code NDArray} will be returned as a proxy or - * a direct instance. + * on the switch {@code useGarbageCollection}, the returned {@code NDArray} will be returned as + * a proxy or a direct instance. * * @param manager the manager to attach the new array to * @param strs the string array @@ -1607,7 +1608,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffe */ public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); - if (manager.isGarbageCollectionOn()) { + if (manager.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index f4c584d0887..52b960d0512 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -34,8 +34,8 @@ public class PtNDManager extends BaseNDManager { protected PtNDArrayProxyMaker proxyMaker; - private PtNDManager(NDManager parent, Device device, boolean useProxies) { - super(parent, device, useProxies); + private PtNDManager(NDManager parent, Device device, boolean useGarbageCollection) { + super(parent, device, useGarbageCollection); } static PtNDManager getSystemManager() { @@ -183,15 +183,15 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy /** {@inheritDoc} */ @Override public PtNDManager newSubManager(Device device) { - PtNDManager manager = new PtNDManager(this, device, isGarbageCollectionOn()); + PtNDManager manager = new PtNDManager(this, device, isUseGarbageCollection()); attachUncappedInternal(manager.uid, manager); return manager; } /** {@inheritDoc} */ @Override - public NDManager newSubManager(Device device, boolean useProxies) { - PtNDManager manager = new PtNDManager(this, device, useProxies); + public NDManager newSubManager(Device device, boolean useGarbageCollection) { + PtNDManager manager = new PtNDManager(this, device, useGarbageCollection); attachUncappedInternal(manager.uid, manager); return manager; } From 043592e42a1b35d228791d72fe31c20e260dd9b1 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 19 Dec 2022 17:32:59 +0100 Subject: [PATCH 14/30] variable details on debugDump --- .../java/ai/djl/pytorch/engine/PtNDManager.java | 14 +++++++++++--- .../java/ai/djl/pytorch/integration/gc/Main.java | 8 ++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 52b960d0512..9162bb3afc4 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -202,9 +202,17 @@ public final Engine getEngine() { return Engine.getEngine(PtEngine.ENGINE_NAME); } - /** Dumps debug information about the current {@link PtNDManager} and all its children. */ - public static void debugDumpFromSystemManager() { - ((BaseNDManager) PtNDManager.getSystemManager()).debugDumpDetailed(0); + /** + * Dumps debug information about the current {@link PtNDManager} and all its children. + * + * @param detailed whether to dump detailed information + */ + public static void debugDumpFromSystemManager(boolean detailed) { + if (detailed) { + ((BaseNDManager) PtNDManager.getSystemManager()).debugDumpDetailed(0); + } else { + ((BaseNDManager) PtNDManager.getSystemManager()).debugDump(0); + } } /** The SystemManager is the root {@link PtNDManager} of which all others are children. */ diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index 2b1715495b0..26466071d48 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -40,7 +40,7 @@ public static void main(String[] args) NDArray a = subManager.create(new float[] {1f}); NDArray b = subManager.create(new float[] {2f}); NDArray c = a.add(b); - debugDumpFromSystemManager(); + debugDumpFromSystemManager(true); logger.info("reference exists ..."); logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); @@ -55,10 +55,10 @@ public static void main(String[] args) logger.info("no reference exists, and likely garbage collected ..."); logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); - debugDumpFromSystemManager(); + debugDumpFromSystemManager(true); } - debugDumpFromSystemManager(); + debugDumpFromSystemManager(true); } - debugDumpFromSystemManager(); + debugDumpFromSystemManager(true); } } From e79c553dd802bc5847719e2a49a5957da04190d9 Mon Sep 17 00:00:00 2001 From: enpasos Date: Wed, 21 Dec 2022 17:14:13 +0100 Subject: [PATCH 15/30] fixed a memory leak --- .../main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java index aab36876c65..fcf88f4ce8a 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -17,9 +17,8 @@ import java.lang.ref.Reference; import java.lang.ref.ReferenceQueue; import java.lang.reflect.Proxy; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.WeakHashMap; @@ -34,7 +33,7 @@ public class WeakHashMapWrapper implements Map { private final WeakHashMap map = new WeakHashMap<>(); private final ReferenceQueue queue = new ReferenceQueue<>(); - private final List> weakReferenceWrapperList = new ArrayList<>(); + private final Set> weakReferenceWrapperSet = new HashSet<>(); private void checkQueue() { for (Reference ref; (ref = queue.poll()) != null; ) { @@ -44,6 +43,7 @@ private void checkQueue() { V value = ref2.getValue(); if (value instanceof NDArray) { // just as one example ((NDArray) value).close(); + weakReferenceWrapperSet.remove(ref2); } } } @@ -94,7 +94,7 @@ public V put(K key, V value) { throw new IllegalArgumentException( "Proxy is not supported to be stored as value here."); } - weakReferenceWrapperList.add(new WeakReferenceWrapper(key, value, queue)); + weakReferenceWrapperSet.add(new WeakReferenceWrapper(key, value, queue)); return map.put(key, value); } From 5165335725409d254fab665d349451e6779fc935 Mon Sep 17 00:00:00 2001 From: enpasos Date: Tue, 27 Dec 2022 19:32:12 +0100 Subject: [PATCH 16/30] global switch --- api/src/main/java/ai/djl/Model.java | 53 ------------------- api/src/main/java/ai/djl/engine/Engine.java | 21 -------- .../java/ai/djl/ndarray/BaseNDManager.java | 33 +----------- .../main/java/ai/djl/ndarray/NDManager.java | 38 ------------- .../ndarray/gc/SwitchGarbageCollection.java | 36 +++++++++++++ .../passthrough/PassthroughNDManager.java | 18 ------- .../java/ai/djl/dlr/engine/DlrEngine.java | 6 --- .../java/ai/djl/ml/lightgbm/LgbmEngine.java | 5 -- .../java/ai/djl/ml/xgboost/XgbEngine.java | 6 --- .../java/ai/djl/mxnet/engine/MxEngine.java | 6 --- .../ai/djl/onnxruntime/engine/OrtEngine.java | 6 --- .../ai/djl/paddlepaddle/engine/PpEngine.java | 6 --- .../java/ai/djl/pytorch/engine/PtEngine.java | 11 ---- .../java/ai/djl/pytorch/engine/PtModel.java | 14 ----- .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 7 +-- .../ai/djl/pytorch/engine/PtNDManager.java | 16 ++---- .../ai/djl/pytorch/integration/gc/Main.java | 5 +- .../ai/djl/tensorflow/engine/TfEngine.java | 6 --- .../ai/djl/tensorrt/engine/TrtEngine.java | 6 --- .../ai/djl/tflite/engine/TfLiteEngine.java | 6 --- 20 files changed, 49 insertions(+), 256 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java diff --git a/api/src/main/java/ai/djl/Model.java b/api/src/main/java/ai/djl/Model.java index be385bced37..0fec77426c0 100644 --- a/api/src/main/java/ai/djl/Model.java +++ b/api/src/main/java/ai/djl/Model.java @@ -99,59 +99,6 @@ static Model newInstance(String name, Device device, String engineName) { return Engine.getEngine(engineName).newModel(name, device); } - /** - * Creates an empty model instance. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @return a new Model instance - */ - static Model newInstance(String name, boolean useGarbageCollection) { - return newInstance(name, (Device) null, useGarbageCollection); - } - - /** - * Creates an empty model instance on the specified {@link Device}. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @param device the device to load the model onto - * @return a new model instance - */ - static Model newInstance(String name, Device device, boolean useGarbageCollection) { - return Engine.getInstance().newModel(name, device, useGarbageCollection); - } - - /** - * Creates an empty model instance on the specified {@link Device} and engine. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @param engineName the name of the engine - * @return a new model instance - */ - static Model newInstance(String name, String engineName, boolean useGarbageCollection) { - Engine engine = Engine.getEngine(engineName); - return engine.newModel(name, null, useGarbageCollection); - } - - /** - * Creates an empty model instance on the specified {@link Device} and engine. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @param device the device to load the model onto - * @param engineName the name of the engine - * @return a new model instance - */ - static Model newInstance( - String name, Device device, String engineName, boolean useGarbageCollection) { - if (engineName == null || engineName.isEmpty()) { - return newInstance(name, device); - } - return Engine.getEngine(engineName).newModel(name, device, useGarbageCollection); - } - /** * Loads the model from the {@code modelPath}. * diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 8997eff5050..73d19779bd5 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -282,27 +282,6 @@ public int getGpuCount() { */ public abstract Model newModel(String name, Device device); - /** - * Constructs a new model. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @param device the device that the model will be loaded onto - * @return a new Model instance using the network defined in block - */ - public abstract Model newModel(String name, Device device, boolean useGarbageCollection); - /** - * Creates a new top-level {@link NDManager}. - * - *

{@code NDManager} will inherit default {@link Device}. - * - * @param useGarbageCollection whether to facade resources with a proxy - * @return a new top-level {@code NDManager} - */ - public NDManager newBaseManager(boolean useGarbageCollection) { - throw new UnsupportedOperationException("Not implemented"); - } - /** * Creates a new top-level {@link NDManager}. * diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 38c5e6ae629..aa6588f4367 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -43,7 +43,6 @@ public abstract class BaseNDManager implements NDManager { private static final Logger logger = LoggerFactory.getLogger(BaseNDManager.class); - private final boolean useGarbageCollection; protected NDManager parent; protected NDManager alternativeManager; @@ -56,13 +55,8 @@ public abstract class BaseNDManager implements NDManager { protected AtomicBoolean capped = new AtomicBoolean(false); protected BaseNDManager(NDManager parent, Device device) { - this(parent, device, false); - } - - protected BaseNDManager(NDManager parent, Device device, boolean useGarbageCollection) { this.parent = parent; this.device = device == null ? defaultDevice() : device; - this.useGarbageCollection = useGarbageCollection; resources = new ConcurrentHashMap<>(); tempResources = new ConcurrentHashMap<>(); uid = UUID.randomUUID().toString(); @@ -102,18 +96,13 @@ public void closeIfWithoutResources() { public NDArray create(String[] data, Charset charset, Shape shape) { throw new UnsupportedOperationException("Not supported!"); } + /** {@inheritDoc} */ @Override public NDArray create(Shape shape, DataType dataType) { throw new UnsupportedOperationException("Not supported!"); } - /** {@inheritDoc} */ - @Override - public boolean isUseGarbageCollection() { - return useGarbageCollection; - } - /** {@inheritDoc} */ @Override public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { @@ -328,25 +317,7 @@ public NDManager getParentManager() { /** {@inheritDoc} */ @Override public NDManager newSubManager() { - return newSubManager(device, useGarbageCollection); - } - - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(boolean useGarbageCollection) { - return newSubManager(device, useGarbageCollection); - } - - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(Device device) { - return newSubManager(device, useGarbageCollection); - } - - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Not supported!"); + return newSubManager(device); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 32cc42ffa92..42329be8341 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -106,18 +106,6 @@ */ public interface NDManager extends AutoCloseable { - /** - * Creates a new top-level {@code NDManager}. - * - *

{@code NDManager} will inherit default {@link Device}. - * - * @param useGarbageCollection whether to facade {@link NDArray} behind a proxy - * @return a new top-level {@code NDManager} - */ - static NDManager newBaseManager(boolean useGarbageCollection) { - return Engine.getInstance().newBaseManager(useGarbageCollection); - } - /** * Creates a new top-level {@code NDManager}. * @@ -767,13 +755,6 @@ default NDArrayProxyMaker getProxyMaker() { */ String getName(); - /** - * Returns useGarbageCollection. - * - * @return useGarbageCollection - */ - boolean isUseGarbageCollection(); - /** * Creates an instance of {@link NDArray} with specified {@link Shape} filled with zeros. * @@ -1528,25 +1509,6 @@ default NDArray truncatedNormal( */ NDManager newSubManager(Device device); - /** - * Creates a child {@code NDManager} with specified boolean switch useGarbageCollection and will - * inherit default {@link Device} from this {@code NDManager}. - * - * @param useGarbageCollection the boolean switch to use proxies - * @return a child {@code NDManager} - */ - NDManager newSubManager(boolean useGarbageCollection); - - /** - * Creates a child {@code NDManager} with specified default {@link Device} and the boolean - * switch useGarbageCollection. - * - * @param device the default {@link Device} - * @param useGarbageCollection the boolean switch to use proxies - * @return a child {@code NDManager} - */ - NDManager newSubManager(Device device, boolean useGarbageCollection); - /** * Returns the default {@link Device} of this {@code NDManager}. * diff --git a/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java b/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java new file mode 100644 index 00000000000..de356a0e1b2 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java @@ -0,0 +1,36 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +/** {@code GCSwitch} acts as a switch to put garbage collection on or off. */ +public final class SwitchGarbageCollection { + + private static boolean useGarbageCollection; + + /** Hide the constructor of this utility class. */ + private SwitchGarbageCollection() {} + + /** + * Returns whether to use garbage collection to manage temporary resources. + * + * @return the useGarbageCollection + */ + public static boolean isUseGarbageCollection() { + return useGarbageCollection; + } + + /** Switches the garbage collection on. */ + public static void on() { + useGarbageCollection = true; + } +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index 0bacb8e288e..3d1d50bfba4 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -85,12 +85,6 @@ public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { throw new UnsupportedOperationException(UNSUPPORTED); } - /** {@inheritDoc} */ - @Override - public boolean isUseGarbageCollection() { - return false; - } - /** {@inheritDoc} */ @Override public NDList load(Path path) { @@ -249,18 +243,6 @@ public NDManager newSubManager(Device device) { return this; } - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(boolean useGarbageCollection) { - return this; - } - - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(Device device, boolean useGarbageCollection) { - return this; - } - /** {@inheritDoc} */ @Override public Device getDevice() { diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java index 0b41b218f30..ee8bf0d4f76 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrEngine.java @@ -102,12 +102,6 @@ public Model newModel(String name, Device device) { return new DlrModel(name, newBaseManager(Device.cpu())); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java index b0dec588cb3..23d0c1adbcc 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngine.java @@ -101,11 +101,6 @@ public Model newModel(String name, Device device) { return new LgbmModel(name, newBaseManager(device)); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java index 59e423f1688..6449ee5f321 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java @@ -113,12 +113,6 @@ public Model newModel(String name, Device device) { return new XgbModel(name, newBaseManager(device)); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java index 66eeeddefee..c47d8bd1e49 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngine.java @@ -126,12 +126,6 @@ public Model newModel(String name, Device device) { return new MxModel(name, device); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index 4aa4e581931..6bf773e15ec 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -124,12 +124,6 @@ public Model newModel(String name, Device device) { return new OrtModel(name, newBaseManager(device), env); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java index 86a11047205..e482491598d 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngine.java @@ -98,12 +98,6 @@ public Model newModel(String name, Device device) { return new PpModel(name, device, newBaseManager(device)); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 1b76d8a94a6..d0f7a101394 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -128,17 +128,6 @@ public Model newModel(String name, Device device) { return new PtModel(name, device); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - return new PtModel(name, device, useGarbageCollection); - } - - /** {@inheritDoc} */ - @Override - public NDManager newBaseManager(boolean useGarbageCollection) { - return PtNDManager.getSystemManager().newSubManager(useGarbageCollection); - } /** {@inheritDoc} */ @Override public NDManager newBaseManager() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index d643c53ae2e..0de495bc161 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -57,20 +57,6 @@ public class PtModel extends BaseModel { dataType = DataType.FLOAT32; } - /** - * Constructs a new Model on a given device. - * - * @param name the model name - * @param useGarbageCollection whether to use garbage collection - * @param device the device the model should be located on - */ - PtModel(String name, Device device, boolean useGarbageCollection) { - super(name); - manager = PtNDManager.getSystemManager().newSubManager(device, useGarbageCollection); - manager.setName("ptModel"); - dataType = DataType.FLOAT32; - } - /** {@inheritDoc} */ @Override public void load(Path modelPath, String prefix, Map options) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index ef5573caa5f..2ec1579c3b0 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -1570,7 +1571,7 @@ public void close() { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { PtNDArray instance = new PtNDArrayImpl(manager, handle); - if (manager.isUseGarbageCollection()) { + if (SwitchGarbageCollection.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1589,7 +1590,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { */ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { PtNDArray instance = new PtNDArrayImpl(manager, handle, data); - if (manager.isUseGarbageCollection()) { + if (SwitchGarbageCollection.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; @@ -1608,7 +1609,7 @@ public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffe */ public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); - if (manager.isUseGarbageCollection()) { + if (SwitchGarbageCollection.isUseGarbageCollection()) { instance = manager.getProxyMaker().wrap(instance); } return instance; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 9162bb3afc4..b9ea160e515 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -34,8 +34,8 @@ public class PtNDManager extends BaseNDManager { protected PtNDArrayProxyMaker proxyMaker; - private PtNDManager(NDManager parent, Device device, boolean useGarbageCollection) { - super(parent, device, useGarbageCollection); + private PtNDManager(NDManager parent, Device device) { + super(parent, device); } static PtNDManager getSystemManager() { @@ -183,15 +183,7 @@ public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataTy /** {@inheritDoc} */ @Override public PtNDManager newSubManager(Device device) { - PtNDManager manager = new PtNDManager(this, device, isUseGarbageCollection()); - attachUncappedInternal(manager.uid, manager); - return manager; - } - - /** {@inheritDoc} */ - @Override - public NDManager newSubManager(Device device, boolean useGarbageCollection) { - PtNDManager manager = new PtNDManager(this, device, useGarbageCollection); + PtNDManager manager = new PtNDManager(this, device); attachUncappedInternal(manager.uid, manager); return manager; } @@ -219,7 +211,7 @@ public static void debugDumpFromSystemManager(boolean detailed) { private static final class SystemManager extends PtNDManager implements SystemNDManager { SystemManager() { - super(null, null, false); + super(null, null); this.proxyMaker = new PtNDArrayProxyMaker(); } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index 26466071d48..b448053312b 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -33,8 +34,8 @@ private Main() {} public static void main(String[] args) throws IOException, TranslateException, InterruptedException { - - try (NDManager baseManager = NDManager.newBaseManager(true); ) { + SwitchGarbageCollection.on(); + try (NDManager baseManager = NDManager.newBaseManager(); ) { try (NDManager subManager = baseManager.newSubManager()) { NDArray a = subManager.create(new float[] {1f}); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java index 89d4587a9ea..bd8d10f3640 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngine.java @@ -84,12 +84,6 @@ public Model newModel(String name, Device device) { return new TfModel(name, device); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java index 80c38fe5ee6..e5663355228 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngine.java @@ -118,12 +118,6 @@ public Model newModel(String name, Device device) { return new TrtModel(name, newBaseManager(device)); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public TrtNDManager newBaseManager() { diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java index 05b94c36d50..02be61faa66 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngine.java @@ -89,12 +89,6 @@ public Model newModel(String name, Device device) { return new TfLiteModel(name, newBaseManager(device)); } - /** {@inheritDoc} */ - @Override - public Model newModel(String name, Device device, boolean useGarbageCollection) { - throw new UnsupportedOperationException("Garbage collection not supported"); - } - /** {@inheritDoc} */ @Override public SymbolBlock newSymbolBlock(NDManager manager) { From 66a613a48fc0d26152be876c1689c6f97f34eaad Mon Sep 17 00:00:00 2001 From: enpasos Date: Fri, 30 Dec 2022 07:48:24 +0100 Subject: [PATCH 17/30] sync fork --- .../java/ai/djl/pytorch/engine/PtNDArray.java | 17 +++------- .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 31 +++++++++++++++++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 10 +++--- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index a5939f70d60..4c1574c7ac4 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -530,9 +530,7 @@ public interface PtNDArray extends NativeResource, NDArray { /** {@inheritDoc} */ @Override - public NDArray fft(long length, long axis) { - return JniUtils.fft(this, length, axis); - } + public NDArray fft(long length, long axis); /** {@inheritDoc} */ @Override @@ -542,10 +540,7 @@ public NDArray stft( boolean center, NDArray window, boolean normalize, - boolean returnComplex) { - return JniUtils.stft( - this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); - } + boolean returnComplex); /** {@inheritDoc} */ @Override @@ -765,15 +760,11 @@ public NDArray stft( /** {@inheritDoc} */ @Override - public NDArray complex() { - return JniUtils.complex(this); - } + public NDArray complex(); /** {@inheritDoc} */ @Override - public NDArray real() { - return JniUtils.real(this); - } + public NDArray real(); /** {@inheritDoc} */ @Override diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 2ec1579c3b0..4a134034570 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -1057,6 +1057,25 @@ public NDArray flatten(int startDim, int endDim) { return JniUtils.flatten(this, startDim, endDim); } + /** {@inheritDoc} */ + @Override + public NDArray fft(long length, long axis) { + return JniUtils.fft(this, length, axis); + } + + /** {@inheritDoc} */ + @Override + public NDArray stft( + long nFft, + long hopLength, + boolean center, + NDArray window, + boolean normalize, + boolean returnComplex) { + return JniUtils.stft( + this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); + } + /** {@inheritDoc} */ @Override public PtNDArray reshape(Shape shape) { @@ -1513,6 +1532,18 @@ public NDArray batchDot(NDArray other) { throw new UnsupportedOperationException("Not implemented"); } + /** {@inheritDoc} */ + @Override + public NDArray complex() { + return JniUtils.complex(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray real() { + return JniUtils.real(this); + } + /** {@inheritDoc} */ @Override public PtNDArrayEx getNDArrayInternal() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 8a49187f1aa..2d1a07d5293 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -928,7 +928,7 @@ public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { } public static PtNDArray fft(PtNDArray ndArray, long length, long axis) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFft(ndArray.getHandle(), length, axis)); } @@ -953,7 +953,7 @@ public static PtNDArray stft( if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray real(PtNDArray ndArray) { @@ -961,7 +961,7 @@ public static PtNDArray real(PtNDArray ndArray) { if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray complex(PtNDArray ndArray) { @@ -969,7 +969,7 @@ public static PtNDArray complex(PtNDArray ndArray) { if (handle == -1) { throw new UnsupportedOperationException("complex() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray abs(PtNDArray ndArray) { @@ -1237,7 +1237,7 @@ public static PtNDArray eye( public static PtNDArray hannWindow( PtNDManager manager, long numPoints, boolean periodic, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchHannWindow( numPoints, From 687e307baa9efae1647811e15d537d73cdd8a966 Mon Sep 17 00:00:00 2001 From: enpasos Date: Sat, 31 Dec 2022 07:39:01 +0100 Subject: [PATCH 18/30] opened LayerNorm.Builder for inheritance --- api/src/main/java/ai/djl/nn/norm/LayerNorm.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 43ebd85b822..5d69284132e 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -178,7 +178,7 @@ public void loadMetadata(byte loadVersion, DataInputStream is) } /** The Builder to construct a {@link LayerNorm}. */ - public static final class Builder { + public static class Builder { private float epsilon = 1E-5f; // private Shape normalizedShape; From 75bb5f207e07733a1eb467d58c59cfb0833fdf17 Mon Sep 17 00:00:00 2001 From: enpasos Date: Wed, 4 Jan 2023 20:08:54 +0100 Subject: [PATCH 19/30] uid-counter, getImplementation, debugCountNDArrays,getNumOfNDArraysInGCMap, zeroGrad --- .../java/ai/djl/ndarray/BaseNDManager.java | 19 +++++++++++++++ .../ndarray/gc/DynamicInvocationHandler.java | 21 +++++++++++------ .../pytorch/engine/PtGradientCollector.java | 2 +- .../java/ai/djl/pytorch/engine/PtNDArray.java | 23 +++++++++++++++++++ .../ai/djl/pytorch/engine/PtNDArrayImpl.java | 18 +++++++++++++++ .../pytorch/engine/PtNDArrayProxyMaker.java | 12 ++++++---- .../ai/djl/pytorch/engine/PtNDManager.java | 9 ++++++++ .../ai/djl/pytorch/integration/gc/Main.java | 10 +++++++- tools/conf/findbugs-exclude.xml | 4 ++++ 9 files changed, 104 insertions(+), 14 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index aa6588f4367..c7bd4d8ae59 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -532,6 +532,25 @@ public void debugDumpDetailed(int level) { } } + /** + * Returns the number of {@link NDArray} in the hierarchy of this {@link NDManager}. + * + * @return return the number of {@link NDArray} in the hierarchy of this {@link NDManager} + */ + public int debugCountNDArrays() { + int count = 0; + for (AutoCloseable c : resources.values()) { + if (c instanceof BaseNDManager) { + count += ((BaseNDManager) c).debugCountNDArrays(); + } else if (c instanceof NDArray) { + count++; + } else if (c instanceof NDList) { + count += ((NDList) c).size(); + } + } + return count; + } + NDManager getAlternativeManager() { return alternativeManager; } diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java index 804d2f0f557..b611064ed16 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -20,29 +20,30 @@ import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.UUID; /** {@code DynamicInvocationHandler} implements the {@link InvocationHandler}. */ public class DynamicInvocationHandler implements InvocationHandler { private static final Logger logger = LoggerFactory.getLogger(DynamicInvocationHandler.class); - WeakHashMapWrapper map; - UUID uuid; + WeakHashMapWrapper map; + String uid; NDArrayProxyMaker ndArrayProxyMaker; /** * Creates a new instance of {@code DynamicInvocationHandler}. * - * @param uuid the uuid + * @param uid the uid * @param map the map * @param ndArrayProxyMaker the ndArrayProxyMaker */ public DynamicInvocationHandler( - UUID uuid, WeakHashMapWrapper map, NDArrayProxyMaker ndArrayProxyMaker) { + String uid, + WeakHashMapWrapper map, + NDArrayProxyMaker ndArrayProxyMaker) { this.map = map; - this.uuid = uuid; + this.uid = uid; this.ndArrayProxyMaker = ndArrayProxyMaker; } @@ -50,9 +51,15 @@ public DynamicInvocationHandler( @Override public Object invoke(Object proxy, Method method, Object[] args) { + if ("getNumOfNDArraysInGCMap".equals(method.getName())) { + return this.map.size(); + } + if ("getImplementation".equals(method.getName())) { + return map.get(uid); + } Object result; try { - result = method.invoke(map.get(uuid), args); + result = method.invoke(map.get(uid), args); } catch (IllegalAccessException | InvocationTargetException e) { logger.error("Error invoking method", e); throw new RuntimeException(e); // NOPMD diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index eb05dc82b41..91104f78093 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -74,7 +74,7 @@ public void zeroGradients() { for (NDArray array : systemManager.getManagedArrays()) { try { if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + JniUtils.zeroGrad((PtNDArray) array); } } catch (IllegalStateException e) { // ignore if the array is already closed diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 4c1574c7ac4..348d24a2d13 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -785,4 +785,27 @@ public NDArray stft( /** {@inheritDoc} */ @Override void close(); + + /** + * Returns the number of {@code NDArray} in the map used for garbage collection triggered + * closing. + * + * @return number of {@code NDArray} in the map used for garbage collection triggered closing + */ + int getNumOfNDArraysInGCMap(); + + /** + * Returns the number of {@code NDArray} in the hierarchy of {@code NDManager}. + * + * @return number of {@code NDArray} in the hierarchy of {@code NDManager} + */ + int getNumOfNDArraysInNDManagerHierarchy(); + + /** + * Returns the raw implementation of this interface. This could be useful for debugging if the + * interface is implemented by a proxy. + * + * @return the raw implementation of this interface + */ + PtNDArrayImpl getImplementation(); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 4a134034570..4d40bd7568b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -1591,6 +1591,24 @@ public void close() { dataRef = null; } + /** {@inheritDoc} */ + @Override + public int getNumOfNDArraysInGCMap() { + throw new UnsupportedOperationException("Not supported!"); + } + + /** {@inheritDoc} */ + @Override + public int getNumOfNDArraysInNDManagerHierarchy() { + return PtNDManager.debugCountNDArraysFromSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public PtNDArrayImpl getImplementation() { + return this; + } + /** * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} * instead). Depending on the switch {@code useGarbageCollection}, the returned {@code NDArray} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index 3e5253964a5..708689c7b11 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -18,12 +18,14 @@ import ai.djl.ndarray.gc.WeakHashMapWrapper; import java.lang.reflect.Proxy; -import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; /** {@code PtNDArrayProxyMaker} creates a proxy facade. */ public class PtNDArrayProxyMaker implements NDArrayProxyMaker { - WeakHashMapWrapper map = new WeakHashMapWrapper<>(); + WeakHashMapWrapper map = new WeakHashMapWrapper<>(); + + AtomicLong counter = new AtomicLong(0); /** {@inheritDoc} */ @Override @@ -39,9 +41,9 @@ public int mapSize() { */ @Override public PtNDArray wrap(NDArray array) { - UUID uuid = UUID.randomUUID(); - map.put(uuid, array); - DynamicInvocationHandler handler = new DynamicInvocationHandler(uuid, map, this); + String uid = array.getUid() + "-" + counter.incrementAndGet(); + map.put(uid, array); + DynamicInvocationHandler handler = new DynamicInvocationHandler(uid, map, this); return (PtNDArray) Proxy.newProxyInstance( Thread.currentThread().getContextClassLoader(), diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 8756fd713ac..f95cd533434 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -213,6 +213,15 @@ public static void debugDumpFromSystemManager(boolean detailed) { } } + /** + * Returns the number of {@link NDArray} in the hierarchy of the {@link SystemNDManager}. + * + * @return return the number of {@link NDArray} in the hierarchy of the {@link SystemNDManager} + */ + public static int debugCountNDArraysFromSystemManager() { + return ((BaseNDManager) PtNDManager.getSystemManager()).debugCountNDArrays(); + } + /** The SystemManager is the root {@link PtNDManager} of which all others are children. */ private static final class SystemManager extends PtNDManager implements SystemNDManager { diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java index b448053312b..24544230795 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.gc.SwitchGarbageCollection; +import ai.djl.pytorch.engine.PtNDArray; import ai.djl.translate.TranslateException; import org.slf4j.Logger; @@ -40,7 +41,14 @@ public static void main(String[] args) NDArray a = subManager.create(new float[] {1f}); NDArray b = subManager.create(new float[] {2f}); - NDArray c = a.add(b); + PtNDArray c = (PtNDArray) a.add(b); + logger.info( + "number of NDArrays in NDManager hierarchy {}", + c.getNumOfNDArraysInNDManagerHierarchy()); + logger.info( + "number of NDArrays in map used of gc triggered NDArray closing {}", + c.getNumOfNDArraysInGCMap()); + debugDumpFromSystemManager(true); logger.info("reference exists ..."); diff --git a/tools/conf/findbugs-exclude.xml b/tools/conf/findbugs-exclude.xml index 2e59f19a5ad..b882914b622 100644 --- a/tools/conf/findbugs-exclude.xml +++ b/tools/conf/findbugs-exclude.xml @@ -33,4 +33,8 @@ + + + + From 64b056b1a5094a61f272aabff6596fa83167d7d5 Mon Sep 17 00:00:00 2001 From: enpasos Date: Wed, 4 Jan 2023 20:35:15 +0100 Subject: [PATCH 20/30] merged --- .../java/ai/djl/pytorch/engine/PtNDArrayImpl.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java index 4d40bd7568b..50101175d47 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -402,7 +402,7 @@ public NDArray sequenceMask(NDArray sequenceLength) { /** {@inheritDoc} */ @Override public boolean contentEquals(Number number) { - return JniUtils.contentEqual(this, (PtNDArray) manager.create(number)); + return contentEquals(manager.create(number)); } /** {@inheritDoc} */ @@ -1104,11 +1104,11 @@ public PtNDArray squeeze(int axis) { @Override public PtNDArray squeeze(int[] axes) { if (isScalar()) { - if (axes.length > 1 || axes[0] != 0) { - throw new IllegalArgumentException( - "axis " + axes[0] + "is out of bounds for array of dimension 0"); + if (axes.length == 0 || (axes.length == 1 && axes[0] == 0)) { + return (PtNDArray) duplicate(); } - return (PtNDArray) duplicate(); + throw new IllegalArgumentException( + "axis " + axes[0] + " is out of bounds for array of dimension 0"); } long[] shapeArr = getShape().getShape(); List newShape = new ArrayList<>(); @@ -1568,8 +1568,8 @@ public String toString() { /** {@inheritDoc} */ @Override public boolean equals(Object obj) { - if (obj instanceof PtNDArray) { - return contentEquals((PtNDArray) obj); + if (obj instanceof NDArray) { + return contentEquals((NDArray) obj); } return false; } From 866ecf1ab19a4e48e2911f5cc3ad4134510ff492 Mon Sep 17 00:00:00 2001 From: enpasos Date: Sat, 7 Jan 2023 18:00:45 +0100 Subject: [PATCH 21/30] merge fix --- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 0d4c5c311a0..bce1698c472 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -422,10 +422,6 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m } finally { PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } - - return PtNDArrayImpl.newPtNDArray( - manager, - PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle)); } @SuppressWarnings("OptionalGetWithoutIsPresent") From c574584b6401f78f61114f08e64a065758dba704 Mon Sep 17 00:00:00 2001 From: enpasos Date: Sat, 7 Jan 2023 18:07:06 +0100 Subject: [PATCH 22/30] PtGradientCollector from master --- .../java/ai/djl/pytorch/engine/PtGradientCollector.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index 91104f78093..5ef573cee68 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -72,12 +72,8 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { - try { - if (array.hasGradient()) { - JniUtils.zeroGrad((PtNDArray) array); - } - } catch (IllegalStateException e) { - // ignore if the array is already closed + if (array.hasGradient()) { + array.getGradient().subi(array.getGradient()); } } } From 13241f932a040ac03e95cfd605377eb7026d3fa8 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 13:45:05 +0100 Subject: [PATCH 23/30] Revert "A temporary solution to issue 2210 (#2304)" This reverts commit 7866a4b627cf3c1894ba262d338e2e1d256b3ec2. --- .../java/ai/djl/ndarray/BaseNDManager.java | 58 ++++++++------ api/src/main/java/ai/djl/ndarray/NDArray.java | 8 ++ api/src/main/java/ai/djl/ndarray/NDList.java | 7 ++ .../main/java/ai/djl/ndarray/NDManager.java | 11 ++- .../main/java/ai/djl/ndarray/NDResource.java | 9 +++ .../ai/djl/training/GradientCollector.java | 11 +++ .../passthrough/PassthroughNDManager.java | 14 ++-- .../djl/mxnet/engine/MxGradientCollector.java | 13 ++- .../pytorch/engine/PtGradientCollector.java | 27 ++++++- .../training/TrainAirfoilWithTabNetTest.java | 1 + .../GradientCollectorIntegrationTest.java | 79 +++++++++++++++++-- 11 files changed, 194 insertions(+), 44 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 47bbb8fbdcd..c7bd4d8ae59 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -32,9 +32,12 @@ import java.nio.ShortBuffer; import java.nio.charset.Charset; import java.nio.file.Path; +import java.util.List; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** {@code BaseNDManager} is the default implementation of {@link NDManager}. */ public abstract class BaseNDManager implements NDManager { @@ -323,6 +326,30 @@ public Device getDevice() { return device; } + /** {@inheritDoc} */ + @Override + public List getManagedArrays() { + return Stream.concat( + // Main resources + resources.values().stream() + .flatMap( + r -> { + if (r instanceof NDResource) { + return ((NDResource) r) + .getResourceNDArrays().stream(); + } else if (r instanceof NDManager) { + return ((NDManager) r).getManagedArrays().stream(); + } else { + return Stream.empty(); + } + }), + + // Temp resouces + tempResources.values().stream() + .flatMap(tr -> tr.resource.getResourceNDArrays().stream())) + .collect(Collectors.toList()); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -340,9 +367,6 @@ public String toString() { /** {@inheritDoc} */ @Override public synchronized void attachInternal(String resourceId, AutoCloseable resource) { - if (this instanceof SystemNDManager) { - return; - } if (capped.get()) { throw new IllegalStateException("NDManager is capped for addition of resources."); } @@ -352,9 +376,6 @@ public synchronized void attachInternal(String resourceId, AutoCloseable resourc /** {@inheritDoc} */ @Override public synchronized void attachUncappedInternal(String resourceId, AutoCloseable resource) { - if (this instanceof SystemNDManager) { - return; - } if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); } @@ -381,7 +402,8 @@ public synchronized void attachUncappedInternal(String resourceId, AutoCloseable public void tempAttachInternal( NDManager originalManager, String resourceId, NDResource resource) { if (this instanceof SystemNDManager) { - return; + throw new IllegalStateException( + "System manager cannot be temp attached because it can't be closed.."); } if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); @@ -392,9 +414,6 @@ public void tempAttachInternal( /** {@inheritDoc} */ @Override public synchronized void detachInternal(String resourceId) { - if (this instanceof SystemNDManager) { - return; - } if (closed.get()) { // This may happen in the middle of BaseNDManager.close() return; @@ -421,26 +440,13 @@ public NDList invoke(String operation, NDList src, PairList params) { throw new UnsupportedOperationException("Not supported!"); } - /** {@inheritDoc} */ - @Override - public void zeroGradients() { - for (AutoCloseable res : resources.values()) { - if (res instanceof NDManager) { - ((NDManager) res).zeroGradients(); - } else if (res instanceof NDArray) { - NDArray array = (NDArray) res; - if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); - } - } - } - } - /** {@inheritDoc} */ @Override public void close() { if (this instanceof SystemNDManager) { - return; + throw new IllegalStateException( + "The SystemNDManager can not be closed. It is global and lives for the duration" + + " of the process"); } if (!closed.getAndSet(true)) { for (AutoCloseable closeable : resources.values()) { diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 1adfd59ac6a..ee67c7f78ee 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -30,6 +30,8 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.LongStream; @@ -4685,6 +4687,12 @@ default NDArray countNonzero(int axis) { */ NDArray erfinv(); + /** {@inheritDoc} */ + @Override + default List getResourceNDArrays() { + return Collections.singletonList(this); + } + /** * Returns an internal representative of Native {@code NDArray}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index c88c93337e5..36c73bb42a9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.List; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; @@ -269,6 +270,12 @@ public NDManager getManager() { return head().getManager(); } + /** {@inheritDoc} */ + @Override + public List getResourceNDArrays() { + return this; + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 206b7fcb4f0..574cd3fb120 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -35,6 +35,7 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.util.List; /** * NDArray managers are used to create NDArrays (n-dimensional array on native engine). @@ -1544,6 +1545,13 @@ default NDArray hanningWindow(long numPoints) { */ Device getDevice(); + /** + * Returns all {@link NDArray}s managed by this manager (including recursively). + * + * @return all {@link NDArray}s managed by this manager (including recursively) + */ + List getManagedArrays(); + /** * Attaches a resource to this {@code NDManager}. * @@ -1678,9 +1686,6 @@ default void tempAttachAll(NDResource... resources) { */ Engine getEngine(); - /** Sets all the gradients within the NDManager to zero. */ - void zeroGradients(); - /** {@inheritDoc} */ @Override void close(); diff --git a/api/src/main/java/ai/djl/ndarray/NDResource.java b/api/src/main/java/ai/djl/ndarray/NDResource.java index 72d709c8d7c..02d09bb77da 100644 --- a/api/src/main/java/ai/djl/ndarray/NDResource.java +++ b/api/src/main/java/ai/djl/ndarray/NDResource.java @@ -12,6 +12,8 @@ */ package ai.djl.ndarray; +import java.util.List; + /** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */ public interface NDResource extends AutoCloseable { @@ -22,6 +24,13 @@ public interface NDResource extends AutoCloseable { */ NDManager getManager(); + /** + * Returns the {@link NDArray} or {@link NDArray}s contained within this resource. + * + * @return the {@link NDArray} or {@link NDArray}s contained within this resource + */ + List getResourceNDArrays(); + /** * Attaches this {@link NDResource} to the specified {@link NDManager}. * diff --git a/api/src/main/java/ai/djl/training/GradientCollector.java b/api/src/main/java/ai/djl/training/GradientCollector.java index e1a3cd99ca1..0e16e94d82e 100644 --- a/api/src/main/java/ai/djl/training/GradientCollector.java +++ b/api/src/main/java/ai/djl/training/GradientCollector.java @@ -21,6 +21,14 @@ * performed within the try-with-resources are recorded and the variables marked. When {@link * #backward(NDArray) backward function} is called, gradients are collected w.r.t previously marked * variables. + * + *

The typical behavior is to open up a gradient collector during each batch and close it during + * the end of the batch. In this way, the gradient is reset between batches. If the gradient + * collector is left open for multiple calls to backwards, the gradients collected are accumulated + * and added together. + * + *

Due to limitations in most engines, the gradient collectors are global. This means that only + * one can be used at a time. If multiple are opened, an error will be thrown. */ public interface GradientCollector extends AutoCloseable { @@ -31,6 +39,9 @@ public interface GradientCollector extends AutoCloseable { */ void backward(NDArray target); + /** Sets all the gradients within the engine to zero. */ + void zeroGradients(); + /** {@inheritDoc} */ @Override void close(); diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index 856bc72144a..3d1d50bfba4 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -26,6 +26,8 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Path; +import java.util.Collections; +import java.util.List; /** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */ public final class PassthroughNDManager implements NDManager { @@ -247,6 +249,12 @@ public Device getDevice() { return Device.cpu(); } + /** {@inheritDoc} */ + @Override + public List getManagedArrays() { + return Collections.emptyList(); + } + /** {@inheritDoc} */ @Override public void attachInternal(String resourceId, AutoCloseable resource) { @@ -291,12 +299,6 @@ public Engine getEngine() { return null; } - /** {@inheritDoc} */ - @Override - public void zeroGradients() { - throw new UnsupportedOperationException(UNSUPPORTED); - } - /** {@inheritDoc} */ @Override public void close() {} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java index 5e1ad836969..a8a183de317 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java @@ -19,7 +19,7 @@ import ai.djl.training.GradientCollector; /** {@code MxGradientCollector} is the MXNet implementation of {@link GradientCollector}. */ -public class MxGradientCollector implements GradientCollector { +public final class MxGradientCollector implements GradientCollector { /** * Constructs an {@code MxGradientCollector} and enables training data collection for @@ -116,4 +116,15 @@ public void backward(NDArray array) { private void backward(NDArray array, boolean retainGraph) { JnaUtils.autogradBackward(new NDList(array), retainGraph ? 1 : 0); } + + /** {@inheritDoc} */ + @Override + public void zeroGradients() { + NDManager systemManager = MxNDManager.getSystemManager(); + for (NDArray array : systemManager.getManagedArrays()) { + if (array.hasGradient()) { + array.getGradient().subi(array.getGradient()); + } + } + } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index fad7755fa10..5ef573cee68 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -14,18 +14,31 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.GradientCollector; +import java.util.concurrent.atomic.AtomicBoolean; + /** {@code PtGradientCollector} is the PyTorch implementation of {@link GradientCollector}. */ -public class PtGradientCollector implements GradientCollector { +public final class PtGradientCollector implements GradientCollector { private boolean gradModel; + private static AtomicBoolean isCollecting = new AtomicBoolean(); /** Constructs a new {@code PtGradientCollector} instance. */ public PtGradientCollector() { gradModel = JniUtils.isGradMode(); JniUtils.setGradMode(true); + + boolean wasCollecting = isCollecting.getAndSet(true); + if (wasCollecting) { + throw new IllegalStateException( + "A PtGradientCollector is already collecting. Only one can be collecting at a" + + " time"); + } + + zeroGradients(); } /** {@inheritDoc} */ @@ -54,12 +67,24 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c JniUtils.backward((PtNDArray) target, (PtNDArray) grad, keepGraph, createGraph); } + /** {@inheritDoc} */ + @Override + public void zeroGradients() { + NDManager systemManager = PtNDManager.getSystemManager(); + for (NDArray array : systemManager.getManagedArrays()) { + if (array.hasGradient()) { + array.getGradient().subi(array.getGradient()); + } + } + } + /** {@inheritDoc} */ @Override public void close() { if (!gradModel) { JniUtils.setGradMode(false); } + isCollecting.set(false); // TODO: do some clean up if necessary } } diff --git a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java index 89a2dadccb1..403ca2c5003 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java @@ -24,6 +24,7 @@ public class TrainAirfoilWithTabNetTest { @Test public void testTrainAirfoilWithTabNet() throws TranslateException, IOException { + TestRequirements.nightly(); TestRequirements.engine("MXNet", "PyTorch"); String[] args = new String[] {"-g", "1", "-e", "20", "-b", "32"}; TrainingResult result = TrainAirfoilWithTabNet.runExample(args); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index d8469178dab..2ad4f5e3b50 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -82,25 +82,90 @@ public void testAutograd() { } } + @Test + public void testZeroGradients() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); + + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + NDArray b = a.mul(2); + + // Gradients are initially zero + Assert.assertEquals(a.getGradient().getFloat(), 0.0f); + + // Gradients are updated by backwards + gc.backward(b); + Assert.assertEquals(a.getGradient().getFloat(), 2.0f); + + // Gradients are cleared by zeroGradients + gc.zeroGradients(); + Assert.assertEquals(a.getGradient().getFloat(), 0.0f); + } + } + } + /** Tests that the gradients do not accumulate when closing the gradient collector. */ @Test public void testClearGradients() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { - NDArray variable = manager.create(0.0f); - variable.setRequiresGradient(true); + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); - Engine engine = manager.getEngine(); + Engine engine = Engine.getEngine(TestUtils.getEngine()); for (int i = 0; i < 3; i++) { - manager.zeroGradients(); try (GradientCollector gc = engine.newGradientCollector()) { - NDArray loss = variable.mul(2); - gc.backward(loss); + NDArray b = a.mul(2); + gc.backward(b); + } + Assert.assertEquals(a.getGradient().getFloat(), 2.0f); + } + } + } + + /** Tests that the gradients do accumulate within the same gradient collector. */ + @Test + public void testAccumulateGradients() { + // TODO: MXNet support for accumulating gradients does not currently work + TestRequirements.notEngine("MXNet"); + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); + + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + for (int i = 1; i <= 3; i++) { + NDArray b = a.mul(2); + gc.backward(b); + Assert.assertEquals(a.getGradient().getFloat(), 2.0f * i); } - Assert.assertEquals(variable.getGradient().getFloat(), 2.0f); } } } + /** + * Ensures that a gradient collector does not start when one is already created because they are + * global. + */ + @Test + @SuppressWarnings({"try", "PMD.UseTryWithResources"}) + public void testMultipleGradientCollectors() { + Assert.assertThrows( + () -> { + GradientCollector gc2 = null; + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + gc2 = engine.newGradientCollector(); + gc2.close(); + } finally { + if (gc2 != null) { + gc2.close(); + } + } + }); + } + @Test public void testFreezeParameters() { try (Model model = Model.newInstance("model", TestUtils.getEngine())) { From 664210c505cc17aabc17dd52599432bed83a6f89 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 13:52:24 +0100 Subject: [PATCH 24/30] fixed bug: ignore if array is already closed --- .../java/ai/djl/pytorch/engine/PtGradientCollector.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index 5ef573cee68..91104f78093 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -72,8 +72,12 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c public void zeroGradients() { NDManager systemManager = PtNDManager.getSystemManager(); for (NDArray array : systemManager.getManagedArrays()) { - if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); + try { + if (array.hasGradient()) { + JniUtils.zeroGrad((PtNDArray) array); + } + } catch (IllegalStateException e) { + // ignore if the array is already closed } } } From be300cc658df169e4e2f46ce07739365b01be958 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 14:09:29 +0100 Subject: [PATCH 25/30] Revert "some inheritance opening needed in a particular project (#2231)" This reverts commit e3ad95f09c4c410f04e36d98e13ed1854fb9e1d3. --- .../java/ai/djl/nn/convolutional/Conv2d.java | 6 +++--- api/src/main/java/ai/djl/nn/core/Linear.java | 8 ++++---- .../main/java/ai/djl/nn/norm/LayerNorm.java | 18 +++++++++--------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java b/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java index 4f5723a1d13..c68222cc119 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java @@ -61,7 +61,7 @@ public class Conv2d extends Convolution { private static final String STRING_LAYOUT = "NCHW"; private static final int NUM_DIMENSIONS = 4; - protected Conv2d(Builder builder) { + Conv2d(Builder builder) { super(builder); } @@ -201,10 +201,10 @@ public static Builder builder() { } /** The Builder to construct a {@link Conv2d} type of {@link Block}. */ - public static class Builder extends ConvolutionBuilder { + public static final class Builder extends ConvolutionBuilder { /** Creates a builder that can build a {@link Conv2d} block. */ - protected Builder() { + Builder() { stride = new Shape(1, 1); padding = new Shape(0, 0); dilation = new Shape(1, 1); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 5524d6d10e3..21e7941aed9 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -64,7 +64,7 @@ public class Linear extends AbstractBlock { private Parameter weight; private Parameter bias; - protected Linear(Builder builder) { + Linear(Builder builder) { super(VERSION); units = builder.units; weight = @@ -202,12 +202,12 @@ public static Builder builder() { } /** The Builder to construct a {@link Linear} type of {@link Block}. */ - public static class Builder { + public static final class Builder { - protected long units; + private long units; private boolean bias = true; - protected Builder() {} + Builder() {} /** * Sets the number of output channels. diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 5d69284132e..b791e31d5cb 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -57,16 +57,16 @@ */ public class LayerNorm extends AbstractBlock { - protected float epsilon; - protected Shape normalizedShape; + private float epsilon; + private Shape normalizedShape; - protected boolean center; - protected boolean scale; - protected int[] axis; - protected Parameter gamma; - protected Parameter beta; + private boolean center; + private boolean scale; + private int[] axis; + private Parameter gamma; + private Parameter beta; - protected LayerNorm(Builder builder) { + LayerNorm(Builder builder) { epsilon = builder.epsilon; scale = builder.scale; center = builder.center; @@ -186,7 +186,7 @@ public static class Builder { private boolean center = true; private int[] axis; - protected Builder() {} + Builder() {} /** * List the axis over which the mean and variance will be calculated (alternative to From b30c95a71c654af12deed2c71e631dfe3ba49e68 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 14:25:48 +0100 Subject: [PATCH 26/30] removed two added methods from NDManager that are not necessary --- .../java/ai/djl/ndarray/BaseNDManager.java | 19 ------------------- .../main/java/ai/djl/ndarray/NDManager.java | 6 ------ .../passthrough/PassthroughNDManager.java | 12 ------------ 3 files changed, 37 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index c7bd4d8ae59..2e6018285d7 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -72,25 +72,6 @@ public final Device defaultDevice() { return getEngine().defaultDevice(); } - /** {@inheritDoc} */ - @Override - public void closeSubManagersWithoutResources() { - for (AutoCloseable resource : resources.values()) { - if (resource instanceof NDManager) { - NDManager subManager = (NDManager) resource; - subManager.closeIfWithoutResources(); - } - } - } - - /** {@inheritDoc} */ - @Override - public void closeIfWithoutResources() { - if (resources.isEmpty()) { - close(); - } - } - /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 574cd3fb120..71d00675696 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -1690,12 +1690,6 @@ default void tempAttachAll(NDResource... resources) { @Override void close(); - /** Closes all subManagers that do not have any resources attached to them. */ - void closeSubManagersWithoutResources(); - - /** Closes if it does not have any resources attached. */ - void closeIfWithoutResources(); - /** * A {@link SystemNDManager} is a marker class for a base NDManager. * diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index 3d1d50bfba4..bde8a89137a 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -302,16 +302,4 @@ public Engine getEngine() { /** {@inheritDoc} */ @Override public void close() {} - - /** {@inheritDoc} */ - @Override - public void closeSubManagersWithoutResources() { - throw new UnsupportedOperationException(UNSUPPORTED); - } - - /** {@inheritDoc} */ - @Override - public void closeIfWithoutResources() { - throw new UnsupportedOperationException(UNSUPPORTED); - } } From a1816142c1f3d0836c1c2222624ac4f76c0ca8a6 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 16:11:28 +0100 Subject: [PATCH 27/30] introduced threadLocal reference queues --- .../ndarray/gc/DynamicInvocationHandler.java | 13 +++++- .../ai/djl/ndarray/gc/GCRuntimeException.java | 42 +++++++++++++++++++ .../pytorch/engine/PtNDArrayProxyMaker.java | 13 +++++- 3 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java index b611064ed16..895d9761213 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -59,10 +59,19 @@ public Object invoke(Object proxy, Method method, Object[] args) { } Object result; try { - result = method.invoke(map.get(uid), args); + NDArray ndArray = map.get(uid); + if (ndArray == null) { + logger.error("no nDArray found for uid: {}", uid); + throw new GCRuntimeException( + "no nDArray could be found for uid: " + + uid + + ". Consider calling the methods of a particular nDArray only from" + + " one thread or do not switch on garbage collection."); + } + result = method.invoke(ndArray, args); } catch (IllegalAccessException | InvocationTargetException e) { logger.error("Error invoking method", e); - throw new RuntimeException(e); // NOPMD + throw new GCRuntimeException(e); } return result; diff --git a/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java b/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java new file mode 100644 index 00000000000..8cd30a0ba89 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +/** + * {@code GCRuntimeException} is the exception thrown when the {@link DynamicInvocationHandler} + * fails to collect the {@link NDArray} object or call a method on the {@link NDArray} object. + */ +public class GCRuntimeException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Creates a new instance of {@code GCRuntimeException}. + * + * @param message the message + */ + public GCRuntimeException(String message) { + super(message); + } + + /** + * Creates a new instance of {@code GCRuntimeException}. + * + * @param e the exception + */ + public GCRuntimeException(Exception e) { + super(e); + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index 708689c7b11..6d4ad1f806c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -23,14 +23,17 @@ /** {@code PtNDArrayProxyMaker} creates a proxy facade. */ public class PtNDArrayProxyMaker implements NDArrayProxyMaker { - WeakHashMapWrapper map = new WeakHashMapWrapper<>(); + ThreadLocal> tLocalMap = new ThreadLocal<>(); AtomicLong counter = new AtomicLong(0); /** {@inheritDoc} */ @Override public int mapSize() { - return map.size(); + if (tLocalMap.get() == null) { + tLocalMap.set(new WeakHashMapWrapper<>()); + } + return tLocalMap.get().size(); } /** @@ -41,6 +44,12 @@ public int mapSize() { */ @Override public PtNDArray wrap(NDArray array) { + + if (tLocalMap.get() == null) { + tLocalMap.set(new WeakHashMapWrapper<>()); + } + WeakHashMapWrapper map = tLocalMap.get(); + String uid = array.getUid() + "-" + counter.incrementAndGet(); map.put(uid, array); DynamicInvocationHandler handler = new DynamicInvocationHandler(uid, map, this); From e30b771879f8a48bff9dd98329d7fc4474d0bbfb Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 16:54:10 +0100 Subject: [PATCH 28/30] Revert "Revert "some inheritance opening needed in a particular project (#2231)"" This reverts commit be300cc658df169e4e2f46ce07739365b01be958. --- .../java/ai/djl/nn/convolutional/Conv2d.java | 6 +++--- api/src/main/java/ai/djl/nn/core/Linear.java | 8 ++++---- .../main/java/ai/djl/nn/norm/LayerNorm.java | 18 +++++++++--------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java b/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java index c68222cc119..4f5723a1d13 100644 --- a/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java +++ b/api/src/main/java/ai/djl/nn/convolutional/Conv2d.java @@ -61,7 +61,7 @@ public class Conv2d extends Convolution { private static final String STRING_LAYOUT = "NCHW"; private static final int NUM_DIMENSIONS = 4; - Conv2d(Builder builder) { + protected Conv2d(Builder builder) { super(builder); } @@ -201,10 +201,10 @@ public static Builder builder() { } /** The Builder to construct a {@link Conv2d} type of {@link Block}. */ - public static final class Builder extends ConvolutionBuilder { + public static class Builder extends ConvolutionBuilder { /** Creates a builder that can build a {@link Conv2d} block. */ - Builder() { + protected Builder() { stride = new Shape(1, 1); padding = new Shape(0, 0); dilation = new Shape(1, 1); diff --git a/api/src/main/java/ai/djl/nn/core/Linear.java b/api/src/main/java/ai/djl/nn/core/Linear.java index 21e7941aed9..5524d6d10e3 100644 --- a/api/src/main/java/ai/djl/nn/core/Linear.java +++ b/api/src/main/java/ai/djl/nn/core/Linear.java @@ -64,7 +64,7 @@ public class Linear extends AbstractBlock { private Parameter weight; private Parameter bias; - Linear(Builder builder) { + protected Linear(Builder builder) { super(VERSION); units = builder.units; weight = @@ -202,12 +202,12 @@ public static Builder builder() { } /** The Builder to construct a {@link Linear} type of {@link Block}. */ - public static final class Builder { + public static class Builder { - private long units; + protected long units; private boolean bias = true; - Builder() {} + protected Builder() {} /** * Sets the number of output channels. diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index b791e31d5cb..5d69284132e 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -57,16 +57,16 @@ */ public class LayerNorm extends AbstractBlock { - private float epsilon; - private Shape normalizedShape; + protected float epsilon; + protected Shape normalizedShape; - private boolean center; - private boolean scale; - private int[] axis; - private Parameter gamma; - private Parameter beta; + protected boolean center; + protected boolean scale; + protected int[] axis; + protected Parameter gamma; + protected Parameter beta; - LayerNorm(Builder builder) { + protected LayerNorm(Builder builder) { epsilon = builder.epsilon; scale = builder.scale; center = builder.center; @@ -186,7 +186,7 @@ public static class Builder { private boolean center = true; private int[] axis; - Builder() {} + protected Builder() {} /** * List the axis over which the mean and variance will be calculated (alternative to From 9b2f9bb61e22448c3a3d5d5b7a8af93adaa6d0f9 Mon Sep 17 00:00:00 2001 From: enpasos Date: Mon, 9 Jan 2023 16:58:56 +0100 Subject: [PATCH 29/30] here I reverted to much --- api/src/main/java/ai/djl/nn/norm/LayerNorm.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 5d69284132e..43ebd85b822 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -178,7 +178,7 @@ public void loadMetadata(byte loadVersion, DataInputStream is) } /** The Builder to construct a {@link LayerNorm}. */ - public static class Builder { + public static final class Builder { private float epsilon = 1E-5f; // private Shape normalizedShape; From 6cf606eab98754a3c58091b030b283648cdb85b2 Mon Sep 17 00:00:00 2001 From: enpasos Date: Wed, 11 Jan 2023 20:34:45 +0100 Subject: [PATCH 30/30] added a method gc() to NDManager which explicitly calls checkQueue on WeakHashMapWrapper --- .../main/java/ai/djl/ndarray/NDManager.java | 7 ++ .../ai/djl/ndarray/gc/NDArrayProxyMaker.java | 5 ++ .../ai/djl/ndarray/gc/WeakHashMapWrapper.java | 5 +- .../pytorch/engine/PtNDArrayProxyMaker.java | 17 +++-- .../ai/djl/pytorch/engine/PtNDManager.java | 6 ++ .../ai/djl/pytorch/integration/gc/Main2.java | 72 +++++++++++++++++++ 6 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index 71d00675696..8cbbf8dd16a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -751,6 +751,13 @@ default NDArrayProxyMaker getProxyMaker() { throw new UnsupportedOperationException("Not supported"); } + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + default void gc() { + throw new UnsupportedOperationException("Not supported"); + } + /** * Sets the name for the NDManager. * diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java index b6f6786a444..f42175f0e25 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java +++ b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java @@ -24,6 +24,11 @@ public interface NDArrayProxyMaker { */ int mapSize(); + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + public void gc(); + /** * Wraps the {@link NDArray} in a proxy facade. * diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java index fcf88f4ce8a..919d5c46efa 100644 --- a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -35,7 +35,10 @@ public class WeakHashMapWrapper implements Map { private final Set> weakReferenceWrapperSet = new HashSet<>(); - private void checkQueue() { + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + public void checkQueue() { for (Reference ref; (ref = queue.poll()) != null; ) { synchronized (queue) { @SuppressWarnings("unchecked") diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java index 6d4ad1f806c..50d65ab4377 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -30,10 +30,20 @@ public class PtNDArrayProxyMaker implements NDArrayProxyMaker { /** {@inheritDoc} */ @Override public int mapSize() { + return getLocalWeakHashMapWrapper().size(); + } + + private WeakHashMapWrapper getLocalWeakHashMapWrapper() { if (tLocalMap.get() == null) { tLocalMap.set(new WeakHashMapWrapper<>()); } - return tLocalMap.get().size(); + return tLocalMap.get(); + } + + /** {@inheritDoc} */ + @Override + public void gc() { + getLocalWeakHashMapWrapper().checkQueue(); } /** @@ -45,10 +55,7 @@ public int mapSize() { @Override public PtNDArray wrap(NDArray array) { - if (tLocalMap.get() == null) { - tLocalMap.set(new WeakHashMapWrapper<>()); - } - WeakHashMapWrapper map = tLocalMap.get(); + WeakHashMapWrapper map = getLocalWeakHashMapWrapper(); String uid = array.getUid() + "-" + counter.incrementAndGet(); map.put(uid, array); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index f95cd533434..8269ad05ff7 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -48,6 +48,12 @@ public PtNDArrayProxyMaker getProxyMaker() { return getSystemManager().getProxyMaker(); } + /** {@inheritDoc} */ + @Override + public void gc() { + getSystemManager().getProxyMaker().gc(); + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java new file mode 100644 index 00000000000..43d6de02510 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration.gc; + +import static ai.djl.pytorch.engine.PtNDManager.debugDumpFromSystemManager; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; +import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** Some poc testing code. */ +public final class Main2 { + + private static final Logger logger = LoggerFactory.getLogger(Main2.class); + + private Main2() {} + + public static void main(String[] args) + throws IOException, TranslateException, InterruptedException { + SwitchGarbageCollection.on(); + try (NDManager baseManager = NDManager.newBaseManager(); ) { + try (NDManager subManager = baseManager.newSubManager()) { + + NDArray a = subManager.create(new float[] {1f}); + NDArray b = subManager.create(new float[] {2f}); + PtNDArray c = (PtNDArray) a.add(b); + + debugDumpFromSystemManager(true); + + System.out.println("reference exists ..."); + baseManager.gc(); + debugDumpFromSystemManager(true); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + a = null; + b = null; + c = null; + System.out.println("no reference exists, but likely not yet garbage collected ..."); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + baseManager.gc(); + debugDumpFromSystemManager(true); + + System.gc(); // just for testing - do not use in production + TimeUnit.SECONDS.sleep(1); + + System.out.println("no reference exists, and likely garbage collected ..."); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + baseManager.gc(); + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } +}