diff --git a/cmd/demoserver/main_test.go b/cmd/demoserver/main_test.go index 5b22b5a..d643d2d 100644 --- a/cmd/demoserver/main_test.go +++ b/cmd/demoserver/main_test.go @@ -20,12 +20,12 @@ import ( "io" "net/http" "net/http/httptest" - "sync" "testing" "connectrpc.com/connect" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" elizav1 "connect-examples-go/internal/gen/connectrpc/eliza/v1" "connect-examples-go/internal/gen/connectrpc/eliza/v1/elizav1connect" @@ -66,43 +66,37 @@ func TestElizaServer(t *testing.T) { for _, client := range clients { sendValues := []string{"Hello!", "How are you doing?", "I have an issue with my bike", "bye"} var receivedValues []string - stream := client.Converse(context.Background()) - var wg sync.WaitGroup - wg.Add(2) - errs := make(chan error, 4) - go func() { - defer wg.Done() + grp, ctx := errgroup.WithContext(context.Background()) + stream := client.Converse(ctx) + grp.Go(func() error { for _, sentence := range sendValues { err := stream.Send(&elizav1.ConverseRequest{Sentence: sentence}) - errs <- err + if err != nil { + return err + } } err := stream.CloseRequest() - errs <- err - }() - go func() { - defer wg.Done() + if err != nil { + return err + } + return nil + }) + grp.Go(func() error { for { msg, err := stream.Receive() if errors.Is(err, io.EOF) { break } - errs <- err assert.NotEmpty(t, msg.GetSentence()) receivedValues = append(receivedValues, msg.GetSentence()) } err := stream.CloseResponse() - errs <- err - }() - // close errs once all children finish. - go func() { - wg.Wait() - close(errs) - }() - for err := range errs { if err != nil { - t.Fatal(err) + return err } - } + return nil + }) + require.NoError(t, grp.Wait()) assert.Equal(t, len(receivedValues), len(sendValues)) } }) diff --git a/go.mod b/go.mod index 1fa407f..1e24637 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/text v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7016f85..b409c3e 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=