diff --git a/.gitignore b/.gitignore index ba72453bb..55841367d 100644 --- a/.gitignore +++ b/.gitignore @@ -16,12 +16,19 @@ bin/ .*.swp # Other files -*.jar -*.class *.er *.log *.bck *.so +*.patch + +# Binaries +*.jar +*.class + +# Archives +*.gz +*.zip # Serialised models *.ser @@ -29,3 +36,12 @@ bin/ # Temporary stuff junk/* .DS_Store +.ipynb_checkpoints + +# Profiling files +*.jfr +*.iprof +*.jfc + +# Tutorial files +tutorials/*.svm diff --git a/Common/LibSVM/src/main/java/org/tribuo/common/libsvm/LibSVMTrainer.java b/Common/LibSVM/src/main/java/org/tribuo/common/libsvm/LibSVMTrainer.java index 37e28a489..0885d6d13 100644 --- a/Common/LibSVM/src/main/java/org/tribuo/common/libsvm/LibSVMTrainer.java +++ b/Common/LibSVM/src/main/java/org/tribuo/common/libsvm/LibSVMTrainer.java @@ -131,6 +131,7 @@ protected LibSVMTrainer() {} /** * Constructs a LibSVMTrainer from the parameters. * @param parameters The SVM parameters. + * @param seed The RNG seed. */ protected LibSVMTrainer(SVMParameters parameters, long seed) { this.parameters = parameters.getParameters(); diff --git a/Core/src/main/java/org/tribuo/Model.java b/Core/src/main/java/org/tribuo/Model.java index 1f42fbcf8..8c0e3f5d3 100644 --- a/Core/src/main/java/org/tribuo/Model.java +++ b/Core/src/main/java/org/tribuo/Model.java @@ -300,20 +300,21 @@ public String toString() { /** * Casts the model to the specified output type, assuming it is valid. - *

* If it's not valid, throws {@link ClassCastException}. - * @param inputModel The model to cast. + *

+ * This method is intended for use on a deserialized model to restore it's + * generic type in a safe way. * @param outputType The output type to cast to. - * @param The output type. + * @param The output type. * @return The model cast to the correct value. */ - public static > Model castModel(Model inputModel, Class outputType) { - if (inputModel.validate(outputType)) { + public > Model castModel(Class outputType) { + if (validate(outputType)) { @SuppressWarnings("unchecked") // guarded by validate - Model castedModel = (Model) inputModel; + Model castedModel = (Model) this; return castedModel; } else { - throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString()); + throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + this.toString()); } } diff --git a/Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java b/Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java index 5f20c43e0..765e07dd2 100644 --- a/Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java +++ b/Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java @@ -49,6 +49,7 @@ * "The Elements of Statistical Learning" * Springer 2001. PDF * + * @param The prediction type. */ public class BaggingTrainer> implements Trainer { @@ -177,6 +178,7 @@ public EnsembleModel train(Dataset examples, Map runPr * @param labelIDs The output domain. * @param randInt A random int from an rng instance * @param runProvenance Provenance for this instance. + * @param invocationCount The invocation count for the inner trainer. * @return The trained ensemble member. */ protected Model trainSingleModel(Dataset examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo labelIDs, int randInt, Map runProvenance, int invocationCount) { diff --git a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java index af3a11e01..ca79a0eb8 100644 --- a/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java +++ b/Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java @@ -79,9 +79,10 @@ default ONNXNode exportCombiner(ONNXNode input) { * will be required to provide ONNX support. * @param input the node to be ensembled according to this implementation. * @param weight The node of weights for ensembling. + * @param The type of the weights input reference. * @return The leaf node of the graph of operations added to ensemble input. */ - default > ONNXNode exportCombiner(ONNXNode input, T weight) { + default > ONNXNode exportCombiner(ONNXNode input, U weight) { Logger.getLogger(this.getClass().getName()).severe("Tried to export an ensemble combiner to ONNX format, but this is not implemented."); throw new IllegalStateException("This ensemble cannot be exported as the combiner '" + this.getClass() + "' uses the default implementation of EnsembleCombiner.exportCombiner."); } diff --git a/Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java b/Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java index a9b5c9907..bfba02f4a 100644 --- a/Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java +++ b/Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java @@ -72,7 +72,7 @@ private SQLDBConfig() {} /** * Constructs a SQL database configuration. *

- * Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}. + * Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}. * @param connectionString The connection string. * @param username The username. * @param password The password. @@ -87,7 +87,7 @@ public SQLDBConfig(String connectionString, String username, String password, Ma /** * Constructs a SQL database configuration. *

- * Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}. + * Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}. * @param host The host to connect to. * @param port The port to connect on. * @param db The db name. diff --git a/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModel.java b/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModel.java index 904254dc5..ef633b86f 100644 --- a/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModel.java +++ b/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModel.java @@ -309,6 +309,7 @@ public static ConfigFileAuthenticationDetailsProvider makeAuthProvider(Path conf * @param configFile The OCI configuration file, if null use the default file. * @param endpointURL The endpoint URL. * @param outputConverter The converter for the specified output type. + * @param The output type. * @return An OCIModel ready to score new inputs. */ public static > OCIModel createOCIModel(OutputFactory factory, @@ -332,6 +333,7 @@ public static > OCIModel createOCIModel(OutputFactory * @param profileName The profile name in the OCI configuration file, if null uses the default profile. * @param endpointURL The endpoint URL. * @param outputConverter The converter for the specified output type. + * @param The output type. * @return An OCIModel ready to score new inputs. */ public static > OCIModel createOCIModel(OutputFactory factory, diff --git a/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModelCLI.java b/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModelCLI.java index 6337a8e6d..81db75aa9 100644 --- a/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModelCLI.java +++ b/Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModelCLI.java @@ -64,7 +64,7 @@ private static void createModelAndDeploy(OCIModelOptions options) throws IOExcep // Load the Tribuo model Model

* N.B. This class will be sealed once the library is updated past Java 8. Users should not subclass this class. - * @param + * @param The protobuf type this reference generates. */ public abstract class ONNXRef { // Unfortunately there is no other shared supertype for OnnxML protobufs @@ -44,19 +44,36 @@ public abstract class ONNXRef { private final String baseName; protected final ONNXContext context; - + /** + * Creates an ONNXRef for the specified context, protobuf and name. + * @param context The ONNXContext we're operating in. + * @param backRef The protobuf reference. + * @param baseName The name of this reference. + */ ONNXRef(ONNXContext context, T backRef, String baseName) { this.context = context; this.backRef = backRef; this.baseName = baseName; } + /** + * Gets the output name of this object. + * @return The output name. + */ public abstract String getReference(); + /** + * The name of this object. + * @return The name. + */ public String getBaseName() { return baseName; } + /** + * The context this reference operates in. + * @return The context. + */ public ONNXContext onnxContext() { return context; } @@ -66,7 +83,7 @@ public ONNXContext onnxContext() { * as the first argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. The other * arguments behave as in the analogous method on ONNXContext. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. - * @param otherInputs A list of {@ONNXRef}s created by this instance of ONNXContext. + * @param otherInputs A list of {@link ONNXRef}s created by this instance of ONNXContext. * @param outputs A list of names that the output nodes of {@code op} should take. * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. @@ -199,7 +216,7 @@ public > Ret assignTo(Ret output) { /** * Casts this ONNXRef to a different type using the {@link ONNXOperators#CAST} operation, and returning the output * node of that op. Currently supports only float, double, int, and long, which are specified by their respective - * {@link Class} objects (eg. {@link float.class}). Throws {@link IllegalArgumentException} when an unsupported cast + * {@link Class} objects (e.g., {@code float.class}). Throws {@link IllegalArgumentException} when an unsupported cast * is requested. * @param clazz The class object specifying the type to cast to. * @return An ONNXRef representing this object cast into the requested type. diff --git a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java index 4abf7339f..1c2aeedb4 100644 --- a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/package-info.java @@ -15,7 +15,7 @@ */ /** - * Interfaces and utilities for writing ONNX models from Java. *

* Developed to support Tribuo, but can be used to export * other machine learning models from JVM languages. diff --git a/docs/Architecture.md b/docs/Architecture.md index 054db2f22..3bee56c57 100644 --- a/docs/Architecture.md +++ b/docs/Architecture.md @@ -421,3 +421,67 @@ that Tribuo has no knowledge of the true feature names, and the system transparently hashes the inputs. The feature names tend to be particularly sensitive when working with NLP problems. For example, without such hashing, bigrams would appear in the feature domains. + +## ONNX Export + +From v4.2 Tribuo supports exporting some models in the [ONNX](https://onnx.ai) +model format. The ONNX format is a cross-platform model exchange format which +can be loaded in by many different machine learning libraries. Tribuo supports +inference on ONNX models via ONNX Runtime. Models which can be exported +implement the `ONNXExportable` interface, which provides methods for +constructing the ONNX protobuf and serializing it to disk. As of the release of +4.2, a subset of Tribuo's models are supported: linear models, sparse linear +models, LibSVM models, factorization machines, and ensembles thereof. We plan +to expand the set of exportable models in future releases. It is unlikely that +Tribuo will support direct ONNX export of TensorFlow models, however this can +be achieved by saving the Tribuo trained model in TensorFlow Saved Model +format, and then using the Python +[tf2onnx](https://github.com/onnx/tensorflow-onnx) project to convert that into +an onnx file. + +### ONNX and provenance + +Tribuo-exported ONNX files contain the Tribuo model provenance, stored as a +protobuf in the metadata field "TRIBUO\_PROVENANCE". If the model is loaded +back into Tribuo via ONNX Runtime, then the model provenance can be recovered +from the file, allowing the reproducibility system and the model tracking +features to work. + +### ONNX and deployment + +The ONNX format is widely supported in industry and across cloud providers. +Many hardware accelerators and edge computing vendors provide ONNX support for +their inference platforms, and this allows Tribuo-trained models to be widely +deployed after they have been exported. Tribuo provides an interface to [OCI +Data Science Model +Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm) +which deploys an ONNX model on [Oracle Cloud](https://www.oracle.com/cloud/), +and also can wrap a model deployment REST endpoint so it appears as a Tribuo +Model, allowing cloud deployment and inference from Tribuo. ONNX models are +also supported by [Oracle Machine Learning +Services](https://docs.oracle.com/en/database/oracle/machine-learning/omlss/index.html), +and many other cloud providers also provide ONNX model inference services which +can be used with exported Tribuo ONNX models. + +## Reproducibility + +From v4.2 Tribuo has a built-in reproducibility system for non-sequence Models. +This accepts a `Model` or `ModelProvenance` instance, automatically extracts +the configuration from the instance and then retrains the model, using the +data loading pipeline and training hyperparameters specified in the model provenance. +The system produces a diff of the reproduced model's provenance against the +original provenance, highlighting areas where the new model may behave differently +to the old one (e.g., showing if the number of features differs, or if the data +files have changed). + +This is useful to check the validity of deployed production models, and to allow +easy comparison between a production model and one trained on current data. Over +time we plan to expand this system to support experimenting with different model +hyperparameters and training data configurations, tracking all this information +using the provenance built into Tribuo. + +The reproducibility system requires Java 17, and as such is not included in the +`tribuo-all` Maven Central target. It is designed to be used in a development +environment rather than deployed in a production system like the rest of +Tribuo. As Tribuo migrates to newer versions of Java, we will consider +providing a jlink'd version of this utility. diff --git a/docs/FAQs.md b/docs/FAQs.md index 4882dc276..2fc1847d0 100644 --- a/docs/FAQs.md +++ b/docs/FAQs.md @@ -90,7 +90,7 @@ libraries, and given Python's lax approach to typing, those methods are only part of the API by convention rather that being enforced by the type system. In Tribuo, we've separated training from prediction. Tribuo's fit method is called "train" and lives on the `Trainer` interface, whereas Tribuo's "predict" method -lives on the Model class. Tribuo uses the same predict call to produce both the +lives on the `Model` class. Tribuo uses the same predict call to produce both the outputs and the scores for those outputs. Its predict method is the equivalent of both "predict" and "predict\_proba" in scikit-learn. We made this separation between training and prediction so as to enable the type system to act as a diff --git a/docs/HelperPrograms.md b/docs/HelperPrograms.md index b256f115e..b92019766 100644 --- a/docs/HelperPrograms.md +++ b/docs/HelperPrograms.md @@ -59,6 +59,14 @@ operations best done in user code if they are required, however we consider `StripProvenance` part of the supported API as it performs a complex function and is best expressed as a standalone program. +### OCIModelCLI + +Found in tribuo-oci, `org.tribuo.oci.OCIModelCLI` can deploy a Tribuo +multi-class classification model in OCI Data Science using the model deployment +API. It provides a CLI wrapper around the functions in `org.tribuo.oci.OCIUtil` +which can deploy classification, regression and multi-label classification +models to OCI. + ### PreprocessAndSerialize Found in tribuo-data, `org.tribuo.data.PreprocessAndSerialize` loads in a diff --git a/docs/Internals.md b/docs/Internals.md index 4026dab49..c5c189ce7 100644 --- a/docs/Internals.md +++ b/docs/Internals.md @@ -67,6 +67,8 @@ provenance built into their models and evaluations. ## Tracing a training and evaluation run +This section describes the internal process of a training and evaluation run. + ### DataSource `Example`s are created in a `DataSource`. Preferably they are created with a `Feature` list as this ensures the O(n log n) sort cost is paid once, rather than diff --git a/docs/PackageOverview.md b/docs/PackageOverview.md index 81696141d..b51e76647 100644 --- a/docs/PackageOverview.md +++ b/docs/PackageOverview.md @@ -55,7 +55,7 @@ a math library, and common modules shared across prediction types. are always applied at prediction time. - `util` - Utilities for basic operations such as for working with arrays and random samples. -- Data - (artifactID `tribuo-data`, package root: `org.tribuo.data`) provides classes which deal with sampled data, columnar data, csv +- Data - (artifactID: `tribuo-data`, package root: `org.tribuo.data`) provides classes which deal with sampled data, columnar data, csv files and text inputs. The user is encouraged to provide their own text processing infrastructure implementation, as the one here is fairly basic. - `columnar` - The columnar package provides many useful base classes for @@ -66,9 +66,9 @@ processing infrastructure implementation, as the one here is fairly basic. working with JDBC sources. - `text` - Text processing infrastructure interfaces and an example implementation. -- Json - (artifactID `tribuo-json`, package root: `org.tribuo.json`) provides functionality +- Json - (artifactID: `tribuo-json`, package root: `org.tribuo.json`) provides functionality for loading from json data sources, and for stripping provenance out of a model. -- Math - (artifactID `tribuo-math`, package root: `org.tribuo.math`) provides a linear algebra library for working with both sparse +- Math - (artifactID: `tribuo-math`, package root: `org.tribuo.math`) provides a linear algebra library for working with both sparse and dense vectors and matrices. - `kernel` - a set of kernel functions for use in the SGD package (and elsewhere). - `la` - a linear algebra library containing functions used in the @@ -79,6 +79,20 @@ should be considered the default algorithm since it works best across the widest range of linear SGD problems. - `util` - various util classes for working with arrays, vectors and matrices. +## Util libraries + +There are 3 utility libraries which are used by Tribuo but do not depend +on other parts of it. + +- InformationTheory - (artifactID: `tribuo-util-infotheory`, package root: `org.tribuo.util.infotheory`) provides discrete information theoretic functions suitable +for computing clustering metrics, feature selection and structure learning. +- ONNXExport - (artifactID: `tribuo-util-onnx`, package root: `org.tribuo.util.onnx`) provides infrastructure for building ONNX graphs from Java. +This package is suitable for use in other JVM libraries which want to write ONNX models, and provides additional type safety and usability over +directly writing the protobufs. +- Tokenization - (artifactID: `tribuo-util-tokenization`, package root: `org.tribuo.util.tokens`) provides a tokenization API suitable +for feature extraction or information retrieval, along with several tokenizer implementations, including a wordpiece implementation +suitable for use with models like BERT. + ## Multi-class Classification Multi-class classification is the act of assigning a single label from a set of @@ -93,7 +107,7 @@ labels to a test example. The classification module has several submodules: | LibLinear | `tribuo-classification-liblinear` | `org.tribuo.classification.liblinear` | A wrapper around the LibLinear-java library. This provides linear-SVMs and other l1 or l2 regularised linear classifiers. | | LibSVM | `tribuo-classification-libsvm` | `org.tribuo.classification.libsvm` | A wrapper around the Java version of LibSVM. This provides linear & kernel SVMs with sigmoid, gaussian and polynomial kernels. | | Multinomial Naive Bayes | `tribuo-classification-mnnaivebayes` | `org.tribuo.classification.mnb` | An implementation of a multinomial naive bayes classifier. Since it aims to store a compact in-memory representation of the model, it only keeps track of weights for observed feature/class pairs. | -| SGD | `tribuo-classification-sgd` | `org.tribuo.classification.sgd` | An implementation of stochastic gradient descent based classifiers. It includes a linear package for logistic regression and linear-SVM (using log and hinge losses, respectively), a kernel package for training a kernel-SVM using the Pegasos algorithm, and a crf package for training a linear-chain CRF. These implementations depend upon the stochastic gradient optimisers in the main Math package. The linear and crf packages can use any of the provided gradient optimisers, which enforce various different kinds of regularisation or convergence metrics. This is the preferred package for linear classification and for sequence classification due to the speed and scalability of the SGD approach. | +| SGD | `tribuo-classification-sgd` | `org.tribuo.classification.sgd` | An implementation of stochastic gradient descent based classifiers. It includes a linear package for logistic regression and linear-SVM (using log and hinge losses, respectively), a kernel package for training a kernel-SVM using the Pegasos algorithm, a crf package for training a linear-chain CRF, and a fm package for training pairwise factorization machines. These implementations depend upon the stochastic gradient optimisers in the main Math package. The linear, fm, and crf packages can use any of the provided gradient optimisers, which enforce various different kinds of regularisation or convergence metrics. This is the preferred package for linear classification and for sequence classification due to the speed and scalability of the SGD approach. | | XGBoost | `tribuo-classification-xgboost` | `org.tribuo.classification.xgboost` | A wrapper around the XGBoost Java API. XGBoost requires a C library accessed via JNI. XGBoost is a scalable implementation of gradient boosted trees. | ## Multi-label Classification @@ -111,7 +125,7 @@ convert a classification trainer into a multi-label trainer. | Folder | ArtifactID | Package root | Description | | --- | --- | --- | --- | | Core | `tribuo-multilabel-core` | `org.tribuo.multilabel` | Contains an Output subclass for multi-label prediction, evaluation code for checking the performance of a multi-label model, and a basic implementation of independent binary predictions. It also contains implementations of Classifier Chains and Classifier Chain Ensembles, which are more powerful ensemble techniques for multi-label prediction tasks. | -| SGD | `tribuo-multilabel-sgd` | `org.tribuo.multilabel.sgd` | An implementation of stochastic gradient descent based classifiers. It includes a linear package for independent logistic regression and linear-SVM (using log and hinge losses, respectively) for each output label. These implementations depend upon the stochastic gradient optimisers in the main Math package. The linear package can use any of the provided gradient optimisers, which enforce various different kinds of regularisation or convergence metrics. | +| SGD | `tribuo-multilabel-sgd` | `org.tribuo.multilabel.sgd` | An implementation of stochastic gradient descent based classifiers. It includes a linear package for independent logistic regression and linear-SVM (using log and hinge losses, respectively), along with factorization machines using either loss for each output label. These implementations depend upon the stochastic gradient optimisers in the main Math package. The linear and fm packages can use any of the provided gradient optimisers, which enforce various different kinds of regularisation or convergence metrics. | ## Regression @@ -124,7 +138,7 @@ This package provides several modules: | LibLinear | `tribuo-regression-liblinear` | `org.tribuo.regression.liblinear` | A wrapper around the LibLinear-java library. This provides linear-SVMs and other l1 or l2 regularised linear regressions. | | LibSVM | `tribuo-regression-libsvm` | `org.tribuo.regression.libsvm` | A wrapper around the Java version of LibSVM. This provides linear & kernel SVRs with sigmoid, gaussian and polynomial kernels. | | RegressionTrees | `tribuo-regression-tree` | `org.tribuo.regression.rtree` | An implementation of two types of CART regression trees. The first type builds a separate tree per output dimension, while the second type builds a single tree for all outputs. | -| SGD | `tribuo-regression-sgd` | `org.tribuo.regression.sgd` | An implementation of stochastic gradient descent for linear regression. It uses the main Math package's set of gradient optimisers, which allow for various regularisation and descent algorithms. | +| SGD | `tribuo-regression-sgd` | `org.tribuo.regression.sgd` | An implementation of stochastic gradient descent for linear regression and factorization machine regression. It uses the main Math package's set of gradient optimisers, which allow for various regularisation and descent algorithms. | | SLM | `tribuo-regression-slm` | `org.tribuo.regression.slm` | An implementation of sparse linear models. It includes a co-ordinate descent implementation of ElasticNet, a LARS implementation, a LASSO implementation using LARS, and a couple of sequential forward selection algorithms. | | XGBoost | `tribuo-regression-xgboost` | `org.tribuo.regression.xgboost` | A wrapper around the XGBoost Java API. XGBoost requires a C library accessed via JNI. | @@ -137,6 +151,7 @@ one cluster. This package provides two modules: | Folder | ArtifactID | Package root | Description | | --- | --- | --- | --- | | Core | `tribuo-clustering-core` | `org.tribuo.clustering` | Contains the Output subclass for use with clustering data, as well as the evaluation code for measuring clustering performance. | +| HDBSCAN | `tribuo-clustering-hdbscan` | `org.tribuo.clustering.hdbscan` | An implementation of HDBSCAN, a non-parametric density based clustering algorithm. | | KMeans | `tribuo-clustering-kmeans` | `org.tribuo.clustering.kmeans` | An implementation of K-Means using the Java 8 Stream API for parallelisation, along with the K-Means++ initialization algorithm. | ## Anomaly Detection @@ -165,15 +180,22 @@ Randomized Trees (ExtraTrees). Tribuo supports loading a number of third party models which were trained outside the system (even in other programming languages) and scoring them from Java using Tribuo's infrastructure. Currently, we support loading ONNX, -TensorFlow and XGBoost models. +TensorFlow and XGBoost models. Additionally we support wrapping an +[OCI Data Science](https://www.oracle.com/data-science/cloud-infrastructure-data-science.html) +model deployment in a Tribuo model. +- OCI - Supports deploying Tribuo models to OCI Data Science, and wrapping OCI + Data Science models in Tribuo external models to allow them to be served with +other Tribuo models. - ONNX - [ONNX](https://onnx.ai) (Open Neural Network eXchange) format is used by several deep learning systems as an export format, and there are converters from systems like scikit-learn to the ONNX format. Tribuo provides a wrapper around Microsoft's [ONNX Runtime](https://onnxruntime.ai) that can score ONNX models on both CPU and GPU platforms. ONNX support is found in the `tribuo-onnx` artifact in the `org.tribuo.interop.onnx` package which also -provides a feature extractor that uses BERT embedding models. +provides a feature extractor that uses BERT embedding models. This package can +load Tribuo-exported ONNX models and extract the stored Tribuo provenance +objects from those models. - TensorFlow - Tribuo supports loading [TensorFlow](https://tensorflow.org)'s frozen graphs and saved models and scoring them. - XGBoost - Tribuo supports loading [XGBoost](https://xgboost.ai) @@ -181,8 +203,8 @@ provides a feature extractor that uses BERT embedding models. ## TensorFlow -Tribuo includes experimental support for TensorFlow-Java 0.3.1 (using -TensorFlow 2.4.1) in the `tribuo-tensorflow` artifact in the +Tribuo includes experimental support for TensorFlow-Java 0.4.0 (using +TensorFlow 2.7.0) in the `tribuo-tensorflow` artifact in the `org.tribuo.interop.tensorflow` package. Models can be defined using TensorFlow-Java's graph construction mechanisms, and Tribuo will manage the gradient optimizer output function and loss function. It includes a Java diff --git a/docs/Roadmap.md b/docs/Roadmap.md index c0c2401a0..85f2cae2a 100644 --- a/docs/Roadmap.md +++ b/docs/Roadmap.md @@ -28,7 +28,8 @@ specific operations (though this can be achieved today using `DatasetView` and p - Make `Example`s immutable after they've been added to a `Dataset`. This is likely to be a breaking change. - Add support for global feature transformations, like normalizing to a unit vector, applying PCA and others. - Integrate with a plotting library. -- ONNX format model export. +- ONNX format model export. + - In 4.2 we support exporting linear models, sparse linear models, factorization machines, liblinear, libsvm and ensembles containing the previously listed models. ## Internals @@ -53,22 +54,28 @@ examples, or examples which didn't have suitable features for the model). ## New ML algorithms or parameters -- ~~Add K-Means++ initialisation for K-Means.~~ Integrated in Tribuo 4.1. +- ~~Add K-Means++ initialisation for K-Means.~~ + - Integrated in Tribuo 4.1. - ~~Add extra parameters to the tree trainers to allow for an ExtraTrees style ensemble, and to -specify a minimum purity decrease requirement.~~ Integrated in Tribuo 4.1. +specify a minimum purity decrease requirement.~~ + - Integrated in Tribuo 4.1. - Gaussian Processes. - Vowpal Wabbit interface. - Feature selection. We already have several feature selection algorithms implemented in a Tribuo compatible interface, but the codebase isn't quite ready for release. - Support word embedding features. -- ~~Support contextualised word embeddings (through the ONNX or TensorFlow interfaces).~~ ONNX support for BERT embeddings is integrated in Tribuo 4.1. -- More complex Multi-Label prediction algorithms. +- ~~Support contextualised word embeddings (through the ONNX or TensorFlow interfaces).~~ + - ONNX support for BERT embeddings is integrated in Tribuo 4.1. +- ~~More complex Multi-Label prediction algorithms.~~ - A Multi-Label linear SGD is integrated in Tribuo 4.1. - - Classifier chains and classifier chain ensembles are planned for Tribuo 4.2. + - Multi-label factorization machines are integrated in Tribuo 4.2. + - Classifier chains and classifier chain ensembles are integrated in Tribuo 4.2. - More anomaly detection algorithms. - LibLinear based anomaly detection is integrated in Tribuo 4.1. - More clustering algorithms. -- Factorization machines for classification. + - Added HDBSCAN in Tribuo 4.2. +- ~~Factorization machines for classification and regression.~~ + - Integrated in Tribuo 4.2. ## Performance @@ -84,5 +91,9 @@ in a Tribuo compatible interface, but the codebase isn't quite ready for release ## Documentation -- Fill out the javadoc so it exists for all public and protected methods, including constructors. -- Add more tutorials. Note: Tribuo 4.0.2 adds tutorials for external model loading and columnar data processing, and 4.1 adds tutorials for TensorFlow and document classification +- Fill out the javadoc so it exists for all public and protected methods, including constructors. + - Javadoc for all public methods and fields is present in Tribuo 4.2. +- Add more tutorials. + - Tribuo 4.0.2 adds tutorials for external model loading and columnar data processing. + - Tribuo 4.1 adds tutorials for TensorFlow and document classification. + - Tribuo 4.2 adds tutorials for multi-label classification, ONNX export, and model reproducibility. diff --git a/docs/Security.md b/docs/Security.md index 924aefa5c..8a7b0f641 100644 --- a/docs/Security.md +++ b/docs/Security.md @@ -44,7 +44,7 @@ native code inside an application container like a JavaEE or JakartaEE server. Multiple instances of Tribuo running inside separate containers may cause issues with JNI library loading due to ClassLoader security considerations. -## Configuration +## SecurityManager configuration Tribuo uses [OLCUT](https://github.com/oracle/olcut)'s configuration and provenance systems, which use reflection to construct and inspect classes. Therefore, when running with a Java security manager, you need to give the @@ -52,7 +52,7 @@ OLCUT jar appropriate permissions. We have tested this set of permissions, which allows the configuration and provenance systems to work: // OLCUT permissions - grant codeBase "file:/path/to/olcut/olcut-core-5.1.6.jar" { + grant codeBase "file:/path/to/olcut/olcut-core-5.2.0.jar" { permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.util.logging.LoggingPermission "control"; @@ -68,7 +68,12 @@ This scope should be narrowed based on your requirements. If you need to save an OLCUT configuration, you will also need to add write permissions for the save location. -Similar file read and write permissions are necessary for Tribuo to be able to +Tribuo uses `ForkJoinPool` for parallelism, which requires the `modifyThread` +and `modifyThreadGroup` privileges when running under a `java.lang.SecurityManager`. +Therefore classes which have parallel execution inside will require those +permissions in addition to the ones listed for OLCUT above. + +File read and write permissions are necessary for Tribuo to be able to load and save models; therefore, you'll need to grant Tribuo those permissions using a similar snippet when running with a security manager. diff --git a/docs/example-configs/all-classification-config.xml b/docs/example-configs/all-classification-config.xml new file mode 100644 index 000000000..a0a9365c5 --- /dev/null +++ b/docs/example-configs/all-classification-config.xml @@ -0,0 +1,173 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/example-configs/all-multilabel-config.xml b/docs/example-configs/all-multilabel-config.xml new file mode 100644 index 000000000..9e8ee9534 --- /dev/null +++ b/docs/example-configs/all-multilabel-config.xml @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/example-configs/all-regression-config.xml b/docs/example-configs/all-regression-config.xml new file mode 100644 index 000000000..493d51680 --- /dev/null +++ b/docs/example-configs/all-regression-config.xml @@ -0,0 +1,164 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/jep-290-filter.txt b/docs/jep-290-filter.txt index 6de2395b2..3d004a0ca 100644 --- a/docs/jep-290-filter.txt +++ b/docs/jep-290-filter.txt @@ -1 +1 @@ -org.tribuo.**;libsvm.svm_model;libsvm.svm_parameter;libsvm.svm_node;de.bwaldvogel.liblinear.Model;de.bwaldvogel.liblinear.SolverType;java.util.**;java.lang.*;!*; +org.tribuo.**;com.oracle.labs.mlrg.olcut.util.*;com.oracle.labs.mlrg.olcut.provenance.**;com.oracle.labs.mlrg.olcut.config.*;libsvm.svm_model;libsvm.svm_parameter;libsvm.svm_node;de.bwaldvogel.liblinear.Model;de.bwaldvogel.liblinear.SolverType;java.util.**;java.io.File;java.nio.file.Path;java.net.URL;java.time.*;java.lang.*;!* diff --git a/docs/release-notes/tribuo-v4-1-1-release-notes.md b/docs/release-notes/tribuo-v4-1-1-release-notes.md new file mode 100644 index 000000000..55619c68b --- /dev/null +++ b/docs/release-notes/tribuo-v4-1-1-release-notes.md @@ -0,0 +1,48 @@ +# Tribuo v4.1.1 Release Notes + +This is the first patch release for Tribuo v4.1. The main fixes in this release +are to the multi-dimensional output regression support, and to support the use +of KMeans and KNN models when running under a restrictive `SecurityManager`. +Additionally this release pulls in TensorFlow-Java 0.4.0 which upgrades the +TensorFlow native library to 2.7.0 fixing several CVEs. Note those CVEs may not +be applicable to TensorFlow-Java, as many of them relate to Python codepaths +which are not included in TensorFlow-Java. Note the TensorFlow upgrade is a +breaking API change as graph initialization is handled differently in this +release, which causes unavoidable changes in Tribuo's TF API. + +## Multi-dimensional Regression fix + +In Tribuo 4.1.0 and earlier there is a severe bug in multi-dimensional +regression models (i.e., regression tasks with multiple output dimensions). +Models other than `LinearSGDModel` and `SparseLinearModel` (apart from when +using the `ElasticNetCDTrainer`) have a bug in how the output dimension indices +are constructed, and may produce incorrect outputs for all dimensions (as the +output will be for a different dimension than the one named in the `Regressor` +object). This has been fixed, and loading in models trained in earlier versions +of Tribuo will patch the model to rearrange the dimensions appropriately. +Unfortunately this fix cannot be applied to tree based models, and so all +multi-output regression tree based models should be retrained using Tribuo 4.2 +as they are irretrievably corrupt. Additionally when using standardization in +multi-output regression LibSVM models dimensions past the first dimension have +the model improperly stored and will also need to be retrained with Tribuo 4.2. +See [#177](https://github.com/oracle/tribuo/pull/177) for more details. + +## Bug fixes + +- NPE fix for LIME explanations using models which don't support per class weights ([#157](https://github.com/oracle/tribuo/pull/157)). +- Fixing a bug in multi-label evaluation which swapped FP for FN ([#167](https://github.com/oracle/tribuo/pull/167)). +- Fixing LibSVM and LibLinear so they have reproducible behaviour ([#172](https://github.com/oracle/tribuo/pull/172)). +- Provenance fix for TransformTrainer and an extra factory for XGBoostExternalModel so you can make them from an in memory booster ([#176](https://github.com/oracle/tribuo/pull/176)) +- Fix multidimensional regression ([#177](https://github.com/oracle/tribuo/pull/177)) (fixes regression ids, fixes libsvm so it emits correct standardized models, adds support for per dimension feature weights in XGBoostRegressionModel). +- Normalize LibSVMDataSource paths consistently in the provenance ([#181](https://github.com/oracle/tribuo/pull/181)). +- KMeans and KNN now run correctly when using OpenSearch's SecurityManager ([#197](https://github.com/oracle/tribuo/pull/197)). +- TensorFlow-Java 0.4.0 ([#195](https://github.com/oracle/tribuo/pull/195)). + + +## Contributors + +- Adam Pocock ([@Craigacp](https://github.com/Craigacp)) +- Jack Sullivan ([@JackSullivan](https://github.com/JackSullivan)) +- Philip Ogren ([@pogren](https://github.com/pogren)) +- Jeffrey Alexander ([@jhalexand](https://github.com/jhalexand)) + diff --git a/docs/release-notes/tribuo-v4-2-release-notes.md b/docs/release-notes/tribuo-v4-2-release-notes.md new file mode 100644 index 000000000..3f64f4faf --- /dev/null +++ b/docs/release-notes/tribuo-v4-2-release-notes.md @@ -0,0 +1,174 @@ +# Tribuo v4.2 Release Notes + +Tribuo 4.2 adds new models, ONNX export for several types of models, a +reproducibility framework for recreating Tribuo models, easy deployment of +Tribuo models on Oracle Cloud, along with several smaller improvements and bug +fixes. We've added more tutorials covering the new features along with +multi-label classification, and further expanded the javadoc to cover all +public methods. + +In Tribuo 4.1.0 and earlier there is a severe bug in multi-dimensional +regression models (i.e., regression tasks with multiple output dimensions). +Models other than `LinearSGDModel` and `SparseLinearModel` (apart from when +using the `ElasticNetCDTrainer`) have a bug in how the output dimension indices +are constructed, and may produce incorrect outputs for all dimensions (as the +output will be for a different dimension than the one named in the `Regressor` +object). This has been fixed, and loading in models trained in earlier versions +of Tribuo will patch the model to rearrange the dimensions appropriately. +Unfortunately this fix cannot be applied to tree based models, and so all +multi-output regression tree based models should be retrained using Tribuo 4.2 +as they are irretrievably corrupt. Additionally when using standardization in +multi-output regression LibSVM models dimensions past the first dimension have +the model improperly stored and will also need to be retrained with Tribuo 4.2. +See [#177](https://github.com/oracle/tribuo/pull/177) for more details. + +Note the KMeans implementation had several internal changes to support running +with a `java.lang.SecurityManager` which will break any subclasses of `KMeansTrainer`. +In most cases changing the signature of any overridden `mStep` method to match +the new signature, and allowing the `fjp` argument to be null in single threaded +execution will fix the subclass. + +## New models + +In this release we've added [Factorization +Machines](https://www.computer.org/csdl/proceedings-article/icdm/2010/4256a995/12OmNwMFMfl), +[Classifier +Chains](https://link.springer.com/content/pdf/10.1007/s10994-011-5256-5.pdf) +and +[HDBSCAN\*](https://link.springer.com/chapter/10.1007/978-3-642-37456-2_14). +Factorization machines are a powerful non-linear predictor which uses a +factorized approximation to learn a per output feature-feature interaction term +in addition to a linear model. We've added Factorization Machines for +multi-class classification, multi-label classification and regression. +Classifier chains are an ensemble approach to multi-label classification which +given a specific ordering of the labels learns a chain of classifiers where +each classifier gets the features along with the predicted labels from earlier +in the chain. We also added ensembles of randomly ordered classifier chains +which work well in situations when the ground truth label ordering is unknown +(i.e., most of the time). HDBSCAN is a hierarchical density based clustering +algorithm which chooses the number of clusters based on properties of the data +rather than as a hyperparameter. The Tribuo implementation can cluster a +dataset, and then at prediction time it provides the cluster the given +datapoint would be in without modifying the cluster structure. + +- Classifier Chains ([#149](https://github.com/oracle/tribuo/pull/149)), which + also adds the jaccard score as a multi-label evaluation metric, and a +multi-label voting combiner for use in multi-label ensembles. +- Factorization machines ([#179](https://github.com/oracle/tribuo/pull/179)). +- HDBSCAN ([#196](https://github.com/oracle/tribuo/pull/196)). + +## ONNX Export + +The [ONNX](https://onnx.ai) format is a cross-platform and cross-library model +exchange format. Tribuo can already serve ONNX models via its [ONNX +Runtime](https://onnxruntime.ai) interface, and now has the ability to export +models in ONNX format for serving on edge devices, in cloud services, or in +other languages like Python or C#. + +In this release Tribuo supports exporting linear models (multi-class +classification, multi-label classification and regression), sparse linear +regression models, factorization machines (multi-class classification, +multi-label classification and regression), LibLinear models (multi-class +classification and regression), LibSVM models (multi-class classification and +regression), along with ensembles of those models, including arbitrary levels +of ensemble nesting. We plan to expand this coverage to more models over time, +however for TensorFlow we recommend users export those models as a Saved Model +and use the Python tf2onnx converter. + +Tribuo models exported in ONNX format preserve their provenance information in +a metadata field which is accessible when the ONNX model is loaded back into +Tribuo. The provenance is stored as a protobuf so could be read from other +libraries or platforms if necessary. + +The ONNX export support is in a separate module with no dependencies, and could +be used elsewhere on the JVM to support generating ONNX graphs. We welcome +contributions to build out the ONNX support in that module. + +- ONNX export for LinearSGDModels + ([#154](https://github.com/oracle/tribuo/pull/154)), which also adds a +multi-label output transformer for scoring multi-label ONNX models. +- ONNX export for SparseLinearModel ([#163](https://github.com/oracle/tribuo/pull/163)). +- Add provenance to ONNX exported models ([#182](https://github.com/oracle/tribuo/pull/182)). +- Refactor ONNX tensor creation ([#187](https://github.com/oracle/tribuo/pull/187)). +- ONNX ensemble export support ([#186](https://github.com/oracle/tribuo/pull/186)). +- ONNX export for LibSVM and LibLinear ([#191](https://github.com/oracle/tribuo/pull/191)). +- Refactor ONNX support to improve type safety ([#199](https://github.com/oracle/tribuo/pull/199)). +- Extract ONNX support into separate module ([#TBD](https://github.com/oracle/tribuo/pull/)). + +## Reproducibility Framework + +Tribuo has strong model metadata support via its provenance system which +records how models, datasets and evaluations are created. In this release we +enhance this support by adding a push-button reproduction framework which +accepts either a model provenance or a model object and rebuilds the complete +training pipeline, ensuring consistent usage of RNGs and other mutable state. + +This allows Tribuo to easily rebuild models to see if updated datasets could +change performance, or even if the model is actually reproducible (which may be +required for regulatory reasons). Over time we hope to expand this support +into a full experimental framework, allowing models to be rebuilt with +hyperparameter or data changes as part of the data science process or for +debugging models in production. + +This framework was written by Joseph Wonsil and Prof. Margo Seltzer at the +University of British Columbia as part of a collaboration between Prof. Seltzer +and Oracle Labs. We're excited to continue working with Joe, Margo and the rest +of the lab at UBC, as this is excellent work. + +Note the reproducibility framework module requires Java 16 or greater, and is +thus not included in the `tribuo-all` meta-module. + +- Reproducibility framework ([#185](https://github.com/oracle/tribuo/pull/185), with minor changes in [#189](https://github.com/oracle/tribuo/pull/189) and [#190](https://github.com/oracle/tribuo/pull/190)). + +## OCI Data Science Integration + +[Oracle Cloud Data +Science](https://www.oracle.com/data-science/cloud-infrastructure-data-science.html) +is a platform for building and deploying models in Oracle Cloud. The model +deployment functionality wraps a Python runtime and deploys them with an +auto-scaler at a REST endpoint. In this release we've added support for +deploying Tribuo models which are ONNX exportable directly to OCI DS, allowing +scale-out deployments of models from the JVM. We also added a `OCIModel` +wrapper which scores Tribuo `Example` objects using a deployed model's REST +endpoint, allowing easy use of cloud resources for ML on the JVM. + +- Oracle Cloud Data Science integration ([#200](https://github.com/oracle/tribuo/pull/200)). + +## Small improvements + +- Date field processor and locale support in metadata extractors ([#148](https://github.com/oracle/tribuo/pull/148)) +- Multi-output response processor allowing loading different formats of multi-label and multi-dimensional regression datasets ([#150](https://github.com/oracle/tribuo/pull/150)) +- ARM dev profile for compiling Tribuo on ARM platforms ([#152](https://github.com/oracle/tribuo/pull/152)) +- Refactor CSVLoader so it uses CSVDataSource and parses CSV files using RowProcessor, allowing an easy transition to more complex columnar extraction ([#153](https://github.com/oracle/tribuo/pull/153)) +- Configurable anomaly demo data source ([#160](https://github.com/oracle/tribuo/pull/160)) +- Configurable clustering demo data source ([#161](https://github.com/oracle/tribuo/pull/161)) +- Configurable classification demo data source ([#162](https://github.com/oracle/tribuo/pull/162)) +- Multi-Label tutorial and configurable multi-label demo data source ([#166](https://github.com/oracle/tribuo/pull/166)) (also adds a multi-label tutorial) plus fix in [#168](https://github.com/oracle/tribuo/pull/168) after #167 +- Add javadoc for all public methods and fields ([#175](https://github.com/oracle/tribuo/pull/175)) (also fixes a bug in Util.vectorNorm) +- Add hooks for model equality checks to trees and LibSVM models ([#183](https://github.com/oracle/tribuo/pull/183)) (also fixes a bug in liblinear get top features) +- XGBoost 1.5.0 ([#192](https://github.com/oracle/tribuo/pull/192)) +- TensorFlow Java 0.4.0 ([#195](https://github.com/oracle/tribuo/pull/195)) (note this changes Tribuo's TF API slightly as TF-Java 0.4.0 has a different method of initializing the session) +- KMeans now uses dense vectors when appropriate, speeding up training ([#201](https://github.com/oracle/tribuo/pull/201)) +- Documentation updates, ONNX and reproducibility tutorials ([#205](https://github.com/oracle/tribuo/pull/205)) + +## Bug fixes + +- NPE fix for LIME explanations using models which don't support per class weights ([#157](https://github.com/oracle/tribuo/pull/157)) +- Fixing a bug in multi-label evaluation which swapped FP for FN ([#167](https://github.com/oracle/tribuo/pull/167)) +- Persist CSVDataSource headers in the provenance ([#171](https://github.com/oracle/tribuo/pull/171)) +- Fixing LibSVM and LibLinear so they have reproducible behaviour ([#172](https://github.com/oracle/tribuo/pull/172)) +- Provenance fix for TransformTrainer and an extra factory for XGBoostExternalModel so you can make them from an in memory booster ([#176](https://github.com/oracle/tribuo/pull/176)) +- Fix multidimensional regression ([#177](https://github.com/oracle/tribuo/pull/177)) (fixes regression ids, fixes libsvm so it emits correct standardized models, adds support for per dimension feature weights in XGBoostRegressionModel) +- Fix provenance generation for FieldResponseProcessor and BinaryResponseProcessor ([#178](https://github.com/oracle/tribuo/pull/178)) +- Normalize LibSVMDataSource paths consistently in the provenance ([#181](https://github.com/oracle/tribuo/pull/181)) +- KMeans and KNN now run correctly when using OpenSearch's SecurityManager ([#197](https://github.com/oracle/tribuo/pull/197)) + +## Contributors + +- Adam Pocock ([@Craigacp](https://github.com/Craigacp)) +- Jack Sullivan ([@JackSullivan](https://github.com/JackSullivan)) +- Joseph Wonsil ([@jwons](https://github.com/jwons)) +- Philip Ogren ([@pogren](https://github.com/pogren)) +- Jeffrey Alexander ([@jhalexand](https://github.com/jhalexand)) +- Geoff Stewart ([@geoffreydstewart](https://github.com/geoffreydstewart)) + diff --git a/pom.xml b/pom.xml index 06a57076f..aa5b22859 100644 --- a/pom.xml +++ b/pom.xml @@ -214,6 +214,12 @@ maven-javadoc-plugin 3.3.1 + + -Xmaxerrs + 65536 + -Xmaxwarns + 65536 + 8 protected true @@ -288,6 +294,12 @@ site + + -Xmaxerrs + 65536 + -Xmaxwarns + 65536 + ./Core/src/main/javadoc/overview.html Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved. diff --git a/tutorials/README.md b/tutorials/README.md index 0524d6c0a..a51c8334e 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -5,8 +5,9 @@ These tutorials require the [IJava](https://github.com/SpencerPark/IJava) Jupyte The tutorials expect the data and required jars to be in the same directory as the notebooks. The dataset download links are given in the tutorial, and Tribuo's jars are on Maven Central, attached to the GitHub release, or you can build it yourself with `mvn clean package` using Apache Maven. -The code in them should work on Java 8 with the addition of types to replace the use of the `var` keyword -added in Java 10, and replacing the collections factories introduced in Java 9. +In most cases code in them should work on Java 8 with the addition of types to replace the use of the `var` keyword +added in Java 10, and replacing the collections factories introduced in Java 9, with the exception of the reproducibility +tutorial which requires Java 16+ as the reproducibility package uses newer Java features. The tutorials cover: - [Intro classification with Irises](irises-tribuo-v4.ipynb) @@ -20,3 +21,5 @@ The tutorials cover: - [Document classification and extracting features from text](document-classification-tribuo-v4.ipynb) - [Importing third-party models](external-models-tribuo-v4.ipynb) - [Training and deploying TensorFlow models](tensorflow-tribuo-v4.ipynb) +- [ONNX export and deployment](onnx-export-tribuo-v4.ipynb) +- [Model reproducibility](reproducibility-tribuo-v4.ipynb) diff --git a/tutorials/configuration-tribuo-v4.ipynb b/tutorials/configuration-tribuo-v4.ipynb index f54cf227c..0ab70c3d5 100644 --- a/tutorials/configuration-tribuo-v4.ipynb +++ b/tutorials/configuration-tribuo-v4.ipynb @@ -89,6 +89,13 @@ "ConfigurationManager.addFileFormatFactory(new JsonConfigFactory())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "OLCUT supports XML, JSON, [edn](https://github.com/edn-format/edn), and [protobuf](https://developers.google.com/protocol-buffers) format configuration files. It also supports serialization for `Provenance` objects in XML, JSON, and protobuf formats." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -131,7 +138,7 @@ "source": [ "var className = \"org.tribuo.classification.sgd.linear.LinearSGDTrainer\";\n", "var clazz = (Class) Class.forName(className);\n", - "Map map = DescribeConfigurable.generateFieldInfo(clazz);\n", + "var map = DescribeConfigurable.generateFieldInfo(clazz);\n", "\n", "var output = DescribeConfigurable.generateDescription(map);\n", "\n", @@ -183,13 +190,6 @@ "System.out.println(writer.toString(\"UTF-8\"));" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "At the moment using it from the REPL is missing some type information in `DescribeConfigurable.generateFieldInfo`, we'll fix that in the next OLCUT release." - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -774,6 +774,8 @@ "source": [ "We can see that both models perform identically. This is because our provenance system records the RNG seeds used at all points, and Tribuo is scrupulous about how and when it uses PRNGs. If you find a model reconstruction that gives a different answer (unless you're using XGBoost or TensorFlow, both of which have some non-determinism beyond our control) then file an issue on our GitHub as that's a bug.\n", "\n", + "We provide a simple push-button replication facility in the `tribuo-reproducibility` project, see the tutorial on reproducibilty for more details.\n", + "\n", "## What else lives in the Provenance?\n", "\n", "These evaluations have provenance in the same way the models do, and we can use a pretty printer in OLCUT to make it a little more human readable.\n", @@ -1093,7 +1095,7 @@ "metadata": {}, "source": [ "## Conclusion\n", - "We've taken a closer look at Tribuo's configuration and provenance systems, showing how to train a model using a configuration file, how to inspect the model's provenance, extract it's configuration, and finally how to combine that extracted configuration with other programmatic elements of the Tribuo library (in this case the feature transformation system). We saw that the provenance combines both the configuration of the trainer and the datasource, along with runtime information extracted from the dataset itself (e.g., timestamps and file hashes).\n", + "We've taken a closer look at Tribuo's configuration and provenance systems, showing how to train a model using a configuration file, how to inspect the model's provenance, extract it's configuration, and finally how to combine that extracted configuration with other programmatic elements of the Tribuo library (in this case the feature transformation system). We saw that the provenance combines both the configuration of the trainer and the datasource, along with runtime information extracted from the dataset itself (e.g., timestamps and file hashes). Tribuo's provenance objects are also persisted in ONNX model files exported from Tribuo, and these provenances can be recovered later using Tribuo's `ONNXExternalModel` class which provides ONNX model inference. For more details on ONNX export see the ONNX export and deployment tutorial.\n", "\n", "Tribuo's configuration system is integrated into a CLI options/arguments parsing system, which can be used to override elements from the configuration file. The values from the options are then stored in the `ConfigurationManager` and appear in the provenance and downstream configuration objects as expected. Tribuo also provides a redaction system for configuration files (e.g., to ensure a password isn't stored in the provenance) and for provenance objects themselves (e.g., to remove the data provenance from a trained model), which aids model deployment to untrusted or less trusted systems." ] diff --git a/tutorials/external-models-tribuo-v4.ipynb b/tutorials/external-models-tribuo-v4.ipynb index a3d0aa50d..f9cf61689 100644 --- a/tutorials/external-models-tribuo-v4.ipynb +++ b/tutorials/external-models-tribuo-v4.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "# Working with external models\n", - "Tribuo can load in models trained in third party systems and deploy them alongside native Tribuo models. In Tribuo 4.1 we support models trained externally in [XGBoost](https://xgboost.ai), [TensorFlow](https://tensorflow.org) frozen graphs & saved models, and models stored in ONNX (Open Neural Network eXchange) format. The latter is particularly interesting for Tribuo as many libraries can export models in ONNX format, such as [scikit-learn](https://scikit-learn.org), [pytorch](https://pytorch.org), TensorFlow among others. For a more complete list of the supported onnx models you can look at the [ONNX website](https://onnx.ai). Tribuo's ONNX support is supplied by [ONNX Runtime](https://microsoft.github.io/onnxruntime/), using the Java interface our group in Oracle Labs contributed to that project.\n", + "Tribuo can load in models trained in third party systems and deploy them alongside native Tribuo models. In Tribuo 4.1+ we support models trained externally in [XGBoost](https://xgboost.ai), [TensorFlow](https://tensorflow.org) frozen graphs & saved models, and models stored in ONNX (Open Neural Network eXchange) format. The latter is particularly interesting for Tribuo as many libraries can export models in ONNX format, such as [scikit-learn](https://scikit-learn.org), [pytorch](https://pytorch.org), TensorFlow among others. For a more complete list of the supported onnx models you can look at the [ONNX website](https://onnx.ai). Tribuo's ONNX support is supplied by [ONNX Runtime](https://microsoft.github.io/onnxruntime/), using the Java interface our group in Oracle Labs contributed to that project. Tribuo 4.2 added support for exporting models in ONNX format, and those models can be loaded back in to Tribuo using our ONNX Runtime interface.\n", "\n", "In this tutorial we'll look at loading in models trained in XGBoost, scikit-learn and pytorch, all for MNIST and we'll deploy them next to a logistic regression model trained in Tribuo. We discuss using external TensorFlow models in the [TensorFlow tutorial](https://github.com/oracle/tribuo/blob/main/tutorials/tensorflow-tribuo-v4.ipynb), as TensorFlow brings it's own complexities. Note these models all depend on native libraries, which are available for x86\\_64 platforms on Windows, Linux and macOS. Both ONNX Runtime and XGBoost support macOS arm64 (i.e., Apple Silicon Macs), but you'll need to compile those from source and add them to Tribuo's class path to make this tutorial run on that platform.\n", "\n", @@ -453,7 +453,7 @@ "## Conclusion\n", "We saw how to load in externally trained models in multiple formats, and how to deploy those models alongside Tribuo's native models. We also looked at how ONNX models can accept different tensor shapes as inputs, and used Tribuo's mechanisms for converting an `Example` into either a vector or a tensor depending on if the external model expected a vector or an image as an input.\n", "\n", - "Given how useful the ONNX model import code is, allowing Tribuo to load in many different kinds of models trained in many different libraries, it's natural to ask what support Tribuo has for exporting ONNX models. At the moment we don't support exporting Tribuo's native models to ONNX format, but we're investigating how to do this purely from Java, and we hope to be able to do this in a future release." + "Given how useful the ONNX model import code is, allowing Tribuo to load in many different kinds of models trained in many different libraries, it's natural to ask what support Tribuo has for exporting ONNX models. As of 4.2 Tribuo can export linear models, sparse linear models, LibLinear, LibSVM, factorization machines, and ensembles thereof. We plan to expand this to cover more of Tribuo's models over time." ] } ], @@ -469,7 +469,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "17-ea+22-1964" + "version": "17+35-LTS-2724" } }, "nbformat": 4, diff --git a/tutorials/irises-tribuo-v4.ipynb b/tutorials/irises-tribuo-v4.ipynb index 7f06b253f..fa10fe888 100644 --- a/tutorials/irises-tribuo-v4.ipynb +++ b/tutorials/irises-tribuo-v4.ipynb @@ -336,17 +336,82 @@ "text": [ "TrainTestSplitter(\n", "\tclass-name = org.tribuo.evaluation.TrainTestSplitter\n", - "\tsource = CSVLoader(\n", - "\t\t\tclass-name = org.tribuo.data.csv.CSVLoader\n", + "\tsource = CSVDataSource(\n", + "\t\t\tclass-name = org.tribuo.data.csv.CSVDataSource\n", + "\t\t\theaders = List[\n", + "\t\t\t\tsepalLength\n", + "\t\t\t\tsepalWidth\n", + "\t\t\t\tpetalLength\n", + "\t\t\t\tpetalWidth\n", + "\t\t\t\tspecies\n", + "\t\t\t]\n", + "\t\t\trowProcessor = RowProcessor(\n", + "\t\t\t\t\tclass-name = org.tribuo.data.columnar.RowProcessor\n", + "\t\t\t\t\tmetadataExtractors = List[]\n", + "\t\t\t\t\tfieldProcessorList = List[\n", + "\t\t\t\t\t\tDoubleFieldProcessor(\n", + "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n", + "\t\t\t\t\t\t\t\t\tfieldName = petalLength\n", + "\t\t\t\t\t\t\t\t\tonlyFieldName = true\n", + "\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n", + "\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n", + "\t\t\t\t\t\t\t\t)\n", + "\t\t\t\t\t\tDoubleFieldProcessor(\n", + "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n", + "\t\t\t\t\t\t\t\t\tfieldName = petalWidth\n", + "\t\t\t\t\t\t\t\t\tonlyFieldName = true\n", + "\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n", + "\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n", + "\t\t\t\t\t\t\t\t)\n", + "\t\t\t\t\t\tDoubleFieldProcessor(\n", + "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n", + "\t\t\t\t\t\t\t\t\tfieldName = sepalWidth\n", + "\t\t\t\t\t\t\t\t\tonlyFieldName = true\n", + "\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n", + "\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n", + "\t\t\t\t\t\t\t\t)\n", + "\t\t\t\t\t\tDoubleFieldProcessor(\n", + "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\n", + "\t\t\t\t\t\t\t\t\tfieldName = sepalLength\n", + "\t\t\t\t\t\t\t\t\tonlyFieldName = true\n", + "\t\t\t\t\t\t\t\t\tthrowOnInvalid = true\n", + "\t\t\t\t\t\t\t\t\thost-short-name = FieldProcessor\n", + "\t\t\t\t\t\t\t\t)\n", + "\t\t\t\t\t]\n", + "\t\t\t\t\tfeatureProcessors = List[]\n", + "\t\t\t\t\tresponseProcessor = FieldResponseProcessor(\n", + "\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.processors.response.FieldResponseProcessor\n", + "\t\t\t\t\t\t\tuppercase = false\n", + "\t\t\t\t\t\t\tfieldNames = List[\n", + "\t\t\t\t\t\t\t\tspecies\n", + "\t\t\t\t\t\t\t]\n", + "\t\t\t\t\t\t\tdefaultValues = List[\n", + "\t\t\t\t\t\t\t\t\n", + "\t\t\t\t\t\t\t]\n", + "\t\t\t\t\t\t\tdisplayField = false\n", + "\t\t\t\t\t\t\toutputFactory = LabelFactory(\n", + "\t\t\t\t\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n", + "\t\t\t\t\t\t\t\t)\n", + "\t\t\t\t\t\t\thost-short-name = ResponseProcessor\n", + "\t\t\t\t\t\t)\n", + "\t\t\t\t\tweightExtractor = FieldExtractor(\n", + "\t\t\t\t\t\t\tclass-name = org.tribuo.data.columnar.FieldExtractor\n", + "\t\t\t\t\t\t)\n", + "\t\t\t\t\treplaceNewlinesWithSpaces = true\n", + "\t\t\t\t\tregexMappingProcessors = Map{}\n", + "\t\t\t\t\thost-short-name = RowProcessor\n", + "\t\t\t\t)\n", + "\t\t\tquote = \"\n", + "\t\t\toutputRequired = true\n", "\t\t\toutputFactory = LabelFactory(\n", "\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n", "\t\t\t\t)\n", - "\t\t\tresponse-name = species\n", "\t\t\tseparator = ,\n", - "\t\t\tquote = \"\n", - "\t\t\tpath = file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\n", - "\t\t\tfile-modified-time = 1999-12-14T15:12:39-05:00\n", + "\t\t\tdataPath = /Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\n", "\t\t\tresource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC\n", + "\t\t\tfile-modified-time = 1999-12-14T15:12:39-05:00\n", + "\t\t\tdatasource-creation-time = 2021-11-01T12:52:18.814629-04:00\n", + "\t\t\thost-short-name = DataSource\n", "\t\t)\n", "\ttrain-proportion = 0.7\n", "\tseed = 1\n", @@ -365,7 +430,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can see the model was trained on a datasource which was split in two, using a specific random seed & split percentage. The original datasource was a CSV file, and the file modified time and SHA-256 hash are recorded too.\n", + "We can see the model was trained on a datasource which was split in two, using a specific random seed & split percentage. The original datasource was a CSV file, and the file modified time and SHA-256 hash are recorded too. As of Tribuo v4.2 `CSVLoader` now generates a `CSVDataSource` allowing simpler migration to more complex columnar processing than the old method, along with producing more accurate provenance information suitable for automatic reproduction of models.\n", "\n", "We can similarly inspect the trainer provenance to find out about the training algorithm." ] @@ -397,7 +462,7 @@ "\t\t\tclass-name = org.tribuo.classification.sgd.objectives.LogMulticlass\n", "\t\t\thost-short-name = LabelObjective\n", "\t\t)\n", - "\ttribuo-version = 4.1.0\n", + "\ttribuo-version = 4.2.0-SNAPSHOT\n", "\ttrain-invocation-count = 0\n", "\tis-sequence = false\n", "\thost-short-name = Trainer\n", @@ -458,7 +523,7 @@ " \"tribuo-version\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"tribuo-version\",\n", - " \"value\" : \"4.1.0\",\n", + " \"value\" : \"4.2.0-SNAPSHOT\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -466,7 +531,7 @@ " \"java-version\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"java-version\",\n", - " \"value\" : \"17-ea\",\n", + " \"value\" : \"17\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -490,7 +555,7 @@ " \"trained-at\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"trained-at\",\n", - " \"value\" : \"2021-05-24T12:27:10.387150-04:00\",\n", + " \"value\" : \"2021-11-01T12:52:19.228195-04:00\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -553,7 +618,7 @@ " \"tribuo-version\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"tribuo-version\",\n", - " \"value\" : \"4.1.0\",\n", + " \"value\" : \"4.2.0-SNAPSHOT\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -612,7 +677,7 @@ " \"tribuo-version\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"tribuo-version\",\n", - " \"value\" : \"4.1.0\",\n", + " \"value\" : \"4.2.0-SNAPSHOT\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -731,8 +796,8 @@ " \"source\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"source\",\n", - " \"value\" : \"csvloader-6\",\n", - " \"provenance-class\" : \"org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance\",\n", + " \"value\" : \"csvdatasource-6\",\n", + " \"provenance-class\" : \"org.tribuo.data.csv.CSVDataSource$CSVDataSourceProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : true\n", " },\n", @@ -825,31 +890,70 @@ " }\n", "}, {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", - " \"object-name\" : \"csvloader-6\",\n", - " \"object-class-name\" : \"org.tribuo.data.csv.CSVLoader\",\n", - " \"provenance-class\" : \"org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance\",\n", + " \"object-name\" : \"csvdatasource-6\",\n", + " \"object-class-name\" : \"org.tribuo.data.csv.CSVDataSource\",\n", + " \"provenance-class\" : \"org.tribuo.data.csv.CSVDataSource$CSVDataSourceProvenance\",\n", " \"map\" : {\n", " \"resource-hash\" : {\n", - " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", - " \"key\" : \"resource-hash\",\n" + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ + " \"key\" : \"resource-hash\",\n", " \"value\" : \"0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance\",\n", " \"additional\" : \"SHA256\",\n", " \"is-reference\" : false\n", " },\n", - " \"path\" : {\n", + " \"headers\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"headers\",\n", + " \"value\" : \"sepalLength\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"headers\",\n", + " \"value\" : \"sepalWidth\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"headers\",\n", + " \"value\" : \"petalLength\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"headers\",\n", + " \"value\" : \"petalWidth\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"headers\",\n", + " \"value\" : \"species\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " } ]\n", + " },\n", + " \"rowProcessor\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", - " \"key\" : \"path\",\n", - " \"value\" : \"file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\",\n", - " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance\",\n", + " \"key\" : \"rowProcessor\",\n", + " \"value\" : \"rowprocessor-7\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", " \"additional\" : \"\",\n", - " \"is-reference\" : false\n", + " \"is-reference\" : true\n", " },\n", " \"file-modified-time\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", @@ -867,18 +971,26 @@ " \"additional\" : \"\",\n", " \"is-reference\" : false\n", " },\n", - " \"response-name\" : {\n", + " \"outputRequired\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", - " \"key\" : \"response-name\",\n", - " \"value\" : \"species\",\n", - " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"key\" : \"outputRequired\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"datasource-creation-time\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"datasource-creation-time\",\n", + " \"value\" : \"2021-11-01T12:52:18.814629-04:00\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", " },\n", " \"outputFactory\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"outputFactory\",\n", - " \"value\" : \"labelfactory-7\",\n", + " \"value\" : \"labelfactory-15\",\n", " \"provenance-class\" : \"org.tribuo.classification.LabelFactory$LabelFactoryProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : true\n", @@ -891,10 +1003,123 @@ " \"additional\" : \"\",\n", " \"is-reference\" : false\n", " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"DataSource\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.csv.CSVDataSource\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"dataPath\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"dataPath\",\n", + " \"value\" : \"/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.FileProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"rowprocessor-7\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.RowProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"metadataExtractors\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ ]\n", + " },\n", + " \"fieldProcessorList\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldProcessorList\",\n", + " \"value\" : \"doublefieldprocessor-9\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldProcessorList\",\n", + " \"value\" : \"doublefieldprocessor-10\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldProcessorList\",\n", + " \"value\" : \"doublefieldprocessor-11\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " }, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldProcessorList\",\n", + " \"value\" : \"doublefieldprocessor-12\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " } ]\n", + " },\n", + " \"featureProcessors\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ ]\n", + " },\n", + " \"responseProcessor\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"responseProcessor\",\n", + " \"value\" : \"fieldresponseprocessor-13\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " },\n", + " \"weightExtractor\" : {\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"weightExtractor\",\n", + " \"value\" : \"fieldextractor-14\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.NullConfiguredProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " },\n", + " \"replaceNewlinesWithSpaces\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"replaceNewlinesWithSpaces\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"regexMappingProcessors\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance\",\n", + " \"map\" : { }\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"RowProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", " \"class-name\" : {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", " \"key\" : \"class-name\",\n", - " \"value\" : \"org.tribuo.data.csv.CSVLoader\",\n", + " \"value\" : \"org.tribuo.data.columnar.RowProcessor\",\n", " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", " \"additional\" : \"\",\n", " \"is-reference\" : false\n", @@ -902,7 +1127,7 @@ " }\n", "}, {\n", " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", - " \"object-name\" : \"labelfactory-7\",\n", + " \"object-name\" : \"labelfactory-15\",\n", " \"object-class-name\" : \"org.tribuo.classification.LabelFactory\",\n", " \"provenance-class\" : \"org.tribuo.classification.LabelFactory$LabelFactoryProvenance\",\n", " \"map\" : {\n", @@ -915,6 +1140,284 @@ " \"is-reference\" : false\n", " }\n", " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"doublefieldprocessor-9\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"fieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldName\",\n", + " \"value\" : \"petalLength\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"onlyFieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"onlyFieldName\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"throwOnInvalid\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"throwOnInvalid\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"FieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"doublefieldprocessor-10\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"fieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldName\",\n", + " \"value\" : \"petalWidth\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"onlyFieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"onlyFieldName\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"throwOnInvalid\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"throwOnInvalid\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"FieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"doublefieldprocessor-11\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"fieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldName\",\n", + " \"value\" : \"sepalWidth\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"onlyFieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"onlyFieldName\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"throwOnInvalid\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"throwOnInvalid\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"FieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \"value\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"doublefieldprocessor-12\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"fieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldName\",\n", + " \"value\" : \"sepalLength\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"onlyFieldName\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"onlyFieldName\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"throwOnInvalid\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"throwOnInvalid\",\n", + " \"value\" : \"true\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"FieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"fieldresponseprocessor-13\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.processors.response.FieldResponseProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl\",\n", + " \"map\" : {\n", + " \"uppercase\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"uppercase\",\n", + " \"value\" : \"false\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"fieldNames\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"fieldNames\",\n", + " \"value\" : \"species\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " } ]\n", + " },\n", + " \"defaultValues\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance\",\n", + " \"list\" : [ {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"defaultValues\",\n", + " \"value\" : \"\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " } ]\n", + " },\n", + " \"displayField\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"displayField\",\n", + " \"value\" : \"false\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"outputFactory\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"outputFactory\",\n", + " \"value\" : \"labelfactory-15\",\n", + " \"provenance-class\" : \"org.tribuo.classification.LabelFactory$LabelFactoryProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : true\n", + " },\n", + " \"host-short-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"host-short-name\",\n", + " \"value\" : \"ResponseProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " },\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.columnar.processors.response.FieldResponseProcessor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", + "}, {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance\",\n", + " \"object-name\" : \"fieldextractor-14\",\n", + " \"object-class-name\" : \"org.tribuo.data.columnar.FieldExtractor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.impl.NullConfiguredProvenance\",\n", + " \"map\" : {\n", + " \"class-name\" : {\n", + " \"marshalled-class\" : \"com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance\",\n", + " \"key\" : \"class-name\",\n", + " \"value\" : \"org.tribuo.data.columnar.FieldExtractor\",\n", + " \"provenance-class\" : \"com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance\",\n", + " \"additional\" : \"\",\n", + " \"is-reference\" : false\n", + " }\n", + " }\n", "} ]\n" ] } @@ -940,7 +1443,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=CSV(class-name=org.tribuo.data.csv.CSVLoader,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),response-name=species,separator=,,quote=\",path=file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,file-modified-time=1999-12-14T15:12:39-05:00,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC]),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=true,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.1.0),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),loggingInterval=1000,objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),tribuo-version=4.1.0,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2021-05-24T12:27:10.387150-04:00,instance-values={},tribuo-version=4.1.0,java-version=17-ea,os-name=Mac OS X,os-arch=x86_64)\n" + "linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=DataSource(class-name=org.tribuo.data.csv.CSVDataSource,headers=[sepalLength, sepalWidth, petalLength, petalWidth, species],rowProcessor=RowProcessor(class-name=org.tribuo.data.columnar.RowProcessor,metadataExtractors=[],fieldProcessorList=[FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=petalLength,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=petalWidth,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=sepalWidth,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor), FieldProcessor(class-name=org.tribuo.data.columnar.processors.field.DoubleFieldProcessor,fieldName=sepalLength,onlyFieldName=true,throwOnInvalid=true,host-short-name=FieldProcessor)],featureProcessors=[],responseProcessor=ResponseProcessor(class-name=org.tribuo.data.columnar.processors.response.FieldResponseProcessor,uppercase=false,fieldNames=[species],defaultValues=[],displayField=false,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),host-short-name=ResponseProcessor),weightExtractor=null,replaceNewlinesWithSpaces=true,regexMappingProcessors={},host-short-name=RowProcessor),quote=\",outputRequired=true,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),separator=,,dataPath=/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC],file-modified-time=1999-12-14T15:12:39-05:00,datasource-creation-time=2021-11-01T12:52:18.814629-04:00,host-short-name=DataSource),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=true,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.2.0-SNAPSHOT),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),loggingInterval=1000,objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),tribuo-version=4.2.0-SNAPSHOT,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2021-11-01T12:52:19.228195-04:00,instance-values={},tribuo-version=4.2.0-SNAPSHOT,java-version=17,os-name=Mac OS X,os-arch=x86_64)\n" ] } ], @@ -965,27 +1468,77 @@ "output_type": "stream", "text": [ "{\n", - " \"tribuo-version\" : \"4.1.0\",\n", + " \"tribuo-version\" : \"4.2.0-SNAPSHOT\",\n", " \"dataset-provenance\" : {\n", " \"num-features\" : \"4\",\n", " \"num-examples\" : \"45\",\n", " \"num-outputs\" : \"3\",\n", - " \"tribuo-version\" : \"4.1.0\",\n", + " \"tribuo-version\" : \"4.2.0-SNAPSHOT\",\n", " \"datasource\" : {\n", " \"train-proportion\" : \"0.7\",\n", " \"seed\" : \"1\",\n", " \"size\" : \"150\",\n", " \"source\" : {\n", " \"resource-hash\" : \"0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC\",\n", - " \"path\" : \"file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\",\n", + " \"headers\" : [ \"sepalLength\", \"sepalWidth\", \"petalLength\", \"petalWidth\", \"species\" ],\n", + " \"rowProcessor\" : {\n", + " \"metadataExtractors\" : [ ],\n", + " \"fieldProcessorList\" : [ {\n", + " \"fieldName\" : \"petalLength\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"petalWidth\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"sepalWidth\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"sepalLength\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " } ],\n", + " \"featureProcessors\" : [ ],\n", + " \"responseProcessor\" : {\n", + " \"uppercase\" : \"false\",\n", + " \"fieldNames\" : [ \"species\" ],\n", + " \"defaultValues\" : [ \"\" ],\n", + " \"displayField\" : \"false\",\n", + " \"outputFactory\" : {\n", + " \"class-name\" : \"org.tribuo.classification.LabelFactory\"\n", + " },\n", + " \"host-short-name\" : \"ResponseProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.response.FieldResponseProcessor\"\n", + " },\n", + " \"weightExtractor\" : {\n", + " \"class-name\" : \"org.tribuo.data.columnar.FieldExtractor\"\n", + " },\n", + " \"replaceNewlinesWithSpaces\" : \"true\",\n", + " \"regexMappingProcessors\" : { },\n", + " \"host-short-name\" : \"RowProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.RowProcessor\"\n", + " },\n", " \"file-modified-time\" : \"1999-12-14T15:12:39-05:00\",\n", " \"quote\" : \"\\\"\",\n", - " \"response-name\" : \"species\",\n", + " \"outputRequired\" : \"true\",\n", + " \"datasource-creation-time\" : \"2021-11-01T12:52:18.814629-04:00\",\n", " \"outputFactory\" : {\n", " \"class-name\" : \"org.tribuo.classification.LabelFactory\"\n", " },\n", " \"separator\" : \",\",\n", - " \"class-name\" : \"org.tribuo.data.csv.CSVLoader\"\n", + " \"host-short-name\" : \"DataSource\",\n", + " \"class-name\" : \"org.tribuo.data.csv.CSVDataSource\",\n", + " \"dataPath\" : \"/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\"\n", " },\n", " \"class-name\" : \"org.tribuo.evaluation.TrainTestSplitter\",\n", " \"is-train\" : \"false\"\n", @@ -998,11 +1551,11 @@ " \"class-name\" : \"org.tribuo.provenance.EvaluationProvenance\",\n", " \"model-provenance\" : {\n", " \"instance-values\" : { },\n", - " \"tribuo-version\" : \"4.1.0\",\n", - " \"java-version\" : \"17-ea\",\n", + " \"tribuo-version\" : \"4.2.0-SNAPSHOT\",\n", + " \"java-version\" : \"17\",\n", " \"trainer\" : {\n", " \"seed\" : \"12345\",\n", - " \"tribuo-version\" : \"4.1.0\",\n", + " \"tribuo-version\" : \"4.2.0-SNAPSHOT\",\n", " \"minibatchSize\" : \"1\",\n", " \"train-invocation-count\" : \"0\",\n", " \"is-sequence\" : \"false\",\n", @@ -1024,28 +1577,78 @@ " }\n", " },\n", " \"os-arch\" : \"x86_64\",\n", - " \"trained-at\" : \"2021-05-24T12:27:10.387150-04:00\",\n", + " \"trained-at\" : \"2021-11-01T12:52:19.228195-04:00\",\n", " \"os-name\" : \"Mac OS X\",\n", " \"dataset\" : {\n", " \"num-features\" : \"4\",\n", " \"num-examples\" : \"105\",\n", " \"num-outputs\" : \"3\",\n", - " \"tribuo-version\" : \"4.1.0\",\n", + " \"tribuo-version\" : \"4.2.0-SNAPSHOT\",\n", " \"datasource\" : {\n", " \"train-proportion\" : \"0.7\",\n", " \"seed\" : \"1\",\n", " \"size\" : \"150\",\n", " \"source\" : {\n", " \"resource-hash\" : \"0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC\",\n", - " \"path\" : \"file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\",\n", + " \"headers\" : [ \"sepalLength\", \"sepalWidth\", \"petalLength\", \"petalWidth\", \"species\" ],\n", + " \"rowProcessor\" : {\n", + " \"metadataExtractors\" : [ ],\n", + " \"fieldProcessorList\" : [ {\n", + " \"fieldName\" : \"petalLength\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"petalWidth\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"sepalWidth\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " }, {\n", + " \"fieldName\" : \"sepalLength\",\n", + " \"onlyFieldName\" : \"true\",\n", + " \"throwOnInvalid\" : \"true\",\n", + " \"host-short-name\" : \"FieldProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.field.DoubleFieldProcessor\"\n", + " } ],\n", + " \"featureProcessors\" : [ ],\n", + " \"responseProcessor\" : {\n", + " \"uppercase\" : \"false\",\n", + " \"fieldNames\" : [ \"species\" ],\n", + " \"defaultValues\" : [ \"\" ],\n", + " \"displayField\" : \"false\",\n", + " \"outputFactory\" : {\n", + " \"class-name\" : \"org.tribuo.classification.LabelFactory\"\n", + " },\n", + " \"host-short-name\" : \"ResponseProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.processors.response.FieldResponseProcessor\"\n", + " },\n", + " \"weightExtractor\" : {\n", + " \"class-name\" : \"org.tribuo.data.columnar.FieldExtractor\"\n", + " },\n", + " \"replaceNewlinesWithSpaces\" : \"true\",\n", + " \"regexMappingProcessors\" : { },\n", + " \"host-short-name\" : \"RowProcessor\",\n", + " \"class-name\" : \"org.tribuo.data.columnar.RowProcessor\"\n", + " },\n", " \"file-modified-time\" : \"1999-12-14T15:12:39-05:00\",\n", " \"quote\" : \"\\\"\",\n", - " \"response-name\" : \"species\",\n", + " \"outputRequired\" : \"true\",\n", + " \"datasource-creation-time\" : \"2021-11-01T12:52:18.814629-04:00\",\n", " \"outputFactory\" : {\n", " \"class-name\" : \"org.tribuo.classification.LabelFactory\"\n", " },\n", " \"separator\" : \",\",\n", - " \"class-name\" : \"org.tribuo.data.csv.CSVLoader\"\n", + " \"host-short-name\" : \"DataSource\",\n", + " \"class-name\" : \"org.tribuo.data.csv.CSVDataSource\",\n", + " \"dataPath\" : \"/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data\"\n", " },\n", " \"class-name\" : \"org.tribuo.evaluation.TrainTestSplitter\",\n", " \"is-train\" : \"true\"\n", @@ -1080,7 +1683,7 @@ "metadata": {}, "source": [ "## Loading and saving models\n", - "Tribuo uses Java Serialization to save and load models. Models and Datasets are `java.io.Serializable` and can be written to input and output streams in the usual manner. Here we'll go through saving and loading the model we just trained, but the procedure is the same for all other Tribuo models.\n", + "Tribuo uses Java Serialization to save and load models. Models and Datasets are `java.io.Serializable` and can be written to input and output streams in the usual manner. Here we'll go through saving and loading the model we just trained, but the procedure is the same for all other Tribuo models. We're going to save this out into the tutorials directory as this model file is used in the reproducibility tutorial.\n", "\n", "First we save the model out using an `ObjectOutputStream`." ] @@ -1091,7 +1694,7 @@ "metadata": {}, "outputs": [], "source": [ - "File tmpFile = File.createTempFile(\"irisModel\",\"ser\");\n", + "File tmpFile = new File(\"iris-lr-model.ser\");\n", "try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(tmpFile))) {\n", " oos.writeObject(irisModel);\n", "}" @@ -1112,7 +1715,7 @@ }, "outputs": [], "source": [ - "String filterPattern = Files.readAllLines(Paths.get(\"../docs/jep-290-allowlist.txt\")).get(0);\n", + "String filterPattern = Files.readAllLines(Paths.get(\"../docs/jep-290-filter.txt\")).get(0);\n", "ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(filterPattern);\n", "Model loadedModel;\n", "try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(tmpFile)))) {\n", @@ -1125,7 +1728,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As Tribuo's models are generically typed, and Java's generics are erased, this requires an unchecked cast to apply the right type to the model. Tribuo has a mechanism for validating that the type is correct, `model.validate(Class>)` which returns true if the supplied class is the same as the internal output type stored in this model." + "As Tribuo's models are generically typed, and Java's generics are erased, this requires an unchecked cast to apply the right type to the model. Tribuo has a mechanism for validating that the type is correct, `model.validate(Class>)` which returns true if the supplied class is the same as the internal output type stored in this model. There's also `model.castModel(Class>)` which wraps up the validate check and either casts the model appropriately or throws `ClassCastException` if the type is invalid." ] }, { @@ -1199,7 +1802,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "17-ea+22-1964" + "version": "17+35-LTS-2724" } }, "nbformat": 4, diff --git a/tutorials/onnx-export-tribuo-v4.ipynb b/tutorials/onnx-export-tribuo-v4.ipynb new file mode 100644 index 000000000..77e25c14b --- /dev/null +++ b/tutorials/onnx-export-tribuo-v4.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model export and deployment tutorial\n", + "\n", + "Tribuo works best as a library which provides training and deployment inside the JVM where the application is running, however sometimes you need to deploy models elsewhere, either in another programming environment like Python, or in a cloud service. To support these use cases many of Tribuo's models can be exported as [ONNX](https://onnx.ai) models, a cross-platform model exchange format. ONNX is widely supported across industry, for edge devices, hardware accelerators, and cloud services. Tribuo also supports loading in ONNX models and scoring them as native Tribuo models, for more information on that see the external models tutorial.\n", + "\n", + "This tutorial will show how to export models in ONNX format, how to recover the provenance information from Tribuo-exported ONNX models, and how to deploy an ONNX model in [OCI Data Science](https://www.oracle.com/data-science/cloud-infrastructure-data-science.html) though of course other cloud providers support ONNX models too. We'll show how to export a factorization machine, create an ensemble of a factorization machine along with some other models, export the ensemble, then we'll discuss how to interact with the provenance of an exported model, before concluding with deploying that model to OCI.\n", + "\n", + "## Setup\n", + "\n", + "This tutorial requires ONNX Runtime to score the exported models, so by default will only run on x86\\_64 platforms. ONNX Runtime can be compiled on ARM64 platforms, but that binary is not in the Maven Central jar Tribuo depends on, so will need to be compiled from scratch to run the tutorial on ARM.\n", + "\n", + "We're going to use MNIST as the example dataset for this tutorial, so you'll need to download it if you haven't already.\n", + "\n", + "First the training set:\n", + "\n", + "`wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz`\n", + "\n", + "`wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz`\n", + "\n", + "Then the test set:\n", + "\n", + "`wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz`\n", + "\n", + "`wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz`\n", + "\n", + "As usual we'll load in some jars for classification problems, along with Tribuo's ONNX Runtime and OCI interfaces." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%jars ./tribuo-classification-experiments-4.2.0-SNAPSHOT-jar-with-dependencies.jar\n", + "%jars ./tribuo-oci-4.2.0-SNAPSHOT-jar-with-dependencies.jar\n", + "%jars ./tribuo-onnx-4.2.0-SNAPSHOT-jar-with-dependencies.jar\n", + "%jars ./tribuo-json-4.2.0-SNAPSHOT-jar-with-dependencies.jar" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import java.nio.file.Files;\n", + "import java.nio.file.Paths;\n", + "\n", + "import org.tribuo.*;\n", + "import org.tribuo.classification.*;\n", + "import org.tribuo.classification.ensemble.*;\n", + "import org.tribuo.classification.evaluation.*;\n", + "import org.tribuo.classification.sgd.fm.FMClassificationTrainer;\n", + "import org.tribuo.classification.sgd.linear.*;\n", + "import org.tribuo.classification.sgd.objectives.LogMulticlass;\n", + "import org.tribuo.ensemble.*;\n", + "import org.tribuo.data.csv.CSVLoader;\n", + "import org.tribuo.datasource.IDXDataSource;\n", + "import org.tribuo.evaluation.TrainTestSplitter;\n", + "import org.tribuo.interop.onnx.*;\n", + "import org.tribuo.math.optimisers.*;\n", + "import org.tribuo.interop.oci.*;\n", + "import org.tribuo.util.onnx.*;\n", + "import org.tribuo.util.Util;\n", + "import com.oracle.bmc.ConfigFileReader;\n", + "import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;\n", + "import com.oracle.bmc.datascience.DataScienceClient;\n", + "import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;\n", + "import com.oracle.labs.mlrg.olcut.util.Pair;\n", + "\n", + "import ai.onnxruntime.*;" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we'll load in MNIST and Wine Quality." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MNIST train size = 60000, number of features = 717, number of classes = 10\n", + "MNIST test size = 10000, number of features = 668, number of classes = 10\n" + ] + } + ], + "source": [ + "var labelFactory = new LabelFactory();\n", + "var labelEvaluator = new LabelEvaluator();\n", + "var mnistTrainSource = new IDXDataSource<>(Paths.get(\"train-images-idx3-ubyte.gz\"),Paths.get(\"train-labels-idx1-ubyte.gz\"),labelFactory);\n", + "var mnistTestSource = new IDXDataSource<>(Paths.get(\"t10k-images-idx3-ubyte.gz\"),Paths.get(\"t10k-labels-idx1-ubyte.gz\"),labelFactory);\n", + "var mnistTrain = new MutableDataset<>(mnistTrainSource);\n", + "var mnistTest = new MutableDataset<>(mnistTestSource);\n", + "System.out.println(String.format(\"MNIST train size = %d, number of features = %d, number of classes = %d\",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));\n", + "System.out.println(String.format(\"MNIST test size = %d, number of features = %d, number of classes = %d\",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exporting a single classification model\n", + "\n", + "We're going to train a multi-class [Factorization Machine](https://ieeexplore.ieee.org/document/5694074), which is a non-linear model that approximates all the non-linear feature interactions with a small per-feature embedding vector. It's similar to a logistic regression with an additional feature-feature interaction term, one per output label. In Tribuo Factorization Machines can be trained using stochastic gradient descent, using the standard SGD algorithms Tribuo uses for other models. We're going to use AdaGrad as it's usually a good baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "var fmLabelTrainer = new FMClassificationTrainer(new LogMulticlass(), // Loss function\n", + " new AdaGrad(0.1,0.1), // Gradient optimiser\n", + " 5, // Number of training epochs\n", + " 30000, // Logging interval\n", + " Trainer.DEFAULT_SEED, // RNG seed\n", + " 6, // Factor size\n", + " 0.1 // Factor initialisation variance\n", + " );" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After defining the model we train it as usual. Factorization machines take a little longer to train than logistic regression does, but not excessively so." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training factorization machine took (00:00:18:816)\n" + ] + } + ], + "source": [ + "var fmStartTime = System.currentTimeMillis();\n", + "var fmMNIST = fmLabelTrainer.train(mnistTrain);\n", + "var fmEndTime = System.currentTimeMillis();\n", + "System.out.println(\"Training factorization machine took \" + Util.formatDuration(fmStartTime,fmEndTime));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And then evaluate it using Tribuo's built in evaluation system." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scoring factorization machine took (00:00:00:475)\n", + "Class n tp fn fp recall prec f1\n", + "0 980 959 21 31 0.979 0.969 0.974\n", + "1 1,135 1,120 15 22 0.987 0.981 0.984\n", + "2 1,032 976 56 57 0.946 0.945 0.945\n", + "3 1,010 952 58 39 0.943 0.961 0.952\n", + "4 982 952 30 49 0.969 0.951 0.960\n", + "5 892 857 35 63 0.961 0.932 0.946\n", + "6 958 920 38 30 0.960 0.968 0.964\n", + "7 1,028 969 59 36 0.943 0.964 0.953\n", + "8 974 916 58 57 0.940 0.941 0.941\n", + "9 1,009 951 58 44 0.943 0.956 0.949\n", + "Total 10,000 9,572 428 428\n", + "Accuracy 0.957\n", + "Micro Average 0.957 0.957 0.957\n", + "Macro Average 0.957 0.957 0.957\n", + "Balanced Error Rate 0.043\n", + " 0 1 2 3 4 5 6 7 8 9\n", + "0 959 0 0 0 1 2 7 4 4 3\n", + "1 0 1,120 4 1 3 0 3 0 4 0\n", + "2 6 5 976 7 7 2 5 8 14 2\n", + "3 0 2 15 952 0 19 1 3 14 4\n", + "4 3 3 7 1 952 0 4 1 1 10\n", + "5 3 1 0 6 1 857 5 5 13 1\n", + "6 8 2 7 2 7 11 920 1 0 0\n", + "7 2 5 13 5 4 4 0 969 4 22\n", + "8 2 1 9 9 11 15 4 5 916 2\n", + "9 7 3 2 8 15 10 1 9 3 951\n", + "\n" + ] + } + ], + "source": [ + "fmStartTime = System.currentTimeMillis();\n", + "var mnistFMEval = labelEvaluator.evaluate(fmMNIST,mnistTest);\n", + "fmEndTime = System.currentTimeMillis();\n", + "System.out.println(\"Scoring factorization machine took \" + Util.formatDuration(fmStartTime,fmEndTime));\n", + "System.out.println(mnistFMEval.toString());\n", + "System.out.println(mnistFMEval.getConfusionMatrix().toString());" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get about 95% accuracy on MNIST, which is pretty good for a fairly simple model. Now let's export it to ONNX, then we'll load it back in via Tribuo's ONNX Runtime interface and compare the performance. We'll use this model in the reproducibility tutorial so we'll save it to disk in the tutorials folder.\n", + "\n", + "Tribuo `Model`s which support ONNX export implement the `ONNXExportable` interface which defines methods for constructing an ONNX protobuf and saving it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "var fmMNISTPath = Paths.get(\".\",\"fm-mnist.onnx\");\n", + "fmMNIST.saveONNXModel(\"org.tribuo.tutorials.onnxexport.fm\", // namespace for the model\n", + " 0, // model version number\n", + " fmMNISTPath // path to save the model\n", + " );" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To load an ONNX model we need to define the mapping between Tribuo's feature names and the indices that the ONNX model understands. Fortunately for models exported from Tribuo we already have that information, as it is stored in the feature and output maps. We'll extract it into the general form that `ONNXExternalModel` expects." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "Map mnistFeatureMap = new HashMap<>();\n", + "for (VariableInfo f : fmMNIST.getFeatureIDMap()){\n", + " VariableIDInfo id = (VariableIDInfo) f;\n", + " mnistFeatureMap.put(id.getName(),id.getID());\n", + "}\n", + "Map mnistOutputMap = new HashMap<>();\n", + "for (Pair l : fmMNIST.getOutputIDInfo()) {\n", + " mnistOutputMap.put(l.getB(), l.getA());\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we'll define a test function that compares two sets of predictions, as ONNX Runtime uses single precision for computations, and Tribuo uses double precision so the prediction scores are never bitwise equal." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "public boolean checkPredictions(List> nativePredictions, List> onnxPredictions, double delta) {\n", + " for (int i = 0; i < nativePredictions.size(); i++) {\n", + " Prediction