diff --git a/nomad/util.go b/nomad/util.go index f3de97b9e9b..8b8594661df 100644 --- a/nomad/util.go +++ b/nomad/util.go @@ -330,12 +330,21 @@ func dropButLastChannel(sourceCh <-chan bool, shutdownCh <-chan struct{}) chan b dst := make(chan bool) go func() { + // last value received lv := false + // ok source was closed + ok := false + // received message since last delivery to destination + messageReceived := false DEQUE_SOURCE: // wait for first message select { - case lv = <-sourceCh: + case lv, ok = <-sourceCh: + if !ok { + goto SOURCE_CLOSED + } + messageReceived = true goto ENQUEUE_DST case <-shutdownCh: return @@ -345,7 +354,11 @@ func dropButLastChannel(sourceCh <-chan bool, shutdownCh <-chan struct{}) chan b // prioritize draining source first dequeue without blocking for { select { - case lv = <-sourceCh: + case lv, ok = <-sourceCh: + if !ok { + goto SOURCE_CLOSED + } + messageReceived = true default: break ENQUEUE_DST } @@ -353,14 +366,29 @@ func dropButLastChannel(sourceCh <-chan bool, shutdownCh <-chan struct{}) chan b // attempt to enqueue but keep monitoring source channel select { - case lv = <-sourceCh: + case lv, ok = <-sourceCh: + if !ok { + goto SOURCE_CLOSED + } + messageReceived = true goto ENQUEUE_DST case dst <- lv: + messageReceived = false // enqueued value; back to dequeing from source goto DEQUE_SOURCE case <-shutdownCh: return } + + SOURCE_CLOSED: + if messageReceived { + select { + case dst <- lv: + case <-shutdownCh: + return + } + } + close(dst) }() return dst diff --git a/nomad/util_test.go b/nomad/util_test.go index 145ef68d890..67bd5cab93a 100644 --- a/nomad/util_test.go +++ b/nomad/util_test.go @@ -379,3 +379,55 @@ RECEIVE_LOOP: require.Equal(t, 1, receivedFalse) require.LessOrEqual(t, receivedTrue, sentMessages-1) } + +// TestDropButLastChannel_DeliversMessages_Close asserts that last +// message is always delivered, some messages are dropped but never +// introduce new messages, even with a closed signal. +func TestDropButLastChannel_DeliversMessages_Close(t *testing.T) { + sourceCh := make(chan bool) + shutdownCh := make(chan struct{}) + + dstCh := dropButLastChannel(sourceCh, shutdownCh) + + // timeout duration for any channel propagation delay + timeoutDuration := 5 * time.Millisecond + + sentMessages := 100 + go func() { + for i := 0; i < sentMessages-1; i++ { + sourceCh <- true + } + sourceCh <- false + close(sourceCh) + }() + + receivedTrue, receivedFalse := 0, 0 + var lastReceived *bool + +RECEIVE_LOOP: + for { + select { + case v, ok := <-dstCh: + if !ok { + break RECEIVE_LOOP + } + lastReceived = &v + if v { + receivedTrue++ + } else { + receivedFalse++ + } + + case <-time.After(timeoutDuration): + require.Fail(t, "timed out while waiting for messages") + } + } + + t.Logf("receiver got %v out %v true messages, and %v out of %v false messages", + receivedTrue, sentMessages-1, receivedFalse, 1) + + require.NotNil(t, lastReceived) + require.False(t, *lastReceived) + require.Equal(t, 1, receivedFalse) + require.LessOrEqual(t, receivedTrue, sentMessages-1) +}