Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make COPY TO STDOUT robust to errors #331

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pgserver/arrowwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgserver
import (
"os"
"strings"
"sync/atomic"

"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/apecloud/myduckserver/adapter"
Expand Down Expand Up @@ -62,24 +63,31 @@ func NewArrowWriter(
}, nil
}

func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) {
func (dw *ArrowWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) {
// Execute the statement in a separate goroutine.
ch := make(chan CopyToResult, 1)
go func() {
defer os.Remove(dw.pipePath)
defer close(ch)

dw.ctx.GetLogger().Tracef("Executing statement via Arrow interface: %s", dw.duckSQL)
conn, err := adapter.GetConn(dw.ctx)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}

// If there is a global error, return immediately.
if e := globalErr.Load(); e != nil {
ch <- CopyToResult{Err: *e}
return
}

// Open the pipe for writing.
// This operation will block until the reader opens the pipe for reading.
pipe, err := os.OpenFile(dw.pipePath, os.O_WRONLY, os.ModeNamedPipe)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down Expand Up @@ -114,6 +122,7 @@ func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) {
}
return recordReader.Err()
}); err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down
44 changes: 22 additions & 22 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1414,13 +1414,13 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
}
defer writer.Close()

pipePath, ch, err := writer.Start()
var globalErr atomic.Pointer[error]
pipePath, ch, err := writer.Start(&globalErr)
if err != nil {
return err
}

done := make(chan struct{})
var globalErr atomic.Value
var blocked atomic.Bool
blocked.Store(true)
go func() {
Expand All @@ -1431,7 +1431,8 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe)
blocked.Store(false)
if err != nil {
globalErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err))
err = fmt.Errorf("failed to open pipe for reading: %w", err)
globalErr.Store(&err)
cancel()
return
}
Expand Down Expand Up @@ -1465,39 +1466,39 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo

switch format {
case tree.CopyFormatText:
flag := true
responsed := false
reader := bufio.NewReader(pipe)
for {
line, err := reader.ReadSlice('\n')
if err != nil {
if err == io.EOF {
break
}
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
if flag {
flag = false
if !responsed {
responsed = true
count := bytes.Count(line, []byte{'\t'})
err := sendCopyOutResponse(count + 1)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
}
err = sendCopyData(line)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
}
default:
err := sendCopyOutResponse(1)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
Expand All @@ -1509,14 +1510,14 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
if err == io.EOF {
break
}
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
if n > 0 {
err := sendCopyData(buf[:n])
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
Expand All @@ -1528,30 +1529,29 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
select {
case <-ctx.Done(): // Context is canceled
<-done
err, _ := globalErr.Load().(error)
return errors.Join(ctx.Err(), err)
if errPtr := globalErr.Load(); errPtr != nil {
return errors.Join(ctx.Err(), err)
}
return ctx.Err()
case result := <-ch:
if blocked.Load() {
// If the pipe is still opened for reading but the writer has exited,
// then we need to open the pipe for writing again to unblock the reader.
globalErr.Store(errors.Join(
fmt.Errorf("pipe is opened for reading but the writer has exited"),
result.Err,
))
pipe, _ := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe)
<-done
if pipe != nil {
pipe.Close()
}
} else {
<-done
}

<-done

if result.Err != nil {
return fmt.Errorf("failed to copy data: %w", result.Err)
}

if err, ok := globalErr.Load().(error); ok {
return err
if errPtr := globalErr.Load(); errPtr != nil {
return *errPtr
}

// After data is sent and the producer side is finished without errors, send CopyDone
Expand Down
7 changes: 4 additions & 3 deletions pgserver/datawriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"strings"
"sync/atomic"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/backend"
Expand All @@ -13,7 +14,7 @@ import (
)

type DataWriter interface {
Start() (string, chan CopyToResult, error)
Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error)
Close()
}

Expand Down Expand Up @@ -145,18 +146,18 @@ func NewDuckDataWriter(
}, nil
}

func (dw *DuckDataWriter) Start() (string, chan CopyToResult, error) {
func (dw *DuckDataWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) {
// Execute the COPY TO statement in a separate goroutine.
ch := make(chan CopyToResult, 1)
go func() {
defer os.Remove(dw.pipePath)
defer close(ch)

dw.ctx.GetLogger().Tracef("Executing COPY TO statement: %s", dw.duckSQL)

// This operation will block until the reader opens the pipe for reading.
result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down
1 change: 0 additions & 1 deletion test/bats/postgres/copy_tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ EOF
}

@test "copy error handling" {
skip
# Test copying from non-existent schema
run psql_exec "\copy nonexistent_schema.t TO STDOUT;"
[ "$status" -ne 0 ]
Expand Down
Loading