Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(gen-ai): remove strict validation of query response for aggregation generation #5858

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
chore: use the same response validation in ai accuracy tests as we do…
… in the code
Anemy committed Jun 3, 2024
commit e546cc1dd6cbbcee8a16d525e37159d4a9935fb1
Original file line number Diff line number Diff line change
@@ -29,6 +29,10 @@ import util from 'util';
import { execFile as callbackExecFile } from 'child_process';
import decomment from 'decomment';

import {
validateAIQueryResponse,
validateAIAggregationResponse,
} from '../../src/atlas-ai-service';
import { loadFixturesToDB } from './fixtures';
import type { Fixtures } from './fixtures';
import { AtlasAPI } from './ai-backend';
@@ -229,6 +233,10 @@ const runOnce = async (
if (assertResult) {
let cursor;

type === 'query'
? validateAIQueryResponse(response)
: validateAIAggregationResponse(response);

if (
type === 'aggregation' ||
(type === 'query' &&
207 changes: 103 additions & 104 deletions packages/compass-generative-ai/src/atlas-ai-service.ts
Original file line number Diff line number Diff line change
@@ -94,6 +94,107 @@ function buildQueryOrAggregationMessageBody(
return msgBody;
}

function hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}

export function validateAIQueryResponse(
response: any
): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (!query && !aggregation) {
throw new Error(
'Unexpected response: expected query or aggregation, got none'
);
}

if (query && typeof query !== 'object') {
throw new Error('Unexpected response: expected query to be an object');
}

if (
hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

export function validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (content.aggregation && typeof content.aggregation.pipeline !== 'string') {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

export class AtlasAiService {
private initPromise: Promise<void> | null = null;

@@ -240,116 +341,18 @@ export class AtlasAiService {
return this.getQueryOrAggregationFromUserInput(
AGGREGATION_URI,
input,
this.validateAIAggregationResponse.bind(this)
validateAIAggregationResponse
);
}

async getQueryFromUserInput(input: GenerativeAiInput) {
return this.getQueryOrAggregationFromUserInput(
QUERY_URI,
input,
this.validateAIQueryResponse.bind(this)
validateAIQueryResponse
);
}

private validateAIQueryResponse(response: any): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (!query && !aggregation) {
throw new Error(
'Unexpected response: expected query or aggregation, got none'
);
}

if (query && typeof query !== 'object') {
throw new Error('Unexpected response: expected query to be an object');
}

if (
this.hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

private validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (
content.aggregation &&
typeof content.aggregation.pipeline !== 'string'
) {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

private validateAIFeatureEnablementResponse(
response: any
): asserts response is AIFeatureEnablement {
@@ -358,8 +361,4 @@ export class AtlasAiService {
throw new Error('Unexpected response: expected features to be an object');
}
}

private hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}
}