From f3a1c378168fa1f05fe6e7a47583c6d79b4c5a0b Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 30 Apr 2024 09:05:40 -0700 Subject: [PATCH] [rust] Support load DJL model for RsModel Allows -1 in shape string presentation --- api/src/main/java/ai/djl/BaseModel.java | 22 +++++++++++++++++++ .../main/java/ai/djl/ndarray/types/Shape.java | 3 ++- api/src/main/java/ai/djl/nn/Blocks.java | 11 ++++++++-- .../java/ai/djl/pytorch/engine/PtModel.java | 19 +--------------- .../main/java/ai/djl/engine/rust/RsModel.java | 10 ++++++--- 5 files changed, 41 insertions(+), 24 deletions(-) diff --git a/api/src/main/java/ai/djl/BaseModel.java b/api/src/main/java/ai/djl/BaseModel.java index db2f0d3dd708..5705d480b38a 100644 --- a/api/src/main/java/ai/djl/BaseModel.java +++ b/api/src/main/java/ai/djl/BaseModel.java @@ -234,6 +234,28 @@ protected void setModelDir(Path modelDir) { this.modelDir = Utils.getNestedModelDir(modelDir); } + protected void loadBlock(String prefix, Map options) + throws IOException, MalformedModelException { + boolean hasParameter = true; + if (options != null) { + String paramOption = (String) options.get("hasParameter"); + if (paramOption != null) { + hasParameter = Boolean.parseBoolean(paramOption); + } + } + if (hasParameter) { + Path paramFile = paramPathResolver(prefix, options); + if (paramFile == null) { + throw new IOException( + "Parameter file not found in: " + + modelDir + + ". If you only specified model path, make sure path name" + + " match your saved model file name."); + } + readParameters(paramFile, options); + } + } + /** {@inheritDoc} */ @Override public void save(Path modelPath, String newModelName) throws IOException { diff --git a/api/src/main/java/ai/djl/ndarray/types/Shape.java b/api/src/main/java/ai/djl/ndarray/types/Shape.java index 51b4d2ad64c0..10ec0bf32921 100644 --- a/api/src/main/java/ai/djl/ndarray/types/Shape.java +++ b/api/src/main/java/ai/djl/ndarray/types/Shape.java @@ -548,7 +548,8 @@ public static PairList parseShapes(String value) { PairList inputShapes = new PairList<>(); if (value != null) { if (value.contains("(")) { - Pattern pattern = Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)(\\w?)"); + Pattern pattern = + Pattern.compile("\\((\\s*([-\\d]+)([,\\s]+[-\\d]+)*\\s*)\\)(\\w?)"); Matcher matcher = pattern.matcher(value); while (matcher.find()) { String[] tokens = matcher.group(1).split(","); diff --git a/api/src/main/java/ai/djl/nn/Blocks.java b/api/src/main/java/ai/djl/nn/Blocks.java index 136c9bd0f57d..d10fefde12a7 100644 --- a/api/src/main/java/ai/djl/nn/Blocks.java +++ b/api/src/main/java/ai/djl/nn/Blocks.java @@ -98,13 +98,20 @@ public static Block identityBlock() { public static Block onesBlock(PairList shapes, String[] names) { return new LambdaBlock( a -> { + Shape[] inShapes = a.getShapes(); NDManager manager = a.getManager(); NDList list = new NDList(shapes.size()); int index = 0; for (Pair pair : shapes) { + long[] shape = pair.getValue().getShape().clone(); + System.out.println(pair.getValue()); + for (int i = 0; i < shape.length; ++i) { + if (shape[i] == -1) { + shape[i] = inShapes[index].get(i); + } + } DataType dataType = pair.getKey(); - Shape shape = pair.getValue(); - NDArray arr = manager.ones(shape, dataType); + NDArray arr = manager.ones(new Shape(shape), dataType); if (names.length == list.size()) { arr.setName(names[index++]); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java index 6b5251c357cb..f7f0a71ea65c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtModel.java @@ -115,24 +115,7 @@ public void load(Path modelPath, String prefix, Map options) */ block.freezeParameters(!trainParam); } else { - boolean hasParameter = true; - if (options != null) { - String paramOption = (String) options.get("hasParameter"); - if (paramOption != null) { - hasParameter = Boolean.parseBoolean(paramOption); - } - } - if (hasParameter) { - Path paramFile = paramPathResolver(prefix, options); - if (paramFile == null) { - throw new IOException( - "Parameter file not found in: " - + modelDir - + ". If you only specified model path, make sure path name" - + " match your saved model file name."); - } - readParameters(paramFile, options); - } + loadBlock(prefix, options); } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index 4cdb3e9b5cd0..7ab4d209ccfb 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -48,8 +48,12 @@ public void load(Path modelPath, String prefix, Map options) throw new FileNotFoundException( "Model directory doesn't exist: " + modelPath.toAbsolutePath()); } - modelDir = modelPath.toAbsolutePath(); - long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal()); - block = new RsSymbolBlock((RsNDManager) manager, handle); + setModelDir(modelPath); + if (block == null) { + long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal()); + block = new RsSymbolBlock((RsNDManager) manager, handle); + } else { + loadBlock(prefix, options); + } } }