From be182f63e822f7b9c4bc3f0a9d79b770017ee3d5 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Wed, 19 Jul 2023 15:34:59 -0700 Subject: [PATCH] Creates Patch This creates the Patch concept along with some start of usages. There is a more specialized ParamPatch for the standard parameter additive patches and a Scaled, Basic, and LoRA implementation. The patches can be created directly, by comparing models, and from gradients. This is an initial step. Following this, there are a few pieces of work that could be considered: 1. DJL Serving Python engine specific patch implementation 2. LoRA for full training 3. Make BasicParamPatch from Optimizer (including gradients, momentum, and lr) Additionally, I included some changes to the IntegrationTest. I ran into the dumb issue where I made the tests private which makes them unable to run from IntegrationTest. Worse, the exceptions had no cause and therefore they wouldn't print any message or run to give feedback through println or logger. It still runs fine in IntelliJ too, making this issue only show up through gradle. After this change, it would provide a clear exception message which makes this easy to debug in the future. --- .../java/ai/djl/patch/BasicParamPatch.java | 135 +++++++++++++++++ api/src/main/java/ai/djl/patch/LoRA.java | 58 +++++++ .../main/java/ai/djl/patch/ParamPatch.java | 94 ++++++++++++ api/src/main/java/ai/djl/patch/Patch.java | 34 +++++ .../java/ai/djl/patch/ReversiblePatch.java | 26 ++++ .../java/ai/djl/patch/ScaledParamPatch.java | 55 +++++++ .../main/java/ai/djl/patch/package-info.java | 15 ++ .../ai/djl/integration/IntegrationTest.java | 21 ++- .../ai/djl/integration/tests/PatchTest.java | 143 ++++++++++++++++++ .../djl/integration/tests/package-info.java | 15 ++ 10 files changed, 590 insertions(+), 6 deletions(-) create mode 100644 api/src/main/java/ai/djl/patch/BasicParamPatch.java create mode 100644 api/src/main/java/ai/djl/patch/LoRA.java create mode 100644 api/src/main/java/ai/djl/patch/ParamPatch.java create mode 100644 api/src/main/java/ai/djl/patch/Patch.java create mode 100644 api/src/main/java/ai/djl/patch/ReversiblePatch.java create mode 100644 api/src/main/java/ai/djl/patch/ScaledParamPatch.java create mode 100644 api/src/main/java/ai/djl/patch/package-info.java create mode 100644 integration/src/main/java/ai/djl/integration/tests/PatchTest.java create mode 100644 integration/src/main/java/ai/djl/integration/tests/package-info.java diff --git a/api/src/main/java/ai/djl/patch/BasicParamPatch.java b/api/src/main/java/ai/djl/patch/BasicParamPatch.java new file mode 100644 index 00000000000..0a26157456f --- /dev/null +++ b/api/src/main/java/ai/djl/patch/BasicParamPatch.java @@ -0,0 +1,135 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; +import ai.djl.ndarray.NDArray; +import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; +import ai.djl.training.GradientCollector; +import ai.djl.util.Pair; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** The basic implementation of a {@link ParamPatch}. */ +public class BasicParamPatch extends ParamPatch { + + Map data; + + /** + * Constructs a {@link BasicParamPatch} with patching data. + * + * @param data the patching data + */ + public BasicParamPatch(Map data) { + this.data = data; + } + + /** + * Makes a patch by comparing two models. + * + * @param source the source model + * @param target the target model + * @return a patch that would transform the source model to the target model + */ + public static BasicParamPatch makePatch(Model source, Model target) { + return BasicParamPatch.makePatch(source.getBlock(), target.getBlock()); + } + + /** + * Makes a patch by comparing two blocks. + * + * @param source the source block + * @param target the target block + * @return a patch that would transform the source block to the target block + */ + public static BasicParamPatch makePatch(Block source, Block target) { + return BasicParamPatch.makePatch(source.getParameters(), target.getParameters()); + } + + /** + * Makes a patch by comparing two {@link ParameterList}s. + * + * @param source the source {@link ParameterList} + * @param target the target {@link ParameterList} + * @return a patch that would transform the source {@link ParameterList} to the target {@link + * ParameterList}. + */ + public static BasicParamPatch makePatch(ParameterList source, ParameterList target) { + Map data = new ConcurrentHashMap<>(source.size()); + for (Pair sourcePair : source) { + String key = sourcePair.getKey(); + NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray()); + data.put(key, patchValue); + } + return new BasicParamPatch(data); + } + + /** + * Makes a patch from gradients. + * + *

