Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: this expression #1111

Merged
merged 16 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
isSdsClass,
isSdsDeclaration,
isSdsEnumVariant,
isSdsExpression,
isSdsExpressionLambda,
isSdsExpressionStatement,
isSdsFunction,
Expand All @@ -48,6 +49,7 @@ import {
isSdsTemplateStringInner,
isSdsTemplateStringPart,
isSdsTemplateStringStart,
isSdsThis,
isSdsTypeCast,
isSdsWildcard,
isSdsYield,
Expand Down Expand Up @@ -115,6 +117,7 @@ import { CODEGEN_PREFIX } from './constants.js';

const LAMBDA_PREFIX = `${CODEGEN_PREFIX}lambda_`;
const BLOCK_LAMBDA_RESULT_PREFIX = `${CODEGEN_PREFIX}block_lambda_result_`;
const RECEIVER_PREFIX = `${CODEGEN_PREFIX}receiver_`;
const YIELD_PREFIX = `${CODEGEN_PREFIX}yield_`;

const RUNNER_PACKAGE = 'safeds_runner';
Expand Down Expand Up @@ -541,35 +544,21 @@ export class SafeDsPythonGenerator {

private generateStatement(statement: SdsStatement, frame: GenerationInfoFrame, generateLambda: boolean): Generated {
const result: Generated[] = [];

if (isSdsAssignment(statement)) {
if (statement.expression) {
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
}
result.push(this.generateAssignment(statement, frame, generateLambda));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
const assignment = this.generateAssignment(statement, frame, generateLambda);
result.push(...frame.getExtraStatements(), assignment);
} else if (isSdsExpressionStatement(statement)) {
for (const node of AstUtils.streamAllContents(statement.expression)) {
if (isSdsBlockLambda(node)) {
result.push(this.generateBlockLambda(node, frame));
} else if (isSdsExpressionLambda(node)) {
result.push(this.generateExpressionLambda(node, frame));
}
}
result.push(this.generateExpression(statement.expression, frame));
return joinTracedToNode(statement)(result, (stmt) => stmt, {
separator: NL,
})!;
}
/* c8 ignore next 2 */
throw new Error(`Unknown SdsStatement: ${statement}`);
const expressionStatement = this.generateExpression(statement.expression, frame);
result.push(...frame.getExtraStatements(), expressionStatement);
} /* c8 ignore start */ else {
throw new Error(`Unknown statement: ${statement}`);
} /* c8 ignore stop */

frame.resetExtraStatements();
return joinTracedToNode(statement)(result, {
separator: NL,
});
}

private generateAssignment(
Expand Down Expand Up @@ -656,28 +645,39 @@ export class SafeDsPythonGenerator {
)}`,
);
}
return expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaName(

const extraStatement = expandTracedToNode(blockLambda)`def ${frame.getUniqueLambdaName(
blockLambda,
)}(${this.generateParameters(blockLambda.parameterList, frame)}):`
.appendNewLine()
.indent({
indentedChildren: [lambdaBlock],
indentation: PYTHON_INDENT,
});
frame.addExtraStatement(blockLambda, extraStatement);

return traceToNode(blockLambda)(frame.getUniqueLambdaName(blockLambda));
}

private generateExpressionLambda(node: SdsExpressionLambda, frame: GenerationInfoFrame): Generated {
const name = frame.getUniqueLambdaName(node);
const parameters = this.generateParameters(node.parameterList, frame);
const result = this.generateExpression(node.result, frame);

return expandTracedToNode(node)`
const extraStatement = expandTracedToNode(node)`
def ${name}(${parameters}):
return ${result}
`;
frame.addExtraStatement(node, extraStatement);

return traceToNode(node)(name);
}

