This tutorial walks through training NeoML classification model to classify the well-known News20 data set.
We are going to use the linear classifier that by default will use "one versus all" method for multiclasstask.
We assume that the data set is split into two parts: train and test, and each is serialized in a file on disk as a CMemoryProblem
(which is a simple implementation of the IProblem
interface provided in the library).
The library serialization methods can be used to load the data into memory for processing.
CPtr<CMemoryProblem> trainData = new CMemoryProblem();
CPtr<CMemoryProblem> testData = new CMemoryProblem();
CArchiveFile trainFile( "news20.train", CArchive::load );
CArchive trainArchive( &trainFile, CArchive::load );
trainArchive >> trainData;
CArchiveFile testFile( "news20.test", CArchive::load );
CArchive testArchive( &testFile, CArchive::load );
testArchive >> testData;
The "one versus all" method uses the specified classifier to train a model per each class that would determine the probability for an object to belong to this class. An input object is then classified by the models voting.
- Create a linear classifier using the
CLinear
class (by defaultCOneVersusAll
will be used for multiclass task). Select the logistic regression loss function (EF_LogReg
constant). - Call the
Train
method, passing thetrainData
training set prepared above. The method will train the model and return it as an object implementing theIModel
interface.
CLinear linear( EF_LogReg );
CPtr<IModel> model = linear.Train( *trainData );
We can check the results the trained model shows on the test sample using the Classify
method of the IModel
interface. Call this method for each vector of the testData
data set prepared before.
int correct = 0;
for( int i = 0; i < testData->GetVectorCount(); i++ ) {
CClassificationResult result;
model->Classify( testData->GetVector( i ), result );
if( result.PreferredClass == testData->GetClass( i ) ) {
correct++;
}
}
double totalResult = static_cast<double>(correct) / testData->GetVectorCount();
printf("%.3f\n", totalResult);
On this testing run, 83.3% of the vectors were classified correctly.
0.833