Skip to content

Commit

Permalink
[api] Adds OnesBlockFactory to make it easy for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 29, 2024
1 parent c68f8a7 commit f418481
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 16 deletions.
62 changes: 62 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -535,4 +537,64 @@ public boolean isRankOne() {
}
return max == ans;
}

/**
* Parses a string representation of shapes for NDList.
*
* @param value a string representation of shapes for NDList
* @return a list of Shape and datatype pairs
*/
public static PairList<DataType, Shape> parseShapes(String value) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (value != null) {
if (value.contains("(")) {
Pattern pattern = Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)(\\w?)");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
DataType dataType;
String dataTypeStr = matcher.group(4);
if (dataTypeStr == null || dataTypeStr.isEmpty()) {
dataType = DataType.FLOAT32;
} else {
switch (dataTypeStr) {
case "s":
dataType = DataType.FLOAT16;
break;
case "d":
dataType = DataType.FLOAT64;
break;
case "u":
dataType = DataType.UINT8;
break;
case "b":
dataType = DataType.INT8;
break;
case "i":
dataType = DataType.INT32;
break;
case "l":
dataType = DataType.INT64;
break;
case "B":
dataType = DataType.BOOLEAN;
break;
case "f":
dataType = DataType.FLOAT32;
break;
default:
throw new IllegalArgumentException("Invalid input-shape: " + value);
}
}
inputShapes.add(dataType, new Shape(array));
}
} else {
String[] tokens = value.split(",");
long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
inputShapes.add(DataType.FLOAT32, new Shape(shapes));
}
}
return inputShapes;
}
}
28 changes: 28 additions & 0 deletions api/src/main/java/ai/djl/nn/Blocks.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -87,6 +90,31 @@ public static Block identityBlock() {
return new LambdaBlock(x -> x, "identity");
}

/**
* Creates a {@link LambdaBlock} that return all-ones NDList.
*
* @return an all-ones {@link Block}
*/
public static Block onesBlock(PairList<DataType, Shape> shapes, String[] names) {
return new LambdaBlock(
a -> {
NDManager manager = a.getManager();
NDList list = new NDList(shapes.size());
int index = 0;
for (Pair<DataType, Shape> pair : shapes) {
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
NDArray arr = manager.ones(shape, dataType);
if (names.length == list.size()) {
arr.setName(names[index++]);
}
list.add(arr);
}
return list;
},
"ones");
}

/**
* Returns a string representation of the passed {@link Block} describing the input axes, output
* axes, and the block's children.
Expand Down
45 changes: 45 additions & 0 deletions api/src/main/java/ai/djl/nn/OnesBlockFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.Model;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import java.nio.file.Path;
import java.util.Map;

/** A {@link BlockFactory} class that creates LambdaBlock. */
public class OnesBlockFactory implements BlockFactory {

private static final long serialVersionUID = 1L;

/** {@inheritDoc} */
@Override
public Block newBlock(Model model, Path modelPath, Map<String, ?> arguments) {
String shapes = ArgumentsUtil.stringValue(arguments, "block_shapes");
String blockNames = ArgumentsUtil.stringValue(arguments, "block_names");
PairList<DataType, Shape> pairs = Shape.parseShapes(shapes);
String[] names;
if (blockNames != null) {
names = blockNames.split(",");
} else {
names = Utils.EMPTY_ARRAY;
}

return Blocks.onesBlock(pairs, names);
}
}
38 changes: 38 additions & 0 deletions api/src/test/java/ai/djl/nn/BlockFactoryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
package ai.djl.nn;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;

import org.testng.Assert;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@SuppressWarnings("PMD.TestClassWithoutTestCases")
public class BlockFactoryTest {
Expand All @@ -32,4 +39,35 @@ public void testIdentityBlockFactory() {
Assert.assertEquals(((LambdaBlock) block).getName(), "identity");
}
}

@Test
public void testOnesBlockFactory() throws ModelException, IOException {
OnesBlockFactory factory = new OnesBlockFactory();
Path path = Paths.get("build");
Criteria<NDList, NDList> criteria =
Criteria.builder()
.setTypes(NDList.class, NDList.class)
.optModelPath(path)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1)s,(1)d,(1)u,(1)b,(1)i,(1)l,(1)B,(1)f,(1)")
.optArgument("block_names", "1,2,3,4,5,6,7,8,9")
.optOption("hasParameter", "false")
.build();

try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
Block block = model.getBlock();
Assert.assertTrue(block instanceof LambdaBlock);

Map<String, String> args = new ConcurrentHashMap<>();
args.put("block_shapes", "1,2");
block = factory.newBlock(model, path, args);
Assert.assertTrue(block instanceof LambdaBlock);

args.put("block_shapes", "(1)a");
Assert.assertThrows(
() -> {
factory.newBlock(model, path, args);
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.LambdaBlock;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.testing.Assertions;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;
import ai.djl.util.Utils;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -50,15 +54,6 @@ public void testTextEmbeddingTranslator()
throws ModelException, IOException, TranslateException {
String text = "This is an example sentence";

Block block =
new LambdaBlock(
a -> {
NDManager manager = a.getManager();
NDArray arr = manager.ones(new Shape(1, 7, 384));
arr.setName("last_hidden_state");
return new NDList(arr);
},
"model");
Path modelDir = Paths.get("build/model");
Files.createDirectories(modelDir);
try (NDManager manager = NDManager.newBaseManager("Rust")) {
Expand All @@ -84,8 +79,10 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optEngine("PyTorch")
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optArgument("tokenizer", "bert-base-uncased")
.optOption("hasParameter", "false")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
Expand All @@ -103,7 +100,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "max")
Expand All @@ -123,7 +122,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "mean_sqrt_len")
Expand All @@ -143,7 +144,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "weightedmean")
Expand All @@ -163,7 +166,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("dense", "linear.safetensors")
Expand All @@ -183,7 +188,9 @@ public void testTextEmbeddingTranslator()
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(modelDir)
.optBlock(block)
.optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory")
.optArgument("block_shapes", "(1,7,384)")
.optArgument("block_names", "last_hidden_state")
.optEngine("PyTorch")
.optArgument("tokenizer", "bert-base-uncased")
.optArgument("pooling", "cls")
Expand Down Expand Up @@ -212,7 +219,9 @@ public void testTextEmbeddingTranslator()
}

try (Model model = Model.newInstance("test")) {
model.setBlock(block);
PairList<DataType, Shape> pairs = new PairList<>();
pairs.add(DataType.FLOAT32, new Shape(1, 7, 384));
model.setBlock(Blocks.onesBlock(pairs, Utils.EMPTY_ARRAY));
Map<String, String> options = new HashMap<>();
options.put("hasParameter", "false");
model.load(modelDir, "test", options);
Expand Down

0 comments on commit f418481

Please sign in to comment.