Skip to content

Commit

Permalink
Infer importpath if not set explicitly (#3705)
Browse files Browse the repository at this point in the history
* Infer importpath if not set explicitely

* Add test with explicit and implicit importpath

---------

Co-authored-by: Zhongpeng Lin <[email protected]>
  • Loading branch information
mering and linzhp authored Sep 27, 2023
1 parent 69d0fc8 commit d1da1bb
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 18 deletions.
10 changes: 5 additions & 5 deletions go/private/context.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,17 @@ def _check_importpaths(ctx):
if ":" in p:
fail("import path '%s' contains invalid character :" % p)

def _infer_importpath(ctx):
def _infer_importpath(ctx, attr):
DEFAULT_LIB = "go_default_library"
VENDOR_PREFIX = "/vendor/"

# Check if paths were explicitly set, either in this rule or in an
# embedded rule.
attr_importpath = getattr(ctx.attr, "importpath", "")
attr_importmap = getattr(ctx.attr, "importmap", "")
attr_importpath = getattr(attr, "importpath", "")
attr_importmap = getattr(attr, "importmap", "")
embed_importpath = ""
embed_importmap = ""
for embed in getattr(ctx.attr, "embed", []):
for embed in getattr(attr, "embed", []):
if GoLibrary not in embed:
continue
lib = embed[GoLibrary]
Expand Down Expand Up @@ -504,7 +504,7 @@ def go_context(ctx, attr = None):
toolchain.sdk.tools)

_check_importpaths(ctx)
importpath, importmap, pathtype = _infer_importpath(ctx)
importpath, importmap, pathtype = _infer_importpath(ctx, attr)
importpath_aliases = tuple(getattr(attr, "importpath_aliases", ()))

return struct(
Expand Down
3 changes: 0 additions & 3 deletions go/private/rules/library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ load(
load(
"//go/private:providers.bzl",
"GoLibrary",
"INFERRED_PATH",
)
load(
"//go/private/rules:transition.bzl",
Expand All @@ -39,8 +38,6 @@ load(
def _go_library_impl(ctx):
"""Implements the go_library() rule."""
go = go_context(ctx)
if go.pathtype == INFERRED_PATH:
fail("importpath must be specified in this library or one of its embedded libraries")
library = go.new_library(go)
source = go.library_to_source(go, ctx.attr, library, ctx.coverage_instrumented())
archive = go.archive(go, source)
Expand Down
16 changes: 6 additions & 10 deletions proto/def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ load(
"//go/private:go_toolchain.bzl",
"GO_TOOLCHAIN",
)
load(
"//go/private:providers.bzl",
"INFERRED_PATH",
)
load(
"//go/private/rules:transition.bzl",
"non_go_tool_transition",
Expand All @@ -46,7 +42,7 @@ load(

GoProtoImports = provider()

def get_imports(attr):
def get_imports(attr, importpath):
proto_deps = []

# ctx.attr.proto is a one-element array since there is a Starlark transition attached to it.
Expand All @@ -60,7 +56,7 @@ def get_imports(attr):
direct = dict()
for dep in proto_deps:
for src in dep[ProtoInfo].check_deps_sources.to_list():
direct["{}={}".format(proto_path(src, dep[ProtoInfo]), attr.importpath)] = True
direct["{}={}".format(proto_path(src, dep[ProtoInfo]), importpath)] = True

deps = getattr(attr, "deps", []) + getattr(attr, "embed", [])
transitive = [
Expand All @@ -71,7 +67,8 @@ def get_imports(attr):
return depset(direct = direct.keys(), transitive = transitive)

def _go_proto_aspect_impl(_target, ctx):
imports = get_imports(ctx.rule.attr)
go = go_context(ctx, ctx.rule.attr)
imports = get_imports(ctx.rule.attr, go.importpath)
return [GoProtoImports(imports = imports)]

_go_proto_aspect = aspect(
Expand All @@ -80,6 +77,7 @@ _go_proto_aspect = aspect(
"deps",
"embed",
],
toolchains = [GO_TOOLCHAIN],
)

def _proto_library_to_source(_go, attr, source, merge):
Expand All @@ -93,8 +91,6 @@ def _proto_library_to_source(_go, attr, source, merge):

def _go_proto_library_impl(ctx):
go = go_context(ctx)
if go.pathtype == INFERRED_PATH:
fail("importpath must be specified in this library or one of its embedded libraries")
if ctx.attr.compiler:
#TODO: print("DEPRECATED: compiler attribute on {}, use compilers instead".format(ctx.label))
compilers = [ctx.attr.compiler]
Expand Down Expand Up @@ -124,7 +120,7 @@ def _go_proto_library_impl(ctx):
go,
compiler = compiler,
protos = [d[ProtoInfo] for d in proto_deps],
imports = get_imports(ctx.attr),
imports = get_imports(ctx.attr, go.importpath),
importpath = go.importpath,
))
library = go.new_library(
Expand Down
36 changes: 36 additions & 0 deletions tests/core/go_proto_library_importpath/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@rules_proto//proto:defs.bzl", "proto_library")

# Common rules
proto_library(
name = "foo_proto",
srcs = ["foo.proto"],
)

go_proto_library(
name = "foo_go_proto",
importpath = "path/to/foo_go",
proto = ":foo_proto",
)

proto_library(
name = "bar_proto",
srcs = ["bar.proto"],
deps = [":foo_proto"],
)

go_proto_library(
name = "bar_go_proto",
proto = ":bar_proto",
deps = [":foo_go_proto"],
)

go_test(
name = "importpath_test",
srcs = ["importpath_test.go"],
deps = [
":bar_go_proto",
":foo_go_proto",
],
)
9 changes: 9 additions & 0 deletions tests/core/go_proto_library_importpath/bar.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package tests.core.go_proto_library_importpath.bar;

import "tests/core/go_proto_library_importpath/foo.proto";

message Bar {
foo.Foo value = 1;
}
7 changes: 7 additions & 0 deletions tests/core/go_proto_library_importpath/foo.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
syntax = "proto3";

package tests.core.go_proto_library_importpath.foo;

message Foo {
int64 value = 1;
}
22 changes: 22 additions & 0 deletions tests/core/go_proto_library_importpath/importpath_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package importpath_test

import (
"fmt"
"testing"

bar_proto "tests/core/go_proto_library_importpath/bar_go_proto"
foo_proto "path/to/foo_go"
)

func Test(t *testing.T) {
bar := &bar_proto.Bar{}
bar.Value = &foo_proto.Foo{}
bar.Value.Value = 5

var expected int64 = 5
if bar.Value.Value != expected {
t.Errorf(fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
"actual : %s", expected, bar.Value.Value))
}
}

0 comments on commit d1da1bb

Please sign in to comment.