From c40c9ba315aaaa58eb4dc74ba9d8bbc5058f1dfa Mon Sep 17 00:00:00 2001
From: Doug Fawley <dfawley@google.com>
Date: Tue, 10 Oct 2023 11:42:11 -0700
Subject: [PATCH] server: prohibit more than MaxConcurrentStreams handlers from
 running at once (#6703) (#6705)

---
 benchmark/primitives/primitives_test.go | 39 ++++++++++
 internal/transport/http2_server.go      | 11 +--
 internal/transport/transport_test.go    | 35 +++++----
 server.go                               | 71 ++++++++++++------
 server_ext_test.go                      | 99 +++++++++++++++++++++++++
 5 files changed, 210 insertions(+), 45 deletions(-)
 create mode 100644 server_ext_test.go

diff --git a/benchmark/primitives/primitives_test.go b/benchmark/primitives/primitives_test.go
index dbbb313e8dcc..5d7c81090c15 100644
--- a/benchmark/primitives/primitives_test.go
+++ b/benchmark/primitives/primitives_test.go
@@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
 		}
 	})
 }
+
+type ifNop interface {
+	nop()
+}
+
+type alwaysNop struct{}
+
+func (alwaysNop) nop() {}
+
+type concreteNop struct {
+	isNop atomic.Bool
+	i     int
+}
+
+func (c *concreteNop) nop() {
+	if c.isNop.Load() {
+		return
+	}
+	c.i++
+}
+
+func BenchmarkInterfaceNop(b *testing.B) {
+	n := ifNop(alwaysNop{})
+	b.RunParallel(func(pb *testing.PB) {
+		for pb.Next() {
+			n.nop()
+		}
+	})
+}
+
+func BenchmarkConcreteNop(b *testing.B) {
+	n := &concreteNop{}
+	n.isNop.Store(true)
+	b.RunParallel(func(pb *testing.PB) {
+		for pb.Next() {
+			n.nop()
+		}
+	})
+}
diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go
index 8d3a353c1d58..c06db679d89c 100644
--- a/internal/transport/http2_server.go
+++ b/internal/transport/http2_server.go
@@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
 		ID:  http2.SettingMaxFrameSize,
 		Val: http2MaxFrameLen,
 	}}
-	// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
-	// permitted in the HTTP2 spec.
-	maxStreams := config.MaxStreams
-	if maxStreams == 0 {
-		maxStreams = math.MaxUint32
-	} else {
+	if config.MaxStreams != math.MaxUint32 {
 		isettings = append(isettings, http2.Setting{
 			ID:  http2.SettingMaxConcurrentStreams,
-			Val: maxStreams,
+			Val: config.MaxStreams,
 		})
 	}
 	dynamicWindow := true
@@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
 		framer:            framer,
 		readerDone:        make(chan struct{}),
 		writerDone:        make(chan struct{}),
-		maxStreams:        maxStreams,
+		maxStreams:        config.MaxStreams,
 		inTapHandle:       config.InTapHandle,
 		fc:                &trInFlow{limit: uint32(icwz)},
 		state:             reachable,
diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go
index 258ef7411cf0..bb27a6b63e98 100644
--- a/internal/transport/transport_test.go
+++ b/internal/transport/transport_test.go
@@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
 			return
 		}
 		rawConn := conn
+		if serverConfig.MaxStreams == 0 {
+			serverConfig.MaxStreams = math.MaxUint32
+		}
 		transport, err := NewServerTransport(conn, serverConfig)
 		if err != nil {
 			return
@@ -443,8 +446,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
 	return server
 }
 
