Skip to content

Commit

Permalink
[rust] Support load DJL model for RsModel
Browse files Browse the repository at this point in the history
Allows -1 in shape string presentation
  • Loading branch information
frankfliu committed May 1, 2024
1 parent 1d3613a commit f3a1c37
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 24 deletions.
22 changes: 22 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,28 @@ protected void setModelDir(Path modelDir) {
this.modelDir = Utils.getNestedModelDir(modelDir);
}

protected void loadBlock(String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
boolean hasParameter = true;
if (options != null) {
String paramOption = (String) options.get("hasParameter");
if (paramOption != null) {
hasParameter = Boolean.parseBoolean(paramOption);
}
}
if (hasParameter) {
Path paramFile = paramPathResolver(prefix, options);
if (paramFile == null) {
throw new IOException(
"Parameter file not found in: "
+ modelDir
+ ". If you only specified model path, make sure path name"
+ " match your saved model file name.");
}
readParameters(paramFile, options);
}
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ public static PairList<DataType, Shape> parseShapes(String value) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (value != null) {
if (value.contains("(")) {
Pattern pattern = Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)(\\w?)");
Pattern pattern =
Pattern.compile("\\((\\s*([-\\d]+)([,\\s]+[-\\d]+)*\\s*)\\)(\\w?)");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
Expand Down
11 changes: 9 additions & 2 deletions api/src/main/java/ai/djl/nn/Blocks.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,20 @@ public static Block identityBlock() {
public static Block onesBlock(PairList<DataType, Shape> shapes, String[] names) {
return new LambdaBlock(
a -> {
Shape[] inShapes = a.getShapes();
NDManager manager = a.getManager();
NDList list = new NDList(shapes.size());
int index = 0;
for (Pair<DataType, Shape> pair : shapes) {
long[] shape = pair.getValue().getShape().clone();
System.out.println(pair.getValue());
for (int i = 0; i < shape.length; ++i) {
if (shape[i] == -1) {
shape[i] = inShapes[index].get(i);
}
}
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
NDArray arr = manager.ones(shape, dataType);
NDArray arr = manager.ones(new Shape(shape), dataType);
if (names.length == list.size()) {
arr.setName(names[index++]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
*/
block.freezeParameters(!trainParam);
} else {
boolean hasParameter = true;
if (options != null) {
String paramOption = (String) options.get("hasParameter");
if (paramOption != null) {
hasParameter = Boolean.parseBoolean(paramOption);
}
}
if (hasParameter) {
Path paramFile = paramPathResolver(prefix, options);
if (paramFile == null) {
throw new IOException(
"Parameter file not found in: "
+ modelDir
+ ". If you only specified model path, make sure path name"
+ " match your saved model file name.");
}
readParameters(paramFile, options);
}
loadBlock(prefix, options);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
throw new FileNotFoundException(
"Model directory doesn't exist: " + modelPath.toAbsolutePath());
}
modelDir = modelPath.toAbsolutePath();
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
setModelDir(modelPath);
if (block == null) {
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
} else {
loadBlock(prefix, options);
}
}
}

0 comments on commit f3a1c37

Please sign in to comment.