diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2aaef756437..88db8c5e89f 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -42,4 +42,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 524f35f6f47..e1e1ff2362e 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7af6d3e4d94..b2e2437e6bb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: flags: "" steps: diff --git a/README.md b/README.md index aa3ec41e533..270c9ad8707 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +### NetBird on Lawrence Systems (Video) +[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) ### Key features @@ -62,6 +64,7 @@ | | | | | | | | |
  • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
  • | | | | | | | | | + ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 033d1bb6ab8..d998f9ea9e6 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string { signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") + _, mgmLis := startManagement(t, config, "../testdata/store.sql") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c6a96a876cd..c271e592dce 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -21,13 +22,19 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type entry struct { + spec []string + position int +} + type aclManager struct { iptablesClient *iptables.IPTables wgIface iFaceMapper routingFwChainName string - entries map[string][][]string - ipsetStore *ipsetStore + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi wgIface: wgIface, routingFwChainName: routingFwChainName, - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error { } } + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error { } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } +} + func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6fefd58e67e..94bd2fccfe1 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering( } func (m *Manager) AddRouteFiltering( - sources [] netip.Prefix, + sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 737b207854b..e60c352d5c1 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error { log.Debug("flushing routing related tables") for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := tableFilter - if chain == chainRTNAT { - table = tableNat - } + table := r.getTableForChain(chain) ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error { func (r *router) createContainers() error { for _, chain := range []string{chainRTFWD, chainRTNAT} { if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %v", chain, err) + return fmt.Errorf("create chain %s: %w", chain, err) } } if err := r.insertEstablishedRule(chainRTFWD); err != nil { - return fmt.Errorf("insert established rule: %v", err) + return fmt.Errorf("insert established rule: %w", err) + } + + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) } - return r.addJumpRules() + return nil } func (r *router) createAndSetupChain(chain string) error { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index a6185d3708e..556bda0d6b1 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { // GenerateSetName generates a unique name for an ipset based on the given sources. func GenerateSetName(sources []netip.Prefix) string { // sort for consistent naming - sortPrefixes(sources) + SortPrefixes(sources) var sourcesStr strings.Builder for _, src := range sources { @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { return merged } -// sortPrefixes sorts the given slice of netip.Prefix in place. +// SortPrefixes sorts the given slice of netip.Prefix in place. // It sorts first by IP address, then by prefix length (most specific to least specific). -func sortPrefixes(prefixes []netip.Prefix) { +func SortPrefixes(prefixes []netip.Prefix) { sort.Slice(prefixes, func(i, j int) bool { addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) if addrCmp != 0 { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index eaf7fb6a023..61434f03518 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,6 +31,7 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -40,15 +43,14 @@ var ( ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,7 +63,7 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) // and is permanent. Using same connection for both type of operations @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to netbird-rt-fwd - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) err = m.rConn.Flush() if err != nil { @@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) { return fmt.Errorf(flushError, err) } + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) + } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) } @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain { return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa61e18585f..9b8fdbda53d 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -24,7 +25,7 @@ import ( const ( chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" + chainNameRoutingNat = "netbird-rt-postrouting" chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" @@ -149,7 +150,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { } func (r *router) createContainers() error { - r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, @@ -157,25 +157,26 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + prio := *nftables.ChainPriorityNATSource - 1 + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, + Priority: &prio, Type: nftables.ChainTypeNAT, }) r.acceptForwardRules() - err := r.refreshRulesMap() - if err != nil { + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.conn.Flush() - if err != nil { + if err := r.conn.Flush(); err != nil { return fmt.Errorf("nftables: unable to initialize table: %v", err) } + return nil } @@ -188,6 +189,7 @@ func (r *router) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -248,9 +250,18 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - r.rules[string(ruleKey)] = r.conn.AddRule(rule) + rule = r.conn.AddRule(rule) - return ruleKey, r.conn.Flush() + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { @@ -288,6 +299,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + setName := r.findSetNameInRule(nftRule) if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -658,7 +673,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed rules for %s", pair.Destination) + log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index bbf92f3beaf..25b7587ac67 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -314,6 +314,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) { ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + // Check if the rule is in the internal map rule, ok := r.rules[ruleKey.GetRuleID()] assert.True(t, ok, "Rule not found in internal map") @@ -346,10 +350,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) { // Verify actual nftables rule content verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) - - // Clean up - err = r.DeleteRouteRule(ruleKey) - require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index e27fce439fc..8ce73655d5f 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -1,8 +1,11 @@ package id import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/netip" + "strconv" "github.com/netbirdio/netbird/client/firewall/manager" ) @@ -21,5 +24,41 @@ func GenerateRouteRuleKey( dPort *manager.Port, action manager.Action, ) RuleID { - return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3d1983c6bda..74b10ee44fa 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql") if err != nil { t.Fatal(err) return @@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396b3..1b740388d95 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) - return + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() } - wgConfigWorkaround() - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } + if wgProxy != nil { + wgProxy.Work() } - conn.wgProxyICE = wgProxy + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) + return + } + wgConfigWorkaround() conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) - wgConfigWorkaround() - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy + wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) } - return nil, nil, err } - return ep, wgProxy, nil + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() + } } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1d0..e850f4533ce 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840cc..b6a8ac45228 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread -} + ctx context.Context + cancel context.CancelFunc -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool +} +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) + } + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd103..558121cdd5a 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be555f..b88ff3f83c1 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d899..f73500717a9 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) diff --git a/client/server/server_test.go b/client/server/server_test.go index e534ad7e2d6..61bdaf660d2 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.sql b/client/testdata/store.sql new file mode 100644 index 00000000000..ed539548613 --- /dev/null +++ b/client/testdata/store.sql @@ -0,0 +1,36 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); + +COMMIT; diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite deleted file mode 100644 index 118c2bebc9f..00000000000 Binary files a/client/testdata/store.sqlite and /dev/null differ diff --git a/go.mod b/go.mod index e7137ce5bf5..e7e3c17a68a 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,12 +56,12 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible @@ -89,10 +90,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -133,7 +134,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 4563dc9335f..e9bc318d6fd 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/management/client/client_test.go b/management/client/client_test.go index 313a67617db..100b3fcaa12 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -4,7 +4,6 @@ import ( "context" "net" "os" - "path/filepath" "sync" "testing" "time" @@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } - -func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { - t.Helper() - dataDir := t.TempDir() - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) - if err != nil { - t.Fatal(err) - } - - store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil -} diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d631f..719d1a78c1a 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/management/server/account.go b/management/server/account.go index 74981280448..7c84ad1ca1b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -50,8 +51,9 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour - DefaultPeerInactivityExpiration = 10 * time.Minute + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1416,7 +1418,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil @@ -1431,28 +1433,39 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } - if user != nil && user.AppMetadata.WTAccountID == account.Id { + if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", - account.Id, userID) + accountID, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { return err } @@ -1676,48 +1689,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { + if claims.Domain == "" { + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil + } - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { - log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil } - err := am.Store.SaveAccount(ctx, account) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) return err } - return nil + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1725,44 +1759,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } + + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + if err != nil { + return "", err + } + + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true + + err = am.Store.SaveAccount(ctx, newAccount) + if err != nil { + return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) if err != nil { - return nil, err + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) - return account, nil + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1906,7 +1954,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -2092,16 +2140,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -2111,98 +2160,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - if claims.AccountId != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) - } - return claims.AccountId, nil - } return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } + } - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { + // We checked if the domain has a primary account already + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } + if err != nil { + return "", err + } + + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { return "", err } - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } + return userAccountID, nil } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } // We checked if the domain has a primary account already domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } + return claims.AccountId, nil +} - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } +func handleNotFound(err error) error { + if err == nil { + return nil + } - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error - return "", err + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 4951641dd9c..19514dad181 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,22 +465,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { - name string - inputClaims jwtclaims.AuthorizationClaims - inputInitUserParams initUserParams - inputUpdateAttrs bool - inputUpdateClaimAccount bool - testingFunc require.ComparisonAssertionFunc - expectedMSG string - expectedUserRole UserRole - expectedDomainCategory string - expectedDomain string - expectedPrimaryDomainStatus bool - expectedCreatedBy string - expectedUsers []string - } - var ( publicDomain = "public.com" privateDomain = "private.com" @@ -492,143 +476,153 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { UserId: "defaultUser", } - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, - }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - initUnknown := defaultInitAccount initUnknown.DomainCategory = UnknownCategory initUnknown.Domain = unknownDomain - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, - }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, - }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain privateInitAccount.DomainCategory = PrivateCategory - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + testCases := []struct { + name string + inputClaims jwtclaims.AuthorizationClaims + inputInitUserParams initUserParams + inputUpdateAttrs bool + inputUpdateClaimAccount bool + testingFunc require.ComparisonAssertionFunc + expectedMSG string + expectedUserRole UserRole + expectedDomainCategory string + expectedDomain string + expectedPrimaryDomainStatus bool + expectedCreatedBy string + expectedUsers []string + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, + }, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, + }, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, + }, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + } + + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -2738,7 +2732,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 23941495e8b..c7f435b688d 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index f8ab46d8176..dc8765e197f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie } func Test_SyncStatusRace(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") } @@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) { } func testSyncStatusRace(t *testing.T) { t.Helper() + t.Skip() dir := t.TempDir() - mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { t.Skip("Skipping test on CI or Windows") } @@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) { t.Helper() dir := t.TempDir() - mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", diff --git a/management/server/management_test.go b/management/server/management_test.go index ba27dc5e885..d53c177d6b8 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -58,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config, dataDir, "testdata/store.sqlite") + s, listener = startServer(config, dataDir, "testdata/store.sql") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) + store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 7dbd4420c10..8a3fe6eb049 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 49f3b59f40e..c5edb5636ad 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1066,7 +1066,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1131,7 +1131,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1197,7 +1197,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1250,6 +1250,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } diff --git a/management/server/route_test.go b/management/server/route_test.go index fbe0221020a..09cbe53ff53 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d056015d823..de3dfa9455e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "strconv" "strings" "sync" "time" @@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { return nil, err } - conns := runtime.NumCPU() - sql.SetMaxOpenConns(conns) // TODO: make it configurable + + conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS")) + if err != nil { + conns = runtime.NumCPU() + } + sql.SetMaxOpenConns(conns) + + log.Infof("Set max open db connections to %d", conns) if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) @@ -316,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -431,7 +461,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -444,7 +474,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } if key.AccountID == "" { @@ -462,7 +492,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return token.ID, nil @@ -476,7 +506,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if token.UserID == "" { @@ -511,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) @@ -560,7 +604,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -623,7 +667,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if user.AccountID == "" { @@ -640,7 +684,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -658,7 +702,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -676,7 +720,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -689,7 +733,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -702,7 +746,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.NewSetupKeyNotFoundError() + return "", status.NewSetupKeyNotFoundError(result.Error) } if accountID == "" { @@ -723,7 +767,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - return nil, status.Errorf(status.Internal, "issue getting IPs from store") + return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } // Convert the JSON strings to net.IP objects @@ -751,7 +795,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return nil, status.Errorf(status.NotFound, "no peers found for the account") } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } return labels, nil @@ -764,7 +808,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting network from store") + return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } @@ -776,7 +820,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - return nil, status.Errorf(status.Internal, "issue getting peer from store") + return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } return &peer, nil @@ -788,7 +832,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - return nil, status.Errorf(status.Internal, "issue getting settings from store") + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } @@ -904,28 +948,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data return store, nil } -// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, fileStore.InstallationID) - if err != nil { - return nil, err - } - - for _, account := range fileStore.GetAllAccounts(ctx) { - err := store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } - - return store, nil -} - // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewPostgresqlStore(ctx, dsn, metrics) @@ -956,7 +978,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } @@ -988,7 +1010,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - return status.Errorf(status.Internal, "issue finding group 'All'") + return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1000,7 +1022,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All'") + return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil @@ -1014,7 +1036,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - return status.Errorf(status.Internal, "issue finding group") + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1026,7 +1048,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group") + return status.Errorf(status.Internal, "issue updating group: %s", err) } return nil @@ -1039,7 +1061,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - return status.Errorf(status.Internal, "issue adding peer to account") + return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } return nil @@ -1048,7 +1070,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count") + return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 4eed09c69b6..20e812ea709 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -11,14 +11,13 @@ import ( "testing" "time" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/testutil" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" @@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) { } func TestSqlite_SaveAccount_Large(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" { - t.Skip("skip large test on non-linux OS due to environment restrictions") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } + t.Run("SQLite", func(t *testing.T) { - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) runLargeTest(t, store) }) + // create store outside to have a better time counter for the test - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) t.Run("PostgreSQL", func(t *testing.T) { runLargeTest(t, store) }) @@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } func TestMigrate(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newSqliteStore(t) + // TODO: figure out why this fails on postgres + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - err := migrate(context.Background(), store.db) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) { }, } - err = store.db.Save(act).Error + err = store.(*SqlStore).db.Save(act).Error require.NoError(t, err, "Failed to insert Gob data") type route struct { @@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) { Route: route2.Route{ID: "route1"}, } - err = store.db.Save(rt).Error + err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") - err = store.db.Delete(rt).Where("id = ?", "route1").Error + err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error require.NoError(t, err, "Failed to delete Gob data") prefix = netip.MustParsePrefix("12.0.0.0/24") @@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) { Peer: "peer-id", } - err = store.db.Save(nRT).Error + err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -716,63 +734,15 @@ func newAccount(store Store, id int) error { return store.SaveAccount(context.Background(), account) } -func newPostgresqlStore(t *testing.T) *SqlStore { - t.Helper() - - cleanUp, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) - if err != nil { - t.Fatalf("could not initialize postgresql store: %s", err) - } - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { - t.Helper() - - store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) - t.Cleanup(cleanUpQ) - if err != nil { - return nil - } - - cleanUpP, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUpP) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) - require.NoError(t, err) - require.NotNil(t, store) - - return pstore -} - func TestPostgresql_NewStore(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) { } func TestPostgresql_SaveAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { } func TestPostgresql_DeleteAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } func TestPostgresql_SavePeerStatus(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) existingDomain := "test.com" @@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } func TestPostgresql_GetUserByTokenID(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } func TestSqlite_GetTakenIPs(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) { } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { } func TestSqlite_GetAccountNetwork(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { } func TestSqlite_GetSetupKeyBySecret(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1187,11 +1161,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { } func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } + group := &nbgroup.Group{ ID: "group-id", AccountID: "account-id", @@ -1215,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { }) assert.NoError(t, err) } + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} diff --git a/management/server/status/error.go b/management/server/status/error.go index d7fde35b998..29d185216d8 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError() error { - return Errorf(NotFound, "setup key not found") +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +func NewGetAccountFromStoreError(err error) error { + return Errorf(Internal, "issue getting account from store: %s", err) } // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store diff --git a/management/server/store.go b/management/server/store.go index 50bc6afdfd2..131fd8aaab6 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,10 +9,12 @@ import ( "os" "path" "path/filepath" + "runtime" "strings" "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/netbirdio/netbird/dns" @@ -56,9 +58,11 @@ type Store interface { GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -240,28 +244,39 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromSqlite is only used in tests -func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE +// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. +// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database +func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var store *SqlStore - var err error - var cleanUp func() + storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = storeSqliteFileName + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), getGormConfig()) + if err != nil { + return nil, nil, err + } - if filename == "" { - store, err = NewSqliteStore(ctx, dataDir, nil) - cleanUp = func() { - store.Close(ctx) + if filename != "" { + err = loadSQL(db, filename) + if err != nil { + return nil, nil, fmt.Errorf("failed to load SQL file: %v", err) } - } else { - store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) } + + store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create test store: %v", err) + } + cleanUp := func() { + store.Close(ctx) } if kind == PostgresStoreEngine { @@ -284,21 +299,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string return store, cleanUp, nil } -func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) +func loadSQL(db *gorm.DB, filepath string) error { + sqlContent, err := os.ReadFile(filepath) if err != nil { - return nil, nil, err + return err } - store, err := NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err + queries := strings.Split(string(sqlContent), ";") + + for _, query := range queries { + query = strings.TrimSpace(query) + if query != "" { + err := db.Exec(query).Error + if err != nil { + return err + } + } } - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil + return nil } // MigrateFileStoreToSqlite migrates the file store to the SQLite store. diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql new file mode 100644 index 00000000000..b522741e7e0 --- /dev/null +++ b/management/server/testdata/extended-store.sql @@ -0,0 +1,37 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite deleted file mode 100644 index 81aea8118cc..00000000000 Binary files a/management/server/testdata/extended-store.sqlite and /dev/null differ diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql new file mode 100644 index 00000000000..32a59128bf1 --- /dev/null +++ b/management/server/testdata/store.sql @@ -0,0 +1,33 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store.sqlite b/management/server/testdata/store.sqlite deleted file mode 100644 index 5fc746285f0..00000000000 Binary files a/management/server/testdata/store.sqlite and /dev/null differ diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql new file mode 100644 index 00000000000..a9360e9d65c --- /dev/null +++ b/management/server/testdata/store_policy_migrate.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite deleted file mode 100644 index 0c1a491a68d..00000000000 Binary files a/management/server/testdata/store_policy_migrate.sqlite and /dev/null differ diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql new file mode 100644 index 00000000000..100a6470f43 --- /dev/null +++ b/management/server/testdata/store_with_expired_peers.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite deleted file mode 100644 index ed1133211d2..00000000000 Binary files a/management/server/testdata/store_with_expired_peers.sqlite and /dev/null differ diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql new file mode 100644 index 00000000000..69194d62391 --- /dev/null +++ b/management/server/testdata/storev1.sql @@ -0,0 +1,39 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO installations VALUES(1,''); + diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite deleted file mode 100644 index 9a376698e4d..00000000000 Binary files a/management/server/testdata/storev1.sqlite and /dev/null differ diff --git a/management/server/user.go b/management/server/user.go index 38a8ac0c401..71608ef20e1 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -19,10 +19,11 @@ import ( ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } diff --git a/signal/server/signal.go b/signal/server/signal.go index 63cc43bd7ef..305fd052b2e 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" diff --git a/util/file.go b/util/file.go index 8355488c98a..ecaecd22260 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 00000000000..3de7db49bdd --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49bdd..1330e738e8d 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) - - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { - - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - arr := []string{"value1", "value2"} + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } + + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) - - }) - }) - }) - - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) + var result Config - hashSrc := md5.New() - hashDst := md5.New() + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +} diff --git a/util/net/net.go b/util/net/net.go index 61b47dbe7d3..035d7552bc7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,7 +11,8 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" )