Skip to content

Commit

Permalink
Track and resolve the specific callee in a call to a generic function (
Browse files Browse the repository at this point in the history
…#4395)

Add a new `specific_function` instruction that represents a generic
function plus its deduced argument list as a callee in a function call.
The new instruction can only appear as the immediate operand of a call
instruction, so we give it a builtin placeholder type.

At the end of each file, require definitions for all specific functions
used in that file. Resolve the generic with the argument list to produce
those specific function definitions as needed, and diagnose if the
generic doesn't have a definition available.

A few tests are updated in cases where they declared and used generic
functions but didn't previously provide a function definition.
  • Loading branch information
zygoloid authored Oct 10, 2024
1 parent bb874f2 commit 1a1bfd2
Show file tree
Hide file tree
Showing 35 changed files with 1,591 additions and 358 deletions.
18 changes: 15 additions & 3 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "toolchain/check/convert.h"
#include "toolchain/check/deduce.h"
#include "toolchain/check/function.h"
#include "toolchain/sem_ir/builtin_function_kind.h"
#include "toolchain/sem_ir/builtin_inst_kind.h"
#include "toolchain/sem_ir/entity_with_params_base.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/inst.h"
Expand All @@ -24,8 +26,8 @@ namespace Carbon::Check {
// `self_id` and `arg_ids` are the self argument and explicit arguments in the
// call.
//
// Returns a SpecificId for the specific callee, or `nullopt` if an error has
// been diagnosed.
// Returns a `SpecificId` for the specific callee, `SpecificId::Invalid` if the
// callee is not generic, or `nullopt` if an error has been diagnosed.
static auto ResolveCalleeInCall(Context& context, SemIR::LocId loc_id,
const SemIR::EntityWithParamsBase& entity,
llvm::StringLiteral entity_kind_for_diagnostic,
Expand Down Expand Up @@ -141,10 +143,20 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
// for the call.
auto callee_specific_id = ResolveCalleeInCall(
context, loc_id, callable, "function", callable.generic_id,
callee_function.specific_id, callee_function.self_id, arg_ids);
callee_function.enclosing_specific_id, callee_function.self_id, arg_ids);
if (!callee_specific_id) {
return SemIR::InstId::BuiltinError;
}
if (callee_specific_id->is_valid()) {
callee_id =
context.AddInst(context.insts().GetLocId(callee_id),
SemIR::SpecificFunction{
.type_id = context.GetBuiltinType(
SemIR::BuiltinInstKind::SpecificFunctionType),
.callee_id = callee_id,
.specific_id = *callee_specific_id});
context.definitions_required().push_back(callee_id);
}

// If there is a return slot, build storage for the result.
SemIR::InstId return_storage_id = SemIR::InstId::Invalid;
Expand Down
38 changes: 31 additions & 7 deletions toolchain/check/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "toolchain/check/context.h"
#include "toolchain/check/diagnostic_helpers.h"
#include "toolchain/check/function.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/handle.h"
#include "toolchain/check/import.h"
#include "toolchain/check/import_ref.h"
Expand Down Expand Up @@ -783,14 +784,18 @@ auto NodeIdTraversal::Next() -> std::optional<Parse::NodeId> {
}
}

// Emits a diagnostic for each declaration in context.definitions_required()
// that doesn't have a definition.
static auto DiagnoseMissingDefinitions(Context& context,
Context::DiagnosticEmitter& emitter)
// Checks that each required definition is available. If the definition can be
// generated by resolving a specific, does so, otherwise emits a diagnostic for
// each declaration in context.definitions_required() that doesn't have a
// definition.
static auto CheckRequiredDefinitions(Context& context,
Context::DiagnosticEmitter& emitter)
-> void {
CARBON_DIAGNOSTIC(MissingDefinitionInImpl, Error,
"no definition found for declaration in impl file");
for (SemIR::InstId decl_inst_id : context.definitions_required()) {
// Note that more required definitions can be added during this loop.
for (size_t i = 0; i != context.definitions_required().size(); ++i) {
SemIR::InstId decl_inst_id = context.definitions_required()[i];
SemIR::Inst decl_inst = context.insts().Get(decl_inst_id);
CARBON_KIND_SWITCH(context.insts().Get(decl_inst_id)) {
case CARBON_KIND(SemIR::ClassDecl class_decl): {
Expand All @@ -817,6 +822,25 @@ static auto DiagnoseMissingDefinitions(Context& context,
// triggering https://github.com/carbon-language/carbon-lang/issues/4071
CARBON_FATAL("TODO: Support interfaces in DiagnoseMissingDefinitions");
}
case CARBON_KIND(SemIR::SpecificFunction specific_function): {
if (!ResolveSpecificDefinition(context,
specific_function.specific_id)) {
CARBON_DIAGNOSTIC(MissingGenericFunctionDefinition, Error,
"use of undefined generic function");
CARBON_DIAGNOSTIC(MissingGenericFunctionDefinitionHere, Note,
"generic function declared here");
auto generic_decl_id =
context.generics()
.Get(context.specifics()
.Get(specific_function.specific_id)
.generic_id)
.decl_id;
emitter.Build(decl_inst_id, MissingGenericFunctionDefinition)
.Note(generic_decl_id, MissingGenericFunctionDefinitionHere)
.Emit();
}
break;
}
default: {
CARBON_FATAL("Unexpected inst in definitions_required: {0}", decl_inst);
}
Expand Down Expand Up @@ -908,9 +932,9 @@ static auto CheckParseTree(
return;
}

context.Finalize();
CheckRequiredDefinitions(context, emitter);

DiagnoseMissingDefinitions(context, emitter);
context.Finalize();

context.VerifyOnFinish();

Expand Down
1 change: 1 addition & 0 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ class TypeCompleter {
case SemIR::BuiltinInstKind::NamespaceType:
case SemIR::BuiltinInstKind::BoundMethodType:
case SemIR::BuiltinInstKind::WitnessType:
case SemIR::BuiltinInstKind::SpecificFunctionType:
return MakeCopyValueRepr(type_id);

case SemIR::BuiltinInstKind::StringType:
Expand Down
9 changes: 7 additions & 2 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,9 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
return RebuildAndValidateIfFieldsAreConstant(
eval_context, inst,
[&](SemIR::IntType result) {
return ValidateIntType(eval_context.context(),
int_type.bit_width_id, result);
return ValidateIntType(
eval_context.context(),
inst_id.is_valid() ? inst_id : int_type.bit_width_id, result);
},
&SemIR::IntType::bit_width_id);
}
Expand All @@ -1207,6 +1208,10 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
},
&SemIR::FloatType::bit_width_id);
}
case SemIR::SpecificFunction::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::SpecificFunction::callee_id,
&SemIR::SpecificFunction::specific_id);
case SemIR::StructType::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::StructType::fields_id);
Expand Down
2 changes: 2 additions & 0 deletions toolchain/check/testdata/basics/builtin_insts.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
// CHECK:STDOUT: instFloatType: {kind: BuiltinInst, arg0: FloatType, type: typeTypeType}
// CHECK:STDOUT: instStringType: {kind: BuiltinInst, arg0: StringType, type: typeTypeType}
// CHECK:STDOUT: instBoundMethodType: {kind: BuiltinInst, arg0: BoundMethodType, type: typeTypeType}
// CHECK:STDOUT: instSpecificFunctionType: {kind: BuiltinInst, arg0: SpecificFunctionType, type: typeTypeType}
// CHECK:STDOUT: instNamespaceType: {kind: BuiltinInst, arg0: NamespaceType, type: typeTypeType}
// CHECK:STDOUT: instWitnessType: {kind: BuiltinInst, arg0: WitnessType, type: typeTypeType}
// CHECK:STDOUT: 'inst+0': {kind: Namespace, arg0: name_scope0, arg1: inst<invalid>, type: type(instNamespaceType)}
Expand All @@ -47,6 +48,7 @@
// CHECK:STDOUT: instFloatType: templateConstant(instFloatType)
// CHECK:STDOUT: instStringType: templateConstant(instStringType)
// CHECK:STDOUT: instBoundMethodType: templateConstant(instBoundMethodType)
// CHECK:STDOUT: instSpecificFunctionType: templateConstant(instSpecificFunctionType)
// CHECK:STDOUT: instNamespaceType: templateConstant(instNamespaceType)
// CHECK:STDOUT: instWitnessType: templateConstant(instWitnessType)
// CHECK:STDOUT: 'inst+0': templateConstant(inst+0)
Expand Down
12 changes: 8 additions & 4 deletions toolchain/check/testdata/class/generic/import.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ class Class(U:! type) {
// CHECK:STDOUT: %.6: type = ptr_type %.2 [template]
// CHECK:STDOUT: %F.type.3: type = fn_type @F.2 [template]
// CHECK:STDOUT: %F.3: %F.type.3 = struct_value () [template]
// CHECK:STDOUT: %.7: <specific function> = specific_function %F.2, @F.1(i32) [template]
// CHECK:STDOUT: %UseField.type: type = fn_type @UseField [template]
// CHECK:STDOUT: %UseField: %UseField.type = struct_value () [template]
// CHECK:STDOUT: }
Expand Down Expand Up @@ -502,9 +503,10 @@ class Class(U:! type) {
// CHECK:STDOUT: %F.call.loc6: init %CompleteClass.3 = call %F.ref.loc6() to %.loc6_7
// CHECK:STDOUT: assign %v.var, %F.call.loc6
// CHECK:STDOUT: %v.ref: ref %CompleteClass.3 = name_ref v, %v
// CHECK:STDOUT: %.loc7_11: %F.type.2 = specific_constant imports.%import_ref.7, @CompleteClass(i32) [template = constants.%F.2]
// CHECK:STDOUT: %F.ref.loc7: %F.type.2 = name_ref F, %.loc7_11 [template = constants.%F.2]
// CHECK:STDOUT: %F.call.loc7: init i32 = call %F.ref.loc7()
// CHECK:STDOUT: %.loc7_11.1: %F.type.2 = specific_constant imports.%import_ref.7, @CompleteClass(i32) [template = constants.%F.2]
// CHECK:STDOUT: %F.ref.loc7: %F.type.2 = name_ref F, %.loc7_11.1 [template = constants.%F.2]
// CHECK:STDOUT: %.loc7_11.2: <specific function> = specific_function %F.ref.loc7, @F.1(i32) [template = constants.%.7]
// CHECK:STDOUT: %F.call.loc7: init i32 = call %.loc7_11.2()
// CHECK:STDOUT: %.loc7_15.1: i32 = value_of_initializer %F.call.loc7
// CHECK:STDOUT: %.loc7_15.2: i32 = converted %F.call.loc7, %.loc7_15.1
// CHECK:STDOUT: return %.loc7_15.2
Expand Down Expand Up @@ -564,7 +566,9 @@ class Class(U:! type) {
// CHECK:STDOUT: %F => constants.%F.2
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: specific @F.1(i32) {}
// CHECK:STDOUT: specific @F.1(i32) {
// CHECK:STDOUT: !definition:
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- fail_generic_arg_mismatch.carbon
// CHECK:STDOUT:
Expand Down
Loading

0 comments on commit 1a1bfd2

Please sign in to comment.