-
Notifications
You must be signed in to change notification settings - Fork 18
Getting Started Guide
The VisRec API JSR #381 is a software development standard recognized by the Java Community Process (JCP) that simplifies and standardizes a set of APIs familiar to Java developers for classifying and recognizing objects in images using machine learning.
There are two types of Java developers that may be interested in VisRec JSR #381: application developers interested in creating apps that use the VisRec API, and library developers interested in implementing the VisRec API over a given AI/ML engine.
This document is specifically designed for application developers. The VisRec examples use well-known “clean” datasets publicly available. In production ML applications it is important to ensure all data go through a “data cleaning” phase to ensure correctness, consistency, and non-bias.
- Java JDK 8+ - https://www.oracle.com/technetwork/java/javase/downloads/index.html
- Git - https://git-scm.com/downloads
- Maven - https://maven.apache.org/install.html
Note: No GPU is required.
git clone https://github.com/JavaVisRec/jsr381-examples.git
cd jsr381-examples
mvn clean install
mvn exec:java -Dexec.mainClass=jsr381.example.ImplementationExample
The easiest way to build the examples (or your own applications) is to include the necessary dependencies in your Maven POM or Gradle build script for automated building. The files that satisfy the external dependencies are typically found in popular Java repositories such as Maven Central. Alternatively, you can manually clone and install the dependencies before proceeding to the examples.
In your pom.xml file, add the following dependencies to the <dependencies />
section:
...
<dependency>
<groupId>javax.visrec</groupId>
<artifactId>visrec-api</artifactId>
<version>1.0.1</version>
</dependency>
<dependency>
<groupId>javax.visrec</groupId>
<artifactId>visrec-ri</artifactId>
<version>1.0.2</version>
</dependency>
...
And add the Sonatype Snapshot repository to the <repositories />
section:
...
<repository>
<id>snapshots</id>
<url>https://oss.sonatype.org/content/groups/public/</url>
<snapshots>
<enabled>true</enabled>
<updatePolicy>always<updatePolicy/>
</snapshots>
<releases>
<enabled>false</enabled>
</releases>
</repository>
...
If you are using Gradle instead of Maven, here are the equivalent dependencies:
...
dependencies {
implementation 'javax.visrec:visrec-api:1.0-SNAPSHOT'
implementation 'javax.visrec:visrec-ri:1.0-SNAPSHOT'
}
...
And add Sonatype Snapshot repository:
...
repositories {
maven {
url "https://oss.sonatype.org/content/groups/public/"
mavenContent {
snapshotsOnly()
}
}
}
...
If you are using Maven or Gradle, you can skip this section and go right to the Examples. If you need to manually install the mandatory dependencies for VisRec, follow the instructions in this article on GitHub. Essentially you need to clone four repos:
- VisRec API
- VisRec Reference Implementation (or another implementation)
- Deep Netts Community Edition (or another ML engine)
- Examples
-
API Specification https://github.com/JavaVisRec/visrec-api
-
Reference implementation https://github.com/JavaVisRec/visrec-ri/
-
Examples https://github.com/JavaVisRec/jsr381-examples
All the examples described here can also be found in the VisRec examples repository on GitHub: http://github.com/JavaVisRec/jsr381-examples. The first example “Hello World” ensures you have properly installed all the necessary dependencies and components. The next two examples, “Simple Linear Regression” and “Logistic Regression”, cover fundamental ML concepts that are necessary for Java developers to understand. The general principles of linear regression and parameter tuning is applicable to all other ML algorithms. In order to understand and appreciate visual recognition applications (and ML in general), Java developers need to understand these basics.
Experienced ML developers may want to skip the first few examples and go directly to exercise 5 that specifically covers image classification.
Running this example will return the used implementation of the API, and will confirm that you have everything that you need to run the following VisRec examples.
System.out.println("VisRec API (JSR 381) implementation: "
+ ServiceProvider.current().getImplementationService().toString());
Simple Linear Regression is a basic machine learning algorithm. It is based on a statistical method, that can be used to estimate the relationship between two values, assuming that the relationship can roughly be approximated by a straight line - a linear function. Once the model is built, you can use it to estimate the dependent variable for a given value for the input variable.
Simple Linear Regression introduces a basic iterative error minimization procedure for a given data set. This is the basis for more advanced algorithms in the VisRec API. For more theory see Simple Linear Regression on Wikipedia
Example 2 uses a Swedish Auto Insurance Dataset to predict the total payment for all auto insurance claims (in thousands of Swedish Krona), given the total number of claims.
After Example 2 successfully runs, try changing the learning rate or maximum error parameters to see the impact on the result.
STEP 1. LOAD DATA SET Get the prepared instance of data set from the library of datasets included for testing purposes. The data set contains two value per row: number of claims and total payment for all the claims.
// Get data set from the library of data set examples
DataSet dataSet = DataSetExamples.getSwedishAutoInsuranceDataSet();
Step 2. BUILD THE MODEL Build a machine learning model using builder from SimpleLinearRegressionNetwork. This will create a simple feed forward neural network that performs simple linear regression. It also implements the SimpleLinearRegression interface.
// Build the model
SimpleLinearRegression linReg = SimpleLinearRegressionNetwork.builder()
.trainingSet(dataSet)
.learningRate(0.1f)
.maxError(0.01f)
.build();
Model building is a critical step in the overall ML workflow. This step is an iterative process in which the estimation error is lowered with each iteration. Max error setting determines the minimum acceptable error while learning rate setting determines the step size for a change of model's internal parameters.
STEP 3. DISPLAY TRAINED MODEL DETAILS Display information about the trained model:
float slope = linReg.getSlope();
float intercept = linReg.getIntercept();
System.out.println("Trained Model y = " + slope + " * x + " + intercept);
The internal formula of linear regression is a straight line thathas the-well-known mathematical formula y = slope * x + intercept The parameters, slope and intercept determine how the line is positioned, where intercept is a cut on y-axis and slope is the relative angle of the line and the x-axis.
STEP 4. RUN MODEL ON NEW DATA AND PREDICT RESULT Give the trained model some new input and allow it to predict the outcome:
float someInput = 0.10483871f; // some arbitrary input
Float prediction = linReg.predict(someInput);
System.out.println("Predicted output for " + (someInput*124) + " is:" + (prediction*422.2));
Note that the input and output are multiplied with some values in order to get the final result since all the values in the data set are scaled when prepared for training.
Full source code of this example is available here
Logistic regression is a basic binary classification algorithm, which can learn to assign an input to one of two possible categories/classes (typically YES/NO kind of tasks). Moreover, it gives a probability that some input belongs to specific category. A nice picture below, from this article by Suresh explains why linear regression can't be used for this type of problems, and how logistic regression solves it by using so-called sigmoid (s-shaped) function. Logistic regression tweaks this function's parameters in order to fit given data. For more theory see Logistic Regression at Wikipedia or this post
This example will show how to train a simple neural network that performs logistic regression for 60 inputs, for classifying sonar signals in order to detect mines. Sonar data set contains patterns obtained by bouncing sonar signals from mines and rocks, with corresponding classification See more about the original dataset here
STEP 1. Get sonar data set from example dataset library
DataSet dataSet = DataSetExamples.getSonarDataSet();
STEP 2. Create the logistic regressor based on the training set and parameters:
LogisticRegression<FeedForwardNetwork> logReg = LogisticRegressionNetwork.builder()
.inputsNum(60)
.trainingSet(dataSet)
.learningRate(0.01f)
.maxError(0.03f)
.maxEpochs(1500)
.build();
STEP 3. Prepare some input for the regressor:
float[] someInput = new float[]{ 0.02f,0.0371f,0.0428f,0.0207f,0.0954f,0.0986f,0.1539f,0.1601f,0.3109f,0.2111f,0.1609f,0.1582f,0.2238f,
0.0645f,0.066f,0.2273f,0.31f,0.2999f,0.5078f,0.4797f,0.5783f,0.5071f,0.4328f,0.555f,0.6711f,0.6415f, 0.7104f,0.808f,0.6791f,0.3857f,0.1307f,0.2604f,0.5121f,0.7547f,0.8537f,0.8507f,0.6692f,0.6097f,0.4943f, 0.2744f,0.051f,0.2834f,0.2825f,0.4256f,0.2641f,0.1386f,0.1051f,0.1343f,0.0383f,0.0324f,0.0232f,0.0027f, 0.0065f,0.0159f,0.0072f,0.0167f,0.018f,0.0084f,0.009f,0.0032f};
Step 4. Classify with the regressor based on the input and print the results:
Map<Boolean, Float> result = logReg.classify(someInput);
System.out.println(result);
The training output and results:
...
------------------------------------------------------------------------
TRAINING NEURAL NETWORK
------------------------------------------------------------------------
Epoch:1, Time:100ms, TrainError:0.08329139, TestError:0.0, TrainErrorChange:0.08329139, Accuracy: 0.5362319
Epoch:2, Time:0ms, TrainError:0.14647661, TestError:0.0, TrainErrorChange:0.14647661, Accuracy: 0.5362319
Epoch:3, Time:0ms, TrainError:0.14412558, TestError:0.0, TrainErrorChange:0.14412558, Accuracy: 0.5362319
Epoch:4, Time:0ms, TrainError:0.14199357, TestError:0.0, TrainErrorChange:0.14199357, Accuracy: 0.5362319
… many epochs suppressed ...
Epoch:496, Time:0ms, TrainError:0.13132669, TestError:0.0, TrainErrorChange:0.13132669, Accuracy: 0.5555556
Epoch:497, Time:0ms, TrainError:0.1313421, TestError:0.0, TrainErrorChange:0.1313421, Accuracy: 0.5555556
Epoch:498, Time:0ms, TrainError:0.13135755, TestError:0.0, TrainErrorChange:0.13135755, Accuracy: 0.5555556
Epoch:499, Time:0ms, TrainError:0.13137314, TestError:0.0, TrainErrorChange:0.13137314, Accuracy: 0.5555556
Epoch:500, Time:0ms, TrainError:0.13138853, TestError:0.0, TrainErrorChange:0.13138853, Accuracy: 0.5555556
TRAINING COMPLETED
Total Training Time: 714ms
------------------------------------------------------------------------
{false=2.2679567E-4, true=0.9997732}
...
The Iris Flower Classification is a hello world example in machine learning for multi class classification. Multi class classification is task of assigning input to one of several predefined categories. The task for Iris flower classification is to assign flowers into one of three categories based on 4 input values which represent petal and sepal dimensions (with and height). For more info about the data set see https://en.wikipedia.org/wiki/Iris_flower_data_set
STEP 1. GET DATA SET AND SPLIT INTO TRAINING AND TEST SET
// Load iris data set
DataSet dataSet = DataSetExamples.getIrisClassificationDataSet();
DataSet[] trainTest = DataSets.trainTestSplit(dataSet, 0.7);
STEP 2. BUILD THE MODEL
Create and train the multi class classifier using builder.
MultiClassClassifier<float[], String> irisClassifier = MultiClassClassifierNetwork.builder()
.inputsNum(4)
.hiddenLayers(16)
.outputsNum(3)
.maxEpochs(9000)
.maxError(0.03f)
.learningRate(0.01f)
.trainingSet(trainTest[0])
.build();
STEP 3. RUN MODEL ON NEW DATA AND PREDICT RESULT
Feed classifier with data about item to classify and get probabilities for possible classes:
Map<String, Float> results = irisClassifier.classify(new float[] {0.1f, 0.2f, 0.3f, 0.4f});
Then use Maven to run the example:
mvn exec:java -Dexec.mainClass=jsr381.example.IrisFlowersClassificationExample
Source code for this example is available at https://github.com/JavaVisRec/jsr381-examples/blob/master/src/main/java/jsr381/example/IrisFlowersClassificationExample.java
Like the Iris Flower Classification, this example has an even more complicated dataset which is also downloaded from another repository on runtime. In order to run this example, please checkout the examples repository as described in Quick Start.
STEP 1.
Download and instantiate data set. It will be downloaded to System.getProperty("java.io.tmpdir")
DataSetExamples.MnistDataSet dataSet = DataSetExamples.getMnistDataSet();
STEP 2. BUILD THE MODEL
Create and train image classifier using builder.
ImageClassifier<BufferedImage> classifier = NeuralNetImageClassifier.builder()
.inputClass(BufferedImage.class)
.imageHeight(28)
.imageWidth(28)
.labelsFile(dataSet.getLabelsFile())
.trainingFile(dataSet.getTrainingFile())
.networkArchitecture(dataSet.getNetworkArchitectureFile())
.modelFile(new File("mnist.dnet"))
.maxError(1.4f)
.maxEpochs(100)
.learningRate(0.01f)
.build();
STEP 3. USE CLASSIFIER TO CLASSIFY IMAGE
BufferedImage image = ImageIO.read(new File(input.getFile()));
Map<String, Float> results = classifier.classify(image);
Then use Maven to run the example:
mvn exec:java -Dexec.mainClass=jsr381.example.MnistDemo
Full source code for this example is available at https://github.com/JavaVisRec/jsr381-examples/blob/master/src/main/java/jsr381/example/MnistExample.java
When you build applications using VisRec JSR 381, you need 4 primary components:
-
Application
This is the code of your application. Normally, this code is the responsibility of the application developer. For the purposes of this document, we will be using the VisRec examples described in this document. -
VisRec API
The VisRec API is the OSS at the heart of the VisRec JSR. This software specifically provides the interface functionality described in the VisRec design document. -
VisRec Implementation
The OSS RI is one of several possible implementations that provides a layer from the VisRec API to the underlying AI/ML engine. With the help of the open-source community, we expect many implementations, e.g., VisRec/TensorFlow, VisRec/DeepNetts, VisRec/Watson, etc. As with all JSRs, a pre-built Reference Implementation (RI) is provided as a convenience for developers. -
AI/ML Engine
This software provides important AI/ML functionality for a particular VisRec API implementation. This low-level engine could be either OSS or closed. The code in this layer is not the responsibility of this JSR. There are many AI/ML engines available: e.g., TensorFlow, CNTK, DeepNetts, Watson, DL4J, etc. It is necessary to use the specific associated AI/ML engine for the particular implementation. The examples in this document will use the DeepNetts Community Edition as the foundational AI/ML implementation.
- Provide an implementation of DataSet interface for your data set class
- Implement Classifier or Regressor interface
- Provide static builder() method that returns corresponding Builder object for building the machine learning model