-
Notifications
You must be signed in to change notification settings - Fork 668
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This creates the Patch concept along with some start of usages. There is a more specialized ParamPatch for the standard parameter additive patches and a Scaled, Basic, and LoRA implementation. The patches can be created directly, by comparing models, and from gradients. This is an initial step. Following this, there are a few pieces of work that could be considered: 1. DJL Serving Python engine specific patch implementation 2. LoRA for full training 3. Make BasicParamPatch from Optimizer (including gradients, momentum, and lr) Additionally, I included some changes to the IntegrationTest. I ran into the dumb issue where I made the tests private which makes them unable to run from IntegrationTest. Worse, the exceptions had no cause and therefore they wouldn't print any message or run to give feedback through println or logger. It still runs fine in IntelliJ too, making this issue only show up through gradle. After this change, it would provide a clear exception message which makes this easy to debug in the future.
- Loading branch information
Showing
10 changed files
with
590 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.nn.Block; | ||
import ai.djl.nn.Parameter; | ||
import ai.djl.nn.ParameterList; | ||
import ai.djl.training.GradientCollector; | ||
import ai.djl.util.Pair; | ||
|
||
import java.util.Map; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
/** The basic implementation of a {@link ParamPatch}. */ | ||
public class BasicParamPatch extends ParamPatch { | ||
|
||
Map<String, NDArray> data; | ||
|
||
/** | ||
* Constructs a {@link BasicParamPatch} with patching data. | ||
* | ||
* @param data the patching data | ||
*/ | ||
public BasicParamPatch(Map<String, NDArray> data) { | ||
this.data = data; | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two models. | ||
* | ||
* @param source the source model | ||
* @param target the target model | ||
* @return a patch that would transform the source model to the target model | ||
*/ | ||
public static BasicParamPatch makePatch(Model source, Model target) { | ||
return BasicParamPatch.makePatch(source.getBlock(), target.getBlock()); | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two blocks. | ||
* | ||
* @param source the source block | ||
* @param target the target block | ||
* @return a patch that would transform the source block to the target block | ||
*/ | ||
public static BasicParamPatch makePatch(Block source, Block target) { | ||
return BasicParamPatch.makePatch(source.getParameters(), target.getParameters()); | ||
} | ||
|
||
/** | ||
* Makes a patch by comparing two {@link ParameterList}s. | ||
* | ||
* @param source the source {@link ParameterList} | ||
* @param target the target {@link ParameterList} | ||
* @return a patch that would transform the source {@link ParameterList} to the target {@link | ||
* ParameterList}. | ||
*/ | ||
public static BasicParamPatch makePatch(ParameterList source, ParameterList target) { | ||
Map<String, NDArray> data = new ConcurrentHashMap<>(source.size()); | ||
for (Pair<String, Parameter> sourcePair : source) { | ||
String key = sourcePair.getKey(); | ||
NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray()); | ||
data.put(key, patchValue); | ||
} | ||
return new BasicParamPatch(data); | ||
} | ||
|
||
/** | ||
* Makes a patch from gradients. | ||
* | ||
* <p>This does not include learning rates or any other data from the {@link | ||
* ai.djl.training.optimizer.Optimizer}. | ||
* | ||
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link | ||
* GradientCollector#zeroGradients()} to clear the gradients. | ||
* | ||
* @param block the block for which to collect gradients | ||
* @param gradientCollector the {@link GradientCollector} of the gradients | ||
* @return the gradients as a {@link BasicParamPatch}. | ||
*/ | ||
public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) { | ||
ParameterList params = block.getParameters(); | ||
Map<String, NDArray> data = new ConcurrentHashMap<>(params.size()); | ||
for (Pair<String, Parameter> param : params) { | ||
String key = param.getKey(); | ||
// Get gradient * -1 to account for gradient being subtracted from param | ||
NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1); | ||
data.put(key, patchValue); | ||
} | ||
return new BasicParamPatch(data); | ||
} | ||
|
||
/** | ||
* Makes a patch from gradients. | ||
* | ||
* <p>This does not include learning rates or any other data from the {@link | ||
* ai.djl.training.optimizer.Optimizer}. | ||
* | ||
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link | ||
* GradientCollector#zeroGradients()} to clear the gradients. | ||
* | ||
* @param model the model for which to collect gradients | ||
* @param gradientCollector the {@link GradientCollector} of the gradients | ||
* @return the gradients as a {@link BasicParamPatch}. | ||
*/ | ||
public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) { | ||
return makePatch(model.getBlock(), gradientCollector); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPatch(String paramName) { | ||
return data.get(paramName).duplicate(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
for (NDArray d : data.values()) { | ||
d.close(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.util.Pair; | ||
|
||
import java.util.Map; | ||
|
||
/** | ||
* A {@link ParamPatch} based on low-rank adapters. | ||
* | ||
* <p>Based on the paper <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of | ||
* Large Language Models</a>. | ||
* | ||
* <p>TODO This support for LoRA is still a placeholder and needs effective code for creating and | ||
* training | ||
*/ | ||
public class LoRA extends ParamPatch { | ||
|
||
/** Data of type map from param name to (A, B) pair. */ | ||
Map<String, Pair<NDArray, NDArray>> data; | ||
|
||
/** | ||
* Constructs a {@link LoRA}. | ||
* | ||
* @param data the data to patch with | ||
*/ | ||
public LoRA(Map<String, Pair<NDArray, NDArray>> data) { | ||
this.data = data; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public NDArray getPatch(String paramName) { | ||
Pair<NDArray, NDArray> d = data.get(paramName); | ||
return d.getKey().get(paramName).matMul(d.getValue().get(paramName)); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void close() { | ||
for (Pair<NDArray, NDArray> d : data.values()) { | ||
d.getKey().close(); | ||
d.getValue().close(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.nn.Block; | ||
import ai.djl.nn.Parameter; | ||
import ai.djl.nn.ParameterList; | ||
import ai.djl.util.Pair; | ||
|
||
/** | ||
* A standard {@link Patch} that only adds to {@link Parameter}s. | ||
* | ||
* <p>To create a param patch, see {@link BasicParamPatch}. | ||
*/ | ||
public abstract class ParamPatch extends ReversiblePatch { | ||
|
||
/** | ||
* Scales the patch by a scalar multiplier. | ||
* | ||
* @param scale the scalar multiplier for each patch NDArray | ||
* @return a new patch that is a scaled version of this patch | ||
*/ | ||
public ParamPatch scale(float scale) { | ||
return new ScaledParamPatch(scale, this); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public ParamPatch reverse() { | ||
return scale(-1); | ||
} | ||
|
||
/** | ||
* Returns the patch {@link NDArray} for a particular paramName. | ||
* | ||
* @param paramName the parameter path in a {@link ParameterList}. | ||
* @return the patch array | ||
*/ | ||
public abstract NDArray getPatch(String paramName); | ||
|
||
/** | ||
* Applies the part of this patch to a particular {@link Parameter}. | ||
* | ||
* @param paramName the parameter path in a {@link ParameterList}. | ||
* @param param the {@link Parameter} to patch | ||
*/ | ||
public void apply(String paramName, Parameter param) { | ||
NDArray p = getPatch(paramName).duplicate(); | ||
param.getArray().addi(p); | ||
p.close(); | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link ParameterList}. | ||
* | ||
* @param params the params to patch | ||
*/ | ||
public void apply(ParameterList params) { | ||
for (Pair<String, Parameter> param : params) { | ||
apply(param.getKey(), param.getValue()); | ||
} | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link Block}. | ||
* | ||
* @param block the block to patch | ||
*/ | ||
public void apply(Block block) { | ||
apply(block.getParameters()); | ||
} | ||
|
||
/** | ||
* Applies this patch to a {@link Model}. | ||
* | ||
* @param model the model to patch | ||
*/ | ||
@Override | ||
public void apply(Model model) { | ||
apply(model.getBlock()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
import ai.djl.Model; | ||
|
||
/** | ||
* A method for modifying a {@link Model}. | ||
* | ||
* <p>The most standard form is the {@link ParamPatch}. | ||
*/ | ||
public abstract class Patch implements AutoCloseable { | ||
|
||
/** | ||
* Applies this patch to a model. | ||
* | ||
* @param model the model to update with the patch | ||
*/ | ||
public abstract void apply(Model model); | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public abstract void close(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.patch; | ||
|
||
/** A {@link Patch} that can be reversed. */ | ||
public abstract class ReversiblePatch extends Patch { | ||
|
||
/** | ||
* Returns a new {@link Patch} that reverses the effect of this one. | ||
* | ||
* <p>For a {@link ParamPatch}, it is equivalent to scaling by -1. | ||
* | ||
* @return a new {@link Patch} that reverses the effect of this one. | ||
*/ | ||
public abstract ParamPatch reverse(); | ||
} |
Oops, something went wrong.