From 4dca38ec991f0ea0cca4246f0e53ce32b8f5ed44 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Fri, 12 Jun 2020 22:22:35 -0700 Subject: [PATCH] THRIFT-5233: Handle I/O timeouts in go library Client: go As discussed in the JIRA ticket, this commit changes how we handle I/O timeouts in the go library. This is a breaking change that adds context to all Read*, Write*, and Skip functions to TProtocol, along with the compiler change to support that, and also adds context to TStandardClient.Recv, TDeserializer, TStruct, and a few others. Along with the function signature changes, this commit also implements context cancellation check in the following TProtocol's ReadMessageBegin implementations: - TBinaryProtocol - TCompactProtocol - THeaderProtocol In those ReadMessageBegin implementations, if the passed in context object has a deadline attached, it will keep retrying the I/O timeout errors, until the deadline on the context object passed. They won't retry I/O timeout errors if the passed in context does not have a deadline attached (still return on the first error). --- CHANGES.md | 2 + .../cpp/src/thrift/generate/t_go_generator.cc | 140 +++++----- lib/go/thrift/application_exception.go | 48 ++-- lib/go/thrift/binary_protocol.go | 185 +++++++------- lib/go/thrift/client.go | 20 +- lib/go/thrift/compact_protocol.go | 117 +++++---- lib/go/thrift/debug_protocol.go | 168 ++++++------ lib/go/thrift/deserializer.go | 17 +- lib/go/thrift/header_protocol.go | 182 ++++++------- lib/go/thrift/header_transport.go | 34 ++- lib/go/thrift/header_transport_test.go | 6 +- lib/go/thrift/json_protocol.go | 158 ++++++------ lib/go/thrift/json_protocol_test.go | 62 ++--- lib/go/thrift/multiplexed_protocol.go | 10 +- lib/go/thrift/protocol.go | 134 +++++----- lib/go/thrift/protocol_test.go | 82 +++--- lib/go/thrift/serializer.go | 8 +- lib/go/thrift/serializer_test.go | 10 +- lib/go/thrift/serializer_types_test.go | 241 +++++++++--------- lib/go/thrift/simple_json_protocol.go | 147 +++++------ lib/go/thrift/simple_json_protocol_test.go | 66 ++--- lib/go/thrift/simple_server.go | 2 +- lib/go/thrift/transport_exception.go | 14 + lib/go/thrift/transport_exception_test.go | 31 ++- 24 files changed, 982 insertions(+), 902 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fbaf35dffc8..ceb8f8b6f89 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -14,6 +14,7 @@ - [THRIFT-5138](https://issues.apache.org/jira/browse/THRIFT-5138) - Swift generator does not escape keywords properly - [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - In Go library TProcessor interface now includes ProcessorMap and AddToProcessorMap functions. - [THRIFT-5186](https://issues.apache.org/jira/browse/THRIFT-5186) - cpp: use all getaddrinfo() results when retrying failed bind() in T{Nonblocking,}ServerSocket +- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - go: Now all Read*, Write* and Skip functions in TProtocol accept context arg ### Java @@ -24,6 +25,7 @@ - [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - Add TSerializerPool and TDeserializerPool, which are thread-safe versions of TSerializer and TDeserializer. - [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ClientMiddleware function type and WrapClient function to support wrapping a TClient with middleware functions. - [THRIFT-5164](https://issues.apache.org/jira/browse/THRIFT-5164) - Add ProcessorMiddleware function type and WrapProcessor function to support wrapping a TProcessor with middleware functions. +- [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - Add context deadline check to ReadMessageBegin in TBinaryProtocol, TCompactProtocol, and THeaderProtocol. ## 0.13.0 diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 4f871515991..b89052b5e03 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -888,7 +888,7 @@ string t_go_generator::render_fastbinary_includes() { } /** - * Autogen'd comment. The different text is necessary due to + * Autogen'd comment. The different text is necessary due to * https://github.com/golang/go/issues/13560#issuecomment-288457920 */ string t_go_generator::go_autogen_comment() { @@ -1585,10 +1585,10 @@ void t_go_generator::generate_go_struct_reader(ostream& out, const vector& fields = tstruct->get_members(); vector::const_iterator f_iter; string escaped_tstruct_name(escape_string(tstruct->get_name())); - out << indent() << "func (p *" << tstruct_name << ") " << read_method_name_ << "(iprot thrift.TProtocol) error {" + out << indent() << "func (p *" << tstruct_name << ") " << read_method_name_ << "(ctx context.Context, iprot thrift.TProtocol) error {" << endl; indent_up(); - out << indent() << "if _, err := iprot.ReadStructBegin(); err != nil {" << endl; + out << indent() << "if _, err := iprot.ReadStructBegin(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)" << endl; out << indent() << "}" << endl << endl; @@ -1606,7 +1606,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out, indent(out) << "for {" << endl; indent_up(); // Read beginning field marker - out << indent() << "_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()" << endl; + out << indent() << "_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)" << endl; out << indent() << "if err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(" "\"%T field %d read error: \", p, fieldId), err)" << endl; @@ -1646,7 +1646,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out, } out << indent() << "if fieldTypeId == " << thriftFieldTypeId << " {" << endl; - out << indent() << " if err := p." << field_method_prefix << field_method_suffix << "(iprot); err != nil {" + out << indent() << " if err := p." << field_method_prefix << field_method_suffix << "(ctx, iprot); err != nil {" << endl; out << indent() << " return err" << endl; out << indent() << " }" << endl; @@ -1658,7 +1658,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out, } out << indent() << "} else {" << endl; - out << indent() << " if err := iprot.Skip(fieldTypeId); err != nil {" << endl; + out << indent() << " if err := iprot.Skip(ctx, fieldTypeId); err != nil {" << endl; out << indent() << " return err" << endl; out << indent() << " }" << endl; out << indent() << "}" << endl; @@ -1674,7 +1674,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out, } // Skip unknown fields in either case - out << indent() << "if err := iprot.Skip(fieldTypeId); err != nil {" << endl; + out << indent() << "if err := iprot.Skip(ctx, fieldTypeId); err != nil {" << endl; out << indent() << " return err" << endl; out << indent() << "}" << endl; @@ -1685,12 +1685,12 @@ void t_go_generator::generate_go_struct_reader(ostream& out, } // Read field end marker - out << indent() << "if err := iprot.ReadFieldEnd(); err != nil {" << endl; + out << indent() << "if err := iprot.ReadFieldEnd(ctx); err != nil {" << endl; out << indent() << " return err" << endl; out << indent() << "}" << endl; indent_down(); out << indent() << "}" << endl; - out << indent() << "if err := iprot.ReadStructEnd(); err != nil {" << endl; + out << indent() << "if err := iprot.ReadStructEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(" "\"%T read struct end error: \", p), err)" << endl; out << indent() << "}" << endl; @@ -1723,7 +1723,7 @@ void t_go_generator::generate_go_struct_reader(ostream& out, } out << indent() << "func (p *" << tstruct_name << ") " << field_method_prefix << field_method_suffix - << "(iprot thrift.TProtocol) error {" << endl; + << "(ctx context.Context, iprot thrift.TProtocol) error {" << endl; indent_up(); generate_deserialize_field(out, *f_iter, false, "p."); indent_down(); @@ -1741,7 +1741,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out, string name(tstruct->get_name()); const vector& fields = tstruct->get_sorted_members(); vector::const_iterator f_iter; - indent(out) << "func (p *" << tstruct_name << ") " << write_method_name_ << "(oprot thrift.TProtocol) error {" << endl; + indent(out) << "func (p *" << tstruct_name << ") " << write_method_name_ << "(ctx context.Context, oprot thrift.TProtocol) error {" << endl; indent_up(); if (tstruct->is_union() && uses_countsetfields) { std::string tstruct_name(publicize(tstruct->get_name())); @@ -1750,7 +1750,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out, << " return fmt.Errorf(\"%T write union: exactly one field must be set (%d set).\", p, c)" << endl << indent() << "}" << endl; } - out << indent() << "if err := oprot.WriteStructBegin(\"" << name << "\"); err != nil {" << endl; + out << indent() << "if err := oprot.WriteStructBegin(ctx, \"" << name << "\"); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(" "\"%T write struct begin error: \", p), err) }" << endl; @@ -1776,16 +1776,16 @@ void t_go_generator::generate_go_struct_writer(ostream& out, } out << indent() << "if err := p." << field_method_prefix << field_method_suffix - << "(oprot); err != nil { return err }" << endl; + << "(ctx, oprot); err != nil { return err }" << endl; } indent_down(); out << indent() << "}" << endl; // Write the struct map - out << indent() << "if err := oprot.WriteFieldStop(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteFieldStop(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"write field stop error: \", err) }" << endl; - out << indent() << "if err := oprot.WriteStructEnd(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteStructEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"write struct stop error: \", err) }" << endl; out << indent() << "return nil" << endl; indent_down(); @@ -1806,7 +1806,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out, } out << indent() << "func (p *" << tstruct_name << ") " << field_method_prefix << field_method_suffix - << "(oprot thrift.TProtocol) (err error) {" << endl; + << "(ctx context.Context, oprot thrift.TProtocol) (err error) {" << endl; indent_up(); if (field_required == t_field::T_OPTIONAL) { @@ -1814,7 +1814,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out, indent_up(); } - out << indent() << "if err := oprot.WriteFieldBegin(\"" << escape_field_name << "\", " + out << indent() << "if err := oprot.WriteFieldBegin(ctx, \"" << escape_field_name << "\", " << type_to_enum((*f_iter)->get_type()) << ", " << field_id << "); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write field begin error " << field_id << ":" << escape_field_name << ": \", p), err) }" << endl; @@ -1823,7 +1823,7 @@ void t_go_generator::generate_go_struct_writer(ostream& out, generate_serialize_field(out, *f_iter, "p."); // Write field closer - out << indent() << "if err := oprot.WriteFieldEnd(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteFieldEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write field end error " << field_id << ":" << escape_field_name << ": \", p), err) }" << endl; @@ -2503,7 +2503,7 @@ void t_go_generator::generate_service_remote(t_service* tservice) { << endl; f_remote << indent() << "argvalue" << i << " := " << tstruct_module << ".New" << tstruct_name << "()" << endl; - f_remote << indent() << err2 << " := argvalue" << i << "." << read_method_name_ << "(" << jsProt << ")" << endl; + f_remote << indent() << err2 << " := argvalue" << i << "." << read_method_name_ << "(context.Background(), " << jsProt << ")" << endl; f_remote << indent() << "if " << err2 << " != nil {" << endl; f_remote << indent() << " Usage()" << endl; f_remote << indent() << " return" << endl; @@ -2531,7 +2531,7 @@ void t_go_generator::generate_service_remote(t_service* tservice) { << endl; f_remote << indent() << "containerStruct" << i << " := " << package_name_aliased << ".New" << argumentsName << "()" << endl; - f_remote << indent() << err2 << " := containerStruct" << i << ".ReadField" << (i + 1) << "(" + f_remote << indent() << err2 << " := containerStruct" << i << ".ReadField" << (i + 1) << "(context.Background(), " << jsProt << ")" << endl; f_remote << indent() << "if " << err2 << " != nil {" << endl; f_remote << indent() << " Usage()" << endl; @@ -2698,19 +2698,19 @@ void t_go_generator::generate_service_server(t_service* tservice) { f_types_ << indent() << "func (p *" << serviceName << "Processor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err " "thrift.TException) {" << endl; - f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin()" << endl; + f_types_ << indent() << " name, _, seqId, err := iprot.ReadMessageBegin(ctx)" << endl; f_types_ << indent() << " if err != nil { return false, err }" << endl; f_types_ << indent() << " if processor, ok := p.GetProcessorFunction(name); ok {" << endl; f_types_ << indent() << " return processor.Process(ctx, seqId, iprot, oprot)" << endl; f_types_ << indent() << " }" << endl; - f_types_ << indent() << " iprot.Skip(thrift.STRUCT)" << endl; - f_types_ << indent() << " iprot.ReadMessageEnd()" << endl; + f_types_ << indent() << " iprot.Skip(ctx, thrift.STRUCT)" << endl; + f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl; f_types_ << indent() << " " << x << " := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, \"Unknown function " "\" + name)" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " " << x << ".Write(oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd()" << endl; + f_types_ << indent() << " oprot.WriteMessageBegin(ctx, name, thrift.EXCEPTION, seqId)" << endl; + f_types_ << indent() << " " << x << ".Write(ctx, oprot)" << endl; + f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; f_types_ << indent() << " oprot.Flush(ctx)" << endl; f_types_ << indent() << " return false, " << x << endl; f_types_ << indent() << "" << endl; @@ -2765,21 +2765,21 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* "thrift.TException) {" << endl; indent_up(); f_types_ << indent() << "args := " << argsname << "{}" << endl; - f_types_ << indent() << "if err = args." << read_method_name_ << "(iprot); err != nil {" << endl; - f_types_ << indent() << " iprot.ReadMessageEnd()" << endl; + f_types_ << indent() << "if err = args." << read_method_name_ << "(ctx, iprot); err != nil {" << endl; + f_types_ << indent() << " iprot.ReadMessageEnd(ctx)" << endl; if (!tfunction->is_oneway()) { f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(\"" << escape_string(tfunction->get_name()) + f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd()" << endl; + f_types_ << indent() << " x.Write(ctx, oprot)" << endl; + f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; f_types_ << indent() << " oprot.Flush(ctx)" << endl; } f_types_ << indent() << " return false, err" << endl; f_types_ << indent() << "}" << endl << endl; - f_types_ << indent() << "iprot.ReadMessageEnd()" << endl; + f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl; if (!tfunction->is_oneway()) { f_types_ << indent() << "result := " << resultname << "{}" << endl; @@ -2839,10 +2839,10 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << " x := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " "\"Internal error processing " << escape_string(tfunction->get_name()) << ": \" + err2.Error())" << endl; - f_types_ << indent() << " oprot.WriteMessageBegin(\"" << escape_string(tfunction->get_name()) + f_types_ << indent() << " oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.EXCEPTION, seqId)" << endl; - f_types_ << indent() << " x.Write(oprot)" << endl; - f_types_ << indent() << " oprot.WriteMessageEnd()" << endl; + f_types_ << indent() << " x.Write(ctx, oprot)" << endl; + f_types_ << indent() << " oprot.WriteMessageEnd(ctx)" << endl; f_types_ << indent() << " oprot.Flush(ctx)" << endl; } @@ -2868,15 +2868,15 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* } else { f_types_ << endl; } - f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(\"" + f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \"" << escape_string(tfunction->get_name()) << "\", thrift.REPLY, seqId); err2 != nil {" << endl; f_types_ << indent() << " err = err2" << endl; f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = result." << write_method_name_ << "(oprot); err == nil && err2 != nil {" << endl; + f_types_ << indent() << "if err2 = result." << write_method_name_ << "(ctx, oprot); err == nil && err2 != nil {" << endl; f_types_ << indent() << " err = err2" << endl; f_types_ << indent() << "}" << endl; - f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(); err == nil && err2 != nil {" + f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil && err2 != nil {" << endl; f_types_ << indent() << " err = err2" << endl; f_types_ << indent() << "}" << endl; @@ -2945,42 +2945,42 @@ void t_go_generator::generate_deserialize_field(ostream& out, case t_base_type::TYPE_STRING: if (type->is_binary() && !inkey) { - out << "ReadBinary()"; + out << "ReadBinary(ctx)"; } else { - out << "ReadString()"; + out << "ReadString(ctx)"; } break; case t_base_type::TYPE_BOOL: - out << "ReadBool()"; + out << "ReadBool(ctx)"; break; case t_base_type::TYPE_I8: - out << "ReadByte()"; + out << "ReadByte(ctx)"; break; case t_base_type::TYPE_I16: - out << "ReadI16()"; + out << "ReadI16(ctx)"; break; case t_base_type::TYPE_I32: - out << "ReadI32()"; + out << "ReadI32(ctx)"; break; case t_base_type::TYPE_I64: - out << "ReadI64()"; + out << "ReadI64(ctx)"; break; case t_base_type::TYPE_DOUBLE: - out << "ReadDouble()"; + out << "ReadDouble(ctx)"; break; default: throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase); } } else if (type->is_enum()) { - out << "ReadI32()"; + out << "ReadI32(ctx)"; } out << "; err != nil {" << endl; @@ -3023,7 +3023,7 @@ void t_go_generator::generate_deserialize_struct(ostream& out, out << indent() << prefix << eq << (pointer_field ? "&" : ""); generate_go_struct_initializer(out, tstruct); - out << indent() << "if err := " << prefix << "." << read_method_name_ << "(iprot); err != nil {" << endl; + out << indent() << "if err := " << prefix << "." << read_method_name_ << "(ctx, iprot); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T error reading struct: \", " << prefix << "), err)" << endl; out << indent() << "}" << endl; @@ -3047,21 +3047,21 @@ void t_go_generator::generate_deserialize_container(ostream& out, // Declare variables, read header if (ttype->is_map()) { - out << indent() << "_, _, size, err := iprot.ReadMapBegin()" << endl; + out << indent() << "_, _, size, err := iprot.ReadMapBegin(ctx)" << endl; out << indent() << "if err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading map begin: \", err)" << endl; out << indent() << "}" << endl; out << indent() << "tMap := make(" << type_to_go_type(orig_type) << ", size)" << endl; out << indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tMap" << endl; } else if (ttype->is_set()) { - out << indent() << "_, size, err := iprot.ReadSetBegin()" << endl; + out << indent() << "_, size, err := iprot.ReadSetBegin(ctx)" << endl; out << indent() << "if err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading set begin: \", err)" << endl; out << indent() << "}" << endl; out << indent() << "tSet := make(" << type_to_go_type(orig_type) << ", 0, size)" << endl; out << indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tSet" << endl; } else if (ttype->is_list()) { - out << indent() << "_, size, err := iprot.ReadListBegin()" << endl; + out << indent() << "_, size, err := iprot.ReadListBegin(ctx)" << endl; out << indent() << "if err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading list begin: \", err)" << endl; out << indent() << "}" << endl; @@ -3092,15 +3092,15 @@ void t_go_generator::generate_deserialize_container(ostream& out, // Read container end if (ttype->is_map()) { - out << indent() << "if err := iprot.ReadMapEnd(); err != nil {" << endl; + out << indent() << "if err := iprot.ReadMapEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading map end: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_set()) { - out << indent() << "if err := iprot.ReadSetEnd(); err != nil {" << endl; + out << indent() << "if err := iprot.ReadSetEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading set end: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_list()) { - out << indent() << "if err := iprot.ReadListEnd(); err != nil {" << endl; + out << indent() << "if err := iprot.ReadListEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error reading list end: \", err)" << endl; out << indent() << "}" << endl; } @@ -3194,42 +3194,42 @@ void t_go_generator::generate_serialize_field(ostream& out, case t_base_type::TYPE_STRING: if (type->is_binary() && !inkey) { - out << "WriteBinary(" << name << ")"; + out << "WriteBinary(ctx, " << name << ")"; } else { - out << "WriteString(string(" << name << "))"; + out << "WriteString(ctx, string(" << name << "))"; } break; case t_base_type::TYPE_BOOL: - out << "WriteBool(bool(" << name << "))"; + out << "WriteBool(ctx, bool(" << name << "))"; break; case t_base_type::TYPE_I8: - out << "WriteByte(int8(" << name << "))"; + out << "WriteByte(ctx, int8(" << name << "))"; break; case t_base_type::TYPE_I16: - out << "WriteI16(int16(" << name << "))"; + out << "WriteI16(ctx, int16(" << name << "))"; break; case t_base_type::TYPE_I32: - out << "WriteI32(int32(" << name << "))"; + out << "WriteI32(ctx, int32(" << name << "))"; break; case t_base_type::TYPE_I64: - out << "WriteI64(int64(" << name << "))"; + out << "WriteI64(ctx, int64(" << name << "))"; break; case t_base_type::TYPE_DOUBLE: - out << "WriteDouble(float64(" << name << "))"; + out << "WriteDouble(ctx, float64(" << name << "))"; break; default: throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase); } } else if (type->is_enum()) { - out << "WriteI32(int32(" << name << "))"; + out << "WriteI32(ctx, int32(" << name << "))"; } out << "; err != nil {" << endl; @@ -3250,7 +3250,7 @@ void t_go_generator::generate_serialize_field(ostream& out, */ void t_go_generator::generate_serialize_struct(ostream& out, t_struct* tstruct, string prefix) { (void)tstruct; - out << indent() << "if err := " << prefix << "." << write_method_name_ << "(oprot); err != nil {" << endl; + out << indent() << "if err := " << prefix << "." << write_method_name_ << "(ctx, oprot); err != nil {" << endl; out << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T error writing struct: \", " << prefix << "), err)" << endl; out << indent() << "}" << endl; @@ -3264,20 +3264,20 @@ void t_go_generator::generate_serialize_container(ostream& out, prefix = "*" + prefix; } if (ttype->is_map()) { - out << indent() << "if err := oprot.WriteMapBegin(" + out << indent() << "if err := oprot.WriteMapBegin(ctx, " << type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << "len(" << prefix << ")); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing map begin: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_set()) { - out << indent() << "if err := oprot.WriteSetBegin(" + out << indent() << "if err := oprot.WriteSetBegin(ctx, " << type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << "len(" << prefix << ")); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing set begin: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_list()) { - out << indent() << "if err := oprot.WriteListBegin(" + out << indent() << "if err := oprot.WriteListBegin(ctx, " << type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << "len(" << prefix << ")); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing list begin: \", err)" << endl; @@ -3323,15 +3323,15 @@ void t_go_generator::generate_serialize_container(ostream& out, } if (ttype->is_map()) { - out << indent() << "if err := oprot.WriteMapEnd(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteMapEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing map end: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_set()) { - out << indent() << "if err := oprot.WriteSetEnd(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteSetEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing set end: \", err)" << endl; out << indent() << "}" << endl; } else if (ttype->is_list()) { - out << indent() << "if err := oprot.WriteListEnd(); err != nil {" << endl; + out << indent() << "if err := oprot.WriteListEnd(ctx); err != nil {" << endl; out << indent() << " return thrift.PrependError(\"error writing list end: \", err)" << endl; out << indent() << "}" << endl; } diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go index 0023c57cf1e..6de37ee73f7 100644 --- a/lib/go/thrift/application_exception.go +++ b/lib/go/thrift/application_exception.go @@ -19,6 +19,10 @@ package thrift +import ( + "context" +) + const ( UNKNOWN_APPLICATION_EXCEPTION = 0 UNKNOWN_METHOD = 1 @@ -51,8 +55,8 @@ var defaultApplicationExceptionMessage = map[int32]string{ type TApplicationException interface { TException TypeId() int32 - Read(iprot TProtocol) error - Write(oprot TProtocol) error + Read(ctx context.Context, iprot TProtocol) error + Write(ctx context.Context, oprot TProtocol) error } type tApplicationException struct { @@ -75,9 +79,9 @@ func (p *tApplicationException) TypeId() int32 { return p.type_ } -func (p *tApplicationException) Read(iprot TProtocol) error { +func (p *tApplicationException) Read(ctx context.Context, iprot TProtocol) error { // TODO: this should really be generated by the compiler - _, err := iprot.ReadStructBegin() + _, err := iprot.ReadStructBegin(ctx) if err != nil { return err } @@ -86,7 +90,7 @@ func (p *tApplicationException) Read(iprot TProtocol) error { type_ := int32(UNKNOWN_APPLICATION_EXCEPTION) for { - _, ttype, id, err := iprot.ReadFieldBegin() + _, ttype, id, err := iprot.ReadFieldBegin(ctx) if err != nil { return err } @@ -96,34 +100,34 @@ func (p *tApplicationException) Read(iprot TProtocol) error { switch id { case 1: if ttype == STRING { - if message, err = iprot.ReadString(); err != nil { + if message, err = iprot.ReadString(ctx); err != nil { return err } } else { - if err = SkipDefaultDepth(iprot, ttype); err != nil { + if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil { return err } } case 2: if ttype == I32 { - if type_, err = iprot.ReadI32(); err != nil { + if type_, err = iprot.ReadI32(ctx); err != nil { return err } } else { - if err = SkipDefaultDepth(iprot, ttype); err != nil { + if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil { return err } } default: - if err = SkipDefaultDepth(iprot, ttype); err != nil { + if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil { return err } } - if err = iprot.ReadFieldEnd(); err != nil { + if err = iprot.ReadFieldEnd(ctx); err != nil { return err } } - if err := iprot.ReadStructEnd(); err != nil { + if err := iprot.ReadStructEnd(ctx); err != nil { return err } @@ -133,38 +137,38 @@ func (p *tApplicationException) Read(iprot TProtocol) error { return nil } -func (p *tApplicationException) Write(oprot TProtocol) (err error) { - err = oprot.WriteStructBegin("TApplicationException") +func (p *tApplicationException) Write(ctx context.Context, oprot TProtocol) (err error) { + err = oprot.WriteStructBegin(ctx, "TApplicationException") if len(p.Error()) > 0 { - err = oprot.WriteFieldBegin("message", STRING, 1) + err = oprot.WriteFieldBegin(ctx, "message", STRING, 1) if err != nil { return } - err = oprot.WriteString(p.Error()) + err = oprot.WriteString(ctx, p.Error()) if err != nil { return } - err = oprot.WriteFieldEnd() + err = oprot.WriteFieldEnd(ctx) if err != nil { return } } - err = oprot.WriteFieldBegin("type", I32, 2) + err = oprot.WriteFieldBegin(ctx, "type", I32, 2) if err != nil { return } - err = oprot.WriteI32(p.type_) + err = oprot.WriteI32(ctx, p.type_) if err != nil { return } - err = oprot.WriteFieldEnd() + err = oprot.WriteFieldEnd(ctx) if err != nil { return } - err = oprot.WriteFieldStop() + err = oprot.WriteFieldStop(ctx) if err != nil { return } - err = oprot.WriteStructEnd() + err = oprot.WriteStructEnd(ctx) return } diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go index 93ae898cf5e..c87d23a1bb1 100644 --- a/lib/go/thrift/binary_protocol.go +++ b/lib/go/thrift/binary_protocol.go @@ -72,146 +72,146 @@ func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol { * Writing Methods */ -func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { +func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { if p.strictWrite { version := uint32(VERSION_1) | uint32(typeId) - e := p.WriteI32(int32(version)) + e := p.WriteI32(ctx, int32(version)) if e != nil { return e } - e = p.WriteString(name) + e = p.WriteString(ctx, name) if e != nil { return e } - e = p.WriteI32(seqId) + e = p.WriteI32(ctx, seqId) return e } else { - e := p.WriteString(name) + e := p.WriteString(ctx, name) if e != nil { return e } - e = p.WriteByte(int8(typeId)) + e = p.WriteByte(ctx, int8(typeId)) if e != nil { return e } - e = p.WriteI32(seqId) + e = p.WriteI32(ctx, seqId) return e } return nil } -func (p *TBinaryProtocol) WriteMessageEnd() error { +func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteStructBegin(name string) error { +func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error { return nil } -func (p *TBinaryProtocol) WriteStructEnd() error { +func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { - e := p.WriteByte(int8(typeId)) +func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + e := p.WriteByte(ctx, int8(typeId)) if e != nil { return e } - e = p.WriteI16(id) + e = p.WriteI16(ctx, id) return e } -func (p *TBinaryProtocol) WriteFieldEnd() error { +func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteFieldStop() error { - e := p.WriteByte(STOP) +func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error { + e := p.WriteByte(ctx, STOP) return e } -func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { - e := p.WriteByte(int8(keyType)) +func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { + e := p.WriteByte(ctx, int8(keyType)) if e != nil { return e } - e = p.WriteByte(int8(valueType)) + e = p.WriteByte(ctx, int8(valueType)) if e != nil { return e } - e = p.WriteI32(int32(size)) + e = p.WriteI32(ctx, int32(size)) return e } -func (p *TBinaryProtocol) WriteMapEnd() error { +func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error { - e := p.WriteByte(int8(elemType)) +func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { + e := p.WriteByte(ctx, int8(elemType)) if e != nil { return e } - e = p.WriteI32(int32(size)) + e = p.WriteI32(ctx, int32(size)) return e } -func (p *TBinaryProtocol) WriteListEnd() error { +func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error { - e := p.WriteByte(int8(elemType)) +func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { + e := p.WriteByte(ctx, int8(elemType)) if e != nil { return e } - e = p.WriteI32(int32(size)) + e = p.WriteI32(ctx, int32(size)) return e } -func (p *TBinaryProtocol) WriteSetEnd() error { +func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) WriteBool(value bool) error { +func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error { if value { - return p.WriteByte(1) + return p.WriteByte(ctx, 1) } - return p.WriteByte(0) + return p.WriteByte(ctx, 0) } -func (p *TBinaryProtocol) WriteByte(value int8) error { +func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error { e := p.trans.WriteByte(byte(value)) return NewTProtocolException(e) } -func (p *TBinaryProtocol) WriteI16(value int16) error { +func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error { v := p.buffer[0:2] binary.BigEndian.PutUint16(v, uint16(value)) _, e := p.trans.Write(v) return NewTProtocolException(e) } -func (p *TBinaryProtocol) WriteI32(value int32) error { +func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error { v := p.buffer[0:4] binary.BigEndian.PutUint32(v, uint32(value)) _, e := p.trans.Write(v) return NewTProtocolException(e) } -func (p *TBinaryProtocol) WriteI64(value int64) error { +func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error { v := p.buffer[0:8] binary.BigEndian.PutUint64(v, uint64(value)) _, err := p.trans.Write(v) return NewTProtocolException(err) } -func (p *TBinaryProtocol) WriteDouble(value float64) error { - return p.WriteI64(int64(math.Float64bits(value))) +func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error { + return p.WriteI64(ctx, int64(math.Float64bits(value))) } -func (p *TBinaryProtocol) WriteString(value string) error { - e := p.WriteI32(int32(len(value))) +func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error { + e := p.WriteI32(ctx, int32(len(value))) if e != nil { return e } @@ -219,8 +219,8 @@ func (p *TBinaryProtocol) WriteString(value string) error { return NewTProtocolException(err) } -func (p *TBinaryProtocol) WriteBinary(value []byte) error { - e := p.WriteI32(int32(len(value))) +func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error { + e := p.WriteI32(ctx, int32(len(value))) if e != nil { return e } @@ -232,8 +232,8 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error { * Reading methods */ -func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { - size, e := p.ReadI32() +func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { + size, e := p.ReadI32(ctx) if e != nil { return "", typeId, 0, NewTProtocolException(e) } @@ -243,11 +243,11 @@ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, if version != VERSION_1 { return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin")) } - name, e = p.ReadString() + name, e = p.ReadString(ctx) if e != nil { return name, typeId, seqId, NewTProtocolException(e) } - seqId, e = p.ReadI32() + seqId, e = p.ReadI32(ctx) if e != nil { return name, typeId, seqId, NewTProtocolException(e) } @@ -260,62 +260,62 @@ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, if e2 != nil { return name, typeId, seqId, e2 } - b, e3 := p.ReadByte() + b, e3 := p.ReadByte(ctx) if e3 != nil { return name, typeId, seqId, e3 } typeId = TMessageType(b) - seqId, e4 := p.ReadI32() + seqId, e4 := p.ReadI32(ctx) if e4 != nil { return name, typeId, seqId, e4 } return name, typeId, seqId, nil } -func (p *TBinaryProtocol) ReadMessageEnd() error { +func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) { +func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { return } -func (p *TBinaryProtocol) ReadStructEnd() error { +func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) { - t, err := p.ReadByte() +func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) { + t, err := p.ReadByte(ctx) typeId = TType(t) if err != nil { return name, typeId, seqId, err } if t != STOP { - seqId, err = p.ReadI16() + seqId, err = p.ReadI16(ctx) } return name, typeId, seqId, err } -func (p *TBinaryProtocol) ReadFieldEnd() error { +func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error { return nil } var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length")) -func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) { - k, e := p.ReadByte() +func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) { + k, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } kType = TType(k) - v, e := p.ReadByte() + v, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } vType = TType(v) - size32, e := p.ReadI32() + size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return @@ -328,18 +328,18 @@ func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err erro return kType, vType, size, nil } -func (p *TBinaryProtocol) ReadMapEnd() error { +func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) { - b, e := p.ReadByte() +func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { + b, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) - size32, e := p.ReadI32() + size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return @@ -353,18 +353,18 @@ func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) return } -func (p *TBinaryProtocol) ReadListEnd() error { +func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { - b, e := p.ReadByte() +func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { + b, e := p.ReadByte(ctx) if e != nil { err = NewTProtocolException(e) return } elemType = TType(b) - size32, e := p.ReadI32() + size32, e := p.ReadI32(ctx) if e != nil { err = NewTProtocolException(e) return @@ -377,12 +377,12 @@ func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) { return elemType, size, nil } -func (p *TBinaryProtocol) ReadSetEnd() error { +func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error { return nil } -func (p *TBinaryProtocol) ReadBool() (bool, error) { - b, e := p.ReadByte() +func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) { + b, e := p.ReadByte(ctx) v := true if b != 1 { v = false @@ -390,41 +390,41 @@ func (p *TBinaryProtocol) ReadBool() (bool, error) { return v, e } -func (p *TBinaryProtocol) ReadByte() (int8, error) { +func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) { v, err := p.trans.ReadByte() return int8(v), err } -func (p *TBinaryProtocol) ReadI16() (value int16, err error) { +func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) { buf := p.buffer[0:2] - err = p.readAll(buf) + err = p.readAll(ctx, buf) value = int16(binary.BigEndian.Uint16(buf)) return value, err } -func (p *TBinaryProtocol) ReadI32() (value int32, err error) { +func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) { buf := p.buffer[0:4] - err = p.readAll(buf) + err = p.readAll(ctx, buf) value = int32(binary.BigEndian.Uint32(buf)) return value, err } -func (p *TBinaryProtocol) ReadI64() (value int64, err error) { +func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) { buf := p.buffer[0:8] - err = p.readAll(buf) + err = p.readAll(ctx, buf) value = int64(binary.BigEndian.Uint64(buf)) return value, err } -func (p *TBinaryProtocol) ReadDouble() (value float64, err error) { +func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) { buf := p.buffer[0:8] - err = p.readAll(buf) + err = p.readAll(ctx, buf) value = math.Float64frombits(binary.BigEndian.Uint64(buf)) return value, err } -func (p *TBinaryProtocol) ReadString() (value string, err error) { - size, e := p.ReadI32() +func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) { + size, e := p.ReadI32(ctx) if e != nil { return "", e } @@ -436,8 +436,8 @@ func (p *TBinaryProtocol) ReadString() (value string, err error) { return p.readStringBody(size) } -func (p *TBinaryProtocol) ReadBinary() ([]byte, error) { - size, e := p.ReadI32() +func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) { + size, e := p.ReadI32(ctx) if e != nil { return nil, e } @@ -455,16 +455,27 @@ func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) { return NewTProtocolException(p.trans.Flush(ctx)) } -func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { - return SkipDefaultDepth(p, fieldType) +func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + return SkipDefaultDepth(ctx, p, fieldType) } func (p *TBinaryProtocol) Transport() TTransport { return p.origTransport } -func (p *TBinaryProtocol) readAll(buf []byte) error { - _, err := io.ReadFull(p.trans, buf) +func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) { + var read int + _, deadlineSet := ctx.Deadline() + for { + read, err = io.ReadFull(p.trans, buf) + if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil { + // This is I/O timeout without anything read, + // and we still have time left, keep retrying. + continue + } + // For anything else, don't retry + break + } return NewTProtocolException(err) } diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go index b073a952d92..1c5705d5522 100644 --- a/lib/go/thrift/client.go +++ b/lib/go/thrift/client.go @@ -34,20 +34,20 @@ func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32 } } - if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil { + if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil { return err } - if err := args.Write(oprot); err != nil { + if err := args.Write(ctx, oprot); err != nil { return err } - if err := oprot.WriteMessageEnd(); err != nil { + if err := oprot.WriteMessageEnd(ctx); err != nil { return err } return oprot.Flush(ctx) } -func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error { - rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin() +func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error { + rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx) if err != nil { return err } @@ -58,11 +58,11 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method)) } else if rTypeId == EXCEPTION { var exception tApplicationException - if err := exception.Read(iprot); err != nil { + if err := exception.Read(ctx, iprot); err != nil { return err } - if err := iprot.ReadMessageEnd(); err != nil { + if err := iprot.ReadMessageEnd(ctx); err != nil { return err } @@ -71,11 +71,11 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method)) } - if err := result.Read(iprot); err != nil { + if err := result.Read(ctx, iprot); err != nil { return err } - return iprot.ReadMessageEnd() + return iprot.ReadMessageEnd(ctx) } func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error { @@ -91,5 +91,5 @@ func (p *TStandardClient) Call(ctx context.Context, method string, args, result return nil } - return p.Recv(p.iprot, seqId, method, result) + return p.Recv(ctx, p.iprot, seqId, method, result) } diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go index 1900d50c3b1..468935781de 100644 --- a/lib/go/thrift/compact_protocol.go +++ b/lib/go/thrift/compact_protocol.go @@ -125,7 +125,7 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol { // Write a message header to the wire. Compact Protocol messages contain the // protocol version so we can migrate forwards in the future if need be. -func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { +func (p *TCompactProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error { err := p.writeByteDirect(COMPACT_PROTOCOL_ID) if err != nil { return NewTProtocolException(err) @@ -138,17 +138,17 @@ func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, s if err != nil { return NewTProtocolException(err) } - e := p.WriteString(name) + e := p.WriteString(ctx, name) return e } -func (p *TCompactProtocol) WriteMessageEnd() error { return nil } +func (p *TCompactProtocol) WriteMessageEnd(ctx context.Context) error { return nil } // Write a struct begin. This doesn't actually put anything on the wire. We // use it as an opportunity to put special placeholder markers on the field // stack so we can get the field id deltas correct. -func (p *TCompactProtocol) WriteStructBegin(name string) error { +func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) error { p.lastField = append(p.lastField, p.lastFieldId) p.lastFieldId = 0 return nil @@ -157,26 +157,26 @@ func (p *TCompactProtocol) WriteStructBegin(name string) error { // Write a struct end. This doesn't actually put anything on the wire. We use // this as an opportunity to pop the last field from the current struct off // of the field stack. -func (p *TCompactProtocol) WriteStructEnd() error { +func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error { p.lastFieldId = p.lastField[len(p.lastField)-1] p.lastField = p.lastField[:len(p.lastField)-1] return nil } -func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { +func (p *TCompactProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { if typeId == BOOL { // we want to possibly include the value, so we'll wait. p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true return nil } - _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF) + _, err := p.writeFieldBeginInternal(ctx, name, typeId, id, 0xFF) return NewTProtocolException(err) } // The workhorse of writeFieldBegin. It has the option of doing a // 'type override' of the type header. This is used specifically in the // boolean field case. -func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) { +func (p *TCompactProtocol) writeFieldBeginInternal(ctx context.Context, name string, typeId TType, id int16, typeOverride byte) (int, error) { // short lastField = lastField_.pop(); // if there's a type override, use that. @@ -201,7 +201,7 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id if err != nil { return 0, err } - err = p.WriteI16(id) + err = p.WriteI16(ctx, id) written = 1 + 2 if err != nil { return 0, err @@ -213,14 +213,14 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id return written, nil } -func (p *TCompactProtocol) WriteFieldEnd() error { return nil } +func (p *TCompactProtocol) WriteFieldEnd(ctx context.Context) error { return nil } -func (p *TCompactProtocol) WriteFieldStop() error { +func (p *TCompactProtocol) WriteFieldStop(ctx context.Context) error { err := p.writeByteDirect(STOP) return NewTProtocolException(err) } -func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { +func (p *TCompactProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { if size == 0 { err := p.writeByteDirect(0) return NewTProtocolException(err) @@ -233,32 +233,32 @@ func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size in return NewTProtocolException(err) } -func (p *TCompactProtocol) WriteMapEnd() error { return nil } +func (p *TCompactProtocol) WriteMapEnd(ctx context.Context) error { return nil } // Write a list header. -func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error { +func (p *TCompactProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { _, err := p.writeCollectionBegin(elemType, size) return NewTProtocolException(err) } -func (p *TCompactProtocol) WriteListEnd() error { return nil } +func (p *TCompactProtocol) WriteListEnd(ctx context.Context) error { return nil } // Write a set header. -func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error { +func (p *TCompactProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { _, err := p.writeCollectionBegin(elemType, size) return NewTProtocolException(err) } -func (p *TCompactProtocol) WriteSetEnd() error { return nil } +func (p *TCompactProtocol) WriteSetEnd(ctx context.Context) error { return nil } -func (p *TCompactProtocol) WriteBool(value bool) error { +func (p *TCompactProtocol) WriteBool(ctx context.Context, value bool) error { v := byte(COMPACT_BOOLEAN_FALSE) if value { v = byte(COMPACT_BOOLEAN_TRUE) } if p.booleanFieldPending { // we haven't written the field header yet - _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v) + _, err := p.writeFieldBeginInternal(ctx, p.booleanFieldName, BOOL, p.booleanFieldId, v) p.booleanFieldPending = false return NewTProtocolException(err) } @@ -268,31 +268,31 @@ func (p *TCompactProtocol) WriteBool(value bool) error { } // Write a byte. Nothing to see here! -func (p *TCompactProtocol) WriteByte(value int8) error { +func (p *TCompactProtocol) WriteByte(ctx context.Context, value int8) error { err := p.writeByteDirect(byte(value)) return NewTProtocolException(err) } // Write an I16 as a zigzag varint. -func (p *TCompactProtocol) WriteI16(value int16) error { +func (p *TCompactProtocol) WriteI16(ctx context.Context, value int16) error { _, err := p.writeVarint32(p.int32ToZigzag(int32(value))) return NewTProtocolException(err) } // Write an i32 as a zigzag varint. -func (p *TCompactProtocol) WriteI32(value int32) error { +func (p *TCompactProtocol) WriteI32(ctx context.Context, value int32) error { _, err := p.writeVarint32(p.int32ToZigzag(value)) return NewTProtocolException(err) } // Write an i64 as a zigzag varint. -func (p *TCompactProtocol) WriteI64(value int64) error { +func (p *TCompactProtocol) WriteI64(ctx context.Context, value int64) error { _, err := p.writeVarint64(p.int64ToZigzag(value)) return NewTProtocolException(err) } // Write a double to the wire as 8 bytes. -func (p *TCompactProtocol) WriteDouble(value float64) error { +func (p *TCompactProtocol) WriteDouble(ctx context.Context, value float64) error { buf := p.buffer[0:8] binary.LittleEndian.PutUint64(buf, math.Float64bits(value)) _, err := p.trans.Write(buf) @@ -300,7 +300,7 @@ func (p *TCompactProtocol) WriteDouble(value float64) error { } // Write a string to the wire with a varint size preceding. -func (p *TCompactProtocol) WriteString(value string) error { +func (p *TCompactProtocol) WriteString(ctx context.Context, value string) error { _, e := p.writeVarint32(int32(len(value))) if e != nil { return NewTProtocolException(e) @@ -312,7 +312,7 @@ func (p *TCompactProtocol) WriteString(value string) error { } // Write a byte array, using a varint for the size. -func (p *TCompactProtocol) WriteBinary(bin []byte) error { +func (p *TCompactProtocol) WriteBinary(ctx context.Context, bin []byte) error { _, e := p.writeVarint32(int32(len(bin))) if e != nil { return NewTProtocolException(e) @@ -329,9 +329,20 @@ func (p *TCompactProtocol) WriteBinary(bin []byte) error { // // Read a message header. -func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { +func (p *TCompactProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { + var protocolId byte - protocolId, err := p.readByteDirect() + _, deadlineSet := ctx.Deadline() + for { + protocolId, err = p.readByteDirect() + if deadlineSet && isTimeoutError(err) && ctx.Err() == nil { + // keep retrying I/O timeout errors since we still have + // time left + continue + } + // For anything else, don't retry + break + } if err != nil { return } @@ -358,15 +369,15 @@ func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, err = NewTProtocolException(e) return } - name, err = p.ReadString() + name, err = p.ReadString(ctx) return } -func (p *TCompactProtocol) ReadMessageEnd() error { return nil } +func (p *TCompactProtocol) ReadMessageEnd(ctx context.Context) error { return nil } // Read a struct begin. There's nothing on the wire for this, but it is our // opportunity to push a new struct begin marker onto the field stack. -func (p *TCompactProtocol) ReadStructBegin() (name string, err error) { +func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { p.lastField = append(p.lastField, p.lastFieldId) p.lastFieldId = 0 return @@ -374,7 +385,7 @@ func (p *TCompactProtocol) ReadStructBegin() (name string, err error) { // Doesn't actually consume any wire data, just removes the last field for // this struct from the field stack. -func (p *TCompactProtocol) ReadStructEnd() error { +func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error { // consume the last field we read off the wire. p.lastFieldId = p.lastField[len(p.lastField)-1] p.lastField = p.lastField[:len(p.lastField)-1] @@ -382,7 +393,7 @@ func (p *TCompactProtocol) ReadStructEnd() error { } // Read a field header off the wire. -func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { +func (p *TCompactProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) { t, err := p.readByteDirect() if err != nil { return @@ -397,7 +408,7 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16 modifier := int16((t & 0xf0) >> 4) if modifier == 0 { // not a delta. look ahead for the zigzag varint field id. - id, err = p.ReadI16() + id, err = p.ReadI16(ctx) if err != nil { return } @@ -423,12 +434,12 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16 return } -func (p *TCompactProtocol) ReadFieldEnd() error { return nil } +func (p *TCompactProtocol) ReadFieldEnd(ctx context.Context) error { return nil } // Read a map header off the wire. If the size is zero, skip reading the key // and value type. This means that 0-length maps will yield TMaps without the // "correct" types. -func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { +func (p *TCompactProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) { size32, e := p.readVarint32() if e != nil { err = NewTProtocolException(e) @@ -452,13 +463,13 @@ func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size return } -func (p *TCompactProtocol) ReadMapEnd() error { return nil } +func (p *TCompactProtocol) ReadMapEnd(ctx context.Context) error { return nil } // Read a list header off the wire. If the list size is 0-14, the size will // be packed into the element type header. If it's a longer list, the 4 MSB // of the element type header will be 0xF, and a varint will follow with the // true size. -func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) { +func (p *TCompactProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { size_and_type, err := p.readByteDirect() if err != nil { return @@ -484,22 +495,22 @@ func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) return } -func (p *TCompactProtocol) ReadListEnd() error { return nil } +func (p *TCompactProtocol) ReadListEnd(ctx context.Context) error { return nil } // Read a set header off the wire. If the set size is 0-14, the size will // be packed into the element type header. If it's a longer set, the 4 MSB // of the element type header will be 0xF, and a varint will follow with the // true size. -func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) { - return p.ReadListBegin() +func (p *TCompactProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { + return p.ReadListBegin(ctx) } -func (p *TCompactProtocol) ReadSetEnd() error { return nil } +func (p *TCompactProtocol) ReadSetEnd(ctx context.Context) error { return nil } // Read a boolean off the wire. If this is a boolean field, the value should // already have been read during readFieldBegin, so we'll just consume the // pre-stored value. Otherwise, read a byte. -func (p *TCompactProtocol) ReadBool() (value bool, err error) { +func (p *TCompactProtocol) ReadBool(ctx context.Context) (value bool, err error) { if p.boolValueIsNotNull { p.boolValueIsNotNull = false return p.boolValue, nil @@ -509,7 +520,7 @@ func (p *TCompactProtocol) ReadBool() (value bool, err error) { } // Read a single byte off the wire. Nothing interesting here. -func (p *TCompactProtocol) ReadByte() (int8, error) { +func (p *TCompactProtocol) ReadByte(ctx context.Context) (int8, error) { v, err := p.readByteDirect() if err != nil { return 0, NewTProtocolException(err) @@ -518,13 +529,13 @@ func (p *TCompactProtocol) ReadByte() (int8, error) { } // Read an i16 from the wire as a zigzag varint. -func (p *TCompactProtocol) ReadI16() (value int16, err error) { - v, err := p.ReadI32() +func (p *TCompactProtocol) ReadI16(ctx context.Context) (value int16, err error) { + v, err := p.ReadI32(ctx) return int16(v), err } // Read an i32 from the wire as a zigzag varint. -func (p *TCompactProtocol) ReadI32() (value int32, err error) { +func (p *TCompactProtocol) ReadI32(ctx context.Context) (value int32, err error) { v, e := p.readVarint32() if e != nil { return 0, NewTProtocolException(e) @@ -534,7 +545,7 @@ func (p *TCompactProtocol) ReadI32() (value int32, err error) { } // Read an i64 from the wire as a zigzag varint. -func (p *TCompactProtocol) ReadI64() (value int64, err error) { +func (p *TCompactProtocol) ReadI64(ctx context.Context) (value int64, err error) { v, e := p.readVarint64() if e != nil { return 0, NewTProtocolException(e) @@ -544,7 +555,7 @@ func (p *TCompactProtocol) ReadI64() (value int64, err error) { } // No magic here - just read a double off the wire. -func (p *TCompactProtocol) ReadDouble() (value float64, err error) { +func (p *TCompactProtocol) ReadDouble(ctx context.Context) (value float64, err error) { longBits := p.buffer[0:8] _, e := io.ReadFull(p.trans, longBits) if e != nil { @@ -554,7 +565,7 @@ func (p *TCompactProtocol) ReadDouble() (value float64, err error) { } // Reads a []byte (via readBinary), and then UTF-8 decodes it. -func (p *TCompactProtocol) ReadString() (value string, err error) { +func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err error) { length, e := p.readVarint32() if e != nil { return "", NewTProtocolException(e) @@ -577,7 +588,7 @@ func (p *TCompactProtocol) ReadString() (value string, err error) { } // Read a []byte from the wire. -func (p *TCompactProtocol) ReadBinary() (value []byte, err error) { +func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err error) { length, e := p.readVarint32() if e != nil { return nil, NewTProtocolException(e) @@ -598,8 +609,8 @@ func (p *TCompactProtocol) Flush(ctx context.Context) (err error) { return NewTProtocolException(p.trans.Flush(ctx)) } -func (p *TCompactProtocol) Skip(fieldType TType) (err error) { - return SkipDefaultDepth(p, fieldType) +func (p *TCompactProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + return SkipDefaultDepth(ctx, p, fieldType) } func (p *TCompactProtocol) Transport() TTransport { diff --git a/lib/go/thrift/debug_protocol.go b/lib/go/thrift/debug_protocol.go index c33fba879ad..a0920031a65 100644 --- a/lib/go/thrift/debug_protocol.go +++ b/lib/go/thrift/debug_protocol.go @@ -70,214 +70,214 @@ func (tdp *TDebugProtocol) logf(format string, v ...interface{}) { fallbackLogger(tdp.Logger)(fmt.Sprintf(format, v...)) } -func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { - err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid) +func (tdp *TDebugProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error { + err := tdp.Delegate.WriteMessageBegin(ctx, name, typeId, seqid) tdp.logf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err) return err } -func (tdp *TDebugProtocol) WriteMessageEnd() error { - err := tdp.Delegate.WriteMessageEnd() +func (tdp *TDebugProtocol) WriteMessageEnd(ctx context.Context) error { + err := tdp.Delegate.WriteMessageEnd(ctx) tdp.logf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteStructBegin(name string) error { - err := tdp.Delegate.WriteStructBegin(name) +func (tdp *TDebugProtocol) WriteStructBegin(ctx context.Context, name string) error { + err := tdp.Delegate.WriteStructBegin(ctx, name) tdp.logf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err) return err } -func (tdp *TDebugProtocol) WriteStructEnd() error { - err := tdp.Delegate.WriteStructEnd() +func (tdp *TDebugProtocol) WriteStructEnd(ctx context.Context) error { + err := tdp.Delegate.WriteStructEnd(ctx) tdp.logf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { - err := tdp.Delegate.WriteFieldBegin(name, typeId, id) +func (tdp *TDebugProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + err := tdp.Delegate.WriteFieldBegin(ctx, name, typeId, id) tdp.logf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err) return err } -func (tdp *TDebugProtocol) WriteFieldEnd() error { - err := tdp.Delegate.WriteFieldEnd() +func (tdp *TDebugProtocol) WriteFieldEnd(ctx context.Context) error { + err := tdp.Delegate.WriteFieldEnd(ctx) tdp.logf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteFieldStop() error { - err := tdp.Delegate.WriteFieldStop() +func (tdp *TDebugProtocol) WriteFieldStop(ctx context.Context) error { + err := tdp.Delegate.WriteFieldStop(ctx) tdp.logf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { - err := tdp.Delegate.WriteMapBegin(keyType, valueType, size) +func (tdp *TDebugProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { + err := tdp.Delegate.WriteMapBegin(ctx, keyType, valueType, size) tdp.logf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err) return err } -func (tdp *TDebugProtocol) WriteMapEnd() error { - err := tdp.Delegate.WriteMapEnd() +func (tdp *TDebugProtocol) WriteMapEnd(ctx context.Context) error { + err := tdp.Delegate.WriteMapEnd(ctx) tdp.logf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error { - err := tdp.Delegate.WriteListBegin(elemType, size) +func (tdp *TDebugProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { + err := tdp.Delegate.WriteListBegin(ctx, elemType, size) tdp.logf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) return err } -func (tdp *TDebugProtocol) WriteListEnd() error { - err := tdp.Delegate.WriteListEnd() +func (tdp *TDebugProtocol) WriteListEnd(ctx context.Context) error { + err := tdp.Delegate.WriteListEnd(ctx) tdp.logf("%sWriteListEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error { - err := tdp.Delegate.WriteSetBegin(elemType, size) +func (tdp *TDebugProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { + err := tdp.Delegate.WriteSetBegin(ctx, elemType, size) tdp.logf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err) return err } -func (tdp *TDebugProtocol) WriteSetEnd() error { - err := tdp.Delegate.WriteSetEnd() +func (tdp *TDebugProtocol) WriteSetEnd(ctx context.Context) error { + err := tdp.Delegate.WriteSetEnd(ctx) tdp.logf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err) return err } -func (tdp *TDebugProtocol) WriteBool(value bool) error { - err := tdp.Delegate.WriteBool(value) +func (tdp *TDebugProtocol) WriteBool(ctx context.Context, value bool) error { + err := tdp.Delegate.WriteBool(ctx, value) tdp.logf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteByte(value int8) error { - err := tdp.Delegate.WriteByte(value) +func (tdp *TDebugProtocol) WriteByte(ctx context.Context, value int8) error { + err := tdp.Delegate.WriteByte(ctx, value) tdp.logf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteI16(value int16) error { - err := tdp.Delegate.WriteI16(value) +func (tdp *TDebugProtocol) WriteI16(ctx context.Context, value int16) error { + err := tdp.Delegate.WriteI16(ctx, value) tdp.logf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteI32(value int32) error { - err := tdp.Delegate.WriteI32(value) +func (tdp *TDebugProtocol) WriteI32(ctx context.Context, value int32) error { + err := tdp.Delegate.WriteI32(ctx, value) tdp.logf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteI64(value int64) error { - err := tdp.Delegate.WriteI64(value) +func (tdp *TDebugProtocol) WriteI64(ctx context.Context, value int64) error { + err := tdp.Delegate.WriteI64(ctx, value) tdp.logf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteDouble(value float64) error { - err := tdp.Delegate.WriteDouble(value) +func (tdp *TDebugProtocol) WriteDouble(ctx context.Context, value float64) error { + err := tdp.Delegate.WriteDouble(ctx, value) tdp.logf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteString(value string) error { - err := tdp.Delegate.WriteString(value) +func (tdp *TDebugProtocol) WriteString(ctx context.Context, value string) error { + err := tdp.Delegate.WriteString(ctx, value) tdp.logf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) WriteBinary(value []byte) error { - err := tdp.Delegate.WriteBinary(value) +func (tdp *TDebugProtocol) WriteBinary(ctx context.Context, value []byte) error { + err := tdp.Delegate.WriteBinary(ctx, value) tdp.logf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err) return err } -func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { - name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin() +func (tdp *TDebugProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) { + name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin(ctx) tdp.logf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err) return } -func (tdp *TDebugProtocol) ReadMessageEnd() (err error) { - err = tdp.Delegate.ReadMessageEnd() +func (tdp *TDebugProtocol) ReadMessageEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadMessageEnd(ctx) tdp.logf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) { - name, err = tdp.Delegate.ReadStructBegin() +func (tdp *TDebugProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { + name, err = tdp.Delegate.ReadStructBegin(ctx) tdp.logf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err) return } -func (tdp *TDebugProtocol) ReadStructEnd() (err error) { - err = tdp.Delegate.ReadStructEnd() +func (tdp *TDebugProtocol) ReadStructEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadStructEnd(ctx) tdp.logf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) { - name, typeId, id, err = tdp.Delegate.ReadFieldBegin() +func (tdp *TDebugProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) { + name, typeId, id, err = tdp.Delegate.ReadFieldBegin(ctx) tdp.logf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err) return } -func (tdp *TDebugProtocol) ReadFieldEnd() (err error) { - err = tdp.Delegate.ReadFieldEnd() +func (tdp *TDebugProtocol) ReadFieldEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadFieldEnd(ctx) tdp.logf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { - keyType, valueType, size, err = tdp.Delegate.ReadMapBegin() +func (tdp *TDebugProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) { + keyType, valueType, size, err = tdp.Delegate.ReadMapBegin(ctx) tdp.logf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err) return } -func (tdp *TDebugProtocol) ReadMapEnd() (err error) { - err = tdp.Delegate.ReadMapEnd() +func (tdp *TDebugProtocol) ReadMapEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadMapEnd(ctx) tdp.logf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) { - elemType, size, err = tdp.Delegate.ReadListBegin() +func (tdp *TDebugProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadListBegin(ctx) tdp.logf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) return } -func (tdp *TDebugProtocol) ReadListEnd() (err error) { - err = tdp.Delegate.ReadListEnd() +func (tdp *TDebugProtocol) ReadListEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadListEnd(ctx) tdp.logf("%sReadListEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) { - elemType, size, err = tdp.Delegate.ReadSetBegin() +func (tdp *TDebugProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { + elemType, size, err = tdp.Delegate.ReadSetBegin(ctx) tdp.logf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err) return } -func (tdp *TDebugProtocol) ReadSetEnd() (err error) { - err = tdp.Delegate.ReadSetEnd() +func (tdp *TDebugProtocol) ReadSetEnd(ctx context.Context) (err error) { + err = tdp.Delegate.ReadSetEnd(ctx) tdp.logf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err) return } -func (tdp *TDebugProtocol) ReadBool() (value bool, err error) { - value, err = tdp.Delegate.ReadBool() +func (tdp *TDebugProtocol) ReadBool(ctx context.Context) (value bool, err error) { + value, err = tdp.Delegate.ReadBool(ctx) tdp.logf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadByte() (value int8, err error) { - value, err = tdp.Delegate.ReadByte() +func (tdp *TDebugProtocol) ReadByte(ctx context.Context) (value int8, err error) { + value, err = tdp.Delegate.ReadByte(ctx) tdp.logf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadI16() (value int16, err error) { - value, err = tdp.Delegate.ReadI16() +func (tdp *TDebugProtocol) ReadI16(ctx context.Context) (value int16, err error) { + value, err = tdp.Delegate.ReadI16(ctx) tdp.logf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadI32() (value int32, err error) { - value, err = tdp.Delegate.ReadI32() +func (tdp *TDebugProtocol) ReadI32(ctx context.Context) (value int32, err error) { + value, err = tdp.Delegate.ReadI32(ctx) tdp.logf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadI64() (value int64, err error) { - value, err = tdp.Delegate.ReadI64() +func (tdp *TDebugProtocol) ReadI64(ctx context.Context) (value int64, err error) { + value, err = tdp.Delegate.ReadI64(ctx) tdp.logf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) { - value, err = tdp.Delegate.ReadDouble() +func (tdp *TDebugProtocol) ReadDouble(ctx context.Context) (value float64, err error) { + value, err = tdp.Delegate.ReadDouble(ctx) tdp.logf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadString() (value string, err error) { - value, err = tdp.Delegate.ReadString() +func (tdp *TDebugProtocol) ReadString(ctx context.Context) (value string, err error) { + value, err = tdp.Delegate.ReadString(ctx) tdp.logf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) { - value, err = tdp.Delegate.ReadBinary() +func (tdp *TDebugProtocol) ReadBinary(ctx context.Context) (value []byte, err error) { + value, err = tdp.Delegate.ReadBinary(ctx) tdp.logf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err) return } -func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) { - err = tdp.Delegate.Skip(fieldType) +func (tdp *TDebugProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + err = tdp.Delegate.Skip(ctx, fieldType) tdp.logf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err) return } diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go index 2ab82145ec0..e1203a868e1 100644 --- a/lib/go/thrift/deserializer.go +++ b/lib/go/thrift/deserializer.go @@ -20,6 +20,7 @@ package thrift import ( + "context" "sync" ) @@ -38,27 +39,27 @@ func NewTDeserializer() *TDeserializer { protocol} } -func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) { +func (t *TDeserializer) ReadString(ctx context.Context, msg TStruct, s string) (err error) { t.Transport.Reset() err = nil if _, err = t.Transport.Write([]byte(s)); err != nil { return } - if err = msg.Read(t.Protocol); err != nil { + if err = msg.Read(ctx, t.Protocol); err != nil { return } return } -func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) { +func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err error) { t.Transport.Reset() err = nil if _, err = t.Transport.Write(b); err != nil { return } - if err = msg.Read(t.Protocol); err != nil { + if err = msg.Read(ctx, t.Protocol); err != nil { return } return @@ -85,14 +86,14 @@ func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool { } } -func (t *TDeserializerPool) ReadString(msg TStruct, s string) error { +func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error { d := t.pool.Get().(*TDeserializer) defer t.pool.Put(d) - return d.ReadString(msg, s) + return d.ReadString(ctx, msg, s) } -func (t *TDeserializerPool) Read(msg TStruct, b []byte) error { +func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error { d := t.pool.Get().(*TDeserializer) defer t.pool.Put(d) - return d.Read(msg, b) + return d.Read(ctx, msg, b) } diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go index 99deaf7dcff..9ee703645c0 100644 --- a/lib/go/thrift/header_protocol.go +++ b/lib/go/thrift/header_protocol.go @@ -95,106 +95,106 @@ func (p *THeaderProtocol) Flush(ctx context.Context) error { return p.transport.Flush(ctx) } -func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error { +func (p *THeaderProtocol) WriteMessageBegin(ctx context.Context, name string, typeID TMessageType, seqID int32) error { newProto, err := p.transport.Protocol().GetProtocol(p.transport) if err != nil { return err } p.protocol = newProto p.transport.SequenceID = seqID - return p.protocol.WriteMessageBegin(name, typeID, seqID) + return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID) } -func (p *THeaderProtocol) WriteMessageEnd() error { - if err := p.protocol.WriteMessageEnd(); err != nil { +func (p *THeaderProtocol) WriteMessageEnd(ctx context.Context) error { + if err := p.protocol.WriteMessageEnd(ctx); err != nil { return err } return p.transport.Flush(context.Background()) } -func (p *THeaderProtocol) WriteStructBegin(name string) error { - return p.protocol.WriteStructBegin(name) +func (p *THeaderProtocol) WriteStructBegin(ctx context.Context, name string) error { + return p.protocol.WriteStructBegin(ctx, name) } -func (p *THeaderProtocol) WriteStructEnd() error { - return p.protocol.WriteStructEnd() +func (p *THeaderProtocol) WriteStructEnd(ctx context.Context) error { + return p.protocol.WriteStructEnd(ctx) } -func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error { - return p.protocol.WriteFieldBegin(name, typeID, id) +func (p *THeaderProtocol) WriteFieldBegin(ctx context.Context, name string, typeID TType, id int16) error { + return p.protocol.WriteFieldBegin(ctx, name, typeID, id) } -func (p *THeaderProtocol) WriteFieldEnd() error { - return p.protocol.WriteFieldEnd() +func (p *THeaderProtocol) WriteFieldEnd(ctx context.Context) error { + return p.protocol.WriteFieldEnd(ctx) } -func (p *THeaderProtocol) WriteFieldStop() error { - return p.protocol.WriteFieldStop() +func (p *THeaderProtocol) WriteFieldStop(ctx context.Context) error { + return p.protocol.WriteFieldStop(ctx) } -func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { - return p.protocol.WriteMapBegin(keyType, valueType, size) +func (p *THeaderProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { + return p.protocol.WriteMapBegin(ctx, keyType, valueType, size) } -func (p *THeaderProtocol) WriteMapEnd() error { - return p.protocol.WriteMapEnd() +func (p *THeaderProtocol) WriteMapEnd(ctx context.Context) error { + return p.protocol.WriteMapEnd(ctx) } -func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error { - return p.protocol.WriteListBegin(elemType, size) +func (p *THeaderProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { + return p.protocol.WriteListBegin(ctx, elemType, size) } -func (p *THeaderProtocol) WriteListEnd() error { - return p.protocol.WriteListEnd() +func (p *THeaderProtocol) WriteListEnd(ctx context.Context) error { + return p.protocol.WriteListEnd(ctx) } -func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error { - return p.protocol.WriteSetBegin(elemType, size) +func (p *THeaderProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { + return p.protocol.WriteSetBegin(ctx, elemType, size) } -func (p *THeaderProtocol) WriteSetEnd() error { - return p.protocol.WriteSetEnd() +func (p *THeaderProtocol) WriteSetEnd(ctx context.Context) error { + return p.protocol.WriteSetEnd(ctx) } -func (p *THeaderProtocol) WriteBool(value bool) error { - return p.protocol.WriteBool(value) +func (p *THeaderProtocol) WriteBool(ctx context.Context, value bool) error { + return p.protocol.WriteBool(ctx, value) } -func (p *THeaderProtocol) WriteByte(value int8) error { - return p.protocol.WriteByte(value) +func (p *THeaderProtocol) WriteByte(ctx context.Context, value int8) error { + return p.protocol.WriteByte(ctx, value) } -func (p *THeaderProtocol) WriteI16(value int16) error { - return p.protocol.WriteI16(value) +func (p *THeaderProtocol) WriteI16(ctx context.Context, value int16) error { + return p.protocol.WriteI16(ctx, value) } -func (p *THeaderProtocol) WriteI32(value int32) error { - return p.protocol.WriteI32(value) +func (p *THeaderProtocol) WriteI32(ctx context.Context, value int32) error { + return p.protocol.WriteI32(ctx, value) } -func (p *THeaderProtocol) WriteI64(value int64) error { - return p.protocol.WriteI64(value) +func (p *THeaderProtocol) WriteI64(ctx context.Context, value int64) error { + return p.protocol.WriteI64(ctx, value) } -func (p *THeaderProtocol) WriteDouble(value float64) error { - return p.protocol.WriteDouble(value) +func (p *THeaderProtocol) WriteDouble(ctx context.Context, value float64) error { + return p.protocol.WriteDouble(ctx, value) } -func (p *THeaderProtocol) WriteString(value string) error { - return p.protocol.WriteString(value) +func (p *THeaderProtocol) WriteString(ctx context.Context, value string) error { + return p.protocol.WriteString(ctx, value) } -func (p *THeaderProtocol) WriteBinary(value []byte) error { - return p.protocol.WriteBinary(value) +func (p *THeaderProtocol) WriteBinary(ctx context.Context, value []byte) error { + return p.protocol.WriteBinary(ctx, value) } // ReadFrame calls underlying THeaderTransport's ReadFrame function. -func (p *THeaderProtocol) ReadFrame() error { - return p.transport.ReadFrame() +func (p *THeaderProtocol) ReadFrame(ctx context.Context) error { + return p.transport.ReadFrame(ctx) } -func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) { - if err = p.transport.ReadFrame(); err != nil { +func (p *THeaderProtocol) ReadMessageBegin(ctx context.Context) (name string, typeID TMessageType, seqID int32, err error) { + if err = p.transport.ReadFrame(ctx); err != nil { return } @@ -205,103 +205,103 @@ func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, if !ok { return } - if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil { + if e := p.protocol.WriteMessageBegin(ctx, "", EXCEPTION, seqID); e != nil { return } - if e := tAppExc.Write(p.protocol); e != nil { + if e := tAppExc.Write(ctx, p.protocol); e != nil { return } - if e := p.protocol.WriteMessageEnd(); e != nil { + if e := p.protocol.WriteMessageEnd(ctx); e != nil { return } - if e := p.transport.Flush(context.Background()); e != nil { + if e := p.transport.Flush(ctx); e != nil { return } return } p.protocol = newProto - return p.protocol.ReadMessageBegin() + return p.protocol.ReadMessageBegin(ctx) } -func (p *THeaderProtocol) ReadMessageEnd() error { - return p.protocol.ReadMessageEnd() +func (p *THeaderProtocol) ReadMessageEnd(ctx context.Context) error { + return p.protocol.ReadMessageEnd(ctx) } -func (p *THeaderProtocol) ReadStructBegin() (name string, err error) { - return p.protocol.ReadStructBegin() +func (p *THeaderProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { + return p.protocol.ReadStructBegin(ctx) } -func (p *THeaderProtocol) ReadStructEnd() error { - return p.protocol.ReadStructEnd() +func (p *THeaderProtocol) ReadStructEnd(ctx context.Context) error { + return p.protocol.ReadStructEnd(ctx) } -func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) { - return p.protocol.ReadFieldBegin() +func (p *THeaderProtocol) ReadFieldBegin(ctx context.Context) (name string, typeID TType, id int16, err error) { + return p.protocol.ReadFieldBegin(ctx) } -func (p *THeaderProtocol) ReadFieldEnd() error { - return p.protocol.ReadFieldEnd() +func (p *THeaderProtocol) ReadFieldEnd(ctx context.Context) error { + return p.protocol.ReadFieldEnd(ctx) } -func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) { - return p.protocol.ReadMapBegin() +func (p *THeaderProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) { + return p.protocol.ReadMapBegin(ctx) } -func (p *THeaderProtocol) ReadMapEnd() error { - return p.protocol.ReadMapEnd() +func (p *THeaderProtocol) ReadMapEnd(ctx context.Context) error { + return p.protocol.ReadMapEnd(ctx) } -func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) { - return p.protocol.ReadListBegin() +func (p *THeaderProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) { + return p.protocol.ReadListBegin(ctx) } -func (p *THeaderProtocol) ReadListEnd() error { - return p.protocol.ReadListEnd() +func (p *THeaderProtocol) ReadListEnd(ctx context.Context) error { + return p.protocol.ReadListEnd(ctx) } -func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) { - return p.protocol.ReadSetBegin() +func (p *THeaderProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) { + return p.protocol.ReadSetBegin(ctx) } -func (p *THeaderProtocol) ReadSetEnd() error { - return p.protocol.ReadSetEnd() +func (p *THeaderProtocol) ReadSetEnd(ctx context.Context) error { + return p.protocol.ReadSetEnd(ctx) } -func (p *THeaderProtocol) ReadBool() (value bool, err error) { - return p.protocol.ReadBool() +func (p *THeaderProtocol) ReadBool(ctx context.Context) (value bool, err error) { + return p.protocol.ReadBool(ctx) } -func (p *THeaderProtocol) ReadByte() (value int8, err error) { - return p.protocol.ReadByte() +func (p *THeaderProtocol) ReadByte(ctx context.Context) (value int8, err error) { + return p.protocol.ReadByte(ctx) } -func (p *THeaderProtocol) ReadI16() (value int16, err error) { - return p.protocol.ReadI16() +func (p *THeaderProtocol) ReadI16(ctx context.Context) (value int16, err error) { + return p.protocol.ReadI16(ctx) } -func (p *THeaderProtocol) ReadI32() (value int32, err error) { - return p.protocol.ReadI32() +func (p *THeaderProtocol) ReadI32(ctx context.Context) (value int32, err error) { + return p.protocol.ReadI32(ctx) } -func (p *THeaderProtocol) ReadI64() (value int64, err error) { - return p.protocol.ReadI64() +func (p *THeaderProtocol) ReadI64(ctx context.Context) (value int64, err error) { + return p.protocol.ReadI64(ctx) } -func (p *THeaderProtocol) ReadDouble() (value float64, err error) { - return p.protocol.ReadDouble() +func (p *THeaderProtocol) ReadDouble(ctx context.Context) (value float64, err error) { + return p.protocol.ReadDouble(ctx) } -func (p *THeaderProtocol) ReadString() (value string, err error) { - return p.protocol.ReadString() +func (p *THeaderProtocol) ReadString(ctx context.Context) (value string, err error) { + return p.protocol.ReadString(ctx) } -func (p *THeaderProtocol) ReadBinary() (value []byte, err error) { - return p.protocol.ReadBinary() +func (p *THeaderProtocol) ReadBinary(ctx context.Context) (value []byte, err error) { + return p.protocol.ReadBinary(ctx) } -func (p *THeaderProtocol) Skip(fieldType TType) error { - return p.protocol.Skip(fieldType) +func (p *THeaderProtocol) Skip(ctx context.Context, fieldType TType) error { + return p.protocol.Skip(ctx, fieldType) } // GetResponseHeadersFromClient is a helper function to get the read THeaderMap diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go index 85d296d6b4c..3affeb59310 100644 --- a/lib/go/thrift/header_transport.go +++ b/lib/go/thrift/header_transport.go @@ -297,18 +297,34 @@ func (t *THeaderTransport) IsOpen() bool { // ReadFrame tries to read the frame header, guess the client type, and handle // unframed clients. -func (t *THeaderTransport) ReadFrame() error { +func (t *THeaderTransport) ReadFrame(ctx context.Context) error { if !t.needReadFrame() { // No need to read frame, skipping. return nil } + // Peek and handle the first 32 bits. // They could either be the length field of a framed message, // or the first bytes of an unframed message. - buf, err := t.reader.Peek(size32) + var buf []byte + var err error + // This is also usually the first read from a connection, + // so handle retries around socket timeouts. + _, deadlineSet := ctx.Deadline() + for { + buf, err = t.reader.Peek(size32) + if deadlineSet && isTimeoutError(err) && ctx.Err() == nil { + // This is I/O timeout and we still have time, + // continue trying + continue + } + // For anything else, do not retry + break + } if err != nil { return err } + frameSize := binary.BigEndian.Uint32(buf) if frameSize&VERSION_MASK == VERSION_1 { t.clientType = clientUnframedBinary @@ -341,7 +357,7 @@ func (t *THeaderTransport) ReadFrame() error { version := binary.BigEndian.Uint32(buf) if version&THeaderHeaderMask == THeaderHeaderMagic { t.clientType = clientHeaders - return t.parseHeaders(frameSize) + return t.parseHeaders(ctx, frameSize) } if version&VERSION_MASK == VERSION_1 { t.clientType = clientFramedBinary @@ -371,7 +387,7 @@ func (t *THeaderTransport) endOfFrame() error { return t.frameReader.Close() } -func (t *THeaderTransport) parseHeaders(frameSize uint32) error { +func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) error { if t.clientType != clientHeaders { return nil } @@ -451,11 +467,11 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error { return err } for i := 0; i < int(count); i++ { - key, err := hp.ReadString() + key, err := hp.ReadString(ctx) if err != nil { return err } - value, err := hp.ReadString() + value, err := hp.ReadString(ctx) if err != nil { return err } @@ -485,7 +501,7 @@ func (t *THeaderTransport) needReadFrame() bool { } func (t *THeaderTransport) Read(p []byte) (read int, err error) { - err = t.ReadFrame() + err = t.ReadFrame(context.Background()) if err != nil { return } @@ -557,10 +573,10 @@ func (t *THeaderTransport) Flush(ctx context.Context) error { return NewTTransportExceptionFromError(err) } for key, value := range t.writeHeaders { - if err := hp.WriteString(key); err != nil { + if err := hp.WriteString(ctx, key); err != nil { return NewTTransportExceptionFromError(err) } - if err := hp.WriteString(value); err != nil { + if err := hp.WriteString(ctx, value); err != nil { return NewTTransportExceptionFromError(err) } } diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go index e3ae41b02d8..cee1cadc370 100644 --- a/lib/go/thrift/header_transport_test.go +++ b/lib/go/thrift/header_transport_test.go @@ -78,17 +78,17 @@ func TestTHeaderHeadersReadWrite(t *testing.T) { // Read // Make sure multiple calls to ReadFrame is fine. - if err := reader.ReadFrame(); err != nil { + if err := reader.ReadFrame(context.Background()); err != nil { t.Errorf("reader.ReadFrame returned error: %v", err) } - if err := reader.ReadFrame(); err != nil { + if err := reader.ReadFrame(context.Background()); err != nil { t.Errorf("reader.ReadFrame returned error: %v", err) } read, err := ioutil.ReadAll(reader) if err != nil { t.Errorf("Read returned error: %v", err) } - if err := reader.ReadFrame(); err != nil && err != io.EOF { + if err := reader.ReadFrame(context.Background()); err != nil && err != io.EOF { t.Errorf("reader.ReadFrame returned error: %v", err) } if string(read) != payload1+payload2 { diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go index 800ac22c7bc..9a9328dc7dc 100644 --- a/lib/go/thrift/json_protocol.go +++ b/lib/go/thrift/json_protocol.go @@ -57,43 +57,43 @@ func NewTJSONProtocolFactory() *TJSONProtocolFactory { return &TJSONProtocolFactory{} } -func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { +func (p *TJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil { + if e := p.WriteI32(ctx, THRIFT_JSON_PROTOCOL_VERSION); e != nil { return e } - if e := p.WriteString(name); e != nil { + if e := p.WriteString(ctx, name); e != nil { return e } - if e := p.WriteByte(int8(typeId)); e != nil { + if e := p.WriteByte(ctx, int8(typeId)); e != nil { return e } - if e := p.WriteI32(seqId); e != nil { + if e := p.WriteI32(ctx, seqId); e != nil { return e } return nil } -func (p *TJSONProtocol) WriteMessageEnd() error { +func (p *TJSONProtocol) WriteMessageEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TJSONProtocol) WriteStructBegin(name string) error { +func (p *TJSONProtocol) WriteStructBegin(ctx context.Context, name string) error { if e := p.OutputObjectBegin(); e != nil { return e } return nil } -func (p *TJSONProtocol) WriteStructEnd() error { +func (p *TJSONProtocol) WriteStructEnd(ctx context.Context) error { return p.OutputObjectEnd() } -func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { - if e := p.WriteI16(id); e != nil { +func (p *TJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + if e := p.WriteI16(ctx, id); e != nil { return e } if e := p.OutputObjectBegin(); e != nil { @@ -103,19 +103,19 @@ func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) err if e1 != nil { return e1 } - if e := p.WriteString(s); e != nil { + if e := p.WriteString(ctx, s); e != nil { return e } return nil } -func (p *TJSONProtocol) WriteFieldEnd() error { +func (p *TJSONProtocol) WriteFieldEnd(ctx context.Context) error { return p.OutputObjectEnd() } -func (p *TJSONProtocol) WriteFieldStop() error { return nil } +func (p *TJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil } -func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { +func (p *TJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } @@ -123,77 +123,77 @@ func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) if e1 != nil { return e1 } - if e := p.WriteString(s); e != nil { + if e := p.WriteString(ctx, s); e != nil { return e } s, e1 = p.TypeIdToString(valueType) if e1 != nil { return e1 } - if e := p.WriteString(s); e != nil { + if e := p.WriteString(ctx, s); e != nil { return e } - if e := p.WriteI64(int64(size)); e != nil { + if e := p.WriteI64(ctx, int64(size)); e != nil { return e } return p.OutputObjectBegin() } -func (p *TJSONProtocol) WriteMapEnd() error { +func (p *TJSONProtocol) WriteMapEnd(ctx context.Context) error { if e := p.OutputObjectEnd(); e != nil { return e } return p.OutputListEnd() } -func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error { +func (p *TJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } -func (p *TJSONProtocol) WriteListEnd() error { +func (p *TJSONProtocol) WriteListEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error { +func (p *TJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } -func (p *TJSONProtocol) WriteSetEnd() error { +func (p *TJSONProtocol) WriteSetEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TJSONProtocol) WriteBool(b bool) error { +func (p *TJSONProtocol) WriteBool(ctx context.Context, b bool) error { if b { - return p.WriteI32(1) + return p.WriteI32(ctx, 1) } - return p.WriteI32(0) + return p.WriteI32(ctx, 0) } -func (p *TJSONProtocol) WriteByte(b int8) error { - return p.WriteI32(int32(b)) +func (p *TJSONProtocol) WriteByte(ctx context.Context, b int8) error { + return p.WriteI32(ctx, int32(b)) } -func (p *TJSONProtocol) WriteI16(v int16) error { - return p.WriteI32(int32(v)) +func (p *TJSONProtocol) WriteI16(ctx context.Context, v int16) error { + return p.WriteI32(ctx, int32(v)) } -func (p *TJSONProtocol) WriteI32(v int32) error { +func (p *TJSONProtocol) WriteI32(ctx context.Context, v int32) error { return p.OutputI64(int64(v)) } -func (p *TJSONProtocol) WriteI64(v int64) error { +func (p *TJSONProtocol) WriteI64(ctx context.Context, v int64) error { return p.OutputI64(int64(v)) } -func (p *TJSONProtocol) WriteDouble(v float64) error { +func (p *TJSONProtocol) WriteDouble(ctx context.Context, v float64) error { return p.OutputF64(v) } -func (p *TJSONProtocol) WriteString(v string) error { +func (p *TJSONProtocol) WriteString(ctx context.Context, v string) error { return p.OutputString(v) } -func (p *TJSONProtocol) WriteBinary(v []byte) error { +func (p *TJSONProtocol) WriteBinary(ctx context.Context, v []byte) error { // JSON library only takes in a string, // not an arbitrary byte array, to ensure bytes are transmitted // efficiently we must convert this into a valid JSON string @@ -219,12 +219,12 @@ func (p *TJSONProtocol) WriteBinary(v []byte) error { } // Reading methods. -func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { +func (p *TJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } - version, err := p.ReadI32() + version, err := p.ReadI32(ctx) if err != nil { return name, typeId, seqId, err } @@ -233,47 +233,47 @@ func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, se return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e) } - if name, err = p.ReadString(); err != nil { + if name, err = p.ReadString(ctx); err != nil { return name, typeId, seqId, err } - bTypeId, err := p.ReadByte() + bTypeId, err := p.ReadByte(ctx) typeId = TMessageType(bTypeId) if err != nil { return name, typeId, seqId, err } - if seqId, err = p.ReadI32(); err != nil { + if seqId, err = p.ReadI32(ctx); err != nil { return name, typeId, seqId, err } return name, typeId, seqId, nil } -func (p *TJSONProtocol) ReadMessageEnd() error { +func (p *TJSONProtocol) ReadMessageEnd(ctx context.Context) error { err := p.ParseListEnd() return err } -func (p *TJSONProtocol) ReadStructBegin() (name string, err error) { +func (p *TJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { _, err = p.ParseObjectStart() return "", err } -func (p *TJSONProtocol) ReadStructEnd() error { +func (p *TJSONProtocol) ReadStructEnd(ctx context.Context) error { return p.ParseObjectEnd() } -func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { +func (p *TJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) { b, _ := p.reader.Peek(1) if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] { return "", STOP, -1, nil } - fieldId, err := p.ReadI16() + fieldId, err := p.ReadI16(ctx) if err != nil { return "", STOP, fieldId, err } if _, err = p.ParseObjectStart(); err != nil { return "", STOP, fieldId, err } - sType, err := p.ReadString() + sType, err := p.ReadString(ctx) if err != nil { return "", STOP, fieldId, err } @@ -281,17 +281,17 @@ func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { return "", fType, fieldId, err } -func (p *TJSONProtocol) ReadFieldEnd() error { +func (p *TJSONProtocol) ReadFieldEnd(ctx context.Context) error { return p.ParseObjectEnd() } -func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { +func (p *TJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, VOID, 0, e } // read keyType - sKeyType, e := p.ReadString() + sKeyType, e := p.ReadString(ctx) if e != nil { return keyType, valueType, size, e } @@ -301,7 +301,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int } // read valueType - sValueType, e := p.ReadString() + sValueType, e := p.ReadString(ctx) if e != nil { return keyType, valueType, size, e } @@ -311,7 +311,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int } // read size - iSize, e := p.ReadI64() + iSize, e := p.ReadI64(ctx) if e != nil { return keyType, valueType, size, e } @@ -321,7 +321,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int return keyType, valueType, size, e } -func (p *TJSONProtocol) ReadMapEnd() error { +func (p *TJSONProtocol) ReadMapEnd(ctx context.Context) error { e := p.ParseObjectEnd() if e != nil { return e @@ -329,53 +329,53 @@ func (p *TJSONProtocol) ReadMapEnd() error { return p.ParseListEnd() } -func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { +func (p *TJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) { return p.ParseElemListBegin() } -func (p *TJSONProtocol) ReadListEnd() error { +func (p *TJSONProtocol) ReadListEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { +func (p *TJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) { return p.ParseElemListBegin() } -func (p *TJSONProtocol) ReadSetEnd() error { +func (p *TJSONProtocol) ReadSetEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TJSONProtocol) ReadBool() (bool, error) { - value, err := p.ReadI32() +func (p *TJSONProtocol) ReadBool(ctx context.Context) (bool, error) { + value, err := p.ReadI32(ctx) return (value != 0), err } -func (p *TJSONProtocol) ReadByte() (int8, error) { - v, err := p.ReadI64() +func (p *TJSONProtocol) ReadByte(ctx context.Context) (int8, error) { + v, err := p.ReadI64(ctx) return int8(v), err } -func (p *TJSONProtocol) ReadI16() (int16, error) { - v, err := p.ReadI64() +func (p *TJSONProtocol) ReadI16(ctx context.Context) (int16, error) { + v, err := p.ReadI64(ctx) return int16(v), err } -func (p *TJSONProtocol) ReadI32() (int32, error) { - v, err := p.ReadI64() +func (p *TJSONProtocol) ReadI32(ctx context.Context) (int32, error) { + v, err := p.ReadI64(ctx) return int32(v), err } -func (p *TJSONProtocol) ReadI64() (int64, error) { +func (p *TJSONProtocol) ReadI64(ctx context.Context) (int64, error) { v, _, err := p.ParseI64() return v, err } -func (p *TJSONProtocol) ReadDouble() (float64, error) { +func (p *TJSONProtocol) ReadDouble(ctx context.Context) (float64, error) { v, _, err := p.ParseF64() return v, err } -func (p *TJSONProtocol) ReadString() (string, error) { +func (p *TJSONProtocol) ReadString(ctx context.Context) (string, error) { var v string if err := p.ParsePreValue(); err != nil { return v, err @@ -405,7 +405,7 @@ func (p *TJSONProtocol) ReadString() (string, error) { return v, p.ParsePostValue() } -func (p *TJSONProtocol) ReadBinary() ([]byte, error) { +func (p *TJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) { var v []byte if err := p.ParsePreValue(); err != nil { return nil, err @@ -444,8 +444,8 @@ func (p *TJSONProtocol) Flush(ctx context.Context) (err error) { return NewTProtocolException(err) } -func (p *TJSONProtocol) Skip(fieldType TType) (err error) { - return SkipDefaultDepth(p, fieldType) +func (p *TJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + return SkipDefaultDepth(ctx, p, fieldType) } func (p *TJSONProtocol) Transport() TTransport { @@ -460,10 +460,10 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { if e1 != nil { return e1 } - if e := p.WriteString(s); e != nil { + if e := p.OutputString(s); e != nil { return e } - if e := p.WriteI64(int64(size)); e != nil { + if e := p.OutputI64(int64(size)); e != nil { return e } return nil @@ -473,7 +473,11 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } - sElemType, err := p.ReadString() + // We don't really use the ctx in ReadString implementation, + // so this is safe for now. + // We might want to add context to ParseElemListBegin if we start to use + // ctx in ReadString implementation in the future. + sElemType, err := p.ReadString(context.Background()) if err != nil { return VOID, size, err } @@ -481,7 +485,7 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) if err != nil { return elemType, size, err } - nSize, err2 := p.ReadI64() + nSize, _, err2 := p.ParseI64() size = int(nSize) return elemType, size, err2 } @@ -490,7 +494,11 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } - sElemType, err := p.ReadString() + // We don't really use the ctx in ReadString implementation, + // so this is safe for now. + // We might want to add context to ParseElemListBegin if we start to use + // ctx in ReadString implementation in the future. + sElemType, err := p.ReadString(context.Background()) if err != nil { return VOID, size, err } @@ -498,7 +506,7 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) if err != nil { return elemType, size, err } - nSize, err2 := p.ReadI64() + nSize, _, err2 := p.ParseI64() size = int(nSize) return elemType, size, err2 } diff --git a/lib/go/thrift/json_protocol_test.go b/lib/go/thrift/json_protocol_test.go index 59c4d64a260..07afa9683b7 100644 --- a/lib/go/thrift/json_protocol_test.go +++ b/lib/go/thrift/json_protocol_test.go @@ -34,7 +34,7 @@ func TestWriteJSONProtocolBool(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range BOOL_VALUES { - if e := p.WriteBool(value); e != nil { + if e := p.WriteBool(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -71,7 +71,7 @@ func TestReadJSONProtocolBool(t *testing.T) { } trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadBool() + v, e := p.ReadBool(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -92,7 +92,7 @@ func TestWriteJSONProtocolByte(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range BYTE_VALUES { - if e := p.WriteByte(value); e != nil { + if e := p.WriteByte(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -119,7 +119,7 @@ func TestReadJSONProtocolByte(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadByte() + v, e := p.ReadByte(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -139,7 +139,7 @@ func TestWriteJSONProtocolI16(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT16_VALUES { - if e := p.WriteI16(value); e != nil { + if e := p.WriteI16(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -166,7 +166,7 @@ func TestReadJSONProtocolI16(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI16() + v, e := p.ReadI16(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -186,7 +186,7 @@ func TestWriteJSONProtocolI32(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT32_VALUES { - if e := p.WriteI32(value); e != nil { + if e := p.WriteI32(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -213,7 +213,7 @@ func TestReadJSONProtocolI32(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI32() + v, e := p.ReadI32(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -233,7 +233,7 @@ func TestWriteJSONProtocolI64(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range INT64_VALUES { - if e := p.WriteI64(value); e != nil { + if e := p.WriteI64(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -260,7 +260,7 @@ func TestReadJSONProtocolI64(t *testing.T) { trans.WriteString(strconv.FormatInt(value, 10)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI64() + v, e := p.ReadI64(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -280,7 +280,7 @@ func TestWriteJSONProtocolDouble(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -322,7 +322,7 @@ func TestReadJSONProtocolDouble(t *testing.T) { trans.WriteString(n.String()) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadDouble() + v, e := p.ReadDouble(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -356,7 +356,7 @@ func TestWriteJSONProtocolString(t *testing.T) { trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) for _, value := range STRING_VALUES { - if e := p.WriteString(value); e != nil { + if e := p.WriteString(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -383,7 +383,7 @@ func TestReadJSONProtocolString(t *testing.T) { trans.WriteString(jsonQuote(value)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadString() + v, e := p.ReadString(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -407,7 +407,7 @@ func TestWriteJSONProtocolBinary(t *testing.T) { b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) - if e := p.WriteBinary(value); e != nil { + if e := p.WriteBinary(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -418,7 +418,7 @@ func TestWriteJSONProtocolBinary(t *testing.T) { if s != expectedString { t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString) } - v1, err := p.ReadBinary() + v1, err := p.ReadBinary(context.Background()) if err != nil { t.Fatalf("Unable to read binary: %s", err.Error()) } @@ -444,7 +444,7 @@ func TestReadJSONProtocolBinary(t *testing.T) { trans.WriteString(jsonQuote(b64String)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadBinary() + v, e := p.ReadBinary(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -468,13 +468,13 @@ func TestWriteJSONProtocolList(t *testing.T) { thetype := "list" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) - p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteListBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteListEnd() + p.WriteListEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -522,13 +522,13 @@ func TestWriteJSONProtocolSet(t *testing.T) { thetype := "set" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) - p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteSetBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteSetEnd() + p.WriteSetEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -576,16 +576,16 @@ func TestWriteJSONProtocolMap(t *testing.T) { thetype := "map" trans := NewTMemoryBuffer() p := NewTJSONProtocol(trans) - p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteMapBegin(context.Background(), TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) for k, value := range DOUBLE_VALUES { - if e := p.WriteI32(int32(k)); e != nil { + if e := p.WriteI32(context.Background(), int32(k)); e != nil { t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) } - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteMapEnd() + p.WriteMapEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -593,7 +593,7 @@ func TestWriteJSONProtocolMap(t *testing.T) { if str[0] != '[' || str[len(str)-1] != ']' { t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES) } - expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin() + expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin(context.Background()) if err != nil { t.Fatalf("Error while reading map begin: %s", err.Error()) } @@ -607,14 +607,14 @@ func TestWriteJSONProtocolMap(t *testing.T) { t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize) } for k, value := range DOUBLE_VALUES { - ik, err := p.ReadI32() + ik, err := p.ReadI32(context.Background()) if err != nil { t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error()) } if int(ik) != k { t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k) } - dv, err := p.ReadDouble() + dv, err := p.ReadDouble(context.Background()) if err != nil { t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error()) } @@ -642,7 +642,7 @@ func TestWriteJSONProtocolMap(t *testing.T) { } } } - err = p.ReadMapEnd() + err = p.ReadMapEnd(context.Background()) if err != nil { t.Fatalf("Error while reading map end: %s", err.Error()) } diff --git a/lib/go/thrift/multiplexed_protocol.go b/lib/go/thrift/multiplexed_protocol.go index 9db59c4c91c..2f7997e7795 100644 --- a/lib/go/thrift/multiplexed_protocol.go +++ b/lib/go/thrift/multiplexed_protocol.go @@ -68,11 +68,11 @@ func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplex } } -func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { +func (t *TMultiplexedProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error { if typeId == CALL || typeId == ONEWAY { - return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid) + return t.TProtocol.WriteMessageBegin(ctx, t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid) } else { - return t.TProtocol.WriteMessageBegin(name, typeId, seqid) + return t.TProtocol.WriteMessageBegin(ctx, name, typeId, seqid) } } @@ -190,7 +190,7 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces } func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) { - name, typeId, seqid, err := in.ReadMessageBegin() + name, typeId, seqid, err := in.ReadMessageBegin(ctx) if err != nil { return false, err } @@ -226,6 +226,6 @@ func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageTy return &storedMessageProtocol{protocol, name, typeId, seqid} } -func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) { +func (s *storedMessageProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) { return s.name, s.typeId, s.seqid, nil } diff --git a/lib/go/thrift/protocol.go b/lib/go/thrift/protocol.go index 2e6bc4b1619..0a69bd4162d 100644 --- a/lib/go/thrift/protocol.go +++ b/lib/go/thrift/protocol.go @@ -31,50 +31,50 @@ const ( ) type TProtocol interface { - WriteMessageBegin(name string, typeId TMessageType, seqid int32) error - WriteMessageEnd() error - WriteStructBegin(name string) error - WriteStructEnd() error - WriteFieldBegin(name string, typeId TType, id int16) error - WriteFieldEnd() error - WriteFieldStop() error - WriteMapBegin(keyType TType, valueType TType, size int) error - WriteMapEnd() error - WriteListBegin(elemType TType, size int) error - WriteListEnd() error - WriteSetBegin(elemType TType, size int) error - WriteSetEnd() error - WriteBool(value bool) error - WriteByte(value int8) error - WriteI16(value int16) error - WriteI32(value int32) error - WriteI64(value int64) error - WriteDouble(value float64) error - WriteString(value string) error - WriteBinary(value []byte) error + WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error + WriteMessageEnd(ctx context.Context) error + WriteStructBegin(ctx context.Context, name string) error + WriteStructEnd(ctx context.Context) error + WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error + WriteFieldEnd(ctx context.Context) error + WriteFieldStop(ctx context.Context) error + WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error + WriteMapEnd(ctx context.Context) error + WriteListBegin(ctx context.Context, elemType TType, size int) error + WriteListEnd(ctx context.Context) error + WriteSetBegin(ctx context.Context, elemType TType, size int) error + WriteSetEnd(ctx context.Context) error + WriteBool(ctx context.Context, value bool) error + WriteByte(ctx context.Context, value int8) error + WriteI16(ctx context.Context, value int16) error + WriteI32(ctx context.Context, value int32) error + WriteI64(ctx context.Context, value int64) error + WriteDouble(ctx context.Context, value float64) error + WriteString(ctx context.Context, value string) error + WriteBinary(ctx context.Context, value []byte) error - ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) - ReadMessageEnd() error - ReadStructBegin() (name string, err error) - ReadStructEnd() error - ReadFieldBegin() (name string, typeId TType, id int16, err error) - ReadFieldEnd() error - ReadMapBegin() (keyType TType, valueType TType, size int, err error) - ReadMapEnd() error - ReadListBegin() (elemType TType, size int, err error) - ReadListEnd() error - ReadSetBegin() (elemType TType, size int, err error) - ReadSetEnd() error - ReadBool() (value bool, err error) - ReadByte() (value int8, err error) - ReadI16() (value int16, err error) - ReadI32() (value int32, err error) - ReadI64() (value int64, err error) - ReadDouble() (value float64, err error) - ReadString() (value string, err error) - ReadBinary() (value []byte, err error) + ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) + ReadMessageEnd(ctx context.Context) error + ReadStructBegin(ctx context.Context) (name string, err error) + ReadStructEnd(ctx context.Context) error + ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) + ReadFieldEnd(ctx context.Context) error + ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) + ReadMapEnd(ctx context.Context) error + ReadListBegin(ctx context.Context) (elemType TType, size int, err error) + ReadListEnd(ctx context.Context) error + ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) + ReadSetEnd(ctx context.Context) error + ReadBool(ctx context.Context) (value bool, err error) + ReadByte(ctx context.Context) (value int8, err error) + ReadI16(ctx context.Context) (value int16, err error) + ReadI32(ctx context.Context) (value int32, err error) + ReadI64(ctx context.Context) (value int64, err error) + ReadDouble(ctx context.Context) (value float64, err error) + ReadString(ctx context.Context) (value string, err error) + ReadBinary(ctx context.Context) (value []byte, err error) - Skip(fieldType TType) (err error) + Skip(ctx context.Context, fieldType TType) (err error) Flush(ctx context.Context) (err error) Transport() TTransport @@ -84,12 +84,12 @@ type TProtocol interface { const DEFAULT_RECURSION_DEPTH = 64 // Skips over the next data element from the provided input TProtocol object. -func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { - return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH) +func SkipDefaultDepth(ctx context.Context, prot TProtocol, typeId TType) (err error) { + return Skip(ctx, prot, typeId, DEFAULT_RECURSION_DEPTH) } // Skips over the next data element from the provided input TProtocol object. -func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { +func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (err error) { if maxDepth <= 0 { return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded")) @@ -97,79 +97,79 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { switch fieldType { case BOOL: - _, err = self.ReadBool() + _, err = self.ReadBool(ctx) return case BYTE: - _, err = self.ReadByte() + _, err = self.ReadByte(ctx) return case I16: - _, err = self.ReadI16() + _, err = self.ReadI16(ctx) return case I32: - _, err = self.ReadI32() + _, err = self.ReadI32(ctx) return case I64: - _, err = self.ReadI64() + _, err = self.ReadI64(ctx) return case DOUBLE: - _, err = self.ReadDouble() + _, err = self.ReadDouble(ctx) return case STRING: - _, err = self.ReadString() + _, err = self.ReadString(ctx) return case STRUCT: - if _, err = self.ReadStructBegin(); err != nil { + if _, err = self.ReadStructBegin(ctx); err != nil { return err } for { - _, typeId, _, _ := self.ReadFieldBegin() + _, typeId, _, _ := self.ReadFieldBegin(ctx) if typeId == STOP { break } - err := Skip(self, typeId, maxDepth-1) + err := Skip(ctx, self, typeId, maxDepth-1) if err != nil { return err } - self.ReadFieldEnd() + self.ReadFieldEnd(ctx) } - return self.ReadStructEnd() + return self.ReadStructEnd(ctx) case MAP: - keyType, valueType, size, err := self.ReadMapBegin() + keyType, valueType, size, err := self.ReadMapBegin(ctx) if err != nil { return err } for i := 0; i < size; i++ { - err := Skip(self, keyType, maxDepth-1) + err := Skip(ctx, self, keyType, maxDepth-1) if err != nil { return err } - self.Skip(valueType) + self.Skip(ctx, valueType) } - return self.ReadMapEnd() + return self.ReadMapEnd(ctx) case SET: - elemType, size, err := self.ReadSetBegin() + elemType, size, err := self.ReadSetBegin(ctx) if err != nil { return err } for i := 0; i < size; i++ { - err := Skip(self, elemType, maxDepth-1) + err := Skip(ctx, self, elemType, maxDepth-1) if err != nil { return err } } - return self.ReadSetEnd() + return self.ReadSetEnd(ctx) case LIST: - elemType, size, err := self.ReadListBegin() + elemType, size, err := self.ReadListBegin(ctx) if err != nil { return err } for i := 0; i < size; i++ { - err := Skip(self, elemType, maxDepth-1) + err := Skip(ctx, self, elemType, maxDepth-1) if err != nil { return err } } - return self.ReadListEnd() + return self.ReadListEnd(ctx) default: return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType))) } diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go index 944055c0b4a..c1c67e8ca08 100644 --- a/lib/go/thrift/protocol_test.go +++ b/lib/go/thrift/protocol_test.go @@ -222,22 +222,22 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BOOL) thelen := len(BOOL_VALUES) - err := p.WriteListBegin(thetype, thelen) + err := p.WriteListBegin(context.Background(), thetype, thelen) if err != nil { t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype) } for k, v := range BOOL_VALUES { - err = p.WriteBool(v) + err = p.WriteBool(context.Background(), v) if err != nil { t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v) } } - p.WriteListEnd() + p.WriteListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES) } p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES) } @@ -251,7 +251,7 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range BOOL_VALUES { - value, err := p.ReadBool() + value, err := p.ReadBool(context.Background()) if err != nil { t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v) } @@ -259,7 +259,7 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value) } } - err = p.ReadListEnd() + err = p.ReadListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err) } @@ -268,17 +268,17 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BYTE) thelen := len(BYTE_VALUES) - err := p.WriteListBegin(thetype, thelen) + err := p.WriteListBegin(context.Background(), thetype, thelen) if err != nil { t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype) } for k, v := range BYTE_VALUES { - err = p.WriteByte(v) + err = p.WriteByte(context.Background(), v) if err != nil { t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v) } } - err = p.WriteListEnd() + err = p.WriteListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } @@ -286,7 +286,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { if err != nil { t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES) } @@ -300,7 +300,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range BYTE_VALUES { - value, err := p.ReadByte() + value, err := p.ReadByte(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v) } @@ -308,7 +308,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value) } } - err = p.ReadListEnd() + err = p.ReadListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err) } @@ -317,13 +317,13 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I16) thelen := len(INT16_VALUES) - p.WriteListBegin(thetype, thelen) + p.WriteListBegin(context.Background(), thetype, thelen) for _, v := range INT16_VALUES { - p.WriteI16(v) + p.WriteI16(context.Background(), v) } - p.WriteListEnd() + p.WriteListEnd(context.Background()) p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES) } @@ -337,7 +337,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range INT16_VALUES { - value, err := p.ReadI16() + value, err := p.ReadI16(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v) } @@ -345,7 +345,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value) } } - err = p.ReadListEnd() + err = p.ReadListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err) } @@ -354,13 +354,13 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I32) thelen := len(INT32_VALUES) - p.WriteListBegin(thetype, thelen) + p.WriteListBegin(context.Background(), thetype, thelen) for _, v := range INT32_VALUES { - p.WriteI32(v) + p.WriteI32(context.Background(), v) } - p.WriteListEnd() + p.WriteListEnd(context.Background()) p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES) } @@ -374,7 +374,7 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range INT32_VALUES { - value, err := p.ReadI32() + value, err := p.ReadI32(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v) } @@ -390,13 +390,13 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I64) thelen := len(INT64_VALUES) - p.WriteListBegin(thetype, thelen) + p.WriteListBegin(context.Background(), thetype, thelen) for _, v := range INT64_VALUES { - p.WriteI64(v) + p.WriteI64(context.Background(), v) } - p.WriteListEnd() + p.WriteListEnd(context.Background()) p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES) } @@ -410,7 +410,7 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range INT64_VALUES { - value, err := p.ReadI64() + value, err := p.ReadI64(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v) } @@ -426,13 +426,13 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(DOUBLE) thelen := len(DOUBLE_VALUES) - p.WriteListBegin(thetype, thelen) + p.WriteListBegin(context.Background(), thetype, thelen) for _, v := range DOUBLE_VALUES { - p.WriteDouble(v) + p.WriteDouble(context.Background(), v) } - p.WriteListEnd() + p.WriteListEnd(context.Background()) p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES) } @@ -443,7 +443,7 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2) } for k, v := range DOUBLE_VALUES { - value, err := p.ReadDouble() + value, err := p.ReadDouble(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v) } @@ -455,7 +455,7 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value) } } - err = p.ReadListEnd() + err = p.ReadListEnd(context.Background()) if err != nil { t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err) } @@ -464,13 +464,13 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(STRING) thelen := len(STRING_VALUES) - p.WriteListBegin(thetype, thelen) + p.WriteListBegin(context.Background(), thetype, thelen) for _, v := range STRING_VALUES { - p.WriteString(v) + p.WriteString(context.Background(), v) } - p.WriteListEnd() + p.WriteListEnd(context.Background()) p.Flush(context.Background()) - thetype2, thelen2, err := p.ReadListBegin() + thetype2, thelen2, err := p.ReadListBegin(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES) } @@ -484,7 +484,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { } } for k, v := range STRING_VALUES { - value, err := p.ReadString() + value, err := p.ReadString(context.Background()) if err != nil { t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v) } @@ -499,9 +499,9 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { v := protocol_bdata - p.WriteBinary(v) + p.WriteBinary(context.Background(), v) p.Flush(context.Background()) - value, err := p.ReadBinary() + value, err := p.ReadBinary(context.Background()) if err != nil { t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error()) } diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go index d85d2049de7..b1b8061d629 100644 --- a/lib/go/thrift/serializer.go +++ b/lib/go/thrift/serializer.go @@ -30,8 +30,8 @@ type TSerializer struct { } type TStruct interface { - Write(p TProtocol) error - Read(p TProtocol) error + Write(ctx context.Context, p TProtocol) error + Read(ctx context.Context, p TProtocol) error } func NewTSerializer() *TSerializer { @@ -46,7 +46,7 @@ func NewTSerializer() *TSerializer { func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) { t.Transport.Reset() - if err = msg.Write(t.Protocol); err != nil { + if err = msg.Write(ctx, t.Protocol); err != nil { return } @@ -63,7 +63,7 @@ func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, e func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) { t.Transport.Reset() - if err = msg.Write(t.Protocol); err != nil { + if err = msg.Write(ctx, t.Protocol); err != nil { return } diff --git a/lib/go/thrift/serializer_test.go b/lib/go/thrift/serializer_test.go index 52ebdca890b..2e37ea2b195 100644 --- a/lib/go/thrift/serializer_test.go +++ b/lib/go/thrift/serializer_test.go @@ -80,7 +80,7 @@ type serializer interface { } type deserializer interface { - ReadString(TStruct, string) error + ReadString(context.Context, TStruct, string) error } func plainSerializer(pf ProtocolFactory) serializer { @@ -158,7 +158,7 @@ func ProtocolTest1(t *testing.T, pf ProtocolFactory) { t1 := impl.Deserializer(pf) var m1 MyTestStruct - if err = t1.ReadString(&m1, s); err != nil { + if err = t1.ReadString(context.Background(), &m1, s); err != nil { test.Fatalf("Unable to Deserialize struct: %v", err) } @@ -199,7 +199,7 @@ func ProtocolTest2(t *testing.T, pf ProtocolFactory) { t1 := impl.Deserializer(pf) var m1 MyTestStruct - if err = t1.ReadString(&m1, s); err != nil { + if err = t1.ReadString(context.Background(), &m1, s); err != nil { test.Fatalf("Unable to Deserialize struct: %v", err) } @@ -264,7 +264,7 @@ func TestSerializerPoolAsync(t *testing.T) { t.Fatal("serialize:", err) } var m1 MyTestStruct - if err = d.ReadString(&m1, str); err != nil { + if err = d.ReadString(context.Background(), &m1, str); err != nil { t.Fatal("deserialize:", err) } @@ -335,7 +335,7 @@ func BenchmarkSerializer(b *testing.B) { str, _ := s.WriteString(context.Background(), &m) var m1 MyTestStruct d := c.Deserializer() - d.ReadString(&m1, str) + d.ReadString(context.Background(), &m1, str) } }, ) diff --git a/lib/go/thrift/serializer_types_test.go b/lib/go/thrift/serializer_types_test.go index e5472bbffc3..4d1e992ae42 100644 --- a/lib/go/thrift/serializer_types_test.go +++ b/lib/go/thrift/serializer_types_test.go @@ -48,6 +48,7 @@ struct MyTestStruct { */ import ( + "context" "fmt" ) @@ -162,12 +163,12 @@ func (p *MyTestStruct) GetStringSet() map[string]struct{} { func (p *MyTestStruct) GetE() MyTestEnum { return p.E } -func (p *MyTestStruct) Read(iprot TProtocol) error { - if _, err := iprot.ReadStructBegin(); err != nil { +func (p *MyTestStruct) Read(ctx context.Context, iprot TProtocol) error { + if _, err := iprot.ReadStructBegin(ctx); err != nil { return PrependError(fmt.Sprintf("%T read error: ", p), err) } for { - _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin() + _, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx) if err != nil { return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err) } @@ -176,70 +177,70 @@ func (p *MyTestStruct) Read(iprot TProtocol) error { } switch fieldId { case 1: - if err := p.readField1(iprot); err != nil { + if err := p.readField1(ctx, iprot); err != nil { return err } case 2: - if err := p.readField2(iprot); err != nil { + if err := p.readField2(ctx, iprot); err != nil { return err } case 3: - if err := p.readField3(iprot); err != nil { + if err := p.readField3(ctx, iprot); err != nil { return err } case 4: - if err := p.readField4(iprot); err != nil { + if err := p.readField4(ctx, iprot); err != nil { return err } case 5: - if err := p.readField5(iprot); err != nil { + if err := p.readField5(ctx, iprot); err != nil { return err } case 6: - if err := p.readField6(iprot); err != nil { + if err := p.readField6(ctx, iprot); err != nil { return err } case 7: - if err := p.readField7(iprot); err != nil { + if err := p.readField7(ctx, iprot); err != nil { return err } case 8: - if err := p.readField8(iprot); err != nil { + if err := p.readField8(ctx, iprot); err != nil { return err } case 9: - if err := p.readField9(iprot); err != nil { + if err := p.readField9(ctx, iprot); err != nil { return err } case 10: - if err := p.readField10(iprot); err != nil { + if err := p.readField10(ctx, iprot); err != nil { return err } case 11: - if err := p.readField11(iprot); err != nil { + if err := p.readField11(ctx, iprot); err != nil { return err } case 12: - if err := p.readField12(iprot); err != nil { + if err := p.readField12(ctx, iprot); err != nil { return err } default: - if err := iprot.Skip(fieldTypeId); err != nil { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err } } - if err := iprot.ReadFieldEnd(); err != nil { + if err := iprot.ReadFieldEnd(ctx); err != nil { return err } } - if err := iprot.ReadStructEnd(); err != nil { + if err := iprot.ReadStructEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } return nil } -func (p *MyTestStruct) readField1(iprot TProtocol) error { - if v, err := iprot.ReadBool(); err != nil { +func (p *MyTestStruct) readField1(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { return PrependError("error reading field 1: ", err) } else { p.On = v @@ -247,8 +248,8 @@ func (p *MyTestStruct) readField1(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField2(iprot TProtocol) error { - if v, err := iprot.ReadByte(); err != nil { +func (p *MyTestStruct) readField2(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadByte(ctx); err != nil { return PrependError("error reading field 2: ", err) } else { temp := int8(v) @@ -257,8 +258,8 @@ func (p *MyTestStruct) readField2(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField3(iprot TProtocol) error { - if v, err := iprot.ReadI16(); err != nil { +func (p *MyTestStruct) readField3(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadI16(ctx); err != nil { return PrependError("error reading field 3: ", err) } else { p.Int16 = v @@ -266,8 +267,8 @@ func (p *MyTestStruct) readField3(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField4(iprot TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { +func (p *MyTestStruct) readField4(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadI32(ctx); err != nil { return PrependError("error reading field 4: ", err) } else { p.Int32 = v @@ -275,8 +276,8 @@ func (p *MyTestStruct) readField4(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField5(iprot TProtocol) error { - if v, err := iprot.ReadI64(); err != nil { +func (p *MyTestStruct) readField5(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadI64(ctx); err != nil { return PrependError("error reading field 5: ", err) } else { p.Int64 = v @@ -284,8 +285,8 @@ func (p *MyTestStruct) readField5(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField6(iprot TProtocol) error { - if v, err := iprot.ReadDouble(); err != nil { +func (p *MyTestStruct) readField6(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadDouble(ctx); err != nil { return PrependError("error reading field 6: ", err) } else { p.D = v @@ -293,8 +294,8 @@ func (p *MyTestStruct) readField6(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField7(iprot TProtocol) error { - if v, err := iprot.ReadString(); err != nil { +func (p *MyTestStruct) readField7(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 7: ", err) } else { p.St = v @@ -302,8 +303,8 @@ func (p *MyTestStruct) readField7(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField8(iprot TProtocol) error { - if v, err := iprot.ReadBinary(); err != nil { +func (p *MyTestStruct) readField8(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadBinary(ctx); err != nil { return PrependError("error reading field 8: ", err) } else { p.Bin = v @@ -311,8 +312,8 @@ func (p *MyTestStruct) readField8(iprot TProtocol) error { return nil } -func (p *MyTestStruct) readField9(iprot TProtocol) error { - _, _, size, err := iprot.ReadMapBegin() +func (p *MyTestStruct) readField9(ctx context.Context, iprot TProtocol) error { + _, _, size, err := iprot.ReadMapBegin(ctx) if err != nil { return PrependError("error reading map begin: ", err) } @@ -320,27 +321,27 @@ func (p *MyTestStruct) readField9(iprot TProtocol) error { p.StringMap = tMap for i := 0; i < size; i++ { var _key0 string - if v, err := iprot.ReadString(); err != nil { + if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) } else { _key0 = v } var _val1 string - if v, err := iprot.ReadString(); err != nil { + if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) } else { _val1 = v } p.StringMap[_key0] = _val1 } - if err := iprot.ReadMapEnd(); err != nil { + if err := iprot.ReadMapEnd(ctx); err != nil { return PrependError("error reading map end: ", err) } return nil } -func (p *MyTestStruct) readField10(iprot TProtocol) error { - _, size, err := iprot.ReadListBegin() +func (p *MyTestStruct) readField10(ctx context.Context, iprot TProtocol) error { + _, size, err := iprot.ReadListBegin(ctx) if err != nil { return PrependError("error reading list begin: ", err) } @@ -348,21 +349,21 @@ func (p *MyTestStruct) readField10(iprot TProtocol) error { p.StringList = tSlice for i := 0; i < size; i++ { var _elem2 string - if v, err := iprot.ReadString(); err != nil { + if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) } else { _elem2 = v } p.StringList = append(p.StringList, _elem2) } - if err := iprot.ReadListEnd(); err != nil { + if err := iprot.ReadListEnd(ctx); err != nil { return PrependError("error reading list end: ", err) } return nil } -func (p *MyTestStruct) readField11(iprot TProtocol) error { - _, size, err := iprot.ReadSetBegin() +func (p *MyTestStruct) readField11(ctx context.Context, iprot TProtocol) error { + _, size, err := iprot.ReadSetBegin(ctx) if err != nil { return PrependError("error reading set begin: ", err) } @@ -370,21 +371,21 @@ func (p *MyTestStruct) readField11(iprot TProtocol) error { p.StringSet = tSet for i := 0; i < size; i++ { var _elem3 string - if v, err := iprot.ReadString(); err != nil { + if v, err := iprot.ReadString(ctx); err != nil { return PrependError("error reading field 0: ", err) } else { _elem3 = v } p.StringSet[_elem3] = struct{}{} } - if err := iprot.ReadSetEnd(); err != nil { + if err := iprot.ReadSetEnd(ctx); err != nil { return PrependError("error reading set end: ", err) } return nil } -func (p *MyTestStruct) readField12(iprot TProtocol) error { - if v, err := iprot.ReadI32(); err != nil { +func (p *MyTestStruct) readField12(ctx context.Context, iprot TProtocol) error { + if v, err := iprot.ReadI32(ctx); err != nil { return PrependError("error reading field 12: ", err) } else { temp := MyTestEnum(v) @@ -393,233 +394,233 @@ func (p *MyTestStruct) readField12(iprot TProtocol) error { return nil } -func (p *MyTestStruct) Write(oprot TProtocol) error { - if err := oprot.WriteStructBegin("MyTestStruct"); err != nil { +func (p *MyTestStruct) Write(ctx context.Context, oprot TProtocol) error { + if err := oprot.WriteStructBegin(ctx, "MyTestStruct"); err != nil { return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } - if err := p.writeField1(oprot); err != nil { + if err := p.writeField1(ctx, oprot); err != nil { return err } - if err := p.writeField2(oprot); err != nil { + if err := p.writeField2(ctx, oprot); err != nil { return err } - if err := p.writeField3(oprot); err != nil { + if err := p.writeField3(ctx, oprot); err != nil { return err } - if err := p.writeField4(oprot); err != nil { + if err := p.writeField4(ctx, oprot); err != nil { return err } - if err := p.writeField5(oprot); err != nil { + if err := p.writeField5(ctx, oprot); err != nil { return err } - if err := p.writeField6(oprot); err != nil { + if err := p.writeField6(ctx, oprot); err != nil { return err } - if err := p.writeField7(oprot); err != nil { + if err := p.writeField7(ctx, oprot); err != nil { return err } - if err := p.writeField8(oprot); err != nil { + if err := p.writeField8(ctx, oprot); err != nil { return err } - if err := p.writeField9(oprot); err != nil { + if err := p.writeField9(ctx, oprot); err != nil { return err } - if err := p.writeField10(oprot); err != nil { + if err := p.writeField10(ctx, oprot); err != nil { return err } - if err := p.writeField11(oprot); err != nil { + if err := p.writeField11(ctx, oprot); err != nil { return err } - if err := p.writeField12(oprot); err != nil { + if err := p.writeField12(ctx, oprot); err != nil { return err } - if err := oprot.WriteFieldStop(); err != nil { + if err := oprot.WriteFieldStop(ctx); err != nil { return PrependError("write field stop error: ", err) } - if err := oprot.WriteStructEnd(); err != nil { + if err := oprot.WriteStructEnd(ctx); err != nil { return PrependError("write struct stop error: ", err) } return nil } -func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil { +func (p *MyTestStruct) writeField1(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "on", BOOL, 1); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err) } - if err := oprot.WriteBool(bool(p.On)); err != nil { + if err := oprot.WriteBool(ctx, bool(p.On)); err != nil { return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err) } return err } -func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil { +func (p *MyTestStruct) writeField2(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "b", BYTE, 2); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err) } - if err := oprot.WriteByte(int8(p.B)); err != nil { + if err := oprot.WriteByte(ctx, int8(p.B)); err != nil { return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err) } return err } -func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil { +func (p *MyTestStruct) writeField3(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "int16", I16, 3); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err) } - if err := oprot.WriteI16(int16(p.Int16)); err != nil { + if err := oprot.WriteI16(ctx, int16(p.Int16)); err != nil { return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err) } return err } -func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil { +func (p *MyTestStruct) writeField4(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "int32", I32, 4); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err) } - if err := oprot.WriteI32(int32(p.Int32)); err != nil { + if err := oprot.WriteI32(ctx, int32(p.Int32)); err != nil { return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err) } return err } -func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil { +func (p *MyTestStruct) writeField5(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "int64", I64, 5); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err) } - if err := oprot.WriteI64(int64(p.Int64)); err != nil { + if err := oprot.WriteI64(ctx, int64(p.Int64)); err != nil { return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err) } return err } -func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil { +func (p *MyTestStruct) writeField6(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "d", DOUBLE, 6); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err) } - if err := oprot.WriteDouble(float64(p.D)); err != nil { + if err := oprot.WriteDouble(ctx, float64(p.D)); err != nil { return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err) } return err } -func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil { +func (p *MyTestStruct) writeField7(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "st", STRING, 7); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err) } - if err := oprot.WriteString(string(p.St)); err != nil { + if err := oprot.WriteString(ctx, string(p.St)); err != nil { return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err) } return err } -func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil { +func (p *MyTestStruct) writeField8(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "bin", STRING, 8); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err) } - if err := oprot.WriteBinary(p.Bin); err != nil { + if err := oprot.WriteBinary(ctx, p.Bin); err != nil { return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err) } return err } -func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil { +func (p *MyTestStruct) writeField9(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "stringMap", MAP, 9); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err) } - if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil { + if err := oprot.WriteMapBegin(ctx, STRING, STRING, len(p.StringMap)); err != nil { return PrependError("error writing map begin: ", err) } for k, v := range p.StringMap { - if err := oprot.WriteString(string(k)); err != nil { + if err := oprot.WriteString(ctx, string(k)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } - if err := oprot.WriteString(string(v)); err != nil { + if err := oprot.WriteString(ctx, string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } - if err := oprot.WriteMapEnd(); err != nil { + if err := oprot.WriteMapEnd(ctx); err != nil { return PrependError("error writing map end: ", err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err) } return err } -func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil { +func (p *MyTestStruct) writeField10(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "stringList", LIST, 10); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err) } - if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil { + if err := oprot.WriteListBegin(ctx, STRING, len(p.StringList)); err != nil { return PrependError("error writing list begin: ", err) } for _, v := range p.StringList { - if err := oprot.WriteString(string(v)); err != nil { + if err := oprot.WriteString(ctx, string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } - if err := oprot.WriteListEnd(); err != nil { + if err := oprot.WriteListEnd(ctx); err != nil { return PrependError("error writing list end: ", err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err) } return err } -func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil { +func (p *MyTestStruct) writeField11(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "stringSet", SET, 11); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err) } - if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil { + if err := oprot.WriteSetBegin(ctx, STRING, len(p.StringSet)); err != nil { return PrependError("error writing set begin: ", err) } for v := range p.StringSet { - if err := oprot.WriteString(string(v)); err != nil { + if err := oprot.WriteString(ctx, string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } } - if err := oprot.WriteSetEnd(); err != nil { + if err := oprot.WriteSetEnd(ctx); err != nil { return PrependError("error writing set end: ", err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err) } return err } -func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) { - if err := oprot.WriteFieldBegin("e", I32, 12); err != nil { +func (p *MyTestStruct) writeField12(ctx context.Context, oprot TProtocol) (err error) { + if err := oprot.WriteFieldBegin(ctx, "e", I32, 12); err != nil { return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err) } - if err := oprot.WriteI32(int32(p.E)); err != nil { + if err := oprot.WriteI32(ctx, int32(p.E)); err != nil { return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err) } - if err := oprot.WriteFieldEnd(); err != nil { + if err := oprot.WriteFieldEnd(ctx); err != nil { return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err) } return err diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go index f5e0c05d189..d101b993c17 100644 --- a/lib/go/thrift/simple_json_protocol.go +++ b/lib/go/thrift/simple_json_protocol.go @@ -156,114 +156,113 @@ func mismatch(expected, actual string) error { return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual) } -func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { +func (p *TSimpleJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error { p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteString(name); e != nil { + if e := p.WriteString(ctx, name); e != nil { return e } - if e := p.WriteByte(int8(typeId)); e != nil { + if e := p.WriteByte(ctx, int8(typeId)); e != nil { return e } - if e := p.WriteI32(seqId); e != nil { + if e := p.WriteI32(ctx, seqId); e != nil { return e } return nil } -func (p *TSimpleJSONProtocol) WriteMessageEnd() error { +func (p *TSimpleJSONProtocol) WriteMessageEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error { +func (p *TSimpleJSONProtocol) WriteStructBegin(ctx context.Context, name string) error { if e := p.OutputObjectBegin(); e != nil { return e } return nil } -func (p *TSimpleJSONProtocol) WriteStructEnd() error { +func (p *TSimpleJSONProtocol) WriteStructEnd(ctx context.Context) error { return p.OutputObjectEnd() } -func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { - if e := p.WriteString(name); e != nil { +func (p *TSimpleJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error { + if e := p.WriteString(ctx, name); e != nil { return e } return nil } -func (p *TSimpleJSONProtocol) WriteFieldEnd() error { - //return p.OutputListEnd() +func (p *TSimpleJSONProtocol) WriteFieldEnd(ctx context.Context) error { return nil } -func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil } +func (p *TSimpleJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil } -func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { +func (p *TSimpleJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteByte(int8(keyType)); e != nil { + if e := p.WriteByte(ctx, int8(keyType)); e != nil { return e } - if e := p.WriteByte(int8(valueType)); e != nil { + if e := p.WriteByte(ctx, int8(valueType)); e != nil { return e } - return p.WriteI32(int32(size)) + return p.WriteI32(ctx, int32(size)) } -func (p *TSimpleJSONProtocol) WriteMapEnd() error { +func (p *TSimpleJSONProtocol) WriteMapEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error { +func (p *TSimpleJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } -func (p *TSimpleJSONProtocol) WriteListEnd() error { +func (p *TSimpleJSONProtocol) WriteListEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error { +func (p *TSimpleJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error { return p.OutputElemListBegin(elemType, size) } -func (p *TSimpleJSONProtocol) WriteSetEnd() error { +func (p *TSimpleJSONProtocol) WriteSetEnd(ctx context.Context) error { return p.OutputListEnd() } -func (p *TSimpleJSONProtocol) WriteBool(b bool) error { +func (p *TSimpleJSONProtocol) WriteBool(ctx context.Context, b bool) error { return p.OutputBool(b) } -func (p *TSimpleJSONProtocol) WriteByte(b int8) error { - return p.WriteI32(int32(b)) +func (p *TSimpleJSONProtocol) WriteByte(ctx context.Context, b int8) error { + return p.WriteI32(ctx, int32(b)) } -func (p *TSimpleJSONProtocol) WriteI16(v int16) error { - return p.WriteI32(int32(v)) +func (p *TSimpleJSONProtocol) WriteI16(ctx context.Context, v int16) error { + return p.WriteI32(ctx, int32(v)) } -func (p *TSimpleJSONProtocol) WriteI32(v int32) error { +func (p *TSimpleJSONProtocol) WriteI32(ctx context.Context, v int32) error { return p.OutputI64(int64(v)) } -func (p *TSimpleJSONProtocol) WriteI64(v int64) error { +func (p *TSimpleJSONProtocol) WriteI64(ctx context.Context, v int64) error { return p.OutputI64(int64(v)) } -func (p *TSimpleJSONProtocol) WriteDouble(v float64) error { +func (p *TSimpleJSONProtocol) WriteDouble(ctx context.Context, v float64) error { return p.OutputF64(v) } -func (p *TSimpleJSONProtocol) WriteString(v string) error { +func (p *TSimpleJSONProtocol) WriteString(ctx context.Context, v string) error { return p.OutputString(v) } -func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { +func (p *TSimpleJSONProtocol) WriteBinary(ctx context.Context, v []byte) error { // JSON library only takes in a string, // not an arbitrary byte array, to ensure bytes are transmitted // efficiently we must convert this into a valid JSON string @@ -289,39 +288,39 @@ func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { } // Reading methods. -func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { +func (p *TSimpleJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) { p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } - if name, err = p.ReadString(); err != nil { + if name, err = p.ReadString(ctx); err != nil { return name, typeId, seqId, err } - bTypeId, err := p.ReadByte() + bTypeId, err := p.ReadByte(ctx) typeId = TMessageType(bTypeId) if err != nil { return name, typeId, seqId, err } - if seqId, err = p.ReadI32(); err != nil { + if seqId, err = p.ReadI32(ctx); err != nil { return name, typeId, seqId, err } return name, typeId, seqId, nil } -func (p *TSimpleJSONProtocol) ReadMessageEnd() error { +func (p *TSimpleJSONProtocol) ReadMessageEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) { +func (p *TSimpleJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) { _, err = p.ParseObjectStart() return "", err } -func (p *TSimpleJSONProtocol) ReadStructEnd() error { +func (p *TSimpleJSONProtocol) ReadStructEnd(ctx context.Context) error { return p.ParseObjectEnd() } -func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { +func (p *TSimpleJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) { if err := p.ParsePreValue(); err != nil { return "", STOP, 0, err } @@ -340,21 +339,6 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { return name, STOP, 0, err } return name, STOP, -1, p.ParsePostValue() - /* - if err = p.ParsePostValue(); err != nil { - return name, STOP, 0, err - } - if isNull, err := p.ParseListBegin(); isNull || err != nil { - return name, STOP, 0, err - } - bType, err := p.ReadByte() - thetype := TType(bType) - if err != nil { - return name, thetype, 0, err - } - id, err := p.ReadI16() - return name, thetype, id, err - */ } e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b)) return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e) @@ -362,57 +346,56 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { return "", STOP, 0, NewTProtocolException(io.EOF) } -func (p *TSimpleJSONProtocol) ReadFieldEnd() error { +func (p *TSimpleJSONProtocol) ReadFieldEnd(ctx context.Context) error { return nil - //return p.ParseListEnd() } -func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) { +func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, VOID, 0, e } // read keyType - bKeyType, e := p.ReadByte() + bKeyType, e := p.ReadByte(ctx) keyType = TType(bKeyType) if e != nil { return keyType, valueType, size, e } // read valueType - bValueType, e := p.ReadByte() + bValueType, e := p.ReadByte(ctx) valueType = TType(bValueType) if e != nil { return keyType, valueType, size, e } // read size - iSize, err := p.ReadI64() + iSize, err := p.ReadI64(ctx) size = int(iSize) return keyType, valueType, size, err } -func (p *TSimpleJSONProtocol) ReadMapEnd() error { +func (p *TSimpleJSONProtocol) ReadMapEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) { +func (p *TSimpleJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) { return p.ParseElemListBegin() } -func (p *TSimpleJSONProtocol) ReadListEnd() error { +func (p *TSimpleJSONProtocol) ReadListEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) { +func (p *TSimpleJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) { return p.ParseElemListBegin() } -func (p *TSimpleJSONProtocol) ReadSetEnd() error { +func (p *TSimpleJSONProtocol) ReadSetEnd(ctx context.Context) error { return p.ParseListEnd() } -func (p *TSimpleJSONProtocol) ReadBool() (bool, error) { +func (p *TSimpleJSONProtocol) ReadBool(ctx context.Context) (bool, error) { var value bool if err := p.ParsePreValue(); err != nil { @@ -467,32 +450,32 @@ func (p *TSimpleJSONProtocol) ReadBool() (bool, error) { return value, p.ParsePostValue() } -func (p *TSimpleJSONProtocol) ReadByte() (int8, error) { - v, err := p.ReadI64() +func (p *TSimpleJSONProtocol) ReadByte(ctx context.Context) (int8, error) { + v, err := p.ReadI64(ctx) return int8(v), err } -func (p *TSimpleJSONProtocol) ReadI16() (int16, error) { - v, err := p.ReadI64() +func (p *TSimpleJSONProtocol) ReadI16(ctx context.Context) (int16, error) { + v, err := p.ReadI64(ctx) return int16(v), err } -func (p *TSimpleJSONProtocol) ReadI32() (int32, error) { - v, err := p.ReadI64() +func (p *TSimpleJSONProtocol) ReadI32(ctx context.Context) (int32, error) { + v, err := p.ReadI64(ctx) return int32(v), err } -func (p *TSimpleJSONProtocol) ReadI64() (int64, error) { +func (p *TSimpleJSONProtocol) ReadI64(ctx context.Context) (int64, error) { v, _, err := p.ParseI64() return v, err } -func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) { +func (p *TSimpleJSONProtocol) ReadDouble(ctx context.Context) (float64, error) { v, _, err := p.ParseF64() return v, err } -func (p *TSimpleJSONProtocol) ReadString() (string, error) { +func (p *TSimpleJSONProtocol) ReadString(ctx context.Context) (string, error) { var v string if err := p.ParsePreValue(); err != nil { return v, err @@ -522,7 +505,7 @@ func (p *TSimpleJSONProtocol) ReadString() (string, error) { return v, p.ParsePostValue() } -func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) { +func (p *TSimpleJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) { var v []byte if err := p.ParsePreValue(); err != nil { return nil, err @@ -557,8 +540,8 @@ func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) { return NewTProtocolException(p.writer.Flush()) } -func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) { - return SkipDefaultDepth(p, fieldType) +func (p *TSimpleJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) { + return SkipDefaultDepth(ctx, p, fieldType) } func (p *TSimpleJSONProtocol) Transport() TTransport { @@ -740,10 +723,10 @@ func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) erro if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteByte(int8(elemType)); e != nil { + if e := p.OutputI64(int64(elemType)); e != nil { return e } - if e := p.WriteI64(int64(size)); e != nil { + if e := p.OutputI64(int64(size)); e != nil { return e } return nil @@ -1039,12 +1022,12 @@ func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } - bElemType, err := p.ReadByte() + bElemType, _, err := p.ParseI64() elemType = TType(bElemType) if err != nil { return elemType, size, err } - nSize, err2 := p.ReadI64() + nSize, _, err2 := p.ParseI64() size = int(nSize) return elemType, size, err2 } diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go index 0126da0a8e4..951389a60d2 100644 --- a/lib/go/thrift/simple_json_protocol_test.go +++ b/lib/go/thrift/simple_json_protocol_test.go @@ -35,7 +35,7 @@ func TestWriteSimpleJSONProtocolBool(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range BOOL_VALUES { - if e := p.WriteBool(value); e != nil { + if e := p.WriteBool(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -66,7 +66,7 @@ func TestReadSimpleJSONProtocolBool(t *testing.T) { } trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadBool() + v, e := p.ReadBool(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -86,7 +86,7 @@ func TestWriteSimpleJSONProtocolByte(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range BYTE_VALUES { - if e := p.WriteByte(value); e != nil { + if e := p.WriteByte(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -113,7 +113,7 @@ func TestReadSimpleJSONProtocolByte(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadByte() + v, e := p.ReadByte(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -133,7 +133,7 @@ func TestWriteSimpleJSONProtocolI16(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT16_VALUES { - if e := p.WriteI16(value); e != nil { + if e := p.WriteI16(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -160,7 +160,7 @@ func TestReadSimpleJSONProtocolI16(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI16() + v, e := p.ReadI16(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -180,7 +180,7 @@ func TestWriteSimpleJSONProtocolI32(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT32_VALUES { - if e := p.WriteI32(value); e != nil { + if e := p.WriteI32(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -207,7 +207,7 @@ func TestReadSimpleJSONProtocolI32(t *testing.T) { trans.WriteString(strconv.Itoa(int(value))) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI32() + v, e := p.ReadI32(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -231,7 +231,7 @@ func TestReadSimpleJSONProtocolI32Null(t *testing.T) { trans.WriteString(value) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI32() + v, e := p.ReadI32(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) @@ -248,7 +248,7 @@ func TestWriteSimpleJSONProtocolI64(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range INT64_VALUES { - if e := p.WriteI64(value); e != nil { + if e := p.WriteI64(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -275,7 +275,7 @@ func TestReadSimpleJSONProtocolI64(t *testing.T) { trans.WriteString(strconv.FormatInt(value, 10)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI64() + v, e := p.ReadI64(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -299,7 +299,7 @@ func TestReadSimpleJSONProtocolI64Null(t *testing.T) { trans.WriteString(value) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadI64() + v, e := p.ReadI64(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) @@ -316,7 +316,7 @@ func TestWriteSimpleJSONProtocolDouble(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -358,7 +358,7 @@ func TestReadSimpleJSONProtocolDouble(t *testing.T) { trans.WriteString(n.String()) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadDouble() + v, e := p.ReadDouble(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -392,7 +392,7 @@ func TestWriteSimpleJSONProtocolString(t *testing.T) { trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) for _, value := range STRING_VALUES { - if e := p.WriteString(value); e != nil { + if e := p.WriteString(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -419,7 +419,7 @@ func TestReadSimpleJSONProtocolString(t *testing.T) { trans.WriteString(jsonQuote(value)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadString() + v, e := p.ReadString(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -443,7 +443,7 @@ func TestReadSimpleJSONProtocolStringNull(t *testing.T) { trans.WriteString(value) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadString() + v, e := p.ReadString(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -462,7 +462,7 @@ func TestWriteSimpleJSONProtocolBinary(t *testing.T) { b64String := string(b64value) trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) - if e := p.WriteBinary(value); e != nil { + if e := p.WriteBinary(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } if e := p.Flush(context.Background()); e != nil { @@ -490,7 +490,7 @@ func TestReadSimpleJSONProtocolBinary(t *testing.T) { trans.WriteString(jsonQuote(b64String)) trans.Flush(context.Background()) s := trans.String() - v, e := p.ReadBinary() + v, e := p.ReadBinary(context.Background()) if e != nil { t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error()) } @@ -519,7 +519,7 @@ func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) { trans.WriteString(value) trans.Flush(context.Background()) s := trans.String() - b, e := p.ReadBinary() + b, e := p.ReadBinary(context.Background()) v := string(b) if e != nil { @@ -536,13 +536,13 @@ func TestWriteSimpleJSONProtocolList(t *testing.T) { thetype := "list" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) - p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteListBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteListEnd() + p.WriteListEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -590,13 +590,13 @@ func TestWriteSimpleJSONProtocolSet(t *testing.T) { thetype := "set" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) - p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteSetBegin(context.Background(), TType(DOUBLE), len(DOUBLE_VALUES)) for _, value := range DOUBLE_VALUES { - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteSetEnd() + p.WriteSetEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -644,16 +644,16 @@ func TestWriteSimpleJSONProtocolMap(t *testing.T) { thetype := "map" trans := NewTMemoryBuffer() p := NewTSimpleJSONProtocol(trans) - p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) + p.WriteMapBegin(context.Background(), TType(I32), TType(DOUBLE), len(DOUBLE_VALUES)) for k, value := range DOUBLE_VALUES { - if e := p.WriteI32(int32(k)); e != nil { + if e := p.WriteI32(context.Background(), int32(k)); e != nil { t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error()) } - if e := p.WriteDouble(value); e != nil { + if e := p.WriteDouble(context.Background(), value); e != nil { t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error()) } } - p.WriteMapEnd() + p.WriteMapEnd(context.Background()) if e := p.Flush(context.Background()); e != nil { t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error()) } @@ -720,17 +720,17 @@ func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) { p := NewTSimpleJSONProtocol(trans) trans.Write([]byte{'a', 'b'}) trans.Flush(context.Background()) - + test1 := p.safePeekContains([]byte{'a', 'b'}) if !test1 { t.Fatalf("Should match at test 1") } - + test2 := p.safePeekContains([]byte{'a', 'b', 'c', 'd'}) if test2 { t.Fatalf("Should not match at test 2") } - + test3 := p.safePeekContains([]byte{'x', 'y'}) if test3 { t.Fatalf("Should not match at test 3") diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 5a9c9c9e127..85baa4ee1a4 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -285,7 +285,7 @@ func (p *TSimpleServer) processRequests(client TTransport) (err error) { // ReadFrame is safe to be called multiple times so it // won't break when it's called again later when we // actually start to read the message. - if err := headerProtocol.ReadFrame(); err != nil { + if err := headerProtocol.ReadFrame(ctx); err != nil { return err } ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders()) diff --git a/lib/go/thrift/transport_exception.go b/lib/go/thrift/transport_exception.go index d2283ea2e5c..16193ee8661 100644 --- a/lib/go/thrift/transport_exception.go +++ b/lib/go/thrift/transport_exception.go @@ -64,6 +64,10 @@ func (p *tTransportException) Unwrap() error { return p.err } +func (p *tTransportException) Timeout() bool { + return p.typeId == TIMED_OUT +} + func NewTTransportException(t int, e string) TTransportException { return &tTransportException{typeId: t, err: errors.New(e)} } @@ -92,3 +96,13 @@ func NewTTransportExceptionFromError(e error) TTransportException { return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e} } + +// isTimeoutError returns true when err is a timeout error. +// +// Note that this also includes TTransportException wrapped timeout errors. +func isTimeoutError(err error) bool { + if t, ok := err.(timeoutable); ok { + return t.Timeout() + } + return false +} diff --git a/lib/go/thrift/transport_exception_test.go b/lib/go/thrift/transport_exception_test.go index cf26258d92c..fb1dc2602aa 100644 --- a/lib/go/thrift/transport_exception_test.go +++ b/lib/go/thrift/transport_exception_test.go @@ -20,9 +20,9 @@ package thrift import ( + "errors" "fmt" "io" - "testing" ) @@ -78,3 +78,32 @@ func TestTExceptionEOF(t *testing.T) { t.Errorf("Unwrapped exception did not match: expected %v, got %v", io.EOF, e.Unwrap()) } } + +func TestIsTimeoutError(t *testing.T) { + te := &timeout{true} + if !isTimeoutError(te) { + t.Error("isTimeoutError expected true, got false") + } + e := NewTTransportExceptionFromError(te) + if !isTimeoutError(e) { + t.Error("isTimeoutError on wrapped TTransportException expected true, got false") + } + + te = &timeout{false} + if isTimeoutError(te) { + t.Error("isTimeoutError expected false, got true") + } + e = NewTTransportExceptionFromError(te) + if isTimeoutError(e) { + t.Error("isTimeoutError on wrapped TTransportException expected false, got true") + } + + err := errors.New("foo") + if isTimeoutError(err) { + t.Error("isTimeoutError expected false, got true") + } + e = NewTTransportExceptionFromError(err) + if isTimeoutError(e) { + t.Error("isTimeoutError on wrapped TTransportException expected false, got true") + } +}