Skip to content

Commit

Permalink
[internal] Add plugin hook for Go codegen support (#14707)
Browse files Browse the repository at this point in the history
Closes #14258. As described there, codegen for compiled languages is more complex because the generated code must be _compiled_, unlike Python where the code only needs to be present.

We still use the `GenerateSourcesRequest` plugin hook to generate the raw `.go` files and so that integrations like `export-codegen` goal still work. But that alone is not powerful enough to know how to compile the Go code. 

So, we add a new Go-specific plugin hook. Plugin implementations return back the standardized type `BuildGoPackageRequest`, which is all the information needed to compile a particular package, including by compiling its transitive dependencies. That allows for complex codegen modeling such as Protobuf needing the Protobuf third-party Go package compiled first, or Protobufs depending on other Protobufs.

Rule authors can then directly tell Pants to compile that codegen (#14705), or it can be loaded via a `dependency` on a normal `go_package`.

[ci skip-rust]
  • Loading branch information
Eric-Arellano authored Mar 4, 2022
1 parent 5db4b72 commit 8db4f81
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 49 deletions.
32 changes: 20 additions & 12 deletions src/python/pants/backend/codegen/protobuf/go/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pants.backend.codegen.protobuf.protoc import Protoc
from pants.backend.codegen.protobuf.target_types import ProtobufGrpcToggleField, ProtobufSourceField
from pants.backend.go.target_types import GoPackageSourcesField
from pants.backend.go.util_rules.build_pkg import BuildGoPackageRequest
from pants.backend.go.util_rules.build_pkg_target import GoCodegenBuildRequest
from pants.backend.go.util_rules.sdk import GoSdkProcess
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
from pants.core.util_rules.source_files import SourceFilesRequest
Expand Down Expand Up @@ -38,21 +40,32 @@
from pants.util.logging import LogLevel


class GoCodegenBuildProtobufRequest(GoCodegenBuildRequest):
generate_from = ProtobufSourceField


class GenerateGoFromProtobufRequest(GenerateSourcesRequest):
input = ProtobufSourceField
output = GoPackageSourcesField


@dataclass(frozen=True)
class SetupGoProtocPlugin:
class _SetupGoProtocPlugin:
digest: Digest


@rule(desc="Generate Go from Protobuf", level=LogLevel.DEBUG)
@rule
async def setup_build_go_package_request_for_protobuf(
_: GoCodegenBuildProtobufRequest,
) -> BuildGoPackageRequest:
raise NotImplementedError()


@rule(desc="Generate Go source files from Protobuf", level=LogLevel.DEBUG)
async def generate_go_from_protobuf(
request: GenerateGoFromProtobufRequest,
protoc: Protoc,
go_protoc_plugin: SetupGoProtocPlugin,
go_protoc_plugin: _SetupGoProtocPlugin,
) -> GeneratedSources:
output_dir = "_generated_files"
protoc_relpath = "__protoc"
Expand Down Expand Up @@ -81,13 +94,7 @@ async def generate_go_from_protobuf(
)

input_digest = await Get(
Digest,
MergeDigests(
[
all_sources_stripped.snapshot.digest,
empty_output_dir,
]
),
Digest, MergeDigests([all_sources_stripped.snapshot.digest, empty_output_dir])
)

maybe_grpc_plugin_args = []
Expand Down Expand Up @@ -174,7 +181,7 @@ async def generate_go_from_protobuf(


@rule
async def setup_go_protoc_plugin(platform: Platform) -> SetupGoProtocPlugin:
async def setup_go_protoc_plugin(platform: Platform) -> _SetupGoProtocPlugin:
go_mod_digest = await Get(
Digest,
CreateDigest(
Expand Down Expand Up @@ -241,11 +248,12 @@ async def setup_go_protoc_plugin(platform: Platform) -> SetupGoProtocPlugin:
),
)
plugin_digest = await Get(Digest, RemovePrefix(merged_output_digests, "gopath/bin"))
return SetupGoProtocPlugin(plugin_digest)
return _SetupGoProtocPlugin(plugin_digest)


def rules():
return (
*collect_rules(),
UnionRule(GenerateSourcesRequest, GenerateGoFromProtobufRequest),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildProtobufRequest),
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def assert_files_generated(
HydratedSources, [HydrateSourcesRequest(tgt[ProtobufSourceField])]
)
generated_sources = rule_runner.request(
GeneratedSources,
[GenerateGoFromProtobufRequest(protocol_sources.snapshot, tgt)],
GeneratedSources, [GenerateGoFromProtobufRequest(protocol_sources.snapshot, tgt)]
)
assert set(generated_sources.snapshot.files) == set(expected_files)

Expand Down
72 changes: 68 additions & 4 deletions src/python/pants/backend/go/util_rules/build_pkg_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dataclasses
from dataclasses import dataclass
from typing import ClassVar, Type, cast

from pants.backend.go.target_types import (
GoImportPathField,
Expand All @@ -29,16 +30,27 @@
)
from pants.build_graph.address import Address
from pants.engine.engine_aware import EngineAwareParameter
from pants.engine.internals.graph import AmbiguousCodegenImplementationsException
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import collect_rules, rule
from pants.engine.target import Dependencies, DependenciesRequest, UnexpandedTargets, WrappedTarget
from pants.engine.target import (
Dependencies,
DependenciesRequest,
SourcesField,
Target,
UnexpandedTargets,
WrappedTarget,
)
from pants.engine.unions import UnionMembership, union
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import bullet_list


@dataclass(frozen=True)
class BuildGoPackageTargetRequest(EngineAwareParameter):
"""Build a `go_package` or `go_third_party_package` target and its dependencies as `__pkg__.a`
files."""
"""Build a `go_package`, `go_third_party_package`, or Go codegen target and its dependencies as
`__pkg__.a` files."""

address: Address
is_main: bool = False
Expand All @@ -48,15 +60,66 @@ def debug_hint(self) -> str:
return str(self.address)


@union
@dataclass(frozen=True)
class GoCodegenBuildRequest:
"""The plugin hook to build/compile Go code.
Note that you should still use the normal `GenerateSourcesRequest` plugin hook from
`pants.engine.target` too, which is necessary for integrations like the `export-codegen` goal.
However, that is only helpful to generate the raw `.go` files; you also need to use this
plugin hook so that Pants knows how to compile those generated `.go` files.
Subclass this and set the class property `generate_from`. Define a rule that goes from your
subclass to `BuildGoPackageRequest` - the request must result in valid compilation, which you
should test for by using `rule_runner.request(BuiltGoPackage, BuildGoPackageRequest)` in your
tests. For example, make sure to set up any third-party packages needed by the generated code.
Finally, register `UnionRule(GoCodegenBuildRequest, MySubclass)`.
"""

target: Target

generate_from: ClassVar[type[SourcesField]]


def maybe_get_codegen_request_type(
tgt: Target, union_membership: UnionMembership
) -> GoCodegenBuildRequest | None:
if not tgt.has_field(SourcesField):
return None
generate_request_types = cast(
FrozenOrderedSet[Type[GoCodegenBuildRequest]], union_membership.get(GoCodegenBuildRequest)
)
sources_field = tgt[SourcesField]
relevant_requests = [
req for req in generate_request_types if isinstance(sources_field, req.generate_from)
]
if len(relevant_requests) > 1:
generate_from_sources = relevant_requests[0].generate_from.__name__
raise AmbiguousCodegenImplementationsException(
f"Multiple registered code generators from {GoCodegenBuildRequest.__name__} can "
f"generate from {generate_from_sources}. It is ambiguous which implementation to "
f"use.\n\n"
f"Possible implementations:\n\n"
f"{bullet_list(sorted(generator.__name__ for generator in relevant_requests))}"
)
return relevant_requests[0](tgt) if relevant_requests else None


# NB: We must have a description for the streaming of this rule to work properly
# (triggered by `FallibleBuildGoPackageRequest` subclassing `EngineAwareReturnType`).
@rule(desc="Set up Go compilation request", level=LogLevel.DEBUG)
async def setup_build_go_package_target_request(
request: BuildGoPackageTargetRequest,
request: BuildGoPackageTargetRequest, union_membership: UnionMembership
) -> FallibleBuildGoPackageRequest:
wrapped_target = await Get(WrappedTarget, Address, request.address)
target = wrapped_target.target

codegen_request = maybe_get_codegen_request_type(target, union_membership)
if codegen_request:
codegen_result = await Get(BuildGoPackageRequest, GoCodegenBuildRequest, codegen_request)
return FallibleBuildGoPackageRequest(codegen_result, codegen_result.import_path)

embed_config: EmbedConfig | None = None
if target.has_field(GoPackageSourcesField):
_maybe_first_party_pkg_analysis, _maybe_first_party_pkg_digest = await MultiGet(
Expand Down Expand Up @@ -137,6 +200,7 @@ async def setup_build_go_package_target_request(
if (
tgt.has_field(GoPackageSourcesField)
or tgt.has_field(GoThirdPartyPackageDependenciesField)
or bool(maybe_get_codegen_request_type(tgt, union_membership))
)
)
direct_dependencies = []
Expand Down
125 changes: 120 additions & 5 deletions src/python/pants/backend/go/util_rules/build_pkg_target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,61 @@
FallibleBuildGoPackageRequest,
FallibleBuiltGoPackage,
)
from pants.backend.go.util_rules.build_pkg_target import BuildGoPackageTargetRequest
from pants.engine.addresses import Address
from pants.engine.fs import Snapshot
from pants.engine.rules import QueryRule
from pants.backend.go.util_rules.build_pkg_target import (
BuildGoPackageTargetRequest,
GoCodegenBuildRequest,
)
from pants.core.target_types import FileSourceField, FileTarget
from pants.engine.addresses import Address, Addresses
from pants.engine.fs import CreateDigest, Digest, FileContent, Snapshot
from pants.engine.rules import Get, QueryRule, rule
from pants.engine.target import Dependencies, DependenciesRequest
from pants.engine.unions import UnionRule
from pants.testutil.rule_runner import RuleRunner
from pants.util.strutil import path_safe


# Set up a semi-complex codegen plugin. Note that we cyclically call into the
# `BuildGoPackageTargetRequest` rule to set up a dependency on a third-party package, as this
# is common for codegen plugins to need to do.
class GoCodegenBuildFilesRequest(GoCodegenBuildRequest):
generate_from = FileSourceField


@rule
async def generate_from_file(request: GoCodegenBuildFilesRequest) -> BuildGoPackageRequest:
content = dedent(
"""\
package gen
import "fmt"
import "github.com/google/uuid"
func Quote(s string) string {
uuid.SetClockSequence(-1) // A trivial line to use uuid.
return fmt.Sprintf(">> %s <<", s)
}
"""
)
digest = await Get(Digest, CreateDigest([FileContent("codegen/f.go", content.encode())]))

deps = await Get(Addresses, DependenciesRequest(request.target[Dependencies]))
assert len(deps) == 1
assert deps[0].generated_name == "github.com/google/uuid"
thirdparty_dep = await Get(FallibleBuildGoPackageRequest, BuildGoPackageTargetRequest(deps[0]))
assert thirdparty_dep.request is not None

return BuildGoPackageRequest(
import_path="codegen.com/gen",
digest=digest,
dir_path="codegen",
go_file_names=("f.go",),
s_file_names=(),
direct_dependencies=(thirdparty_dep.request,),
minimum_go_version=None,
)


@pytest.fixture
def rule_runner() -> RuleRunner:
rule_runner = RuleRunner(
Expand All @@ -49,12 +96,14 @@ def rule_runner() -> RuleRunner:
*first_party_pkg.rules(),
*third_party_pkg.rules(),
*target_type_rules.rules(),
generate_from_file,
QueryRule(BuiltGoPackage, [BuildGoPackageRequest]),
QueryRule(FallibleBuiltGoPackage, [BuildGoPackageRequest]),
QueryRule(BuildGoPackageRequest, [BuildGoPackageTargetRequest]),
QueryRule(FallibleBuildGoPackageRequest, [BuildGoPackageTargetRequest]),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildFilesRequest),
],
target_types=[GoModTarget, GoPackageTarget],
target_types=[GoModTarget, GoPackageTarget, FileTarget],
)
rule_runner.set_options([], env_inherit={"PATH"})
return rule_runner
Expand Down Expand Up @@ -353,3 +402,69 @@ def test_build_invalid_target(rule_runner: RuleRunner) -> None:
assert dep_build_request.request is None
assert dep_build_request.exit_code == 1
assert "dep/f.go:1:1: expected 'package', found invalid\n" in (dep_build_request.stderr or "")


def test_build_codegen_target(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"go.mod": dedent(
"""\
module example.com/greeter
go 1.17
require github.com/google/uuid v1.3.0
"""
),
"go.sum": dedent(
"""\
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
"""
),
"generate_from_me.txt": "",
"greeter.go": dedent(
"""\
package greeter
import "fmt"
import "codegen.com/gen"
func Hello() {
fmt.Println(gen.Quote("Hello world!"))
}
"""
),
"BUILD": dedent(
"""\
go_mod(name='mod')
go_package(name='pkg', dependencies=[":gen"])
file(
name='gen',
source='generate_from_me.txt',
dependencies=[':mod#github.com/google/uuid'],
)
"""
),
}
)

# Running directly on a codegen target should work.
assert_pkg_target_built(
rule_runner,
Address("", target_name="gen"),
expected_import_path="codegen.com/gen",
expected_dir_path="codegen",
expected_go_file_names=["f.go"],
expected_direct_dependency_import_paths=["github.com/google/uuid"],
expected_transitive_dependency_import_paths=[],
)

# Direct dependencies on codegen targets must be propagated.
assert_pkg_target_built(
rule_runner,
Address("", target_name="pkg"),
expected_import_path="example.com/greeter",
expected_dir_path="",
expected_go_file_names=["greeter.go"],
expected_direct_dependency_import_paths=["codegen.com/gen"],
expected_transitive_dependency_import_paths=["github.com/google/uuid"],
)
Loading

0 comments on commit 8db4f81

Please sign in to comment.