Skip to content

Commit

Permalink
semaphore: add worker-pool example
Browse files Browse the repository at this point in the history
I've commented several times in various forums that basically every
time I've seen the “worker goroutine” pattern in Go, there has turned
out to be a cleaner implementation using semaphores.

This change adds a simple such example. (For more complex usage, I
would generally pair the semaphore with an errgroup.Group.)

Change-Id: Ibf69ee761d14ba59c1acc6a2d595b4fcf0d8f6d6
Reviewed-on: https://go-review.googlesource.com/75170
Reviewed-by: Ross Light <[email protected]>
  • Loading branch information
Bryan C. Mills committed Nov 1, 2017
1 parent 8e0aa68 commit fd80eb9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 12 deletions.
7 changes: 4 additions & 3 deletions semaphore/semaphore_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

// +build go1.7

package semaphore
package semaphore_test

import (
"fmt"
"testing"

"golang.org/x/net/context"
"golang.org/x/sync/semaphore"
)

// weighted is an interface matching a subset of *Weighted. It allows
Expand Down Expand Up @@ -85,7 +86,7 @@ func BenchmarkNewSeq(b *testing.B) {
for _, cap := range []int64{1, 128} {
b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = NewWeighted(cap)
_ = semaphore.NewWeighted(cap)
}
})
b.Run(fmt.Sprintf("semChan-%d", cap), func(b *testing.B) {
Expand Down Expand Up @@ -116,7 +117,7 @@ func BenchmarkAcquireSeq(b *testing.B) {
name string
w weighted
}{
{"Weighted", NewWeighted(c.cap)},
{"Weighted", semaphore.NewWeighted(c.cap)},
{"semChan", newSemChan(c.cap)},
} {
b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) {
Expand Down
84 changes: 84 additions & 0 deletions semaphore/semaphore_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package semaphore_test

import (
"context"
"fmt"
"log"
"runtime"

"golang.org/x/sync/semaphore"
)

// Example_workerPool demonstrates how to use a semaphore to limit the number of
// goroutines working on parallel tasks.
//
// This use of a semaphore mimics a typical “worker pool” pattern, but without
// the need to explicitly shut down idle workers when the work is done.
func Example_workerPool() {
ctx := context.TODO()

var (
maxWorkers = runtime.GOMAXPROCS(0)
sem = semaphore.NewWeighted(int64(maxWorkers))
out = make([]int, 32)
)

// Compute the output using up to maxWorkers goroutines at a time.
for i := range out {
// When maxWorkers goroutines are in flight, Acquire blocks until one of the
// workers finishes.
if err := sem.Acquire(ctx, 1); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
break
}

go func(i int) {
defer sem.Release(1)
out[i] = collatzSteps(i + 1)
}(i)
}

// Acquire all of the tokens to wait for any remaining workers to finish.
//
// If you are already waiting for the workers by some other means (such as an
// errgroup.Group), you can omit this final Acquire call.
if err := sem.Acquire(ctx, int64(maxWorkers)); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
}

fmt.Println(out)

// Output:
// [0 1 7 2 5 8 16 3 19 6 14 9 9 17 17 4 12 20 20 7 7 15 15 10 23 10 111 18 18 18 106 5]
}

// collatzSteps computes the number of steps to reach 1 under the Collatz
// conjecture. (See https://en.wikipedia.org/wiki/Collatz_conjecture.)
func collatzSteps(n int) (steps int) {
if n <= 0 {
panic("nonpositive input")
}

for ; n > 1; steps++ {
if steps < 0 {
panic("too many steps")
}

if n%2 == 0 {
n /= 2
continue
}

const maxInt = int(^uint(0) >> 1)
if n > (maxInt-1)/3 {
panic("overflow")
}
n = 3*n + 1
}

return steps
}
19 changes: 10 additions & 9 deletions semaphore/semaphore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package semaphore
package semaphore_test

import (
"math/rand"
Expand All @@ -13,11 +13,12 @@ import (

"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

const maxSleep = 1 * time.Millisecond

func HammerWeighted(sem *Weighted, n int64, loops int) {
func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) {
for i := 0; i < loops; i++ {
sem.Acquire(context.Background(), n)
time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond)
Expand All @@ -30,7 +31,7 @@ func TestWeighted(t *testing.T) {

n := runtime.GOMAXPROCS(0)
loops := 10000 / n
sem := NewWeighted(int64(n))
sem := semaphore.NewWeighted(int64(n))
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
Expand All @@ -51,15 +52,15 @@ func TestWeightedPanic(t *testing.T) {
t.Fatal("release of an unacquired weighted semaphore did not panic")
}
}()
w := NewWeighted(1)
w := semaphore.NewWeighted(1)
w.Release(1)
}

func TestWeightedTryAcquire(t *testing.T) {
t.Parallel()

ctx := context.Background()
sem := NewWeighted(2)
sem := semaphore.NewWeighted(2)
tries := []bool{}
sem.Acquire(ctx, 1)
tries = append(tries, sem.TryAcquire(1))
Expand All @@ -83,7 +84,7 @@ func TestWeightedAcquire(t *testing.T) {
t.Parallel()

ctx := context.Background()
sem := NewWeighted(2)
sem := semaphore.NewWeighted(2)
tryAcquire := func(n int64) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()
Expand Down Expand Up @@ -113,7 +114,7 @@ func TestWeightedDoesntBlockIfTooBig(t *testing.T) {
t.Parallel()

const n = 2
sem := NewWeighted(n)
sem := semaphore.NewWeighted(n)
{
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -132,7 +133,7 @@ func TestWeightedDoesntBlockIfTooBig(t *testing.T) {
})
}
if err := g.Wait(); err != nil {
t.Errorf("NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1)
t.Errorf("semaphore.NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1)
}
}

Expand All @@ -143,7 +144,7 @@ func TestLargeAcquireDoesntStarve(t *testing.T) {

ctx := context.Background()
n := int64(runtime.GOMAXPROCS(0))
sem := NewWeighted(n)
sem := semaphore.NewWeighted(n)
running := true

var wg sync.WaitGroup
Expand Down

0 comments on commit fd80eb9

Please sign in to comment.