Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(gen-ai): remove strict validation of query response for aggregation generation #5858

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import util from 'util';
import { execFile as callbackExecFile } from 'child_process';
import decomment from 'decomment';

import {
validateAIQueryResponse,
validateAIAggregationResponse,
} from '../../src/atlas-ai-service';
import { loadFixturesToDB } from './fixtures';
import type { Fixtures } from './fixtures';
import { AtlasAPI } from './ai-backend';
Expand Down Expand Up @@ -229,6 +233,10 @@ const runOnce = async (
if (assertResult) {
let cursor;

type === 'query'
? validateAIQueryResponse(response)
: validateAIAggregationResponse(response);

if (
type === 'aggregation' ||
(type === 'query' &&
Expand Down
201 changes: 103 additions & 98 deletions packages/compass-generative-ai/src/atlas-ai-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,107 @@ function buildQueryOrAggregationMessageBody(
return msgBody;
}

function hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}

export function validateAIQueryResponse(
response: any
): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (!query && !aggregation) {
throw new Error(
'Unexpected response: expected query or aggregation, got none'
);
}

if (query && typeof query !== 'object') {
throw new Error('Unexpected response: expected query to be an object');
}

if (
hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

export function validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (content.aggregation && typeof content.aggregation.pipeline !== 'string') {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

export class AtlasAiService {
private initPromise: Promise<void> | null = null;

Expand Down Expand Up @@ -240,110 +341,18 @@ export class AtlasAiService {
return this.getQueryOrAggregationFromUserInput(
AGGREGATION_URI,
input,
this.validateAIAggregationResponse.bind(this)
validateAIAggregationResponse
);
}

async getQueryFromUserInput(input: GenerativeAiInput) {
return this.getQueryOrAggregationFromUserInput(
QUERY_URI,
input,
this.validateAIQueryResponse.bind(this)
validateAIQueryResponse
);
}

private validateAIQueryResponse(response: any): asserts response is AIQuery {
const { content } = response ?? {};

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['query', 'aggregation'])) {
throw new Error(
'Unexpected keys in response: expected query and aggregation'
);
}

const { query, aggregation } = content;

if (typeof query !== 'object' || query === null) {
throw new Error('Unexpected response: expected query to be an object');
}

if (
this.hasExtraneousKeys(query, [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
])
) {
throw new Error(
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
);
}

for (const field of [
'filter',
'project',
'collation',
'sort',
'skip',
'limit',
]) {
if (query[field] && typeof query[field] !== 'string') {
throw new Error(
`Unexpected response: expected field ${field} to be a string, got ${JSON.stringify(
query[field],
null,
2
)}`
);
}
}

if (aggregation && typeof aggregation.pipeline !== 'string') {
throw new Error(
`Unexpected response: expected aggregation pipeline to be a string, got ${JSON.stringify(
aggregation,
null,
2
)}`
);
}
}

private validateAIAggregationResponse(
response: any
): asserts response is AIAggregation {
const { content } = response;

if (typeof content !== 'object' || content === null) {
throw new Error('Unexpected response: expected content to be an object');
}

if (this.hasExtraneousKeys(content, ['aggregation'])) {
throw new Error('Unexpected keys in response: expected aggregation');
}

if (
content.aggregation &&
typeof content.aggregation.pipeline !== 'string'
) {
// Compared to queries where we will always get the `query` field, for
// aggregations backend deletes the whole `aggregation` key if pipeline is
// empty, so we only validate `pipeline` key if `aggregation` key is present
throw new Error(
`Unexpected response: expected aggregation to be a string, got ${String(
content.aggregation.pipeline
)}`
);
}
}

private validateAIFeatureEnablementResponse(
response: any
): asserts response is AIFeatureEnablement {
Expand All @@ -352,8 +361,4 @@ export class AtlasAiService {
throw new Error('Unexpected response: expected features to be an object');
}
}

private hasExtraneousKeys(obj: any, expectedKeys: string[]) {
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
}
}
Loading