-func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
-	return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
+func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
+	return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
 }
 
 func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@@ -539,7 +542,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
 
 // Tests that when streamID > MaxStreamId, the current client transport drains.
 func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	defer server.stop()
 	callHdr := &CallHdr{
@@ -584,7 +587,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
 }
 
 func (s) TestClientSendAndReceive(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -624,7 +627,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
 }
 
 func (s) TestClientErrorNotify(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	go server.stop()
 	// ct.reader should detect the error and activate ct.Error().
@@ -658,7 +661,7 @@ func performOneRPC(ct ClientTransport) {
 }
 
 func (s) TestClientMix(t *testing.T) {
-	s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	s, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	time.AfterFunc(time.Second, s.stop)
 	go func(ct ClientTransport) {
@@ -672,7 +675,7 @@ func (s) TestClientMix(t *testing.T) {
 }
 
 func (s) TestLargeMessage(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -807,7 +810,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
 // proceed until they complete naturally, while not allowing creation of new
 // streams during this window.
 func (s) TestGracefulClose(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
+	server, ct, cancel := setUp(t, 0, pingpong)
 	defer cancel()
 	defer func() {
 		// Stop the server's listener to make the server's goroutines terminate
@@ -873,7 +876,7 @@ func (s) TestGracefulClose(t *testing.T) {
 }
 
 func (s) TestLargeMessageSuspension(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
+	server, ct, cancel := setUp(t, 0, suspended)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -981,7 +984,7 @@ func (s) TestMaxStreams(t *testing.T) {
 }
 
 func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
+	server, ct, cancel := setUp(t, 0, suspended)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -1453,7 +1456,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
 var encodingTestStatus = status.New(codes.Internal, "\n")
 
 func (s) TestEncodingRequiredStatus(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
+	server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -1481,7 +1484,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
 }
 
 func (s) TestInvalidHeaderField(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
+	server, ct, cancel := setUp(t, 0, invalidHeaderField)
 	defer cancel()
 	callHdr := &CallHdr{
 		Host:   "localhost",
@@ -1503,7 +1506,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
 }
 
 func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
+	server, ct, cancel := setUp(t, 0, invalidHeaderField)
 	defer cancel()
 	defer server.stop()
 	defer ct.Close(fmt.Errorf("closed manually by test"))
@@ -2171,7 +2174,7 @@ func (s) TestPingPong1MB(t *testing.T) {
 
 // This is a stress-test of flow control logic.
 func runPingPongTest(t *testing.T, msgSize int) {
-	server, client, cancel := setUp(t, 0, 0, pingpong)
+	server, client, cancel := setUp(t, 0, pingpong)
 	defer cancel()
 	defer server.stop()
 	defer client.Close(fmt.Errorf("closed manually by test"))
@@ -2253,7 +2256,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
 		}
 	}()
 
-	server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, ct, cancel := setUp(t, 0, normal)
 	defer cancel()
 	defer ct.Close(fmt.Errorf("closed manually by test"))
 	defer server.stop()
@@ -2612,7 +2615,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
 
 func (s) TestPeerSetInServerContext(t *testing.T) {
 	// create client and server transports.
-	server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
+	server, client, cancel := setUp(t, 0, normal)
 	defer cancel()
 	defer server.stop()
 	defer client.Close(fmt.Errorf("closed manually by test"))
diff --git a/server.go b/server.go
index 244123c6c5a8..eeae92fbe020 100644
--- a/server.go
+++ b/server.go
@@ -115,12 +115,6 @@ type serviceInfo struct {
 	mdata       any
 }
 
-type serverWorkerData struct {
-	st     transport.ServerTransport
-	wg     *sync.WaitGroup
-	stream *transport.Stream
-}
-
 // Server is a gRPC server to serve RPC requests.
 type Server struct {
 	opts serverOptions
@@ -145,7 +139,7 @@ type Server struct {
 	channelzID *channelz.Identifier
 	czData     *channelzData
 
-	serverWorkerChannel chan *serverWorkerData
+	serverWorkerChannel chan func()
 }
 
 type serverOptions struct {
@@ -179,6 +173,7 @@ type serverOptions struct {
 }
 
 var defaultServerOptions = serverOptions{
+	maxConcurrentStreams:  math.MaxUint32,
 	maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
 	maxSendMessageSize:    defaultServerMaxSendMessageSize,
 	connectionTimeout:     120 * time.Second,
@@ -404,6 +399,9 @@ func MaxSendMsgSize(m int) ServerOption {
 // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
 // of concurrent streams to each ServerTransport.
 func MaxConcurrentStreams(n uint32) ServerOption {
+	if n == 0 {
+		n = math.MaxUint32
+	}
 	return newFuncServerOption(func(o *serverOptions) {
 		o.maxConcurrentStreams = n
 	})
@@ -605,24 +603,19 @@ const serverWorkerResetThreshold = 1 << 16
 // [1] https://github.com/golang/go/issues/18138
 func (s *Server) serverWorker() {
 	for completed := 0; completed < serverWorkerResetThreshold; completed++ {
-		data, ok := <-s.serverWorkerChannel
+		f, ok := <-s.serverWorkerChannel
 		if !ok {
 			return
 		}
-		s.handleSingleStream(data)
+		f()
 	}
 	go s.serverWorker()
 }
 
-func (s *Server) handleSingleStream(data *serverWorkerData) {
-	defer data.wg.Done()
-	s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
-}
-
 // initServerWorkers creates worker goroutines and a channel to process incoming
 // connections to reduce the time spent overall on runtime.morestack.
 func (s *Server) initServerWorkers() {
-	s.serverWorkerChannel = make(chan *serverWorkerData)
+	s.serverWorkerChannel = make(chan func())
 	for i := uint32(0); i < s.opts.numServerWorkers; i++ {
 		go s.serverWorker()
 	}
@@ -982,21 +975,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
 	defer st.Close(errors.New("finished serving streams for the server transport"))
 	var wg sync.WaitGroup
 
+	streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
 	st.HandleStreams(func(stream *transport.Stream) {
 		wg.Add(1)
+
+		streamQuota.acquire()
+		f := func() {
+			defer streamQuota.release()
+			defer wg.Done()
+			s.handleStream(st, stream, s.traceInfo(st, stream))
+		}
+
 		if s.opts.numServerWorkers > 0 {
-			data := &serverWorkerData{st: st, wg: &wg, stream: stream}
 			select {
-			case s.serverWorkerChannel <- data:
+			case s.serverWorkerChannel <- f:
 				return
 			default:
 				// If all stream workers are busy, fallback to the default code path.
 			}
 		}
-		go func() {
-			defer wg.Done()
-			s.handleStream(st, stream, s.traceInfo(st, stream))
-		}()
+		go f()
 	}, func(ctx context.Context, method string) context.Context {
 		if !EnableTracing {
 			return ctx
@@ -2091,3 +2089,34 @@ func validateSendCompressor(name, clientCompressors string) error {
 	}
 	return fmt.Errorf("client does not support compressor %q", name)
 }
+
+// atomicSemaphore implements a blocking, counting semaphore. acquire should be
+// called synchronously; release may be called asynchronously.
+type atomicSemaphore struct {
+	n    atomic.Int64
+	wait chan struct{}
+}
+
+func (q *atomicSemaphore) acquire() {
+	if q.n.Add(-1) < 0 {
+		// We ran out of quota.  Block until a release happens.
+		<-q.wait
+	}
+}
+
+func (q *atomicSemaphore) release() {
+	// N.B. the "<= 0" check below should allow for this to work with multiple
+	// concurrent calls to acquire, but also note that with synchronous calls to
+	// acquire, as our system does, n will never be less than -1.  There are
+	// fairness issues (queuing) to consider if this was to be generalized.
+	if q.n.Add(1) <= 0 {
+		// An acquire was waiting on us.  Unblock it.
+		q.wait <- struct{}{}
+	}
+}
+
+func newHandlerQuota(n uint32) *atomicSemaphore {
+	a := &atomicSemaphore{wait: make(chan struct{}, 1)}
+	a.n.Store(int64(n))
+	return a
+}
diff --git a/server_ext_test.go b/server_ext_test.go
new file mode 100644
index 000000000000..df79755f3255
--- /dev/null
+++ b/server_ext_test.go
@@ -0,0 +1,99 @@
+/*
+ *
+ * Copyright 2023 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package grpc_test
+
+import (
+	"context"
+	"io"
+	"testing"
+	"time"
+
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/internal/grpcsync"
+	"google.golang.org/grpc/internal/stubserver"
+
+	testgrpc "google.golang.org/grpc/interop/grpc_testing"
+)
+
+// TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server
+// handlers are active at one time.
+func (s) TestServer_MaxHandlers(t *testing.T) {
+	started := make(chan struct{})
+	blockCalls := grpcsync.NewEvent()
+
+	// This stub server does not properly respect the stream context, so it will
+	// not exit when the context is canceled.
+	ss := stubserver.StubServer{
+		FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
+			started <- struct{}{}
+			<-blockCalls.Done()
+			return nil
+		},
+	}
+	if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}); err != nil {
+		t.Fatal("Error starting server:", err)
+	}
+	defer ss.Stop()
+
+	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
+	defer cancel()
+
+	// Start one RPC to the server.
+	ctx1, cancel1 := context.WithCancel(ctx)
+	_, err := ss.Client.FullDuplexCall(ctx1)
+	if err != nil {
+		t.Fatal("Error staring call:", err)
+	}
+
+	// Wait for the handler to be invoked.
+	select {
+	case <-started:
+	case <-ctx.Done():
+		t.Fatalf("Timed out waiting for RPC to start on server.")
+	}
+
+	// Cancel it on the client.  The server handler will still be running.
+	cancel1()
+
+	ctx2, cancel2 := context.WithCancel(ctx)
+	defer cancel2()
+	s, err := ss.Client.FullDuplexCall(ctx2)
+	if err != nil {
+		t.Fatal("Error staring call:", err)
+	}
+
+	// After 100ms, allow the first call to unblock.  That should allow the
+	// second RPC to run and finish.
+	select {
+	case <-started:
+		blockCalls.Fire()
+		t.Fatalf("RPC started unexpectedly.")
+	case <-time.After(100 * time.Millisecond):
+		blockCalls.Fire()
+	}
+
+	select {
+	case <-started:
+	case <-ctx.Done():
+		t.Fatalf("Timed out waiting for second RPC to start on server.")
+	}
+	if _, err := s.Recv(); err != io.EOF {
+		t.Fatal("Received unexpected RPC error:", err)
+	}
+}