From c984a4deab72e7036d6bdfc531e90ef1c15580d8 Mon Sep 17 00:00:00 2001 From: Jonas B Date: Mon, 27 May 2024 21:03:56 +0200 Subject: [PATCH] feat: infer column type from runner --- .../src/extension/eda/apis/runnerApi.ts | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts b/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts index e55e55ce3..c09eafe0b 100644 --- a/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts +++ b/packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts @@ -263,6 +263,12 @@ export class RunnerApi { private sdsStringForCorrelationHeatmap(tablePlaceholder: string, newPlaceholderName: string) { return 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.plot.correlationHeatmap(); \n'; } + + private sdsStringForIsNumeric(tablePlaceholder: string, columnName: string, newPlaceholderName: string) { + return ( + 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.getColumn("' + columnName + '").isNumeric; \n' + ); + } //#endregion //#region Placeholder handling @@ -283,6 +289,9 @@ export class RunnerApi { return; } this.services.runtime.PythonServer.removeMessageCallback('placeholder_value', placeholderValueCallback); + safeDsLogger.debug( + 'Got placeholder value: ' + JSON.stringify(message.data.value).slice(0, 100) + '...', + ); resolve(message.data.value); }; @@ -304,8 +313,28 @@ export class RunnerApi { //#region Table fetching public async getTableByPlaceholder(tableName: string, pipelineExecutionId: string): Promise { safeDsLogger.debug('Getting table by placeholder: ' + tableName); + const pythonTableColumns = await this.getPlaceholderValue(tableName, pipelineExecutionId); if (pythonTableColumns) { + // Get Column Types + safeDsLogger.debug('Getting column types for table: ' + tableName); + let sdsLines = ''; + let placeholderNames: string[] = []; + let columnNameToPlaceholderIsNumericNameMap = new Map(); + for (const columnName of Object.keys(pythonTableColumns)) { + const newPlaceholderName = this.genPlaceholderName(columnName + '_type'); + columnNameToPlaceholderIsNumericNameMap.set(columnName, newPlaceholderName); + placeholderNames.push(newPlaceholderName); + sdsLines += this.sdsStringForIsNumeric(tableName, columnName, newPlaceholderName); + } + + await this.addToAndExecutePipeline(pipelineExecutionId, sdsLines, placeholderNames); + const columnIsNumeric = new Map(); + for (const [columnName, placeholderName] of columnNameToPlaceholderIsNumericNameMap) { + const columnType = await this.getPlaceholderValue(placeholderName, pipelineExecutionId); + columnIsNumeric.set(columnName, columnType as string); + } + const table: Table = { totalRows: 0, name: tableName, @@ -322,8 +351,7 @@ export class RunnerApi { currentMax = columnValues.length; } - const isNumerical = typeof columnValues[0] === 'number'; - const columnType = isNumerical ? 'numerical' : 'categorical'; + const columnType = columnIsNumeric.get(columnName) ? 'numerical' : 'categorical'; const column: Column = { name: columnName,