Skip to content

Commit

Permalink
[Automatic Import] Fix yarn draw-graphs command (#198229)
Browse files Browse the repository at this point in the history
## Summary

Fixes #196425.

It turns out the reason `yarn draw-graphs` produced the three-box graphs
was because `.withConfig` creates an instance of RunnableInput which does 
not have a good way to draw itself other than as three boxes.

The solution was to makes sure we are calling the original version
without `.withConfig` when drawing the graphs. We still call the new
version when invoking them, as demonstrated by the run names here.

We are now able to generate the correct graphs for all chains.

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
ilyannn and elasticmachine authored Nov 1, 2024
1 parent 5544b1a commit 0ecef0a
Show file tree
Hide file tree
Showing 22 changed files with 59 additions and 32 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 17 additions & 12 deletions x-pack/plugins/integration_assistant/scripts/draw_graphs_script.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import { getEcsGraph, getEcsSubGraph } from '../server/graphs/ecs/graph';
import { getLogFormatDetectionGraph } from '../server/graphs/log_type_detection/graph';
import { getRelatedGraph } from '../server/graphs/related/graph';
import { getKVGraph } from '../server/graphs/kv/graph';
import { getUnstructuredGraph } from '../server/graphs/unstructured';
import { getCelGraph } from '../server/graphs/cel/graph';

// Some mock elements just to get the graph to compile
const model = new FakeLLM({
Expand All @@ -45,17 +47,20 @@ async function drawGraph(compiledGraph: RunnableGraph, graphName: string) {
await saveFile(`${graphName}.png`, buffer);
}

const GRAPH_LIST = {
related_graph: getRelatedGraph,
log_detection_graph: getLogFormatDetectionGraph,
categorization_graph: getCategorizationGraph,
kv_graph: getKVGraph,
ecs_graph: getEcsGraph,
ecs_subgraph: getEcsSubGraph,
unstructured_graph: getUnstructuredGraph,
cel_graph: getCelGraph,
};

export async function drawGraphs() {
const relatedGraph = (await getRelatedGraph({ client, model })).getGraph();
const logFormatDetectionGraph = (await getLogFormatDetectionGraph({ client, model })).getGraph();
const categorizationGraph = (await getCategorizationGraph({ client, model })).getGraph();
const ecsSubGraph = (await getEcsSubGraph({ model })).getGraph();
const ecsGraph = (await getEcsGraph({ model })).getGraph();
const kvGraph = (await getKVGraph({ client, model })).getGraph();
drawGraph(relatedGraph, 'related_graph');
drawGraph(logFormatDetectionGraph, 'log_detection_graph');
drawGraph(categorizationGraph, 'categorization_graph');
drawGraph(ecsSubGraph, 'ecs_subgraph');
drawGraph(ecsGraph, 'ecs_graph');
drawGraph(kvGraph, 'kv_graph');
for (const [name, graph] of Object.entries(GRAPH_LIST)) {
const compiledGraph = (await graph({ client, model })).getGraph();
drawGraph(compiledGraph, name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,6 @@ export async function getCategorizationGraph({ client, model }: CategorizationGr
}
);

const compiledCategorizationGraph = workflow.compile().withConfig({ runName: 'Categorization' });
const compiledCategorizationGraph = workflow.compile();
return compiledCategorizationGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ export async function getCelGraph({ model }: CelInputGraphParams) {
.addEdge('handleGetStateVariables', 'handleGetStateDetails')
.addEdge('handleGetStateDetails', 'modelOutput');

const compiledCelGraph = workflow.compile().withConfig({ runName: 'CEL' });
const compiledCelGraph = workflow.compile();
return compiledCelGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export async function getEcsSubGraph({ model }: EcsGraphParams) {
})
.addEdge('modelSubOutput', END);

const compiledEcsSubGraph = workflow.compile().withConfig({ runName: 'ECS Mapping (Chunk)' });
const compiledEcsSubGraph = workflow.compile();
return compiledEcsSubGraph;
}

Expand All @@ -96,7 +96,7 @@ export async function getEcsGraph({ model }: EcsGraphParams) {
.addNode('handleMergedSubGraphResponse', (state: EcsMappingState) =>
modelMergedInputFromSubGraph({ state })
)
.addNode('subGraph', subGraph)
.addNode('subGraph', subGraph.withConfig({ runName: 'ECS Mapping (Chunk)' }))
.addEdge(START, 'modelInput')
.addEdge('subGraph', 'handleMergedSubGraphResponse')
.addEdge('handleDuplicates', 'handleValidation')
Expand All @@ -119,6 +119,6 @@ export async function getEcsGraph({ model }: EcsGraphParams) {
})
.addEdge('modelOutput', END);

const compiledEcsGraph = workflow.compile().withConfig({ runName: 'ECS Mapping' });
const compiledEcsGraph = workflow.compile();
return compiledEcsGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,6 @@ export async function getKVGraph({ model, client }: KVGraphParams) {
})
.addEdge('modelOutput', END);

const compiledKVGraph = workflow.compile().withConfig({ runName: 'Key-Value' });
const compiledKVGraph = workflow.compile();
return compiledKVGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,14 @@ export async function getLogFormatDetectionGraph({ model, client }: LogDetection
.addNode('handleLogFormatDetection', (state: LogFormatDetectionState) =>
handleLogFormatDetection({ state, model, client })
)
.addNode('handleKVGraph', await getKVGraph({ model, client }))
.addNode('handleUnstructuredGraph', await getUnstructuredGraph({ model, client }))
.addNode(
'handleKVGraph',
(await getKVGraph({ model, client })).withConfig({ runName: 'Key-Value' })
)
.addNode(
'handleUnstructuredGraph',
(await getUnstructuredGraph({ model, client })).withConfig({ runName: 'Unstructured' })
)
.addNode('handleCSV', (state: LogFormatDetectionState) => handleCSV({ state, model, client }))
.addEdge(START, 'modelInput')
.addEdge('modelInput', 'handleLogFormatDetection')
Expand All @@ -138,6 +144,6 @@ export async function getLogFormatDetectionGraph({ model, client }: LogDetection
}
);

const compiledLogFormatDetectionGraph = workflow.compile().withConfig({ runName: 'Log Format' });
const compiledLogFormatDetectionGraph = workflow.compile();
return compiledLogFormatDetectionGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,6 @@ export async function getRelatedGraph({ client, model }: RelatedGraphParams) {
}
);

const compiledRelatedGraph = workflow.compile().withConfig({ runName: 'Related' });
const compiledRelatedGraph = workflow.compile();
return compiledRelatedGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,6 @@ export async function getUnstructuredGraph({ model, client }: UnstructuredGraphP
.addEdge('handleUnstructuredError', 'handleUnstructuredValidate')
.addEdge('modelOutput', END);

const compiledUnstructuredGraph = workflow.compile().withConfig({ runName: 'Unstructured' });
const compiledUnstructuredGraph = workflow.compile();
return compiledUnstructuredGraph;
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ export function registerAnalyzeLogsRoutes(
logSamples,
};
const graph = await getLogFormatDetectionGraph({ model, client });
const graphResults = await graph.invoke(logFormatParameters, options);
const graphResults = await graph
.withConfig({ runName: 'Log Format' })
.invoke(logFormatParameters, options);
const graphLogFormat = graphResults.results.samplesFormat.name;
if (graphLogFormat === 'unsupported') {
throw new UnsupportedLogFormatError(GenerationErrorCode.UNSUPPORTED_LOG_SAMPLES_FORMAT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ const mockResult = jest.fn().mockResolvedValue({
jest.mock('../graphs/categorization', () => {
return {
getCategorizationGraph: jest.fn().mockResolvedValue({
invoke: () => mockResult(),
withConfig: () => ({
invoke: () => mockResult(),
}),
}),
};
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ export function registerCategorizationRoutes(
};

const graph = await getCategorizationGraph({ client, model });
const results = await graph.invoke(parameters, options);
const results = await graph
.withConfig({ runName: 'Categorization' })
.invoke(parameters, options);

return res.ok({ body: CategorizationResponse.parse(results) });
} catch (err) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ const mockResult = jest.fn().mockResolvedValue({
jest.mock('../graphs/cel', () => {
return {
getCelGraph: jest.fn().mockResolvedValue({
invoke: () => mockResult(),
withConfig: () => ({
invoke: () => mockResult(),
}),
}),
};
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function registerCelInputRoutes(router: IRouter<IntegrationAssistantRoute
};

const graph = await getCelGraph({ model });
const results = await graph.invoke(parameters, options);
const results = await graph.withConfig({ runName: 'CEL' }).invoke(parameters, options);

return res.ok({ body: CelInputResponse.parse(results) });
} catch (e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ const mockResult = jest.fn().mockResolvedValue({
jest.mock('../graphs/ecs', () => {
return {
getEcsGraph: jest.fn().mockResolvedValue({
invoke: () => mockResult(),
withConfig: () => ({
invoke: () => mockResult(),
}),
}),
};
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ export function registerEcsRoutes(router: IRouter<IntegrationAssistantRouteHandl
};

const graph = await getEcsGraph({ model });
const results = await graph.invoke(parameters, options);
const results = await graph
.withConfig({ runName: 'ECS Mapping' })
.invoke(parameters, options);

return res.ok({ body: EcsMappingResponse.parse(results) });
} catch (err) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ const mockResult = jest.fn().mockResolvedValue({
jest.mock('../graphs/related', () => {
return {
getRelatedGraph: jest.fn().mockResolvedValue({
invoke: () => mockResult(),
withConfig: () => ({
invoke: () => mockResult(),
}),
}),
};
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ export function registerRelatedRoutes(router: IRouter<IntegrationAssistantRouteH
};

const graph = await getRelatedGraph({ client, model });
const results = await graph.invoke(parameters, options);
const results = await graph
.withConfig({ runName: 'Related' })
.invoke(parameters, options);
return res.ok({ body: RelatedResponse.parse(results) });
} catch (err) {
try {
Expand Down

0 comments on commit 0ecef0a

Please sign in to comment.