diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 0000000..24660a8 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,13 @@ +version = 1 + +[[analyzers]] +name = "test-coverage" + +[[analyzers]] +name = "go" + + [analyzers.meta] + import_root = "github.com/gaukas/water" + +[[transformers]] +name = "gofumpt" \ No newline at end of file diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2e320cb..02941fd 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,16 +14,17 @@ jobs: strategy: fail-fast: false matrix: - os: [ "ubuntu-latest", "windows-latest", "macos-latest" ] - # go: [ "1.20.x", "1.21.x" ] + # os: [ "ubuntu-latest", "windows-latest", "macos-latest" ] # Windows is not supported until net library implements Fd() for Windows + os: [ "ubuntu-latest", "macos-latest" ] + go: [ "1.20.x", "1.21.x" ] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v4 with: - go-version: "1.21.x" + go-version: ${{ matrix.go }} - run: go version - name: Build run: go build -v ./... - name: Test - run: go test ./... + run: go test -failfast ./... diff --git a/README.md b/README.md index 801377b..6f31c75 100644 --- a/README.md +++ b/README.md @@ -5,39 +5,21 @@ W.A.T.E.R. provides a runtime environment for WebAssembly modules to run in and ## API -Currently, W.A.T.E.R. provides a set of APIs relying on **WASI Preview 1 (wasip1)** snapshot. +Currently, W.A.T.E.R. provides a set of APIs based on **WASI Preview 1 (wasip1)** snapshot. ### Config A `Config` is a struct that contains the configuration for a WASI instance. It is used to configure the WASI reactor before starting it. -### RuntimeConn -A `RuntimeConn` is a `Conn` that represents a connection from the local user to a remote peer. Each living `RuntimeConn` encapsulates a running WASI instance. -It process the data sent from the local user and send it to the remote peer, and vice versa. +### Dialer -A `RuntimeConn` interfaces `io.ReadWriteCloser` and is always and only spawned by a `RuntimeConnDialer`. +A `Dialer` could be used to dial a remote address upon `Dial()` and return a `net.Conn` back to the caller once the connection is established. Caller could use the `net.Conn` to read and write data to the remote address and the data will be processed by a WebAssembly instance. -#### RuntimeConnDialer -A `RuntimeConnDialer` is a `Dialer` loaded with a `Config` that can dial for `RuntimeConn` as abstracted connections. Currently, it is just a wrapper around a `Config`. **It does not contain any running WASI instance.** +### Listener -### RuntimeDialer _(TODO)_ -A `RuntimeDialer` is a `Dialer` that dials for `RuntimeDialerConn`. Each living `RuntimeDialer` encapsulates a running WASI instance. It manages multiple `RuntimeDialerConn` instances created upon caller's request. +A `Listener` could be used to listen on a local address. Upon `Accept()`, it returns a `net.Conn` back once an incoming connection is accepted from the wrapped listener. Caller could use the `net.Conn` to read and write data to the remote address and the data will be processed by a WebAssembly instance. -\* Not to be confused with [`RuntimeConnDialer`](#runtimeconndialer), a static dialer which creates `RuntimeConn` instances from `Config`. +### Server -#### RuntimeDialerConn -A `RuntimeDialerConn` is a sub-`Conn` spawned by a `RuntimeDialer` upon caller's request. It is a `Conn` that is dialed by a `RuntimeDialer` and is used to communicate with a remote peer. Multiple `RuntimeDialerConn` instances can be created from a single `RuntimeDialer`, which means they could be related to one single WASI instance. +A `Server` somewhat combines the role of `Dialer` and `Listener`. It could be used to listen on a local address and dial a remote address and automatically `Accept()` the incoming connections, feed them into the WebAssembly instance and `Dial()` the pre-defined remote address. Without any caller interaction, the `Server` will automatically* handle the data transmission between the two ends. -\* Not to be confused with [`RuntimeConn`](#runtimeconn), an `io.ReadWriteCloser` that encapsulates a running WASI instance each. - -## TODOs - -- W.A.T.E.R. API - - [x] `Config` - - [x] `RuntimeConn` - - [x] `RuntimeConnDialer` - - [ ] `RuntimeDialer` - - [ ] `RuntimeDialerConn` -- [x] Minimal W.A.T.E.R. WASI example - - No background worker threads -- [ ] Multi-threaded W.A.T.E.R. WASI example - - [ ] Background worker threads working +***TODO: Server could not be realistic until WASI multi-threading or blocking mainloop is supported** \ No newline at end of file diff --git a/config.go b/config.go index 31355fb..94be183 100644 --- a/config.go +++ b/config.go @@ -1,21 +1,124 @@ package water +import ( + "net" + "os" + + "github.com/gaukas/water/internal/log" + "github.com/gaukas/water/internal/wasm" +) + type Config struct { - // WASI contains the compiled WASI binary in bytes. - WASI []byte + // WATMBin contains the binary format of the WebAssembly Transport Module. + // In a typical use case, this mandatory field is populated by loading + // from a .wasm file, downloaded from a remote target, or generated from + // a .wat (WebAssembly Text Format) file. + WATMBin []byte + + // DialerFunc specifies a func that dials the specified address on the + // named network. This optional field can be set to override the Go + // default dialer func: + // net.Dial(network, address) + DialerFunc func(network, address string) (net.Conn, error) + + // NetworkListener specifies a net.listener implementation that listens + // on the specified address on the named network. This optional field + // will be used to provide (incoming) network connections from a + // presumably remote source to the WASM instance. Required by + // ListenConfig(). + NetworkListener net.Listener + + // Feature specifies a series of experimental features for the WASM + // runtime. + // + // Each feature flag is bit-masked and version-dependent, and flags + // are independent of each other. This means that a particular + // feature flag may be supported in one version of the runtime but + // not in another. If a feature flag is not supported or not recognized + // by the runtime, it will be silently ignored. + Feature Feature + + // WATMConfig optionally provides a configuration file to be pushed into + // the WASM Transport Module. + WATMConfig WATMConfig + + // wasiConfigFactory is used to replicate the WASI config for each WASM + // instance created. This field is for advanced use cases and/or debugging + // purposes only. + // + // Caller is supposed to call c.WASIConfig() to get the pointer to the + // WASIConfigFactory. If the pointer is nil, a new WASIConfigFactory will + // be created and returned. + wasiConfigFactory *wasm.WASIConfigFactory +} - // Dialer is used to dial a network connection. - Dialer Dialer +func (c *Config) Clone() *Config { + if c == nil { + return nil + } + + wasmClone := make([]byte, len(c.WATMBin)) + copy(wasmClone, c.WATMBin) + + return &Config{ + WATMBin: c.WATMBin, + DialerFunc: c.DialerFunc, + NetworkListener: c.NetworkListener, + Feature: c.Feature, + WATMConfig: c.WATMConfig, + wasiConfigFactory: c.wasiConfigFactory.Clone(), + } } -// init() checks if the Config is valid and initializes -// the Config with default values if optional fields are not provided. -func (c *Config) init() { - if len(c.WASI) == 0 { - panic("water: WASI binary is not provided") +func (c *Config) DialerFuncOrDefault() func(network, address string) (net.Conn, error) { + if c.DialerFunc == nil { + return net.Dial } - if c.Dialer == nil { - c.Dialer = DefaultDialer() + return c.DialerFunc +} + +func (c *Config) NetworkListenerOrPanic() net.Listener { + if c.NetworkListener == nil { + panic("water: network listener is not provided in config") } + + return c.NetworkListener +} + +func (c *Config) WATMBinOrPanic() []byte { + if len(c.WATMBin) == 0 { + panic("water: WebAssembly Transport Module binary is not provided in config") + } + + return c.WATMBin +} + +func (c *Config) WASIConfig() *wasm.WASIConfigFactory { + if c.wasiConfigFactory == nil { + c.wasiConfigFactory = wasm.NewWasiConfigFactory() + } + + return c.wasiConfigFactory +} + +// WATMConfig defines the configuration file used by the WebAssembly Transport Module. +type WATMConfig struct { + FilePath string // Path to the config file. +} + +// File opens the config file and returns the file descriptor. +func (c *WATMConfig) File() *os.File { + if c.FilePath == "" { + log.Errorf("water: WASM config file path is not provided in config") + return nil + } + + f, err := os.Open(c.FilePath) + if err != nil { + log.Errorf("water: failed to open WATM config file: %v", err) + return nil + } + + return f } diff --git a/conn_generic.go b/conn_generic.go new file mode 100644 index 0000000..51a6485 --- /dev/null +++ b/conn_generic.go @@ -0,0 +1,80 @@ +package water + +import ( + "fmt" + "net" + "time" +) + +var mapCoreDialContext = make(map[string]func(core *core, network, address string) (Conn, error)) +var mapCoreAccept = make(map[string]func(*core) (Conn, error)) + +// Conn is an abstracted connection interface which encapsulates +// a WASM runtime core. +type Conn interface { + net.Conn + + // For forward compatibility with any new methods added to the + // interface, all Conn implementations MUST embed the + // UnimplementedConn in order to make sure they could be used + // in the future without any code change. + mustEmbedUnimplementedConn() +} + +func RegisterDial(version string, dialContext func(core *core, network, address string) (Conn, error)) error { + if _, ok := mapCoreDialContext[version]; ok { + return fmt.Errorf("water: core dial context already registered for version %s", version) + } + mapCoreDialContext[version] = dialContext + return nil +} + +func RegisterAccept(version string, accept func(*core) (Conn, error)) error { + if _, ok := mapCoreAccept[version]; ok { + return fmt.Errorf("water: core accept already registered for version %s", version) + } + mapCoreAccept[version] = accept + return nil +} + +// UnimplementedConn is used to provide forward compatibility for +// implementations of Conn, such that if new methods are added +// to the interface, old implementations will not be required to implement +// each of them. +type UnimplementedConn struct{} + +func (*UnimplementedConn) Read([]byte) (int, error) { + return 0, fmt.Errorf("water: Read() is not implemented") +} + +func (*UnimplementedConn) Write([]byte) (int, error) { + return 0, fmt.Errorf("water: Write() is not implemented") +} + +func (*UnimplementedConn) Close() error { + return fmt.Errorf("water: Close() is not implemented") +} + +func (*UnimplementedConn) LocalAddr() net.Addr { + return nil +} + +func (*UnimplementedConn) RemoteAddr() net.Addr { + return nil +} + +func (*UnimplementedConn) SetDeadline(_ time.Time) error { + return fmt.Errorf("water: SetDeadline() is not implemented") +} + +func (*UnimplementedConn) SetReadDeadline(_ time.Time) error { + return fmt.Errorf("water: SetReadDeadline() is not implemented") +} + +func (*UnimplementedConn) SetWriteDeadline(_ time.Time) error { + return fmt.Errorf("water: SetWriteDeadline() is not implemented") +} + +func (*UnimplementedConn) mustEmbedUnimplementedConn() {} + +var _ Conn = (*UnimplementedConn)(nil) diff --git a/conn_v0.go b/conn_v0.go new file mode 100644 index 0000000..c9af7bd --- /dev/null +++ b/conn_v0.go @@ -0,0 +1,239 @@ +//go:build !nov0 + +package water + +import ( + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/gaukas/water/internal/socket" + v0 "github.com/gaukas/water/internal/v0" + "github.com/gaukas/water/internal/wasm" +) + +func init() { + RegisterDial("_v0", DialV0) + RegisterAccept("_v0", AcceptV0) +} + +// ConnV0 is the first version of RuntimeConn. +type ConnV0 struct { + networkConn net.Conn // network-facing net.Conn, data written to this connection will be sent on the wire + uoConn net.Conn // user-oriented net.Conn, user Read()/Write() to this connection + + wasm *WASMv0 + + UnimplementedConn // embedded to ensure forward compatibility +} + +// DialV0 dials the network address using through the WASM module +// while using the dialerFunc specified in core.config. +func DialV0(core *core, network, address string) (c Conn, err error) { + wasm := NewWASMv0(core) + conn := &ConnV0{ + wasm: wasm, + } + + dialer := v0.MakeWASIDialer(network, address, core.Config().DialerFuncOrDefault()) + + if err = conn.wasm.LinkNetworkInterface(dialer, nil); err != nil { + return nil, err + } + + if err = conn.wasm.Initialize(); err != nil { + return nil, err + } + + if conn.wasm._dial == nil { + return nil, fmt.Errorf("water: WASM module does not export _dial") + } + + // Initialize WASM module as ReadWriter + if err = conn.wasm.InitializeReadWriter(); err != nil { + return nil, err + } + + var wasmCallerConn net.Conn + wasmCallerConn, conn.uoConn, err = socket.UnixConnPair("") + if err != nil { + return nil, fmt.Errorf("water: socket.UnixConnPair returned error: %w", err) + } + + wasmNetworkConn, err := conn.wasm.DialFrom(wasmCallerConn) + if err != nil { + return nil, err + } + + conn.networkConn = wasmNetworkConn + + return conn, nil +} + +// AcceptV0 accepts the network connection using through the WASM module +// while using the net.Listener specified in core.config. +func AcceptV0(core *core) (c Conn, err error) { + wasm := NewWASMv0(core) + conn := &ConnV0{ + wasm: wasm, + } + + listener := v0.MakeWASIListener(core.Config().NetworkListenerOrPanic()) + + if err = conn.wasm.LinkNetworkInterface(nil, listener); err != nil { + return nil, err + } + + if err = conn.wasm.Initialize(); err != nil { + return nil, err + } + + if conn.wasm._accept == nil { + return nil, fmt.Errorf("water: WASM module does not export _accept") + } + + // Initialize WASM module as ReadWriter + if err = conn.wasm.InitializeReadWriter(); err != nil { + return nil, err + } + + var wasmCallerConn net.Conn + wasmCallerConn, conn.uoConn, err = socket.UnixConnPair("") + if err != nil { + return nil, fmt.Errorf("water: socket.UnixConnPair returned error: %w", err) + } + + wasmNetworkConn, err := conn.wasm.AcceptFor(wasmCallerConn) + if err != nil { + return nil, err + } + + conn.networkConn = wasmNetworkConn + + return conn, nil +} + +// Read implements the net.Conn interface. +// +// It calls to the underlying user-oriented net.Conn's Read() method. +func (c *ConnV0) Read(b []byte) (n int, err error) { + if c.uoConn == nil { + return 0, errors.New("water: cannot read, (*RuntimeConnV0).uoConn is nil") + } + + // call _read + ret, err := c.wasm._read.Call(c.wasm.Store()) + if err != nil { + return 0, fmt.Errorf("water: (*wasmtime.Func).Call returned error: %w", err) + } + + if ret32, ok := ret.(int32); !ok { + return 0, fmt.Errorf("water: (*wasmtime.Func).Call returned non-int32 value") + } else { + if ret32 != 0 { + return 0, wasm.WASMErr(ret32) + } + } + + return c.uoConn.Read(b) +} + +// Write implements the net.Conn interface. +// +// It calls to the underlying user-oriented net.Conn's Write() method. +func (c *ConnV0) Write(b []byte) (n int, err error) { + if c.uoConn == nil { + return 0, errors.New("water: cannot write, (*RuntimeConnV0).uoConn is nil") + } + + n, err = c.uoConn.Write(b) + if err != nil { + return n, fmt.Errorf("uoConn.Write: %w", err) + } + + if n < len(b) { + return n, io.ErrShortWrite + } + + if n > len(b) { + return n, errors.New("invalid write result") // io.errInvalidWrite + } + + // call _write to notify WASM + ret, err := c.wasm._write.Call(c.wasm.Store()) + if err != nil { + return 0, fmt.Errorf("water: (*wasmtime.Func).Call returned error: %w", err) + } + + if ret32, ok := ret.(int32); !ok { + return 0, fmt.Errorf("water: (*wasmtime.Func).Call returned non-int32 value") + } else { + return n, wasm.WASMErr(ret32) + } +} + +// Close implements the net.Conn interface. +// +// It will close both the network connection AND the WASM module, then +// the user-facing net.Conn will be closed. +func (c *ConnV0) Close() error { + err := c.networkConn.Close() + if err != nil { + return fmt.Errorf("water: (*RuntimeConnV0).netConn.Close returned error: %w", err) + } + + _, err = c.wasm._close.Call(c.wasm.Store()) + if err != nil { + return fmt.Errorf("water: (*RuntimeConnV0)._close.Call returned error: %w", err) + } + + c.wasm.DeferAll() + c.wasm.Cleanup() + + if c.uoConn != nil { + c.uoConn.Close() + } + + return nil +} + +// LocalAddr implements the net.Conn interface. +// +// It calls to the underlying network connection's LocalAddr() method. +func (c *ConnV0) LocalAddr() net.Addr { + return c.networkConn.LocalAddr() +} + +// RemoteAddr implements the net.Conn interface. +// +// It calls to the underlying network connection's RemoteAddr() method. +func (c *ConnV0) RemoteAddr() net.Addr { + return c.networkConn.RemoteAddr() +} + +// SetDeadline implements the net.Conn interface. +// +// It calls to the underlying user-oriented connection's SetDeadline() method. +func (c *ConnV0) SetDeadline(t time.Time) error { + return c.uoConn.SetDeadline(t) +} + +// SetReadDeadline implements the net.Conn interface. +// +// It calls to the underlying user-oriented connection's SetReadDeadline() method. +// +// Note: in practice this method should actively be used by the caller. Otherwise +// it is possible for a silently failed network connection to cause the WASM module +// to hang forever on Read(). +func (c *ConnV0) SetReadDeadline(t time.Time) error { + return c.uoConn.SetReadDeadline(t) +} + +// SetWriteDeadline implements the net.Conn interface. +// +// It calls to the underlying user-oriented connection's SetWriteDeadline() method. +func (c *ConnV0) SetWriteDeadline(t time.Time) error { + return c.uoConn.SetWriteDeadline(t) +} diff --git a/conn_v0_test.go b/conn_v0_test.go new file mode 100644 index 0000000..00111b4 --- /dev/null +++ b/conn_v0_test.go @@ -0,0 +1,469 @@ +//go:build unix && !windows && !nov0 + +package water_test + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net" + "os" + "runtime" + "sync" + "testing" + "time" + + "github.com/gaukas/water" +) + +var hexencoder_v0 []byte +var plain_v0 []byte + +func TestConnV0(t *testing.T) { + // read file into hexencoder_v0 + var err error + hexencoder_v0, err = os.ReadFile("./testdata/hexencoder_v0.wasm") + if err != nil { + t.Fatal(err) + } + t.Run("DialerV0", testDialerV0) + t.Run("ListenerV0", testListenerV0) +} + +func testDialerV0(t *testing.T) { + // t.Parallel() + + // create random TCP listener listening on localhost + tcpLis, err := net.ListenTCP("tcp", nil) + if err != nil { + t.Fatal(err) + } + defer tcpLis.Close() + + // goroutine to accept incoming connections + var lisConn net.Conn + var goroutineErr error + var wg *sync.WaitGroup = new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + lisConn, goroutineErr = tcpLis.Accept() + }() + + // Dial + dialer := &water.Dialer{ + Config: &water.Config{ + WATMBin: hexencoder_v0, + WATMConfig: water.WATMConfig{ + FilePath: "./testdata/hexencoder_v0.dialer.json", + }, + }, + } + dialer.Config.WASIConfig().InheritStdout() + + rConn, err := dialer.Dial("tcp", tcpLis.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer rConn.Close() + + // wait for listener to accept connection + wg.Wait() + if goroutineErr != nil { + t.Fatal(goroutineErr) + } + + runtime.GC() + time.Sleep(10 * time.Millisecond) + + if err = testUppercaseHexencoderConn(rConn, lisConn, []byte("hello"), []byte("world")); err != nil { + t.Fatal(err) + } + + runtime.GC() + time.Sleep(10 * time.Millisecond) + + if err = testUppercaseHexencoderConn(rConn, lisConn, []byte("i'm dialer"), []byte("hello dialer")); err != nil { + t.Fatal(err) + } + + runtime.GC() + time.Sleep(10 * time.Millisecond) + + if err = testUppercaseHexencoderConn(rConn, lisConn, []byte("who are you?"), []byte("I'm listener")); err != nil { + t.Fatal(err) + } +} + +func testListenerV0(t *testing.T) { + // t.Parallel() + + // prepare for listener + config := &water.Config{ + WATMBin: hexencoder_v0, + WATMConfig: water.WATMConfig{ + FilePath: "./testdata/hexencoder_v0.listener.json", + }, + // WASIConfigFactory: wasm.NewWasiConfigFactory(), + } + config.WASIConfig().InheritStdout() + + lis, err := config.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + // goroutine to dial listener + var dialConn net.Conn + var goroutineErr error + var wg *sync.WaitGroup = new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + dialConn, goroutineErr = net.Dial("tcp", lis.Addr().String()) + }() + + // Accept + rConn, err := lis.Accept() + if err != nil { + t.Fatal(err) + } + defer rConn.Close() + + // wait for dialer to dial + wg.Wait() + if goroutineErr != nil { + t.Fatal(goroutineErr) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + if err = testLowercaseHexencoderConn(rConn, dialConn, []byte("hello"), []byte("world")); err != nil { + t.Error(err) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + if err = testLowercaseHexencoderConn(rConn, dialConn, []byte("i'm listener"), []byte("hello listener")); err != nil { + t.Error(err) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + if err = testLowercaseHexencoderConn(rConn, dialConn, []byte("who are you?"), []byte("I'm dialer")); err != nil { + t.Error(err) + } +} + +func testUppercaseHexencoderConn(encoderConn, plainConn net.Conn, dMsg, lMsg []byte) error { + // dConn -> lConn + _, err := encoderConn.Write(dMsg) + if err != nil { + return err + } + + // receive data + buf := make([]byte, 1024) + n, err := plainConn.Read(buf) + if err != nil { + return err + } + + // decode hex + var decoded []byte = make([]byte, 1024) + n, err = hex.Decode(decoded, buf[:n]) + if err != nil { + return err + } + + // compare received bytes with expected bytes + if string(decoded[:n]) != string(dMsg) { + return fmt.Errorf("expected: %s, got: %s", dMsg, decoded[:n]) + } + + // encode hex + var encoded []byte = make([]byte, 1024) + n = hex.Encode(encoded, lMsg) + + // lConn -> dConn + _, err = plainConn.Write(encoded[:n]) + if err != nil { + return err + } + + // receive data + n, err = encoderConn.Read(buf) + if err != nil { + return err + } + + // compare received bytes with expected bytes + var upperLMsg []byte = make([]byte, len(lMsg)) + for i, b := range lMsg { + if b >= 'a' && b <= 'z' { // to uppercase + upperLMsg[i] = b - 32 + } else { + upperLMsg[i] = b + } + } + + if string(buf[:n]) != string(upperLMsg) { + return fmt.Errorf("expected: %s, got: %s", upperLMsg, decoded[:n]) + } + + return nil +} + +func testLowercaseHexencoderConn(encoderConn, plainConn net.Conn, dMsg, lMsg []byte) error { + // dConn -> lConn + _, err := encoderConn.Write(dMsg) + if err != nil { + return err + } + + // receive data + buf := make([]byte, 1024) + n, err := plainConn.Read(buf) + if err != nil { + return err + } + + // decode hex + var decoded []byte = make([]byte, 1024) + n, err = hex.Decode(decoded, buf[:n]) + if err != nil { + return err + } + + // compare received bytes with expected bytes + if string(decoded[:n]) != string(dMsg) { + return fmt.Errorf("expected: %s, got: %s", dMsg, decoded[:n]) + } + + // encode hex + var encoded []byte = make([]byte, 1024) + n = hex.Encode(encoded, lMsg) + + // lConn -> dConn + _, err = plainConn.Write(encoded[:n]) + if err != nil { + return err + } + + // receive data + n, err = encoderConn.Read(buf) + if err != nil { + return err + } + + // compare received bytes with expected bytes + var upperLMsg []byte = make([]byte, len(lMsg)) + for i, b := range lMsg { + if b >= 'A' && b <= 'Z' { // to lowercase + upperLMsg[i] = b + 32 + } else { + upperLMsg[i] = b + } + } + + if string(buf[:n]) != string(upperLMsg) { + return fmt.Errorf("expected: %s, got: %s", upperLMsg, decoded[:n]) + } + + return nil +} + +func BenchmarkConnV0(b *testing.B) { + // read file into plain_v0 + var err error + plain_v0, err = os.ReadFile("./testdata/plain_v0.wasm") + if err != nil { + b.Fatal(err) + } + b.Run("PlainV0-Dialer", benchmarkPlainV0Dialer) + b.Run("PlainV0-Listener", benchmarkPlainV0Listener) + b.Run("RefTCP", benchmarkReferenceTCP) +} + +func benchmarkPlainV0Dialer(b *testing.B) { + // create random TCP listener listening on localhost + tcpLis, err := net.ListenTCP("tcp", nil) + if err != nil { + b.Fatal(err) + } + defer tcpLis.Close() + + // goroutine to accept incoming connections + var lisConn net.Conn + var goroutineErr error + var wg *sync.WaitGroup = new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + lisConn, goroutineErr = tcpLis.Accept() + }() + + // Dial + dialer := &water.Dialer{ + Config: &water.Config{ + WATMBin: plain_v0, + }, + } + + rConn, err := dialer.Dial("tcp", tcpLis.Addr().String()) + if err != nil { + b.Fatal(err) + } + defer rConn.Close() + + // wait for listener to accept connection + wg.Wait() + if goroutineErr != nil { + b.Fatal(goroutineErr) + } + + var sendMsg []byte = make([]byte, 1024) + rand.Read(sendMsg) + + runtime.GC() + time.Sleep(10 * time.Millisecond) + + b.SetBytes(1024) + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, err = rConn.Write(sendMsg) + if err != nil { + b.Logf("Write error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + + buf := make([]byte, 1024+128) + _, err = lisConn.Read(buf) + if err != nil { + b.Logf("Read error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + } + b.StopTimer() + b.Logf("avg bandwidth: %f MB/s (N=%d)", float64(b.N*1024)/time.Since(start).Seconds()/1024/1024, b.N) +} + +func benchmarkPlainV0Listener(b *testing.B) { + // prepare for listener + config := &water.Config{ + WATMBin: plain_v0, + } + + lis, err := config.Listen("tcp", "localhost:0") + if err != nil { + b.Fatal(err) + } + + // goroutine to dial listener + var dialConn net.Conn + var goroutineErr error + var wg *sync.WaitGroup = new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + dialConn, goroutineErr = net.Dial("tcp", lis.Addr().String()) + }() + + // Accept + rConn, err := lis.Accept() + if err != nil { + b.Fatal(err) + } + defer rConn.Close() + + // wait for dialer to dial + wg.Wait() + if goroutineErr != nil { + b.Fatal(goroutineErr) + } + + var sendMsg []byte = make([]byte, 512) + rand.Read(sendMsg) + + b.SetBytes(1024) // we will send 512-byte data and 128-byte will be transmitted on wire due to hex encoding + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, err = rConn.Write(sendMsg) + if err != nil { + b.Logf("Write error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + + // receive data + buf := make([]byte, 1024) + _, err = dialConn.Read(buf) + if err != nil { + b.Logf("Read error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + } + b.StopTimer() + b.Logf("avg bandwidth: %f MB/s (N=%d)", float64(b.N*1024)/time.Since(start).Seconds()/1024/1024, b.N) +} + +func benchmarkReferenceTCP(b *testing.B) { + // create random TCP listener listening on localhost + tcpLis, err := net.ListenTCP("tcp", nil) + if err != nil { + b.Fatal(err) + } + defer tcpLis.Close() + + // goroutine to accept incoming connections + var lisConn net.Conn + var goroutineErr error + var wg *sync.WaitGroup = new(sync.WaitGroup) + wg.Add(1) + go func() { + defer wg.Done() + lisConn, goroutineErr = tcpLis.Accept() + }() + + nConn, err := net.Dial("tcp", tcpLis.Addr().String()) + if err != nil { + b.Fatal(err) + } + defer nConn.Close() + + // wait for listener to accept connection + wg.Wait() + if goroutineErr != nil { + b.Fatal(goroutineErr) + } + + var sendMsg []byte = make([]byte, 1024) + rand.Read(sendMsg) + + b.SetBytes(1024) + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, err = nConn.Write(sendMsg) + if err != nil { + b.Logf("Write error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + + // receive data + buf := make([]byte, 1024) + _, err = lisConn.Read(buf) + if err != nil { + b.Logf("Read error, cntr: %d, N: %d", i, b.N) + b.Fatal(err) + } + + // time.Sleep(10 * time.Microsecond) + } + b.StopTimer() + b.Logf("avg bandwidth: %f MB/s (N=%d)", float64(b.N*1024)/time.Since(start).Seconds()/1024/1024, b.N) +} diff --git a/core.go b/core.go new file mode 100644 index 0000000..b1fd23e --- /dev/null +++ b/core.go @@ -0,0 +1,109 @@ +package water + +import ( + "fmt" + + "github.com/bytecodealliance/wasmtime-go/v13" +) + +// Core provides the WASM runtime base and is an internal struct +// that every RuntimeXxx implementation will embed. +// +// Core is not versioned and is not subject to breaking changes +// unless a severe bug needs to be fixed in a breaking way. +type core struct { + // config + config *Config + + // wasmtime + engine *wasmtime.Engine + module *wasmtime.Module + store *wasmtime.Store // avoid directly accessing store once the instance is created + linker *wasmtime.Linker + instance *wasmtime.Instance +} + +// Core creates a new Core, which is the base of all +// WASM runtime functionalities. +func Core(config *Config) (c *core, err error) { + c = &core{ + config: config, + } + + var wasiConfig *wasmtime.WasiConfig + wasiConfig, err = c.config.WASIConfig().GetConfig() + if err != nil { + err = fmt.Errorf("water: (*WasiConfigFactory).GetConfig returned error: %w", err) + return + } + + c.engine = wasmtime.NewEngine() + c.module, err = wasmtime.NewModule(c.engine, c.config.WATMBinOrPanic()) + if err != nil { + err = fmt.Errorf("water: wasmtime.NewModule returned error: %w", err) + return + } + c.store = wasmtime.NewStore(c.engine) + c.store.SetWasiConfig(wasiConfig) + c.linker = wasmtime.NewLinker(c.engine) + err = c.linker.DefineWasi() + if err != nil { + err = fmt.Errorf("water: (*wasmtime.Linker).DefineWasi returned error: %w", err) + return + } + + return +} + +func (c *core) DialVersion(network, address string) (Conn, error) { + for _, export := range c.module.Exports() { + if f, ok := mapCoreDialContext[export.Name()]; ok { + return f(c, network, address) + } + } + return nil, fmt.Errorf("water: core loaded a WASM module that does not implement any known version") +} + +func (c *core) AcceptVersion() (Conn, error) { + for _, export := range c.module.Exports() { + if f, ok := mapCoreAccept[export.Name()]; ok { + return f(c) + } + } + return nil, fmt.Errorf("water: core loaded a WASM module that does not implement any known version") +} + +// Config returns the Config used to create the Core. +func (c *core) Config() *Config { + return c.config +} + +func (c *core) Engine() *wasmtime.Engine { + return c.engine +} + +func (c *core) Instance() *wasmtime.Instance { + return c.instance +} + +func (c *core) Linker() *wasmtime.Linker { + return c.linker +} + +func (c *core) Module() *wasmtime.Module { + return c.module +} + +func (c *core) Store() *wasmtime.Store { + return c.store +} + +func (c *core) Instantiate() error { + instance, err := c.linker.Instantiate(c.store, c.module) + if err != nil { + return fmt.Errorf("water: (*wasmtime.Linker).Instantiate returned error: %w", err) + } + + c.instance = instance + return nil +} diff --git a/dialer.go b/dialer.go index 27236c3..f6264b0 100644 --- a/dialer.go +++ b/dialer.go @@ -1,76 +1,71 @@ package water import ( - "crypto/tls" + "context" "fmt" - "net" - "strings" ) -type Dialer interface { - // Dial connects to the address on the named network. - Dial(network, address string) (net.Conn, error) +// Dialer dials the given network address upon caller calling +// Dial() and returns a net.Conn which is connected to the +// WASM module. +// +// The structure of a Dialer is as follows: +// +// dial +----------------+ dial +// ----->| Decode |------> +// Caller | WASM Runtime | Remote +// <-----| Decode/Encode |<------ +// +----------------+ +// Dialer +type Dialer struct { + // Config is the configuration for the core. + Config *Config } -type dialer struct { - tlsDialer Dialer // for a tlsDialer, Dial() function should return a *tls.Conn or its equivalent. And Handshake() should be called before returning. -} - -func DefaultDialer() Dialer { - return &dialer{ - tlsDialer: TLSDialerWithConfig(&tls.Config{}), +func (c *Config) Dialer() *Dialer { + return &Dialer{ + Config: c.Clone(), } } -func DialerWithTLS(tlsDialer Dialer) Dialer { - return &dialer{ - tlsDialer: tlsDialer, - } +// Dialer dials the given network address using the specified dialer +// in the config. The returned RuntimeConn implements net.Conn and +// could be seen as the outbound connection with a wrapping transport +// protocol handled by the WASM module. +// +// Internally, DialContext() is called with a background context. +func (d *Dialer) Dial(network, address string) (Conn, error) { + return d.DialContext(context.Background(), network, address) } -func (d *dialer) Dial(network, address string) (net.Conn, error) { - switch network { - case "tls", "tls4", "tls6": - tlsNetwork := strings.ReplaceAll(network, "tls", "tcp") // tls4 -> tcp4, etc. - return d.tlsDialer.Dial(tlsNetwork, address) - default: - return net.Dial(network, address) +// DialContext dials the given network address using the specified dialer +// in the config. The returned RuntimeConn implements net.Conn and +// could be seen as the outbound connection with a wrapping transport +// protocol handled by the WASM module. +// +// If the context expires before the connection is complete, an error is +// returned. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (conn Conn, err error) { + if d.Config == nil { + return nil, fmt.Errorf("water: dialing with nil config is not allowed") } -} -type tlsDialer struct { - tlsConfig *tls.Config -} + ctxReady, dialReady := context.WithCancel(context.Background()) + go func() { + defer dialReady() + var core *core + core, err = Core(d.Config) + if err != nil { + return + } -func TLSDialerWithConfig(config *tls.Config) Dialer { - return &tlsDialer{config.Clone()} -} + conn, err = core.DialVersion(network, address) + }() -func (d *tlsDialer) Dial(network, address string) (net.Conn, error) { - d.tlsConfig.ServerName = strings.Split(address, ":")[0] // "example.com:443" -> "example.com" - tlsConn, err := tls.Dial(network, address, d.tlsConfig) - if err != nil { - return nil, fmt.Errorf("tls.Dial(): %w", err) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ctxReady.Done(): + return conn, err } - return tlsConn, nil -} - -type AddressedDialer interface { - Dial(network string) (net.Conn, error) -} - -type addressedDialer struct { - dialer Dialer - address string -} - -func SetDialerAddress(dialer Dialer, address string) AddressedDialer { - return &addressedDialer{ - dialer: dialer, - address: address, - } -} - -func (d *addressedDialer) Dial(network string) (net.Conn, error) { - return d.dialer.Dial(network, d.address) } diff --git a/errors.go b/errors.go deleted file mode 100644 index 9bd04c0..0000000 --- a/errors.go +++ /dev/null @@ -1,6 +0,0 @@ -package water - -// errno -var ( - ErrIO int = -1 -) diff --git a/feature.go b/feature.go new file mode 100644 index 0000000..5ad5030 --- /dev/null +++ b/feature.go @@ -0,0 +1,14 @@ +package water + +type Feature uint64 + +// Feature is a bit mask of experimental features of WATER. +// +// TODO: implement Feature. +const ( + FEATURE_DUMMY Feature = 1 << iota // a dummy feature that does nothing. + FEATURE_RESERVED // reserved for future use + // ... + FEATURE_CWAL Feature = 0xFFFFFFFFFFFFFFFF // CWAL = Can't Wait Any Longer + FEATURE_NONE Feature = 0 // NONE = No Experimental Features +) diff --git a/gcfix.go b/gcfix.go new file mode 100644 index 0000000..bdbe98a --- /dev/null +++ b/gcfix.go @@ -0,0 +1,12 @@ +//go:build !nogcfix + +package water + +// GCFIX is a workaround to prevent Go GC from incorrectly garbage +// collecting the cloned `*os.File` pushed to WASM with `PushFile()`. +// +// BUG: There is an undocumented GC issue in Go 1.20 and Go 1.21. +// The first `*os.File` pushed to WASM with `PushFile()` from wasmtime +// will be incorrectly garbage collected by Go GC even if it is still +// accessible from Go. +const GCFIX bool = true diff --git a/go.mod b/go.mod index ff67e6a..3408f57 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,10 @@ module github.com/gaukas/water -go 1.21 +go 1.20 -require github.com/bytecodealliance/wasmtime-go/v11 v11.0.0 +replace github.com/bytecodealliance/wasmtime-go/v13 v13.0.0 => github.com/refraction-networking/wasmtime-go/v13 v13.0.0 + +require ( + github.com/bytecodealliance/wasmtime-go/v13 v13.0.0 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 +) diff --git a/go.sum b/go.sum index 060bd7c..160c503 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,8 @@ -github.com/bytecodealliance/wasmtime-go/v11 v11.0.0 h1:SwLgbjbFpQ1tf5vIbWexaZABezBSL8WmzP+foLyi0lE= -github.com/bytecodealliance/wasmtime-go/v11 v11.0.0/go.mod h1:9btfEuCkOP7EDR9a7LqDXrfQ7dtWeGlDHt3buV5UyjY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/refraction-networking/wasmtime-go/v13 v13.0.0 h1:5Asz7xwxaRW59P9hTwxKjn5gKjf7BziCX0+Y9CIZJPs= +github.com/refraction-networking/wasmtime-go/v13 v13.0.0/go.mod h1:KmsZLdjjzNH/E5wbfoRehqP70tHzKlfNOi730VCAR4E= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/filesocket/bundle.go b/internal/filesocket/bundle.go deleted file mode 100644 index 7c95660..0000000 --- a/internal/filesocket/bundle.go +++ /dev/null @@ -1,75 +0,0 @@ -package filesocket - -import ( - "io" - "net" - "os" -) - -// Bundle is a combination of a FileSocket and a (net).Conn. -// -// Anything received from the net.Conn will be written into the FileSocket, -// and anything received from the FileSocket will be written into the net.Conn. -type Bundle interface { - Start() // start handling data transfer between the net.Conn and the FileSocket - RxFile() *os.File // the file where received data from net.Conn is written to - TxFile() *os.File // the file where data should be written-to to be sent via net.Conn - net.Conn // ONLY for LocalAddr(), RemoteAddr(), SetDeadline(), SetReadDeadline(), SetWriteDeadline() - OnClose(func()) // optional callback to be called when the Close() method is called -} - -// bundle implements Bundle -type bundle struct { - net.Conn - fs FileSocket - onClose func() -} - -func BundleFileSocket(conn net.Conn, fs FileSocket) Bundle { - return &bundle{ - Conn: conn, - fs: fs, - } -} - -// BundleFiles creates a FileSocket from the given files, writing -// received data from net.Conn to the rxFile and send data from the txFile -// to the net.Conn. -func BundleFiles(conn net.Conn, rxFile, txFile *os.File) Bundle { - return &bundle{ - Conn: conn, - fs: NewFileSocket(txFile, rxFile), - } -} - -func (b *bundle) Start() { - go func() { - io.Copy(b.fs, b.Conn) - b.fs.Close() - }() - go func() { - io.Copy(b.Conn, b.fs) - b.Conn.Close() - b.fs.Close() // TODO: is this necessary? now added just to be safe - }() -} - -func (b *bundle) RxFile() *os.File { - return b.fs.(*fileSocket).wrFile -} - -func (b *bundle) TxFile() *os.File { - return b.fs.(*fileSocket).rdFile -} - -func (b *bundle) Close() error { - if b.onClose != nil { - defer b.onClose() - } - b.fs.Close() - return b.Conn.Close() -} - -func (b *bundle) OnClose(f func()) { - b.onClose = f -} diff --git a/internal/filesocket/filesocket.go b/internal/filesocket/filesocket.go deleted file mode 100644 index e2f3c69..0000000 --- a/internal/filesocket/filesocket.go +++ /dev/null @@ -1,92 +0,0 @@ -package filesocket - -import ( - "errors" - "io" - "os" - "sync/atomic" -) - -type FileSocket interface { - io.ReadWriteCloser - - RdFile() *os.File // returns the file where Read() wi;; read from - WrFile() *os.File // returns the file where Write() will write to -} - -type fileSocket struct { - rdFile *os.File // Read() reads from this file - wrFile *os.File // Write() writes to this file, ReadFrom() reads data from a reader into this file - - closed *atomic.Bool -} - -func NewFileSocket(rdFile, wrFile *os.File) FileSocket { - return &fileSocket{rdFile, wrFile, &atomic.Bool{}} -} - -func (fs *fileSocket) Read(p []byte) (n int, err error) { - if fs.closed.Load() { - return 0, os.ErrClosed - } - if fs.rdFile == nil { - return 0, os.ErrInvalid - } - - // if everything new in file has been read, it will return EOF. - // In which case, we want to block on retrying until there is new data. - for n == 0 { - n, err = fs.rdFile.Read(p) - if err != nil && !errors.Is(err, io.EOF) { - if errors.Is(err, os.ErrClosed) && fs.closed.Load() { - err = io.EOF // when closed by caller, return EOF instead - } - return 0, err - } - } - - return n, nil -} - -func (fs *fileSocket) ReadFrom(r io.Reader) (n int64, err error) { - if fs.closed.Load() { - return 0, os.ErrClosed - } - if fs.rdFile == nil { - return 0, os.ErrInvalid - } - return fs.wrFile.ReadFrom(r) // ReadFrom() could have platform-specific benefits -} - -func (fs *fileSocket) Write(p []byte) (n int, err error) { - if fs.closed.Load() { - return 0, os.ErrClosed - } - if fs.wrFile == nil { - return 0, os.ErrInvalid - } - return fs.wrFile.Write(p) -} - -func (fs *fileSocket) Close() error { - if !fs.closed.CompareAndSwap(false, true) { - return os.ErrClosed - } - if fs.rdFile != nil { - fs.rdFile.Close() - os.Remove(fs.rdFile.Name()) - } - if fs.wrFile != nil { - fs.wrFile.Close() - os.Remove(fs.wrFile.Name()) - } - return nil -} - -func (fs *fileSocket) RdFile() *os.File { - return fs.rdFile -} - -func (fs *fileSocket) WrFile() *os.File { - return fs.wrFile -} diff --git a/internal/filesocket/filesocket_test.go b/internal/filesocket/filesocket_test.go deleted file mode 100644 index ae01681..0000000 --- a/internal/filesocket/filesocket_test.go +++ /dev/null @@ -1,155 +0,0 @@ -package filesocket - -import ( - "errors" - "io" - "os" - "sync" - "testing" - "time" -) - -func TestFileSocket(t *testing.T) { - t.Run("testTempFileSocket", testTempFileSocket) -} - -func testTempFileSocket(t *testing.T) { - // Create 2 temp files - rdFile, err := os.CreateTemp("", "rdFile_*.tmp") - if err != nil { - t.Fatalf("error in creating temp rdFile: %v", err) - } - t.Logf("created: %v", rdFile.Name()) - // close and remove the file on exit - defer func() { - rdFile.Close() - os.Remove(rdFile.Name()) - }() - - wrFile, err := os.CreateTemp("", "wrFile_*.tmp") - if err != nil { - t.Fatalf("error in creating temp wrFile: %v", err) - } - t.Logf("created: %v", wrFile.Name()) - // close and remove the file on exit - defer func() { - wrFile.Close() - os.Remove(wrFile.Name()) - }() - - // Create a FileSocket - fs := NewFileSocket(rdFile, wrFile) - if fs == nil { - t.Fatalf("NewFileSocket() returned nil") - } - defer fs.Close() - - wg := &sync.WaitGroup{} - wg.Add(3) // 3 goroutines - - // One goroutine to read from the wrFile, capitalise the string and write to rdFile - - wrFileCopy, err := os.OpenFile(wrFile.Name(), os.O_RDONLY, 0644) - if err != nil { - t.Errorf("error in opening wrFile: %v", err) - return - } - - rdFileCopy, err := os.OpenFile(rdFile.Name(), os.O_WRONLY, 0644) - if err != nil { - t.Errorf("error in opening rdFile: %v", err) - return - } - - go func(wrFile, rdFile *os.File) { - defer wg.Done() - var buf []byte = make([]byte, 1024) - for { - n, err := wrFile.Read(buf) - if err != nil && !errors.Is(err, io.EOF) { - if !errors.Is(err, os.ErrClosed) { - t.Errorf("wrFile.Read() errored: %v", err) - } - return - } - if n == 0 { - // t.Logf("wrFile.Read(): empty read") - continue - } - // t.Logf("wrFile.Read(): %v bytes read: %v", n, string(buf[:n])) - - // Capitalise the string - for i := 0; i < n; i++ { - if buf[i] >= 'a' && buf[i] <= 'z' { - buf[i] = buf[i] - 'a' + 'A' - } - } - - // Write to rdFile - n, err = rdFile.Write(buf[:n]) - if err != nil { - t.Logf("rdFile.Write() errored: %v", err) - return - } - // t.Logf("rdFile.Write(): %v bytes written", n) - - // Sleep for 10 millisecond - time.Sleep(10 * time.Millisecond) - } - }(wrFileCopy, rdFileCopy) - - // one goroutine used to write to the FileSocket - go func() { - defer wg.Done() - defer fs.Close() - defer wrFileCopy.Close() - for i := 0; i < 10; i++ { - // Write to the FileSocket - sendBuf := []byte("hello world") - n, err := fs.Write(sendBuf) - if err != nil { - t.Errorf("fs.Write() errored: %v", err) - return - } - - if n != len(sendBuf) { - t.Errorf("fs.Write() wrote %v bytes, expected %v bytes", n, len(sendBuf)) - return - } - t.Logf("fs.Write(): %v bytes written: %v", n, string(sendBuf)) - - // Sleep for 1 Second - time.Sleep(1 * time.Second) - } - }() - - // one goroutine used to read from the FileSocket - go func() { - defer wg.Done() - defer rdFileCopy.Close() - for { - buf := make([]byte, 1024) - n, err := fs.Read(buf) - if err != nil { - if errors.Is(err, os.ErrClosed) && fs.(*fileSocket).closed.Load() { - t.Logf("fs.Read(): reading from closed socket") - return - } - if errors.Is(err, io.EOF) { - t.Logf("fs.Read(): EOF") - return - } - t.Errorf("fs.Read() errored: %v", err) - return - } - if n == 0 { - // t.Logf("fs.Read(): empty read") - continue - } - t.Logf("fs.Read(): %v bytes read: %v", n, string(buf[:n])) - } - }() - - // Wait for the goroutines to finish - wg.Wait() -} diff --git a/internal/log/README.md b/internal/log/README.md new file mode 100644 index 0000000..2fefbde --- /dev/null +++ b/internal/log/README.md @@ -0,0 +1,3 @@ +# log + +`log` provides a version-independent wrapper around `slog` from standard Go library, where the latter is version dependent to Go version (located at "golang.org/x/exp/slog" for older Go versions before Go 1.21). \ No newline at end of file diff --git a/internal/log/slog.go b/internal/log/slog.go new file mode 100644 index 0000000..0389483 --- /dev/null +++ b/internal/log/slog.go @@ -0,0 +1,30 @@ +//go:build go1.21 + +package log + +import ( + "fmt" + + "log/slog" +) + +func Debugf(format string, args ...any) { + slog.Default().Debug(fmt.Sprintf(format, args...)) +} + +func Infof(format string, args ...any) { + slog.Default().Info(fmt.Sprintf(format, args...)) +} + +func Warningf(format string, args ...any) { + slog.Default().Warn(fmt.Sprintf(format, args...)) +} + +func Errorf(format string, args ...any) { + slog.Default().Error(fmt.Sprintf(format, args...)) +} + +func Fatalf(format string, args ...any) { + slog.Default().Error(fmt.Sprintf(format, args...)) + panic("fatal error occurred") +} diff --git a/internal/log/slog_old.go b/internal/log/slog_old.go new file mode 100644 index 0000000..c25f0d9 --- /dev/null +++ b/internal/log/slog_old.go @@ -0,0 +1,30 @@ +//go:build !go1.21 + +package log + +import ( + "fmt" + + "golang.org/x/exp/slog" +) + +func Debugf(format string, args ...any) { + slog.Default().Debug(fmt.Sprintf(format, args...)) +} + +func Infof(format string, args ...any) { + slog.Default().Info(fmt.Sprintf(format, args...)) +} + +func Warnf(format string, args ...any) { + slog.Default().Warn(fmt.Sprintf(format, args...)) +} + +func Errorf(format string, args ...any) { + slog.Default().Error(fmt.Sprintf(format, args...)) +} + +func Fatalf(format string, args ...any) { + slog.Default().Error(fmt.Sprintf(format, args...)) + panic("fatal error occurred") +} diff --git a/internal/socket/file.go b/internal/socket/file.go new file mode 100644 index 0000000..22cc842 --- /dev/null +++ b/internal/socket/file.go @@ -0,0 +1,43 @@ +package socket + +import ( + "errors" + "fmt" + "io" + "os" + + "github.com/gaukas/water/internal/log" +) + +var ErrNoKnownConversion = errors.New("no known conversion to *os.File") + +type EmbedFile interface { + File() (*os.File, error) +} + +func AsFile(f any) (*os.File, error) { + switch f := f.(type) { + case *os.File: + log.Debugf("%T is already *os.File", f) + return f, nil + // Anything implementing EmbedFile interface, including: + // - *net.TCPConn + // - *net.UDPConn + // - *net.UnixConn + // - *net.TCPListener + // - *net.UnixListener + case EmbedFile: + log.Debugf("%T has implemented File() (*os.File, error)", f) + return f.File() + case io.ReadWriteCloser: // and also net.Conn + log.Debugf("%T implements only ReadWriteCloser and needs wrapping", f) + unixConn, err := UnixConnWrap(f) + if err != nil { + return nil, err + } + return unixConn.File() + default: + log.Debugf("%T has no known conversion to *os.File", f) + return nil, fmt.Errorf("%T: %w", f, ErrNoKnownConversion) + } +} diff --git a/internal/socket/unixconn.go b/internal/socket/unixconn.go new file mode 100644 index 0000000..2d2f346 --- /dev/null +++ b/internal/socket/unixconn.go @@ -0,0 +1,172 @@ +package socket + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "net" + "os" + "sync" + "time" +) + +const () + +// UnixConnWrap wraps an io.Reader/io.Writer/io.ReadWriteCloser +// interface into a UnixConn. +// +// This function spins up either one or two goroutines to copy +// data between the ReadWriteCloser and the UnixConn. Anything +// written to the UnixConn by caller will be written to the +// wrapped object if the object implements io.Writer, and if +// the object implements io.Reader, anything read by goroutine +// from the wrapped object will be readable from the UnixConn +// by caller. +// +// Once this function is invoked, the caller should not perform I/O +// operations on the ReadWriteCloser anymore. +func UnixConnWrap(obj any) (*net.UnixConn, error) { + // randomize the name of the socket + var randName []byte = make([]byte, 8) // 8-byte so 16-char hex string, 64-bit randomness is good enough + if _, err := rand.Read(randName); err != nil { + return nil, err + } + socketName := hex.EncodeToString(randName) + + // listen on the socket + unixAddr, err := net.ResolveUnixAddr("unix", os.TempDir()+"/"+string(socketName)) + if err != nil { + return nil, err + } + unixListener, err := net.ListenUnix("unix", unixAddr) + if err != nil { + return nil, err + } + defer unixListener.Close() // we will no longer need this listener since the name is not recorded anywhere + + // spin up a goroutine to wait for listening + var unixConn *net.UnixConn + var acceptErr error + acceptWg := &sync.WaitGroup{} + acceptWg.Add(1) + go func() { + defer acceptWg.Done() + unixConn, acceptErr = unixListener.AcceptUnix() // so caller will have the accepted connection + if acceptErr != nil { + return + } + }() + + // reverseUnixConn is used to access the unixConn's read/write buffer: + // - writing to reverseUnixConn = save to unixConn's read buffer + // - reading from reverseUnixConn = read from unixConn's write buffer + reverseUnixConn, err := net.DialUnix("unix", nil, unixAddr) + if err != nil { + return nil, err + } + acceptWg.Wait() // wait for the goroutine to accept the connection + if acceptErr != nil { + return nil, acceptErr + } + + // if the object implements io.Reader: read from the object and write to the reverseUnixConn + if reader, ok := obj.(io.Reader); ok { + go func() { + io.Copy(reverseUnixConn, reader) + // when the src is closed, we will close the dst + time.Sleep(1 * time.Millisecond) + reverseUnixConn.Close() + }() + } + + // if the object implements io.Writer: read from the reverseUnixConn and write to the object + if writer, ok := obj.(io.Writer); ok { + go func() { + io.Copy(writer, reverseUnixConn) + // when the src is closed, we will close the dst + if closer, ok := obj.(io.Closer); ok { + time.Sleep(1 * time.Millisecond) + closer.Close() + } + }() + } + + return unixConn, nil +} + +func UnixConnPair(path string) (c1, c2 net.Conn, err error) { + unixPath := path + if path == "" { + // randomize a socket name + randBytes := make([]byte, 16) + if _, err := rand.Read(randBytes); err != nil { + return nil, nil, fmt.Errorf("crypto/rand.Read returned error: %w", err) + } + unixPath = os.TempDir() + string(os.PathSeparator) + fmt.Sprintf("%x", randBytes) + } + + // create a one-time use UnixListener + ul, err := net.Listen("unix", unixPath) + if err != nil { + return nil, nil, fmt.Errorf("net.Listen returned error: %w", err) + } + defer ul.Close() + + var wg *sync.WaitGroup = new(sync.WaitGroup) + var goroutineErr error + wg.Add(1) + go func() { + defer wg.Done() + c2, goroutineErr = ul.Accept() + }() + + // dial the one-time use UnixListener + c1, err = net.Dial("unix", ul.Addr().String()) + if err != nil { + return nil, nil, fmt.Errorf("net.Dial returned error: %w", err) + } + wg.Wait() + + if goroutineErr != nil { + return nil, nil, fmt.Errorf("ul.Accept returned error: %w", goroutineErr) + } + + if c1 == nil || c2 == nil { + return nil, nil, fmt.Errorf("c1 or c2 is nil") + } + + return c1, c2, nil +} + +func TCPConnPair(address string) (c1, c2 net.Conn, err error) { + l, err := net.Listen("tcp", address) + if err != nil { + return nil, nil, fmt.Errorf("net.Listen returned error: %w", err) + } + defer l.Close() + + var wg *sync.WaitGroup = new(sync.WaitGroup) + var goroutineErr error + wg.Add(1) + go func() { + defer wg.Done() + c2, goroutineErr = l.Accept() + }() + + c1, err = net.Dial("tcp", l.Addr().String()) + if err != nil { + return nil, nil, fmt.Errorf("net.Dial returned error: %w", err) + } + wg.Wait() + + if goroutineErr != nil { + return nil, nil, fmt.Errorf("l.Accept returned error: %w", goroutineErr) + } + + if c1 == nil || c2 == nil { + return nil, nil, fmt.Errorf("c1 or c2 is nil") + } + + return c1, c2, nil +} diff --git a/internal/socket/unixconn_test.go b/internal/socket/unixconn_test.go new file mode 100644 index 0000000..8b8d329 --- /dev/null +++ b/internal/socket/unixconn_test.go @@ -0,0 +1,61 @@ +package socket_test + +import ( + "crypto/rand" + "fmt" + "net" + "runtime" + "testing" + "time" + + "github.com/gaukas/water/internal/socket" +) + +func TestUnixConnPair(t *testing.T) { + c1, c2, err := socket.UnixConnPair("") + if err != nil { + t.Fatal(err) + } + + runtime.GC() + time.Sleep(1 * time.Second) + + // test c1 -> c2 + err = testIO(c1, c2, 10000, 1024, 0) + if err != nil { + t.Fatal(err) + } + + runtime.GC() + time.Sleep(1 * time.Second) + + // test c2 -> c1 + err = testIO(c2, c1, 10000, 1024, 0) + if err != nil { + t.Fatal(err) + } +} + +func testIO(wrConn, rdConn net.Conn, N int, sz int, sleep time.Duration) error { + var sendMsg []byte = make([]byte, sz) + rand.Read(sendMsg) + + var err error + for i := 0; i < N; i++ { + _, err = wrConn.Write(sendMsg) + if err != nil { + return fmt.Errorf("Write error: %w, cntr: %d, N: %d", err, i, N) + } + + // receive data + buf := make([]byte, 1024) + _, err = rdConn.Read(buf) + if err != nil { + return fmt.Errorf("Read error: %w, cntr: %d, N: %d", err, i, N) + } + + time.Sleep(sleep) + } + + return nil +} diff --git a/internal/v0/wasi_dialer.go b/internal/v0/wasi_dialer.go new file mode 100644 index 0000000..a99daf4 --- /dev/null +++ b/internal/v0/wasi_dialer.go @@ -0,0 +1,101 @@ +package v0 + +import ( + "fmt" + "net" + "os" + + "github.com/bytecodealliance/wasmtime-go/v13" + "github.com/gaukas/water/internal/socket" + "github.com/gaukas/water/internal/wasm" +) + +// WASIDialer is a convenient wrapper around net.Dialer which +// restricts the dialer to only dialing to a single address on +// a single network. +// +// WASM module will (through WASI) call to the dialer to dial +// for network connections. +type WASIDialer struct { + network string + address string + dialerFunc func(network, address string) (net.Conn, error) + mapFdConn map[int32]net.Conn // saves all the connections created by this WasiDialer by their file descriptors! (So we could close them when needed) + mapFdClonedFile map[int32]*os.File // saves all files so GC won't close them +} + +func MakeWASIDialer( + network, address string, + dialerFunc func(network, address string) (net.Conn, error), +) *WASIDialer { + return &WASIDialer{ + network: network, + address: address, + dialerFunc: dialerFunc, + mapFdConn: make(map[int32]net.Conn), + mapFdClonedFile: make(map[int32]*os.File), + } +} + +func (wd *WASIDialer) WrappedDial() wasm.WASMTIMEStoreIndependentFunction { + return WrapConnectFunc(wd.dial) +} + +// dial(apw i32) -> fd i32 +func (wd *WASIDialer) dial(caller *wasmtime.Caller) (fd int32, err error) { + conn, err := wd.dialerFunc(wd.network, wd.address) + if err != nil { + return wasm.GENERAL_ERROR, fmt.Errorf("dialerFunc: %w", err) + } + + connFile, err := socket.AsFile(conn) + if err != nil { + return wasm.GENERAL_ERROR, fmt.Errorf("socket.AsFile: %w", err) + } + + uintfd, err := caller.PushFile(connFile, wasmtime.READ_WRITE) + if err != nil { + return wasm.WASICTX_ERR, fmt.Errorf("(*wasmtime.Caller).PushFile: %w", err) + } + + wd.mapFdConn[int32(uintfd)] = conn // save the connection by its file descriptor + + // fix: Go GC will close the file descriptor (clone) created by (*net.XxxConn).File() + wd.mapFdClonedFile[int32(uintfd)] = connFile + + return int32(uintfd), nil +} + +func (wd *WASIDialer) GetConnByFd(fd int32) net.Conn { + if wd.mapFdConn == nil { + return nil + } + return wd.mapFdConn[fd] +} + +func (wd *WASIDialer) GetFileByFd(fd int32) *os.File { + if wd.mapFdClonedFile == nil { + return nil + } + return wd.mapFdClonedFile[fd] +} + +func (wd *WASIDialer) CloseAllConn() { + if wd == nil { + return + } + + if wd.mapFdConn != nil { + for k, conn := range wd.mapFdConn { + conn.Close() + delete(wd.mapFdConn, k) + } + } + + if wd.mapFdClonedFile != nil { + for k, file := range wd.mapFdClonedFile { + file.Close() + delete(wd.mapFdClonedFile, k) + } + } +} diff --git a/internal/v0/wasi_listener.go b/internal/v0/wasi_listener.go new file mode 100644 index 0000000..74f7146 --- /dev/null +++ b/internal/v0/wasi_listener.go @@ -0,0 +1,95 @@ +package v0 + +import ( + "fmt" + "net" + "os" + + "github.com/bytecodealliance/wasmtime-go/v13" + "github.com/gaukas/water/internal/socket" + "github.com/gaukas/water/internal/wasm" +) + +type WASIListener struct { + listener net.Listener + mapFdConn map[int32]net.Conn // saves all the connections accepted by this WASIListener by their file descriptors! + mapFdClonedFile map[int32]*os.File // saves all files so GC won't close them +} + +func MakeWASIListener(listener net.Listener) *WASIListener { + if listener == nil { + panic("water: NewWASIListener: listener is nil") + } + + return &WASIListener{ + listener: listener, + mapFdConn: make(map[int32]net.Conn), + mapFdClonedFile: make(map[int32]*os.File), + } +} + +func (wl *WASIListener) WrappedAccept() wasm.WASMTIMEStoreIndependentFunction { + return WrapConnectFunc(wl.accept) +} + +func (wl *WASIListener) accept(caller *wasmtime.Caller) (fd int32, err error) { + conn, err := wl.listener.Accept() + if err != nil { + return -1, fmt.Errorf("listener.Accept: %w", err) + } + + connFile, err := socket.AsFile(conn) + if err != nil { + return -1, fmt.Errorf("socket.AsFile: %w", err) + } + + uintfd, err := caller.PushFile(connFile, wasmtime.READ_WRITE) + if err != nil { + return -1, fmt.Errorf("(*wasmtime.Caller).PushFile: %w", err) + } + + wl.mapFdConn[int32(uintfd)] = conn // save the connection by its file descriptor + + // fix: Go GC will close the file descriptor clone created by (*net.XxxConn).File() + wl.mapFdClonedFile[int32(uintfd)] = connFile + + return int32(uintfd), nil +} + +// Close should not be called if the embedded listener is shared across +// multiple WASM instances or WASIListeners. +func (wl *WASIListener) Close() error { + return wl.listener.Close() +} + +func (wl *WASIListener) GetConnByFd(fd int32) net.Conn { + if wl.mapFdConn == nil { + return nil + } + return wl.mapFdConn[fd] +} + +func (wl *WASIListener) GetFileByFd(fd int32) *os.File { + if wl.mapFdClonedFile == nil { + return nil + } + return wl.mapFdClonedFile[fd] +} + +func (wl *WASIListener) CloseAllConn() { + if wl == nil { + return + } + + if wl.mapFdConn != nil { + for _, conn := range wl.mapFdConn { + conn.Close() + } + } + + if wl.mapFdClonedFile != nil { + for _, file := range wl.mapFdClonedFile { + file.Close() + } + } +} diff --git a/internal/v0/wasi_net.go b/internal/v0/wasi_net.go new file mode 100644 index 0000000..f773ddb --- /dev/null +++ b/internal/v0/wasi_net.go @@ -0,0 +1,41 @@ +package v0 + +import ( + "fmt" + + "github.com/bytecodealliance/wasmtime-go/v13" + "github.com/gaukas/water/internal/wasm" +) + +type WASIConnectFunc = func(caller *wasmtime.Caller) (fd int32, err error) + +var WASIConnectFuncType *wasmtime.FuncType = wasmtime.NewFuncType( + []*wasmtime.ValType{}, + []*wasmtime.ValType{ + wasmtime.NewValType(wasmtime.KindI32), // return: connectionFd + }, +) + +func WrapConnectFunc(f WASIConnectFunc) wasm.WASMTIMEStoreIndependentFunction { + return func(caller *wasmtime.Caller, vals []wasmtime.Val) ([]wasmtime.Val, *wasmtime.Trap) { + if len(vals) != 0 { + return []wasmtime.Val{wasmtime.ValI32(wasm.INVALID_ARGUMENT)}, wasmtime.NewTrap(fmt.Sprintf("v0.WASIConnectFunc expects 0 argument, got %d", len(vals))) + } + + fd, err := f(caller) + if err != nil { // here fd is expected to be an error code (negative) + return []wasmtime.Val{wasmtime.ValI32(fd)}, wasmtime.NewTrap(fmt.Sprintf("v0.WASIConnectFunc: %v", err)) + } + + return []wasmtime.Val{wasmtime.ValI32(fd)}, nil + } +} + +func WrappedNopWASIConnectFunc() wasm.WASMTIMEStoreIndependentFunction { + return WrapConnectFunc(nopWASIConnectFunc) +} + +// nopWASIConnectFunc is a WASIConnectFunc that does nothing. +func nopWASIConnectFunc(caller *wasmtime.Caller) (fd int32, err error) { + return wasm.INVALID_FUNCTION, fmt.Errorf("NOP WASIConnectFunc is called") +} diff --git a/internal/wasm/errors.go b/internal/wasm/errors.go new file mode 100644 index 0000000..a57bdf5 --- /dev/null +++ b/internal/wasm/errors.go @@ -0,0 +1,54 @@ +package wasm + +import "fmt" + +// WASMErrCode is the error code returned by the wasm module +type WASMErrCode = int32 + +// Pre-defined WASMErrCode +const ( + NO_ERROR WASMErrCode = -iota + GENERAL_ERROR + INVALID_ARGUMENT + INVALID_CONFIG + INVALID_FD + INVALID_FUNCTION + DOUBLE_INIT + FAILED_IO + NOT_INITIALIZED + WASICTX_ERR +) + +// Pre-defined WASM Errors +var ( + ErrGeneralError = fmt.Errorf("general error") + ErrInvalidArgument = fmt.Errorf("invalid argument") + ErrInvalidConfig = fmt.Errorf("invalid config") + ErrInvalidFD = fmt.Errorf("invalid file descriptor") + ErrInvalidFunction = fmt.Errorf("invalid function") + ErrDoubleInit = fmt.Errorf("double init") + ErrFailedIO = fmt.Errorf("i/o operation failed") + ErrNotInitialized = fmt.Errorf("not initialized") + ErrWASICTX = fmt.Errorf("wasi ctx error") +) + +var mapWASMErrCode = map[WASMErrCode]error{ + NO_ERROR: nil, + GENERAL_ERROR: ErrGeneralError, + INVALID_ARGUMENT: ErrInvalidArgument, + INVALID_CONFIG: ErrInvalidConfig, + INVALID_FD: ErrInvalidFD, + INVALID_FUNCTION: ErrInvalidFunction, + DOUBLE_INIT: ErrDoubleInit, + FAILED_IO: ErrFailedIO, + NOT_INITIALIZED: ErrNotInitialized, + WASICTX_ERR: ErrWASICTX, +} + +// WASMErr returns the error corresponding to the WASM error code. +func WASMErr(code WASMErrCode) error { + if err, ok := mapWASMErrCode[code]; ok { + return err + } + return fmt.Errorf("unrecognized error (%d)", code) +} diff --git a/internal/wasm/net.go b/internal/wasm/net.go new file mode 100644 index 0000000..0919f8a --- /dev/null +++ b/internal/wasm/net.go @@ -0,0 +1,5 @@ +package wasm + +import "github.com/bytecodealliance/wasmtime-go/v13" + +type WASMTIMEStoreIndependentFunction = func(*wasmtime.Caller, []wasmtime.Val) ([]wasmtime.Val, *wasmtime.Trap) diff --git a/internal/wasm/wasi_config.go b/internal/wasm/wasi_config.go new file mode 100644 index 0000000..dad336f --- /dev/null +++ b/internal/wasm/wasi_config.go @@ -0,0 +1,119 @@ +package wasm + +import "github.com/bytecodealliance/wasmtime-go/v13" + +// WASIConfigFactory creates wasmtime.WasiConfig. +// Since WasiConfig cannot be cloned, we will instead save +// all the repeated setup functions in a slice and call them +// on newly created wasmtime.WasiConfig when needed. +type WASIConfigFactory struct { + setupFuncs []func(*wasmtime.WasiConfig) error // if any of these functions returns an error, the whole setup will fail. +} + +func NewWasiConfigFactory() *WASIConfigFactory { + return &WASIConfigFactory{ + setupFuncs: make([]func(*wasmtime.WasiConfig) error, 0), + } +} + +func (wcf *WASIConfigFactory) Clone() *WASIConfigFactory { + if wcf == nil || wcf.setupFuncs == nil { + return NewWasiConfigFactory() + } + + clone := &WASIConfigFactory{ + setupFuncs: make([]func(*wasmtime.WasiConfig) error, len(wcf.setupFuncs)), + } + copy(clone.setupFuncs, wcf.setupFuncs) + + return clone +} + +// GetConfig sets up and returns the finished wasmtime.WasiConfig. +// +// If the setup fails, it will return nil and an error. +func (wcf *WASIConfigFactory) GetConfig() (*wasmtime.WasiConfig, error) { + wasiConfig := wasmtime.NewWasiConfig() + if wcf != nil && wcf.setupFuncs != nil { + for _, f := range wcf.setupFuncs { + if err := f(wasiConfig); err != nil { + return nil, err + } + } + } + return wasiConfig, nil +} + +func (wcf *WASIConfigFactory) SetArgv(argv []string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.SetArgv(argv) + return nil + }) +} + +func (wcf *WASIConfigFactory) InheritArgv() { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.InheritArgv() + return nil + }) +} + +func (wcf *WASIConfigFactory) SetEnv(keys, values []string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.SetEnv(keys, values) + return nil + }) +} + +func (wcf *WASIConfigFactory) InheritEnv() { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.InheritEnv() + return nil + }) +} + +func (wcf *WASIConfigFactory) SetStdinFile(path string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + return wasiConfig.SetStdinFile(path) + }) +} + +func (wcf *WASIConfigFactory) InheritStdin() { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.InheritStdin() + return nil + }) +} + +func (wcf *WASIConfigFactory) SetStdoutFile(path string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + return wasiConfig.SetStdoutFile(path) + }) +} + +func (wcf *WASIConfigFactory) InheritStdout() { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.InheritStdout() + return nil + }) +} + +func (wcf *WASIConfigFactory) SetStderrFile(path string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + return wasiConfig.SetStderrFile(path) + }) +} + +func (wcf *WASIConfigFactory) InheritStderr() { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.InheritStderr() + return nil + }) +} + +func (wcf *WASIConfigFactory) SetPreopenDir(path string, guestPath string) { + wcf.setupFuncs = append(wcf.setupFuncs, func(wasiConfig *wasmtime.WasiConfig) error { + wasiConfig.PreopenDir(path, guestPath) + return nil + }) +} diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..66d7f2c --- /dev/null +++ b/listener.go @@ -0,0 +1,110 @@ +package water + +import ( + "fmt" + "net" + "sync/atomic" +) + +// Listener listens on a local network address and upon caller +// calling Accept(), it accepts an incoming connection and +// passes it to the WASM module, which returns a net.Conn to +// caller. +// +// The structure of a Listener is as follows: +// +// +---------------+ accept +---------------+ accept +// ---->| |------->| Decode |-------> +// Source | net.Listener | | WASM Runtime | Caller +// <----| |<-------| Decode/Encode |<------- +// +---------------+ +---------------+ +// \ / +// \------Listener------/ +// +// As shown above, a Listener consists of a net.Listener to accept +// incoming connections and a WASM runtime to handle the incoming +// connections from an external source. The WASM runtime will return +// a net.Conn that caller can Read() from or Write() to. +// +// The WASM module used by a Listener must implement a WASMListener. +type Listener struct { + Config *Config + closed *atomic.Bool +} + +// ListenConfig listens on the network address and returns a Listener +// configured with the given Config. +// +// This is the recommended way to create a Listener, unless there are +// other requirements such as supplying a custom net.Listener. In that +// case, a Listener could be created with WrapListener() with a Config +// specifying a custom net.Listener. +func (c *Config) Listen(network, address string) (net.Listener, error) { + lis, err := net.Listen(network, address) + if err != nil { + return nil, err + } + + config := c.Clone() + config.NetworkListener = lis + + return &Listener{ + Config: config, + closed: new(atomic.Bool), + }, nil +} + +// WrapListener creates a Listener with the given Config. +// +// The Config must specify a custom net.Listener, otherwise the +// Accept() method will fail. +func WrapListener(config *Config) *Listener { + return &Listener{ + Config: config, + closed: new(atomic.Bool), + } +} + +// Accept waits for and returns the next connection after processing +// the data with the WASM module. +// +// The returned net.Conn implements net.Conn and could be seen as +// the inbound connection with a wrapping transport protocol handled +// by the WASM module. +// +// Implements net.Listener. +func (l *Listener) Accept() (net.Conn, error) { + if l.closed.Load() { + return nil, fmt.Errorf("water: listener is closed") + } + + if l.Config == nil { + return nil, fmt.Errorf("water: dialing with nil config is not allowed") + } + + var core *core + var err error + core, err = Core(l.Config) + if err != nil { + return nil, err + } + + return core.AcceptVersion() +} + +// Close closes the listener. +// +// Implements net.Listener. +func (l *Listener) Close() error { + if l.closed.CompareAndSwap(false, true) { + return l.Config.NetworkListener.Close() + } + return nil +} + +// Addr returns the listener's network address. +// +// Implements net.Listener. +func (l *Listener) Addr() net.Addr { + return l.Config.NetworkListener.Addr() +} diff --git a/nogcfix.go b/nogcfix.go new file mode 100644 index 0000000..f925e89 --- /dev/null +++ b/nogcfix.go @@ -0,0 +1,8 @@ +//go:build nogcfix + +package water + +// If the program is compiled with `go build -tags nogcfix`, the +// GC fix mentioned in gcfix.go will not be applied. Unexpected +// GC behavior is expected ;) +const GCFIX bool = false diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..13783a0 --- /dev/null +++ b/relay.go @@ -0,0 +1,29 @@ +//go:build v1 + +package water + +// Relay listens on a local network address and handles requests +// on incoming connections by passing the incoming connection to +// the WASM module and dial corresponding outbound connections +// to the pre-defined destination address, which can either be a +// remote TCP/UDP address or a unix socket. +// +// The structure of a Relay is as follows: +// +// accept +---------------+ +---------------+ dial +// ------->| |----->| Decode |-----> +// Source | net.Listener | | WASM Runtime | Remote +// <-------| |<-----| Decode/Encode |<----- +// +---------------+ +---------------+ +// \ / +// \------Relay-------/ +// +// As shown above, a Relay consists of a net.Listener to accept +// incoming connections and a WASM runtime to handle the incoming +// connections from an external source. The WASM runtime will dial +// the corresponding outbound connections to a pre-defined +// destination address. It requires no further caller interaction +// once it is started. +// +// The WASM module used by a Relay must implement a WASMDialer. +type Relay struct{} diff --git a/runtime_test.go b/runtime_test.go deleted file mode 100644 index d821f25..0000000 --- a/runtime_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package water - -import ( - "bytes" - "crypto/rand" - "net" - "os" - "testing" -) - -func TestRuntimeConnDialer(t *testing.T) { - t.Run("testRuntimeConnDialerNoBG", testRuntimeConnDialerNoBG) - if t.Failed() { - t.Run("testRuntimeConnDialerNoBGGranularity", testRuntimeConnDialerNoBGGranularity) - } -} - -// this testcase directly calls (*RuntimeConnDialer).Dial() and -// fails the entire test suite if the call fails. -// -// It tests a RuntimeConnDialer that spawns no background goroutines. -func testRuntimeConnDialerNoBG(t *testing.T) { - // listen on a local TCP port - tcpListener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - defer tcpListener.Close() - - // accept connections - go func() { - for { - conn, err := tcpListener.Accept() - if err != nil { - return - } - go func(c net.Conn) { - defer func() { - // t.Logf("TCP server connection closing") - c.Close() - }() - for { - // read from conn, write back - buf := make([]byte, 1024) - n, err := c.Read(buf) - if err != nil { - return - } - - t.Logf("TCP server reads: %x", buf[:n]) - - _, err = c.Write(buf[:n]) - if err != nil { - return - } - - t.Logf("TCP server writes: %x", buf[:n]) - } - }(conn) - } - }() - - rd := &RuntimeConnDialer{} - rd.DebugMode() - t.Logf("listening on %s", tcpListener.Addr().String()) - _, err = rd.Dial(tcpListener.Addr().String()) - if err == nil { - t.Fatal("expected error, got nil") - } - - // load WASI binary from testdata - wasi, err := os.ReadFile("testdata/wasi_template.wasi.wasm") - if err != nil { - t.Fatal(err) - } - rd.Config = &Config{ - WASI: wasi, - } - - // dial again - conn, err := rd.Dial(tcpListener.Addr().String()) - if err != nil { - t.Error(err) - t.Fail() - return - } - defer conn.Close() - - // communication test: write 10 random messages and read back - for i := 0; i < 10; i++ { - var msg []byte = make([]byte, 64) - n, err := rand.Read(msg) - if err != nil { - t.Fatal(err) - } - - _, err = conn.Write(msg[:n]) - if err != nil { - t.Fatal(err) - } - - t.Logf("TCP client writes: %x", msg[:n]) - - buf := make([]byte, 1024) - n, err = conn.Read(buf) - if err != nil { - t.Fatal(err) - } - - t.Logf("TCP client reads: %x", buf[:n]) - - if bytes.Equal(msg[:n], buf[:n]) { - t.Log("TCP client: message echoed") - } else { - t.Fatal("TCP client: message not echoed") - } - } - - return -} - -func testRuntimeConnDialerNoBGGranularity(t *testing.T) { - // TODO: implement this for granular testing - t.Skip("not implemented") -} diff --git a/runtimes.go b/runtimes.go deleted file mode 100644 index aed207a..0000000 --- a/runtimes.go +++ /dev/null @@ -1,422 +0,0 @@ -package water - -import ( - "context" - "fmt" - "io" - "log" - "net" - "os" - - "github.com/bytecodealliance/wasmtime-go/v11" - "github.com/gaukas/water/internal/filesocket" -) - -const ( - RUNTIME_VERSION_MAJOR int32 = 0x001aaaaa - RUNTIME_VERSION string = "v0.1-alpha" -) - -type RuntimeConnDialer struct { - Config *Config - - debug bool -} - -func (d *RuntimeConnDialer) DebugMode() { - d.debug = true -} - -func (d *RuntimeConnDialer) Dial(address string) (rc *RuntimeConn, err error) { - return d.DialContext(context.Background(), address) -} - -func (d *RuntimeConnDialer) DialContext(ctx context.Context, address string) (rc *RuntimeConn, err error) { - if d.Config == nil { - return nil, fmt.Errorf("water: dialing with nil config is prohibited") - } - d.Config.init() - - var wasiConfig *wasmtime.WasiConfig - var ok bool - if wasiConfig, ok = ctx.Value("wasi_config").(*wasmtime.WasiConfig); !ok { - wasiConfig = wasmtime.NewWasiConfig() - } - - if d.debug { // bind stdin/stdout/stderr to host - wasiConfig.InheritStdin() - wasiConfig.InheritStdout() - wasiConfig.InheritStderr() - } - - rc = new(RuntimeConn) - if d.debug { - rc.debug = true - } - // preopen the socket directory - err = rc.preopenSocketDir(wasiConfig) - if err != nil { - return nil, fmt.Errorf("water: (*RuntimeConn).preopenSoocketDir retirmed error: %w", err) - } - - // load the WASI module - if rc.engine, ok = ctx.Value("wasm_engine").(*wasmtime.Engine); !ok { - rc.engine = wasmtime.NewEngine() - } - rc.module, err = wasmtime.NewModule(rc.engine, d.Config.WASI) - if err != nil { - return nil, fmt.Errorf("water: wasmtime.NewModule returned error: %w", err) - } - - // create the store - if rc.store, ok = ctx.Value("wasm_store").(*wasmtime.Store); !ok { - rc.store = wasmtime.NewStore(rc.engine) - } - rc.store.SetWasi(wasiConfig) - - // create the linker - if rc.linker, ok = ctx.Value("wasm_linker").(*wasmtime.Linker); !ok { - rc.linker = wasmtime.NewLinker(rc.engine) - } - err = rc.linker.DefineWasi() - if err != nil { - return nil, fmt.Errorf("water: linker.DefineWasi returned error: %w", err) - } - - // link dialer funcs - err = rc.linkDialerFunc(d.Config.Dialer, address) - if err != nil { - return nil, fmt.Errorf("water: (*RuntimeConn).linkDialerFunc returned error: %w", err) - } - - // instantiate the WASI module - rc.instance, err = rc.linker.Instantiate(rc.store, rc.module) - if err != nil { - return nil, fmt.Errorf("water: linker.Instantiate returned error: %w", err) - } - - // check the WASI version - if err = rc._version(); err != nil { - return nil, fmt.Errorf("water: (*RuntimeConn)._version returned error: %w", err) - } - - // run the WASI init function - if err = rc._init(); err != nil { - return nil, fmt.Errorf("water: (*RuntimeConn)._init returned error: %w", err) - } - - // check if the WASI module is single-threaded - if err = rc.finalize(); err != nil { - return nil, fmt.Errorf("water: (*RuntimeConn).finalize returned error: %w", err) - } - - return rc, nil -} - -// RuntimeConn is a net.Conn-like type which runs a WASI module to handle -// one connection. -type RuntimeConn struct { - userWriteDone func(byteLen int) (byteSuccess int, err error) // notify the WASI instance that it should read from user, process the data, and write to net - userWillRead func() (byteLen int, err error) // notify the WASI instance that it should read from net, process the data, and write to user - - debug bool - nonBlockingIO bool // true only if WASI has a blocking-forever loop used for handling I/O - uFs, netFs filesocket.FileSocket - netBundle filesocket.Bundle - deferFunc func() // to be called on Close() - onCloseCallback func() // to be called as Close() returns - - // wasmtime - engine *wasmtime.Engine - module *wasmtime.Module - store *wasmtime.Store - linker *wasmtime.Linker - instance *wasmtime.Instance -} - -func Dial(address string, config *Config) (*RuntimeConn, error) { - rd := &RuntimeConnDialer{ - Config: config, - } - return rd.DialContext(context.Background(), address) -} - -func DialContext(ctx context.Context, address string, config *Config) (*RuntimeConn, error) { - rd := &RuntimeConnDialer{ - Config: config, - } - - return rd.DialContext(ctx, address) -} - -func (rc *RuntimeConn) Write(p []byte) (n int, err error) { - n, err = rc.uFs.Write(p) - if err != nil { - return 0, fmt.Errorf("water: failed to write to WASI module: %w", err) - } - - // single-thread WASI module requires explicit call to write() funcs - if !rc.nonBlockingIO { - nSuccess, err := rc.userWriteDone(n) - if err != nil { - return 0, fmt.Errorf("water: failed to notify WASI module: %w", err) - } - if nSuccess == ErrIO { - return 0, fmt.Errorf("water: WASI module encountered I/O error") - } - - if nSuccess != n { - return 0, fmt.Errorf("water: length written to WASI does not match expected") - } - } - - return -} - -func (rc *RuntimeConn) Read(p []byte) (n int, err error) { - // single-thread WASI module requires explicit call to read() funcs - var nExpect int - if !rc.nonBlockingIO { - nExpect, err = rc.userWillRead() - if err != nil { - return 0, fmt.Errorf("water: failed to notify WASI module: %w", err) - } - - if nExpect == ErrIO { - return 0, fmt.Errorf("water: WASI module encountered I/O error") - } - - if rc.debug { - log.Printf("WASI module expects %d bytes", nExpect) - } - } - - n, err = rc.uFs.Read(p) - if err != nil { - return n, fmt.Errorf("water: failed to read from WASI buffer: %w", err) - } - - // check if short-buffer. No data-loss in this case but still need to notify the caller. - if nExpect > len(p) { - err = io.ErrShortBuffer - } - - return -} - -func (rc *RuntimeConn) Close() error { - if rc.deferFunc != nil { - rc.deferFunc() - } - - if rc.onCloseCallback != nil { - defer rc.onCloseCallback() - } - - if rc.uFs != nil { - rc.uFs.Close() - } - - if rc.netFs != nil { - rc.netFs.Close() - } - - if rc.netBundle != nil { - rc.netBundle.Close() - } - - return nil -} - -// preopenSocketDir preopens a temporary directory on host for the WASI module to -// interact with sockets and creates 4 files in the directory: -// - uin (input from the user, read-only) -// - uout (output to the user, write-only) -// - netrx (net socket RX, read-only) -// - nettx (net socket TX, write-only) -func (rc *RuntimeConn) preopenSocketDir(wasiConfig *wasmtime.WasiConfig) error { - tmpDir, err := os.MkdirTemp("", "water_*") // create a dir with randomized name under os.TempDir() - if err != nil { - return fmt.Errorf("failed to create temporary directory: %w", err) - } - rc.deferFunc = func() { os.RemoveAll(tmpDir) } // remove the temporary directory when wasi expires - - // create the 4 files - uin, err := os.Create(tmpDir + "/uin") - if err != nil { - rc.deferFunc() - return fmt.Errorf("failed to create temporary file: %w", err) - } - uout, err := os.Create(tmpDir + "/uout") - if err != nil { - rc.deferFunc() - return fmt.Errorf("failed to create temporary file: %w", err) - } - rc.uFs = filesocket.NewFileSocket(uout, uin) // user reads from uout, writes to uin to interact with the WASI module - - netrx, err := os.Create(tmpDir + "/netrx") - if err != nil { - rc.deferFunc() - return fmt.Errorf("failed to create temporary file: %w", err) - } - nettx, err := os.Create(tmpDir + "/nettx") - if err != nil { - rc.deferFunc() - return fmt.Errorf("failed to create temporary file: %w", err) - } - rc.netFs = filesocket.NewFileSocket(nettx, netrx) // Runtime reads from nettx, writes to netrx as it relays data to the net.Conn - - // preopen the temporary directory - if rc.debug { - log.Printf("preopening %s as /tmp", tmpDir) - // show what's in the temporary directory - dir, err := os.ReadDir(tmpDir) - if err != nil { - return fmt.Errorf("failed to read temporary directory: %w", err) - } - for _, entry := range dir { - log.Printf("- %s", entry.Name()) - } - } - err = wasiConfig.PreopenDir(tmpDir, "/tmp") - return err -} - -func (rc *RuntimeConn) linkDialerFunc(dialer Dialer, address string) error { - if rc.linker == nil { - return fmt.Errorf("linker not set") - } - - if dialer == nil { - return fmt.Errorf("dialer not set") - } - - var arrNetworks []string = []string{ - "tcp", - "udp", - "tls", // experimental - } - - for _, network := range arrNetworks { - err := func(network string) error { - if err := rc.linker.DefineFunc(rc.store, "env", "connect_"+network, func() int32 { - log.Printf("dialer.Dial(%s, %s)", network, address) - conn, err := dialer.Dial(network, address) - if err != nil { - log.Printf("failed to dial %s: %v", address, err) - return -1 // TODO: remove magic number - } - rc.makeNetBundle(conn) - return 0 // TODO: remove magic number - }); err != nil { - return fmt.Errorf("(*wasmtime.Linker).DefineFunc: %w", err) - } - - return nil - }(network) - - if err != nil { - return err - } - } - - return nil -} - -func (rc *RuntimeConn) makeNetBundle(conn net.Conn) { - rc.netBundle = filesocket.BundleFileSocket(conn, rc.netFs) - // rc.netBundle.OnClose() - rc.netBundle.Start() -} - -func (rc *RuntimeConn) _version() error { - // check the WASI version - versionFunc := rc.instance.GetFunc(rc.store, "_version") - if versionFunc == nil { - return fmt.Errorf("loaded WASI module does not export _version function") - } - version, err := versionFunc.Call(rc.store) - if err != nil { - return err - } - if version, ok := version.(int32); !ok { - return fmt.Errorf("_version function returned non-int32 value") - } else if version != RUNTIME_VERSION_MAJOR { - return fmt.Errorf("WASI module version `v%d` is not compatible with runtime version `%s`!", version, RUNTIME_VERSION) - } - - return nil -} - -func (rc *RuntimeConn) _init() error { - initFunc := rc.instance.GetFunc(rc.store, "_init") - if initFunc == nil { - return fmt.Errorf("loaded WASI module does not export _init function") - } - _, err := initFunc.Call(rc.store) - if err != nil { - return err - } - return nil -} - -func (rc *RuntimeConn) finalize() error { - backgroundWorker := rc.instance.GetFunc(rc.store, "_background_worker") - runBackgroundWorker := rc.instance.GetFunc(rc.store, "_run_background_worker") - if backgroundWorker == nil || runBackgroundWorker == nil { - if rc.debug { - log.Printf("registering callback functions") - } - // single-threaded WASI module, set user_write_ready and user_will_read - // bind instance functions - wasiUserWriteReady := rc.instance.GetFunc(rc.store, "_user_write_done") - if wasiUserWriteReady == nil { - return fmt.Errorf("loaded WASI module does not export either _user_write_ready or _background_worker function") - } - rc.userWriteDone = func(n int) (int, error) { - ret, err := wasiUserWriteReady.Call(rc.store, int32(n)) - if err != nil { - return 0, err - } - return int(ret.(int32)), nil - } - - wasiUserWillRead := rc.instance.GetFunc(rc.store, "_user_will_read") - if wasiUserWillRead == nil { - return fmt.Errorf("loaded WASI module does not export either _user_will_read or _background_worker function") - } - rc.userWillRead = func() (int, error) { - ret, err := wasiUserWillRead.Call(rc.store) - if err != nil { - return 0, err - } - return int(ret.(int32)), nil - } - } else { - if rc.debug { - log.Printf("spawning background workers") - } - // call _background_worker to get the number of background workers needed - bgWorkerNum, err := backgroundWorker.Call(rc.store) - if err != nil { - return fmt.Errorf("errored upon calling _background_worker function: %w", err) - } - if bgWorkerNum, ok := bgWorkerNum.(int32); !ok { - return fmt.Errorf("_background_worker function returned non-int32 value") - } else { - var i int32 - for i = 0; i < bgWorkerNum; i++ { - // spawn thread for background_worker - go func(tid int32) { - _, err := runBackgroundWorker.Call(rc.store, tid) - if err != nil { - panic(fmt.Errorf("errored upon calling _run_background_worker function: %w", err)) - } - }(i) - } - } - rc.nonBlockingIO = true - } - - return nil -} diff --git a/testdata/README.md b/testdata/README.md deleted file mode 100644 index 8ff4b97..0000000 --- a/testdata/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Testdata - -## wasi_template.wasi.wasm - -This is the most-basic WASI module that can be created. It does nothing other than reversing everything sent from the client and write to server, and vice versa. - -Build from branch `wasi-template` of this repo. \ No newline at end of file diff --git a/testdata/hexencoder_v0.dialer.json b/testdata/hexencoder_v0.dialer.json new file mode 100644 index 0000000..b1ee3f0 --- /dev/null +++ b/testdata/hexencoder_v0.dialer.json @@ -0,0 +1,4 @@ +{ + "role": "dialer", + "mode": "uppercase" +} \ No newline at end of file diff --git a/testdata/hexencoder_v0.listener.json b/testdata/hexencoder_v0.listener.json new file mode 100644 index 0000000..c7b49c2 --- /dev/null +++ b/testdata/hexencoder_v0.listener.json @@ -0,0 +1,4 @@ +{ + "role": "listener", + "mode": "lowercase" +} \ No newline at end of file diff --git a/testdata/hexencoder_v0.wasm b/testdata/hexencoder_v0.wasm new file mode 100644 index 0000000..755133d Binary files /dev/null and b/testdata/hexencoder_v0.wasm differ diff --git a/testdata/plain_v0.wasm b/testdata/plain_v0.wasm new file mode 100644 index 0000000..011021a Binary files /dev/null and b/testdata/plain_v0.wasm differ diff --git a/testdata/wasi_template.wasi.wasm b/testdata/wasi_template.wasi.wasm deleted file mode 100644 index f09e1ae..0000000 Binary files a/testdata/wasi_template.wasi.wasm and /dev/null differ diff --git a/wasm_api_v0.go b/wasm_api_v0.go new file mode 100644 index 0000000..d1a3490 --- /dev/null +++ b/wasm_api_v0.go @@ -0,0 +1,364 @@ +//go:build !nov0 + +package water + +import ( + "fmt" + "net" + "os" + "runtime" + "sync" + + "github.com/bytecodealliance/wasmtime-go/v13" + "github.com/gaukas/water/internal/socket" + v0 "github.com/gaukas/water/internal/v0" + "github.com/gaukas/water/internal/wasm" +) + +// WASMv0 is a wrapper around core which provides extended functionalities +// for WASM runtime in V0 spec. +type WASMv0 struct { + *core + + _init *wasmtime.Func // _init() -> i32 + + // _dial: + // - Calls to `env.host_dial() -> fd: i32` to dial a network connection (wrapped with the + // application protocol) and bind it to one of its file descriptors, record the fd as + // `remoteConnFd`. This will be the fd it used to read/write data from/to the remote + // destination. + // - Records the `callerConnFd`. This will be the fd it used to read/write data from/to + // the caller. + // - Returns `remoteConnFd` to the caller to be kept track of. + _dial *wasmtime.Func // _dial(callerConnFd i32) (remoteConnFd i32) + + // _accept: + // - Calls to `env.host_accept() -> fd: i32` to accept a network connection (wrapped with the + // application protocol) and bind it to one of its file descriptors, record the fd as + // `sourceConnFd`. This will be the fd it used to read/write data from/to the source + // address. + // - Records the `callerConnFd`. This will be the fd it used to read/write data from/to + // the caller. + // - Returns `sourceConnFd` to the caller to be kept track of. + _accept *wasmtime.Func // _accept(callerConnFd i32) (sourceConnFd i32) + + // _read: + // - if `callerConnFd` is invalid, this will return an error. + // - if `sourceConnFd` is valid, this will read from `sourceConnFd` and write to `callerConnFd`. + // - if `remoteConnFd` is valid, this will read from `remoteConnFd` and write to `callerConnFd`. + _read *wasmtime.Func // _read() (err int32) + + // _write: + // - if `callerConnFd` is invalid, this will return an error. + // - if `sourceConnFd` is valid, this will read from `callerConnFd` and write to `sourceConnFd`. + // - if `remoteConnFd` is valid, this will read from `callerConnFd` and write to `remoteConnFd`. + _write *wasmtime.Func // _write() (err int32) + + // _close: + // - Closes the all the file descriptors it owns. + // - Cleans up any other resouce it allocated within the WASM module. + // - Calls back to runtime by calling `env.host_defer` for the runtime to self-clean. + _close *wasmtime.Func + + dialer *v0.WASIDialer + listener *v0.WASIListener + + gcfixOnce *sync.Once + pushedConn map[int32]*struct { + conn net.Conn + file *os.File + } + + deferOnce *sync.Once + deferredFuncs []func() +} + +func NewWASMv0(core *core) *WASMv0 { + wasm := &WASMv0{ + core: core, + gcfixOnce: new(sync.Once), + pushedConn: make(map[int32]*struct { + conn net.Conn + file *os.File + }), + deferOnce: new(sync.Once), + deferredFuncs: make([]func(), 0), + } + + runtime.SetFinalizer(wasm, func(w *WASMv0) { + w.DeferAll() + w.Cleanup() + }) + + return wasm +} + +func (w *WASMv0) LinkNetworkInterface(dialer *v0.WASIDialer, listener *v0.WASIListener) error { + if w.Linker() == nil { + return fmt.Errorf("water: linker not set, is Core initialized?") + } + + // import host_dial + if dialer != nil { + if err := w.Linker().FuncNew("env", "host_dial", v0.WASIConnectFuncType, dialer.WrappedDial()); err != nil { + return fmt.Errorf("water: linking WASI dialer, (*wasmtime.Linker).FuncNew: %w", err) + } + } else { + if err := w.Linker().FuncNew("env", "host_dial", v0.WASIConnectFuncType, v0.WrappedNopWASIConnectFunc()); err != nil { + return fmt.Errorf("water: linking NOP dialer, (*wasmtime.Linker).FuncNew: %w", err) + } + } + w.dialer = dialer + + // import host_accept + if listener != nil { + if err := w.Linker().FuncNew("env", "host_accept", v0.WASIConnectFuncType, listener.WrappedAccept()); err != nil { + return fmt.Errorf("water: linking WASI listener, (*wasmtime.Linker).FuncNew: %w", err) + } + } else { + if err := w.Linker().FuncNew("env", "host_accept", v0.WASIConnectFuncType, v0.WrappedNopWASIConnectFunc()); err != nil { + return fmt.Errorf("water: linking NOP listener, (*wasmtime.Linker).FuncNew: %w", err) + } + } + w.listener = listener + + return nil +} + +// Initialize initializes the WASMv0 runtime by getting all the exported functions from +// the WASM module. +// +// All imports must be set before calling this function. +func (w *WASMv0) Initialize() error { + if w.core == nil { + return fmt.Errorf("water: no core loaded") + } + + var err error + // import host_defer function + if err = w.Linker().FuncWrap("env", "host_defer", func() { + w.DeferAll() + }); err != nil { + return fmt.Errorf("water: linking deferh function, (*wasmtime.Linker).FuncWrap: %w", err) + } + + // import pull_config function (it is called pushConfig here in the host) + if err := w.Linker().FuncNew("env", "pull_config", v0.WASIConnectFuncType, v0.WrapConnectFunc(w.pushConfig)); err != nil { + return fmt.Errorf("water: linking pull_config function, (*wasmtime.Linker).FuncNew: %w", err) + } + + // instantiate the WASM module + if err = w.Instantiate(); err != nil { + return err + } + + // _init + w._init = w.Instance().GetFunc(w.Store(), "_init") + if w._init == nil { + return fmt.Errorf("water: WASM module does not export _init") + } + + // _dial + w._dial = w.Instance().GetFunc(w.Store(), "_dial") + // if w._dial == nil { + // return fmt.Errorf("water: WASM module does not export _dial") + // } + + // _accept + w._accept = w.Instance().GetFunc(w.Store(), "_accept") + // if w._accept == nil { + // return fmt.Errorf("water: WASM module does not export _accept") + // } + + // _close + w._close = w.Instance().GetFunc(w.Store(), "_close") + if w._close == nil { + return fmt.Errorf("water: WASM module does not export _close") + } + + // call _init + ret, err := w._init.Call(w.Store()) + if err != nil { + return fmt.Errorf("water: calling _init function returned error: %w", err) + } + + return wasm.WASMErr(ret.(int32)) +} + +// Caller need to make sure anything caller writes to the WASM module is +// readable on the callerConn. +func (w *WASMv0) InitializeReadWriter() error { + // _read + w._read = w.Instance().GetFunc(w.Store(), "_read") + if w._read == nil { + return fmt.Errorf("water: WASM module does not export _read") + } + + // _write + w._write = w.Instance().GetFunc(w.Store(), "_write") + if w._write == nil { + return fmt.Errorf("water: WASM module does not export _write") + } + + return nil +} + +func (w *WASMv0) DialFrom(callerConn net.Conn) (destConn net.Conn, err error) { + callerFd, err := w.PushConn(callerConn) + if err != nil { + return nil, fmt.Errorf("water: pushing caller conn to store failed: %w", err) + } + + ret, err := w._dial.Call(w.Store(), callerFd) + if err != nil { + return nil, fmt.Errorf("water: calling _dial function returned error: %w", err) + } + + if remoteFd, ok := ret.(int32); !ok { + return nil, fmt.Errorf("water: invalid _dial function signature") + } else { + if remoteFd < 0 { + return nil, wasm.WASMErr(remoteFd) + } else { + destConn := w.dialer.GetConnByFd(remoteFd) + if destConn == nil { + return nil, fmt.Errorf("water: failed to look up network connection by fd") + } + return destConn, nil + } + } +} + +func (w *WASMv0) AcceptFor(callerConn net.Conn) (sourceConn net.Conn, err error) { + callerFd, err := w.PushConn(callerConn) + if err != nil { + return nil, fmt.Errorf("water: pushing caller conn to store failed: %w", err) + } + + ret, err := w._accept.Call(w.Store(), callerFd) + if err != nil { + return nil, fmt.Errorf("water: calling _accept function returned error: %w", err) + } + + if sourceFd, ok := ret.(int32); !ok { + return nil, fmt.Errorf("water: invalid _accept function signature") + } else { + if sourceFd < 0 { + return nil, wasm.WASMErr(sourceFd) + } else { + sourceConn := w.listener.GetConnByFd(sourceFd) + if sourceConn == nil { + return nil, fmt.Errorf("water: failed to look up network connection by fd") + } + return sourceConn, nil + } + } +} + +func (w *WASMv0) PushConn(conn net.Conn) (fd int32, err error) { + w.gcfixOnce.Do(func() { + if GCFIX { + // create temp file + var f *os.File + f, err = os.CreateTemp("", "water-gcfix") + if err != nil { + return + } + + // push dummy file + fd, err := w.Store().PushFile(f, wasmtime.READ_ONLY) + if err != nil { + return + } + + // save dummy file to map + w.pushedConn[int32(fd)] = &struct { + conn net.Conn + file *os.File + }{ + conn: nil, + file: f, + } + } + }) + + if err != nil { + return 0, fmt.Errorf("water: creating temp file for GC fix: %w", err) + } + + connFile, err := socket.AsFile(conn) + if err != nil { + return 0, fmt.Errorf("water: converting conn to file failed: %w", err) + } + + fdu32, err := w.store.PushFile(connFile, wasmtime.READ_WRITE) + if err != nil { + return 0, fmt.Errorf("water: pushing conn file to store failed: %w", err) + } + fd = int32(fdu32) + + w.pushedConn[fd] = &struct { + conn net.Conn + file *os.File + }{ + conn: conn, + file: connFile, + } + + return fd, nil +} + +func (w *WASMv0) DeferAll() { + w.deferOnce.Do(func() { // execute all deferred functions if not yet executed + for _, f := range w.deferredFuncs { + f() + } + }) +} + +func (w *WASMv0) Defer(f func()) { + w.deferredFuncs = append(w.deferredFuncs, f) +} + +func (w *WASMv0) Cleanup() { + // clean up pushed files + var keyList []int32 + for k, v := range w.pushedConn { + if v != nil { + if v.file != nil { + v.file.Close() + v.file = nil + } + if v.conn != nil { + v.conn.Close() + v.conn = nil + } + } + keyList = append(keyList, k) + } + for _, k := range keyList { + delete(w.pushedConn, k) + } + + // clean up deferred functions + w.deferredFuncs = nil + + w.dialer.CloseAllConn() + w.listener.CloseAllConn() +} + +func (w *WASMv0) pushConfig(caller *wasmtime.Caller) (int32, error) { + // get config file + configFile := w.Config().WATMConfig.File() + if configFile == nil { + return wasm.INVALID_FD, nil // we don't return error here so no trap is triggered + } + + // push file to WASM + configFd, err := caller.PushFile(configFile, wasmtime.READ_ONLY) + if err != nil { + return wasm.INVALID_FD, err + } + + return int32(configFd), nil +}