Skip to content

Commit

Permalink
new: WATM Driver v1
Browse files Browse the repository at this point in the history
v1: rename, reformat, and regulate/standardize the function import/exports.

Signed-off-by: Gaukas Wang <[email protected]>
  • Loading branch information
gaukas committed Apr 5, 2024
1 parent 8fdcf38 commit 3544e16
Show file tree
Hide file tree
Showing 29 changed files with 3,934 additions and 89 deletions.
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type Config struct {
// DialedAddressValidator is an optional field that can be set to validate
// the dialed address. It is only used when WATM specifies the remote
// address to dial.
//
// If not set, all addresses are considered invalid. To allow all addresses,
// simply set this field to a function that always returns nil.
DialedAddressValidator func(network, address string) error

// NetworkListener specifies a net.listener implementation that listens
Expand Down
79 changes: 0 additions & 79 deletions connector.go

This file was deleted.

70 changes: 67 additions & 3 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package water

import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
Expand All @@ -11,7 +14,11 @@ import (
"github.com/refraction-networking/water/internal/log"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental/sys"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"

"github.com/karelbilek/wazero-fs-tools/memfs"
expsysfs "github.com/tetratelabs/wazero/experimental/sysfs"
)

var (
Expand Down Expand Up @@ -110,6 +117,9 @@ type Core interface {
// If the target function is not exported, this function returns an error.
Invoke(funcName string, params ...uint64) (results []uint64, err error)

// ReadIovs reads data from the memory pointed by iovs and writes it to buf.
ReadIovs(iovs, iovsLen int32, buf []byte) (int, error)

// WASIPreview1 enables the WASI preview1 API.
//
// It is recommended that this function only to be invoked if
Expand Down Expand Up @@ -331,10 +341,10 @@ func (c *core) ImportFunction(module, name string, f any) error {
// Unsafe: check if the WebAssembly module really imports this function under
// the given module and name. If not, we warn and skip the import.
if mod, ok := c.ImportedFunctions()[module]; !ok {
log.LDebugf(c.config.Logger(), "water: module %s is not imported.", module)
log.LDebugf(c.config.Logger(), "water: module %s is not imported by the WebAssembly module.", module)
return ErrModuleNotImported
} else if _, ok := mod[name]; !ok {
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported.", module, name)
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported by the WebAssembly module.", module, name)
return ErrFuncNotImported
}

Expand Down Expand Up @@ -370,6 +380,28 @@ func (c *core) Instantiate() (err error) {
}
}

// If TransportModuleConfig is set, we pass the config to the runtime.
if c.config.TransportModuleConfig != nil {
mc := c.config.ModuleConfig()
fsCfg := mc.GetFSConfig()
if fsCfg == nil {
fsCfg = wazero.NewFSConfig()

}

memFS := memfs.New()

err := memFS.WriteFile("watm.cfg", c.config.TransportModuleConfig.AsBytes())
if errors.Is(err, nil) || errors.Is(err, sys.Errno(0)) {
return fmt.Errorf("water: memFS.WriteFile returned error: %w", err)
}

if expFsCfg, ok := fsCfg.(expsysfs.FSConfig); ok {
fsCfg = expFsCfg.WithSysFSMount(memFS, "/conf/")
mc.SetFSConfig(fsCfg)
}
}

if c.instance, err = c.runtime.InstantiateModule(
c.ctx,
c.module,
Expand All @@ -393,9 +425,41 @@ func (c *core) Invoke(funcName string, params ...uint64) (results []uint64, err

results, err = expFunc.Call(c.ctx, params...)
if err != nil {
return nil, fmt.Errorf("water: (*wazero.ExportedFunction).Call returned error: %w", err)
return nil, fmt.Errorf("water: (*wazero.ExportedFunction)%q.Call returned error: %w", funcName, err)
}

return
}

var le = binary.LittleEndian

// adapted from fd_write implementation in wazero
func (c *core) ReadIovs(iovs, iovsLen int32, buf []byte) (n int, err error) {
mem := c.instance.Memory()

iovsStop := uint32(iovsLen) << 3 // iovsCount * 8
iovsBuf, ok := mem.Read(uint32(iovs), iovsStop)
if !ok {
return 0, errors.New("ReadIovs: failed to read iovs from memory")
}

for iovsPos := uint32(0); iovsPos < iovsStop; iovsPos += 8 {
offset := le.Uint32(iovsBuf[iovsPos:])
l := le.Uint32(iovsBuf[iovsPos+4:])

b, ok := mem.Read(offset, l)
if !ok {
return 0, errors.New("ReadIovs: failed to read iov from memory")
}

// Write to buf
nCopied := copy(buf[n:], b)
n += nCopied

if nCopied != len(b) {
return n, io.ErrShortBuffer
}
}
return
}

Expand Down
73 changes: 73 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,76 @@ func NewDialerWithContext(ctx context.Context, c *Config) (Dialer, error) {

return nil, ErrDialerVersionNotFound
}

// FixedDialer acts like a dialer, despite the fact that the destination is managed by
// the WebAssembly Transport Module (WATM) instead of specified by the caller.
//
// In other words, FixedDialer is a dialer that does not take network or address as input
// but returns a connection to a remote network address specified by the WATM.
type FixedDialer interface {
// DialFixed dials a remote network address provided by the WATM
// and returns a superset of net.Conn.
//
// It is recommended to use DialFixedContext instead of Connect. This
// method may be removed in the future.
DialFixed() (Conn, error)

// DialFixedContext dials a remote network address provided by the WATM
// with the given context and returns a superset of net.Conn.
DialFixedContext(ctx context.Context) (Conn, error)

mustEmbedUnimplementedFixedDialer()
}

type newFixedDialerFunc func(context.Context, *Config) (FixedDialer, error)

var (
knownFixedDialerVersions = make(map[string]newFixedDialerFunc)

ErrFixedDialerAlreadyRegistered = errors.New("water: free dialer already registered")
ErrFixedDialerVersionNotFound = errors.New("water: free dialer version not found")
ErrUnimplementedFixedDialer = errors.New("water: unimplemented free dialer")

_ FixedDialer = (*UnimplementedFixedDialer)(nil) // type guard
)

// UnimplementedFixedDialer is a FixedDialer that always returns errors.
//
// It is used to ensure forward compatibility of the FixedDialer interface.
type UnimplementedFixedDialer struct{}

// Connect implements FixedDialer.DialFixed().
func (*UnimplementedFixedDialer) DialFixed() (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

// DialFixedContext implements FixedDialer.DialFixedContext().
func (*UnimplementedFixedDialer) DialFixedContext(_ context.Context) (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

func (*UnimplementedFixedDialer) mustEmbedUnimplementedFixedDialer() {} //nolint:unused

func RegisterWATMFixedDialer(name string, dialer newFixedDialerFunc) error {
if _, ok := knownFixedDialerVersions[name]; ok {
return ErrFixedDialerAlreadyRegistered
}
knownFixedDialerVersions[name] = dialer
return nil
}

func NewFixedDialerWithContext(ctx context.Context, cfg *Config) (FixedDialer, error) {
core, err := NewCoreWithContext(ctx, cfg)
if err != nil {
return nil, err
}

// Sniff the version of the dialer
for exportName := range core.Exports() {
if f, ok := knownFixedDialerVersions[exportName]; ok {
return f(ctx, cfg)
}
}

return nil, ErrFixedDialerVersionNotFound
}
6 changes: 5 additions & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net"

"github.com/refraction-networking/water"
_ "github.com/refraction-networking/water/transport/v0"
_ "github.com/refraction-networking/water/transport/v1"
)

// ExampleDialer demonstrates how to use water.Dialer.
Expand Down Expand Up @@ -66,6 +66,10 @@ func ExampleDialer() {
panic("short read")
}

if err := waterConn.Close(); err != nil {
panic(err)
}

fmt.Println(string(buf[:n]))
// Output: olleh
}
Expand Down
82 changes: 82 additions & 0 deletions fixed_dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package water_test

import (
"context"
"fmt"
"net"

"github.com/refraction-networking/water"
_ "github.com/refraction-networking/water/transport/v1"
)

// ExampleDialer demonstrates how to use water.Dialer.
//
// This example is expected to demonstrate how to use the LATEST version of
// W.A.T.E.R. API, while other older examples could be found under transport/vX,
// where X is the version number (e.g. v0, v1, etc.).
//
// It is worth noting that unless the W.A.T.E.R. API changes, the version upgrade
// does not bring any essential changes to this example other than the import
// path and wasm file path.
// ExampleFixedDialer demonstrates how to use v1.FixedDialer as a water.Dialer.
func ExampleFixedDialer() {
config := &water.Config{
TransportModuleBin: wasmReverse,
ModuleConfigFactory: water.NewWazeroModuleConfigFactory(),
DialedAddressValidator: func(network, address string) error {
if network != "tcp" || address != "localhost:7700" {
return fmt.Errorf("invalid address: %s", address)
}
return nil
},
}

waterDialer, err := water.NewFixedDialerWithContext(context.Background(), config)
if err != nil {
panic(err)
}

// create a local TCP listener
tcpListener, err := net.Listen("tcp", "localhost:7700")
if err != nil {
panic(err)
}
defer tcpListener.Close() // skipcq: GO-S2307

waterConn, err := waterDialer.DialFixedContext(context.Background())
if err != nil {
panic(err)
}
defer waterConn.Close() // skipcq: GO-S2307

tcpConn, err := tcpListener.Accept()
if err != nil {
panic(err)
}
defer tcpConn.Close() // skipcq: GO-S2307

var msg = []byte("hello")
n, err := waterConn.Write(msg)
if err != nil {
panic(err)
}
if n != len(msg) {
panic("short write")
}

buf := make([]byte, 1024)
n, err = tcpConn.Read(buf)
if err != nil {
panic(err)
}
if n != len(msg) {
panic("short read")
}

if err := waterConn.Close(); err != nil {
panic(err)
}

fmt.Println(string(buf[:n]))
// Output: olleh
}
Loading

0 comments on commit 3544e16

Please sign in to comment.