diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index abbc3c31605..2e6018285d7 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -32,9 +32,12 @@ import java.nio.ShortBuffer; import java.nio.charset.Charset; import java.nio.file.Path; +import java.util.List; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** {@code BaseNDManager} is the default implementation of {@link NDManager}. */ public abstract class BaseNDManager implements NDManager { @@ -304,6 +307,30 @@ public Device getDevice() { return device; } + /** {@inheritDoc} */ + @Override + public List getManagedArrays() { + return Stream.concat( + // Main resources + resources.values().stream() + .flatMap( + r -> { + if (r instanceof NDResource) { + return ((NDResource) r) + .getResourceNDArrays().stream(); + } else if (r instanceof NDManager) { + return ((NDManager) r).getManagedArrays().stream(); + } else { + return Stream.empty(); + } + }), + + // Temp resouces + tempResources.values().stream() + .flatMap(tr -> tr.resource.getResourceNDArrays().stream())) + .collect(Collectors.toList()); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -321,9 +348,6 @@ public String toString() { /** {@inheritDoc} */ @Override public synchronized void attachInternal(String resourceId, AutoCloseable resource) { - if (this instanceof SystemNDManager) { - return; - } if (capped.get()) { throw new IllegalStateException("NDManager is capped for addition of resources."); } @@ -333,9 +357,6 @@ public synchronized void attachInternal(String resourceId, AutoCloseable resourc /** {@inheritDoc} */ @Override public synchronized void attachUncappedInternal(String resourceId, AutoCloseable resource) { - if (this instanceof SystemNDManager) { - return; - } if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); } @@ -362,7 +383,8 @@ public synchronized void attachUncappedInternal(String resourceId, AutoCloseable public void tempAttachInternal( NDManager originalManager, String resourceId, NDResource resource) { if (this instanceof SystemNDManager) { - return; + throw new IllegalStateException( + "System manager cannot be temp attached because it can't be closed.."); } if (closed.get()) { throw new IllegalStateException("NDManager has been closed already."); @@ -373,9 +395,6 @@ public void tempAttachInternal( /** {@inheritDoc} */ @Override public synchronized void detachInternal(String resourceId) { - if (this instanceof SystemNDManager) { - return; - } if (closed.get()) { // This may happen in the middle of BaseNDManager.close() return; @@ -402,26 +421,13 @@ public NDList invoke(String operation, NDList src, PairList params) { throw new UnsupportedOperationException("Not supported!"); } - /** {@inheritDoc} */ - @Override - public void zeroGradients() { - for (AutoCloseable res : resources.values()) { - if (res instanceof NDManager) { - ((NDManager) res).zeroGradients(); - } else if (res instanceof NDArray) { - NDArray array = (NDArray) res; - if (array.hasGradient()) { - array.getGradient().subi(array.getGradient()); - } - } - } - } - /** {@inheritDoc} */ @Override public void close() { if (this instanceof SystemNDManager) { - return; + throw new IllegalStateException( + "The SystemNDManager can not be closed. It is global and lives for the duration" + + " of the process"); } if (!closed.getAndSet(true)) { for (AutoCloseable closeable : resources.values()) { @@ -463,6 +469,69 @@ public void debugDump(int level) { } } + /** + * Prints information about this {@link NDManager} and all sub-managers to the console. + * + * @param level the level of this {@link NDManager} in the hierarchy + */ + public void debugDumpDetailed(int level) { + StringBuilder sb = new StringBuilder(100); + for (int i = 0; i < level; ++i) { + sb.append(" "); + } + sb.append("\\--- NDManager(") + .append(uid.substring(24)) + .append(", ") + .append(device) + .append(") resource count: ") + .append(resources.size()); + + System.out.println(sb); // NOPMD + for (AutoCloseable c : resources.values()) { + if (c instanceof NDManager) { + ((BaseNDManager) c).debugDumpDetailed(level + 1); + } else if (c instanceof NDArray) { + StringBuilder sb2 = new StringBuilder(100); + for (int i = 0; i < level + 1; ++i) { + sb2.append(" "); + } + sb2.append( + "\\--- NDArray(" + + ((NDArray) c).getUid() + + ", Shape" + + ((NDArray) c).getShape() + + ")"); + System.out.println(sb2); // NOPMD + } else if (c instanceof NDResource) { + StringBuilder sb2 = new StringBuilder(100); + for (int i = 0; i < level + 1; ++i) { + sb2.append(" "); + } + sb2.append("\\--- other NDResource"); + System.out.println(sb2); // NOPMD + } + } + } + + /** + * Returns the number of {@link NDArray} in the hierarchy of this {@link NDManager}. + * + * @return return the number of {@link NDArray} in the hierarchy of this {@link NDManager} + */ + public int debugCountNDArrays() { + int count = 0; + for (AutoCloseable c : resources.values()) { + if (c instanceof BaseNDManager) { + count += ((BaseNDManager) c).debugCountNDArrays(); + } else if (c instanceof NDArray) { + count++; + } else if (c instanceof NDList) { + count += ((NDList) c).size(); + } + } + return count; + } + NDManager getAlternativeManager() { return alternativeManager; } diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 1adfd59ac6a..ee67c7f78ee 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -30,6 +30,8 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.LongStream; @@ -4685,6 +4687,12 @@ default NDArray countNonzero(int axis) { */ NDArray erfinv(); + /** {@inheritDoc} */ + @Override + default List getResourceNDArrays() { + return Collections.singletonList(this); + } + /** * Returns an internal representative of Native {@code NDArray}. * diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index c88c93337e5..36c73bb42a9 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.List; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; @@ -269,6 +270,12 @@ public NDManager getManager() { return head().getManager(); } + /** {@inheritDoc} */ + @Override + public List getResourceNDArrays() { + return this; + } + /** {@inheritDoc} */ @Override public void attach(NDManager manager) { diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index d61ba438763..8cbbf8dd16a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.ndarray.gc.NDArrayProxyMaker; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.translate.Translator; @@ -34,6 +35,7 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.nio.file.Path; +import java.util.List; /** * NDArray managers are used to create NDArrays (n-dimensional array on native engine). @@ -740,6 +742,22 @@ default NDList load(Path path, Device device) { return newSubManager(device).load(path); } + /** + * Returns the {@link NDArrayProxyMaker}. + * + * @return the {@link NDArrayProxyMaker} + */ + default NDArrayProxyMaker getProxyMaker() { + throw new UnsupportedOperationException("Not supported"); + } + + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + default void gc() { + throw new UnsupportedOperationException("Not supported"); + } + /** * Sets the name for the NDManager. * @@ -1534,6 +1552,13 @@ default NDArray hanningWindow(long numPoints) { */ Device getDevice(); + /** + * Returns all {@link NDArray}s managed by this manager (including recursively). + * + * @return all {@link NDArray}s managed by this manager (including recursively) + */ + List getManagedArrays(); + /** * Attaches a resource to this {@code NDManager}. * @@ -1668,9 +1693,6 @@ default void tempAttachAll(NDResource... resources) { */ Engine getEngine(); - /** Sets all the gradients within the NDManager to zero. */ - void zeroGradients(); - /** {@inheritDoc} */ @Override void close(); diff --git a/api/src/main/java/ai/djl/ndarray/NDResource.java b/api/src/main/java/ai/djl/ndarray/NDResource.java index 72d709c8d7c..02d09bb77da 100644 --- a/api/src/main/java/ai/djl/ndarray/NDResource.java +++ b/api/src/main/java/ai/djl/ndarray/NDResource.java @@ -12,6 +12,8 @@ */ package ai.djl.ndarray; +import java.util.List; + /** An object which is managed by an {@link NDManager} and tracks the manager it is attached to. */ public interface NDResource extends AutoCloseable { @@ -22,6 +24,13 @@ public interface NDResource extends AutoCloseable { */ NDManager getManager(); + /** + * Returns the {@link NDArray} or {@link NDArray}s contained within this resource. + * + * @return the {@link NDArray} or {@link NDArray}s contained within this resource + */ + List getResourceNDArrays(); + /** * Attaches this {@link NDResource} to the specified {@link NDManager}. * diff --git a/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java new file mode 100644 index 00000000000..895d9761213 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/DynamicInvocationHandler.java @@ -0,0 +1,79 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** {@code DynamicInvocationHandler} implements the {@link InvocationHandler}. */ +public class DynamicInvocationHandler implements InvocationHandler { + + private static final Logger logger = LoggerFactory.getLogger(DynamicInvocationHandler.class); + + WeakHashMapWrapper map; + String uid; + + NDArrayProxyMaker ndArrayProxyMaker; + + /** + * Creates a new instance of {@code DynamicInvocationHandler}. + * + * @param uid the uid + * @param map the map + * @param ndArrayProxyMaker the ndArrayProxyMaker + */ + public DynamicInvocationHandler( + String uid, + WeakHashMapWrapper map, + NDArrayProxyMaker ndArrayProxyMaker) { + this.map = map; + this.uid = uid; + this.ndArrayProxyMaker = ndArrayProxyMaker; + } + + /** {@inheritDoc} */ + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + + if ("getNumOfNDArraysInGCMap".equals(method.getName())) { + return this.map.size(); + } + if ("getImplementation".equals(method.getName())) { + return map.get(uid); + } + Object result; + try { + NDArray ndArray = map.get(uid); + if (ndArray == null) { + logger.error("no nDArray found for uid: {}", uid); + throw new GCRuntimeException( + "no nDArray could be found for uid: " + + uid + + ". Consider calling the methods of a particular nDArray only from" + + " one thread or do not switch on garbage collection."); + } + result = method.invoke(ndArray, args); + } catch (IllegalAccessException | InvocationTargetException e) { + logger.error("Error invoking method", e); + throw new GCRuntimeException(e); + } + + return result; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java b/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java new file mode 100644 index 00000000000..8cd30a0ba89 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/GCRuntimeException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +/** + * {@code GCRuntimeException} is the exception thrown when the {@link DynamicInvocationHandler} + * fails to collect the {@link NDArray} object or call a method on the {@link NDArray} object. + */ +public class GCRuntimeException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Creates a new instance of {@code GCRuntimeException}. + * + * @param message the message + */ + public GCRuntimeException(String message) { + super(message); + } + + /** + * Creates a new instance of {@code GCRuntimeException}. + * + * @param e the exception + */ + public GCRuntimeException(Exception e) { + super(e); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java new file mode 100644 index 00000000000..f42175f0e25 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/NDArrayProxyMaker.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +/** {@code PtNDArrayProxyMaker} creates a proxy facade. */ +public interface NDArrayProxyMaker { + + /** + * Returns the size of the map. + * + * @return the size of the map + */ + int mapSize(); + + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + public void gc(); + + /** + * Wraps the {@link NDArray} in a proxy facade. + * + * @param array the array to wrap + * @return the wrapped array + */ + NDArray wrap(NDArray array); +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java b/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java new file mode 100644 index 00000000000..de356a0e1b2 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/SwitchGarbageCollection.java @@ -0,0 +1,36 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +/** {@code GCSwitch} acts as a switch to put garbage collection on or off. */ +public final class SwitchGarbageCollection { + + private static boolean useGarbageCollection; + + /** Hide the constructor of this utility class. */ + private SwitchGarbageCollection() {} + + /** + * Returns whether to use garbage collection to manage temporary resources. + * + * @return the useGarbageCollection + */ + public static boolean isUseGarbageCollection() { + return useGarbageCollection; + } + + /** Switches the garbage collection on. */ + public static void on() { + useGarbageCollection = true; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java new file mode 100644 index 00000000000..919d5c46efa --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakHashMapWrapper.java @@ -0,0 +1,164 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import ai.djl.ndarray.NDArray; + +import java.lang.ref.Reference; +import java.lang.ref.ReferenceQueue; +import java.lang.reflect.Proxy; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.WeakHashMap; + +/** + * {@code WeakHashMapWrapper} wraps a {@link WeakHashMap}. It has a {@link ReferenceQueue} to + * get informed by the garbage collector. On method invocations it looks for messages from the + * garbage collector and removes the corresponding entries. + */ +public class WeakHashMapWrapper implements Map { + + private final WeakHashMap map = new WeakHashMap<>(); + private final ReferenceQueue queue = new ReferenceQueue<>(); + + private final Set> weakReferenceWrapperSet = new HashSet<>(); + + /** + * Checks the referenceQueue for NDArrays that are garbage collected by Java GC and closes them. + */ + public void checkQueue() { + for (Reference ref; (ref = queue.poll()) != null; ) { + synchronized (queue) { + @SuppressWarnings("unchecked") + WeakReferenceWrapper ref2 = (WeakReferenceWrapper) ref; + V value = ref2.getValue(); + if (value instanceof NDArray) { // just as one example + ((NDArray) value).close(); + weakReferenceWrapperSet.remove(ref2); + } + } + } + } + + // implement all methods of Map interface by calling corresponding methods of + // WeakHashMap instance map + + /** {@inheritDoc} */ + @Override + public int size() { + checkQueue(); + return map.size(); + } + + /** {@inheritDoc} */ + @Override + public boolean isEmpty() { + checkQueue(); + return map.isEmpty(); + } + + /** {@inheritDoc} */ + @Override + public boolean containsKey(Object key) { + checkQueue(); + return map.containsKey(key); + } + + /** {@inheritDoc} */ + @Override + public boolean containsValue(Object value) { + checkQueue(); + return map.containsValue(value); + } + + /** {@inheritDoc} */ + @Override + public V get(Object key) { + checkQueue(); + return map.get(key); + } + + /** {@inheritDoc} */ + @Override + public V put(K key, V value) { + if (value instanceof Proxy) { + throw new IllegalArgumentException( + "Proxy is not supported to be stored as value here."); + } + weakReferenceWrapperSet.add(new WeakReferenceWrapper(key, value, queue)); + return map.put(key, value); + } + + /** {@inheritDoc} */ + @Override + public V remove(Object key) { + checkQueue(); + return map.remove(key); + } + + /** {@inheritDoc} */ + @Override + public void putAll(Map m) { + checkQueue(); + map.putAll(m); + } + + /** {@inheritDoc} */ + @Override + public void clear() { + checkQueue(); + map.clear(); + } + + /** {@inheritDoc} */ + @Override + public Set keySet() { + checkQueue(); + return map.keySet(); + } + + /** {@inheritDoc} */ + @Override + public Collection values() { + checkQueue(); + return map.values(); + } + + /** {@inheritDoc} */ + @Override + public Set> entrySet() { + checkQueue(); + return map.entrySet(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + checkQueue(); + return map.equals(o); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return map.hashCode(); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return map.toString(); + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java b/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java new file mode 100644 index 00000000000..e6426b0eaac --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/WeakReferenceWrapper.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.gc; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; + +/** + * {@code WeakReferenceWrapper} extends a {@link WeakReference}. It uses an object of type K + * as the referent has an object property of type V. + */ +public class WeakReferenceWrapper extends WeakReference { + V value; + + WeakReferenceWrapper(K key, V value, ReferenceQueue queue) { + super(key, queue); + this.value = value; + } + + /** + * Returns the value. + * + * @return the value + */ + public V getValue() { + return value; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/gc/package-info.java b/api/src/main/java/ai/djl/ndarray/gc/package-info.java new file mode 100644 index 00000000000..1bbae260202 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/gc/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes that help to garbage collect orphaned resources. */ +package ai.djl.ndarray.gc; diff --git a/api/src/main/java/ai/djl/training/GradientCollector.java b/api/src/main/java/ai/djl/training/GradientCollector.java index e1a3cd99ca1..0e16e94d82e 100644 --- a/api/src/main/java/ai/djl/training/GradientCollector.java +++ b/api/src/main/java/ai/djl/training/GradientCollector.java @@ -21,6 +21,14 @@ * performed within the try-with-resources are recorded and the variables marked. When {@link * #backward(NDArray) backward function} is called, gradients are collected w.r.t previously marked * variables. + * + *

The typical behavior is to open up a gradient collector during each batch and close it during + * the end of the batch. In this way, the gradient is reset between batches. If the gradient + * collector is left open for multiple calls to backwards, the gradients collected are accumulated + * and added together. + * + *

Due to limitations in most engines, the gradient collectors are global. This means that only + * one can be used at a time. If multiple are opened, an error will be thrown. */ public interface GradientCollector extends AutoCloseable { @@ -31,6 +39,9 @@ public interface GradientCollector extends AutoCloseable { */ void backward(NDArray target); + /** Sets all the gradients within the engine to zero. */ + void zeroGradients(); + /** {@inheritDoc} */ @Override void close(); diff --git a/api/src/main/java/ai/djl/util/NativeResource.java b/api/src/main/java/ai/djl/util/NativeResource.java index 65aa8f0085a..c70483234d1 100644 --- a/api/src/main/java/ai/djl/util/NativeResource.java +++ b/api/src/main/java/ai/djl/util/NativeResource.java @@ -14,58 +14,42 @@ import com.sun.jna.Pointer; -import java.util.concurrent.atomic.AtomicReference; - /** - * {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in - * the different engines. + * {@code NativeResource} is an interface for {@link AutoCloseable} blocks of memory created in the + * different engines. * * @param the resource that could map to a native pointer or java object */ -public abstract class NativeResource implements AutoCloseable { - - protected final AtomicReference handle; - private String uid; - - protected NativeResource(T handle) { - this.handle = new AtomicReference<>(handle); - uid = handle.toString(); - } - +public interface NativeResource extends AutoCloseable { /** * Gets the boolean that indicates whether this resource has been released. * * @return whether this resource has been released */ - public boolean isReleased() { - return handle.get() == null; - } + boolean isReleased(); /** * Gets the {@link Pointer} to this resource. * * @return the {@link Pointer} to this resource */ - public T getHandle() { - T reference = handle.get(); - if (reference == null) { - throw new IllegalStateException("Native resource has been release already."); - } - return reference; - } + T getHandle(); /** * Gets the unique ID of this resource. * * @return the unique ID of this resource */ - public final String getUid() { - return uid; - } + String getUid(); + + /** + * Gets and sets the atomic handle to null. + * + * @return the previous handle value + */ + T getAndSetHandleNull(); /** {@inheritDoc} */ @Override - public void close() { - throw new UnsupportedOperationException("Not implemented."); - } + void close(); } diff --git a/api/src/main/java/ai/djl/util/NativeResourceImpl.java b/api/src/main/java/ai/djl/util/NativeResourceImpl.java new file mode 100644 index 00000000000..f907db7dd78 --- /dev/null +++ b/api/src/main/java/ai/djl/util/NativeResourceImpl.java @@ -0,0 +1,66 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util; + +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@code NativeResourceImpl} is an internal class for {@link AutoCloseable} blocks of memory + * created in the different engines. + * + * @param the resource that could map to a native pointer or java object + */ +public abstract class NativeResourceImpl implements NativeResource { + + protected final AtomicReference handle; + private String uid; + + protected NativeResourceImpl(T handle) { + this.handle = new AtomicReference<>(handle); + uid = handle.toString(); + } + + /** {@inheritDoc} */ + @Override + public T getAndSetHandleNull() { + return handle.getAndSet(null); + } + + /** {@inheritDoc} */ + @Override + public boolean isReleased() { + return handle.get() == null; + } + + /** {@inheritDoc} */ + @Override + public T getHandle() { + T reference = handle.get(); + if (reference == null) { + throw new IllegalStateException("Native resource has been release already."); + } + return reference; + } + + /** {@inheritDoc} */ + @Override + public final String getUid() { + return uid; + } + + /** {@inheritDoc} */ + @Override + public void close() { + throw new UnsupportedOperationException("Not implemented."); + } +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java index 7b35184af7f..bde8a89137a 100644 --- a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -26,6 +26,8 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Path; +import java.util.Collections; +import java.util.List; /** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */ public final class PassthroughNDManager implements NDManager { @@ -247,6 +249,12 @@ public Device getDevice() { return Device.cpu(); } + /** {@inheritDoc} */ + @Override + public List getManagedArrays() { + return Collections.emptyList(); + } + /** {@inheritDoc} */ @Override public void attachInternal(String resourceId, AutoCloseable resource) { @@ -291,12 +299,6 @@ public Engine getEngine() { return null; } - /** {@inheritDoc} */ - @Override - public void zeroGradients() { - throw new UnsupportedOperationException(UNSUPPORTED); - } - /** {@inheritDoc} */ @Override public void close() {} diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java index 62398b1868e..403f284f6ae 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/CachedOp.java @@ -19,7 +19,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.Parameter; import ai.djl.training.ParameterStore; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Pair; import ai.djl.util.PairList; @@ -40,7 +40,7 @@ * analyzing the input shape. It requires minimum input to do inference because most of the * information can be obtained from the model itself. */ -public class CachedOp extends NativeResource { +public class CachedOp extends NativeResourceImpl { private static final Logger logger = LoggerFactory.getLogger(CachedOp.class); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java index 5e1ad836969..a8a183de317 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java @@ -19,7 +19,7 @@ import ai.djl.training.GradientCollector; /** {@code MxGradientCollector} is the MXNet implementation of {@link GradientCollector}. */ -public class MxGradientCollector implements GradientCollector { +public final class MxGradientCollector implements GradientCollector { /** * Constructs an {@code MxGradientCollector} and enables training data collection for @@ -116,4 +116,15 @@ public void backward(NDArray array) { private void backward(NDArray array, boolean retainGraph) { JnaUtils.autogradBackward(new NDList(array), retainGraph ? 1 : 0); } + + /** {@inheritDoc} */ + @Override + public void zeroGradients() { + NDManager systemManager = MxNDManager.getSystemManager(); + for (NDArray array : systemManager.getManagedArrays()) { + if (array.hasGradient()) { + array.getGradient().subi(array.getGradient()); + } + } + } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index b3563f06314..6ad89bee883 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -23,7 +23,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import com.sun.jna.Native; import com.sun.jna.Pointer; @@ -36,7 +36,7 @@ import java.util.stream.IntStream; /** {@code MxNDArray} is the MXNet implementation of {@link NDArray}. */ -public class MxNDArray extends NativeResource implements LazyNDArray { +public class MxNDArray extends NativeResourceImpl implements LazyNDArray { private String name; private Device device; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java index 36bead164e4..6da9c104e53 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxParameterServer.java @@ -20,14 +20,14 @@ import ai.djl.ndarray.NDManager; import ai.djl.training.ParameterServer; import ai.djl.training.optimizer.Optimizer; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import com.sun.jna.Pointer; import java.util.Arrays; /** {@code MxParameterServer} is the MXNet implementation of {@link ParameterServer}. */ -public class MxParameterServer extends NativeResource implements ParameterServer { +public class MxParameterServer extends NativeResourceImpl implements ParameterServer { @SuppressWarnings("PMD.SingularField") // use class field to hold the OptimizerCallback which prevent it from being gc. diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java index f724aad4bde..7ba921ec00f 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/Symbol.java @@ -15,7 +15,7 @@ import ai.djl.Device; import ai.djl.mxnet.jna.JnaUtils; import ai.djl.ndarray.types.Shape; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.PairList; import ai.djl.util.Utils; @@ -37,7 +37,7 @@ * @see MXNet * Symbol */ -public class Symbol extends NativeResource { +public class Symbol extends NativeResourceImpl { // private String[] argParams; // private String[] auxParams; diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java index d9f588a2bff..cde0e533e9c 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PaddlePredictor.java @@ -13,10 +13,10 @@ package ai.djl.paddlepaddle.engine; import ai.djl.paddlepaddle.jni.JniUtils; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; /** PaddlePaddle C++ Predictor. */ -public class PaddlePredictor extends NativeResource { +public class PaddlePredictor extends NativeResourceImpl { PaddlePredictor(long handle) { super(handle); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java index fad7755fa10..91104f78093 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java @@ -14,18 +14,31 @@ package ai.djl.pytorch.engine; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; import ai.djl.pytorch.jni.JniUtils; import ai.djl.training.GradientCollector; +import java.util.concurrent.atomic.AtomicBoolean; + /** {@code PtGradientCollector} is the PyTorch implementation of {@link GradientCollector}. */ -public class PtGradientCollector implements GradientCollector { +public final class PtGradientCollector implements GradientCollector { private boolean gradModel; + private static AtomicBoolean isCollecting = new AtomicBoolean(); /** Constructs a new {@code PtGradientCollector} instance. */ public PtGradientCollector() { gradModel = JniUtils.isGradMode(); JniUtils.setGradMode(true); + + boolean wasCollecting = isCollecting.getAndSet(true); + if (wasCollecting) { + throw new IllegalStateException( + "A PtGradientCollector is already collecting. Only one can be collecting at a" + + " time"); + } + + zeroGradients(); } /** {@inheritDoc} */ @@ -54,12 +67,28 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c JniUtils.backward((PtNDArray) target, (PtNDArray) grad, keepGraph, createGraph); } + /** {@inheritDoc} */ + @Override + public void zeroGradients() { + NDManager systemManager = PtNDManager.getSystemManager(); + for (NDArray array : systemManager.getManagedArrays()) { + try { + if (array.hasGradient()) { + JniUtils.zeroGrad((PtNDArray) array); + } + } catch (IllegalStateException e) { + // ignore if the array is already closed + } + } + } + /** {@inheritDoc} */ @Override public void close() { if (!gradModel) { JniUtils.setGradMode(false); } + isCollecting.set(false); // TODO: do some clean up if necessary } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 29148ec4463..348d24a2d13 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -13,1054 +13,524 @@ package ai.djl.pytorch.engine; import ai.djl.Device; -import ai.djl.ndarray.BaseNDManager; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; -import ai.djl.pytorch.jni.JniUtils; import ai.djl.util.NativeResource; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** {@code PtNDArray} is the PyTorch implementation of {@link NDArray}. */ -public class PtNDArray extends NativeResource implements NDArray { - - private String name; - private Device device; - private DataType dataType; - private Shape shape; - private SparseFormat sparseFormat; - // use Boolean object to maintain three status: null, false, true - private Boolean hasGradient; - private PtNDManager manager; - private PtNDArrayEx ptNDArrayEx; - private String[] strs; - - // keep a reference to direct buffer to avoid GC release the memory - @SuppressWarnings("PMD.UnusedPrivateField") - private ByteBuffer dataRef; - /** - * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead). - * - * @param manager the manager to attach the new array to - * @param handle the pointer to the native PyTorch memory - */ - public PtNDArray(PtNDManager manager, long handle) { - super(handle); - this.manager = manager; - this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attachInternal(getUid(), this); - } - - /** - * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} - * instead) with the data that is hold on Java side. - * - * @param manager the manager to attach the new array to - * @param handle the pointer to the native PyTorch memory - * @param data the direct buffer of the data - */ - public PtNDArray(PtNDManager manager, long handle, ByteBuffer data) { - super(handle); - this.manager = manager; - this.ptNDArrayEx = new PtNDArrayEx(this); - manager.attachInternal(getUid(), this); - dataRef = data; - } - - /** - * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle - * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. - * - * @param manager the manager to attach the new array to - * @param strs the string array - * @param shape the {@link Shape} of the {@link NDArray} - */ - public PtNDArray(PtNDManager manager, String[] strs, Shape shape) { - super(-1L); - this.manager = manager; - this.strs = strs; - this.shape = shape; - this.dataType = DataType.STRING; - } +/** {@code PtNDArray} is the interface for the PyTorch implementation of {@link NDArray}. */ +public interface PtNDArray extends NativeResource, NDArray { /** {@inheritDoc} */ @Override - public PtNDManager getManager() { - return manager; - } + PtNDManager getManager(); /** {@inheritDoc} */ @Override - public String getName() { - return name; - } + String getName(); /** {@inheritDoc} */ @Override - public void setName(String name) { - this.name = name; - } + void setName(String name); /** {@inheritDoc} */ @Override - public DataType getDataType() { - if (dataType == null) { - dataType = JniUtils.getDataType(this); - } - return dataType; - } + DataType getDataType(); /** {@inheritDoc} */ @Override - public Device getDevice() { - if (device == null) { - device = JniUtils.getDevice(this); - } - return device; - } + Device getDevice(); /** {@inheritDoc} */ @Override - public Shape getShape() { - if (shape == null) { - shape = JniUtils.getShape(this); - } - return shape; - } + Shape getShape(); /** {@inheritDoc} */ @Override - public SparseFormat getSparseFormat() { - if (sparseFormat == null) { - sparseFormat = JniUtils.getSparseFormat(this); - } - return sparseFormat; - } + SparseFormat getSparseFormat(); /** {@inheritDoc} */ @Override - public PtNDArray toDevice(Device device, boolean copy) { - if (device.equals(getDevice()) && !copy) { - return this; - } - return JniUtils.to(this, getDataType(), device); - } + PtNDArray toDevice(Device device, boolean copy); /** {@inheritDoc} */ @Override - public PtNDArray toType(DataType dataType, boolean copy) { - if (dataType.equals(getDataType()) && !copy) { - return this; - } - return JniUtils.to(this, dataType, getDevice()); - } + PtNDArray toType(DataType dataType, boolean copy); /** {@inheritDoc} */ @Override - public void setRequiresGradient(boolean requiresGrad) { - JniUtils.attachGradient(this, requiresGrad); - hasGradient = requiresGrad; - } + void setRequiresGradient(boolean requiresGrad); /** {@inheritDoc} */ @Override - public PtNDArray getGradient() { - if (!hasGradient()) { - throw new IllegalStateException( - "No gradient attached to this NDArray, please call array.setRequiresGradient()" - + " on your NDArray or block.setInitializer() on your Block"); - } - PtNDArray res = JniUtils.getGradient(this); - // If you call getGradient() before you run the backward, - // you will get nothing in PyTorch engine. - // To align with MXNet's behavior, we will create a zeros NDArray. - // TODO should we access the grad NDArray after we close the parameter NDArray? - if (res == null) { - res = (PtNDArray) manager.zeros(getShape()); - } - return res; - } + PtNDArray getGradient(); /** {@inheritDoc} */ @Override - public boolean hasGradient() { - if (hasGradient == null) { - hasGradient = JniUtils.requiresGrad(this); - } - return hasGradient; - } + boolean hasGradient(); /** {@inheritDoc} */ @Override - public NDArray stopGradient() { - return JniUtils.detachGradient(this); - } + NDArray stopGradient(); /** {@inheritDoc} */ @Override - public ByteBuffer toByteBuffer() { - return JniUtils.getByteBuffer(this); - } + ByteBuffer toByteBuffer(); /** {@inheritDoc} */ @Override - public String[] toStringArray(Charset charset) { - return strs; - } + String[] toStringArray(Charset charset); /** {@inheritDoc} */ @Override - public void set(Buffer buffer) { - int size = Math.toIntExact(size()); - DataType type = getDataType(); - BaseNDManager.validateBuffer(buffer, type, size); - // TODO how do we handle the exception happened in the middle - dataRef = null; - if (buffer.isDirect() && buffer instanceof ByteBuffer) { - // If NDArray is on the GPU, it is native code responsibility to control the data life - // cycle - if (!getDevice().isGpu()) { - dataRef = (ByteBuffer) buffer; - } - JniUtils.set(this, (ByteBuffer) buffer); - return; - } - // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType - ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes()); - BaseNDManager.copyBuffer(buffer, buf); - - // If NDArray is on the GPU, it is native code responsibility to control the data life cycle - if (!getDevice().isGpu()) { - dataRef = buf; - } - JniUtils.set(this, buf); - } + void set(Buffer buffer); /** {@inheritDoc} */ @Override - public NDArray get(NDManager manager, long... indices) { - return JniUtils.getItem(this, indices, (PtNDManager) manager); - } + NDArray get(NDManager manager, long... indices); /** {@inheritDoc} */ @Override - public NDArray gather(NDArray index, int axis) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray index is supported."); - } - return JniUtils.gather(this, (PtNDArray) index, axis); - } + NDArray gather(NDArray index, int axis); /** {@inheritDoc} */ @Override - public NDArray gatherNd(NDArray index) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray index is supported."); - } - Shape indexShape = index.getShape(); - Shape dataShape = getShape(); - int indexingDepth = (int) indexShape.get(0); - if (indexingDepth > dataShape.dimension()) { - throw new IllegalArgumentException( - "Indexing rank " - + indexShape.get(0) - + " exceeds the data rank " - + dataShape.dimension()); - } - // Row-first order, the linear index is accumulated from z->y->x. - // For example, dataShape = (3, 2, 3), indexShape = (2, 3, 3) - // The method is: indexLinear = index[1] + index[0] * dataShape[1], row-first order - // indexLinear has shape (3, 3), is from combining the index along 0 axis. - // Each number in indexLinear is an indexing to an element in data (3, 2, ...). - // data is flattened to be (3*2, ...) which can be indexed by indexLinear. - // Finally, reshape the output to (3, 3, ...). Thus - // totalShape = indexShape.slice(1).addAll(dataShape.slice(indexingDepth)); - NDArray indexLinear = index.get("{}, ...", indexingDepth - 1); - long dim = 1; - for (int i = indexingDepth - 2; i > -1; i--) { - dim = dim * dataShape.get(i + 1); - indexLinear = indexLinear.addi(index.get("{}, ...", i).muli(dim)); - } - NDArray dataFlatten = this.flatten(0, indexingDepth - 1); - return dataFlatten.get(indexLinear); - } + NDArray gatherNd(NDArray index); /** {@inheritDoc} */ @Override - public NDArray take(NDManager manager, NDArray index) { - if (!(index instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray is supported."); - } - return JniUtils.take(this, (PtNDArray) index, (PtNDManager) manager); - } + NDArray take(NDManager manager, NDArray index); /** {@inheritDoc} */ @Override - public NDArray put(NDArray index, NDArray data) { - if (!(index instanceof PtNDArray) || !(data instanceof PtNDArray)) { - throw new IllegalArgumentException("Only PtNDArray is supported."); - } - return JniUtils.put(this, (PtNDArray) index, (PtNDArray) data); - } + NDArray put(NDArray index, NDArray data); /** {@inheritDoc} */ @Override - public void copyTo(NDArray array) { - throw new UnsupportedOperationException("Not implemented"); - } + void copyTo(NDArray array); /** {@inheritDoc} */ @Override - public void attach(NDManager manager) { - detach(); - this.manager = (PtNDManager) manager; - manager.attachInternal(getUid(), this); - } + void attach(NDManager manager); /** {@inheritDoc} */ @Override - public void returnResource(NDManager manager) { - detach(); - this.manager = (PtNDManager) manager; - manager.attachUncappedInternal(getUid(), this); - } + void returnResource(NDManager manager); /** {@inheritDoc} */ @Override - public void tempAttach(NDManager manager) { - NDManager original = this.manager; - detach(); - this.manager = (PtNDManager) manager; - manager.tempAttachInternal(original, getUid(), this); - } + void tempAttach(NDManager manager); /** {@inheritDoc} */ @Override - public void detach() { - manager.detachInternal(getUid()); - manager = PtNDManager.getSystemManager(); - } + void detach(); /** {@inheritDoc} */ @Override - public NDArray duplicate() { - return JniUtils.clone(this); - } + NDArray duplicate(); /** {@inheritDoc} */ @Override - public PtNDArray booleanMask(NDArray index, int axis) { - Shape indexShape = index.getShape(); - if (indexShape.equals(getShape())) { - // Result is flattened since shape is undetermined - return JniUtils.booleanMask(this, manager.from(index)); - } else if (indexShape.equals(getShape().slice(axis))) { - // index will be broadcast by default - try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) { - // Shape recovery - Shape remainder = getShape().slice(0, axis); - long selectedSize = flattedResult.getShape().size() / remainder.size(); - return flattedResult.reshape(remainder.addAll(new Shape(selectedSize))); - } - } else { - throw new UnsupportedOperationException( - "Not supported for shape not broadcastable " - + indexShape - + " vs " - + getShape()); - } - } + PtNDArray booleanMask(NDArray index, int axis); /** {@inheritDoc} */ @Override - public NDArray sequenceMask(NDArray sequenceLength, float value) { - throw new UnsupportedOperationException("Not implemented yet"); - } + NDArray sequenceMask(NDArray sequenceLength, float value); /** {@inheritDoc} */ @Override - public NDArray sequenceMask(NDArray sequenceLength) { - throw new UnsupportedOperationException("Not implemented yet"); - } + NDArray sequenceMask(NDArray sequenceLength); /** {@inheritDoc} */ @Override - public boolean contentEquals(Number number) { - return contentEquals(manager.create(number)); - } + boolean contentEquals(Number number); /** {@inheritDoc} */ @Override - public boolean contentEquals(NDArray other) { - if (other == null || (!shapeEquals(other))) { - return false; - } - if (getDataType() != other.getDataType()) { - return false; - } - return JniUtils.contentEqual(this, manager.from(other)); - } + boolean contentEquals(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray eq(Number n) { - try (NDArray number = manager.create(n)) { - return eq(number); - } - } + PtNDArray eq(Number n); /** {@inheritDoc} */ @Override - public PtNDArray eq(NDArray other) { - return JniUtils.eq(this, manager.from(other)); - } + PtNDArray eq(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray neq(Number n) { - try (NDArray number = manager.create(n)) { - return neq(number); - } - } + PtNDArray neq(Number n); /** {@inheritDoc} */ @Override - public PtNDArray neq(NDArray other) { - return JniUtils.neq(this, manager.from(other)); - } + PtNDArray neq(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray gt(Number n) { - try (NDArray number = manager.create(n)) { - return gt(number); - } - } + PtNDArray gt(Number n); /** {@inheritDoc} */ @Override - public PtNDArray gt(NDArray other) { - return JniUtils.gt(this, manager.from(other)); - } + PtNDArray gt(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray gte(Number n) { - try (NDArray number = manager.create(n)) { - return gte(number); - } - } + PtNDArray gte(Number n); /** {@inheritDoc} */ @Override - public PtNDArray gte(NDArray other) { - return JniUtils.gte(this, manager.from(other)); - } + PtNDArray gte(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray lt(Number n) { - try (NDArray number = manager.create(n)) { - return lt(number); - } - } + PtNDArray lt(Number n); /** {@inheritDoc} */ @Override - public PtNDArray lt(NDArray other) { - return JniUtils.lt(this, manager.from(other)); - } + PtNDArray lt(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray lte(Number n) { - try (NDArray number = manager.create(n)) { - return lte(number); - } - } + PtNDArray lte(Number n); /** {@inheritDoc} */ @Override - public PtNDArray lte(NDArray other) { - return JniUtils.lte(this, manager.from(other)); - } + PtNDArray lte(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray add(Number n) { - try (NDArray number = manager.create(n)) { - return add(number); - } - } + PtNDArray add(Number n); /** {@inheritDoc} */ @Override - public PtNDArray add(NDArray other) { - return JniUtils.add(this, manager.from(other)); - } + PtNDArray add(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray sub(Number n) { - try (NDArray number = manager.create(n)) { - return sub(number); - } - } + PtNDArray sub(Number n); /** {@inheritDoc} */ @Override - public PtNDArray sub(NDArray other) { - return JniUtils.sub(this, manager.from(other)); - } + PtNDArray sub(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray mul(Number n) { - try (NDArray number = manager.create(n)) { - return mul(number); - } - } + PtNDArray mul(Number n); /** {@inheritDoc} */ @Override - public PtNDArray mul(NDArray other) { - return JniUtils.mul(this, manager.from(other)); - } + PtNDArray mul(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray div(Number n) { - try (NDArray number = manager.create(n)) { - return div(number); - } - } + PtNDArray div(Number n); /** {@inheritDoc} */ @Override - public PtNDArray div(NDArray other) { - return JniUtils.div(this, manager.from(other)); - } + PtNDArray div(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray mod(Number n) { - try (NDArray number = manager.create(n)) { - return mod(number); - } - } + PtNDArray mod(Number n); /** {@inheritDoc} */ @Override - public PtNDArray mod(NDArray other) { - return JniUtils.remainder(this, manager.from(other)); - } + PtNDArray mod(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray pow(Number n) { - try (NDArray number = manager.create(n)) { - return pow(number); - } - } + PtNDArray pow(Number n); /** {@inheritDoc} */ @Override - public PtNDArray pow(NDArray other) { - return JniUtils.pow(this, manager.from(other)); - } + PtNDArray pow(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray addi(Number n) { - try (NDArray number = manager.create(n)) { - return addi(number); - } - } + PtNDArray addi(Number n); /** {@inheritDoc} */ @Override - public PtNDArray addi(NDArray other) { - JniUtils.addi(this, manager.from(other)); - return this; - } + PtNDArray addi(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray subi(Number n) { - try (NDArray number = manager.create(n)) { - return subi(number); - } - } + PtNDArray subi(Number n); /** {@inheritDoc} */ @Override - public PtNDArray subi(NDArray other) { - JniUtils.subi(this, manager.from(other)); - return this; - } + PtNDArray subi(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray muli(Number n) { - try (NDArray number = manager.create(n)) { - return muli(number); - } - } + PtNDArray muli(Number n); /** {@inheritDoc} */ @Override - public PtNDArray muli(NDArray other) { - JniUtils.muli(this, manager.from(other)); - return this; - } + PtNDArray muli(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray divi(Number n) { - try (NDArray number = manager.create(n)) { - return divi(number); - } - } + PtNDArray divi(Number n); /** {@inheritDoc} */ @Override - public PtNDArray divi(NDArray other) { - JniUtils.divi(this, manager.from(other)); - return this; - } + PtNDArray divi(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray modi(Number n) { - try (NDArray number = manager.create(n)) { - return modi(number); - } - } + PtNDArray modi(Number n); /** {@inheritDoc} */ @Override - public PtNDArray modi(NDArray other) { - JniUtils.remainderi(this, manager.from(other)); - return this; - } + PtNDArray modi(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray powi(Number n) { - try (NDArray number = manager.create(n)) { - return powi(number); - } - } + PtNDArray powi(Number n); /** {@inheritDoc} */ @Override - public PtNDArray powi(NDArray other) { - JniUtils.powi(this, manager.from(other)); - return this; - } + PtNDArray powi(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray sign() { - return JniUtils.sign(this); - } + PtNDArray sign(); /** {@inheritDoc} */ @Override - public PtNDArray signi() { - JniUtils.signi(this); - return this; - } + PtNDArray signi(); /** {@inheritDoc} */ @Override - public PtNDArray maximum(Number n) { - try (NDArray number = manager.create(n)) { - return maximum(number); - } - } + PtNDArray maximum(Number n); /** {@inheritDoc} */ @Override - public PtNDArray maximum(NDArray other) { - return JniUtils.max(this, manager.from(other)); - } + PtNDArray maximum(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray minimum(Number n) { - try (NDArray number = manager.create(n)) { - return minimum(number); - } - } + PtNDArray minimum(Number n); /** {@inheritDoc} */ @Override - public PtNDArray minimum(NDArray other) { - return JniUtils.min(this, manager.from(other)); - } + PtNDArray minimum(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray all() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.all(bool); - } - } + PtNDArray all(); /** {@inheritDoc} */ @Override - public PtNDArray any() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.any(bool); - } - } + PtNDArray any(); /** {@inheritDoc} */ @Override - public PtNDArray none() { - try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { - return JniUtils.none(bool); - } - } + PtNDArray none(); /** {@inheritDoc} */ @Override - public PtNDArray neg() { - return JniUtils.neg(this); - } + PtNDArray neg(); /** {@inheritDoc} */ @Override - public PtNDArray negi() { - JniUtils.negi(this); - return this; - } + PtNDArray negi(); /** {@inheritDoc} */ @Override - public PtNDArray abs() { - return JniUtils.abs(this); - } + PtNDArray abs(); /** {@inheritDoc} */ @Override - public PtNDArray square() { - return JniUtils.square(this); - } + PtNDArray square(); /** {@inheritDoc} */ @Override - public NDArray sqrt() { - return JniUtils.sqrt(this); - } + NDArray sqrt(); /** {@inheritDoc} */ @Override - public PtNDArray cbrt() { - return JniUtils.pow(this, (PtNDArray) manager.create(1.0 / 3)); - } + PtNDArray cbrt(); /** {@inheritDoc} */ @Override - public PtNDArray floor() { - return JniUtils.floor(this); - } + PtNDArray floor(); /** {@inheritDoc} */ @Override - public PtNDArray ceil() { - return JniUtils.ceil(this); - } + PtNDArray ceil(); /** {@inheritDoc} */ @Override - public PtNDArray round() { - return JniUtils.round(this); - } + PtNDArray round(); /** {@inheritDoc} */ @Override - public PtNDArray trunc() { - return JniUtils.trunc(this); - } + PtNDArray trunc(); /** {@inheritDoc} */ @Override - public PtNDArray exp() { - return JniUtils.exp(this); - } + PtNDArray exp(); /** {@inheritDoc} */ @Override - public NDArray gammaln() { - throw new UnsupportedOperationException("Not implemented yet."); - } + NDArray gammaln(); /** {@inheritDoc} */ @Override - public PtNDArray log() { - return JniUtils.log(this); - } + PtNDArray log(); /** {@inheritDoc} */ @Override - public PtNDArray log10() { - return JniUtils.log10(this); - } + PtNDArray log10(); /** {@inheritDoc} */ @Override - public PtNDArray log2() { - return JniUtils.log2(this); - } + PtNDArray log2(); /** {@inheritDoc} */ @Override - public PtNDArray sin() { - return JniUtils.sin(this); - } + PtNDArray sin(); /** {@inheritDoc} */ @Override - public PtNDArray cos() { - return JniUtils.cos(this); - } + PtNDArray cos(); /** {@inheritDoc} */ @Override - public PtNDArray tan() { - return JniUtils.tan(this); - } + PtNDArray tan(); /** {@inheritDoc} */ @Override - public PtNDArray asin() { - return JniUtils.asin(this); - } + PtNDArray asin(); /** {@inheritDoc} */ @Override - public PtNDArray acos() { - return JniUtils.acos(this); - } + PtNDArray acos(); /** {@inheritDoc} */ @Override - public PtNDArray atan() { - return JniUtils.atan(this); - } + PtNDArray atan(); /** {@inheritDoc} */ @Override - public PtNDArray sinh() { - return JniUtils.sinh(this); - } + PtNDArray sinh(); /** {@inheritDoc} */ @Override - public PtNDArray cosh() { - return JniUtils.cosh(this); - } + PtNDArray cosh(); /** {@inheritDoc} */ @Override - public PtNDArray tanh() { - return JniUtils.tanh(this); - } + PtNDArray tanh(); /** {@inheritDoc} */ @Override - public PtNDArray asinh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray asinh(); /** {@inheritDoc} */ @Override - public PtNDArray acosh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray acosh(); /** {@inheritDoc} */ @Override - public PtNDArray atanh() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray atanh(); /** {@inheritDoc} */ @Override - public PtNDArray toDegrees() { - return mul(180.0).div(Math.PI); - } + PtNDArray toDegrees(); /** {@inheritDoc} */ @Override - public PtNDArray toRadians() { - return mul(Math.PI).div(180.0); - } + PtNDArray toRadians(); /** {@inheritDoc} */ @Override - public PtNDArray max() { - return JniUtils.max(this); - } + PtNDArray max(); /** {@inheritDoc} */ @Override - public PtNDArray max(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.max(this, axes[0], keepDims); - } + PtNDArray max(int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public PtNDArray min() { - return JniUtils.min(this); - } + PtNDArray min(); /** {@inheritDoc} */ @Override - public PtNDArray min(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.min(this, axes[0], keepDims); - } + PtNDArray min(int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public PtNDArray sum() { - return JniUtils.sum(this); - } + PtNDArray sum(); /** {@inheritDoc} */ @Override - public PtNDArray sum(int[] axes, boolean keepDims) { - return JniUtils.sum(this, Arrays.stream(axes).mapToLong(i -> i).toArray(), keepDims); - } + PtNDArray sum(int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public NDArray cumProd(int axis) { - return JniUtils.cumProd(this, axis, null); - } + NDArray cumProd(int axis); /** {@inheritDoc} */ @Override - public NDArray cumProd(int axis, DataType dataType) { - return JniUtils.cumProd(this, axis, dataType); - } + NDArray cumProd(int axis, DataType dataType); /** {@inheritDoc} */ @Override - public PtNDArray prod() { - return JniUtils.prod(this); - } + PtNDArray prod(); /** {@inheritDoc} */ @Override - public PtNDArray prod(int[] axes, boolean keepDims) { - if (axes.length > 1) { - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.prod(this, axes[0], keepDims); - } + PtNDArray prod(int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public PtNDArray mean() { - return JniUtils.mean(this); - } + PtNDArray mean(); /** {@inheritDoc} */ @Override - public PtNDArray mean(int[] axes, boolean keepDims) { - if (axes.length > 1) { - // TODO fix this - throw new UnsupportedOperationException("Only 1 axis is support!"); - } - return JniUtils.mean(this, axes[0], keepDims); - } + PtNDArray mean(int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public PtNDArray normalize(double p, long dim, double eps) { - return JniUtils.normalize(this, p, dim, eps); - } + PtNDArray normalize(double p, long dim, double eps); /** {@inheritDoc} */ @Override - public PtNDArray rotate90(int times, int[] axes) { - if (axes.length != 2) { - throw new IllegalArgumentException("Axes must be 2"); - } - return JniUtils.rot90(this, times, axes); - } + PtNDArray rotate90(int times, int[] axes); /** {@inheritDoc} */ @Override - public PtNDArray trace(int offset, int axis1, int axis2) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray trace(int offset, int axis1, int axis2); /** {@inheritDoc} */ @Override - public NDList split(long sections, int axis) { - long size = getShape().get(axis) / sections; - return JniUtils.split(this, size, axis); - } + NDList split(long sections, int axis); /** {@inheritDoc} */ @Override - public NDList split(long[] indices, int axis) { - if (indices.length == 0) { - return new NDList(this); - } - List ptIndex = new ArrayList<>(); - ptIndex.add(indices[0]); - for (int i = 1; i < indices.length; i++) { - ptIndex.add(indices[i] - indices[i - 1]); - } - ptIndex.add(size(axis) - indices[indices.length - 1]); - return JniUtils.split(this, ptIndex.stream().mapToLong(i -> i).toArray(), axis); - } + NDList split(long[] indices, int axis); /** {@inheritDoc} */ @Override - public PtNDArray flatten() { - return JniUtils.flatten(this, 0, -1); - } + PtNDArray flatten(); /** {@inheritDoc} */ @Override - public NDArray flatten(int startDim, int endDim) { - return JniUtils.flatten(this, startDim, endDim); - } + NDArray flatten(int startDim, int endDim); /** {@inheritDoc} */ @Override - public NDArray fft(long length, long axis) { - return JniUtils.fft(this, length, axis); - } + public NDArray fft(long length, long axis); /** {@inheritDoc} */ @Override @@ -1070,523 +540,272 @@ public NDArray stft( boolean center, NDArray window, boolean normalize, - boolean returnComplex) { - return JniUtils.stft( - this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); - } + boolean returnComplex); /** {@inheritDoc} */ @Override - public PtNDArray reshape(Shape shape) { - return JniUtils.reshape(this, shape.getShape()); - } + PtNDArray reshape(Shape shape); /** {@inheritDoc} */ @Override - public PtNDArray expandDims(int axis) { - return JniUtils.unsqueeze(this, axis); - } + PtNDArray expandDims(int axis); /** {@inheritDoc} */ @Override - public PtNDArray squeeze() { - return JniUtils.squeeze(this); - } + PtNDArray squeeze(); /** {@inheritDoc} */ @Override - public PtNDArray squeeze(int axis) { - return JniUtils.squeeze(this, axis); - } + PtNDArray squeeze(int axis); /** {@inheritDoc} */ @Override - public PtNDArray squeeze(int[] axes) { - if (isScalar()) { - if (axes.length == 0 || (axes.length == 1 && axes[0] == 0)) { - return (PtNDArray) duplicate(); - } - throw new IllegalArgumentException( - "axis " + axes[0] + " is out of bounds for array of dimension 0"); - } - long[] shapeArr = getShape().getShape(); - List newShape = new ArrayList<>(); - Set set = - IntStream.of(axes).boxed().collect(Collectors.toCollection(HashSet::new)); - // check input - for (int axis : axes) { - if (shapeArr[axis] != 1) { - throw new IllegalArgumentException( - "cannot select an axis to squeeze out which has size not equal to one"); - } - } - for (int i = 0; i < shapeArr.length; i++) { - if (!set.contains(i)) { - newShape.add(shapeArr[i]); - } - } - return (PtNDArray) reshape(newShape.stream().mapToLong(i -> i).toArray()); - } + PtNDArray squeeze(int[] axes); /** {@inheritDoc} */ @Override - public PtNDArray logicalAnd(NDArray other) { - return JniUtils.logicalAnd(this, manager.from(other)); - } + PtNDArray logicalAnd(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray logicalOr(NDArray other) { - return JniUtils.logicalOr(this, manager.from(other)); - } + PtNDArray logicalOr(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray logicalXor(NDArray other) { - return JniUtils.logicalXor(this, manager.from(other)); - } + PtNDArray logicalXor(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray logicalNot() { - return JniUtils.logicalNot(this); - } + PtNDArray logicalNot(); /** {@inheritDoc} */ @Override - public PtNDArray argSort(int axis, boolean ascending) { - PtNDArray arr = JniUtils.argSort(this, axis, false); - if (ascending) { - return arr; - } - PtNDArray flip = JniUtils.flip(arr, new long[] {axis}); - arr.close(); - return flip; - } + PtNDArray argSort(int axis, boolean ascending); /** {@inheritDoc} */ @Override - public PtNDArray sort() { - return sort(-1); - } + PtNDArray sort(); /** {@inheritDoc} */ @Override - public PtNDArray sort(int axis) { - return JniUtils.sort(this, axis, false); - } + PtNDArray sort(int axis); /** {@inheritDoc} */ @Override - public PtNDArray softmax(int axis) { - return JniUtils.softmax(this, axis, getDataType()); - } + PtNDArray softmax(int axis); /** {@inheritDoc} */ @Override - public PtNDArray logSoftmax(int axis) { - return JniUtils.logSoftmax(this, axis, getDataType()); - } + PtNDArray logSoftmax(int axis); /** {@inheritDoc} */ @Override - public PtNDArray cumSum() { - // TODO: change default behavior on cumSum - if (isScalar()) { - return (PtNDArray) reshape(1); - } - if (isEmpty()) { - return (PtNDArray) reshape(0); - } - return cumSum(0); - } + PtNDArray cumSum(); /** {@inheritDoc} */ @Override - public PtNDArray cumSum(int axis) { - return JniUtils.cumSum(this, axis); - } + PtNDArray cumSum(int axis); /** {@inheritDoc} */ @Override - public void intern(NDArray replaced) { - PtNDArray arr = (PtNDArray) replaced; - Long oldHandle = handle.getAndSet(arr.handle.getAndSet(null)); - JniUtils.deleteNDArray(oldHandle); - // dereference old ndarray - arr.close(); - } + void intern(NDArray replaced); /** {@inheritDoc} */ @Override - public PtNDArray isInfinite() { - return JniUtils.isInf(this); - } + PtNDArray isInfinite(); /** {@inheritDoc} */ @Override - public PtNDArray isNaN() { - return JniUtils.isNaN(this); - } + PtNDArray isNaN(); /** {@inheritDoc} */ @Override - public PtNDArray tile(long repeats) { - // zero-dim - if (isEmpty()) { - return (PtNDArray) duplicate(); - } - // scalar - int dim = (isScalar()) ? 1 : getShape().dimension(); - long[] repeatsArray = new long[dim]; - Arrays.fill(repeatsArray, repeats); - return tile(repeatsArray); - } + PtNDArray tile(long repeats); /** {@inheritDoc} */ @Override - public PtNDArray tile(int axis, long repeats) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray tile(int axis, long repeats); /** {@inheritDoc} */ @Override - public PtNDArray tile(long[] repeats) { - return JniUtils.tile(this, repeats); - } + PtNDArray tile(long[] repeats); /** {@inheritDoc} */ @Override - public PtNDArray tile(Shape desiredShape) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray tile(Shape desiredShape); /** {@inheritDoc} */ @Override - public PtNDArray repeat(long repeats) { - // zero-dim - if (isEmpty()) { - return (PtNDArray) duplicate(); - } - // scalar - int dim = (isScalar()) ? 1 : getShape().dimension(); - long[] repeatsArray = new long[dim]; - Arrays.fill(repeatsArray, repeats); - return repeat(repeatsArray); - } + PtNDArray repeat(long repeats); /** {@inheritDoc} */ @Override - public PtNDArray repeat(int axis, long repeats) { - return JniUtils.repeat(this, repeats, axis); - } + PtNDArray repeat(int axis, long repeats); /** {@inheritDoc} */ @Override - public PtNDArray repeat(long[] repeats) { - PtNDArray result = this; - for (int dim = 0; dim < repeats.length; dim++) { - PtNDArray temp = result; - result = JniUtils.repeat(result, repeats[dim], dim); - if (temp != this) { - temp.close(); - } - } - return result; - } + PtNDArray repeat(long[] repeats); /** {@inheritDoc} */ @Override - public PtNDArray repeat(Shape desiredShape) { - return repeat(repeatsToMatchShape(desiredShape)); - } - - private long[] repeatsToMatchShape(Shape desiredShape) { - Shape curShape = getShape(); - int dimension = curShape.dimension(); - if (desiredShape.dimension() > dimension) { - throw new IllegalArgumentException("The desired shape has too many dimensions"); - } - if (desiredShape.dimension() < dimension) { - int additionalDimensions = dimension - desiredShape.dimension(); - desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape); - } - long[] repeats = new long[dimension]; - for (int i = 0; i < dimension; i++) { - if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) { - throw new IllegalArgumentException( - "The desired shape is not a multiple of the original shape"); - } - repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i))); - } - return repeats; - } + PtNDArray repeat(Shape desiredShape); /** {@inheritDoc} */ @Override - public PtNDArray dot(NDArray other) { - int selfDim = this.getShape().dimension(); - int otherDim = other.getShape().dimension(); - if (selfDim != otherDim || selfDim > 2) { - throw new UnsupportedOperationException( - "Dimension mismatch or high dimensional dot operation is not supported. Please" - + " use .matMul instead."); - } - return JniUtils.dot(this, manager.from(other)); - } + PtNDArray dot(NDArray other); /** {@inheritDoc} */ @Override - public NDArray matMul(NDArray other) { - if (isScalar() || other.isScalar()) { - throw new IllegalArgumentException("scalar is not allowed for matMul()"); - } - return JniUtils.matmul(this, manager.from(other)); - } + NDArray matMul(NDArray other); /** {@inheritDoc} */ @Override - public PtNDArray clip(Number min, Number max) { - return JniUtils.clip(this, min, max); - } + PtNDArray clip(Number min, Number max); /** {@inheritDoc} */ @Override - public PtNDArray swapAxes(int axis1, int axis2) { - return JniUtils.transpose(this, axis1, axis2); - } + PtNDArray swapAxes(int axis1, int axis2); /** {@inheritDoc} */ @Override - public NDArray flip(int... axes) { - return JniUtils.flip(this, Arrays.stream(axes).mapToLong(ele -> (long) ele).toArray()); - } + NDArray flip(int... axes); /** {@inheritDoc} */ @Override - public PtNDArray transpose() { - int dim = getShape().dimension(); - int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray(); - return transpose(reversedShape); - } + PtNDArray transpose(); /** {@inheritDoc} */ @Override - public PtNDArray transpose(int... axes) { - if (isScalar() && axes.length > 0) { - throw new IllegalArgumentException("axes don't match NDArray"); - } - return JniUtils.permute(this, Arrays.stream(axes).mapToLong(i -> i).toArray()); - } + PtNDArray transpose(int... axes); /** {@inheritDoc} */ @Override - public PtNDArray broadcast(Shape shape) { - return JniUtils.broadcast(this, shape); - } + PtNDArray broadcast(Shape shape); /** {@inheritDoc} */ @Override - public PtNDArray argMax() { - if (isEmpty()) { - throw new IllegalArgumentException("attempt to get argMax of an empty NDArray"); - } - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMax(this); - } + PtNDArray argMax(); /** {@inheritDoc} */ @Override - public PtNDArray argMax(int axis) { - // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMax(this, axis, false); - } + PtNDArray argMax(int axis); /** {@inheritDoc} */ @Override - public PtNDArray argMin() { - if (isEmpty()) { - throw new IllegalArgumentException("attempt to get argMin of an empty NDArray"); - } - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMin(this); - } + PtNDArray argMin(); /** {@inheritDoc} */ @Override - public PtNDArray argMin(int axis) { - // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 - if (isScalar()) { - return (PtNDArray) manager.create(0L); - } - return JniUtils.argMin(this, axis, false); - } + PtNDArray argMin(int axis); /** {@inheritDoc} */ @Override - public PtNDArray percentile(Number percentile) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray percentile(Number percentile); /** {@inheritDoc} */ @Override - public PtNDArray percentile(Number percentile, int[] axes) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray percentile(Number percentile, int[] axes); /** {@inheritDoc} */ @Override - public PtNDArray median() { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray median(); /** {@inheritDoc} */ @Override - public PtNDArray median(int[] axes) { - throw new UnsupportedOperationException("Not implemented"); - } + PtNDArray median(int[] axes); /** {@inheritDoc} */ @Override - public PtNDArray toDense() { - if (!isSparse() && JniUtils.getLayout(this) != 2) { - return (PtNDArray) duplicate(); - } - return JniUtils.toDense(this); - } + PtNDArray toDense(); /** {@inheritDoc} */ @Override - public PtNDArray toSparse(SparseFormat fmt) { - if (fmt == SparseFormat.DENSE) { - throw new IllegalArgumentException("Default type is not allowed"); - } - if (fmt != SparseFormat.COO) { - throw new UnsupportedOperationException("Only COO sparse type supported for PyTorch"); - } - if (fmt == getSparseFormat()) { - return (PtNDArray) duplicate(); - } - return JniUtils.toSparse(this); - } + PtNDArray toSparse(SparseFormat fmt); /** {@inheritDoc} */ @Override - public PtNDArray nonzero() { - return JniUtils.nonZeros(this); - } + PtNDArray nonzero(); /** {@inheritDoc} */ @Override - public PtNDArray erfinv() { - return JniUtils.erfinv(this); - } + PtNDArray erfinv(); /** {@inheritDoc} */ @Override - public PtNDArray inverse() { - return JniUtils.inverse(this); - } + PtNDArray inverse(); /** {@inheritDoc} */ @Override - public NDArray norm(boolean keepDims) { - return JniUtils.norm(this, 2, new int[] {}, keepDims); - } + NDArray norm(boolean keepDims); /** {@inheritDoc} */ @Override - public NDArray norm(int order, int[] axes, boolean keepDims) { - return JniUtils.norm(this, order, axes, keepDims); - } + NDArray norm(int order, int[] axes, boolean keepDims); /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth) { - return JniUtils.oneHot(this, depth, DataType.FLOAT32); - } + NDArray oneHot(int depth); /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth, DataType dataType) { - return JniUtils.oneHot(this, depth, dataType); - } + NDArray oneHot(int depth, DataType dataType); /** {@inheritDoc} */ @Override - public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { - throw new UnsupportedOperationException("Not implemented"); - } + NDArray oneHot(int depth, float onValue, float offValue, DataType dataType); /** {@inheritDoc} */ @Override - public NDArray batchDot(NDArray other) { - throw new UnsupportedOperationException("Not implemented"); - } + NDArray batchDot(NDArray other); /** {@inheritDoc} */ @Override - public NDArray complex() { - return JniUtils.complex(this); - } + public NDArray complex(); /** {@inheritDoc} */ @Override - public NDArray real() { - return JniUtils.real(this); - } + public NDArray real(); /** {@inheritDoc} */ @Override - public PtNDArrayEx getNDArrayInternal() { - return ptNDArrayEx; - } + PtNDArrayEx getNDArrayInternal(); /** {@inheritDoc} */ @Override - public String toString() { - if (isReleased()) { - return "This array is already closed"; - } - // index operator in toDebugString is not supported for MKLDNN & Sparse layout - if (JniUtils.getLayout(this) != 0) { - try (NDArray tmp = toDense()) { - return tmp.toDebugString(); - } - } - return toDebugString(); - } + String toString(); /** {@inheritDoc} */ @Override - public boolean equals(Object obj) { - if (obj instanceof NDArray) { - return contentEquals((NDArray) obj); - } - return false; - } + boolean equals(Object obj); /** {@inheritDoc} */ @Override - public int hashCode() { - return 0; - } + int hashCode(); /** {@inheritDoc} */ @Override - public void close() { - Long pointer = handle.getAndSet(null); - if (pointer != null && pointer != -1) { - JniUtils.deleteNDArray(pointer); - } - manager.detachInternal(getUid()); - dataRef = null; - } + void close(); + + /** + * Returns the number of {@code NDArray} in the map used for garbage collection triggered + * closing. + * + * @return number of {@code NDArray} in the map used for garbage collection triggered closing + */ + int getNumOfNDArraysInGCMap(); + + /** + * Returns the number of {@code NDArray} in the hierarchy of {@code NDManager}. + * + * @return number of {@code NDArray} in the hierarchy of {@code NDManager} + */ + int getNumOfNDArraysInNDManagerHierarchy(); + + /** + * Returns the raw implementation of this interface. This could be useful for debugging if the + * interface is implemented by a proxy. + * + * @return the raw implementation of this interface + */ + PtNDArrayImpl getImplementation(); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java new file mode 100644 index 00000000000..50101175d47 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayImpl.java @@ -0,0 +1,1666 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.engine; + +import ai.djl.Device; +import ai.djl.ndarray.BaseNDManager; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.ndarray.types.SparseFormat; +import ai.djl.pytorch.jni.JniUtils; +import ai.djl.util.NativeResourceImpl; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** {@code PtNDArrayImpl} is the PyTorch implementation of {@link NDArray}. */ +public final class PtNDArrayImpl extends NativeResourceImpl implements PtNDArray { + + private String name; + private Device device; + private DataType dataType; + private Shape shape; + private SparseFormat sparseFormat; + // use Boolean object to maintain three status: null, false, true + private Boolean hasGradient; + private PtNDManager manager; + private PtNDArrayEx ptNDArrayEx; + private String[] strs; + + // keep a reference to direct buffer to avoid GC release the memory + @SuppressWarnings("PMD.UnusedPrivateField") + private ByteBuffer dataRef; + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead). + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + */ + private PtNDArrayImpl(PtNDManager manager, long handle) { + super(handle); + this.manager = manager; + this.ptNDArrayEx = new PtNDArrayEx(this); + manager.attachInternal(getUid(), this); + } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead) with the data that is hold on Java side. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @param data the direct buffer of the data + */ + private PtNDArrayImpl(PtNDManager manager, long handle, ByteBuffer data) { + super(handle); + this.manager = manager; + this.ptNDArrayEx = new PtNDArrayEx(this); + manager.attachInternal(getUid(), this); + dataRef = data; + } + + /** + * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle + * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. + * + * @param manager the manager to attach the new array to + * @param strs the string array + * @param shape the {@link Shape} of the {@link NDArray} + */ + private PtNDArrayImpl(PtNDManager manager, String[] strs, Shape shape) { + super(-1L); + this.manager = manager; + this.strs = strs; + this.shape = shape; + this.dataType = DataType.STRING; + } + + /** {@inheritDoc} */ + @Override + public PtNDManager getManager() { + return manager; + } + + /** {@inheritDoc} */ + @Override + public String getName() { + return name; + } + + /** {@inheritDoc} */ + @Override + public void setName(String name) { + this.name = name; + } + + /** {@inheritDoc} */ + @Override + public DataType getDataType() { + if (dataType == null) { + dataType = JniUtils.getDataType(this); + } + return dataType; + } + + /** {@inheritDoc} */ + @Override + public Device getDevice() { + if (device == null) { + device = JniUtils.getDevice(this); + } + return device; + } + + /** {@inheritDoc} */ + @Override + public Shape getShape() { + if (shape == null) { + shape = JniUtils.getShape(this); + } + return shape; + } + + /** {@inheritDoc} */ + @Override + public SparseFormat getSparseFormat() { + if (sparseFormat == null) { + sparseFormat = JniUtils.getSparseFormat(this); + } + return sparseFormat; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDevice(Device device, boolean copy) { + if (device.equals(getDevice()) && !copy) { + return this; + } + return JniUtils.to(this, getDataType(), device); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toType(DataType dataType, boolean copy) { + if (dataType.equals(getDataType()) && !copy) { + return this; + } + return JniUtils.to(this, dataType, getDevice()); + } + + /** {@inheritDoc} */ + @Override + public void setRequiresGradient(boolean requiresGrad) { + JniUtils.attachGradient(this, requiresGrad); + hasGradient = requiresGrad; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray getGradient() { + if (!hasGradient()) { + throw new IllegalStateException( + "No gradient attached to this NDArray, please call array.setRequiresGradient()" + + " on your NDArray or block.setInitializer() on your Block"); + } + PtNDArray res = JniUtils.getGradient(this); + // If you call getGradient() before you run the backward, + // you will get nothing in PyTorch engine. + // To align with MXNet's behavior, we will create a zeros NDArray. + // TODO should we access the grad NDArray after we close the parameter NDArray? + if (res == null) { + res = (PtNDArray) manager.zeros(getShape()); + } + return res; + } + + /** {@inheritDoc} */ + @Override + public boolean hasGradient() { + if (hasGradient == null) { + hasGradient = JniUtils.requiresGrad(this); + } + return hasGradient; + } + + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + return JniUtils.detachGradient(this); + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + return JniUtils.getByteBuffer(this); + } + + /** {@inheritDoc} */ + @Override + public String[] toStringArray(Charset charset) { + return strs; + } + + /** {@inheritDoc} */ + @Override + public void set(Buffer buffer) { + int size = Math.toIntExact(size()); + DataType type = getDataType(); + BaseNDManager.validateBuffer(buffer, type, size); + // TODO how do we handle the exception happened in the middle + dataRef = null; + if (buffer.isDirect() && buffer instanceof ByteBuffer) { + // If NDArray is on the GPU, it is native code responsibility to control the data life + // cycle + if (!getDevice().isGpu()) { + dataRef = (ByteBuffer) buffer; + } + JniUtils.set(this, (ByteBuffer) buffer); + return; + } + // int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType + ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes()); + BaseNDManager.copyBuffer(buffer, buf); + + // If NDArray is on the GPU, it is native code responsibility to control the data life cycle + if (!getDevice().isGpu()) { + dataRef = buf; + } + JniUtils.set(this, buf); + } + + /** {@inheritDoc} */ + @Override + public NDArray get(NDManager manager, long... indices) { + return JniUtils.getItem(this, indices, (PtNDManager) manager); + } + + /** {@inheritDoc} */ + @Override + public NDArray gather(NDArray index, int axis) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray index is supported."); + } + return JniUtils.gather(this, (PtNDArray) index, axis); + } + + /** {@inheritDoc} */ + @Override + public NDArray gatherNd(NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray index is supported."); + } + Shape indexShape = index.getShape(); + Shape dataShape = getShape(); + int indexingDepth = (int) indexShape.get(0); + if (indexingDepth > dataShape.dimension()) { + throw new IllegalArgumentException( + "Indexing rank " + + indexShape.get(0) + + " exceeds the data rank " + + dataShape.dimension()); + } + // Row-first order, the linear index is accumulated from z->y->x. + // For example, dataShape = (3, 2, 3), indexShape = (2, 3, 3) + // The method is: indexLinear = index[1] + index[0] * dataShape[1], row-first order + // indexLinear has shape (3, 3), is from combining the index along 0 axis. + // Each number in indexLinear is an indexing to an element in data (3, 2, ...). + // data is flattened to be (3*2, ...) which can be indexed by indexLinear. + // Finally, reshape the output to (3, 3, ...). Thus + // totalShape = indexShape.slice(1).addAll(dataShape.slice(indexingDepth)); + NDArray indexLinear = index.get("{}, ...", indexingDepth - 1); + long dim = 1; + for (int i = indexingDepth - 2; i > -1; i--) { + dim = dim * dataShape.get(i + 1); + indexLinear = indexLinear.addi(index.get("{}, ...", i).muli(dim)); + } + NDArray dataFlatten = this.flatten(0, indexingDepth - 1); + return dataFlatten.get(indexLinear); + } + + /** {@inheritDoc} */ + @Override + public NDArray take(NDManager manager, NDArray index) { + if (!(index instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.take(this, (PtNDArray) index, (PtNDManager) manager); + } + + /** {@inheritDoc} */ + @Override + public NDArray put(NDArray index, NDArray data) { + if (!(index instanceof PtNDArray) || !(data instanceof PtNDArray)) { + throw new IllegalArgumentException("Only PtNDArray is supported."); + } + return JniUtils.put(this, (PtNDArray) index, (PtNDArray) data); + } + + /** {@inheritDoc} */ + @Override + public void copyTo(NDArray array) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public void attach(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void returnResource(NDManager manager) { + detach(); + this.manager = (PtNDManager) manager; + manager.attachUncappedInternal(getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void tempAttach(NDManager manager) { + NDManager original = this.manager; + detach(); + this.manager = (PtNDManager) manager; + manager.tempAttachInternal(original, getUid(), this); + } + + /** {@inheritDoc} */ + @Override + public void detach() { + manager.detachInternal(getUid()); + manager = PtNDManager.getSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public NDArray duplicate() { + return JniUtils.clone(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray booleanMask(NDArray index, int axis) { + Shape indexShape = index.getShape(); + if (indexShape.equals(getShape())) { + // Result is flattened since shape is undetermined + return JniUtils.booleanMask(this, manager.from(index)); + } else if (indexShape.equals(getShape().slice(axis))) { + // index will be broadcast by default + try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) { + // Shape recovery + Shape remainder = getShape().slice(0, axis); + long selectedSize = flattedResult.getShape().size() / remainder.size(); + return flattedResult.reshape(remainder.addAll(new Shape(selectedSize))); + } + } else { + throw new UnsupportedOperationException( + "Not supported for shape not broadcastable " + + indexShape + + " vs " + + getShape()); + } + } + + /** {@inheritDoc} */ + @Override + public NDArray sequenceMask(NDArray sequenceLength, float value) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** {@inheritDoc} */ + @Override + public NDArray sequenceMask(NDArray sequenceLength) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** {@inheritDoc} */ + @Override + public boolean contentEquals(Number number) { + return contentEquals(manager.create(number)); + } + + /** {@inheritDoc} */ + @Override + public boolean contentEquals(NDArray other) { + if (other == null || (!shapeEquals(other))) { + return false; + } + if (getDataType() != other.getDataType()) { + return false; + } + return JniUtils.contentEqual(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray eq(Number n) { + try (NDArray number = manager.create(n)) { + return eq(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray eq(NDArray other) { + return JniUtils.eq(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neq(Number n) { + try (NDArray number = manager.create(n)) { + return neq(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neq(NDArray other) { + return JniUtils.neq(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gt(Number n) { + try (NDArray number = manager.create(n)) { + return gt(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gt(NDArray other) { + return JniUtils.gt(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gte(Number n) { + try (NDArray number = manager.create(n)) { + return gte(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray gte(NDArray other) { + return JniUtils.gte(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lt(Number n) { + try (NDArray number = manager.create(n)) { + return lt(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lt(NDArray other) { + return JniUtils.lt(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lte(Number n) { + try (NDArray number = manager.create(n)) { + return lte(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray lte(NDArray other) { + return JniUtils.lte(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray add(Number n) { + try (NDArray number = manager.create(n)) { + return add(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray add(NDArray other) { + return JniUtils.add(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sub(Number n) { + try (NDArray number = manager.create(n)) { + return sub(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sub(NDArray other) { + return JniUtils.sub(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mul(Number n) { + try (NDArray number = manager.create(n)) { + return mul(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mul(NDArray other) { + return JniUtils.mul(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray div(Number n) { + try (NDArray number = manager.create(n)) { + return div(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray div(NDArray other) { + return JniUtils.div(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mod(Number n) { + try (NDArray number = manager.create(n)) { + return mod(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mod(NDArray other) { + return JniUtils.remainder(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray pow(Number n) { + try (NDArray number = manager.create(n)) { + return pow(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray pow(NDArray other) { + return JniUtils.pow(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray addi(Number n) { + try (NDArray number = manager.create(n)) { + return addi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray addi(NDArray other) { + JniUtils.addi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray subi(Number n) { + try (NDArray number = manager.create(n)) { + return subi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray subi(NDArray other) { + JniUtils.subi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray muli(Number n) { + try (NDArray number = manager.create(n)) { + return muli(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray muli(NDArray other) { + JniUtils.muli(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray divi(Number n) { + try (NDArray number = manager.create(n)) { + return divi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray divi(NDArray other) { + JniUtils.divi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray modi(Number n) { + try (NDArray number = manager.create(n)) { + return modi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray modi(NDArray other) { + JniUtils.remainderi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray powi(Number n) { + try (NDArray number = manager.create(n)) { + return powi(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray powi(NDArray other) { + JniUtils.powi(this, manager.from(other)); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sign() { + return JniUtils.sign(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray signi() { + JniUtils.signi(this); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray maximum(Number n) { + try (NDArray number = manager.create(n)) { + return maximum(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray maximum(NDArray other) { + return JniUtils.max(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray minimum(Number n) { + try (NDArray number = manager.create(n)) { + return minimum(number); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray minimum(NDArray other) { + return JniUtils.min(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray all() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.all(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray any() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.any(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray none() { + try (PtNDArray bool = toType(DataType.BOOLEAN, true)) { + return JniUtils.none(bool); + } + } + + /** {@inheritDoc} */ + @Override + public PtNDArray neg() { + return JniUtils.neg(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray negi() { + JniUtils.negi(this); + return this; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray abs() { + return JniUtils.abs(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray square() { + return JniUtils.square(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray sqrt() { + return JniUtils.sqrt(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cbrt() { + return JniUtils.pow(this, (PtNDArray) manager.create(1.0 / 3)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray floor() { + return JniUtils.floor(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray ceil() { + return JniUtils.ceil(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray round() { + return JniUtils.round(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray trunc() { + return JniUtils.trunc(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray exp() { + return JniUtils.exp(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray gammaln() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log() { + return JniUtils.log(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log10() { + return JniUtils.log10(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray log2() { + return JniUtils.log2(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sin() { + return JniUtils.sin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cos() { + return JniUtils.cos(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tan() { + return JniUtils.tan(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray asin() { + return JniUtils.asin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray acos() { + return JniUtils.acos(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray atan() { + return JniUtils.atan(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sinh() { + return JniUtils.sinh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cosh() { + return JniUtils.cosh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tanh() { + return JniUtils.tanh(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray asinh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray acosh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray atanh() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDegrees() { + return mul(180.0).div(Math.PI); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toRadians() { + return mul(Math.PI).div(180.0); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray max() { + return JniUtils.max(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray max(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.max(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray min() { + return JniUtils.min(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray min(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.min(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sum() { + return JniUtils.sum(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sum(int[] axes, boolean keepDims) { + return JniUtils.sum(this, Arrays.stream(axes).mapToLong(i -> i).toArray(), keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray cumProd(int axis) { + return JniUtils.cumProd(this, axis, null); + } + + /** {@inheritDoc} */ + @Override + public NDArray cumProd(int axis, DataType dataType) { + return JniUtils.cumProd(this, axis, dataType); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray prod() { + return JniUtils.prod(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray prod(int[] axes, boolean keepDims) { + if (axes.length > 1) { + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.prod(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mean() { + return JniUtils.mean(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray mean(int[] axes, boolean keepDims) { + if (axes.length > 1) { + // TODO fix this + throw new UnsupportedOperationException("Only 1 axis is support!"); + } + return JniUtils.mean(this, axes[0], keepDims); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray normalize(double p, long dim, double eps) { + return JniUtils.normalize(this, p, dim, eps); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray rotate90(int times, int[] axes) { + if (axes.length != 2) { + throw new IllegalArgumentException("Axes must be 2"); + } + return JniUtils.rot90(this, times, axes); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray trace(int offset, int axis1, int axis2) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDList split(long sections, int axis) { + long size = getShape().get(axis) / sections; + return JniUtils.split(this, size, axis); + } + + /** {@inheritDoc} */ + @Override + public NDList split(long[] indices, int axis) { + if (indices.length == 0) { + return new NDList(this); + } + List ptIndex = new ArrayList<>(); + ptIndex.add(indices[0]); + for (int i = 1; i < indices.length; i++) { + ptIndex.add(indices[i] - indices[i - 1]); + } + ptIndex.add(size(axis) - indices[indices.length - 1]); + return JniUtils.split(this, ptIndex.stream().mapToLong(i -> i).toArray(), axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray flatten() { + return JniUtils.flatten(this, 0, -1); + } + + /** {@inheritDoc} */ + @Override + public NDArray flatten(int startDim, int endDim) { + return JniUtils.flatten(this, startDim, endDim); + } + + /** {@inheritDoc} */ + @Override + public NDArray fft(long length, long axis) { + return JniUtils.fft(this, length, axis); + } + + /** {@inheritDoc} */ + @Override + public NDArray stft( + long nFft, + long hopLength, + boolean center, + NDArray window, + boolean normalize, + boolean returnComplex) { + return JniUtils.stft( + this, nFft, hopLength, (PtNDArray) window, center, normalize, returnComplex); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray reshape(Shape shape) { + return JniUtils.reshape(this, shape.getShape()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray expandDims(int axis) { + return JniUtils.unsqueeze(this, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze() { + return JniUtils.squeeze(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze(int axis) { + return JniUtils.squeeze(this, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray squeeze(int[] axes) { + if (isScalar()) { + if (axes.length == 0 || (axes.length == 1 && axes[0] == 0)) { + return (PtNDArray) duplicate(); + } + throw new IllegalArgumentException( + "axis " + axes[0] + " is out of bounds for array of dimension 0"); + } + long[] shapeArr = getShape().getShape(); + List newShape = new ArrayList<>(); + Set set = + IntStream.of(axes).boxed().collect(Collectors.toCollection(HashSet::new)); + // check input + for (int axis : axes) { + if (shapeArr[axis] != 1) { + throw new IllegalArgumentException( + "cannot select an axis to squeeze out which has size not equal to one"); + } + } + for (int i = 0; i < shapeArr.length; i++) { + if (!set.contains(i)) { + newShape.add(shapeArr[i]); + } + } + return (PtNDArray) reshape(newShape.stream().mapToLong(i -> i).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalAnd(NDArray other) { + return JniUtils.logicalAnd(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalOr(NDArray other) { + return JniUtils.logicalOr(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalXor(NDArray other) { + return JniUtils.logicalXor(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logicalNot() { + return JniUtils.logicalNot(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argSort(int axis, boolean ascending) { + PtNDArray arr = JniUtils.argSort(this, axis, false); + if (ascending) { + return arr; + } + PtNDArray flip = JniUtils.flip(arr, new long[] {axis}); + arr.close(); + return flip; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sort() { + return sort(-1); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray sort(int axis) { + return JniUtils.sort(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray softmax(int axis) { + return JniUtils.softmax(this, axis, getDataType()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray logSoftmax(int axis) { + return JniUtils.logSoftmax(this, axis, getDataType()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cumSum() { + // TODO: change default behavior on cumSum + if (isScalar()) { + return (PtNDArray) reshape(1); + } + if (isEmpty()) { + return (PtNDArray) reshape(0); + } + return cumSum(0); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray cumSum(int axis) { + return JniUtils.cumSum(this, axis); + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + PtNDArray arr = (PtNDArray) replaced; + Long oldHandle = handle.getAndSet(arr.getAndSetHandleNull()); + JniUtils.deleteNDArray(oldHandle); + // dereference old ndarray + arr.close(); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray isInfinite() { + return JniUtils.isInf(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray isNaN() { + return JniUtils.isNaN(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(long repeats) { + // zero-dim + if (isEmpty()) { + return (PtNDArray) duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return tile(repeatsArray); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(int axis, long repeats) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(long[] repeats) { + return JniUtils.tile(this, repeats); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray tile(Shape desiredShape) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(long repeats) { + // zero-dim + if (isEmpty()) { + return (PtNDArray) duplicate(); + } + // scalar + int dim = (isScalar()) ? 1 : getShape().dimension(); + long[] repeatsArray = new long[dim]; + Arrays.fill(repeatsArray, repeats); + return repeat(repeatsArray); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(int axis, long repeats) { + return JniUtils.repeat(this, repeats, axis); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(long[] repeats) { + PtNDArray result = this; + for (int dim = 0; dim < repeats.length; dim++) { + PtNDArray temp = result; + result = JniUtils.repeat(result, repeats[dim], dim); + if (temp != this) { + temp.close(); + } + } + return result; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray repeat(Shape desiredShape) { + return repeat(repeatsToMatchShape(desiredShape)); + } + + private long[] repeatsToMatchShape(Shape desiredShape) { + Shape curShape = getShape(); + int dimension = curShape.dimension(); + if (desiredShape.dimension() > dimension) { + throw new IllegalArgumentException("The desired shape has too many dimensions"); + } + if (desiredShape.dimension() < dimension) { + int additionalDimensions = dimension - desiredShape.dimension(); + desiredShape = curShape.slice(0, additionalDimensions).addAll(desiredShape); + } + long[] repeats = new long[dimension]; + for (int i = 0; i < dimension; i++) { + if (curShape.get(i) == 0 || desiredShape.get(i) % curShape.get(i) != 0) { + throw new IllegalArgumentException( + "The desired shape is not a multiple of the original shape"); + } + repeats[i] = Math.round(Math.ceil((double) desiredShape.get(i) / curShape.get(i))); + } + return repeats; + } + + /** {@inheritDoc} */ + @Override + public PtNDArray dot(NDArray other) { + int selfDim = this.getShape().dimension(); + int otherDim = other.getShape().dimension(); + if (selfDim != otherDim || selfDim > 2) { + throw new UnsupportedOperationException( + "Dimension mismatch or high dimensional dot operation is not supported. Please" + + " use .matMul instead."); + } + return JniUtils.dot(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public NDArray matMul(NDArray other) { + if (isScalar() || other.isScalar()) { + throw new IllegalArgumentException("scalar is not allowed for matMul()"); + } + return JniUtils.matmul(this, manager.from(other)); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray clip(Number min, Number max) { + return JniUtils.clip(this, min, max); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray swapAxes(int axis1, int axis2) { + return JniUtils.transpose(this, axis1, axis2); + } + + /** {@inheritDoc} */ + @Override + public NDArray flip(int... axes) { + return JniUtils.flip(this, Arrays.stream(axes).mapToLong(ele -> (long) ele).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray transpose() { + int dim = getShape().dimension(); + int[] reversedShape = IntStream.range(0, dim).map(i -> dim - i - 1).toArray(); + return transpose(reversedShape); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray transpose(int... axes) { + if (isScalar() && axes.length > 0) { + throw new IllegalArgumentException("axes don't match NDArray"); + } + return JniUtils.permute(this, Arrays.stream(axes).mapToLong(i -> i).toArray()); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray broadcast(Shape shape) { + return JniUtils.broadcast(this, shape); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMax() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMax of an empty NDArray"); + } + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMax(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMax(int axis) { + // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMax(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMin() { + if (isEmpty()) { + throw new IllegalArgumentException("attempt to get argMin of an empty NDArray"); + } + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMin(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray argMin(int axis) { + // TODO pytorch bug: https://github.com/pytorch/pytorch/issues/37084 + if (isScalar()) { + return (PtNDArray) manager.create(0L); + } + return JniUtils.argMin(this, axis, false); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray percentile(Number percentile) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray percentile(Number percentile, int[] axes) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray median() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray median(int[] axes) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toDense() { + if (!isSparse() && JniUtils.getLayout(this) != 2) { + return (PtNDArray) duplicate(); + } + return JniUtils.toDense(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray toSparse(SparseFormat fmt) { + if (fmt == SparseFormat.DENSE) { + throw new IllegalArgumentException("Default type is not allowed"); + } + if (fmt != SparseFormat.COO) { + throw new UnsupportedOperationException("Only COO sparse type supported for PyTorch"); + } + if (fmt == getSparseFormat()) { + return (PtNDArray) duplicate(); + } + return JniUtils.toSparse(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray nonzero() { + return JniUtils.nonZeros(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray erfinv() { + return JniUtils.erfinv(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArray inverse() { + return JniUtils.inverse(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray norm(boolean keepDims) { + return JniUtils.norm(this, 2, new int[] {}, keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray norm(int order, int[] axes, boolean keepDims) { + return JniUtils.norm(this, order, axes, keepDims); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth) { + return JniUtils.oneHot(this, depth, DataType.FLOAT32); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, DataType dataType) { + return JniUtils.oneHot(this, depth, dataType); + } + + /** {@inheritDoc} */ + @Override + public NDArray oneHot(int depth, float onValue, float offValue, DataType dataType) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDArray batchDot(NDArray other) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDArray complex() { + return JniUtils.complex(this); + } + + /** {@inheritDoc} */ + @Override + public NDArray real() { + return JniUtils.real(this); + } + + /** {@inheritDoc} */ + @Override + public PtNDArrayEx getNDArrayInternal() { + return ptNDArrayEx; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (isReleased()) { + return "This array is already closed"; + } + // index operator in toDebugString is not supported for MKLDNN & Sparse layout + if (JniUtils.getLayout(this) != 0) { + try (NDArray tmp = toDense()) { + return tmp.toDebugString(); + } + } + return toDebugString(); + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (obj instanceof NDArray) { + return contentEquals((NDArray) obj); + } + return false; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return 0; + } + + /** {@inheritDoc} */ + @Override + public void close() { + Long pointer = handle.getAndSet(null); + if (pointer != null && pointer != -1) { + JniUtils.deleteNDArray(pointer); + } + manager.detachInternal(getUid()); + dataRef = null; + } + + /** {@inheritDoc} */ + @Override + public int getNumOfNDArraysInGCMap() { + throw new UnsupportedOperationException("Not supported!"); + } + + /** {@inheritDoc} */ + @Override + public int getNumOfNDArraysInNDManagerHierarchy() { + return PtNDManager.debugCountNDArraysFromSystemManager(); + } + + /** {@inheritDoc} */ + @Override + public PtNDArrayImpl getImplementation() { + return this; + } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead). Depending on the switch {@code useGarbageCollection}, the returned {@code NDArray} + * will be returned as a proxy or a direct instance. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, long handle) { + PtNDArray instance = new PtNDArrayImpl(manager, handle); + if (SwitchGarbageCollection.isUseGarbageCollection()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + /** + * Constructs a PyTorch {@code NDArray} from a native handle (internal. Use {@link NDManager} + * instead) with the data that is hold on Java side. Depending on the switch {@code + * useGarbageCollection}, the returned {@code NDArray} will be returned as a proxy or a direct + * instance. + * + * @param manager the manager to attach the new array to + * @param handle the pointer to the native PyTorch memory + * @param data the direct buffer of the data + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, long handle, ByteBuffer data) { + PtNDArray instance = new PtNDArrayImpl(manager, handle, data); + if (SwitchGarbageCollection.isUseGarbageCollection()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } + + /** + * Constructs a PyTorch {@code NDArray} to hold string array with a dummy native handle + * (internal. Use {@link NDManager} instead) with the data that is hold on Java side. Depending + * on the switch {@code useGarbageCollection}, the returned {@code NDArray} will be returned as + * a proxy or a direct instance. + * + * @param manager the manager to attach the new array to + * @param strs the string array + * @param shape the {@link Shape} of the {@link NDArray} + * @return the new {@code NDArray} + */ + public static PtNDArray newPtNDArray(PtNDManager manager, String[] strs, Shape shape) { + PtNDArray instance = new PtNDArrayImpl(manager, strs, shape); + if (SwitchGarbageCollection.isUseGarbageCollection()) { + instance = manager.getProxyMaker().wrap(instance); + } + return instance; + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java new file mode 100644 index 00000000000..50d65ab4377 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayProxyMaker.java @@ -0,0 +1,69 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.engine; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.gc.DynamicInvocationHandler; +import ai.djl.ndarray.gc.NDArrayProxyMaker; +import ai.djl.ndarray.gc.WeakHashMapWrapper; + +import java.lang.reflect.Proxy; +import java.util.concurrent.atomic.AtomicLong; + +/** {@code PtNDArrayProxyMaker} creates a proxy facade. */ +public class PtNDArrayProxyMaker implements NDArrayProxyMaker { + + ThreadLocal> tLocalMap = new ThreadLocal<>(); + + AtomicLong counter = new AtomicLong(0); + + /** {@inheritDoc} */ + @Override + public int mapSize() { + return getLocalWeakHashMapWrapper().size(); + } + + private WeakHashMapWrapper getLocalWeakHashMapWrapper() { + if (tLocalMap.get() == null) { + tLocalMap.set(new WeakHashMapWrapper<>()); + } + return tLocalMap.get(); + } + + /** {@inheritDoc} */ + @Override + public void gc() { + getLocalWeakHashMapWrapper().checkQueue(); + } + + /** + * Wraps the {@link PtNDArray} in a proxy facade. + * + * @param array the array to wrap + * @return the wrapped array + */ + @Override + public PtNDArray wrap(NDArray array) { + + WeakHashMapWrapper map = getLocalWeakHashMapWrapper(); + + String uid = array.getUid() + "-" + counter.incrementAndGet(); + map.put(uid, array); + DynamicInvocationHandler handler = new DynamicInvocationHandler(uid, map, this); + return (PtNDArray) + Proxy.newProxyInstance( + Thread.currentThread().getContextClassLoader(), + new Class[] {PtNDArray.class}, + handler); + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 0188acea3f8..8269ad05ff7 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -32,6 +32,8 @@ public class PtNDManager extends BaseNDManager { private static final PtNDManager SYSTEM_MANAGER = new SystemManager(); + protected PtNDArrayProxyMaker proxyMaker; + private PtNDManager(NDManager parent, Device device) { super(parent, device); } @@ -40,6 +42,18 @@ static PtNDManager getSystemManager() { return SYSTEM_MANAGER; } + /** {@inheritDoc} */ + @Override + public PtNDArrayProxyMaker getProxyMaker() { + return getSystemManager().getProxyMaker(); + } + + /** {@inheritDoc} */ + @Override + public void gc() { + getSystemManager().getProxyMaker().gc(); + } + /** {@inheritDoc} */ @Override public ByteBuffer allocateDirect(int capacity) { @@ -81,7 +95,7 @@ public PtNDArray create(Buffer data, Shape shape, DataType dataType) { /** {@inheritDoc} */ @Override public NDArray create(String[] data, Charset charset, Shape shape) { - return new PtNDArray(this, data, shape); + return PtNDArrayImpl.newPtNDArray(this, data, shape); } /** {@inheritDoc} */ @@ -192,11 +206,39 @@ public final Engine getEngine() { return Engine.getEngine(PtEngine.ENGINE_NAME); } + /** + * Dumps debug information about the current {@link PtNDManager} and all its children. + * + * @param detailed whether to dump detailed information + */ + public static void debugDumpFromSystemManager(boolean detailed) { + if (detailed) { + ((BaseNDManager) PtNDManager.getSystemManager()).debugDumpDetailed(0); + } else { + ((BaseNDManager) PtNDManager.getSystemManager()).debugDump(0); + } + } + + /** + * Returns the number of {@link NDArray} in the hierarchy of the {@link SystemNDManager}. + * + * @return return the number of {@link NDArray} in the hierarchy of the {@link SystemNDManager} + */ + public static int debugCountNDArraysFromSystemManager() { + return ((BaseNDManager) PtNDManager.getSystemManager()).debugCountNDArrays(); + } + /** The SystemManager is the root {@link PtNDManager} of which all others are children. */ private static final class SystemManager extends PtNDManager implements SystemNDManager { SystemManager() { super(null, null); + this.proxyMaker = new PtNDArrayProxyMaker(); + } + + @Override + public PtNDArrayProxyMaker getProxyMaker() { + return this.proxyMaker; } } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java index b1c34688cdb..48b5192e24b 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValue.java @@ -16,8 +16,9 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import java.util.Arrays; import java.util.Map; @@ -28,7 +29,7 @@ * *

DJL doesn't support creating nested IValue. */ -public class IValue extends NativeResource { +public class IValue extends NativeResourceImpl { IValue(long handle) { super(handle); @@ -392,7 +393,7 @@ public double[] toDoubleArray() { * @return the NDArray value of this IValue */ public PtNDArray toTensor(PtNDManager manager) { - return new PtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); + return PtNDArrayImpl.newPtNDArray(manager, PyTorchLibrary.LIB.iValueToTensor(getHandle())); } /** @@ -403,9 +404,9 @@ public PtNDArray toTensor(PtNDManager manager) { */ public PtNDArray[] toTensorArray(PtNDManager manager) { long[] handles = PyTorchLibrary.LIB.iValueToTensorList(getHandle()); - PtNDArray[] ret = new PtNDArray[handles.length]; + PtNDArray[] ret = new PtNDArrayImpl[handles.length]; for (int i = 0; i < ret.length; ++i) { - ret[i] = new PtNDArray(manager, handles[i]); + ret[i] = PtNDArrayImpl.newPtNDArray(manager, handles[i]); } return ret; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 9c3a12f7c9e..dd6bcb335eb 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -30,6 +30,7 @@ import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.engine.PtDeviceType; import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.pytorch.engine.PtNDArrayImpl; import ai.djl.pytorch.engine.PtNDManager; import ai.djl.pytorch.engine.PtSymbolBlock; @@ -170,9 +171,9 @@ public static PtNDArray createNdFromByteBuffer( if (layout == 1 || layout == 2 || device.isGpu()) { // MKLDNN & COO & GPU device will explicitly make a copy in native code // so we don't want to hold a reference on Java side - return new PtNDArray(manager, handle); + return PtNDArrayImpl.newPtNDArray(manager, handle); } - return new PtNDArray(manager, handle, data); + return PtNDArrayImpl.newPtNDArray(manager, handle, data); } public static void emptyCudaCache() { @@ -182,7 +183,7 @@ public static void emptyCudaCache() { public static PtNDArray createEmptyNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchEmpty( shape.getShape(), @@ -195,7 +196,7 @@ public static PtNDArray createEmptyNdArray( public static PtNDArray createZerosNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchZeros( shape.getShape(), @@ -208,7 +209,7 @@ public static PtNDArray createZerosNdArray( public static PtNDArray createOnesNdArray( PtNDManager manager, Shape shape, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchOnes( shape.getShape(), @@ -226,7 +227,7 @@ public static PtNDArray full( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchFull( shape.getShape(), @@ -240,7 +241,7 @@ public static PtNDArray full( public static PtNDArray zerosLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchZerosLike( array.getHandle(), @@ -253,7 +254,7 @@ public static PtNDArray zerosLike( public static PtNDArray onesLike( PtNDArray array, DataType dType, Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( array.getManager(), PyTorchLibrary.LIB.torchOnesLike( array.getHandle(), @@ -272,7 +273,7 @@ public static PtNDArray arange( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchArange( start, @@ -293,7 +294,7 @@ public static PtNDArray linspace( Device device, SparseFormat fmt) { int layoutVal = layoutMapper(fmt, device); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchLinspace( start, @@ -306,7 +307,7 @@ public static PtNDArray linspace( } public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, Shape shape) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( values.getManager(), PyTorchLibrary.LIB.torchSparseCoo( shape.getShape(), indices.getHandle(), values.getHandle(), false)); @@ -319,7 +320,7 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) if (!device.equals(manager.getDevice())) { manager = manager.newSubManager(device); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchTo( ndArray.getHandle(), @@ -328,23 +329,23 @@ public static PtNDArray to(PtNDArray ndArray, DataType dataType, Device device) } public static PtNDArray toSparse(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToSparse(ndArray.getHandle())); } public static PtNDArray toDense(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchToDense(ndArray.getHandle())); } public static PtNDArray broadcast(PtNDArray ndArray, Shape shape) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExpand(ndArray.getHandle(), shape.getShape())); } public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSlice(ndArray.getHandle(), dim, start, stop, step)); } @@ -355,7 +356,7 @@ public static PtNDArray index( long[] maxIndices, long[] stepIndices, PtNDManager manager) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchIndex( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); @@ -421,7 +422,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index, PtNDManager m PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); } long ret = PyTorchLibrary.LIB.torchIndexAdvGet(ndArray.getHandle(), torchIndexHandle); - return new PtNDArray(manager, ret); + return PtNDArrayImpl.newPtNDArray(manager, ret); } finally { PyTorchLibrary.LIB.torchDeleteIndex(torchIndexHandle); } @@ -517,7 +518,7 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } @@ -526,7 +527,7 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager man if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); } @@ -534,7 +535,7 @@ public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray data) if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPut( ndArray.getHandle(), index.getHandle(), data.getHandle())); @@ -566,20 +567,20 @@ public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false)); } public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchWhere( condition.getHandle(), self.getHandle(), other.getHandle())); } public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMaskedSelect(ndArray.getHandle(), indicesNd.getHandle())); } @@ -594,102 +595,104 @@ public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager m // due to significant performance gain // for commonly used data loading call if (indices.length == 1) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices[0])); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchGetItem(ndArray.getHandle(), indices)); } public static PtNDArray clone(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.tensorClone(ndArray.getHandle())); } public static PtNDArray reshape(PtNDArray ndArray, long[] shape) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchReshape(ndArray.getHandle(), shape)); } public static PtNDArray stack(PtNDArray[] arrays, int dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); + return PtNDArrayImpl.newPtNDArray( + arrays[0].getManager(), PyTorchLibrary.LIB.torchStack(pointers, dim)); } public static PtNDArray cat(PtNDArray[] arrays, long dim) { long[] pointers = Arrays.stream(arrays).mapToLong(PtNDArray::getHandle).toArray(); - return new PtNDArray(arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); + return PtNDArrayImpl.newPtNDArray( + arrays[0].getManager(), PyTorchLibrary.LIB.torchCat(pointers, dim)); } public static PtNDArray tile(PtNDArray ndArray, long[] repeats) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeat(ndArray.getHandle(), repeats)); } public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRepeatInterleave(ndArray.getHandle(), repeat, dim)); } public static PtNDArray softmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, DataType dTpe) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogSoftmax(ndArray.getHandle(), dim, dTpe.ordinal())); } public static PtNDArray argMax(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle())); } public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argMin(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle())); } public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchArgSort(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSort(ndArray.getHandle(), dim, descending)); } public static PtNDArray permute(PtNDArray ndArray, long[] dims) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchPermute(ndArray.getHandle(), dims)); } public static PtNDArray flip(PtNDArray ndArray, long[] dims) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlip(ndArray.getHandle(), dims)); } public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTranspose(ndArray.getHandle(), dim1, dim2)); } @@ -699,7 +702,7 @@ public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchAdd(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -709,7 +712,7 @@ public static void addi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchSub(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -719,7 +722,7 @@ public static void subi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMul(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -729,7 +732,7 @@ public static void muli(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchTrueDivide(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -739,7 +742,7 @@ public static void divi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchRemainder(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -749,7 +752,7 @@ public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchPow(ndArray1.getHandle(), ndArray2.getHandle())); } @@ -759,7 +762,7 @@ public static void powi(PtNDArray ndArray1, PtNDArray ndArray2) { } public static PtNDArray sign(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSign(ndArray.getHandle())); } @@ -768,104 +771,104 @@ public static void signi(PtNDArray ndArray) { } public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalAnd(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalOr(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchLogicalXor(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray logicalNot(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLogicalNot(ndArray.getHandle())); } public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2) { if (ndArray1.getShape().dimension() == 1) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchDot(ndArray1.getHandle(), ndArray2.getHandle())); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMatmul(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMaximum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray max(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle())); } public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMax(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray1.getManager(), PyTorchLibrary.LIB.torchMinimum(ndArray1.getHandle(), ndArray2.getHandle())); } public static PtNDArray min(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle())); } public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMin(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray mean(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle())); } public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchMean(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes) { long[] longaxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRot90(ndArray.getHandle(), times, longaxes)); } public static PtNDArray sum(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle())); } public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSum(ndArray.getHandle(), dims, keepDim)); } @@ -875,29 +878,29 @@ public static PtNDArray cumProd(PtNDArray ndArray, long dim, DataType dataType) if (dataType != null) { dtPosition = dataType.ordinal(); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumProd(ndArray.getHandle(), dim, dtPosition)); } public static PtNDArray prod(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle())); } public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchProd(ndArray.getHandle(), dim, keepDim)); } public static PtNDArray cumSum(PtNDArray ndArray, long dim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCumSum(ndArray.getHandle(), dim)); } public static PtNDArray oneHot(PtNDArray ndArray, int depth, DataType dataType) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNOneHot( ndArray.toType(DataType.INT64, false).getHandle(), depth)) @@ -908,7 +911,7 @@ public static NDList split(PtNDArray ndArray, long size, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), size, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArray(ndArray.getManager(), ptr)); + list.add(PtNDArrayImpl.newPtNDArray(ndArray.getManager(), ptr)); } return list; } @@ -917,34 +920,34 @@ public static NDList split(PtNDArray ndArray, long[] indices, long axis) { long[] ndPtrs = PyTorchLibrary.LIB.torchSplit(ndArray.getHandle(), indices, axis); NDList list = new NDList(); for (long ptr : ndPtrs) { - list.add(new PtNDArray(ndArray.getManager(), ptr)); + list.add(PtNDArrayImpl.newPtNDArray(ndArray.getManager(), ptr)); } return list; } public static PtNDArray squeeze(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle())); } public static PtNDArray squeeze(PtNDArray ndArray, long dim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqueeze(ndArray.getHandle(), dim)); } public static PtNDArray unsqueeze(PtNDArray ndArray, long dim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchUnsqueeze(ndArray.getHandle(), dim)); } public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFlatten(ndArray.getHandle(), startDim, endDim)); } public static PtNDArray fft(PtNDArray ndArray, long length, long axis) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFft(ndArray.getHandle(), length, axis)); } @@ -969,7 +972,7 @@ public static PtNDArray stft( if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray real(PtNDArray ndArray) { @@ -977,7 +980,7 @@ public static PtNDArray real(PtNDArray ndArray) { if (handle == -1) { throw new UnsupportedOperationException("real() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray complex(PtNDArray ndArray) { @@ -985,173 +988,173 @@ public static PtNDArray complex(PtNDArray ndArray) { if (handle == -1) { throw new UnsupportedOperationException("complex() is not supported."); } - return new PtNDArray(ndArray.getManager(), handle); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), handle); } public static PtNDArray abs(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAbs(ndArray.getHandle())); } public static PtNDArray square(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSquare(ndArray.getHandle())); } public static PtNDArray floor(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchFloor(ndArray.getHandle())); } public static PtNDArray ceil(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCeil(ndArray.getHandle())); } public static PtNDArray round(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchRound(ndArray.getHandle())); } public static PtNDArray trunc(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTrunc(ndArray.getHandle())); } public static PtNDArray clip(PtNDArray ndArray, Number min, Number max) { PtNDArray minNd = (PtNDArray) ndArray.getManager().create(min); PtNDArray maxNd = (PtNDArray) ndArray.getManager().create(max); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchClamp( ndArray.getHandle(), minNd.getHandle(), maxNd.getHandle())); } public static PtNDArray exp(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchExp(ndArray.getHandle())); } public static PtNDArray log(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog(ndArray.getHandle())); } public static PtNDArray log10(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog10(ndArray.getHandle())); } public static PtNDArray log2(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchLog2(ndArray.getHandle())); } public static PtNDArray sin(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSin(ndArray.getHandle())); } public static PtNDArray cos(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCos(ndArray.getHandle())); } public static PtNDArray tan(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTan(ndArray.getHandle())); } public static PtNDArray asin(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchASin(ndArray.getHandle())); } public static PtNDArray acos(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAcos(ndArray.getHandle())); } public static PtNDArray atan(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); } public static PtNDArray sinh(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSinh(ndArray.getHandle())); } public static PtNDArray cosh(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchCosh(ndArray.getHandle())); } public static PtNDArray tanh(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTanh(ndArray.getHandle())); } public static PtNDArray sigmoid(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSigmoid(ndArray.getHandle())); } public static PtNDArray all(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAll(ndArray.getHandle())); } public static PtNDArray any(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchAny(ndArray.getHandle())); } public static PtNDArray none(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNone(ndArray.getHandle())); } public static PtNDArray eq(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchEq(self.getHandle(), other.getHandle())); } public static PtNDArray neq(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchNeq(self.getHandle(), other.getHandle())); } public static PtNDArray gt(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGt(self.getHandle(), other.getHandle())); } public static PtNDArray gte(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchGte(self.getHandle(), other.getHandle())); } public static PtNDArray lt(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLt(self.getHandle(), other.getHandle())); } public static PtNDArray lte(PtNDArray self, PtNDArray other) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( self.getManager(), PyTorchLibrary.LIB.torchLte(self.getHandle(), other.getHandle())); } public static PtNDArray neg(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNeg(ndArray.getHandle())); } @@ -1160,12 +1163,12 @@ public static void negi(PtNDArray ndArray) { } public static PtNDArray isNaN(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsNaN(ndArray.getHandle())); } public static PtNDArray isInf(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchIsInf(ndArray.getHandle())); } @@ -1176,7 +1179,7 @@ public static PtNDArray randint( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchRandint( low, @@ -1190,7 +1193,7 @@ public static PtNDArray randint( public static PtNDArray randperm( PtNDManager manager, long n, DataType dataType, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchRandPerm( n, @@ -1207,7 +1210,7 @@ public static PtNDArray normal( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchNormal( mean, @@ -1226,7 +1229,7 @@ public static PtNDArray uniform( Shape size, DataType dataType, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.tensorUniform( low, @@ -1240,7 +1243,7 @@ public static PtNDArray uniform( public static PtNDArray eye( PtNDManager manager, int n, int m, DataType dataType, Device device, SparseFormat fmt) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchEye( n, @@ -1253,7 +1256,7 @@ public static PtNDArray eye( public static PtNDArray hannWindow( PtNDManager manager, long numPoints, boolean periodic, Device device) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( manager, PyTorchLibrary.LIB.torchHannWindow( numPoints, @@ -1262,25 +1265,25 @@ public static PtNDArray hannWindow( } public static PtNDArray erfinv(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } public static PtNDArray inverse(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); } public static PtNDArray interpolate( PtNDArray ndArray, long[] size, int mode, boolean alignCorners) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNInterpolate( ndArray.getHandle(), size, mode, alignCorners)); } public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNLinear( input.getHandle(), @@ -1289,44 +1292,44 @@ public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias } public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( input.getManager(), PyTorchLibrary.LIB.torchNNEmbedding(input.getHandle(), weight.getHandle(), sparse)); } public static PtNDArray relu(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNRelu(ndArray.getHandle())); } public static PtNDArray softPlus(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftPlus(ndArray.getHandle())); } public static PtNDArray softSign(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSoftSign(ndArray.getHandle())); } public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLeakyRelu(ndArray.getHandle(), negativeSlope)); } public static PtNDArray elu(PtNDArray ndArray, double alpha) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNElu(ndArray.getHandle(), alpha)); } public static PtNDArray selu(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNSelu(ndArray.getHandle())); } public static PtNDArray gelu(PtNDArray ndArray) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNGelu(ndArray.getHandle())); } @@ -1338,7 +1341,7 @@ public static PtNDArray convolution( Shape padding, Shape dilation, int groups) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNConvNd( ndArray.getHandle(), @@ -1359,7 +1362,7 @@ public static PtNDArray batchNorm( boolean isTraining, double momentum, double eps) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNBatchNorm( ndArray.getHandle(), @@ -1374,7 +1377,7 @@ public static PtNDArray batchNorm( public static PtNDArray layerNorm( PtNDArray ndArray, Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLayerNorm( ndArray.getHandle(), @@ -1385,13 +1388,13 @@ public static PtNDArray layerNorm( } public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNNormalize(ndArray.getHandle(), p, dim, eps)); } public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } @@ -1424,7 +1427,7 @@ public static NDList rnn( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1455,7 +1458,7 @@ public static NDList gru( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1488,7 +1491,7 @@ public static NDList lstm( batchFirst); NDList res = new NDList(); for (long output : outputs) { - res.add(new PtNDArray(manager, output)); + res.add(PtNDArrayImpl.newPtNDArray(manager, output)); } return res; } @@ -1500,7 +1503,7 @@ public static PtNDArray avgPool( Shape padding, boolean ceilMode, boolean countIncludePad) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAvgPool( ndArray.getHandle(), @@ -1513,7 +1516,7 @@ public static PtNDArray avgPool( public static PtNDArray maxPool( PtNDArray ndArray, Shape kernelSize, Shape stride, Shape padding, boolean ceilMode) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNMaxPool( ndArray.getHandle(), @@ -1524,14 +1527,14 @@ public static PtNDArray maxPool( } public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveMaxPool( ndArray.getHandle(), outputSize.getShape())); } public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, Shape outputSize) { - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNAdaptiveAvgPool( ndArray.getHandle(), outputSize.getShape())); @@ -1542,7 +1545,7 @@ public static PtNDArray lpPool( if (ndArray.getShape().dimension() - 2 == 3) { throw new UnsupportedOperationException("3D lpPool is not supported in PyTorch engine"); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNNLpPool( ndArray.getHandle(), @@ -1607,7 +1610,7 @@ public static void attachGradient(PtNDArray ndArray, boolean requiresGrad) { public static PtNDArray detachGradient(PtNDArray ndArray) { // TODO: detached ndarray may not use the same manager for the attached one - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchDetachGrad(ndArray.getHandle())); } @@ -1616,7 +1619,7 @@ public static PtNDArray getGradient(PtNDArray ndArray) { if (pointer == NULL_PTR) { return null; } - return new PtNDArray(ndArray.getManager(), pointer); + return PtNDArrayImpl.newPtNDArray(ndArray.getManager(), pointer); } public static void backward( @@ -1696,7 +1699,7 @@ public static NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager) { String[] names = PyTorchLibrary.LIB.moduleGetParamNames(block.getHandle()); NDList list = new NDList(handles.length); for (int i = 0; i < handles.length; i++) { - PtNDArray array = new PtNDArray(manager, handles[i]); + PtNDArray array = PtNDArrayImpl.newPtNDArray(manager, handles[i]); array.setName(names[i]); list.add(array); } @@ -1776,7 +1779,7 @@ public static int getLayout(PtNDArray array) { public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims) { long[] longAxes = Arrays.stream(axes).mapToLong(i -> i).toArray(); - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNorm(ndArray.getHandle(), ord, longAxes, keepDims)); } @@ -1785,7 +1788,7 @@ public static PtNDArray nonZeros(PtNDArray ndArray) { if (ndArray.isScalar()) { ndArray = (PtNDArray) ndArray.reshape(-1); } - return new PtNDArray( + return PtNDArrayImpl.newPtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchNonZeros(ndArray.getHandle())); } } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java new file mode 100644 index 00000000000..24544230795 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main.java @@ -0,0 +1,73 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration.gc; + +import static ai.djl.pytorch.engine.PtNDManager.debugDumpFromSystemManager; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; +import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** Some poc testing code. */ +public final class Main { + + private static final Logger logger = LoggerFactory.getLogger(Main.class); + + private Main() {} + + public static void main(String[] args) + throws IOException, TranslateException, InterruptedException { + SwitchGarbageCollection.on(); + try (NDManager baseManager = NDManager.newBaseManager(); ) { + try (NDManager subManager = baseManager.newSubManager()) { + + NDArray a = subManager.create(new float[] {1f}); + NDArray b = subManager.create(new float[] {2f}); + PtNDArray c = (PtNDArray) a.add(b); + logger.info( + "number of NDArrays in NDManager hierarchy {}", + c.getNumOfNDArraysInNDManagerHierarchy()); + logger.info( + "number of NDArrays in map used of gc triggered NDArray closing {}", + c.getNumOfNDArraysInGCMap()); + + debugDumpFromSystemManager(true); + + logger.info("reference exists ..."); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + a = null; + b = null; + c = null; + logger.info("no reference exists, but likely not yet garbage collected ..."); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + + System.gc(); // just for testing - do not use in production + TimeUnit.SECONDS.sleep(1); + + logger.info("no reference exists, and likely garbage collected ..."); + logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java new file mode 100644 index 00000000000..43d6de02510 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/Main2.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.integration.gc; + +import static ai.djl.pytorch.engine.PtNDManager.debugDumpFromSystemManager; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.gc.SwitchGarbageCollection; +import ai.djl.pytorch.engine.PtNDArray; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** Some poc testing code. */ +public final class Main2 { + + private static final Logger logger = LoggerFactory.getLogger(Main2.class); + + private Main2() {} + + public static void main(String[] args) + throws IOException, TranslateException, InterruptedException { + SwitchGarbageCollection.on(); + try (NDManager baseManager = NDManager.newBaseManager(); ) { + try (NDManager subManager = baseManager.newSubManager()) { + + NDArray a = subManager.create(new float[] {1f}); + NDArray b = subManager.create(new float[] {2f}); + PtNDArray c = (PtNDArray) a.add(b); + + debugDumpFromSystemManager(true); + + System.out.println("reference exists ..."); + baseManager.gc(); + debugDumpFromSystemManager(true); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + a = null; + b = null; + c = null; + System.out.println("no reference exists, but likely not yet garbage collected ..."); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + baseManager.gc(); + debugDumpFromSystemManager(true); + + System.gc(); // just for testing - do not use in production + TimeUnit.SECONDS.sleep(1); + + System.out.println("no reference exists, and likely garbage collected ..."); + // logger.info("weakHashMap size: {}", baseManager.getProxyMaker().mapSize()); + baseManager.gc(); + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } + debugDumpFromSystemManager(true); + } +} diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java new file mode 100644 index 00000000000..10391a021d0 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/gc/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** The integration test for testing PyTorch specific features. */ +package ai.djl.pytorch.integration.gc; diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 7265cf534d3..ad4dc53ea0b 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -23,7 +23,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; import ai.djl.tensorflow.engine.javacpp.JavacppUtils; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Preconditions; import org.tensorflow.internal.c_api.TFE_TensorHandle; @@ -40,7 +40,7 @@ /** {@code TfNDArray} is the TensorFlow implementation of {@link NDArray}. */ @SuppressWarnings("PMD.UseTryWithResources") -public class TfNDArray extends NativeResource implements NDArray { +public class TfNDArray extends NativeResourceImpl implements NDArray { private Shape shape; private Device device; diff --git a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java index 89a2dadccb1..403ca2c5003 100644 --- a/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java +++ b/examples/src/test/java/ai/djl/examples/training/TrainAirfoilWithTabNetTest.java @@ -24,6 +24,7 @@ public class TrainAirfoilWithTabNetTest { @Test public void testTrainAirfoilWithTabNet() throws TranslateException, IOException { + TestRequirements.nightly(); TestRequirements.engine("MXNet", "PyTorch"); String[] args = new String[] {"-g", "1", "-e", "20", "-b", "32"}; TrainingResult result = TrainAirfoilWithTabNet.runExample(args); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java index bad891e72da..32a02684f30 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java @@ -13,14 +13,14 @@ package ai.djl.fasttext.jni; import ai.djl.modality.Classifications; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import java.util.ArrayList; import java.util.List; /** A class containing utilities to interact with the fastText JNI layer. */ @SuppressWarnings("MissingJavadocMethod") -public final class FtWrapper extends NativeResource { +public final class FtWrapper extends NativeResourceImpl { private static RuntimeException libraryStatus; diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java index 44d05936e58..4cecc80de9e 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpProcessor.java @@ -14,10 +14,10 @@ import ai.djl.sentencepiece.jni.LibUtils; import ai.djl.sentencepiece.jni.SentencePieceLibrary; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; /** The processor holder for SentencePiece. */ -final class SpProcessor extends NativeResource { +final class SpProcessor extends NativeResourceImpl { private static RuntimeException libraryStatus; diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 8881292b820..6da9cbaca6d 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -18,7 +18,7 @@ import ai.djl.modality.nlp.preprocess.Tokenizer; import ai.djl.ndarray.NDManager; import ai.djl.translate.ArgumentsUtil; -import ai.djl.util.NativeResource; +import ai.djl.util.NativeResourceImpl; import ai.djl.util.Utils; import org.slf4j.Logger; @@ -37,7 +37,7 @@ * {@code HuggingFaceTokenizer} is a Huggingface tokenizer implementation of the {@link Tokenizer} * interface that converts sentences into token. */ -public final class HuggingFaceTokenizer extends NativeResource implements Tokenizer { +public final class HuggingFaceTokenizer extends NativeResourceImpl implements Tokenizer { private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class); diff --git a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java index d8469178dab..2ad4f5e3b50 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/GradientCollectorIntegrationTest.java @@ -82,25 +82,90 @@ public void testAutograd() { } } + @Test + public void testZeroGradients() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); + + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + NDArray b = a.mul(2); + + // Gradients are initially zero + Assert.assertEquals(a.getGradient().getFloat(), 0.0f); + + // Gradients are updated by backwards + gc.backward(b); + Assert.assertEquals(a.getGradient().getFloat(), 2.0f); + + // Gradients are cleared by zeroGradients + gc.zeroGradients(); + Assert.assertEquals(a.getGradient().getFloat(), 0.0f); + } + } + } + /** Tests that the gradients do not accumulate when closing the gradient collector. */ @Test public void testClearGradients() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { - NDArray variable = manager.create(0.0f); - variable.setRequiresGradient(true); + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); - Engine engine = manager.getEngine(); + Engine engine = Engine.getEngine(TestUtils.getEngine()); for (int i = 0; i < 3; i++) { - manager.zeroGradients(); try (GradientCollector gc = engine.newGradientCollector()) { - NDArray loss = variable.mul(2); - gc.backward(loss); + NDArray b = a.mul(2); + gc.backward(b); + } + Assert.assertEquals(a.getGradient().getFloat(), 2.0f); + } + } + } + + /** Tests that the gradients do accumulate within the same gradient collector. */ + @Test + public void testAccumulateGradients() { + // TODO: MXNet support for accumulating gradients does not currently work + TestRequirements.notEngine("MXNet"); + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + NDArray a = manager.create(0.0f); + a.setRequiresGradient(true); + + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + for (int i = 1; i <= 3; i++) { + NDArray b = a.mul(2); + gc.backward(b); + Assert.assertEquals(a.getGradient().getFloat(), 2.0f * i); } - Assert.assertEquals(variable.getGradient().getFloat(), 2.0f); } } } + /** + * Ensures that a gradient collector does not start when one is already created because they are + * global. + */ + @Test + @SuppressWarnings({"try", "PMD.UseTryWithResources"}) + public void testMultipleGradientCollectors() { + Assert.assertThrows( + () -> { + GradientCollector gc2 = null; + Engine engine = Engine.getEngine(TestUtils.getEngine()); + try (GradientCollector gc = engine.newGradientCollector()) { + gc2 = engine.newGradientCollector(); + gc2.close(); + } finally { + if (gc2 != null) { + gc2.close(); + } + } + }); + } + @Test public void testFreezeParameters() { try (Model model = Model.newInstance("model", TestUtils.getEngine())) { diff --git a/tools/conf/findbugs-exclude.xml b/tools/conf/findbugs-exclude.xml index b36584a1714..b882914b622 100644 --- a/tools/conf/findbugs-exclude.xml +++ b/tools/conf/findbugs-exclude.xml @@ -31,6 +31,10 @@ - + + + + +