Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement "import public" using type aliases. #583

Merged
merged 1 commit into from
Apr 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 28 additions & 196 deletions protoc-gen-go/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"crypto/sha256"
"encoding/hex"
"fmt"
"go/build"
"go/parser"
"go/printer"
"go/token"
Expand Down Expand Up @@ -345,8 +346,7 @@ type symbol interface {
type messageSymbol struct {
sym string
hasExtensions, isMessageSet bool
hasOneof bool
getters []getterSymbol
oneofTypes []string
}

type getterSymbol struct {
Expand All @@ -357,146 +357,10 @@ type getterSymbol struct {
}

func (ms *messageSymbol) GenerateAlias(g *Generator, pkg GoPackageName) {
remoteSym := string(pkg) + "." + ms.sym

g.P("type ", ms.sym, " ", remoteSym)
g.P("func (m *", ms.sym, ") Reset() { (*", remoteSym, ")(m).Reset() }")
g.P("func (m *", ms.sym, ") String() string { return (*", remoteSym, ")(m).String() }")
g.P("func (*", ms.sym, ") ProtoMessage() {}")
g.P("func (m *", ms.sym, ") XXX_Unmarshal(buf []byte) error ",
"{ return (*", remoteSym, ")(m).XXX_Unmarshal(buf) }")
g.P("func (m *", ms.sym, ") XXX_Marshal(b []byte, deterministic bool) ([]byte, error) ",
"{ return (*", remoteSym, ")(m).XXX_Marshal(b, deterministic) }")
g.P("func (m *", ms.sym, ") XXX_Size() int ",
"{ return (*", remoteSym, ")(m).XXX_Size() }")
g.P("func (m *", ms.sym, ") XXX_DiscardUnknown() ",
"{ (*", remoteSym, ")(m).XXX_DiscardUnknown() }")
if ms.hasExtensions {
g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ",
"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
}
if ms.hasOneof {
// Oneofs and public imports do not mix well.
// We can make them work okay for the binary format,
// but they're going to break weirdly for text/JSON.
enc := "_" + ms.sym + "_OneofMarshaler"
dec := "_" + ms.sym + "_OneofUnmarshaler"
size := "_" + ms.sym + "_OneofSizer"
encSig := "(msg " + g.Pkg["proto"] + ".Message, b *" + g.Pkg["proto"] + ".Buffer) error"
decSig := "(msg " + g.Pkg["proto"] + ".Message, tag, wire int, b *" + g.Pkg["proto"] + ".Buffer) (bool, error)"
sizeSig := "(msg " + g.Pkg["proto"] + ".Message) int"
g.P("func (m *", ms.sym, ") XXX_OneofFuncs() (func", encSig, ", func", decSig, ", func", sizeSig, ", []interface{}) {")
g.P("_, _, _, x := (*", remoteSym, ")(nil).XXX_OneofFuncs()")
g.P("return ", enc, ", ", dec, ", ", size, ", x")
g.P("}")

g.P("func ", enc, encSig, " {")
g.P("m := msg.(*", ms.sym, ")")
g.P("m0 := (*", remoteSym, ")(m)")
g.P("enc, _, _, _ := m0.XXX_OneofFuncs()")
g.P("return enc(m0, b)")
g.P("}")

g.P("func ", dec, decSig, " {")
g.P("m := msg.(*", ms.sym, ")")
g.P("m0 := (*", remoteSym, ")(m)")
g.P("_, dec, _, _ := m0.XXX_OneofFuncs()")
g.P("return dec(m0, tag, wire, b)")
g.P("}")

g.P("func ", size, sizeSig, " {")
g.P("m := msg.(*", ms.sym, ")")
g.P("m0 := (*", remoteSym, ")(m)")
g.P("_, _, size, _ := m0.XXX_OneofFuncs()")
g.P("return size(m0)")
g.P("}")
}
for _, get := range ms.getters {

if get.typeName != "" {
g.RecordTypeUse(get.typeName)
}
typ := get.typ
val := "(*" + remoteSym + ")(m)." + get.name + "()"
if get.genType {
// typ will be "*pkg.T" (message/group) or "pkg.T" (enum)
// or "map[t]*pkg.T" (map to message/enum).
// The first two of those might have a "[]" prefix if it is repeated.
// Drop any package qualifier since we have hoisted the type into this package.
rep := strings.HasPrefix(typ, "[]")
if rep {
typ = typ[2:]
}
isMap := strings.HasPrefix(typ, "map[")
star := typ[0] == '*'
if !isMap { // map types handled lower down
typ = typ[strings.Index(typ, ".")+1:]
}
if star {
typ = "*" + typ
}
if rep {
// Go does not permit conversion between slice types where both
// element types are named. That means we need to generate a bit
// of code in this situation.
// typ is the element type.
// val is the expression to get the slice from the imported type.

ctyp := typ // conversion type expression; "Foo" or "(*Foo)"
if star {
ctyp = "(" + typ + ")"
}

g.P("func (m *", ms.sym, ") ", get.name, "() []", typ, " {")
g.In()
g.P("o := ", val)
g.P("if o == nil {")
g.In()
g.P("return nil")
g.Out()
g.P("}")
g.P("s := make([]", typ, ", len(o))")
g.P("for i, x := range o {")
g.In()
g.P("s[i] = ", ctyp, "(x)")
g.Out()
g.P("}")
g.P("return s")
g.Out()
g.P("}")
continue
}
if isMap {
// Split map[keyTyp]valTyp.
bra, ket := strings.Index(typ, "["), strings.Index(typ, "]")
keyTyp, valTyp := typ[bra+1:ket], typ[ket+1:]
// Drop any package qualifier.
// Only the value type may be foreign.
star := valTyp[0] == '*'
valTyp = valTyp[strings.Index(valTyp, ".")+1:]
if star {
valTyp = "*" + valTyp
}

typ := "map[" + keyTyp + "]" + valTyp
g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " {")
g.P("o := ", val)
g.P("if o == nil { return nil }")
g.P("s := make(", typ, ", len(o))")
g.P("for k, v := range o {")
g.P("s[k] = (", valTyp, ")(v)")
g.P("}")
g.P("return s")
g.P("}")
continue
}
// Convert imported type into the forwarding type.
val = "(" + typ + ")(" + val + ")"
}

g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " { return ", val, " }")
g.P("type ", ms.sym, " = ", pkg, ".", ms.sym)
for _, name := range ms.oneofTypes {
g.P("type ", name, " = ", pkg, ".", name)
}

}

type enumSymbol struct {
Expand All @@ -506,14 +370,9 @@ type enumSymbol struct {

func (es enumSymbol) GenerateAlias(g *Generator, pkg GoPackageName) {
s := es.name
g.P("type ", s, " ", pkg, ".", s)
g.P("type ", s, " = ", pkg, ".", s)
g.P("var ", s, "_name = ", pkg, ".", s, "_name")
g.P("var ", s, "_value = ", pkg, ".", s, "_value")
g.P("func (x ", s, ") String() string { return (", pkg, ".", s, ")(x).String() }")
if !es.proto3 {
g.P("func (x ", s, ") Enum() *", s, "{ return (*", s, ")((", pkg, ".", s, ")(x).Enum()) }")
g.P("func (x *", s, ") UnmarshalJSON(data []byte) error { return (*", pkg, ".", s, ")(x).UnmarshalJSON(data) }")
}
}

type constOrVarSymbol struct {
Expand Down Expand Up @@ -1486,20 +1345,18 @@ func (g *Generator) generateImports() {
}

func (g *Generator) generateImported(id *ImportedDescriptor) {
// Don't generate public import symbols for files that we are generating
// code for, since those symbols will already be in this package.
// We can't simply avoid creating the ImportedDescriptor objects,
// because g.genFiles isn't populated at that stage.
tn := id.TypeName()
sn := tn[len(tn)-1]
df := id.o.File()
filename := *df.Name
for _, fd := range g.genFiles {
if *fd.Name == filename {
g.P("// Ignoring public import of ", sn, " from ", filename)
g.P()
return
}
if df.importPath == g.file.importPath {
// Don't generate type aliases for files in the same Go package as this one.
g.P("// Ignoring public import of ", sn, " from ", filename)
g.P()
return
}
if !supportTypeAliases {
g.Fail(fmt.Sprintf("%s: public imports require at least go1.9", filename))
}
g.P("// ", sn, " from public import ", filename)
g.usedPackages[df.importPath] = true
Expand Down Expand Up @@ -2232,6 +2089,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
g.P("}")
}
g.P()
var oneofTypes []string
for i, field := range message.Field {
if field.OneofIndex == nil {
continue
Expand All @@ -2241,6 +2099,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
fieldFullPath := fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i)
g.P("type ", Annotate(message.file, fieldFullPath, oneofTypeName[field]), " struct{ ", Annotate(message.file, fieldFullPath, fieldNames[field]), " ", fieldTypes[field], " `", tag, "` }")
g.RecordTypeUse(field.GetTypeName())
oneofTypes = append(oneofTypes, oneofTypeName[field])
}
g.P()
for _, field := range message.Field {
Expand All @@ -2261,7 +2120,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
g.P()

// Field getters
var getters []getterSymbol
for i, field := range message.Field {
oneof := field.OneofIndex != nil

Expand All @@ -2278,42 +2136,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
}
fieldFullPath := fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i)

// Only export getter symbols for basic types,
// and for messages and enums in the same package.
// Groups are not exported.
// Foreign types can't be hoisted through a public import because
// the importer may not already be importing the defining .proto.
// As an example, imagine we have an import tree like this:
// A.proto -> B.proto -> C.proto
// If A publicly imports B, we need to generate the getters from B in A's output,
// but if one such getter returns something from C then we cannot do that
// because A is not importing C already.
var getter, genType bool
switch *field.Type {
case descriptor.FieldDescriptorProto_TYPE_GROUP:
getter = false
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_ENUM:
// Only export getter if its return type is in the same file.
//
// This should be the same package, not the same file.
// However, code elsewhere assumes that there's a 1-1 relationship
// between packages and files, so that's not safe.
//
// TODO: Tear out all of this complexity and just use type aliases.
getter = g.ObjectNamed(field.GetTypeName()).File() == message.File()
genType = true
default:
getter = true
}
if getter {
getters = append(getters, getterSymbol{
name: mname,
typ: typename,
typeName: field.GetTypeName(),
genType: genType,
})
}

if field.GetOptions().GetDeprecated() {
g.P(deprecationComment)
}
Expand Down Expand Up @@ -2416,8 +2238,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
sym: ccTypeName,
hasExtensions: hasExtensions,
isMessageSet: isMessageSet,
hasOneof: len(message.OneofDecl) > 0,
getters: getters,
oneofTypes: oneofTypes,
}
g.file.addExport(message, ms)
}
Expand Down Expand Up @@ -3094,3 +2915,14 @@ const (
// tag numbers in EnumDescriptorProto
enumValuePath = 2 // value
)

var supportTypeAliases bool

func init() {
for _, tag := range build.Default.ReleaseTags {
if tag == "go1.9" {
supportTypeAliases = true
return
}
}
}
15 changes: 15 additions & 0 deletions protoc-gen-go/golden_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"flag"
"fmt"
"go/build"
"go/parser"
"go/token"
"io/ioutil"
Expand Down Expand Up @@ -38,8 +39,13 @@ func TestGolden(t *testing.T) {

// Find all the proto files we need to compile. We assume that each directory
// contains the files for a single package.
supportTypeAliases := hasReleaseTag("1.9")
packages := map[string][]string{}
err = filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error {
if filepath.Base(path) == "import_public" && !supportTypeAliases {
// Public imports require type alias support.
return filepath.SkipDir
}
if !strings.HasSuffix(path, ".proto") {
return nil
}
Expand Down Expand Up @@ -405,3 +411,12 @@ func protoc(t *testing.T, args []string) {
t.Fatalf("protoc: %v", err)
}
}

func hasReleaseTag(want string) bool {
for _, tag := range build.Default.ReleaseTags {
if tag == want {
return true
}
}
return false
}
Loading