From 3f6ace84a3f7e06ccede49fe2fa112ab14f6b392 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 13 Sep 2022 12:23:37 -0400 Subject: [PATCH 1/7] Roughing out tree protos. --- .../main/resources/protos/tribuo-tree.proto | 71 +++++++++++++++++++ .../protos/tribuo-regression-tree.proto | 49 +++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 Common/Trees/src/main/resources/protos/tribuo-tree.proto create mode 100644 Regression/RegressionTree/src/main/resources/protos/tribuo-regression-tree.proto diff --git a/Common/Trees/src/main/resources/protos/tribuo-tree.proto b/Common/Trees/src/main/resources/protos/tribuo-tree.proto new file mode 100644 index 000000000..d89c93a5e --- /dev/null +++ b/Common/Trees/src/main/resources/protos/tribuo-tree.proto @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +/* + * Protos for serializing Tribuo trees. + */ +package tribuo.tree; + +option java_multiple_files = true; +option java_package = "org.tribuo.common.tree.protos"; + +// We use any to encode polymorphism +import "google/protobuf/any.proto"; + +// Import Tribuo's core protos +import "tribuo-core.proto"; + +/** +Tree Node proto + */ +message TreeNodeProto { + int32 version = 1; + string class_name = 2; + google.protobuf.Any serialized_data = 3; +} + +/* +SplitNode proto + */ +message SplitNodeProto { + int32 parent_idx = 1; + int32 greater_than_idx = 2; + int32 less_than_or_equal_idx = 3; + int32 split_feature_idx = 4; + double split_value = 5; + double impurity = 6; +} + +/* +LeafNode proto + */ +message LeafNodeProto { + int32 parent_idx = 1; + double impurity = 2; + tribuo.core.OutputProto output = 3; + map score = 4; + bool generates_probabilities = 5; +} + +/* +TreeModel proto + */ +message TreeModelProto { + tribuo.core.ModelDataProto metadata = 1; + repeated TreeNodeProto nodes = 2; +} diff --git a/Regression/RegressionTree/src/main/resources/protos/tribuo-regression-tree.proto b/Regression/RegressionTree/src/main/resources/protos/tribuo-regression-tree.proto new file mode 100644 index 000000000..16cc280bd --- /dev/null +++ b/Regression/RegressionTree/src/main/resources/protos/tribuo-regression-tree.proto @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +/* + * Protos for serializing Tribuo regression trees. + */ +package tribuo.regression.tree; + +option java_multiple_files = true; +option java_package = "org.tribuo.regression.tree.protos"; + +// We use any to encode polymorphism +import "google/protobuf/any.proto"; + +// Import Tribuo's core protos +import "tribuo-core.proto"; + +// Import Tribuo's tree protos +import "tribuo-tree.proto"; + +/* +Carrier for a list of tree nodes. + */ +message TreeNodeListProto { + repeated TreeNodeProto nodes = 3; +} + +/* +IndependentRegressionTreeModel proto + */ +message IndependentRegressionTreeModelProto { + tribuo.core.ModelDataProto metadata = 1; + map nodes = 2; +} \ No newline at end of file From 46a5cc96293b9bf2bb06dd00984e9f22c11ed2b2 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 14 Sep 2022 15:23:32 -0400 Subject: [PATCH 2/7] Filling out tree serialization. Still missing the main TreeModel serialize method. --- .../tribuo/classification/dtree/TestCART.java | 3 +- .../java/org/tribuo/common/tree/LeafNode.java | 107 ++ .../java/org/tribuo/common/tree/Node.java | 3 +- .../org/tribuo/common/tree/SplitNode.java | 120 ++ .../org/tribuo/common/tree/TreeModel.java | 166 ++- .../common/tree/protos/LeafNodeProto.java | 1163 +++++++++++++++++ .../tree/protos/LeafNodeProtoOrBuilder.java | 82 ++ .../common/tree/protos/SplitNodeProto.java | 879 +++++++++++++ .../tree/protos/SplitNodeProtoOrBuilder.java | 51 + .../common/tree/protos/TreeModelProto.java | 966 ++++++++++++++ .../tree/protos/TreeModelProtoOrBuilder.java | 48 + .../common/tree/protos/TreeNodeProto.java | 819 ++++++++++++ .../tree/protos/TreeNodeProtoOrBuilder.java | 42 + .../tribuo/common/tree/protos/TribuoTree.java | 114 ++ .../main/resources/protos/tribuo-tree.proto | 22 +- .../rtree/IndependentRegressionTreeModel.java | 78 +- .../IndependentRegressionTreeModelProto.java | 904 +++++++++++++ ...dentRegressionTreeModelProtoOrBuilder.java | 58 + .../rtree/protos/TreeNodeListProto.java | 778 +++++++++++ .../protos/TreeNodeListProtoOrBuilder.java | 33 + .../rtree/protos/TribuoRegressionTree.java | 83 ++ .../protos/tribuo-regression-tree.proto | 7 +- .../rtree/TestCARTJointRegressionTrainer.java | 4 +- .../rtree/TestCARTRegressionTrainer.java | 4 +- 24 files changed, 6510 insertions(+), 24 deletions(-) create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/LeafNodeProto.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/LeafNodeProtoOrBuilder.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/SplitNodeProto.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/SplitNodeProtoOrBuilder.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/TreeModelProto.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/TreeModelProtoOrBuilder.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/TreeNodeProto.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/TreeNodeProtoOrBuilder.java create mode 100644 Common/Trees/src/main/java/org/tribuo/common/tree/protos/TribuoTree.java create mode 100644 Regression/RegressionTree/src/main/java/org/tribuo/regression/rtree/protos/IndependentRegressionTreeModelProto.java create mode 100644 Regression/RegressionTree/src/main/java/org/tribuo/regression/rtree/protos/IndependentRegressionTreeModelProtoOrBuilder.java create mode 100644 Regression/RegressionTree/src/main/java/org/tribuo/regression/rtree/protos/TreeNodeListProto.java create mode 100644 Regression/RegressionTree/src/main/java/org/tribuo/regression/rtree/protos/TreeNodeListProtoOrBuilder.java create mode 100644 Regression/RegressionTree/src/main/java/org/tribuo/regression/rtree/protos/TribuoRegressionTree.java diff --git a/Classification/DecisionTree/src/test/java/org/tribuo/classification/dtree/TestCART.java b/Classification/DecisionTree/src/test/java/org/tribuo/classification/dtree/TestCART.java index 5f21db086..347491aaa 100644 --- a/Classification/DecisionTree/src/test/java/org/tribuo/classification/dtree/TestCART.java +++ b/Classification/DecisionTree/src/test/java/org/tribuo/classification/dtree/TestCART.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ public static Model