Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Adds OnesBlockFactory to make it easy for testing #3140

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -535,4 +537,64 @@ public boolean isRankOne() {
}
return max == ans;
}

/**
* Parses a string representation of shapes for NDList.
*
* @param value a string representation of shapes for NDList
* @return a list of Shape and datatype pairs
*/
public static PairList<DataType, Shape> parseShapes(String value) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (value != null) {
if (value.contains("(")) {
Pattern pattern = Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)(\\w?)");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
DataType dataType;
String dataTypeStr = matcher.group(4);
if (dataTypeStr == null || dataTypeStr.isEmpty()) {
dataType = DataType.FLOAT32;
} else {
switch (dataTypeStr) {
case "s":
dataType = DataType.FLOAT16;
break;
case "d":
dataType = DataType.FLOAT64;
break;
case "u":
dataType = DataType.UINT8;
break;
case "b":
dataType = DataType.INT8;
break;
case "i":
dataType = DataType.INT32;
break;
case "l":
dataType = DataType.INT64;
break;
case "B":
dataType = DataType.BOOLEAN;
break;
case "f":
dataType = DataType.FLOAT32;
break;
default:
throw new IllegalArgumentException("Invalid input-shape: " + value);
}
}
inputShapes.add(dataType, new Shape(array));
}
} else {
String[] tokens = value.split(",");
long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
inputShapes.add(DataType.FLOAT32, new Shape(shapes));
}
}
return inputShapes;
}
}
28 changes: 28 additions & 0 deletions api/src/main/java/ai/djl/nn/Blocks.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -87,6 +90,31 @@ public static Block identityBlock() {
return new LambdaBlock(x -> x, "identity");
}

/**
* Creates a {@link LambdaBlock} that return all-ones NDList.
*
* @return an all-ones {@link Block}
*/
public static Block onesBlock(PairList<DataType, Shape> shapes, String[] names) {
return new LambdaBlock(
a -> {
NDManager manager = a.getManager();
NDList list = new NDList(shapes.size());
int index = 0;
for (Pair<DataType, Shape> pair : shapes) {
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
NDArray arr = manager.ones(shape, dataType);
if (names.length == list.size()) {
arr.setName(names[index++]);
}
list.add(arr);
}
return list;
},
"ones");
}

/**
* Returns a string representation of the passed {@link Block} describing the input axes, output
* axes, and the block's children.
Expand Down
45 changes: 45 additions & 0 deletions api/src/main/java/ai/djl/nn/OnesBlockFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import java.nio.file.Path;
import java.util.Map;

/** A {@link BlockFactory} class that creates LambdaBlock. */
public class OnesBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
String shapes = ArgumentsUtil.stringValue(arguments, "block_shapes");
String blockNames = ArgumentsUtil.stringValue(arguments, "block_names");
PairList<DataType, Shape> pairs = Shape.parseShapes(shapes);
String[] names;
if (blockNames != null) {
names = blockNames.split(",");
} else {
names = Utils.EMPTY_ARRAY;
}

return Blocks.onesBlock(pairs, names);
}
}
38 changes: 38 additions & 0 deletions api/src/test/java/ai/djl/nn/BlockFactoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
package ai.djl.nn;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@SuppressWarnings("PMD.TestClassWithoutTestCases")
public class BlockFactoryTest {
Expand All @@ -32,4 +39,35 @@ public void testIdentityBlockFactory() {
Assert.assertEquals(((LambdaBlock) block).getName(), "identity");
}
}

@Test
public void testOnesBlockFactory() throws ModelException, IOException {
OnesBlockFactory factory = new OnesBlockFactory();
Path path = Paths.get("build");
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelPath(path)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1)s,(1)d,(1)u,(1)b,(1)i,(1)l,(1)B,(1)f,(1)")
.optArgument("block_names", "1,2,3,4,5,6,7,8,9")
.optOption("hasParameter", "false")
.build();

try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
Block block = model.getBlock();
Assert.assertTrue(block instanceof LambdaBlock);

Map<String, String> args = new ConcurrentHashMap<>();
args.put("block_shapes", "1,2");
block = factory.newBlock(model, path, args);
Assert.assertTrue(block instanceof LambdaBlock);

args.put("block_shapes", "(1)a");
Assert.assertThrows(
() -> {
factory.newBlock(model, path, args);
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.LambdaBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -50,15 +54,6 @@ public void testTextEmbeddingTranslator()
throws ModelException, IOException, TranslateException {
String text = "This is an example sentence";

Block block =
new LambdaBlock(
a -> {
NDManager manager = a.getManager();
NDArray arr = manager.ones(new Shape(1, 7, 384));
arr.setName("last_hidden_state");
return new NDList(arr);
},
"model");
Path modelDir = Paths.get("build/model");
Files.createDirectories(modelDir);
try (NDManager manager = NDManager.newBaseManager("Rust")) {
Expand All @@ -84,8 +79,10 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optEngine("PyTorch")
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optArgument("tokenizer", "bert-base-uncased")
.optOption("hasParameter", "false")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
Expand All @@ -103,7 +100,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "max")
Expand All @@ -123,7 +122,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "mean_sqrt_len")
Expand All @@ -143,7 +144,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "weightedmean")
Expand All @@ -163,7 +166,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("dense", "linear.safetensors")
Expand All @@ -183,7 +188,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "cls")
Expand Down Expand Up @@ -212,7 +219,9 @@ public void testTextEmbeddingTranslator()
}

try (Model model = Model.newInstance("test")) {
model.setBlock(block);
PairList<DataType, Shape> pairs = new PairList<>();
pairs.add(DataType.FLOAT32, new Shape(1, 7, 384));
model.setBlock(Blocks.onesBlock(pairs, Utils.EMPTY_ARRAY));
Map<String, String> options = new HashMap<>();
options.put("hasParameter", "false");
model.load(modelDir, "test", options);
Expand Down
Loading