Skip to content

Commit

Permalink
config.go: pull unified Config.apply() out of createNewConfig() and u…
Browse files Browse the repository at this point in the history
…pdate()

as a bonus it ensures returned Config object doesn't have any configuration
values missing
  • Loading branch information
nazarewk committed Feb 15, 2024
1 parent c9a29eb commit a49d1d7
Showing 1 changed file with 114 additions and 103 deletions.
217 changes: 114 additions & 103 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"net/url"
"os"
"reflect"
"strings"

log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
Expand Down Expand Up @@ -35,15 +37,16 @@ var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun",

// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
InterfaceName *string
WireguardPort *int
ManagementURL string
AdminURL string
ConfigPath string
PreSharedKey *string
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
InterfaceName *string
InterfaceBlacklist []string
WireguardPort *int
}

// Config Configuration type
Expand Down Expand Up @@ -88,10 +91,13 @@ func ReadConfig(configPath string) (*Config, error) {
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
}
if _, err := config.init(); err != nil {
return nil, err
}
return config, nil
}

cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
cfg, err := createNewConfig(ConfigInput{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -139,155 +145,160 @@ func WriteOutConfig(path string, config *Config) error {

// createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) {
wgKey := generateKey()
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}

config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
}
config := &Config{}

defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
if _, err := config.apply(input); err != nil {
return nil, err
}

config.ManagementURL = defaultManagementURL
if input.ManagementURL != "" {
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = URL
}

config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}

config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}
return config, nil
}

if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}
func update(input ConfigInput) (*Config, error) {
config := &Config{}

if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}

defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
updated, err := config.apply(input)
if err != nil {
return nil, err
}

config.AdminURL = defaultAdminURL
if input.AdminURL != "" {
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
if updated {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
config.AdminURL = newURL
}

config.IFaceBlackList = defaultInterfaceBlacklist
return config, nil
}

func update(input ConfigInput) (*Config, error) {
config := &Config{}
func (config *Config) init() (updated bool, err error) {
return config.apply(ConfigInput{})
}

if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}

refresh := false

if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
log.Infof("new Management URL provided, updated to %s (old value %s)",
input.ManagementURL, config.ManagementURL)
newURL, err := parseURL("Management URL", input.ManagementURL)
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = newURL
refresh = true
}

if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
input.AdminURL, config.AdminURL)
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
return updated, err
}
config.AdminURL = newURL
refresh = true
updated = true
}

if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}

if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
return false, err
}
config.SSHKey = string(pem)
refresh = true
updated = true
}

if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}

if input.WireguardPort != nil {
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
refresh = true
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}

if input.InterfaceName != nil {
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
refresh = true
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}

if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
if input.NATExternalIPs != nil && reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
refresh = true
updated = true
}

if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}

if input.RosenpassEnabled != nil {
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
updated = true
}

if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}

return config, nil
if input.InterfaceBlacklist != nil && reflect.DeepEqual(input.InterfaceBlacklist, config.IFaceBlackList) {
log.Infof("updating interface blacklist [ %s ] (old value: [ %s ])",
strings.Join(input.InterfaceBlacklist, " "),
strings.Join(config.IFaceBlackList, " "))
config.IFaceBlackList = input.InterfaceBlacklist
updated = true
} else if config.IFaceBlackList == nil {
config.IFaceBlackList = defaultInterfaceBlacklist
updated = true
}
return updated, nil
}

// parseURL parses and validates a service URL
Expand Down

0 comments on commit a49d1d7

Please sign in to comment.