private generateExpression(expression: SdsExpression, frame: GenerationInfoFrame): Generated {
private generateExpression(
expression: SdsExpression,
frame: GenerationInfoFrame,
thisParam?: Generated,
): Generated {
if (isSdsTemplateStringPart(expression)) {
if (isSdsTemplateStringStart(expression)) {
return expandTracedToNode(expression)`${this.formatStringSingleLine(expression.value)}{ `;
Expand Down Expand Up @@ -731,7 +731,7 @@ export class SafeDsPythonGenerator {
{ separator: ', ' },
)}]`;
} else if (isSdsBlockLambda(expression)) {
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
return this.generateBlockLambda(expression, frame);
} else if (isSdsCall(expression)) {
const callable = this.nodeMapper.callToCallable(expression);
const receiver = this.generateExpression(expression.receiver, frame);
Expand All @@ -742,20 +742,20 @@ export class SafeDsPythonGenerator {
if (isSdsFunction(callable)) {
const pythonCall = this.builtinAnnotations.getPythonCall(callable);
if (pythonCall) {
let thisParam: Generated | undefined = undefined;
let newReceiver: SdsExpression | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = this.generateExpression(expression.receiver.receiver, frame);
newReceiver = expression.receiver.receiver;
}
const argumentsMap = this.getArgumentsMap(getArguments(expression), frame);
call = this.generatePythonCall(expression, pythonCall, argumentsMap, frame, thisParam);
call = this.generatePythonCall(expression, pythonCall, argumentsMap, frame, newReceiver);
}
}
if (!call && this.isMemoizableCall(expression) && !frame.disableRunnerIntegration) {
let thisParam: Generated | undefined = undefined;
let newReceiver: SdsExpression | undefined = undefined;
if (isSdsMemberAccess(expression.receiver)) {
thisParam = this.generateExpression(expression.receiver.receiver, frame);
newReceiver = expression.receiver.receiver;
}
call = this.generateMemoizedCall(expression, frame, thisParam);
call = this.generateMemoizedCall(expression, frame, newReceiver);
}
}

Expand All @@ -773,7 +773,7 @@ export class SafeDsPythonGenerator {
return call;
}
} else if (isSdsExpressionLambda(expression)) {
return traceToNode(expression)(frame.getUniqueLambdaName(expression));
return this.generateExpressionLambda(expression, frame);
} else if (isSdsInfixOperation(expression)) {
const leftOperand = this.generateExpression(expression.leftOperand, frame);
const rightOperand = this.generateExpression(expression.rightOperand, frame);
Expand Down Expand Up @@ -871,6 +871,8 @@ export class SafeDsPythonGenerator {
const referenceImport = this.createImportDataForReference(expression);
frame.addImport(referenceImport);
return traceToNode(expression)(referenceImport?.alias ?? this.getPythonNameOrDefault(declaration));
} else if (isSdsThis(expression)) {
return thisParam;
} else if (isSdsTypeCast(expression)) {
return traceToNode(expression)(this.generateExpression(expression.expression, frame));
}
Expand All @@ -892,9 +894,17 @@ export class SafeDsPythonGenerator {
pythonCall: string,
argumentsMap: Map<string, Generated>,
frame: GenerationInfoFrame,
thisParam: Generated | undefined = undefined,
receiver: SdsExpression | undefined,
): Generated {
if (thisParam) {
let thisParam: Generated | undefined = undefined;

if (receiver) {
thisParam = frame.getUniqueReceiverName(receiver);
const extraStatement = expandTracedToNode(receiver)`
${thisParam} = ${this.generateExpression(receiver, frame)}
`;
frame.addExtraStatement(receiver, extraStatement);

argumentsMap.set('this', thisParam);
}
const splitRegex = /(\$[_a-zA-Z][_a-zA-Z0-9]*)/gu;
Expand All @@ -919,7 +929,7 @@ export class SafeDsPythonGenerator {
const fullyQualifiedTargetName = this.generateFullyQualifiedFunctionName(expression);
const hiddenParameters = this.getMemoizedCallHiddenParameters(expression, frame);

if (isSdsFunction(callable) && !isStatic(callable) && isSdsMemberAccess(expression.receiver)) {
if (isSdsFunction(callable) && !isStatic(callable) && isSdsMemberAccess(expression.receiver) && thisParam) {
return expandTracedToNode(expression)`
${MEMOIZED_STATIC_CALL}(
"${fullyQualifiedTargetName}",
Expand Down Expand Up @@ -980,12 +990,12 @@ export class SafeDsPythonGenerator {
private generateMemoizedCall(
expression: SdsCall,
frame: GenerationInfoFrame,
thisParam: Generated | undefined = undefined,
receiver: SdsExpression | undefined,
): Generated {
const callable = this.nodeMapper.callToCallable(expression);

if (isSdsFunction(callable) && !isStatic(callable) && isSdsMemberAccess(expression.receiver)) {
return this.generateMemoizedDynamicCall(expression, callable, thisParam, frame);
if (isSdsFunction(callable) && !isStatic(callable) && isSdsExpression(receiver)) {
return this.generateMemoizedDynamicCall(expression, callable, receiver, frame);
} else {
return this.generateMemoizedStaticCall(expression, callable, frame);
}
Expand All @@ -994,19 +1004,24 @@ export class SafeDsPythonGenerator {
private generateMemoizedDynamicCall(
expression: SdsCall,
callable: SdsFunction,
thisParam: Generated,
receiver: SdsExpression,
frame: GenerationInfoFrame,
) {
frame.addImport({ importPath: RUNNER_PACKAGE });

const hiddenParameters = this.getMemoizedCallHiddenParameters(expression, frame);
const thisParam = frame.getUniqueReceiverName(receiver);
const extraStatement = expandTracedToNode(receiver)`
${thisParam} = ${this.generateExpression(receiver, frame)}
`;
frame.addExtraStatement(receiver, extraStatement);

return expandTracedToNode(expression)`
${MEMOIZED_DYNAMIC_CALL}(
${thisParam},
"${this.getPythonNameOrDefault(callable)}",
[${this.generateMemoizedPositionalArgumentList(expression, frame)}],
{${this.generateMemoizedKeywordArgumentList(expression, frame)}},
{${this.generateMemoizedKeywordArgumentList(expression, frame, thisParam)}},
[${joinToNode(hiddenParameters, (param) => param, { separator: ', ' })}]
)
`;
Expand Down Expand Up @@ -1060,7 +1075,11 @@ export class SafeDsPythonGenerator {
);
}

private generateMemoizedKeywordArgumentList(node: SdsCall, frame: GenerationInfoFrame): Generated {
private generateMemoizedKeywordArgumentList(
node: SdsCall,
frame: GenerationInfoFrame,
thisParam?: Generated,
): Generated {
const callable = this.nodeMapper.callToCallable(node);
const parameters = getParameters(callable);
const optionalParameters = getParameters(callable).filter(Parameter.isOptional);
Expand All @@ -1070,7 +1089,7 @@ export class SafeDsPythonGenerator {
optionalParameters,
(parameter) => {
const argument = parametersToArgument.get(parameter);
return expandToNode`"${this.getPythonNameOrDefault(parameter)}": ${this.generateMemoizedArgument(argument, parameter, frame)}`;
return expandToNode`"${this.getPythonNameOrDefault(parameter)}": ${this.generateMemoizedArgument(argument, parameter, frame, thisParam)}`;
},
{
separator: ', ',
Expand All @@ -1082,14 +1101,15 @@ export class SafeDsPythonGenerator {
argument: SdsArgument | undefined,
parameter: SdsParameter,
frame: GenerationInfoFrame,
thisParam?: Generated | undefined,
): Generated {
const value = argument?.value ?? parameter?.defaultValue;
if (!value) {
/* c8 ignore next 2 */
throw new Error(`No value passed for required parameter "${parameter.name}".`);
}

const result = this.generateExpression(value, frame);
const result = this.generateExpression(value, frame, thisParam);
if (!this.isMemoizedPath(parameter)) {
return result;
}
Expand Down Expand Up @@ -1252,13 +1272,14 @@ interface ImportData {
}

class GenerationInfoFrame {
private readonly lambdaManager: IdManager<SdsLambda>;
private readonly idManager: IdManager<SdsExpression>;
private readonly importSet: Map<String, ImportData>;
private readonly utilitySet: Set<UtilityFunction>;
private readonly typeVariableSet: Set<string>;
public readonly isInsidePipeline: boolean;
public readonly targetPlaceholder: string | undefined;
public readonly disableRunnerIntegration: boolean;
private extraStatements = new Map<SdsExpression, Generated>();

constructor(
importSet: Map<String, ImportData> = new Map<String, ImportData>(),
Expand All @@ -1268,7 +1289,7 @@ class GenerationInfoFrame {
targetPlaceholder: string | undefined = undefined,
disableRunnerIntegration: boolean = false,
) {
this.lambdaManager = new IdManager();
this.idManager = new IdManager();
this.importSet = importSet;
this.utilitySet = utilitySet;
this.typeVariableSet = typeVariableSet;
Expand Down Expand Up @@ -1304,8 +1325,26 @@ class GenerationInfoFrame {
}
}

addExtraStatement(node: SdsExpression, statement: Generated): void {
if (!this.extraStatements.has(node)) {
this.extraStatements.set(node, statement);
}
}

resetExtraStatements(): void {
this.extraStatements.clear();
}

getExtraStatements(): Generated[] {
return Array.from(this.extraStatements.values());
}

getUniqueLambdaName(lambda: SdsLambda): string {
return `${LAMBDA_PREFIX}${this.lambdaManager.assignId(lambda)}`;
return `${LAMBDA_PREFIX}${this.idManager.assignId(lambda)}`;
}

getUniqueReceiverName(receiver: SdsExpression): string {
return `${RECEIVER_PREFIX}${this.idManager.assignId(receiver)}`;
}
}

Expand Down
7 changes: 7 additions & 0 deletions packages/safe-ds-lang/src/language/grammar/safe-ds.langium
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ SdsPrimaryExpression returns SdsExpression:
| SdsParenthesizedExpression
| SdsReference
| SdsTemplateString
| SdsThis
;

interface SdsLiteral extends SdsExpression {}
Expand Down Expand Up @@ -896,6 +897,12 @@ SdsTemplateStringEnd returns SdsExpression:
value=TEMPLATE_STRING_END
;

interface SdsThis extends SdsExpression {}

SdsThis returns SdsThis:
{SdsThis} 'this'
;


// -----------------------------------------------------------------------------
// Types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ export class SafeDsCompletionProvider extends DefaultCompletionProvider {
return true;
}

private illegalKeywordsInPipelineFile = new Set(['annotation', 'class', 'enum', 'fun']);
private illegalKeywordsInStubFile = new Set(['pipeline', 'segment']);
private illegalKeywordsInModuleContextOfPipelineFile = new Set(['annotation', 'class', 'enum', 'fun', 'this']);
private illegalKeywordsInModuleContextOfStubFile = new Set(['pipeline', 'segment']);

protected override filterKeyword(context: CompletionContext, keyword: Keyword): boolean {
// Filter out keywords that do not contain any word character
Expand All @@ -84,12 +84,15 @@ export class SafeDsCompletionProvider extends DefaultCompletionProvider {

if ((!context.node || isSdsModule(context.node)) && !getPackageName(context.node)) {
return keyword.value === 'package';
} else if (isSdsModule(context.node) && isInPipelineFile(context.node)) {
return !this.illegalKeywordsInPipelineFile.has(keyword.value);
} else if (isInPipelineFile(context.node)) {
if (isSdsModule(context.node)) {
return !this.illegalKeywordsInModuleContextOfPipelineFile.has(keyword.value);
} else {
return keyword.value !== 'this';
}
} else if (isSdsModule(context.node) && isInStubFile(context.node)) {
return !this.illegalKeywordsInStubFile.has(keyword.value);
return !this.illegalKeywordsInModuleContextOfStubFile.has(keyword.value);
}

return true;
}

Expand Down
Loading