From 08ef523f280c86efa024e6feee9c8e9ec268b75d Mon Sep 17 00:00:00 2001 From: Adam Eijdenberg Date: Mon, 5 Feb 2024 10:37:01 +1100 Subject: [PATCH] Enhancement: can now listen on a unix socket --- changelog/unreleased/pull-272 | 10 ++++ cmd/rest-server/listener_unix.go | 20 +++++-- cmd/rest-server/listener_unix_test.go | 75 +++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 changelog/unreleased/pull-272 create mode 100644 cmd/rest-server/listener_unix_test.go diff --git a/changelog/unreleased/pull-272 b/changelog/unreleased/pull-272 new file mode 100644 index 0000000..caf405b --- /dev/null +++ b/changelog/unreleased/pull-272 @@ -0,0 +1,10 @@ +Enhancement: can now listen on a unix socket + +If `--listen unix:/tmp/foo` is passed, the server will listen on a unix socket. This is triggered by the prefix `unix:`. + +This is useful in combination with remote port portforwarding to enable remote server to backup locally, e.g. + +```bash +rest-server --listen unix:/tmp/foo & +ssh -R /tmp/foo:/tmp/foo user@host restic -r rest:http+unix:/tmp/foo:/repo backup +``` \ No newline at end of file diff --git a/cmd/rest-server/listener_unix.go b/cmd/rest-server/listener_unix.go index 9e55d29..8350d41 100644 --- a/cmd/rest-server/listener_unix.go +++ b/cmd/rest-server/listener_unix.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "strings" "github.com/coreos/go-systemd/v22/activation" ) @@ -23,12 +24,23 @@ func findListener(addr string) (listener net.Listener, err error) { switch len(listeners) { case 0: // no listeners found, listen manually - listener, err = net.Listen("tcp", addr) - if err != nil { - return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + if strings.HasPrefix(addr, "unix:") { // if we want to listen on a unix socket + unixAddr, err := net.ResolveUnixAddr("unix", strings.TrimPrefix(addr, "unix:")) + if err != nil { + return nil, fmt.Errorf("unable to understand unix address %s: %w", addr, err) + } + listener, err = net.ListenUnix("unix", unixAddr) + if err != nil { + return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + } + } else { // assume tcp + listener, err = net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("listen on %v failed: %w", addr, err) + } } - log.Printf("start server on %v", addr) + log.Printf("start server on %v", listener.Addr()) return listener, nil case 1: diff --git a/cmd/rest-server/listener_unix_test.go b/cmd/rest-server/listener_unix_test.go new file mode 100644 index 0000000..a4f32f4 --- /dev/null +++ b/cmd/rest-server/listener_unix_test.go @@ -0,0 +1,75 @@ +//go:build !windows +// +build !windows + +package main + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestUnixSocket(t *testing.T) { + td := t.TempDir() + + // this is the socket we'll listen on and connect to + tempSocket := filepath.Join(td, "sock") + + // create some content and parent dirs + if err := os.MkdirAll(filepath.Join(td, "data", "repo1"), 0700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(td, "data", "repo1", "config"), []byte("foo"), 0700); err != nil { + t.Fatal(err) + } + + // run the following twice, to test that the server will + // cleanup its socket file when quitting, which won't happen + // if it doesn't exit gracefully + for i := 0; i < 2; i++ { + err := testServerWithArgs([]string{ + "--no-auth", + "--path", filepath.Join(td, "data"), + "--listen", fmt.Sprintf("unix:%s", tempSocket), + }, time.Second, func(ctx context.Context, _ *restServerApp) error { + // custom client that will talk HTTP to unix socket + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", tempSocket) + }, + }, + } + for _, test := range []struct { + Path string + StatusCode int + }{ + {"/repo1/", http.StatusMethodNotAllowed}, + {"/repo1/config", http.StatusOK}, + {"/repo2/config", http.StatusNotFound}, + } { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://ignored"+test.Path, nil) + if err != nil { + return err + } + resp, err := client.Do(req) + if err != nil { + return err + } + resp.Body.Close() + if resp.StatusCode != test.StatusCode { + return fmt.Errorf("expected %d from server, instead got %d (path %s)", test.StatusCode, resp.StatusCode, test.Path) + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + } +}