diff --git a/api/src/main/java/ai/djl/patch/BasicParamPatch.java b/api/src/main/java/ai/djl/patch/BasicParamPatch.java new file mode 100644 index 000000000000..0a26157456f6 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/BasicParamPatch.java @@ -0,0 +1,135 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; +import ai.djl.ndarray.NDArray; +import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; +import ai.djl.training.GradientCollector; +import ai.djl.util.Pair; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** The basic implementation of a {@link ParamPatch}. */ +public class BasicParamPatch extends ParamPatch { + + Map data; + + /** + * Constructs a {@link BasicParamPatch} with patching data. + * + * @param data the patching data + */ + public BasicParamPatch(Map data) { + this.data = data; + } + + /** + * Makes a patch by comparing two models. + * + * @param source the source model + * @param target the target model + * @return a patch that would transform the source model to the target model + */ + public static BasicParamPatch makePatch(Model source, Model target) { + return BasicParamPatch.makePatch(source.getBlock(), target.getBlock()); + } + + /** + * Makes a patch by comparing two blocks. + * + * @param source the source block + * @param target the target block + * @return a patch that would transform the source block to the target block + */ + public static BasicParamPatch makePatch(Block source, Block target) { + return BasicParamPatch.makePatch(source.getParameters(), target.getParameters()); + } + + /** + * Makes a patch by comparing two {@link ParameterList}s. + * + * @param source the source {@link ParameterList} + * @param target the target {@link ParameterList} + * @return a patch that would transform the source {@link ParameterList} to the target {@link + * ParameterList}. + */ + public static BasicParamPatch makePatch(ParameterList source, ParameterList target) { + Map data = new ConcurrentHashMap<>(source.size()); + for (Pair sourcePair : source) { + String key = sourcePair.getKey(); + NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray()); + data.put(key, patchValue); + } + return new BasicParamPatch(data); + } + + /** + * Makes a patch from gradients. + * + *

This does not include learning rates or any other data from the {@link + * ai.djl.training.optimizer.Optimizer}. + * + *

Making the patch does not modify the existing gradients. After this, you can call {@link + * GradientCollector#zeroGradients()} to clear the gradients. + * + * @param block the block for which to collect gradients + * @param gradientCollector the {@link GradientCollector} of the gradients + * @return the gradients as a {@link BasicParamPatch}. + */ + public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) { + ParameterList params = block.getParameters(); + Map data = new ConcurrentHashMap<>(params.size()); + for (Pair param : params) { + String key = param.getKey(); + // Get gradient * -1 to account for gradient being subtracted from param + NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1); + data.put(key, patchValue); + } + return new BasicParamPatch(data); + } + + /** + * Makes a patch from gradients. + * + *

This does not include learning rates or any other data from the {@link + * ai.djl.training.optimizer.Optimizer}. + * + *

Making the patch does not modify the existing gradients. After this, you can call {@link + * GradientCollector#zeroGradients()} to clear the gradients. + * + * @param model the model for which to collect gradients + * @param gradientCollector the {@link GradientCollector} of the gradients + * @return the gradients as a {@link BasicParamPatch}. + */ + public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) { + return makePatch(model.getBlock(), gradientCollector); + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + return data.get(paramName).duplicate(); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (NDArray d : data.values()) { + d.close(); + } + } +} diff --git a/api/src/main/java/ai/djl/patch/LoRA.java b/api/src/main/java/ai/djl/patch/LoRA.java new file mode 100644 index 000000000000..6ceabafc8fe9 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/LoRA.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.ndarray.NDArray; +import ai.djl.util.Pair; + +import java.util.Map; + +/** + * A {@link ParamPatch} based on low-rank adapters. + * + *

Based on the paper LoRA: Low-Rank Adaptation of + * Large Language Models. + * + *

TODO This support for LoRA is still a placeholder and needs effective code for creating and + * training + */ +public class LoRA extends ParamPatch { + + /** Data of type map from param name to (A, B) pair. */ + Map> data; + + /** + * Constructs a {@link LoRA}. + * + * @param data the data to patch with + */ + public LoRA(Map> data) { + this.data = data; + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + Pair d = data.get(paramName); + return d.getKey().get(paramName).matMul(d.getValue().get(paramName)); + } + + /** {@inheritDoc} */ + @Override + public void close() { + for (Pair d : data.values()) { + d.getKey().close(); + d.getValue().close(); + } + } +} diff --git a/api/src/main/java/ai/djl/patch/ParamPatch.java b/api/src/main/java/ai/djl/patch/ParamPatch.java new file mode 100644 index 000000000000..9a2fef1c5e5b --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ParamPatch.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; +import ai.djl.ndarray.NDArray; +import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; +import ai.djl.util.Pair; + +/** + * A standard {@link Patch} that only adds to {@link Parameter}s. + * + *

To create a param patch, see {@link BasicParamPatch}. + */ +public abstract class ParamPatch extends ReversiblePatch { + + /** + * Scales the patch by a scalar multiplier. + * + * @param scale the scalar multiplier for each patch NDArray + * @return a new patch that is a scaled version of this patch + */ + public ParamPatch scale(float scale) { + return new ScaledParamPatch(scale, this); + } + + /** {@inheritDoc} */ + @Override + public ParamPatch reverse() { + return scale(-1); + } + + /** + * Returns the patch {@link NDArray} for a particular paramName. + * + * @param paramName the parameter path in a {@link ParameterList}. + * @return the patch array + */ + public abstract NDArray getPatch(String paramName); + + /** + * Applies the part of this patch to a particular {@link Parameter}. + * + * @param paramName the parameter path in a {@link ParameterList}. + * @param param the {@link Parameter} to patch + */ + public void apply(String paramName, Parameter param) { + NDArray p = getPatch(paramName).duplicate(); + param.getArray().addi(p); + p.close(); + } + + /** + * Applies this patch to a {@link ParameterList}. + * + * @param params the params to patch + */ + public void apply(ParameterList params) { + for (Pair param : params) { + apply(param.getKey(), param.getValue()); + } + } + + /** + * Applies this patch to a {@link Block}. + * + * @param block the block to patch + */ + public void apply(Block block) { + apply(block.getParameters()); + } + + /** + * Applies this patch to a {@link Model}. + * + * @param model the model to patch + */ + @Override + public void apply(Model model) { + apply(model.getBlock()); + } +} diff --git a/api/src/main/java/ai/djl/patch/Patch.java b/api/src/main/java/ai/djl/patch/Patch.java new file mode 100644 index 000000000000..4ec8f3abb0a2 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/Patch.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.Model; + +/** + * A method for modifying a {@link Model}. + * + *

The most standard form is the {@link ParamPatch}. + */ +public abstract class Patch implements AutoCloseable { + + /** + * Applies this patch to a model. + * + * @param model the model to update with the patch + */ + public abstract void apply(Model model); + + /** {@inheritDoc} */ + @Override + public abstract void close(); +} diff --git a/api/src/main/java/ai/djl/patch/ReversiblePatch.java b/api/src/main/java/ai/djl/patch/ReversiblePatch.java new file mode 100644 index 000000000000..415d8b0b2320 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ReversiblePatch.java @@ -0,0 +1,26 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +/** A {@link Patch} that can be reversed. */ +public abstract class ReversiblePatch extends Patch { + + /** + * Returns a new {@link Patch} that reverses the effect of this one. + * + *

For a {@link ParamPatch}, it is equivalent to scaling by -1. + * + * @return a new {@link Patch} that reverses the effect of this one. + */ + public abstract ParamPatch reverse(); +} diff --git a/api/src/main/java/ai/djl/patch/ScaledParamPatch.java b/api/src/main/java/ai/djl/patch/ScaledParamPatch.java new file mode 100644 index 000000000000..fee95ab33386 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/ScaledParamPatch.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.patch; + +import ai.djl.ndarray.NDArray; + +/** + * Constructs a {@link ScaledParamPatch} to scale a {@link ParamPatch} by a scalar multiplier. + * + * @see ParamPatch#scale(float) + */ +public class ScaledParamPatch extends ParamPatch { + + float scale; + ParamPatch base; + + /** + * Constructs a {@link ScaledParamPatch}. + * + * @param scale the scalar multiplier + * @param base the {@link ParamPatch} to scale + */ + public ScaledParamPatch(float scale, ParamPatch base) { + if (base instanceof ScaledParamPatch) { + ScaledParamPatch sbase = (ScaledParamPatch) base; + this.scale = scale * sbase.scale; + this.base = sbase.base; + } else { + this.scale = scale; + this.base = base; + } + } + + /** {@inheritDoc} */ + @Override + public NDArray getPatch(String paramName) { + return base.getPatch(paramName).muli(scale); + } + + /** {@inheritDoc} */ + @Override + public void close() { + base.close(); + } +} diff --git a/api/src/main/java/ai/djl/patch/package-info.java b/api/src/main/java/ai/djl/patch/package-info.java new file mode 100644 index 000000000000..26ee925b46b0 --- /dev/null +++ b/api/src/main/java/ai/djl/patch/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to modify a model ({@link ai.djl.patch.Patch} and implementations). */ +package ai.djl.patch; diff --git a/integration/src/main/java/ai/djl/integration/IntegrationTest.java b/integration/src/main/java/ai/djl/integration/IntegrationTest.java index a55cec9a22c0..cf8387828767 100644 --- a/integration/src/main/java/ai/djl/integration/IntegrationTest.java +++ b/integration/src/main/java/ai/djl/integration/IntegrationTest.java @@ -305,7 +305,7 @@ public boolean beforeClass() { } return true; } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } return false; } @@ -316,7 +316,7 @@ public void afterClass() { method.invoke(object); } } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } } @@ -327,7 +327,7 @@ public boolean beforeTest() { } return true; } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } return false; } @@ -338,7 +338,7 @@ public void afterTest() { method.invoke(object); } } catch (InvocationTargetException | IllegalAccessException e) { - logger.error("", e.getCause()); + logExceptionCause(e); } } @@ -368,7 +368,7 @@ public TestResult runTest(int index) { result = TestResult.UNSUPPORTED; } else { logger.error("Test {}.{} FAILED", getName(), method.getName()); - logger.error("", e.getCause()); + logExceptionCause(e); result = TestResult.FAILED; } } finally { @@ -398,6 +398,14 @@ private static boolean expectedException(Method method, Exception e) { } return false; } + + private void logExceptionCause(Exception e) { + if (e.getCause() != null) { + logger.error("", e.getCause()); + } else { + logger.error("", e); + } + } } public enum TestResult { diff --git a/integration/src/main/java/ai/djl/integration/tests/PatchTest.java b/integration/src/main/java/ai/djl/integration/tests/PatchTest.java new file mode 100644 index 000000000000..febacc261186 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/PatchTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.integration.tests; + +import ai.djl.Model; +import ai.djl.basicmodelzoo.basic.Mlp; +import ai.djl.integration.util.TestUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Parameter; +import ai.djl.patch.BasicParamPatch; +import ai.djl.patch.ParamPatch; +import ai.djl.training.DefaultTrainingConfig; +import ai.djl.training.GradientCollector; +import ai.djl.training.Trainer; +import ai.djl.training.initializer.Initializer; +import ai.djl.training.loss.Loss; +import ai.djl.training.optimizer.Optimizer; +import ai.djl.training.tracker.Tracker; +import ai.djl.util.Pair; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** Tests for {@link ai.djl.patch.Patch}. */ +public class PatchTest { + + @Test + public void testScaleReverse() { + try (Model model = Model.newInstance("model", TestUtils.getEngine())) { + testMlp(model); + double initialParamSum = paramSum(model); + + // Create patch + Map patchData = new ConcurrentHashMap<>(); + for (Pair param : model.getBlock().getParameters()) { + patchData.put(param.getKey(), param.getValue().getArray().onesLike()); + } + try (BasicParamPatch patch = new BasicParamPatch(patchData)) { + patch.scale(3).apply(model); + Assert.assertEquals(paramSum(model), initialParamSum + 3 * paramSize(model)); + + patch.reverse().apply(model); + Assert.assertEquals(paramSum(model), initialParamSum + 2 * paramSize(model)); + + patch.scale(-2).apply(model); + Assert.assertEquals(paramSum(model), initialParamSum); + } + } + } + + @Test + public void testComparison() { + try (Model model0 = Model.newInstance("m0", TestUtils.getEngine()); + Model model1 = Model.newInstance("m1", TestUtils.getEngine())) { + testMlp(model0, Initializer.ZEROS); + testMlp(model1, Initializer.ONES); + + ParamPatch patch = BasicParamPatch.makePatch(model0, model1); + patch.apply(model0); + Assert.assertEquals(paramSum(model1), paramSum(model0)); + } + } + + @Test + public void testGradients() { + try (Model model = Model.newInstance("model", TestUtils.getEngine())) { + testMlp(model, null); + try (Trainer trainer = + model.newTrainer( + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer( + Optimizer.sgd() + .setLearningRateTracker(Tracker.fixed(1.0f)) + .build()) + .optInitializer(Initializer.ONES, p -> true))) { + trainer.initialize(new Shape(10)); + double initialParamSum = paramSum(model); + ParamPatch patch; + try (GradientCollector collector = trainer.newGradientCollector()) { + NDList preds = + trainer.forward( + new NDList(model.getNDManager().ones(new Shape(10))), + new NDList(model.getNDManager().ones(new Shape(10)))); + NDArray loss = + trainer.getLoss() + .evaluate( + new NDList( + model.getNDManager().full(new Shape(1), 100)), + preds); + collector.backward(loss); + patch = BasicParamPatch.makePatch(trainer.getModel(), collector); + } + trainer.step(); + + Assert.assertNotEquals(paramSum(model), initialParamSum); + // Note that to reverse a gradient update, you must also account for learning rate + patch.reverse().apply(model); + Assert.assertEquals(paramSum(model), initialParamSum); + } + } + } + + private double paramSum(Model model) { + return model.getBlock().getParameters().values().stream() + .mapToDouble(p -> p.getArray().sum().toType(DataType.FLOAT32, true).getFloat()) + .sum(); + } + + private long paramSize(Model model) { + return model.getBlock().getParameters().values().stream() + .mapToLong(p -> p.getArray().getShape().size()) + .sum(); + } + + private void testMlp(Model model) { + testMlp(model, Initializer.ONES); + } + + private void testMlp(Model model, Initializer initializer) { + Mlp block = new Mlp(10, 1, new int[] {10}); + if (initializer != null) { + block.setInitializer(initializer, p -> true); + block.initialize(model.getNDManager(), DataType.FLOAT32, new Shape(10)); + } + model.setBlock(block); + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/package-info.java b/integration/src/main/java/ai/djl/integration/tests/package-info.java new file mode 100644 index 000000000000..61d65b0f6953 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the engine for {@link ai.djl}. */ +package ai.djl.integration.tests;