Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation updates for 4.2 #205

Merged
merged 23 commits into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0cac128
Updating docs for 4.2.
Craigacp Oct 26, 2021
031fa90
Migrating StripProvenance over to use the ProvenanceSerialization int…
Craigacp Oct 26, 2021
f359e24
Updating the configuration tutorial.
Craigacp Oct 26, 2021
2352f81
Initial readme updates for 4.2
Craigacp Oct 26, 2021
e5a854d
Adding the start of the onnx export tutorial.
Craigacp Oct 26, 2021
d229571
Adding first draft of release notes for 4.2.
Craigacp Oct 29, 2021
56567b0
Fixing the JEP 290 filter.
Craigacp Nov 1, 2021
2a75219
Changing Model.castModel so it's not static.
Craigacp Nov 1, 2021
f016538
Adding a reproducibility tutorial and updating the irises and onnx ex…
Craigacp Nov 1, 2021
5846f36
Updating gitignore file.
Craigacp Nov 1, 2021
34ad29b
Updating reproducibility tutorial with a bigger diff example.
Craigacp Nov 3, 2021
269a1ed
Updating the v4.2 release notes.
Craigacp Nov 12, 2021
41d8518
Adding TF-Java PR number.
Craigacp Nov 28, 2021
98790f8
Updating docs for the HDBSCAN implementation.
Craigacp Dec 2, 2021
a9f6150
Updating 4.2 release notes.
Craigacp Dec 9, 2021
0348805
Adding Tribuo v4.1.1 release notes.
Craigacp Dec 9, 2021
2860ab2
Finishing ONNX export tutorial.
Craigacp Dec 9, 2021
19080df
Updating ONNX export tutorial.
Craigacp Dec 14, 2021
cf9f80c
Docs updates after rebase.
Craigacp Dec 17, 2021
3bd5719
CastModel fix in OCIModelCLI.
Craigacp Dec 17, 2021
5926c3d
Javadoc updates after rebase.
Craigacp Dec 17, 2021
3970371
Fixing the circular PR reference in the 4.2 release notes.
Craigacp Dec 17, 2021
fdd0f1b
Fixing an accidental deletion in the external models tutorial.
Craigacp Dec 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,32 @@ bin/
.*.swp

# Other files
*.jar
*.class
*.er
*.log
*.bck
*.so
*.patch

# Binaries
*.jar
*.class

# Archives
*.gz
*.zip

# Serialised models
*.ser

