diff --git a/mongo/integration/crud_prose_test.go b/mongo/integration/crud_prose_test.go index 8fa17aa92c..4b74e8db41 100644 --- a/mongo/integration/crud_prose_test.go +++ b/mongo/integration/crud_prose_test.go @@ -145,19 +145,17 @@ func TestHintErrors(t *testing.T) { }) } -func TestAggregateSecondaryPreferredReadPreference(t *testing.T) { - // Use secondaryPreferred instead of secondary because sharded clusters started up by mongo-orchestration have - // one-node shards, so a secondary read preference is not satisfiable. - secondaryPrefClientOpts := options.Client(). +func TestAggregatePrimaryPreferredReadPreference(t *testing.T) { + primaryPrefClientOpts := options.Client(). SetWriteConcern(mtest.MajorityWc). - SetReadPreference(readpref.SecondaryPreferred()). + SetReadPreference(readpref.PrimaryPreferred()). SetReadConcern(mtest.MajorityRc) mtOpts := mtest.NewOptions(). - ClientOptions(secondaryPrefClientOpts). + ClientOptions(primaryPrefClientOpts). MinServerVersion("4.1.0") // Consistent with tests in aggregate-out-readConcern.json mt := mtest.New(t, mtOpts) - mt.Run("aggregate $out with read preference secondary", func(mt *mtest.T) { + mt.Run("aggregate $out with non-primary read ppreference", func(mt *mtest.T) { doc, err := bson.Marshal(bson.D{ {"_id", 1}, {"x", 11}, @@ -167,7 +165,7 @@ func TestAggregateSecondaryPreferredReadPreference(t *testing.T) { assert.Nil(mt, err, "InsertOne error: %v", err) mt.ClearEvents() - outputCollName := "aggregate-read-pref-secondary-output" + outputCollName := "aggregate-read-pref-primary-preferred-output" outStage := bson.D{ {"$out", outputCollName}, } diff --git a/mongo/integration/errors_test.go b/mongo/integration/errors_test.go index 8e6c6772fd..822cc2908e 100644 --- a/mongo/integration/errors_test.go +++ b/mongo/integration/errors_test.go @@ -80,14 +80,16 @@ func TestErrors(t *testing.T) { "errors.Is failure: expected error %v to be %v", err, context.DeadlineExceeded) }) - socketTimeoutOpts := options.Client(). - SetSocketTimeout(100 * time.Millisecond) - socketTimeoutMtOpts := mtest.NewOptions(). - ClientOptions(socketTimeoutOpts) - mt.RunOpts("socketTimeoutMS timeouts return network errors", socketTimeoutMtOpts, func(mt *mtest.T) { + mt.Run("socketTimeoutMS timeouts return network errors", func(mt *mtest.T) { _, err := mt.Coll.InsertOne(mtest.Background, bson.D{{"x", 1}}) assert.Nil(mt, err, "InsertOne error: %v", err) + // Reset the test client to have a 100ms socket timeout. We do this here rather than passing it in as a + // test option using mt.RunOpts because that could cause the collection creation or InsertOne to fail. + resetClientOpts := options.Client(). + SetSocketTimeout(100 * time.Millisecond) + mt.ResetClient(resetClientOpts) + mt.ClearEvents() filter := bson.M{ "$where": "function() { sleep(1000); return false; }", diff --git a/mongo/integration/sessions_test.go b/mongo/integration/sessions_test.go index ba419424cf..6d66f699a8 100644 --- a/mongo/integration/sessions_test.go +++ b/mongo/integration/sessions_test.go @@ -252,43 +252,27 @@ func TestSessions(t *testing.T) { deleteID := extractSentSessionID(mt) assert.Equal(mt, findID, deleteID, "expected session ID %v, got %v", findID, deleteID) }) - mt.RunOpts("find and getMore use same ID", noClientOpts, func(mt *mtest.T) { - testCases := []struct { - name string - rp *readpref.ReadPref - topos []mtest.TopologyKind // if nil, all will be used - }{ - {"primary", readpref.Primary(), nil}, - {"primaryPreferred", readpref.PrimaryPreferred(), nil}, - {"secondary", readpref.Secondary(), []mtest.TopologyKind{mtest.ReplicaSet}}, - {"secondaryPreferred", readpref.SecondaryPreferred(), nil}, - {"nearest", readpref.Nearest(), nil}, + mt.Run("find and getMore use same ID", func(mt *mtest.T) { + var docs []interface{} + for i := 0; i < 3; i++ { + docs = append(docs, bson.D{{"x", i}}) } - for _, tc := range testCases { - clientOpts := options.Client().SetReadPreference(tc.rp).SetWriteConcern(mtest.MajorityWc) - mt.RunOpts(tc.name, mtest.NewOptions().ClientOptions(clientOpts).Topologies(tc.topos...), func(mt *mtest.T) { - var docs []interface{} - for i := 0; i < 3; i++ { - docs = append(docs, bson.D{{"x", i}}) - } - _, err := mt.Coll.InsertMany(mtest.Background, docs) - assert.Nil(mt, err, "InsertMany error: %v", err) - - // run a find that will hold onto an implicit session and record the session ID - mt.ClearEvents() - cursor, err := mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(2)) - assert.Nil(mt, err, "Find error: %v", err) - findID := extractSentSessionID(mt) - assert.NotNil(mt, findID, "expected session ID for find, got nil") - - // iterate over all documents and record the session ID of the getMore - for i := 0; i < 3; i++ { - assert.True(mt, cursor.Next(mtest.Background), "Next returned false on iteration %v", i) - } - getMoreID := extractSentSessionID(mt) - assert.Equal(mt, findID, getMoreID, "expected session ID %v, got %v", findID, getMoreID) - }) + _, err := mt.Coll.InsertMany(mtest.Background, docs) + assert.Nil(mt, err, "InsertMany error: %v", err) + + // run a find that will hold onto an implicit session and record the session ID + mt.ClearEvents() + cursor, err := mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(2)) + assert.Nil(mt, err, "Find error: %v", err) + findID := extractSentSessionID(mt) + assert.NotNil(mt, findID, "expected session ID for find, got nil") + + // iterate over all documents and record the session ID of the getMore + for i := 0; i < 3; i++ { + assert.True(mt, cursor.Next(mtest.Background), "Next returned false on iteration %v", i) } + getMoreID := extractSentSessionID(mt) + assert.Equal(mt, findID, getMoreID, "expected session ID %v, got %v", findID, getMoreID) }) mt.Run("imperative API", func(mt *mtest.T) { diff --git a/x/mongo/driver/topology/cmap_prose_test.go b/x/mongo/driver/topology/cmap_prose_test.go index 18504eaa57..37c987347f 100644 --- a/x/mongo/driver/topology/cmap_prose_test.go +++ b/x/mongo/driver/topology/cmap_prose_test.go @@ -78,7 +78,15 @@ func TestCMAPProse(t *testing.T) { _, err := pool.get(context.Background()) assert.NotNil(t, err, "expected get() error, got nil") - assertConnectionCounts(t, pool, 1, 1) + + // If the connection doesn't finish connecting before resourcePool gives it back, the error will be + // detected by pool.get and result in a created/closed count of 1. If it does finish connecting, the + // error will be detected by resourcePool, which will return nil. Then, pool will try to create a new + // connection, which will also error. This process will result in a created/closed count of 2. + assert.True(t, len(created) == 1 || len(created) == 2, "expected 1 or 2 opened events, got %d", len(created)) + assert.True(t, len(closed) == 1 || len(closed) == 2, "expected 1 or 2 closed events, got %d", len(closed)) + netCount := len(created) - len(closed) + assert.Equal(t, 0, netCount, "expected net connection count to be 0, got %d", netCount) }) t.Run("pool is empty", func(t *testing.T) { // If a new connection is created during get(), get() should report that error and publish an event. diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 25222ad95e..53c4cb4f2d 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -115,21 +115,65 @@ func TestConnection(t *testing.T) { assert.NotNil(t, err, "expected connect error %v, got nil", want) assert.Equal(t, want, got, "expected error %v, got %v", want, got) }) - t.Run("cancelConnectContext is nil after connect", func(t *testing.T) { - conn, err := newConnection(address.Address("")) - assert.Nil(t, err, "newConnection shouldn't error. got %v; want nil", err) - var wg sync.WaitGroup - wg.Add(1) + t.Run("context is not pinned by connect", func(t *testing.T) { + // connect creates a cancel-able version of the context passed to it and stores the CancelFunc on the + // connection. The CancelFunc must be set to nil once the connection has been established so the driver + // does not pin the memory associated with the context for the connection's lifetime. + + t.Run("connect succeeds", func(t *testing.T) { + // In the case where connect finishes successfully, it unpins the CancelFunc. + + conn, err := newConnection(address.Address(""), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return &net.TCPConn{}, nil + }) + }), + WithHandshaker(func(Handshaker) Handshaker { + return &testHandshaker{} + }), + ) + assert.Nil(t, err, "newConnection error: %v", err) - go func() { - defer wg.Done() conn.connect(context.Background()) - assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc") - }() - - conn.closeConnectContext() - assert.Nil(t, conn.cancelConnectContext, "expected nil, got context.CancelFunc") - wg.Wait() + err = conn.wait() + assert.Nil(t, err, "error establishing connection: %v", err) + assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") + }) + t.Run("connect cancelled", func(t *testing.T) { + // In the case where connection establishment is cancelled, the closeConnectContext function + // unpins the CancelFunc. + + // Create a connection that will block in connect until doneChan is closed. This prevents + // connect from succeeding and unpinning the CancelFunc. + doneChan := make(chan struct{}) + conn, err := newConnection(address.Address(""), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + <-doneChan + return &net.TCPConn{}, nil + }) + }), + WithHandshaker(func(Handshaker) Handshaker { + return &testHandshaker{} + }), + ) + assert.Nil(t, err, "newConnection error: %v", err) + + // Call connect in a goroutine because it will block. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + conn.connect(context.Background()) + }() + + // Simulate cancelling connection establishment and assert that this cleares the CancelFunc. + conn.closeConnectContext() + assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared") + close(doneChan) + wg.Wait() + }) }) }) t.Run("writeWireMessage", func(t *testing.T) {