Skip to content

Commit

Permalink
[ML] API Integration tests: adds test for Data Frame Analytics evalua…
Browse files Browse the repository at this point in the history
…te endpoint (#97856)

* wip: add api test for evaluate endpoint

* Add api test for evaluate endpoint

* add tests for view only and unauthorized user

Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
alvarezmelissa87 and kibanamachine authored Apr 23, 2021
1 parent ec8ff3a commit d6e0251
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 0 deletions.
188 changes: 188 additions & 0 deletions x-pack/test/api_integration/apis/ml/data_frame_analytics/evaluate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import expect from '@kbn/expect';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/ml/security_common';
import { DataFrameAnalyticsConfig } from '../../../../../plugins/ml/public/application/data_frame_analytics/common';
import { DeepPartial } from '../../../../../plugins/ml/common/types/common';
import { COMMON_REQUEST_HEADERS } from '../../../../functional/services/ml/common_api';

export default ({ getService }: FtrProviderContext) => {
const esArchiver = getService('esArchiver');
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');

const currentTime = `${Date.now()}`;
const generateDestinationIndex = (analyticsId: string) => `user-${analyticsId}`;
const jobEval: any = {
regression: {
index: generateDestinationIndex(`regression_${currentTime}`),
evaluation: {
regression: {
actual_field: 'stab',
predicted_field: 'ml.stab_prediction',
metrics: {
r_squared: {},
mse: {},
msle: {},
huber: {},
},
},
},
},
classification: {
index: generateDestinationIndex(`classification_${currentTime}`),
evaluation: {
classification: {
actual_field: 'y',
predicted_field: 'ml.y_prediction',
metrics: { multiclass_confusion_matrix: {}, accuracy: {}, recall: {} },
},
},
},
};
const jobAnalysis: any = {
classification: {
source: {
index: ['ft_bank_marketing'],
query: {
match_all: {},
},
},
analysis: {
classification: {
dependent_variable: 'y',
training_percent: 20,
},
},
},
regression: {
source: {
index: ['ft_egs_regression'],
query: {
match_all: {},
},
},
analysis: {
regression: {
dependent_variable: 'stab',
training_percent: 20,
},
},
},
};

interface TestConfig {
jobType: string;
config: DeepPartial<DataFrameAnalyticsConfig>;
eval: any;
}

const testJobConfigs: TestConfig[] = ['regression', 'classification'].map((jobType, idx) => {
const analyticsId = `${jobType}_${currentTime}`;
return {
jobType,
config: {
id: analyticsId,
description: `Testing ${jobType} evaluation`,
dest: {
index: generateDestinationIndex(analyticsId),
results_field: 'ml',
},
analyzed_fields: {
includes: [],
excludes: [],
},
model_memory_limit: '60mb',
...jobAnalysis[jobType],
},
eval: jobEval[jobType],
};
});

async function createJobs(mockJobConfigs: TestConfig[]) {
for (const jobConfig of mockJobConfigs) {
await ml.api.createAndRunDFAJob(jobConfig.config as DataFrameAnalyticsConfig);
}
}

describe('POST data_frame/_evaluate', () => {
before(async () => {
await esArchiver.loadIfNeeded('ml/bm_classification');
await esArchiver.loadIfNeeded('ml/egs_regression');
await ml.testResources.setKibanaTimeZoneToUTC();
await createJobs(testJobConfigs);
});

after(async () => {
await ml.api.cleanMlIndices();
});

testJobConfigs.forEach((testConfig) => {
describe(`EvaluateDataFrameAnalytics ${testConfig.jobType}`, async () => {
it(`should evaluate ${testConfig.jobType} analytics job`, async () => {
const { body } = await supertest
.post(`/api/ml/data_frame/_evaluate`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_REQUEST_HEADERS)
.send(testConfig.eval)
.expect(200);

if (testConfig.jobType === 'classification') {
const { classification } = body;
expect(body).to.have.property('classification');
expect(classification).to.have.property('recall');
expect(classification).to.have.property('accuracy');
expect(classification).to.have.property('multiclass_confusion_matrix');
} else {
const { regression } = body;
expect(body).to.have.property('regression');
expect(regression).to.have.property('mse');
expect(regression).to.have.property('msle');
expect(regression).to.have.property('r_squared');
}
});

it(`should evaluate ${testConfig.jobType} job for the user with only view permission`, async () => {
const { body } = await supertest
.post(`/api/ml/data_frame/_evaluate`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(COMMON_REQUEST_HEADERS)
.send(testConfig.eval)
.expect(200);

if (testConfig.jobType === 'classification') {
const { classification } = body;
expect(body).to.have.property('classification');
expect(classification).to.have.property('recall');
expect(classification).to.have.property('accuracy');
expect(classification).to.have.property('multiclass_confusion_matrix');
} else {
const { regression } = body;
expect(body).to.have.property('regression');
expect(regression).to.have.property('mse');
expect(regression).to.have.property('msle');
expect(regression).to.have.property('r_squared');
}
});

it(`should not allow unauthorized user to evaluate ${testConfig.jobType} job`, async () => {
const { body } = await supertest
.post(`/api/ml/data_frame/_evaluate`)
.auth(USER.ML_UNAUTHORIZED, ml.securityCommon.getPasswordForUser(USER.ML_UNAUTHORIZED))
.set(COMMON_REQUEST_HEADERS)
.send(testConfig.eval)
.expect(403);

expect(body.error).to.eql('Forbidden');
expect(body.message).to.eql('Forbidden');
});
});
});
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export default function ({ loadTestFile }: FtrProviderContext) {
loadTestFile(require.resolve('./get_spaces'));
loadTestFile(require.resolve('./update_spaces'));
loadTestFile(require.resolve('./delete_spaces'));
loadTestFile(require.resolve('./evaluate'));
loadTestFile(require.resolve('./explain'));
});
}

0 comments on commit d6e0251

Please sign in to comment.