# Temporary stuff
junk/*
.DS_Store
.ipynb_checkpoints

# Profiling files
*.jfr
*.iprof
*.jfc

# Tutorial files
tutorials/*.svm
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ protected LibSVMTrainer() {}
/**
* Constructs a LibSVMTrainer from the parameters.
* @param parameters The SVM parameters.
* @param seed The RNG seed.
*/
protected LibSVMTrainer(SVMParameters<T> parameters, long seed) {
this.parameters = parameters.getParameters();
Expand Down
15 changes: 8 additions & 7 deletions Core/src/main/java/org/tribuo/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -300,20 +300,21 @@ public String toString() {

/**
* Casts the model to the specified output type, assuming it is valid.
* <p>
* If it's not valid, throws {@link ClassCastException}.
* @param inputModel The model to cast.
* <p>
* This method is intended for use on a deserialized model to restore it's
* generic type in a safe way.
* @param outputType The output type to cast to.
* @param <T> The output type.
* @param <U> The output type.
* @return The model cast to the correct value.
*/
public static <T extends Output<T>> Model<T> castModel(Model<?> inputModel, Class<T> outputType) {
if (inputModel.validate(outputType)) {
public <U extends Output<U>> Model<U> castModel(Class<U> outputType) {
if (validate(outputType)) {
@SuppressWarnings("unchecked") // guarded by validate
Model<T> castedModel = (Model<T>) inputModel;
Model<U> castedModel = (Model<U>) this;
return castedModel;
} else {
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString());
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + this.toString());
}
}

Expand Down
2 changes: 2 additions & 0 deletions Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
* "The Elements of Statistical Learning"
* Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
* </pre>
* @param <T> The prediction type.
*/
public class BaggingTrainer<T extends Output<T>> implements Trainer<T> {

Expand Down Expand Up @@ -177,6 +178,7 @@ public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runPr
* @param labelIDs The output domain.
* @param randInt A random int from an rng instance
* @param runProvenance Provenance for this instance.
* @param invocationCount The invocation count for the inner trainer.
* @return The trained ensemble member.
*/
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, int randInt, Map<String,Provenance> runProvenance, int invocationCount) {
Expand Down
3 changes: 2 additions & 1 deletion Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ default ONNXNode exportCombiner(ONNXNode input) {
* will be required to provide ONNX support.
* @param input the node to be ensembled according to this implementation.
* @param weight The node of weights for ensembling.
* @param <U> The type of the weights input reference.
* @return The leaf node of the graph of operations added to ensemble input.
*/
default <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
default <U extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, U weight) {
Logger.getLogger(this.getClass().getName()).severe("Tried to export an ensemble combiner to ONNX format, but this is not implemented.");
throw new IllegalStateException("This ensemble cannot be exported as the combiner '" + this.getClass() + "' uses the default implementation of EnsembleCombiner.exportCombiner.");
}
Expand Down
4 changes: 2 additions & 2 deletions Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private SQLDBConfig() {}
/**
* Constructs a SQL database configuration.
* <p>
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
* @param connectionString The connection string.
* @param username The username.
* @param password The password.
Expand All @@ -87,7 +87,7 @@ public SQLDBConfig(String connectionString, String username, String password, Ma
/**
* Constructs a SQL database configuration.
* <p>
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
* @param host The host to connect to.
* @param port The port to connect on.
* @param db The db name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ public static ConfigFileAuthenticationDetailsProvider makeAuthProvider(Path conf
* @param configFile The OCI configuration file, if null use the default file.
* @param endpointURL The endpoint URL.
* @param outputConverter The converter for the specified output type.
* @param <T> The output type.
* @return An OCIModel ready to score new inputs.
*/
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,
Expand All @@ -332,6 +333,7 @@ public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T>
* @param profileName The profile name in the OCI configuration file, if null uses the default profile.
* @param endpointURL The endpoint URL.
* @param outputConverter The converter for the specified output type.
* @param <T> The output type.
* @return An OCIModel ready to score new inputs.
*/
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private static void createModelAndDeploy(OCIModelOptions options) throws IOExcep
// Load the Tribuo model
Model<Label> model;
try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(options.modelPath))) {
model = Model.castModel((Model<?>) ois.readObject(),Label.class);
model = ((Model<?>)ois.readObject()).castModel(Label.class);
}
if (!(model instanceof ONNXExportable)) {
throw new IllegalArgumentException("Model not ONNXExportable, received " + model.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ public static <T extends Output<T>, U extends Model<T> & ONNXExportable> String
/**
* Creates the OCI DS model artifact zip file.
* @param onnxFile The ONNX file to create.
* @param config The model artifact configuration.
* @return The path referring to the zip file.
* @throws IOException If the file could not be created or the ONNX file could not be read.
*/
Expand Down
15 changes: 4 additions & 11 deletions Json/src/main/java/org/tribuo/json/StripProvenance.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@

package org.tribuo.json;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceModule;
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.util.IOUtil;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
Expand Down Expand Up @@ -315,11 +312,8 @@ public static <T extends Output<T>> void main(String[] args) {
ModelProvenance oldProvenance = input.getProvenance();

logger.info("Marshalling provenance and creating JSON.");
List<ObjectMarshalledProvenance> list = ProvenanceUtil.marshalProvenance(oldProvenance);
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JsonProvenanceModule());
mapper.enable(SerializationFeature.INDENT_OUTPUT);
String jsonResult = mapper.writeValueAsString(list);
JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);

logger.info("Hashing JSON file");
MessageDigest digest = o.hashType.getDigest();
Expand All @@ -340,8 +334,7 @@ public static <T extends Output<T>> void main(String[] args) {

ModelProvenance newProvenance = tuple.provenance;
logger.info("Marshalling provenance and creating JSON.");
List<ObjectMarshalledProvenance> newList = ProvenanceUtil.marshalProvenance(newProvenance);
String newJsonResult = mapper.writeValueAsString(newList);
String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);

logger.info("Old provenance = \n" + jsonResult);
logger.info("New provenance = \n" + newJsonResult);
Expand Down
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ architectures on Windows 10, macOS and Linux (RHEL/OL/CentOS 7+), as these are
supported platforms for the native libraries with which we interface. If you're
interested in another platform and wish to use one of the native library
interfaces (ONNX Runtime, TensorFlow, and XGBoost), we recommend reaching out
to the developers of those libraries.
to the developers of those libraries. Note the reproducibility package
requires Java 17, and as such is not part of the `tribuo-all` Maven Central
deployment.

## Documentation

Expand Down Expand Up @@ -85,6 +87,7 @@ Tribuo has implementations or interfaces for:
|Algorithm|Implementation|Notes|
|---|---|---|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|
|CART|Tribuo||
|SVM-SGD|Tribuo|An implementation of the Pegasos algorithm|
|Adaboost.SAMME|Tribuo|Can use any Tribuo classification trainer as the base learner|
Expand All @@ -109,6 +112,7 @@ output.
|Algorithm|Implementation|Notes|
|---|---|---|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|
|CART|Tribuo||
|Lasso|Tribuo|Using the LARS algorithm|
|Elastic Net|Tribuo|Using the co-ordinate descent algorithm|
Expand All @@ -124,6 +128,7 @@ algorithms over time.

|Algorithm|Implementation|Notes|
|---|---|---|
|HDBSCAN\*|Tribuo||
|K-Means|Tribuo|Includes both sequential and parallel backends, and the K-Means++ initialisation algorithm|

### Anomaly Detection
Expand All @@ -146,7 +151,9 @@ more multi-label specific implementations over time.
|Algorithm|Implementation|Notes|
|---|---|---|
|Independent wrapper|Tribuo|Converts a multi-class classification algorithm into a multi-label one by producing a separate classifier for each label|
|Classifier Chains|Tribuo|Provides classifier chains and randomized classifier chain ensembles using any of Tribuo's multi-class classification algorithms|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|

### Interfaces

Expand All @@ -158,10 +165,10 @@ discuss how it would fit into Tribuo.
Currently we have interfaces to:

* [LibLinear](https://github.com/bwaldvogel/liblinear-java) - via the LibLinear-java port of the original [LibLinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) (v2.43).
* [LibSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) - using the pure Java transformed version of the C++ implementation (v3.24).
* [ONNX Runtime](https://onnxruntime.ai) - via the Java API contributed by our group (v1.7.0).
* [TensorFlow](https://tensorflow.org) - Using [TensorFlow Java](https://github.com/tensorflow/java) v0.3.1 (based on TensorFlow v2.4.1). This allows the training and deployment of TensorFlow models entirely in Java.
* [XGBoost](https://xgboost.ai) - via the built in XGBoost4J API (v1.4.1).
* [LibSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) - using the pure Java transformed version of the C++ implementation (v3.25).
* [ONNX Runtime](https://onnxruntime.ai) - via the Java API contributed by our group (v1.9.0).
* [TensorFlow](https://tensorflow.org) - Using [TensorFlow Java](https://github.com/tensorflow/java) v0.4.0 (based on TensorFlow v2.7.0). This allows the training and deployment of TensorFlow models entirely in Java.
* [XGBoost](https://xgboost.ai) - via the built in XGBoost4J API (v1.5.0).

## Binaries

Expand All @@ -187,7 +194,7 @@ implementation ("org.tribuo:tribuo-all:4.1.0@pom") {
```

The `tribuo-all` dependency is a pom which depends on all the Tribuo
subprojects.
subprojects except for the reproducibility project which requires Java 17.

Most of Tribuo is pure Java and thus cross-platform, however some of the
interfaces link to libraries which use native code. Those interfaces
Expand All @@ -197,11 +204,13 @@ are supplied. If you need support for a specific platform, reach out to the
maintainers of those projects. As of the 4.1 release these native packages
all provide x86\_64 binaries for Windows, macOS and Linux. It is also possible
to compile each package for macOS ARM64 (i.e., Apple Silicon), though there are
no binaries available on Maven Central for that platform.
no binaries available on Maven Central for that platform. When developing
on an ARM platform you can select the `arm` profile in Tribuo's pom.xml to
disable the native library tests.

Individual jars are published for each Tribuo module. It is preferable to
depend only on the modules necessary for the specific project. This prevents
your code from unnecessarily pulling in large dependencies like TensorFlow
your code from unnecessarily pulling in large dependencies like TensorFlow.

## Compiling from source

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public ONNXContext() {
* ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for
* constructing ONNXNodes which all other methods on ONNXContext and {@code ONNXRef} call.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputs A list of names that the output nodes of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
Expand All @@ -82,7 +82,7 @@ public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op,
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
Expand All @@ -102,7 +102,7 @@ public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> input
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param <T> The ONNXRef type of inputs
* @return An {@link ONNXNode} that is the output nodes of {@code op}.
Expand Down
25 changes: 21 additions & 4 deletions Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,44 @@
* can thus be passed around without needing to pass their governing context as well.
* <p>
* N.B. This class will be sealed once the library is updated past Java 8. Users should not subclass this class.
* @param <T>
* @param <T> The protobuf type this reference generates.
*/
public abstract class ONNXRef<T extends GeneratedMessageV3> {
// Unfortunately there is no other shared supertype for OnnxML protobufs
protected final T backRef;
private final String baseName;
protected final ONNXContext context;


/**
* Creates an ONNXRef for the specified context, protobuf and name.
* @param context The ONNXContext we're operating in.
* @param backRef The protobuf reference.
* @param baseName The name of this reference.
*/
ONNXRef(ONNXContext context, T backRef, String baseName) {
this.context = context;
this.backRef = backRef;
this.baseName = baseName;
}

/**
* Gets the output name of this object.
* @return The output name.
*/
public abstract String getReference();

/**
* The name of this object.
* @return The name.
*/
public String getBaseName() {
return baseName;
}

/**
* The context this reference operates in.
* @return The context.
*/
public ONNXContext onnxContext() {
return context;
}
Expand All @@ -66,7 +83,7 @@ public ONNXContext onnxContext() {
* as the first argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. The other
* arguments behave as in the analogous method on ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param otherInputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param otherInputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputs A list of names that the output nodes of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @return a list of {@link ONNXNode}s that are the output nodes of {@code op}.
Expand Down Expand Up @@ -199,7 +216,7 @@ public <Ret extends ONNXRef<?>> Ret assignTo(Ret output) {
/**
* Casts this ONNXRef to a different type using the {@link ONNXOperators#CAST} operation, and returning the output
* node of that op. Currently supports only float, double, int, and long, which are specified by their respective
* {@link Class} objects (eg. {@link float.class}). Throws {@link IllegalArgumentException} when an unsupported cast
* {@link Class} objects (e.g., {@code float.class}). Throws {@link IllegalArgumentException} when an unsupported cast
* is requested.
* @param clazz The class object specifying the type to cast to.
* @return An ONNXRef representing this object cast into the requested type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

/**
* Interfaces and utilities for writing <a href="https://onnx.ai>ONNX</a> models from Java.
* Interfaces and utilities for writing <a href="https://onnx.ai">ONNX</a> models from Java.
* <p>
* Developed to support <a href="https://tribuo.org">Tribuo</a>, but can be used to export
* other machine learning models from JVM languages.
Expand Down
Loading