diff --git a/dev/tools/controllerbuilder/pkg/codegen/common.go b/dev/tools/controllerbuilder/pkg/codegen/common.go new file mode 100644 index 00000000000..80a18a0ef8b --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/codegen/common.go @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codegen + +import "strings" + +// special-case proto messages that are currently not mapped to KRM Go structs +var protoMessagesNotMappedToGoStruct = map[string]string{ + "google.protobuf.Timestamp": "string", + "google.protobuf.Duration": "string", + "google.protobuf.Int64Value": "int64", + "google.protobuf.StringValue": "string", + "google.protobuf.Struct": "map[string]string", +} + +var Acronyms = []string{ + "ID", "HTML", "URL", "HTTP", "HTTPS", "SSH", + "IP", "GB", "FS", "PD", "KMS", "GCE", "VTPM", +} + +// IsAcronym returns true if the given string is an acronym +func IsAcronym(s string) bool { + for _, acronym := range Acronyms { + if strings.EqualFold(s, acronym) { + return true + } + } + return false +} diff --git a/dev/tools/controllerbuilder/pkg/codegen/typegenerator.go b/dev/tools/controllerbuilder/pkg/codegen/typegenerator.go index 33381cdc844..4dae95b3fb0 100644 --- a/dev/tools/controllerbuilder/pkg/codegen/typegenerator.go +++ b/dev/tools/controllerbuilder/pkg/codegen/typegenerator.go @@ -30,15 +30,6 @@ import ( "k8s.io/klog/v2" ) -// Some special-case values that are not obvious how to map in KRM -var protoMessagesNotMappedToGoStruct = map[string]string{ - "google.protobuf.Timestamp": "string", - "google.protobuf.Duration": "string", - "google.protobuf.Int64Value": "int64", - "google.protobuf.StringValue": "string", - "google.protobuf.Struct": "map[string]string", -} - type TypeGenerator struct { generatorBase api *protoapi.Proto @@ -78,7 +69,7 @@ func (g *TypeGenerator) visitMessage(messageDescriptor protoreflect.MessageDescr g.visitedMessages = append(g.visitedMessages, messageDescriptor) - msgs, err := findDependenciesForMessage(messageDescriptor) + msgs, err := FindDependenciesForMessage(messageDescriptor) if err != nil { return err } @@ -123,7 +114,7 @@ func (g *TypeGenerator) WriteVisitedMessages() error { } out := g.getOutputFile(k) - goTypeName := goNameForProtoMessage(msg) + goTypeName := GoNameForProtoMessage(msg) skipGenerated := true goType, err := g.findTypeDeclaration(goTypeName, out.OutputDir(), skipGenerated) if err != nil { @@ -151,7 +142,7 @@ func (g *TypeGenerator) WriteVisitedMessages() error { } func WriteMessage(out io.Writer, msg protoreflect.MessageDescriptor) { - goType := goNameForProtoMessage(msg) + goType := GoNameForProtoMessage(msg) fmt.Fprintf(out, "\n") fmt.Fprintf(out, "// +kcc:proto=%s\n", msg.FullName()) @@ -163,51 +154,58 @@ func WriteMessage(out io.Writer, msg protoreflect.MessageDescriptor) { fmt.Fprintf(out, "}\n") } -func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protoreflect.MessageDescriptor, fieldIndex int) { - sourceLocations := msg.ParentFile().SourceLocations().ByDescriptor(field) - - jsonName := getJSONForKRM(field) - goFieldName := goFieldName(field) - goType := "" - +func GoType(field protoreflect.FieldDescriptor) (string, error) { if field.IsMap() { entryMsg := field.Message() keyKind := entryMsg.Fields().ByName("key").Kind() valueKind := entryMsg.Fields().ByName("value").Kind() if keyKind == protoreflect.StringKind && valueKind == protoreflect.StringKind { - goType = "map[string]string" + return "map[string]string", nil } else if keyKind == protoreflect.StringKind && valueKind == protoreflect.Int64Kind { - goType = "map[string]int64" + return "map[string]int64", nil } else { - fmt.Fprintf(out, "\n\t// TODO: map type %v %v for %v\n\n", keyKind, valueKind, field.Name()) - return + return "", fmt.Errorf("unsupported map type with key %v and value %v", keyKind, valueKind) } + } + + var goType string + switch field.Kind() { + case protoreflect.MessageKind: + goType = GoNameForProtoMessage(field.Message()) + case protoreflect.EnumKind: + goType = "string" + default: + goType = goTypeForProtoKind(field.Kind()) + } + + if field.Cardinality() == protoreflect.Repeated { + goType = "[]" + goType } else { - switch field.Kind() { - case protoreflect.MessageKind: - goType = goNameForProtoMessage(field.Message()) + goType = "*" + goType + } - case protoreflect.EnumKind: - goType = "string" //string(field.Enum().Name()) + // Special case for proto "bytes" type + if goType == "*[]byte" { + goType = "[]byte" + } + // Special case for proto "google.protobuf.Struct" type + if goType == "*map[string]string" { + goType = "map[string]string" + } - default: - goType = goTypeForProtoKind(field.Kind()) - } + return goType, nil +} - if field.Cardinality() == protoreflect.Repeated { - goType = "[]" + goType - } else { - goType = "*" + goType - } +func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protoreflect.MessageDescriptor, fieldIndex int) { + sourceLocations := msg.ParentFile().SourceLocations().ByDescriptor(field) - // Special case for proto "bytes" type - if goType == "*[]byte" { - goType = "[]byte" - } - // Special case for proto "google.protobuf.Struct" type - if goType == "*map[string]string" { - goType = "map[string]string" - } + jsonName := GetJSONForKRM(field) + GoFieldName := goFieldName(field) + + goType, err := GoType(field) + if err != nil { + fmt.Fprintf(out, "\n\t// TODO: %v\n\n", err) + return } // Blank line between fields for readability @@ -228,7 +226,7 @@ func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protorefl fmt.Fprintf(out, "\t// +kcc:proto=%s\n", field.FullName()) fmt.Fprintf(out, "\t%s %s `json:\"%s,omitempty\"`\n", - goFieldName, + GoFieldName, goType, jsonName, ) @@ -253,7 +251,7 @@ func deduplicateAndSort(messages []protoreflect.MessageDescriptor) []protoreflec return messages } -func goNameForProtoMessage(msg protoreflect.MessageDescriptor) string { +func GoNameForProtoMessage(msg protoreflect.MessageDescriptor) string { fullName := string(msg.FullName()) // Some special-case values that are not obvious how to map in KRM @@ -307,16 +305,16 @@ func goTypeForProtoKind(kind protoreflect.Kind) string { return goType } -// getJSONForKRM returns the KRM JSON name for the field, +// GetJSONForKRM returns the KRM JSON name for the field, // honoring KRM conventions -func getJSONForKRM(protoField protoreflect.FieldDescriptor) string { +func GetJSONForKRM(protoField protoreflect.FieldDescriptor) string { tokens := strings.Split(string(protoField.Name()), "_") for i, token := range tokens { if i == 0 { // Do not capitalize first token continue } - if isAcronym(token) { + if IsAcronym(token) { token = strings.ToUpper(token) } else { token = strings.Title(token) @@ -331,7 +329,7 @@ func getJSONForKRM(protoField protoreflect.FieldDescriptor) string { func goFieldName(protoField protoreflect.FieldDescriptor) string { tokens := strings.Split(string(protoField.Name()), "_") for i, token := range tokens { - if isAcronym(token) { + if IsAcronym(token) { token = strings.ToUpper(token) } else { token = strings.Title(token) @@ -341,35 +339,8 @@ func goFieldName(protoField protoreflect.FieldDescriptor) string { return strings.Join(tokens, "") } -func isAcronym(s string) bool { - switch s { - case "id": - return true - case "html", "url": - return true - case "http", "https", "ssh": - return true - case "ip": - return true - case "gb": - return true - case "fs": - return true - case "pd": - return true - case "kms": - return true - case "gce": - return true - case "vtpm": - return true - default: - return false - } -} - -// findDependenciesForMessage recursively explores the dependent proto messages of the given message. -func findDependenciesForMessage(message protoreflect.MessageDescriptor) ([]protoreflect.MessageDescriptor, error) { +// FindDependenciesForMessage recursively explores the dependent proto messages of the given message. +func FindDependenciesForMessage(message protoreflect.MessageDescriptor) ([]protoreflect.MessageDescriptor, error) { msgs := make(map[string]protoreflect.MessageDescriptor) for i := 0; i < message.Fields().Len(); i++ { field := message.Fields().Get(i) diff --git a/dev/tools/controllerbuilder/pkg/commands/updatetypes/insertcommand.go b/dev/tools/controllerbuilder/pkg/commands/updatetypes/insertcommand.go index 32533b9c090..63c9661101a 100644 --- a/dev/tools/controllerbuilder/pkg/commands/updatetypes/insertcommand.go +++ b/dev/tools/controllerbuilder/pkg/commands/updatetypes/insertcommand.go @@ -77,7 +77,7 @@ func bindInsertFlags(cmd *cobra.Command, opt *insertFieldOptions) { cmd.Flags().StringVar(&opt.field, "field", opt.field, "Name of the field to be inserted, e.g. `schedule_options_v2`") } -func runFieldInserter(ctx context.Context, opt *insertFieldOptions) error { +func runFieldInserter(_ context.Context, opt *insertFieldOptions) error { fieldInserter := typeupdater.NewFieldInserter(&typeupdater.InsertFieldOptions{ ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath, ParentMessageFullName: opt.parent, diff --git a/dev/tools/controllerbuilder/pkg/commands/updatetypes/synccommand.go b/dev/tools/controllerbuilder/pkg/commands/updatetypes/synccommand.go new file mode 100644 index 00000000000..56db4a767a8 --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/commands/updatetypes/synccommand.go @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package updatetypes + +import ( + "context" + "fmt" + + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/typeupdater" + "github.com/spf13/cobra" +) + +type syncProtoPackageOptions struct { + *baseUpdateTypeOptions +} + +func buildSyncCommand(baseOptions *baseUpdateTypeOptions) *cobra.Command { + opt := &syncProtoPackageOptions{ + baseUpdateTypeOptions: baseOptions, + } + + cmd := &cobra.Command{ + Use: "sync", + Short: "sync the KRM types with the proto package", + Long: `Sync the KRM types with the proto package. This command will update the KRM types +to match the proto package. If --message is specified, only the specified message and its +dependent messages will be synced. If --message is not specified, all messages in the proto +package indicated by --service will be synced.`, + PreRunE: validateSyncOptions(opt), + RunE: runSync(opt), + } + + opt.BindFlags(cmd) + + return cmd +} + +func validateSyncOptions(opt *syncProtoPackageOptions) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + if err := validateRequiredFlags(opt); err != nil { + return err + } + return nil + } +} + +func validateRequiredFlags(opt *syncProtoPackageOptions) error { + if opt.apiDirectory == "" { + return fmt.Errorf("--api-dir is required") + } + if opt.apiGoPackagePath == "" { + return fmt.Errorf("--api-go-package-path is required") + } + if opt.ServiceName == "" { + return fmt.Errorf("--service is required") + } + return nil +} + +func runSync(opt *syncProtoPackageOptions) func(*cobra.Command, []string) error { + return func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + if err := runPackageSyncer(ctx, opt); err != nil { + return err + } + return nil + } +} + +func runPackageSyncer(ctx context.Context, opt *syncProtoPackageOptions) error { + syncer := typeupdater.NewProtoPackageSyncer(&typeupdater.SyncProtoPackageOptions{ + ServiceName: opt.ServiceName, + APIVersion: opt.APIVersion, + ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath, + APIDirectory: opt.apiDirectory, + GoPackagePath: opt.apiGoPackagePath, + }) + return syncer.Run() +} diff --git a/dev/tools/controllerbuilder/pkg/commands/updatetypes/updatetypescommand.go b/dev/tools/controllerbuilder/pkg/commands/updatetypes/updatetypescommand.go index dfe8a913208..f396655b5f1 100644 --- a/dev/tools/controllerbuilder/pkg/commands/updatetypes/updatetypescommand.go +++ b/dev/tools/controllerbuilder/pkg/commands/updatetypes/updatetypescommand.go @@ -66,6 +66,7 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command { // subcommands cmd.AddCommand(buildInsertCommand(opt)) + cmd.AddCommand(buildSyncCommand(opt)) return cmd } diff --git a/dev/tools/controllerbuilder/pkg/gocode/messageinfo.go b/dev/tools/controllerbuilder/pkg/gocode/messageinfo.go new file mode 100644 index 00000000000..b804cb63a11 --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/gocode/messageinfo.go @@ -0,0 +1,185 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gocode + +import ( + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" +) + +// MessageInfo contains information about a Go struct parsed from existing types files. +// This struct is used to keep track of existing information about a message in the +// generated and human-edited code. +type MessageInfo struct { + GoName string // The Go struct name + ProtoName string // The proto message name from +kcc:proto annotation + IsVirtual bool // KRM-specific messages that don't map to proto + Comments []string // Original comments + Fields map[string]*FieldInfo // Map of field name to field info + FilePath string // The file path where this Go struct was located +} + +// FieldInfo contains information about a field in a Go struct parsed from existing types files. +// This struct is used to keep track of existing information about a field in the +// generated and human-edited code. +type FieldInfo struct { + GoName string // Field name in Go + ProtoName string // The fully qualified proto field name from +kcc:proto annotation + IsVirtual bool // KRM-specific fields that don't map to proto + IsIgnored bool // Field explicitly marked as not implemented + IsReference bool // Is this a reference field? + RefType string // What type of reference (ProjectRef, etc) + Comments []string // Preserve original comments for reference fields +} + +func ExtractMessageInfoFromGoFiles(dir string) (map[string]MessageInfo, error) { + messages := make(map[string]MessageInfo) + + err := filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error { + if err != nil || d.IsDir() || filepath.Ext(path) != ".go" { + return nil + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + return err + } + + docMap := NewDocMap(fset, file) + + ast.Inspect(file, func(n ast.Node) bool { + ts, ok := n.(*ast.TypeSpec) + if !ok { + return true + } + st, ok := ts.Type.(*ast.StructType) + if !ok { + return true + } + + msgInfo := NewMessageInfo(ts.Name.Name, path) + msgInfo.ParseComments(ts, docMap) + + // parse fields within the message + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + continue + } + fieldInfo := newFieldInfo(field.Names[0].Name) + fieldInfo.parseComments(field, docMap) + msgInfo.Fields[fieldInfo.GoName] = fieldInfo + } + + messages[msgInfo.GoName] = msgInfo + return true + }) + return nil + }) + + return messages, err +} + +func NewMessageInfo(name, filePath string) MessageInfo { + return MessageInfo{ + GoName: name, + FilePath: filePath, + Fields: make(map[string]*FieldInfo), + } +} + +func (info *MessageInfo) ParseComments(ts *ast.TypeSpec, docMap map[ast.Node]*ast.CommentGroup) { + info.IsVirtual = true + + if comments := docMap[ts]; comments != nil { + info.Comments = make([]string, 0, len(comments.List)) + for _, c := range comments.List { + text := strings.TrimSpace(strings.TrimPrefix(c.Text, "//")) + info.Comments = append(info.Comments, text) + + // check for proto annotation + if strings.HasPrefix(text, "+kcc:proto=") { + protoName := strings.TrimSpace(strings.TrimPrefix(text, "+kcc:proto=")) + info.ProtoName = protoName + info.IsVirtual = false + } + } + } +} + +func newFieldInfo(name string) *FieldInfo { + return &FieldInfo{ + GoName: name, + } +} + +func (info *FieldInfo) parseComments(field *ast.Field, docMap map[ast.Node]*ast.CommentGroup) { + info.IsVirtual = true + + // check if field is a reference field + if expr, ok := field.Type.(*ast.StarExpr); ok { + if sel, ok := expr.X.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "refv1beta1" { // HACK: this is a hack to identify reference fields + info.IsReference = true + info.RefType = sel.Sel.Name + } + } + } + } + + // parse comments to find kcc codegen annotations + if comments := docMap[field]; comments != nil { + info.Comments = make([]string, 0, len(comments.List)) + for _, c := range comments.List { + text := strings.TrimSpace(strings.TrimPrefix(c.Text, "//")) + info.Comments = append(info.Comments, text) + + if strings.HasPrefix(text, "+kcc:proto=") { + protoName := strings.TrimSpace(strings.TrimPrefix(text, "+kcc:proto=")) + info.ProtoName = protoName + info.IsVirtual = false + } + if strings.Contains(text, "NOTYET") || strings.Contains(text, "+kcc:proto:ignore") { + info.IsIgnored = true + } + } + } +} + +// GetSpecialAnnotations extracts special annotations like +required from comment group +// These annotations are manually added to the generated code, we need to preserve them. +func GetSpecialAnnotations(comments []string) []string { + if comments == nil { + return nil + } + + var annotations []string + for _, c := range comments { + if strings.Contains(c, "+genclient") || + strings.Contains(c, "+k8s") || + strings.Contains(c, "+kubebuilder") || + strings.Contains(c, "+required") || + strings.Contains(c, "+optional") || + strings.Contains(c, "Immutable") { + annotations = append(annotations, c) + } + } + return annotations +} diff --git a/dev/tools/controllerbuilder/pkg/typeupdater/common.go b/dev/tools/controllerbuilder/pkg/typeupdater/common.go new file mode 100644 index 00000000000..318cff5cff3 --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/typeupdater/common.go @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeupdater + +import ( + "go/ast" + "strings" + "unicode" + + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/codegen" +) + +const kccProtoPrefix = "+kcc:proto=" + +// commentContains checks if the given comment group contains a target string annotation +func commentContains(cg *ast.CommentGroup, target string) bool { + if cg == nil { + return false + } + for _, c := range cg.List { + trimmed := strings.TrimPrefix(c.Text, "//") + trimmed = strings.TrimSpace(trimmed) + if trimmed == target { + return true + } + } + return false +} + +// getProtoFieldName converts a fully qualified proto field name to a snake_case field name +// e.g. "google.cloud.bigquery.datatransfer.v1.TransferConfig.DisplayName" -> "display_name" +func getProtoFieldName(fullName string) string { + parts := strings.Split(fullName, ".") + if len(parts) == 0 { + return "" + } + lastPart := parts[len(parts)-1] + + // convert from camelCase to snake_case + var result []rune + var i int + for i < len(lastPart) { + // check for acronym sequence + if unicode.IsUpper(rune(lastPart[i])) { + if acronym := extractAcronym(lastPart[i:]); len(acronym) > 0 { + if i > 0 { + result = append(result, '_') + } + result = append(result, []rune(strings.ToLower(acronym))...) + i += len(acronym) + continue + } + } + + // regular camelCase handling + r := rune(lastPart[i]) + if i > 0 && unicode.IsUpper(r) { + result = append(result, '_') + } + result = append(result, unicode.ToLower(r)) + i++ + } + + return string(result) +} + +// extractAcronym checks if the string starts with a known acronym and returns it +func extractAcronym(s string) string { + // try to find the longest acronym starting at this position + for j := len(s); j > 0; j-- { + if codegen.IsAcronym(s[:j]) { + return s[:j] + } + } + return "" +} diff --git a/dev/tools/controllerbuilder/pkg/typeupdater/fieldinserter.go b/dev/tools/controllerbuilder/pkg/typeupdater/fieldinserter.go index 313e008b67b..457dcb77d16 100644 --- a/dev/tools/controllerbuilder/pkg/typeupdater/fieldinserter.go +++ b/dev/tools/controllerbuilder/pkg/typeupdater/fieldinserter.go @@ -30,8 +30,6 @@ import ( "k8s.io/klog" ) -const kccProtoPrefix = "+kcc:proto=" - type InsertFieldOptions struct { ParentMessageFullName string FieldToInsert string diff --git a/dev/tools/controllerbuilder/pkg/typeupdater/fieldupdateplan.go b/dev/tools/controllerbuilder/pkg/typeupdater/fieldupdateplan.go new file mode 100644 index 00000000000..2249d18835a --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/typeupdater/fieldupdateplan.go @@ -0,0 +1,206 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeupdater + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/codegen" + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode" + "google.golang.org/protobuf/reflect/protoreflect" + "k8s.io/klog" +) + +// FieldUpdatePlan represents a planned update to a specific Go field +type FieldUpdatePlan struct { + filepath string // path to the file containing the field + + structName string // name of the Go struct containing the field + fieldName string // name of the Go field to update + fieldInfo *gocode.FieldInfo // original field info for reference + + protoParentName string // fully qualified name of the proto parent message + protoName string // fully qualified name of the proto field + protoField protoreflect.FieldDescriptor // proto field descriptor + + content []byte // generated field content +} + +func (s *ProtoPackageSyncer) createFieldUpdatePlan(msgInfo gocode.MessageInfo, fieldInfo *gocode.FieldInfo, msgDesc protoreflect.MessageDescriptor) (*FieldUpdatePlan, error) { + if fieldInfo.IsVirtual { + // TODO: we should skip a virtual field in the future, after all fields are properly annotated. for now, we use a hack to fake the proto name + fieldInfo.ProtoName = fmt.Sprintf("%s.%s", msgInfo.ProtoName, fieldInfo.GoName) // HACK: this is a hack to fake the proto name for go fields that are missing the proto name annotation + // klog.Infof("Skipping virtual field %s in %s", fieldInfo.GoName, msgInfo.GoName) + // return nil, nil + } + + // 1. find the proto field + name := getProtoFieldName(fieldInfo.ProtoName) // e.g. "google.cloud.bigquery.datatransfer.v1.TransferConfig.DisplayName" -> "display_name" + protoField := msgDesc.Fields().ByName(protoreflect.Name(name)) + if protoField == nil { + klog.Warningf("proto field %s (full name: %s) not found in message %s", name, fieldInfo.ProtoName, msgInfo.ProtoName) + return nil, nil + } + + // 2. generate Go structs for the field + var buf bytes.Buffer + + // 2.1 special annotations such as "// +required" are manually added to the generated code, we need to preserve them + specialAnnotations := gocode.GetSpecialAnnotations(fieldInfo.Comments) + if len(specialAnnotations) > 0 { + for _, annotation := range specialAnnotations { + fmt.Fprintf(&buf, "\t// %s\n", annotation) + } + } + + // 2.2 regenerate the field content based on the proto field descriptor + if fieldInfo.IsReference { // For reference fields, preserve original comments and reference type + return nil, nil // skip generating reference fields for now since we don't plan to update them + /* for _, comment := range fieldInfo.Comments { + fmt.Fprintf(&buf, "\t// %s\n", comment) + } + jsonName := codegen.GetJSONForKRM(protoField) + fmt.Fprintf(&buf, "\t%s *refv1beta1.%s `json:\"%s,omitempty\"`\n", + fieldInfo.GoName, + fieldInfo.RefType, + jsonName) */ + } else if fieldInfo.IsIgnored { // for ignored fields, generate only the field declaration without comments + goType, err := codegen.GoType(protoField) + if err != nil { + return nil, fmt.Errorf("determining Go type for ignored field %s (proto: %s): %w", fieldInfo.GoName, fieldInfo.ProtoName, err) + } + jsonName := codegen.GetJSONForKRM(protoField) + fmt.Fprintf(&buf, "\t%s %s `json:\"%s,omitempty\"`\n", + fieldInfo.GoName, + goType, + jsonName) + } else { // for regular fields, generate complete field with comments + codegen.WriteField(&buf, protoField, msgDesc, 0) // HACK: use fieldIndex=0 to avoid generating a blank line + } + + // 3. create the update plan to record every information we need to update the field + plan := &FieldUpdatePlan{ + filepath: msgInfo.FilePath, + structName: msgInfo.GoName, + fieldName: fieldInfo.GoName, + fieldInfo: fieldInfo, + protoParentName: msgInfo.ProtoName, + protoName: fieldInfo.ProtoName, + protoField: protoField, + content: buf.Bytes(), + } + + return plan, nil +} + +func (s *ProtoPackageSyncer) applyFieldUpdatePlan(plan FieldUpdatePlan) error { + content, err := os.ReadFile(plan.filepath) + if err != nil { + return fmt.Errorf("reading file %s: %w", plan.filepath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, plan.filepath, content, parser.ParseComments) + if err != nil { + return fmt.Errorf("parsing file %s: %w", plan.filepath, err) + } + + docMap := gocode.NewDocMap(fset, file) + + // find the target struct and field by matching the proto name + targetMessageName := fmt.Sprintf("%s%s", kccProtoPrefix, plan.protoParentName) + targetFieldName := fmt.Sprintf("%s%s", kccProtoPrefix, plan.protoName) + var fieldNode *ast.Field + var found bool + ast.Inspect(file, func(n ast.Node) bool { + if found { + return false + } + + ts, ok := n.(*ast.TypeSpec) + if !ok { + return true + } + st, ok := ts.Type.(*ast.StructType) + if !ok { + return false + } + if !commentContains(docMap[ts], targetMessageName) { // match by fully qualified proto name annotation + return true + } + + // find the target field + fieldComments := docMap[ts] + for _, field := range st.Fields.List { + if commentContains(fieldComments, targetFieldName) || + (len(field.Names) > 0 && field.Names[0].Name == plan.fieldName) { // TODO: this is a hack to match the field name without proper proto name annotation + fieldNode = field + found = true + return false + } + } + return true + }) + + if !found { + return fmt.Errorf("field %s not found in struct %s", plan.fieldName, plan.structName) + } + + // get the start position (including doc comments if they exist) + var startPos token.Pos + var hasDoc bool + if doc := docMap[fieldNode]; doc != nil { + startPos = doc.Pos() + hasDoc = true + } else { + startPos = fieldNode.Pos() + } + start := fset.Position(startPos) + end := fset.Position(fieldNode.End()) + + if hasDoc { // HACK: remove the leading tabs from the original field content + start.Offset-- + } + + // replace the field content + newContent := make([]byte, 0, len(content)+len(plan.content)) + newContent = append(newContent, content[:start.Offset]...) + newContent = append(newContent, plan.content...) + newContent = append(newContent, content[end.Offset:]...) + + if err := os.WriteFile(plan.filepath, newContent, 0644); err != nil { + return fmt.Errorf("writing file %s: %w", plan.filepath, err) + } + + return nil +} + +func printUpdatePlans(plans []FieldUpdatePlan) { + klog.Infof("Field update plans:") + for _, plan := range plans { + klog.Infof("- File: %s", plan.filepath) + klog.Infof(" Struct: %s", plan.structName) + klog.Infof(" Field: %s", plan.fieldName) + klog.Infof(" Proto: %s", plan.protoName) + klog.Infof(" IsReference: %v", plan.fieldInfo.IsReference) + klog.Infof(" IsIgnored: %v", plan.fieldInfo.IsIgnored) + klog.Infof(" Content: %s", string(plan.content)) + } +} diff --git a/dev/tools/controllerbuilder/pkg/typeupdater/insertfield-ast.go b/dev/tools/controllerbuilder/pkg/typeupdater/insertfield-ast.go index 774dbfdaf04..df2854c92ff 100644 --- a/dev/tools/controllerbuilder/pkg/typeupdater/insertfield-ast.go +++ b/dev/tools/controllerbuilder/pkg/typeupdater/insertfield-ast.go @@ -79,7 +79,7 @@ func (u *FieldInserter) insertGoField() error { } comments := docMap[ts] - if !isTargetStruct(comments, targetComment) { + if !commentContains(comments, targetComment) { return true } @@ -127,17 +127,3 @@ func (u *FieldInserter) insertGoField() error { return nil } - -func isTargetStruct(cg *ast.CommentGroup, target string) bool { - if cg == nil { - return false - } - for _, c := range cg.List { - trimmed := strings.TrimPrefix(c.Text, "//") - trimmed = strings.TrimSpace(trimmed) - if trimmed == target { - return true - } - } - return false -} diff --git a/dev/tools/controllerbuilder/pkg/typeupdater/protopackagesyncer.go b/dev/tools/controllerbuilder/pkg/typeupdater/protopackagesyncer.go new file mode 100644 index 00000000000..8b50549b97b --- /dev/null +++ b/dev/tools/controllerbuilder/pkg/typeupdater/protopackagesyncer.go @@ -0,0 +1,158 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeupdater + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/gocode" + "github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/protoapi" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/klog" +) + +type SyncProtoPackageOptions struct { + ServiceName string + APIVersion string + ProtoSourcePath string + APIDirectory string + GoPackagePath string +} + +type ProtoPackageSyncer struct { + opts *SyncProtoPackageOptions + + // holds info about Go structs in existing types files, including both generated and manually edited structs. + // key is the go struct name + existingGoMessages map[string]gocode.MessageInfo + api *protoapi.Proto // Store the loaded proto API +} + +func NewProtoPackageSyncer(opts *SyncProtoPackageOptions) *ProtoPackageSyncer { + return &ProtoPackageSyncer{ + opts: opts, + existingGoMessages: make(map[string]gocode.MessageInfo), + } +} + +func (s *ProtoPackageSyncer) Run() error { + // 1. parse the existing go types + if err := s.parseExistingTypes(); err != nil { + return err + } + + // 2. load the proto package + if err := s.loadProtoPackage(); err != nil { + return err + } + + // 3. create the update plans + plans, err := s.createFieldUpdatePlans() + if err != nil { + return fmt.Errorf("creating update plans: %w", err) + } + + // printUpdatePlans(plans) + + // 4. apply the update plans to update the existing types + for _, plan := range plans { + if err := s.applyFieldUpdatePlan(plan); err != nil { + return fmt.Errorf("applying update plan for field %s in struct %s: %w", + plan.fieldName, plan.structName, err) + } + } + + return nil +} + +func (s *ProtoPackageSyncer) parseExistingTypes() error { + dir, err := typeFilePath(s.opts.APIDirectory, s.opts.APIVersion) + if err != nil { + return fmt.Errorf("getting API directory for %q: %w", s.opts.APIVersion, err) + } + + klog.Infof("Parsing existing types from %q", dir) + messages, err := gocode.ExtractMessageInfoFromGoFiles(dir) + if err != nil { + return err + } + + s.existingGoMessages = messages + return nil +} + +// typeFilePath returns the path to the types.go file for the given API version +func typeFilePath(apiBaseDir, gv string) (string, error) { + groupVersion, err := schema.ParseGroupVersion(gv) + if err != nil { + return "", fmt.Errorf("parsing APIVersion %q: %w", gv, err) + } + + goPackagePath := strings.TrimSuffix(groupVersion.Group, ".cnrm.cloud.google.com") + "/" + groupVersion.Version + packageTokens := strings.Split(goPackagePath, ".") + return filepath.Join(append([]string{apiBaseDir}, packageTokens...)...), nil +} + +func (s *ProtoPackageSyncer) createFieldUpdatePlans() ([]FieldUpdatePlan, error) { + var plans []FieldUpdatePlan + + // for each existing Go message that has a corresponding proto message + for goTypeName, msgInfo := range s.existingGoMessages { + if msgInfo.IsVirtual { + klog.Infof("Skipping virtual type %s", goTypeName) + continue + } + + // find corresponding proto message + desc, err := s.api.Files().FindDescriptorByName(protoreflect.FullName(msgInfo.ProtoName)) + if err != nil && err != protoregistry.NotFound { + return nil, fmt.Errorf("finding proto message %s: %w", msgInfo.ProtoName, err) + } + if desc == nil { + klog.Warningf("No proto message found for %s", msgInfo.ProtoName) + continue + } + msgDesc, ok := desc.(protoreflect.MessageDescriptor) + if !ok { + return nil, fmt.Errorf("unexpected descriptor type for %s: %T", msgInfo.ProtoName, desc) + } + + // for each field in the message, create update plan based on exsiting go types and the matching proto field + for fieldName, fieldInfo := range msgInfo.Fields { + plan, err := s.createFieldUpdatePlan(msgInfo, fieldInfo, msgDesc) + if err != nil { + return nil, fmt.Errorf("creating plan for field %s: %w", fieldName, err) + } + if plan != nil { + plans = append(plans, *plan) + } + } + } + + return plans, nil +} + +func (s *ProtoPackageSyncer) loadProtoPackage() error { + api, err := protoapi.LoadProto(s.opts.ProtoSourcePath) + if err != nil { + return fmt.Errorf("loading proto: %w", err) + } + s.api = api + return nil +} diff --git a/dev/tools/controllerbuilder/update.sh b/dev/tools/controllerbuilder/update.sh index 02aaee4d617..45be6367498 100755 --- a/dev/tools/controllerbuilder/update.sh +++ b/dev/tools/controllerbuilder/update.sh @@ -20,9 +20,15 @@ set -x REPO_ROOT="$(git rev-parse --show-toplevel)" cd ${REPO_ROOT}/dev/tools/controllerbuilder -# example usage +# example usage of inserting a field go run . update-types insert \ --parent "google.monitoring.dashboard.v1.Dashboard" \ --field "row_layout" \ --api-dir ${REPO_ROOT}/apis/monitoring/v1beta1 \ --ignored-fields "google.monitoring.dashboard.v1.PickTimeSeriesFilter.interval" + +# example usage of syncing a message with all of its dependencies from proto package +go run . update-types sync \ + --service google.cloud.bigquery.datatransfer.v1 \ + --api-version bigquerydatatransfer.cnrm.cloud.google.com/v1beta1 \ + --message "google.cloud.bigquery.datatransfer.v1.TransferConfig"