Skip to content

Commit

Permalink
feat: added proto3 optional support to fastmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
dsexton authored and wmorgan6796 committed Oct 28, 2024
1 parent ddb73d4 commit 2cf177e
Show file tree
Hide file tree
Showing 12 changed files with 935 additions and 648 deletions.
3 changes: 3 additions & 0 deletions cmd/protoc-gen-fastmarshal/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"text/template"

"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
)

// options defines the supported configuration options for a generation request
Expand Down Expand Up @@ -120,6 +121,8 @@ func doGenerate(opts *options) func(*protogen.Plugin) error {
opts.apiVersion = protoAPIVersion("v1")
}

plugin.SupportedFeatures |= uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)

for _, protoFile := range plugin.Files {
if !protoFile.Generate {
continue
Expand Down
61 changes: 40 additions & 21 deletions cmd/protoc-gen-fastmarshal/templates/fieldsnippets.tmpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{{/* SizeOfString - calculate the encoded size of a string field */}}
{{ define "SizeOfString" }}
{{- if ne (.Desc.Cardinality | string) "repeated" -}}
{{- if eq (.Desc.Syntax | string) "proto2" -}}
{{- if or (eq (.Desc.Syntax | string) "proto2") (ne .Desc.ContainingOneof nil) -}}
if m.{{ .GoName | getSafeFieldName }} != nil {
l = len(*m.{{ .GoName | getSafeFieldName }})
sz += csproto.SizeOfTagKey({{.Desc.Number}}) + csproto.SizeOfVarint(uint64(l)) + l
Expand Down Expand Up @@ -39,7 +39,7 @@
{{/* SizeOfInt - calculate the encoded size of an integer field */}}
{{ define "SizeOfInt" }}
{{- if ne (.Desc.Cardinality | string) "repeated" -}}
{{- if eq (.Desc.Syntax | string) "proto2" -}}
{{- if or (eq (.Desc.Syntax | string) "proto2") (ne .Desc.ContainingOneof nil) -}}
if m.{{.GoName | getSafeFieldName}} != nil {
sz += csproto.SizeOfTagKey({{.Desc.Number}}) + csproto.SizeOfVarint(uint64(*m.{{.GoName | getSafeFieldName}}))
}
Expand All @@ -66,7 +66,7 @@
{{/* SizeOfBool - calculate the encoded size of a boolean field */}}
{{ define "SizeOfBool" }}
{{- if ne (.Desc.Cardinality | string) "repeated" -}}
{{- if eq (.Desc.Syntax | string) "proto2" -}}
{{- if or (eq (.Desc.Syntax | string) "proto2") (ne .Desc.ContainingOneof nil) -}}
if m.{{.GoName | getSafeFieldName}} != nil {
sz += csproto.SizeOfTagKey({{.Desc.Number}}) + 1
}
Expand All @@ -88,7 +88,7 @@
{{/* SizeOfSInt - calculate the encoded size of a sint32 or sint64 field */}}
{{ define "SizeOfSInt" }}
{{- if ne (.Desc.Cardinality | string) "repeated" -}}
{{- if eq (.Desc.Syntax | string) "proto2" -}}
{{- if or (eq (.Desc.Syntax | string) "proto2") (ne .Desc.ContainingOneof nil) -}}
if m.{{.GoName | getSafeFieldName}} != nil {
sz += csproto.SizeOfTagKey({{.Desc.Number}}) + csproto.SizeOfZigZag(uint64(*m.{{.GoName | getSafeFieldName}}))
}
Expand All @@ -115,7 +115,7 @@
{{/* SizeOfFixed - calculate the encoded size of a fixed-width field */}}
{{ define "SizeOfFixed" }}
{{- if ne (.Desc.Cardinality | string) "repeated" -}}
{{- if eq (.Desc.Syntax | string) "proto2" -}}
{{- if or (eq (.Desc.Syntax | string) "proto2") (ne .Desc.ContainingOneof nil) -}}
if m.{{.GoName | getSafeFieldName}} != nil {
sz += csproto.SizeOfTagKey({{.Desc.Number}}) + {{if eq (.Desc.Kind | string) "fixed32" "sfixed32" "float"}}4{{else}}8{{end}}
}
Expand Down Expand Up @@ -330,8 +330,12 @@
}
{{- end -}}
{{- else -}}
{{- if eq .Desc.ContainingOneof nil -}}
if m.{{.GoName | getSafeFieldName }} != 0 {
enc.{{$method}}({{.Desc.Number}}, m.{{.GoName | getSafeFieldName}})
{{- else -}}
if m.{{.GoName | getSafeFieldName }} != nil {
{{- end -}}
enc.{{$method}}({{.Desc.Number}}, {{if ne .Desc.ContainingOneof nil}}*{{end}}m.{{.GoName | getSafeFieldName}})
}
{{- end -}}
{{- end -}}
Expand Down Expand Up @@ -362,8 +366,12 @@
}
{{- end -}}
{{- else -}}
{{- if eq .Desc.ContainingOneof nil -}}
if m.{{.GoName | getSafeFieldName}} {
enc.EncodeBool({{.Desc.Number}}, m.{{.GoName | getSafeFieldName}})
{{- else -}}
if m.{{.GoName | getSafeFieldName }} != nil {
{{- end -}}
enc.EncodeBool({{.Desc.Number}}, {{if ne .Desc.ContainingOneof nil}}*{{end}}m.{{.GoName | getSafeFieldName}})
}
{{- end -}}
{{- end -}}
Expand All @@ -388,8 +396,12 @@
}
{{- end -}}
{{- else -}}
{{- if eq .Desc.ContainingOneof nil -}}
if len(m.{{.GoName | getSafeFieldName }}) > 0 {
enc.EncodeString({{.Desc.Number}}, m.{{.GoName | getSafeFieldName}})
{{- else -}}
if m.{{.GoName | getSafeFieldName }} != nil {
{{- end -}}
enc.EncodeString({{.Desc.Number}}, {{if ne .Desc.ContainingOneof nil}}*{{end}}m.{{.GoName | getSafeFieldName}})
}
{{- end -}}
{{- end -}}
Expand Down Expand Up @@ -451,7 +463,7 @@
}
{{- end -}}
{{- else -}}
enc.{{ $method }}({{.Desc.Number}}, {{$cast}}(m.{{.GoName | getSafeFieldName}}))
enc.{{ $method }}({{.Desc.Number}}, {{$cast}}({{if ne .Desc.ContainingOneof nil}}*{{end}}m.{{.GoName | getSafeFieldName}}))
{{- end -}}
{{- end -}}
{{ end }}
Expand Down Expand Up @@ -485,8 +497,12 @@
}
{{- end -}}
{{- else -}}
{{- if eq .Desc.ContainingOneof nil -}}
if m.{{.GoName | getSafeFieldName}} != 0 {
enc.EncodeInt32({{.Desc.Number}}, int32(m.{{.GoName | getSafeFieldName}}))
{{- else -}}
if m.{{.GoName | getSafeFieldName }} != nil {
{{- end -}}
enc.EncodeInt32({{.Desc.Number}}, int32({{if ne .Desc.ContainingOneof nil}}*{{end}}m.{{.GoName | getSafeFieldName}}))
}
{{- end -}}
{{- end -}}
Expand Down Expand Up @@ -685,6 +701,7 @@
{{- $syntax := (.Desc.Syntax | string) -}}
{{- $kind := (.Desc.Kind | string) -}}
{{- $cardinality := (.Desc.Cardinality | string) -}}
{{- $isOptional := or (eq $syntax "proto2") (ne .Desc.ContainingOneof nil) -}}
{{- if eq $kind "bool" -}}
{{- if eq $cardinality "repeated" -}}
switch wt {
Expand All @@ -710,7 +727,7 @@
if v, err := dec.DecodeBool(); err != nil {
return fmt.Errorf("unable to decode boolean value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Bool(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Bool(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "int32" -}}
Expand Down Expand Up @@ -738,7 +755,7 @@
if v, err := dec.DecodeInt32(); err != nil {
return fmt.Errorf("unable to decode int32 value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Int32(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Int32(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "int64" -}}
Expand Down Expand Up @@ -766,7 +783,7 @@
if v, err := dec.DecodeInt64(); err != nil {
return fmt.Errorf("unable to decode int64 value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Int64(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Int64(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "uint32" -}}
Expand Down Expand Up @@ -794,7 +811,7 @@
if v, err := dec.DecodeUInt32(); err != nil {
return fmt.Errorf("unable to decode int32 value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Uint32(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Uint32(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "uint64" -}}
Expand Down Expand Up @@ -822,7 +839,7 @@
if v, err := dec.DecodeUInt64(); err != nil {
return fmt.Errorf("unable to decode uint64 value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Uint64(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Uint64(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "enum" -}}
Expand Down Expand Up @@ -852,7 +869,7 @@
if v, err := dec.DecodeInt32(); err != nil {
return fmt.Errorf("unable to decode int32 enum value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
{{- if eq $syntax "proto2" -}}
{{- if eq $isOptional true -}}
ev := {{.Enum | getImportPrefix }}{{.Enum.GoIdent.GoName | getSafeFieldName}}(v)
m.{{.GoName | getSafeFieldName}} = &ev
{{- else -}}
Expand Down Expand Up @@ -886,7 +903,7 @@
if v, err := dec.DecodeFixed{{$bitSize}}(); err != nil {
return fmt.Errorf("unable to decode uint{{$bitSize}} value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Uint{{$bitSize}}(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Uint{{$bitSize}}(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "sint32" "sint64" -}}
Expand Down Expand Up @@ -915,7 +932,7 @@
if v, err := dec.DecodeSInt{{$bitSize}}(); err != nil {
return fmt.Errorf("unable to decode sint{{$bitSize}} value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Int{{$bitSize}}(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Int{{$bitSize}}(v){{else}}v{{end}}
}
{{- end -}}
{{- else if eq $kind "float" "double" -}}
Expand Down Expand Up @@ -945,7 +962,7 @@
if v, err := dec.DecodeFloat{{$bitSize}}(); err != nil {
return fmt.Errorf("unable to decode {{$kind}} value for field '{{.Desc.Name}}' (tag={{.Desc.Number}}): %w", err)
} else {
m.{{.GoName | getSafeFieldName}} = {{if eq $syntax "proto2"}}csproto.Float{{$bitSize}}(v){{else}}v{{end}}
m.{{.GoName | getSafeFieldName}} = {{if eq $isOptional true}}csproto.Float{{$bitSize}}(v){{else}}v{{end}}
}
{{- end -}}
{{- else -}}
Expand All @@ -956,6 +973,7 @@
{{ define "UnmarshalString" }}
{{- $syntax := (.Desc.Syntax | string) -}}
{{- $cardinality := (.Desc.Cardinality | string) -}}
{{- $isOptional := or (eq $syntax "proto2") (ne .Desc.ContainingOneof nil) -}}
if wt != csproto.WireTypeLengthDelimited {
return fmt.Errorf("incorrect wire type %v for field '{{.Desc.Name}}' (tag={{.Desc.Number}}), expected 2 (length-delimited)", wt)
}
Expand All @@ -964,7 +982,7 @@
} else {
{{- if eq $cardinality "repeated" -}}
m.{{.GoName | getSafeFieldName}} = append(m.{{.GoName | getSafeFieldName}}, s)
{{- else if eq $syntax "proto2" -}}
{{- else if eq $isOptional true -}}
m.{{.GoName | getSafeFieldName}} = csproto.String(s)
{{- else -}}
m.{{.GoName | getSafeFieldName}} = s
Expand All @@ -986,6 +1004,7 @@
{{ define "UnmarshalSFixed" }}
{{- $syntax := (.Desc.Syntax | string) -}}
{{- $cardinality := (.Desc.Cardinality | string) -}}
{{- $isOptional := or (eq $syntax "proto2") (ne .Desc.ContainingOneof nil) -}}
{{- $bitSize := (.Desc.Kind | string | trunc -2) -}}
if wt != csproto.WireTypeFixed{{$bitSize}} {
return fmt.Errorf("incorrect wire type %v for field '{{.Desc.Name}}' (tag={{.Desc.Number}}), expected {{if eq $bitSize "32" }}5 (32-bit){{else}}1 (64-bit){{end}}", wt)
Expand All @@ -995,7 +1014,7 @@
} else {
{{- if eq $cardinality "repeated" -}}
m.{{.GoName | getSafeFieldName}}= append(m.{{.GoName | getSafeFieldName}}, int{{$bitSize}}(v))
{{- else if eq $syntax "proto2" -}}
{{- else if eq $isOptional true -}}
iv := int{{$bitSize}}(v)
m.{{.GoName | getSafeFieldName}}= &iv
{{- else -}}
Expand Down
14 changes: 10 additions & 4 deletions cmd/protoc-gen-fastmarshal/templates/permessage.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ func (m *{{ .Message.GoIdent.GoName }}) Size() int {
var sz, l int
_ = l // avoid unused variable
{{ range .Message.Fields }}
{{- if eq .Desc.ContainingOneof nil }}
{{- if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}
// {{ .GoName }} ({{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }})
{{ template "SizeOfField" . }}
{{- end -}}
{{ end }}
{{ range .Message.Oneofs }}
{{- template "SizeOfOneOf" . }}
{{- if eq .Desc.IsSynthetic false }}
{{ template "SizeOfOneOf" . }}
{{ end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions .Message }}
Expand Down Expand Up @@ -89,14 +91,16 @@ func (m *{{ .Message.GoIdent.GoName }}) MarshalTo(dest []byte) error {
_ = err
_ = extVal
{{ range .Message.Fields }}
{{- if eq .Desc.ContainingOneof nil }}
{{- if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}
// {{ .GoName }} ({{.Desc.Number}},{{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ template "MarshalField" . }}
{{- end -}}
{{ end }}
{{ range .Message.Oneofs -}}
{{- if eq .Desc.IsSynthetic false -}}
// {{.GoName}} (oneof)
{{ template "MarshalOneOf" . -}}
{{- end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions .Message }}
Expand Down Expand Up @@ -124,11 +128,13 @@ func (m *{{ .Message.GoIdent.GoName}}) Unmarshal(p []byte) error {
}
switch tag {
{{- range .Message.Fields -}}
{{ if eq .Desc.ContainingOneof nil }}case {{ .Desc.Number }}: // {{ .GoName }} ({{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}case {{ .Desc.Number }}: // {{ .GoName }} ({{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ template "UnmarshalField" . }}{{ end }}
{{ end }}
{{ range .Message.Oneofs -}}
{{- if eq .Desc.IsSynthetic false -}}
{{ template "UnmarshalOneOf" . -}}
{{- end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions . }}
Expand Down
14 changes: 10 additions & 4 deletions cmd/protoc-gen-fastmarshal/templates/singlefile.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ func (m *{{ .GoIdent.GoName }}) Size() int {
var sz, l int
_ = l // avoid unused variable
{{ range .Fields }}
{{- if eq .Desc.ContainingOneof nil }}
{{- if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}
// {{ .GoName }} ({{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }})
{{ template "SizeOfField" . }}
{{- end -}}
{{ end }}
{{ range .Oneofs }}
{{- template "SizeOfOneOf" . }}
{{- if eq .Desc.IsSynthetic false }}
{{ template "SizeOfOneOf" . }}
{{ end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions . }}
Expand Down Expand Up @@ -91,14 +93,16 @@ func (m *{{ .GoIdent.GoName }}) MarshalTo(dest []byte) error {
_ = err
_ = extVal
{{ range .Fields }}
{{- if eq .Desc.ContainingOneof nil }}
{{- if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}
// {{ .GoName }} ({{.Desc.Number}},{{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ template "MarshalField" . }}
{{- end -}}
{{ end }}
{{ range .Oneofs -}}
{{- if eq .Desc.IsSynthetic false -}}
// {{.GoName}} (oneof)
{{ template "MarshalOneOf" . -}}
{{- end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions . }}
Expand Down Expand Up @@ -126,11 +130,13 @@ func (m *{{ .GoIdent.GoName}}) Unmarshal(p []byte) error {
}
switch tag {
{{- range .Fields -}}
{{ if eq .Desc.ContainingOneof nil }}case {{ .Desc.Number }}: // {{ .GoName }} ({{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ if or (eq .Desc.ContainingOneof nil) (eq .Desc.ContainingOneof.IsSynthetic true) }}case {{ .Desc.Number }}: // {{ .GoName }} ({{if .Desc.IsMap}}map{{else}}{{.Desc.Kind}},{{.Desc.Cardinality}}{{if .Desc.IsPacked }},packed{{ end }}{{end}})
{{ template "UnmarshalField" . }}{{ end }}
{{ end }}
{{ range .Oneofs -}}
{{- if eq .Desc.IsSynthetic false -}}
{{ template "UnmarshalOneOf" . -}}
{{- end -}}
{{ end }}
{{- if eq $protoSyntax "proto2" -}}
{{ range getExtensions . }}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2cf177e

Please sign in to comment.