diff --git a/pkg/wekafs/apiclient/utils.go b/pkg/wekafs/apiclient/utils.go index 56d43376b..9eede7da8 100644 --- a/pkg/wekafs/apiclient/utils.go +++ b/pkg/wekafs/apiclient/utils.go @@ -2,6 +2,7 @@ package apiclient import ( "encoding/binary" + "errors" "fmt" "github.com/rs/zerolog/log" "hash/fnv" @@ -9,6 +10,7 @@ import ( "os" "reflect" "strings" + "time" ) // ObjectsAreEqual returns true if both ApiObject have same immutable fields (other fields and nil fields are disregarded) @@ -142,3 +144,37 @@ func GetNodeIpAddress() string { } return "127.0.0.1" } + +func GetNodeIpAddressByRouting(targetHost string) (string, error) { + rAddr, err := net.ResolveUDPAddr("udp", targetHost+":80") + if err != nil { + return "", err + } + + // Create a UDP connection to the resolved IP address + conn, err := net.DialUDP("udp", nil, rAddr) + if err != nil { + return "", err + } + defer conn.Close() + + // Set a deadline for the connection + err = conn.SetDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + return "", err + } + + // Get the local address from the UDP connection + localAddr := conn.LocalAddr() + if localAddr == nil { + return "", errors.New("failed to get local address") + } + + // Extract the IP address from the local address + localIP, _, err := net.SplitHostPort(localAddr.String()) + if err != nil { + return "", err + } + + return localIP, nil +} diff --git a/pkg/wekafs/apiclient/utils_test.go b/pkg/wekafs/apiclient/utils_test.go index ede39534c..d15814dd9 100644 --- a/pkg/wekafs/apiclient/utils_test.go +++ b/pkg/wekafs/apiclient/utils_test.go @@ -1,6 +1,7 @@ package apiclient import ( + "github.com/rs/zerolog/log" "testing" "github.com/stretchr/testify/assert" @@ -27,3 +28,22 @@ func TestHashString(t *testing.T) { }) } } + +func TestGetNodeIpAddressByRouting(t *testing.T) { + testCases := []struct { + targetHost string + }{ + {"8.8.8.8"}, + {"1.1.1.1"}, + {"localhost"}, + } + + for _, tc := range testCases { + t.Run(tc.targetHost, func(t *testing.T) { + ip, err := GetNodeIpAddressByRouting(tc.targetHost) + assert.NoError(t, err) + assert.NotEmpty(t, ip) + log.Info().Str("ip", ip).Msg("Node IP address") + }) + } +} diff --git a/pkg/wekafs/nfsmount.go b/pkg/wekafs/nfsmount.go index edb7c5e94..f14a35c62 100644 --- a/pkg/wekafs/nfsmount.go +++ b/pkg/wekafs/nfsmount.go @@ -131,23 +131,30 @@ func (m *nfsMount) doMount(ctx context.Context, apiClient *apiclient.ApiClient, return errors.New("no API client for mount, cannot do NFS mount") } - nodeIP := apiclient.GetNodeIpAddress() - if apiClient.EnsureNfsPermissions(ctx, nodeIP, m.fsName, m.clientGroupName) != nil { - logger.Error().Msg("Failed to ensure NFS permissions") - return errors.New("failed to ensure NFS permissions") - } - if err := m.ensureMountIpAddress(ctx, apiClient); err != nil { logger.Error().Err(err).Msg("Failed to get mount IP address") return err } + nodeIP, err := apiclient.GetNodeIpAddressByRouting(m.mountIpAddress) + if err != nil { + logger.Error().Err(err).Msg("Failed to get routed node IP address, relying on node IP") + nodeIP = apiclient.GetNodeIpAddress() + } + + if apiClient.EnsureNfsPermissions(ctx, nodeIP, m.fsName, m.clientGroupName) != nil { + logger.Error().Msg("Failed to ensure NFS permissions") + return errors.New("failed to ensure NFS permissions") + } + mountTarget := m.mountIpAddress + ":/" + m.fsName logger.Trace(). Strs("mount_options", m.mountOptions.Strings()). Str("mount_target", mountTarget). + Str("mount_ip_address", m.mountIpAddress). Msg("Performing mount") - err := m.kMounter.MountSensitive(mountTarget, m.mountPoint, "nfs", mountOptions.Strings(), mountOptionsSensitive) + + err = m.kMounter.MountSensitive(mountTarget, m.mountPoint, "nfs", mountOptions.Strings(), mountOptionsSensitive) if err != nil { if os.IsNotExist(err) { logger.Error().Err(err).Msg("Mount target not found")