Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rust] Support load DJL model for RsModel #3147

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions api/src/main/java/ai/djl/BaseModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,28 @@ protected void setModelDir(Path modelDir) {
this.modelDir = Utils.getNestedModelDir(modelDir);
}

protected void loadBlock(String prefix, Map<String, ?> options)
throws IOException, MalformedModelException {
boolean hasParameter = true;
if (options != null) {
String paramOption = (String) options.get("hasParameter");
if (paramOption != null) {
hasParameter = Boolean.parseBoolean(paramOption);
}
}
if (hasParameter) {
Path paramFile = paramPathResolver(prefix, options);
if (paramFile == null) {
throw new IOException(
"Parameter file not found in: "
+ modelDir
+ ". If you only specified model path, make sure path name"
+ " match your saved model file name.");
}
readParameters(paramFile, options);
}
}

/** {@inheritDoc} */
@Override
public void save(Path modelPath, String newModelName) throws IOException {
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ public static PairList<DataType, Shape> parseShapes(String value) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (value != null) {
if (value.contains("(")) {
Pattern pattern = Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)(\\w?)");
Pattern pattern =
Pattern.compile("\\((\\s*([-\\d]+)([,\\s]+[-\\d]+)*\\s*)\\)(\\w?)");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
Expand Down
10 changes: 8 additions & 2 deletions api/src/main/java/ai/djl/nn/Blocks.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,19 @@ public static Block identityBlock() {
public static Block onesBlock(PairList<DataType, Shape> shapes, String[] names) {
return new LambdaBlock(
a -> {
Shape[] inShapes = a.getShapes();
NDManager manager = a.getManager();
NDList list = new NDList(shapes.size());
int index = 0;
for (Pair<DataType, Shape> pair : shapes) {
long[] shape = pair.getValue().getShape().clone();
for (int i = 0; i < shape.length; ++i) {
if (shape[i] == -1) {
shape[i] = inShapes[index].get(i);
}
}
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
NDArray arr = manager.ones(shape, dataType);
NDArray arr = manager.ones(new Shape(shape), dataType);
if (names.length == list.size()) {
arr.setName(names[index++]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,7 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
*/
block.freezeParameters(!trainParam);
} else {
boolean hasParameter = true;
if (options != null) {
String paramOption = (String) options.get("hasParameter");
if (paramOption != null) {
hasParameter = Boolean.parseBoolean(paramOption);
}
}
if (hasParameter) {
Path paramFile = paramPathResolver(prefix, options);
if (paramFile == null) {
throw new IOException(
"Parameter file not found in: "
+ modelDir
+ ". If you only specified model path, make sure path name"
+ " match your saved model file name.");
}
readParameters(paramFile, options);
}
loadBlock(prefix, options);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
throw new FileNotFoundException(
"Model directory doesn't exist: " + modelPath.toAbsolutePath());
}
modelDir = modelPath.toAbsolutePath();
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
setModelDir(modelPath);
if (block == null) {
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
} else {
loadBlock(prefix, options);
}
}
}
Loading