This does not include learning rates or any other data from the {@link + * ai.djl.training.optimizer.Optimizer}. + * + *

Making the patch does not modify the existing gradients. After this, you can call {@link + * GradientCollector#zeroGradients()} to clear the gradients. + * + * @param block the block for which to collect gradients + * @param gradientCollector the {@link GradientCollector} of the gradients + * @return the gradients as a {@link BasicParamPatch}. + */ + public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) { + ParameterList params = block.getParameters(); + Map data = new ConcurrentHashMap<>(params.size()); + for (Pair param : params) { + String key = param.getKey(); + // Get gradient * -1 to account for gradient being subtracted from param + NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1); + data.put(key, patchValue); + } + return new BasicParamPatch(data); + } + + /** + * Makes a patch from gradients. + * + *

This does not include learning rates or any other data from the {@link + * ai.djl.training.optimizer.Optimizer}. + * + *

Making the patch does not modify the existing gradients. After this, you can call {@link + * GradientCollector#zeroGradients()} to clear the gradients. + * + * @param model the model for which to collect gradients + * @param gradientCollector the {@link GradientCollector} of the gradients + * @return the gradients as a {@link BasicParamPatch}. + */ + public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) { + return makePatch(model.getBlock(), gradientCollector); + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + return data.get(paramName).duplicate(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (NDArray d : data.values()) { + d.close(); + } + } +} diff --git a/api/src/main/java/ai/djl/patch/LoRA.java b/api/src/main/java/ai/djl/patch/LoRA.java new file mode 100644 index 00000000000..6ceabafc8fe --- /dev/null +++ b/api/src/main/java/ai/djl/patch/LoRA.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.ndarray.NDArray; +import ai.djl.util.Pair; + +import java.util.Map; + +/** + * A {@link ParamPatch} based on low-rank adapters. + * + *

Based on the paper LoRA: Low-Rank Adaptation of + * Large Language Models. + * + *

TODO This support for LoRA is still a placeholder and needs effective code for creating and + * training + */ +public class LoRA extends ParamPatch { + + /** Data of type map from param name to (A, B) pair. */ + Map> data; + + /** + * Constructs a {@link LoRA}. + * + * @param data the data to patch with + */ + public LoRA(Map> data) { + this.data = data; + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + Pair d = data.get(paramName); + return d.getKey().get(paramName).matMul(d.getValue().get(paramName)); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (Pair d : data.values()) { + d.getKey().close(); + d.getValue().close(); + } + } +} diff --git a/api/src/main/java/ai/djl/patch/ParamPatch.java b/api/src/main/java/ai/djl/patch/ParamPatch.java new file mode 100644 index 00000000000..9a2fef1c5e5 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ParamPatch.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; +import ai.djl.ndarray.NDArray; +import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; +import ai.djl.util.Pair; + +/** + * A standard {@link Patch} that only adds to {@link Parameter}s. + * + *

To create a param patch, see {@link BasicParamPatch}. + */ +public abstract class ParamPatch extends ReversiblePatch { + + /** + * Scales the patch by a scalar multiplier. + * + * @param scale the scalar multiplier for each patch NDArray + * @return a new patch that is a scaled version of this patch + */ + public ParamPatch scale(float scale) { + return new ScaledParamPatch(scale, this); + } + + /** {@inheritDoc} */ + @Override + public ParamPatch reverse() { + return scale(-1); + } + + /** + * Returns the patch {@link NDArray} for a particular paramName. + * + * @param paramName the parameter path in a {@link ParameterList}. + * @return the patch array + */ + public abstract NDArray getPatch(String paramName); + + /** + * Applies the part of this patch to a particular {@link Parameter}. + * + * @param paramName the parameter path in a {@link ParameterList}. + * @param param the {@link Parameter} to patch + */ + public void apply(String paramName, Parameter param) { + NDArray p = getPatch(paramName).duplicate(); + param.getArray().addi(p); + p.close(); + } + + /** + * Applies this patch to a {@link ParameterList}. + * + * @param params the params to patch + */ + public void apply(ParameterList params) { + for (Pair param : params) { + apply(param.getKey(), param.getValue()); + } + } + + /** + * Applies this patch to a {@link Block}. + * + * @param block the block to patch + */ + public void apply(Block block) { + apply(block.getParameters()); + } + + /** + * Applies this patch to a {@link Model}. + * + * @param model the model to patch + */ + @Override + public void apply(Model model) { + apply(model.getBlock()); + } +} diff --git a/api/src/main/java/ai/djl/patch/Patch.java b/api/src/main/java/ai/djl/patch/Patch.java new file mode 100644 index 00000000000..4ec8f3abb0a --- /dev/null +++ b/api/src/main/java/ai/djl/patch/Patch.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; + +/** + * A method for modifying a {@link Model}. + * + *

The most standard form is the {@link ParamPatch}. + */ +public abstract class Patch implements AutoCloseable { + + /** + * Applies this patch to a model. + * + * @param model the model to update with the patch + */ + public abstract void apply(Model model); + + /** {@inheritDoc} */ + @Override + public abstract void close(); +} diff --git a/api/src/main/java/ai/djl/patch/ReversiblePatch.java b/api/src/main/java/ai/djl/patch/ReversiblePatch.java new file mode 100644 index 00000000000..415d8b0b232 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ReversiblePatch.java @@ -0,0 +1,26 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +/** A {@link Patch} that can be reversed. */ +public abstract class ReversiblePatch extends Patch { + + /** + * Returns a new {@link Patch} that reverses the effect of this one. + * + *

For a {@link ParamPatch}, it is equivalent to scaling by -1. + * + * @return a new {@link Patch} that reverses the effect of this one. + */ + public abstract ParamPatch reverse(); +} diff --git a/api/src/main/java/ai/djl/patch/ScaledParamPatch.java b/api/src/main/java/ai/djl/patch/ScaledParamPatch.java new file mode 100644 index 00000000000..fee95ab3338 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ScaledParamPatch.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.ndarray.NDArray; + +/** + * Constructs a {@link ScaledParamPatch} to scale a {@link ParamPatch} by a scalar multiplier. + * + * @see ParamPatch#scale(float) + */ +public class ScaledParamPatch extends ParamPatch { + + float scale; + ParamPatch base; + + /** + * Constructs a {@link ScaledParamPatch}. + * + * @param scale the scalar multiplier + * @param base the {@link ParamPatch} to scale + */ + public ScaledParamPatch(float scale, ParamPatch base) { + if (base instanceof ScaledParamPatch) { + ScaledParamPatch sbase = (ScaledParamPatch) base; + this.scale = scale * sbase.scale; + this.base = sbase.base; + } else { + this.scale = scale; + this.base = base; + } + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + return base.getPatch(paramName).muli(scale); + } + + /** {@inheritDoc} */ + @Override + public void close() { + base.close(); + } +} diff --git a/api/src/main/java/ai/djl/patch/package-info.java b/api/src/main/java/ai/djl/patch/package-info.java new file mode 100644 index 00000000000..26ee925b46b --- /dev/null +++ b/api/src/main/java/ai/djl/patch/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to modify a model ({@link ai.djl.patch.Patch} and implementations). */ +package ai.djl.patch; diff --git a/integration/src/main/java/ai/djl/integration/IntegrationTest.java b/integration/src/main/java/ai/djl/integration/IntegrationTest.java index 9f9bcf6e199..5f180b86647 100644 --- a/integration/src/main/java/ai/djl/integration/IntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/IntegrationTest.java @@ -311,7 +311,7 @@ public boolean beforeClass() { } return true; } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } return false; } @@ -322,7 +322,7 @@ public void afterClass() { method.invoke(object); } } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } } @@ -333,7 +333,7 @@ public boolean beforeTest() { } return true; } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } return false; } @@ -344,7 +344,7 @@ public void afterTest() { method.invoke(object); } } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } } @@ -370,10 +370,11 @@ public TestResult runTest(int index) { result = TestResult.SKIPPED; } else if (e.getCause() instanceof UnsupportedOperationException) { logger.info("Test {}.{} UNSUPPORTED", getName(), method.getName()); - logger.trace("", e.getCause()); + logExceptionCause(e); result = TestResult.UNSUPPORTED; } else { - logger.error("Test {}.{} FAILED", getName(), method.getName(), e.getCause()); + logger.error("Test {}.{} FAILED", getName(), method.getName()); + logExceptionCause(e); result = TestResult.FAILED; } } finally { @@ -403,6 +404,14 @@ private static boolean expectedException(Method method, Exception e) { } return false; } + + private void logExceptionCause(Exception e) { + if (e.getCause() != null) { + logger.error("", e.getCause()); + } else { + logger.error("", e); + } + } } public enum TestResult { diff --git a/integration/src/main/java/ai/djl/integration/tests/PatchTest.java b/integration/src/main/java/ai/djl/integration/tests/PatchTest.java new file mode 100644 index 00000000000..febacc26118 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/PatchTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests; + +import ai.djl.Model; +import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.integration.util.TestUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.patch.BasicParamPatch; +import ai.djl.patch.ParamPatch; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.GradientCollector; +import ai.djl.training.Trainer; +import ai.djl.training.initializer.Initializer; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; +import ai.djl.util.Pair; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** Tests for {@link ai.djl.patch.Patch}. */ +public class PatchTest { + + @Test + public void testScaleReverse() { + try (Model model = Model.newInstance("model", TestUtils.getEngine())) { + testMlp(model); + double initialParamSum = paramSum(model); + + // Create patch + Map patchData = new ConcurrentHashMap<>(); + for (Pair param : model.getBlock().getParameters()) { + patchData.put(param.getKey(), param.getValue().getArray().onesLike()); + } + try (BasicParamPatch patch = new BasicParamPatch(patchData)) { + patch.scale(3).apply(model); + Assert.assertEquals(paramSum(model), initialParamSum + 3 * paramSize(model)); + + patch.reverse().apply(model); + Assert.assertEquals(paramSum(model), initialParamSum + 2 * paramSize(model)); + + patch.scale(-2).apply(model); + Assert.assertEquals(paramSum(model), initialParamSum); + } + } + } + + @Test + public void testComparison() { + try (Model model0 = Model.newInstance("m0", TestUtils.getEngine()); + Model model1 = Model.newInstance("m1", TestUtils.getEngine())) { + testMlp(model0, Initializer.ZEROS); + testMlp(model1, Initializer.ONES); + + ParamPatch patch = BasicParamPatch.makePatch(model0, model1); + patch.apply(model0); + Assert.assertEquals(paramSum(model1), paramSum(model0)); + } + } + + @Test + public void testGradients() { + try (Model model = Model.newInstance("model", TestUtils.getEngine())) { + testMlp(model, null); + try (Trainer trainer = + model.newTrainer( + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer( + Optimizer.sgd() + .setLearningRateTracker(Tracker.fixed(1.0f)) + .build()) + .optInitializer(Initializer.ONES, p -> true))) { + trainer.initialize(new Shape(10)); + double initialParamSum = paramSum(model); + ParamPatch patch; + try (GradientCollector collector = trainer.newGradientCollector()) { + NDList preds = + trainer.forward( + new NDList(model.getNDManager().ones(new Shape(10))), + new NDList(model.getNDManager().ones(new Shape(10)))); + NDArray loss = + trainer.getLoss() + .evaluate( + new NDList( + model.getNDManager().full(new Shape(1), 100)), + preds); + collector.backward(loss); + patch = BasicParamPatch.makePatch(trainer.getModel(), collector); + } + trainer.step(); + + Assert.assertNotEquals(paramSum(model), initialParamSum); + // Note that to reverse a gradient update, you must also account for learning rate + patch.reverse().apply(model); + Assert.assertEquals(paramSum(model), initialParamSum); + } + } + } + + private double paramSum(Model model) { + return model.getBlock().getParameters().values().stream() + .mapToDouble(p -> p.getArray().sum().toType(DataType.FLOAT32, true).getFloat()) + .sum(); + } + + private long paramSize(Model model) { + return model.getBlock().getParameters().values().stream() + .mapToLong(p -> p.getArray().getShape().size()) + .sum(); + } + + private void testMlp(Model model) { + testMlp(model, Initializer.ONES); + } + + private void testMlp(Model model, Initializer initializer) { + Mlp block = new Mlp(10, 1, new int[] {10}); + if (initializer != null) { + block.setInitializer(initializer, p -> true); + block.initialize(model.getNDManager(), DataType.FLOAT32, new Shape(10)); + } + model.setBlock(block); + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/package-info.java b/integration/src/main/java/ai/djl/integration/tests/package-info.java new file mode 100644 index 00000000000..61d65b0f695 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the engine for {@link ai.djl}. */ +package ai.djl.integration.tests;