From 852c5a80851191eb10d5d9253aaac239eef1d961 Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Sun, 8 May 2022 20:32:46 -0600 Subject: [PATCH] generate: add ability to generate UnsafeReadFrom --- generate/gen.go | 129 +++++++++++++++++++++++++++++------------------- 1 file changed, 78 insertions(+), 51 deletions(-) diff --git a/generate/gen.go b/generate/gen.go index 6282e8c0..d7a193ce 100644 --- a/generate/gen.go +++ b/generate/gen.go @@ -299,16 +299,52 @@ func primDecode(name string, l *LineWriter) { l.Write("v := b.%s()", name) } -func compactDecode(fromFlexible bool, name, typ string, l *LineWriter) { - if fromFlexible { - l.Write("var v %s", typ) +func unsafeDecode(l *LineWriter, fn func(string)) { + l.Write("if unsafe {") + fn("Unsafe") + l.Write("} else {") + fn("") + l.Write("}") +} + +func flexDecode(supports bool, l *LineWriter, fn func(string)) { + if supports { l.Write("if isFlexible {") - l.Write("v = b.Compact%s()", name) + fn("Compact") l.Write("} else {") - l.Write("v = b.%s()", name) - l.Write("}") + defer l.Write("}") + } + fn("") +} + +func primUnsafeDecode(name string, l *LineWriter) { + l.Write("var v string") + unsafeDecode(l, func(u string) { + l.Write("v = b.%s%s()", u, name) + }) +} + +func compactDecode(fromFlexible, hasUnsafe bool, name, typ string, l *LineWriter) { + if fromFlexible { + l.Write("var v %s", typ) + fn := func(u string) { + l.Write("if isFlexible {") + l.Write("v = b.%sCompact%s()", u, name) + l.Write("} else {") + l.Write("v = b.%s%s()", u, name) + l.Write("}") + } + if hasUnsafe { + unsafeDecode(l, fn) + } else { + fn("") + } } else { - l.Write("v := b.%s()", name) + if hasUnsafe { + primUnsafeDecode(name, l) + } else { + primDecode(name, l) + } } } @@ -322,57 +358,44 @@ func (Float64) WriteDecode(l *LineWriter) { primDecode("Float64", l) } func (Uint32) WriteDecode(l *LineWriter) { primDecode("Uint32", l) } func (Varint) WriteDecode(l *LineWriter) { primDecode("Varint", l) } func (Uuid) WriteDecode(l *LineWriter) { primDecode("Uuid", l) } -func (VarintString) WriteDecode(l *LineWriter) { primDecode("VarintString", l) } +func (VarintString) WriteDecode(l *LineWriter) { primUnsafeDecode("VarintString", l) } func (VarintBytes) WriteDecode(l *LineWriter) { primDecode("VarintBytes", l) } func (Throttle) WriteDecode(l *LineWriter) { primDecode("Int32", l) } -func (v String) WriteDecode(l *LineWriter) { compactDecode(v.FromFlexible, "String", "string", l) } -func (v Bytes) WriteDecode(l *LineWriter) { compactDecode(v.FromFlexible, "Bytes", "[]byte", l) } +func (v String) WriteDecode(l *LineWriter) { + compactDecode(v.FromFlexible, true, "String", "string", l) +} + +func (v Bytes) WriteDecode(l *LineWriter) { + compactDecode(v.FromFlexible, false, "Bytes", "[]byte", l) +} + func (v NullableBytes) WriteDecode(l *LineWriter) { - compactDecode(v.FromFlexible, "NullableBytes", "[]byte", l) + compactDecode(v.FromFlexible, false, "NullableBytes", "[]byte", l) } func (v NullableString) WriteDecode(l *LineWriter) { // If there is a nullable version, we write a "read string, then set // pointer" block. + l.Write("var v *string") if v.NullableVersion > 0 { - l.Write("var v *string") l.Write("if version < %d {", v.NullableVersion) l.Write("var vv string") - if v.FromFlexible { - l.Write("if isFlexible {") - l.Write("vv = b.CompactString()") - l.Write("} else {") - l.Write("vv = b.String()") - l.Write("}") - } else { - l.Write("vv = b.String()") - } + flexDecode(v.FromFlexible, l, func(compact string) { + unsafeDecode(l, func(u string) { + l.Write("vv = b.%s%sString()", u, compact) + }) + }) l.Write("v = &vv") l.Write("} else {") defer l.Write("}") } - if v.FromFlexible { - // If we had a nullable version, then we already declared v and - // do not need to again. - if v.NullableVersion == 0 { - l.Write("var v *string") - } - l.Write("if isFlexible {") - l.Write("v = b.CompactNullableString()") - l.Write("} else {") - l.Write("v = b.NullableString()") - l.Write("}") - } else { - // If we had a nullable version, v has been declared and we - // reuse it, if not, we declare v. - if v.NullableVersion == 0 { - l.Write("v := b.NullableString()") - } else { - l.Write("v = b.NullableString()") - } - } + flexDecode(v.FromFlexible, l, func(compact string) { + unsafeDecode(l, func(u string) { + l.Write("v = b.%s%sNullableString()", u, compact) + }) + }) } func (f FieldLengthMinusBytes) WriteDecode(l *LineWriter) { @@ -389,15 +412,9 @@ func (a Array) WriteDecode(l *LineWriter) { if a.IsVarintArray { l.Write("l = b.VarintArrayLen()") } else { - if a.FromFlexible { - l.Write("if isFlexible {") - l.Write("l = b.CompactArrayLen()") - l.Write("} else {") - l.Write("l = b.ArrayLen()") - l.Write("}") - } else { - l.Write("l = b.ArrayLen()") - } + flexDecode(a.FromFlexible, l, func(compact string) { + l.Write("l = b.%sArrayLen()", compact) + }) if a.IsNullableArray { l.Write("if version < %d || l == 0 {", a.NullableVersion) l.Write("a = %s{}", a.TypeName()) @@ -409,8 +426,10 @@ func (a Array) WriteDecode(l *LineWriter) { l.Write("return b.Complete()") l.Write("}") + l.Write("a = a[:0]") + l.Write("if l > 0 {") - l.Write("a = make(%s, l)", a.TypeName()) + l.Write("a = append(a, make(%s, l)...)", a.TypeName()) l.Write("}") l.Write("for i := int32(0); i < l; i++ {") @@ -680,6 +699,14 @@ func (s Struct) WriteAppendFunc(l *LineWriter) { func (s Struct) WriteDecodeFunc(l *LineWriter) { l.Write("func (v *%s) ReadFrom(src []byte) error {", s.Name) + l.Write("return v.readFrom(src, false)") + l.Write("}") + + l.Write("func (v *%s) UnsafeReadFrom(src []byte) error {", s.Name) + l.Write("return v.readFrom(src, true)") + l.Write("}") + + l.Write("func (v *%s) readFrom(src []byte, unsafe bool) error {", s.Name) l.Write("v.Default()") l.Write("b := kbin.Reader{Src: src}") if s.WithVersionField {