Skip to content

Commit

Permalink
generate: add ability to generate UnsafeReadFrom
Browse files Browse the repository at this point in the history
  • Loading branch information
twmb committed May 23, 2022
1 parent 7eb584e commit 852c5a8
Showing 1 changed file with 78 additions and 51 deletions.
129 changes: 78 additions & 51 deletions generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,52 @@ func primDecode(name string, l *LineWriter) {
l.Write("v := b.%s()", name)
}

func compactDecode(fromFlexible bool, name, typ string, l *LineWriter) {
if fromFlexible {
l.Write("var v %s", typ)
func unsafeDecode(l *LineWriter, fn func(string)) {
l.Write("if unsafe {")
fn("Unsafe")
l.Write("} else {")
fn("")
l.Write("}")
}

func flexDecode(supports bool, l *LineWriter, fn func(string)) {
if supports {
l.Write("if isFlexible {")
l.Write("v = b.Compact%s()", name)
fn("Compact")
l.Write("} else {")
l.Write("v = b.%s()", name)
l.Write("}")
defer l.Write("}")
}
fn("")
}

func primUnsafeDecode(name string, l *LineWriter) {
l.Write("var v string")
unsafeDecode(l, func(u string) {
l.Write("v = b.%s%s()", u, name)
})
}

func compactDecode(fromFlexible, hasUnsafe bool, name, typ string, l *LineWriter) {
if fromFlexible {
l.Write("var v %s", typ)
fn := func(u string) {
l.Write("if isFlexible {")
l.Write("v = b.%sCompact%s()", u, name)
l.Write("} else {")
l.Write("v = b.%s%s()", u, name)
l.Write("}")
}
if hasUnsafe {
unsafeDecode(l, fn)
} else {
fn("")
}
} else {
l.Write("v := b.%s()", name)
if hasUnsafe {
primUnsafeDecode(name, l)
} else {
primDecode(name, l)
}
}
}

Expand All @@ -322,57 +358,44 @@ func (Float64) WriteDecode(l *LineWriter) { primDecode("Float64", l) }
func (Uint32) WriteDecode(l *LineWriter) { primDecode("Uint32", l) }
func (Varint) WriteDecode(l *LineWriter) { primDecode("Varint", l) }
func (Uuid) WriteDecode(l *LineWriter) { primDecode("Uuid", l) }
func (VarintString) WriteDecode(l *LineWriter) { primDecode("VarintString", l) }
func (VarintString) WriteDecode(l *LineWriter) { primUnsafeDecode("VarintString", l) }
func (VarintBytes) WriteDecode(l *LineWriter) { primDecode("VarintBytes", l) }
func (Throttle) WriteDecode(l *LineWriter) { primDecode("Int32", l) }

func (v String) WriteDecode(l *LineWriter) { compactDecode(v.FromFlexible, "String", "string", l) }
func (v Bytes) WriteDecode(l *LineWriter) { compactDecode(v.FromFlexible, "Bytes", "[]byte", l) }
func (v String) WriteDecode(l *LineWriter) {
compactDecode(v.FromFlexible, true, "String", "string", l)
}

func (v Bytes) WriteDecode(l *LineWriter) {
compactDecode(v.FromFlexible, false, "Bytes", "[]byte", l)
}

func (v NullableBytes) WriteDecode(l *LineWriter) {
compactDecode(v.FromFlexible, "NullableBytes", "[]byte", l)
compactDecode(v.FromFlexible, false, "NullableBytes", "[]byte", l)
}

func (v NullableString) WriteDecode(l *LineWriter) {
// If there is a nullable version, we write a "read string, then set
// pointer" block.
l.Write("var v *string")
if v.NullableVersion > 0 {
l.Write("var v *string")
l.Write("if version < %d {", v.NullableVersion)
l.Write("var vv string")
if v.FromFlexible {
l.Write("if isFlexible {")
l.Write("vv = b.CompactString()")
l.Write("} else {")
l.Write("vv = b.String()")
l.Write("}")
} else {
l.Write("vv = b.String()")
}
flexDecode(v.FromFlexible, l, func(compact string) {
unsafeDecode(l, func(u string) {
l.Write("vv = b.%s%sString()", u, compact)
})
})
l.Write("v = &vv")
l.Write("} else {")
defer l.Write("}")
}

if v.FromFlexible {
// If we had a nullable version, then we already declared v and
// do not need to again.
if v.NullableVersion == 0 {
l.Write("var v *string")
}
l.Write("if isFlexible {")
l.Write("v = b.CompactNullableString()")
l.Write("} else {")
l.Write("v = b.NullableString()")
l.Write("}")
} else {
// If we had a nullable version, v has been declared and we
// reuse it, if not, we declare v.
if v.NullableVersion == 0 {
l.Write("v := b.NullableString()")
} else {
l.Write("v = b.NullableString()")
}
}
flexDecode(v.FromFlexible, l, func(compact string) {
unsafeDecode(l, func(u string) {
l.Write("v = b.%s%sNullableString()", u, compact)
})
})
}

func (f FieldLengthMinusBytes) WriteDecode(l *LineWriter) {
Expand All @@ -389,15 +412,9 @@ func (a Array) WriteDecode(l *LineWriter) {
if a.IsVarintArray {
l.Write("l = b.VarintArrayLen()")
} else {
if a.FromFlexible {
l.Write("if isFlexible {")
l.Write("l = b.CompactArrayLen()")
l.Write("} else {")
l.Write("l = b.ArrayLen()")
l.Write("}")
} else {
l.Write("l = b.ArrayLen()")
}
flexDecode(a.FromFlexible, l, func(compact string) {
l.Write("l = b.%sArrayLen()", compact)
})
if a.IsNullableArray {
l.Write("if version < %d || l == 0 {", a.NullableVersion)
l.Write("a = %s{}", a.TypeName())
Expand All @@ -409,8 +426,10 @@ func (a Array) WriteDecode(l *LineWriter) {
l.Write("return b.Complete()")
l.Write("}")

l.Write("a = a[:0]")

l.Write("if l > 0 {")
l.Write("a = make(%s, l)", a.TypeName())
l.Write("a = append(a, make(%s, l)...)", a.TypeName())
l.Write("}")

l.Write("for i := int32(0); i < l; i++ {")
Expand Down Expand Up @@ -680,6 +699,14 @@ func (s Struct) WriteAppendFunc(l *LineWriter) {

func (s Struct) WriteDecodeFunc(l *LineWriter) {
l.Write("func (v *%s) ReadFrom(src []byte) error {", s.Name)
l.Write("return v.readFrom(src, false)")
l.Write("}")

l.Write("func (v *%s) UnsafeReadFrom(src []byte) error {", s.Name)
l.Write("return v.readFrom(src, true)")
l.Write("}")

l.Write("func (v *%s) readFrom(src []byte, unsafe bool) error {", s.Name)
l.Write("v.Default()")
l.Write("b := kbin.Reader{Src: src}")
if s.WithVersionField {
Expand Down

0 comments on commit 852c5a8

Please sign in to comment.