diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 90353ce9b5c..54422c82670 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -2797,13 +2797,13 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "if thrift.ServerConnectivityCheckInterval > 0 {" << endl; indent_up(); - f_types_ << indent() << "var cancel context.CancelFunc" << endl; - f_types_ << indent() << "ctx, cancel = context.WithCancel(ctx)" << endl; - f_types_ << indent() << "defer cancel()" << endl; + f_types_ << indent() << "var cancel context.CancelCauseFunc" << endl; + f_types_ << indent() << "ctx, cancel = context.WithCancelCause(ctx)" << endl; + f_types_ << indent() << "defer cancel(nil)" << endl; f_types_ << indent() << "var tickerCtx context.Context" << endl; f_types_ << indent() << "tickerCtx, tickerCancel = context.WithCancel(context.Background())" << endl; f_types_ << indent() << "defer tickerCancel()" << endl; - f_types_ << indent() << "go func(ctx context.Context, cancel context.CancelFunc) {" << endl; + f_types_ << indent() << "go func(ctx context.Context, cancel context.CancelCauseFunc) {" << endl; indent_up(); f_types_ << indent() << "ticker := time.NewTicker(thrift.ServerConnectivityCheckInterval)" << endl; @@ -2821,7 +2821,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* indent_up(); f_types_ << indent() << "if !iprot.Transport().IsOpen() {" << endl; indent_up(); - f_types_ << indent() << "cancel()" << endl; + f_types_ << indent() << "cancel(thrift.ErrAbandonRequest)" << endl; f_types_ << indent() << "return" << endl; indent_down(); f_types_ << indent() << "}" << endl; @@ -2901,6 +2901,15 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "return false, thrift.WrapTException(err2)" << endl; indent_down(); f_types_ << indent() << "}" << endl; + f_types_ << indent() << "if errors.Is(err2, context.Canceled) {" << endl; + indent_up(); + f_types_ << indent() << "if err := context.Cause(ctx); errors.Is(err, thrift.ErrAbandonRequest) {" << endl; + indent_up(); + f_types_ << indent() << "return false, thrift.WrapTException(err)" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; + indent_down(); + f_types_ << indent() << "}" << endl; string exc(tmp("_exc")); f_types_ << indent() << exc << " := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, " diff --git a/lib/go/README.md b/lib/go/README.md index b2cf1df12c2..0aa4f1bc6f2 100644 --- a/lib/go/README.md +++ b/lib/go/README.md @@ -108,13 +108,19 @@ The context object passed into the server handler function will be canceled when the client closes the connection (this is a best effort check, not a guarantee -- there's no guarantee that the context object is always canceled when client closes the connection, but when it's canceled you can always assume the client -closed the connection). When implementing Go Thrift server, you can take -advantage of that to abandon requests that's no longer needed: +closed the connection). The cause of the cancellation (via `context.Cause(ctx)`) +would also be set to `thrift.ErrAbandonRequest`. + +When implementing Go Thrift server, you can take advantage of that to abandon +requests that's no longer needed by returning `thrift.ErrAbandonRequest`: func MyEndpoint(ctx context.Context, req *thriftRequestType) (*thriftResponseType, error) { ... if ctx.Err() == context.Canceled { return nil, thrift.ErrAbandonRequest + // Or just return ctx.Err(), compiler generated processor code will + // handle it for you automatically: + // return nil, ctx.Err() } ... } @@ -155,4 +161,4 @@ will wait for all the client connections to be closed gracefully with zero err time. Otherwise, the stop will wait for all the client connections to be closed gracefully util thrift.ServerStopTimeout is reached, and client connections that are not closed after thrift.ServerStopTimeout -will be closed abruptly which may cause some client errors. \ No newline at end of file +will be closed abruptly which may cause some client errors. diff --git a/lib/go/test/tests/server_connectivity_check_test.go b/lib/go/test/tests/server_connectivity_check_test.go new file mode 100644 index 00000000000..51710eda291 --- /dev/null +++ b/lib/go/test/tests/server_connectivity_check_test.go @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package tests + +import ( + "context" + "runtime/debug" + "testing" + "time" + + "github.com/apache/thrift/lib/go/test/gopath/src/clientmiddlewareexceptiontest" + "github.com/apache/thrift/lib/go/thrift" +) + +func TestServerConnectivityCheck(t *testing.T) { + const ( + // Server will sleep for longer than client is willing to wait + // so client will close the connection. + serverSleep = 50 * time.Millisecond + clientSocketTimeout = time.Millisecond + ) + serverSocket, err := thrift.NewTServerSocket(":0") + if err != nil { + t.Fatalf("failed to create server socket: %v", err) + } + processor := clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestProcessor(fakeClientMiddlewareExceptionTestHandler( + func(ctx context.Context) (*clientmiddlewareexceptiontest.FooResponse, error) { + time.Sleep(serverSleep) + err := ctx.Err() + if err == nil { + t.Error("Expected server ctx to be cancelled, did not happen") + return new(clientmiddlewareexceptiontest.FooResponse), nil + } + return nil, err + }, + )) + server := thrift.NewTSimpleServer2(processor, serverSocket) + if err := server.Listen(); err != nil { + t.Fatalf("failed to listen server: %v", err) + } + server.SetLogger(func(msg string) { + t.Errorf("Server logger called with %q", msg) + t.Errorf("Server logger callstack:\n%s", debug.Stack()) + }) + addr := serverSocket.Addr().String() + go server.Serve() + t.Cleanup(func() { + server.Stop() + }) + + cfg := &thrift.TConfiguration{ + SocketTimeout: clientSocketTimeout, + } + socket := thrift.NewTSocketConf(addr, cfg) + if err := socket.Open(); err != nil { + t.Fatalf("failed to create client connection: %v", err) + } + t.Cleanup(func() { + socket.Close() + }) + inProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + outProtocol := thrift.NewTBinaryProtocolConf(socket, cfg) + client := thrift.NewTStandardClient(inProtocol, outProtocol) + ctx, cancel := context.WithTimeout(context.Background(), clientSocketTimeout) + defer cancel() + _, err = clientmiddlewareexceptiontest.NewClientMiddlewareExceptionTestClient(client).Foo(ctx) + socket.Close() + if err == nil { + t.Error("Expected client to time out, did not happen") + } +} diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index c5c14feed5f..d4f555ccd51 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "net" "sync" "sync/atomic" "time" @@ -354,7 +355,13 @@ func (p *TSimpleServer) processRequests(client TTransport) (err error) { ok, err := processor.Process(ctx, inputProtocol, outputProtocol) if errors.Is(err, ErrAbandonRequest) { - return client.Close() + err := client.Close() + if errors.Is(err, net.ErrClosed) { + // In this case, it's kinda expected to get + // net.ErrClosed, treat that as no-error + return nil + } + return err } if errors.As(err, new(TTransportException)) && err != nil { return err