From 2be1d37dd7bd49aa8539a3a5f205f410db134366 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Sun, 23 Jun 2024 05:57:55 -0400 Subject: [PATCH] chore(internal/protoveneer): generate support functions Generate the support functions along with the rest of the code, instead of requiring a separate package. --- .../protoveneer/cmd/protoveneer/config.go | 2 - .../protoveneer/cmd/protoveneer/converters.go | 28 +++- .../protoveneer/internal}/support/support.go | 61 +++---- .../internal}/support/support_test.go | 8 +- .../cmd/protoveneer/protoveneer.go | 103 +++++++++--- .../cmd/protoveneer/protoveneer_test.go | 38 ++++- .../protoveneer/testdata/basic/config.yaml | 1 - .../cmd/protoveneer/testdata/basic/golden | 152 +++++++++++++++--- 8 files changed, 306 insertions(+), 87 deletions(-) rename internal/protoveneer/{ => cmd/protoveneer/internal}/support/support.go (52%) rename internal/protoveneer/{ => cmd/protoveneer/internal}/support/support_test.go (92%) diff --git a/internal/protoveneer/cmd/protoveneer/config.go b/internal/protoveneer/cmd/protoveneer/config.go index b70bda3a9bb4..f747af541fdf 100644 --- a/internal/protoveneer/cmd/protoveneer/config.go +++ b/internal/protoveneer/cmd/protoveneer/config.go @@ -26,8 +26,6 @@ import ( type config struct { Package string ProtoImportPath string `yaml:"protoImportPath"` - // Import path for the support package needed by the generated code. - SupportImportPath string `yaml:"supportImportPath"` // The types to process. Only these types and the types they depend // on will be output. diff --git a/internal/protoveneer/cmd/protoveneer/converters.go b/internal/protoveneer/cmd/protoveneer/converters.go index 8f364ee85f6a..0175c099b9b1 100644 --- a/internal/protoveneer/cmd/protoveneer/converters.go +++ b/internal/protoveneer/cmd/protoveneer/converters.go @@ -38,10 +38,18 @@ func (identityConverter) genTransformTo() string { return "" } // A derefConverter converts between T in the veneer and *T in the proto. type derefConverter struct{} -func (derefConverter) genFrom(arg string) string { return fmt.Sprintf("support.DerefOrZero(%s)", arg) } -func (derefConverter) genTo(arg string) string { return fmt.Sprintf("support.AddrOrNil(%s)", arg) } -func (derefConverter) genTransformFrom() string { panic("can't handle deref slices") } -func (derefConverter) genTransformTo() string { panic("can't handle deref slices") } +func (derefConverter) genFrom(arg string) string { + needSupport("pvDerefOrZero") + return fmt.Sprintf("pvDerefOrZero(%s)", arg) +} + +func (derefConverter) genTo(arg string) string { + needSupport("pvAddrOrNil") + return fmt.Sprintf("pvAddrOrNil(%s)", arg) +} + +func (derefConverter) genTransformFrom() string { panic("can't handle deref slices") } +func (derefConverter) genTransformTo() string { panic("can't handle deref slices") } type enumConverter struct { protoName, veneerName string @@ -105,14 +113,16 @@ type sliceConverter struct { func (c sliceConverter) genFrom(arg string) string { if fn := c.eltConverter.genTransformFrom(); fn != "" { - return fmt.Sprintf("support.TransformSlice(%s, %s)", arg, fn) + needSupport("pvTransformSlice") + return fmt.Sprintf("pvTransformSlice(%s, %s)", arg, fn) } return c.eltConverter.genFrom(arg) } func (c sliceConverter) genTo(arg string) string { if fn := c.eltConverter.genTransformTo(); fn != "" { - return fmt.Sprintf("support.TransformSlice(%s, %s)", arg, fn) + needSupport("pvTransformSlice") + return fmt.Sprintf("pvTransformSlice(%s, %s)", arg, fn) } return c.eltConverter.genTo(arg) } @@ -132,14 +142,16 @@ type mapConverter struct { func (c mapConverter) genFrom(arg string) string { if fn := c.valueConverter.genTransformFrom(); fn != "" { - return fmt.Sprintf("support.TransformMapValues(%s, %s)", arg, fn) + needSupport("pvTransformMapValues") + return fmt.Sprintf("pvTransformMapValues(%s, %s)", arg, fn) } return c.valueConverter.genFrom(arg) } func (c mapConverter) genTo(arg string) string { if fn := c.valueConverter.genTransformTo(); fn != "" { - return fmt.Sprintf("support.TransformMapValues(%s, %s)", arg, fn) + needSupport("pvTransformMapValues") + return fmt.Sprintf("pvTransformMapValues(%s, %s)", arg, fn) } return c.valueConverter.genTo(arg) } diff --git a/internal/protoveneer/support/support.go b/internal/protoveneer/cmd/protoveneer/internal/support/support.go similarity index 52% rename from internal/protoveneer/support/support.go rename to internal/protoveneer/cmd/protoveneer/internal/support/support.go index a59c4c825233..22891d18f315 100644 --- a/internal/protoveneer/support/support.go +++ b/internal/protoveneer/cmd/protoveneer/internal/support/support.go @@ -12,7 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package support provides support functions for protoveneer. +// Package support provides support functions for protoveneer. The protoveneer binary +// embeds it and extracts the needed functions when generating code. +// +// This package should not be imported. It is written as an ordinary Go package so +// it can be edited and tested with standard tools. +// +// The symbols begin with "pv" to reduce the chance of collision when the generated +// code is combined with user-written code in the same package. package support import ( @@ -29,9 +36,9 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -// TransformSlice applies f to each element of from and returns +// pvTransformSlice applies f to each element of from and returns // a new slice with the results. -func TransformSlice[From, To any](from []From, f func(From) To) []To { +func pvTransformSlice[From, To any](from []From, f func(From) To) []To { if from == nil { return nil } @@ -42,9 +49,9 @@ func TransformSlice[From, To any](from []From, f func(From) To) []To { return to } -// TransformMapValues applies f to each value of from, returning a new map. +// pvTransformMapValues applies f to each value of from, returning a new map. // It does not change the keys. -func TransformMapValues[K comparable, VFrom, VTo any](from map[K]VFrom, f func(VFrom) VTo) map[K]VTo { +func pvTransformMapValues[K comparable, VFrom, VTo any](from map[K]VFrom, f func(VFrom) VTo) map[K]VTo { if from == nil { return nil } @@ -55,9 +62,9 @@ func TransformMapValues[K comparable, VFrom, VTo any](from map[K]VFrom, f func(V return to } -// AddrOrNil returns nil if x is the zero value for T, +// pvAddrOrNil returns nil if x is the zero value for T, // or &x otherwise. -func AddrOrNil[T comparable](x T) *T { +func pvAddrOrNil[T comparable](x T) *T { var z T if x == z { return nil @@ -65,9 +72,9 @@ func AddrOrNil[T comparable](x T) *T { return &x } -// DerefOrZero returns the zero value for T if x is nil, +// pvDerefOrZero returns the zero value for T if x is nil, // or *x otherwise. -func DerefOrZero[T any](x *T) T { +func pvDerefOrZero[T any](x *T) T { if x == nil { var z T return z @@ -75,8 +82,8 @@ func DerefOrZero[T any](x *T) T { return *x } -// CivilDateToProto converts a civil.Date to a date.Date. -func CivilDateToProto(d civil.Date) *date.Date { +// pvCivilDateToProto converts a civil.Date to a date.Date. +func pvCivilDateToProto(d civil.Date) *date.Date { return &date.Date{ Year: int32(d.Year), Month: int32(d.Month), @@ -84,8 +91,8 @@ func CivilDateToProto(d civil.Date) *date.Date { } } -// CivilDateFromProto converts a date.Date to a civil.Date. -func CivilDateFromProto(p *date.Date) civil.Date { +// pvCivilDateFromProto converts a date.Date to a civil.Date. +func pvCivilDateFromProto(p *date.Date) civil.Date { if p == nil { return civil.Date{} } @@ -96,8 +103,8 @@ func CivilDateFromProto(p *date.Date) civil.Date { } } -// MapToStructPB converts a map into a structpb.Struct. -func MapToStructPB(m map[string]any) *structpb.Struct { +// pvMapToStructPB converts a map into a structpb.Struct. +func pvMapToStructPB(m map[string]any) *structpb.Struct { if m == nil { return nil } @@ -108,40 +115,40 @@ func MapToStructPB(m map[string]any) *structpb.Struct { return s } -// MapFromStructPB converts a structpb.Struct to a map. -func MapFromStructPB(p *structpb.Struct) map[string]any { +// pvMapFromStructPB converts a structpb.Struct to a map. +func pvMapFromStructPB(p *structpb.Struct) map[string]any { if p == nil { return nil } return p.AsMap() } -// TimeToProto converts a time.Time into a Timestamp. -func TimeToProto(t time.Time) *timestamppb.Timestamp { +// pvTimeToProto converts a time.Time into a Timestamp. +func pvTimeToProto(t time.Time) *timestamppb.Timestamp { if t.IsZero() { return nil } return timestamppb.New(t) } -// TimeFromProto converts a Timestamp into a time.Time. -func TimeFromProto(ts *timestamppb.Timestamp) time.Time { +// pvTimeFromProto converts a Timestamp into a time.Time. +func pvTimeFromProto(ts *timestamppb.Timestamp) time.Time { if ts == nil { return time.Time{} } return ts.AsTime() } -// APIErrorToProto converts an APIError to a proto Status. -func APIErrorToProto(ae *apierror.APIError) *spb.Status { +// pvAPIErrorToProto converts an APIError to a proto Status. +func pvAPIErrorToProto(ae *apierror.APIError) *spb.Status { if ae == nil { return nil } return ae.GRPCStatus().Proto() } -// APIErrorFromProto converts a proto Status to an APIError. -func APIErrorFromProto(s *spb.Status) *apierror.APIError { +// pvAPIErrorFromProto converts a proto Status to an APIError. +func pvAPIErrorFromProto(s *spb.Status) *apierror.APIError { err := gstatus.ErrorProto(s) aerr, ok := apierror.ParseError(err, true) if !ok { @@ -151,8 +158,8 @@ func APIErrorFromProto(s *spb.Status) *apierror.APIError { return aerr } -// DurationFromProto converts a Duration proto to a time.Duration. -func DurationFromProto(d *durationpb.Duration) time.Duration { +// pvDurationFromProto converts a Duration proto to a time.Duration. +func pvDurationFromProto(d *durationpb.Duration) time.Duration { if d == nil { return 0 } diff --git a/internal/protoveneer/support/support_test.go b/internal/protoveneer/cmd/protoveneer/internal/support/support_test.go similarity index 92% rename from internal/protoveneer/support/support_test.go rename to internal/protoveneer/cmd/protoveneer/internal/support/support_test.go index 347dd0dc6574..25a1f614900f 100644 --- a/internal/protoveneer/support/support_test.go +++ b/internal/protoveneer/cmd/protoveneer/internal/support/support_test.go @@ -28,12 +28,12 @@ import ( func TestTransformMapValues(t *testing.T) { var from map[string]int - got := TransformMapValues(from, strconv.Itoa) + got := pvTransformMapValues(from, strconv.Itoa) if got != nil { t.Fatalf("got %v, want nil", got) } from = map[string]int{"one": 1, "two": 2} - got = TransformMapValues(from, strconv.Itoa) + got = pvTransformMapValues(from, strconv.Itoa) want := map[string]string{"one": "1", "two": "2"} if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -57,7 +57,7 @@ func TestAPIError(t *testing.T) { Details: []*anypb.Any{pbany}, } - ae := APIErrorFromProto(s) + ae := pvAPIErrorFromProto(s) if ae == nil { t.Fatal("got nil") } @@ -72,7 +72,7 @@ func TestAPIError(t *testing.T) { t.Errorf("got %q, want %q", g, reason) } - gps := APIErrorToProto(ae) + gps := pvAPIErrorToProto(ae) if !cmp.Equal(gps, s, cmpopts.IgnoreUnexported(spb.Status{}, anypb.Any{})) { t.Errorf("\ngot %s\nwant %s", gps, s) } diff --git a/internal/protoveneer/cmd/protoveneer/protoveneer.go b/internal/protoveneer/cmd/protoveneer/protoveneer.go index 8c6a3713cdb0..34d114c57b72 100644 --- a/internal/protoveneer/cmd/protoveneer/protoveneer.go +++ b/internal/protoveneer/cmd/protoveneer/protoveneer.go @@ -34,14 +34,7 @@ // See the config type in config.go and the config.yaml files in the testdata // subdirectories to understand how to write configuration. // -// # Support functions -// -// protoveneer generates code that relies on a few support functions. These live -// in the support subdirectory. You should copy the contents of this directory -// to a location of your choice, and add "supportImportPath" to your config to -// refer to that directory's import path. -// -// # Unhandled features +// # Unsupported features // // There is no support for oneofs. Omit the oneof type and write custom code. // However, the types of the individual oneof cases can be generated. @@ -56,6 +49,7 @@ package main import ( "bytes" "context" + _ "embed" "errors" "flag" "fmt" @@ -68,6 +62,7 @@ import ( "os" "path" "path/filepath" + "regexp" "sort" "strings" "text/template" @@ -265,6 +260,8 @@ func buildConverterMap(typeInfos []*typeInfo, conf *config) (map[string]converte for _, et := range externalTypes { if et.used && et.convertTo != "" { converters[et.qualifiedName] = customConverter{et.convertTo, et.convertFrom} + needSupport(et.convertTo) + needSupport(et.convertFrom) } } @@ -752,12 +749,14 @@ func write(typeInfos []*typeInfo, conf *config, fset *token.FileSet) ([]byte, er } pr("\n") prn(` pb "%s"`, conf.ProtoImportPath) - if conf.SupportImportPath == "" { - return nil, errors.New("missing supportImportPath in config") - } - prn(` "%s"`, conf.SupportImportPath) for ip := range otherImportPaths { - prn(` "%s"`, ip) + // May be just a path, or "id path". + id, path, found := strings.Cut(ip, " ") + if !found { + prn(` "%s"`, ip) + } else { + prn(` %s "%s"`, id, path) + } } pr(")\n\n") @@ -782,7 +781,9 @@ func write(typeInfos []*typeInfo, conf *config, fset *token.FileSet) ([]byte, er ti.generateConversionMethods(pr) } } - + if err := generateSupportFunctions(&buf, neededSupportFunctions); err != nil { + return nil, err + } return buf.Bytes(), nil } @@ -888,35 +889,40 @@ var externalTypes = []*externalType{ qualifiedName: "civil.Date", replaces: "*date.Date", importPaths: []string{"cloud.google.com/go/civil"}, - convertTo: "support.CivilDateToProto", - convertFrom: "support.CivilDateFromProto", + convertTo: "pvCivilDateToProto", + convertFrom: "pvCivilDateFromProto", }, { qualifiedName: "map[string]any", replaces: "*structpb.Struct", - convertTo: "support.MapToStructPB", - convertFrom: "support.MapFromStructPB", + importPaths: []string{"google.golang.org/protobuf/types/known/structpb"}, + convertTo: "pvMapToStructPB", + convertFrom: "pvMapFromStructPB", }, { qualifiedName: "time.Time", replaces: "*timestamppb.Timestamp", - importPaths: []string{"time"}, - convertTo: "support.TimeToProto", - convertFrom: "support.TimeFromProto", + importPaths: []string{"time", "google.golang.org/protobuf/types/known/timestamppb"}, + convertTo: "pvTimeToProto", + convertFrom: "pvTimeFromProto", }, { qualifiedName: "time.Duration", replaces: "*durationpb.Duration", importPaths: []string{"time", "google.golang.org/protobuf/types/known/durationpb"}, convertTo: "durationpb.New", - convertFrom: "support.DurationFromProto", + convertFrom: "pvDurationFromProto", }, { qualifiedName: "*apierror.APIError", replaces: "*status.Status", - importPaths: []string{"github.com/googleapis/gax-go/v2/apierror"}, - convertTo: "support.APIErrorToProto", - convertFrom: "support.APIErrorFromProto", + importPaths: []string{ + "github.com/googleapis/gax-go/v2/apierror", + "spb google.golang.org/genproto/googleapis/rpc/status", + "gstatus google.golang.org/grpc/status", + }, + convertTo: "pvAPIErrorToProto", + convertFrom: "pvAPIErrorFromProto", }, } @@ -938,6 +944,53 @@ func init() { //////////////////////////////////////////////////////////////// +//go:embed internal/support/support.go +var supportCode string + +var neededSupportFunctions = map[string]bool{} + +// needSupport should be called whenever a support function is needed by the generated code. +// It is OK to call it for functions that are not in the support package. +func needSupport(name string) { neededSupportFunctions[name] = true } + +var ( + // Regexps to match the start and end of top-level functions. + // These assume the file is gofmt'd. + // The "m" flag means that ^ and $ match line starts and ends, respectively. + startFuncRegexp = regexp.MustCompile(`(?m:^func ([A-Za-z0-9_]+))`) + endFuncRegexp = regexp.MustCompile(`(?m:^}$)`) +) + +// generateSupportFunctions writes the support functions needed by the +// generated code to w. +func generateSupportFunctions(w io.Writer, need map[string]bool) error { + // Walk through the file of support functions, + // writing the ones whose names are in need. + code := supportCode + for { + inds := startFuncRegexp.FindStringSubmatchIndex(code) + if inds == nil { + break + } + end := endFuncRegexp.FindStringIndex(code) + if end == nil { + return errors.New("generateSupportFunctions: missing function end") + } + // inds[0] to inds[1]: entire start regexp + // inds[2] to inds[3]: function name + // end[1]: index of newline after '}'. + name := code[inds[2]:inds[3]] + if need[name] { + fmt.Fprintf(w, "\n%s\n", code[inds[0]:end[1]]) + } + // Move past match. + code = code[end[1]:] + } + return nil +} + +//////////////////////////////////////////////////////////////// + var emptyFileSet = token.NewFileSet() // typeString produces a string for a type expression. diff --git a/internal/protoveneer/cmd/protoveneer/protoveneer_test.go b/internal/protoveneer/cmd/protoveneer/protoveneer_test.go index 23c14d6bd21a..4ea3f3762f65 100644 --- a/internal/protoveneer/cmd/protoveneer/protoveneer_test.go +++ b/internal/protoveneer/cmd/protoveneer/protoveneer_test.go @@ -15,6 +15,7 @@ package main import ( + "bytes" "context" "flag" "os" @@ -41,13 +42,15 @@ func TestGeneration(t *testing.T) { dir := filepath.Join("testdata", e.Name()) configFile := filepath.Join(dir, "config.yaml") goldenFile := filepath.Join(dir, "golden") - outFile := filepath.Join(dir, e.Name()+"_veneer.gen.go") + // Don't use t.TempDir, because it will be removed even if -keep is set. + outDir := os.TempDir() + outFile := filepath.Join(outDir, e.Name()+"_veneer.gen.go") if *keep { t.Logf("keeping %s", outFile) } else { defer os.Remove(outFile) } - if err := run(ctx, configFile, dir, dir); err != nil { + if err := run(ctx, configFile, dir, outDir); err != nil { t.Fatal(err) } if *update { @@ -161,3 +164,34 @@ func TestInStdLib(t *testing.T) { } } } + +func TestGenerateSupportFunctions(t *testing.T) { + var buf bytes.Buffer + need := map[string]bool{ + "pvDurationFromProto": true, + "pvAddrOrNil": true, + } + if err := generateSupportFunctions(&buf, need); err != nil { + t.Fatal(err) + } + got := buf.String() + want := ` +func pvAddrOrNil[T comparable](x T) *T { + var z T + if x == z { + return nil + } + return &x +} + +func pvDurationFromProto(d *durationpb.Duration) time.Duration { + if d == nil { + return 0 + } + return d.AsDuration() +} +` + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("mismatch (-want, _got):\n%s", diff) + } +} diff --git a/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml b/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml index 3b2fe3bee8f6..6f38819bae84 100644 --- a/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml +++ b/internal/protoveneer/cmd/protoveneer/testdata/basic/config.yaml @@ -14,7 +14,6 @@ package: basic protoImportPath: example.com/basic -supportImportPath: example.com/protoveneer/support types: HarmCategory: diff --git a/internal/protoveneer/cmd/protoveneer/testdata/basic/golden b/internal/protoveneer/cmd/protoveneer/testdata/basic/golden index 5b8d90211706..307d8603fab8 100644 --- a/internal/protoveneer/cmd/protoveneer/testdata/basic/golden +++ b/internal/protoveneer/cmd/protoveneer/testdata/basic/golden @@ -8,9 +8,12 @@ import ( "cloud.google.com/go/civil" pb "example.com/basic" - "example.com/protoveneer/support" "github.com/googleapis/gax-go/v2/apierror" + spb "google.golang.org/genproto/googleapis/rpc/status" + gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" ) // Blob contains raw media bytes. @@ -59,9 +62,9 @@ func (v *Citation) toProto() *pb.Citation { } return &pb.Citation{ Uri: v.URI, - PublicationDate: support.CivilDateToProto(v.PublicationDate), - Struct: support.MapToStructPB(v.Struct), - CreateTime: support.TimeToProto(v.CreateTime), + PublicationDate: pvCivilDateToProto(v.PublicationDate), + Struct: pvMapToStructPB(v.Struct), + CreateTime: pvTimeToProto(v.CreateTime), } } @@ -71,9 +74,9 @@ func (Citation) fromProto(p *pb.Citation) *Citation { } return &Citation{ URI: p.Uri, - PublicationDate: support.CivilDateFromProto(p.PublicationDate), - Struct: support.MapFromStructPB(p.Struct), - CreateTime: support.TimeFromProto(p.CreateTime), + PublicationDate: pvCivilDateFromProto(p.PublicationDate), + Struct: pvMapFromStructPB(p.Struct), + CreateTime: pvTimeFromProto(p.CreateTime), } } @@ -89,8 +92,8 @@ func (v *CitationMetadata) toProto() *pb.CitationMetadata { return nil } return &pb.CitationMetadata{ - Citations: support.TransformSlice(v.Citations, (*Citation).toProto), - CitMap: support.TransformMapValues(v.CitMap, (*Citation).toProto), + Citations: pvTransformSlice(v.Citations, (*Citation).toProto), + CitMap: pvTransformMapValues(v.CitMap, (*Citation).toProto), } } @@ -99,8 +102,8 @@ func (CitationMetadata) fromProto(p *pb.CitationMetadata) *CitationMetadata { return nil } return &CitationMetadata{ - Citations: support.TransformSlice(p.Citations, (Citation{}).fromProto), - CitMap: support.TransformMapValues(p.CitMap, (Citation{}).fromProto), + Citations: pvTransformSlice(p.Citations, (Citation{}).fromProto), + CitMap: pvTransformMapValues(p.CitMap, (Citation{}).fromProto), } } @@ -115,7 +118,7 @@ func (v *File) toProto() *pb.File { return nil } return &pb.File{ - Error: support.APIErrorToProto(v.Error), + Error: pvAPIErrorToProto(v.Error), Dur: durationpb.New(v.Dur), } } @@ -125,8 +128,8 @@ func (File) fromProto(p *pb.File) *File { return nil } return &File{ - Error: support.APIErrorFromProto(p.Error), - Dur: support.DurationFromProto(p.Dur), + Error: pvAPIErrorFromProto(p.Error), + Dur: pvDurationFromProto(p.Dur), } } @@ -191,8 +194,8 @@ func (v *GenerationConfig) toProto() *pb.GenerationConfig { return nil } return &pb.GenerationConfig{ - Temperature: support.AddrOrNil(v.Temperature), - CandidateCount: support.AddrOrNil(v.CandidateCount), + Temperature: pvAddrOrNil(v.Temperature), + CandidateCount: pvAddrOrNil(v.CandidateCount), StopSequences: v.StopSequences, HarmCat: pb.HarmCategory(v.HarmCat), FinishReason: pb.Candidate_FinishReason(v.FinishReason), @@ -206,8 +209,8 @@ func (GenerationConfig) fromProto(p *pb.GenerationConfig) *GenerationConfig { return nil } return &GenerationConfig{ - Temperature: support.DerefOrZero(p.Temperature), - CandidateCount: support.DerefOrZero(p.CandidateCount), + Temperature: pvDerefOrZero(p.Temperature), + CandidateCount: pvDerefOrZero(p.CandidateCount), StopSequences: p.StopSequences, HarmCat: HarmCategory(p.HarmCat), FinishReason: FinishReason(p.FinishReason), @@ -276,3 +279,116 @@ func (Pop) fromProto(p *pb.Pop) *Pop { popYFrom(v, p) return v } + +func pvTransformSlice[From, To any](from []From, f func(From) To) []To { + if from == nil { + return nil + } + to := make([]To, len(from)) + for i, e := range from { + to[i] = f(e) + } + return to +} + +func pvTransformMapValues[K comparable, VFrom, VTo any](from map[K]VFrom, f func(VFrom) VTo) map[K]VTo { + if from == nil { + return nil + } + to := map[K]VTo{} + for k, v := range from { + to[k] = f(v) + } + return to +} + +func pvAddrOrNil[T comparable](x T) *T { + var z T + if x == z { + return nil + } + return &x +} + +func pvDerefOrZero[T any](x *T) T { + if x == nil { + var z T + return z + } + return *x +} + +func pvCivilDateToProto(d civil.Date) *date.Date { + return &date.Date{ + Year: int32(d.Year), + Month: int32(d.Month), + Day: int32(d.Day), + } +} + +func pvCivilDateFromProto(p *date.Date) civil.Date { + if p == nil { + return civil.Date{} + } + return civil.Date{ + Year: int(p.Year), + Month: time.Month(p.Month), + Day: int(p.Day), + } +} + +func pvMapToStructPB(m map[string]any) *structpb.Struct { + if m == nil { + return nil + } + s, err := structpb.NewStruct(m) + if err != nil { + panic(fmt.Errorf("support.MapToProto: %w", err)) + } + return s +} + +func pvMapFromStructPB(p *structpb.Struct) map[string]any { + if p == nil { + return nil + } + return p.AsMap() +} + +func pvTimeToProto(t time.Time) *timestamppb.Timestamp { + if t.IsZero() { + return nil + } + return timestamppb.New(t) +} + +func pvTimeFromProto(ts *timestamppb.Timestamp) time.Time { + if ts == nil { + return time.Time{} + } + return ts.AsTime() +} + +func pvAPIErrorToProto(ae *apierror.APIError) *spb.Status { + if ae == nil { + return nil + } + return ae.GRPCStatus().Proto() +} + +func pvAPIErrorFromProto(s *spb.Status) *apierror.APIError { + err := gstatus.ErrorProto(s) + aerr, ok := apierror.ParseError(err, true) + if !ok { + // Should be impossible. + return nil + } + return aerr +} + +func pvDurationFromProto(d *durationpb.Duration) time.Duration { + if d == nil { + return 0 + } + return d.AsDuration() +}