diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index 440f9edee10c..e9c1f18a95d8 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -191,7 +191,7 @@ func (r *commandResult) SetError(err error) { // addInternal is the skeleton of AddRow and AddBatch implementations. // bufferData should update rowsAffected and buffer the data accordingly. -func (r *commandResult) addInternal(bufferData func()) error { +func (r *commandResult) addInternal(bufferData func() error) error { r.assertNotReleased() if r.err != nil { panic(errors.AssertionFailedf("can't call AddRow after having set error: %s", @@ -205,30 +205,29 @@ func (r *commandResult) addInternal(bufferData func()) error { panic("can't send row after error") } - bufferData() - - var err error - if r.bufferingDisabled { - err = r.conn.Flush(r.pos) - } else { - _ /* flushed */, err = r.conn.maybeFlush(r.pos) + if err := bufferData(); err != nil { + return err } - return err + + return r.conn.maybeFlush(r.pos, r.bufferingDisabled) } // AddRow is part of the sql.RestrictedCommandResult interface. func (r *commandResult) AddRow(ctx context.Context, row tree.Datums) error { - return r.addInternal(func() { + return r.addInternal(func() error { r.rowsAffected++ r.conn.bufferRow(ctx, row, r.formatCodes, r.conv, r.location, r.types) + return nil }) } // AddBatch is part of the sql.RestrictedCommandResult interface. func (r *commandResult) AddBatch(ctx context.Context, batch coldata.Batch) error { - return r.addInternal(func() { + return r.addInternal(func() error { r.rowsAffected += batch.Length() - r.conn.bufferBatch(ctx, batch, r.formatCodes, r.conv, r.location) + return r.conn.bufferBatch( + ctx, batch, r.formatCodes, r.conv, r.location, r.pos, r.bufferingDisabled, + ) }) } @@ -439,10 +438,7 @@ func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) erro return r.moreResultsNeeded(ctx) } - if _ /* flushed */, err := r.conn.maybeFlush(r.pos); err != nil { - return err - } - return nil + return r.conn.maybeFlush(r.pos, r.bufferingDisabled) } // SupportsAddBatch is part of the sql.RestrictedCommandResult interface. diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 3953c27c219f..0923a73468fd 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -1281,7 +1281,8 @@ func (c *conn) bufferRow( } // bufferBatch serializes a batch and adds all the rows from it to the buffer. -// It is a noop for zero-length batch. +// It is a noop for zero-length batch. Depending on the buffer size limit, +// bufferBatch may flush the buffered data to the connection. // // formatCodes describes the desired encoding for each column. It can be nil, in // which case all columns are encoded using the text encoding. Otherwise, it @@ -1292,7 +1293,9 @@ func (c *conn) bufferBatch( formatCodes []pgwirebase.FormatCode, conv sessiondatapb.DataConversionConfig, sessionLoc *time.Location, -) { + pos sql.CmdPos, + bufferingDisabled bool, +) error { sel := batch.Selection() n := batch.Length() vecs := batch.ColVecs() @@ -1321,7 +1324,11 @@ func (c *conn) bufferBatch( if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { panic(fmt.Sprintf("unexpected err from buffer: %s", err)) } + if err := c.maybeFlush(pos, bufferingDisabled); err != nil { + return err + } } + return nil } func (c *conn) bufferReadyForQuery(txnStatus byte) { @@ -1524,12 +1531,12 @@ func (c *conn) Flush(pos sql.CmdPos) error { } // maybeFlush flushes the buffer to the network connection if it exceeded -// sessionArgs.ConnResultsBufferSize. -func (c *conn) maybeFlush(pos sql.CmdPos) (bool, error) { - if int64(c.writerState.buf.Len()) <= c.sessionArgs.ConnResultsBufferSize { - return false, nil +// sessionArgs.ConnResultsBufferSize or if buffering is disabled. +func (c *conn) maybeFlush(pos sql.CmdPos, bufferingDisabled bool) error { + if !bufferingDisabled && int64(c.writerState.buf.Len()) <= c.sessionArgs.ConnResultsBufferSize { + return nil } - return true, c.Flush(pos) + return c.Flush(pos) } // LockCommunication is part of the ClientComm interface.