Skip to content

Commit

Permalink
feat: infer column type from runner
Browse files Browse the repository at this point in the history
  • Loading branch information
SmiteDeluxe committed May 27, 2024
1 parent d83c3d4 commit c984a4d
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions packages/safe-ds-vscode/src/extension/eda/apis/runnerApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ export class RunnerApi {
private sdsStringForCorrelationHeatmap(tablePlaceholder: string, newPlaceholderName: string) {
return 'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.plot.correlationHeatmap(); \n';
}

private sdsStringForIsNumeric(tablePlaceholder: string, columnName: string, newPlaceholderName: string) {
return (
'val ' + newPlaceholderName + ' = ' + tablePlaceholder + '.getColumn("' + columnName + '").isNumeric; \n'
);
}
//#endregion

//#region Placeholder handling
Expand All @@ -283,6 +289,9 @@ export class RunnerApi {
return;
}
this.services.runtime.PythonServer.removeMessageCallback('placeholder_value', placeholderValueCallback);
safeDsLogger.debug(
'Got placeholder value: ' + JSON.stringify(message.data.value).slice(0, 100) + '...',
);
resolve(message.data.value);
};

Expand All @@ -304,8 +313,28 @@ export class RunnerApi {
//#region Table fetching
public async getTableByPlaceholder(tableName: string, pipelineExecutionId: string): Promise<Table | undefined> {
safeDsLogger.debug('Getting table by placeholder: ' + tableName);

const pythonTableColumns = await this.getPlaceholderValue(tableName, pipelineExecutionId);
if (pythonTableColumns) {
// Get Column Types
safeDsLogger.debug('Getting column types for table: ' + tableName);
let sdsLines = '';
let placeholderNames: string[] = [];
let columnNameToPlaceholderIsNumericNameMap = new Map<string, string>();
for (const columnName of Object.keys(pythonTableColumns)) {
const newPlaceholderName = this.genPlaceholderName(columnName + '_type');
columnNameToPlaceholderIsNumericNameMap.set(columnName, newPlaceholderName);
placeholderNames.push(newPlaceholderName);
sdsLines += this.sdsStringForIsNumeric(tableName, columnName, newPlaceholderName);
}

await this.addToAndExecutePipeline(pipelineExecutionId, sdsLines, placeholderNames);
const columnIsNumeric = new Map<string, string>();
for (const [columnName, placeholderName] of columnNameToPlaceholderIsNumericNameMap) {
const columnType = await this.getPlaceholderValue(placeholderName, pipelineExecutionId);
columnIsNumeric.set(columnName, columnType as string);
}

const table: Table = {
totalRows: 0,
name: tableName,
Expand All @@ -322,8 +351,7 @@ export class RunnerApi {
currentMax = columnValues.length;
}

const isNumerical = typeof columnValues[0] === 'number';
const columnType = isNumerical ? 'numerical' : 'categorical';
const columnType = columnIsNumeric.get(columnName) ? 'numerical' : 'categorical';

const column: Column = {
name: columnName,
Expand Down

0 comments on commit c984a4d

Please sign in to comment.