Skip to content

Commit

Permalink
Model the return slot as an output parameter (#4432)
Browse files Browse the repository at this point in the history
Also fix `Param` insts to have meaningful names in pretty-printing, to
help clarify relationship with return slot.
  • Loading branch information
geoffromer authored Oct 23, 2024
1 parent 5038218 commit 9266f86
Show file tree
Hide file tree
Showing 463 changed files with 7,401 additions and 4,939 deletions.
8 changes: 4 additions & 4 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
}

// If there is a return slot, build storage for the result.
SemIR::InstId return_storage_id = SemIR::InstId::Invalid;
SemIR::InstId return_slot_arg_id = SemIR::InstId::Invalid;
SemIR::ReturnTypeInfo return_info = [&] {
DiagnosticAnnotationScope annotate_diagnostics(
&context.emitter(), [&](auto& builder) {
CARBON_DIAGNOSTIC(IncompleteReturnTypeHere, Note,
"return type declared here");
builder.Note(callable.return_storage_id, IncompleteReturnTypeHere);
builder.Note(callable.return_slot_id, IncompleteReturnTypeHere);
});
return CheckFunctionReturnType(context, callee_id, callable,
*callee_specific_id);
Expand All @@ -190,7 +190,7 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
case SemIR::InitRepr::InPlace:
// Tentatively put storage for a temporary in the function's return slot.
// This will be replaced if necessary when we perform initialization.
return_storage_id = context.AddInst<SemIR::TemporaryStorage>(
return_slot_arg_id = context.AddInst<SemIR::TemporaryStorage>(
loc_id, {.type_id = return_info.type_id});
break;
case SemIR::InitRepr::None:
Expand All @@ -211,7 +211,7 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,

// Convert the arguments to match the parameters.
auto converted_args_id = ConvertCallArgs(
context, loc_id, callee_function.self_id, arg_ids, return_storage_id,
context, loc_id, callee_function.self_id, arg_ids, return_slot_arg_id,
CalleeParamsInfo(callable), *callee_specific_id);
auto call_inst_id =
context.AddInst<SemIR::Call>(loc_id, {.type_id = return_info.type_id,
Expand Down
8 changes: 4 additions & 4 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ static auto ConvertSelf(Context& context, SemIR::LocId call_loc_id,
auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs,
SemIR::InstId return_storage_id,
SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee,
SemIR::SpecificId callee_specific_id)
-> SemIR::InstBlockId {
Expand All @@ -1187,7 +1187,7 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
// Start building a block to hold the converted arguments.
llvm::SmallVector<SemIR::InstId> args;
args.reserve(implicit_param_patterns.size() + param_patterns.size() +
return_storage_id.is_valid());
return_slot_arg_id.is_valid());

// Check implicit parameters.
for (auto implicit_param_id : implicit_param_patterns) {
Expand Down Expand Up @@ -1241,8 +1241,8 @@ auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
}

// Track the return storage, if present.
if (return_storage_id.is_valid()) {
args.push_back(return_storage_id);
if (return_slot_arg_id.is_valid()) {
args.push_back(return_slot_arg_id);
}

return context.inst_blocks().AddOrEmpty(args);
Expand Down
2 changes: 1 addition & 1 deletion toolchain/check/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct CalleeParamsInfo {
auto ConvertCallArgs(Context& context, SemIR::LocId call_loc_id,
SemIR::InstId self_id,
llvm::ArrayRef<SemIR::InstId> arg_refs,
SemIR::InstId return_storage_id,
SemIR::InstId return_slot_arg_id,
const CalleeParamsInfo& callee,
SemIR::SpecificId callee_specific_id)
-> SemIR::InstBlockId;
Expand Down
4 changes: 3 additions & 1 deletion toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,7 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
case SemIR::BindValue::Kind:
case SemIR::Deref::Kind:
case SemIR::ImportRefLoaded::Kind:
case SemIR::ReturnSlot::Kind:
case SemIR::Temporary::Kind:
case SemIR::TemporaryStorage::Kind:
case SemIR::ValueAsRef::Kind:
Expand Down Expand Up @@ -1492,8 +1493,9 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
case SemIR::RequirementEquivalent::Kind:
case SemIR::RequirementImpls::Kind:
case SemIR::RequirementRewrite::Kind:
case SemIR::ReturnExpr::Kind:
case SemIR::Return::Kind:
case SemIR::ReturnExpr::Kind:
case SemIR::ReturnSlotPattern::Kind:
case SemIR::StructLiteral::Kind:
case SemIR::TupleLiteral::Kind:
case SemIR::VarStorage::Kind:
Expand Down
2 changes: 1 addition & 1 deletion toolchain/check/global_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ auto GlobalInit::Finalize() -> void {
.extern_library_id = SemIR::LibraryNameId::Invalid,
.non_owning_decl_id = SemIR::InstId::Invalid,
.first_owning_decl_id = SemIR::InstId::Invalid},
{.return_storage_id = SemIR::InstId::Invalid,
{.return_slot_id = SemIR::InstId::Invalid,
.body_block_ids = {SemIR::InstBlockId::GlobalInit}}}));
}

Expand Down
30 changes: 18 additions & 12 deletions toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ auto HandleParseNode(Context& context, Parse::ReturnTypeId node_id) -> bool {
// Propagate the type expression.
auto [type_node_id, type_inst_id] = context.node_stack().PopExprWithNodeId();
auto type_id = ExprAsType(context, type_node_id, type_inst_id).type_id;
// TODO: Use a dedicated instruction rather than VarStorage here.
context.AddInstAndPush<SemIR::VarStorage>(
node_id, {.type_id = type_id, .name_id = SemIR::NameId::ReturnSlot});
auto return_slot_id = context.AddPatternInst<SemIR::ReturnSlotPattern>(
node_id, {.type_id = type_id, .type_inst_id = type_inst_id});
// TODO: Use a separate inst kind here and for the corresponding Param,
// to capture the fact that the corresponding Param is not a value expression.
auto param_pattern_id = context.AddPatternInst<SemIR::ParamPattern>(
node_id, {.type_id = type_id,
.subpattern_id = return_slot_id,
.runtime_index = SemIR::RuntimeParamIndex::Unknown});
context.node_stack().Push(node_id, param_pattern_id);
return true;
}

Expand Down Expand Up @@ -97,7 +103,7 @@ static auto MergeFunctionRedecl(Context& context, SemIRLoc new_loc,
// Track the signature from the definition, so that IDs in the body
// match IDs in the signature.
prev_function.MergeDefinition(new_function);
prev_function.return_storage_id = new_function.return_storage_id;
prev_function.return_slot_id = new_function.return_slot_id;
}
if ((prev_import_ir_id.is_valid() && !new_is_import)) {
ReplacePrevInstForMerge(context, new_function.parent_scope_id,
Expand Down Expand Up @@ -170,14 +176,14 @@ static auto BuildFunctionDecl(Context& context,
Parse::AnyFunctionDeclId node_id,
bool is_definition)
-> std::pair<SemIR::FunctionId, SemIR::InstId> {
auto return_storage_id = SemIR::InstId::Invalid;
if (auto [return_node, maybe_return_storage_id] =
auto return_slot_pattern_id = SemIR::InstId::Invalid;
if (auto [return_node, maybe_return_slot_pattern_id] =
context.node_stack().PopWithNodeIdIf<Parse::NodeKind::ReturnType>();
maybe_return_storage_id) {
return_storage_id = *maybe_return_storage_id;
maybe_return_slot_pattern_id) {
return_slot_pattern_id = *maybe_return_slot_pattern_id;
}

auto name = PopNameComponent(context);
auto name = PopNameComponent(context, return_slot_pattern_id);
if (!name.params_id.is_valid()) {
context.TODO(node_id, "function with positional parameters");
name.params_id = SemIR::InstBlockId::Empty;
Expand Down Expand Up @@ -232,7 +238,7 @@ static auto BuildFunctionDecl(Context& context,
auto function_info =
SemIR::Function{{name_context.MakeEntityWithParamsBase(
name, decl_id, is_extern, introducer.extern_library)},
{.return_storage_id = return_storage_id,
{.return_slot_id = name.return_slot_id,
.virtual_modifier = virtual_modifier}};
if (is_definition) {
function_info.definition_id = decl_id;
Expand Down Expand Up @@ -332,7 +338,7 @@ static auto HandleFunctionDefinitionAfterSignature(
context.AddCurrentCodeBlockToFunction();

// Check the return type is complete.
CheckFunctionReturnType(context, function.return_storage_id, function,
CheckFunctionReturnType(context, function.return_slot_id, function,
SemIR::SpecificId::Invalid);

// Check the parameter types are complete.
Expand Down Expand Up @@ -397,7 +403,7 @@ auto HandleParseNode(Context& context, Parse::FunctionDefinitionId node_id)
// If the `}` of the function is reachable, reject if we need a return value
// and otherwise add an implicit `return;`.
if (context.is_current_position_reachable()) {
if (context.functions().Get(function_id).return_storage_id.is_valid()) {
if (context.functions().Get(function_id).return_slot_id.is_valid()) {
CARBON_DIAGNOSTIC(
MissingReturnStatement, Error,
"missing `return` at end of function with declared return type");
Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/handle_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,12 @@ static auto PopImplIntroducerAndParamsAsNameComponent(

ParameterBlocks parameter_blocks{
.implicit_params_id = SemIR::InstBlockId::Invalid,
.params_id = SemIR::InstBlockId::Invalid};
.params_id = SemIR::InstBlockId::Invalid,
.return_slot_id = SemIR::InstId::Invalid};
if (implicit_param_patterns_id) {
parameter_blocks = CalleePatternMatch(context, *implicit_param_patterns_id,
SemIR::InstBlockId::Invalid);
parameter_blocks =
CalleePatternMatch(context, *implicit_param_patterns_id,
SemIR::InstBlockId::Invalid, SemIR::InstId::Invalid);
}

Parse::NodeId first_param_node_id =
Expand All @@ -229,6 +231,8 @@ static auto PopImplIntroducerAndParamsAsNameComponent(
.params_loc_id = Parse::NodeId::Invalid,
.params_id = SemIR::InstBlockId::Invalid,
.param_patterns_id = SemIR::InstBlockId::Invalid,
.return_slot_pattern_id = SemIR::InstId::Invalid,
.return_slot_id = SemIR::InstId::Invalid,
.pattern_block_id = context.pattern_block_stack().Pop(),
};
}
Expand Down
16 changes: 9 additions & 7 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ class ImportRefResolver {

auto new_param_id = context_.AddInstInNoBlock<SemIR::Param>(
AddImportIRInst(param_id),
{.type_id = type_id, .runtime_index = param_inst.runtime_index});
{.type_id = type_id,
.runtime_index = param_inst.runtime_index,
.pretty_name_id = GetLocalNameId(param_inst.pretty_name_id)});
switch (bind_inst.kind) {
case SemIR::BindName::Kind: {
auto entity_name_id = context_.entity_names().Add(
Expand Down Expand Up @@ -1598,7 +1600,7 @@ class ImportRefResolver {
// Start with an incomplete function.
function_decl.function_id = context_.functions().Add(
{GetIncompleteLocalEntityBase(function_decl_id, import_function),
{.return_storage_id = SemIR::InstId::Invalid,
{.return_slot_id = SemIR::InstId::Invalid,
.builtin_function_kind = import_function.builtin_function_kind}});

function_decl.type_id =
Expand Down Expand Up @@ -1643,9 +1645,9 @@ class ImportRefResolver {
}

auto return_type_const_id = SemIR::ConstantId::Invalid;
if (import_function.return_storage_id.is_valid()) {
if (import_function.return_slot_id.is_valid()) {
return_type_const_id = GetLocalConstantId(
import_ir_.insts().Get(import_function.return_storage_id).type_id());
import_ir_.insts().Get(import_function.return_slot_id).type_id());
}
auto parent_scope_id = GetLocalNameScopeId(import_function.parent_scope_id);
LoadLocalParamConstantIds(import_function.implicit_param_refs_id);
Expand All @@ -1672,13 +1674,13 @@ class ImportRefResolver {
SetGenericData(import_function.generic_id, new_function.generic_id,
generic_data);

if (import_function.return_storage_id.is_valid()) {
if (import_function.return_slot_id.is_valid()) {
// Recreate the return slot from scratch.
// TODO: Once we import function definitions, we'll need to make sure we
// use the same return storage variable in the declaration and definition.
new_function.return_storage_id =
new_function.return_slot_id =
context_.AddInstInNoBlock<SemIR::VarStorage>(
AddImportIRInst(import_function.return_storage_id),
AddImportIRInst(import_function.return_slot_id),
{.type_id =
context_.GetTypeIdForTypeConstant(return_type_const_id),
.name_id = SemIR::NameId::ReturnSlot});
Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/name_component.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

namespace Carbon::Check {

auto PopNameComponent(Context& context) -> NameComponent {
auto PopNameComponent(Context& context, SemIR::InstId return_slot_pattern_id)
-> NameComponent {
Parse::NodeId first_param_node_id = Parse::InvalidNodeId();
Parse::NodeId last_param_node_id = Parse::InvalidNodeId();

Expand Down Expand Up @@ -42,8 +43,9 @@ auto PopNameComponent(Context& context) -> NameComponent {
implicit_param_patterns_id = SemIR::InstBlockId::Invalid;
}

auto [implicit_params_id, params_id] = CalleePatternMatch(
context, *implicit_param_patterns_id, *param_patterns_id);
auto [implicit_params_id, params_id, return_slot_id] =
CalleePatternMatch(context, *implicit_param_patterns_id,
*param_patterns_id, return_slot_pattern_id);

auto [name_loc_id, name_id] = context.node_stack().PopNameWithNodeId();
return {
Expand All @@ -57,6 +59,8 @@ auto PopNameComponent(Context& context) -> NameComponent {
.params_loc_id = params_loc_id,
.params_id = params_id,
.param_patterns_id = *param_patterns_id,
.return_slot_pattern_id = return_slot_pattern_id,
.return_slot_id = return_slot_id,
.pattern_block_id = context.pattern_block_stack().Pop(),
};
}
Expand Down
10 changes: 9 additions & 1 deletion toolchain/check/name_component.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@ struct NameComponent {
SemIR::InstBlockId params_id;
SemIR::InstBlockId param_patterns_id;

// The return slot.
// TODO: These are only used for function declarations. Should they go
// somewhere else?
SemIR::InstId return_slot_pattern_id;
SemIR::InstId return_slot_id;

// The pattern block.
SemIR::InstBlockId pattern_block_id;
};

// Pop a name component from the node stack and pattern block stack.
auto PopNameComponent(Context& context) -> NameComponent;
auto PopNameComponent(Context& context, SemIR::InstId return_slot_pattern_id =
SemIR::InstId::Invalid)
-> NameComponent;

// Pop the name of a declaration from the node stack and pattern block stack,
// and diagnose if it has parameters.
Expand Down
Loading

0 comments on commit 9266f86

Please sign in to comment.