Skip to content

Commit

Permalink
#1562, fix tests, fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
vsychov committed Sep 27, 2023
1 parent 074d2fa commit eefd9ee
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562)

## 0.22.3 (2023-05-12)

### Changes
Expand Down
11 changes: 9 additions & 2 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ func (h *Headscale) handleAuthKey(
Msg("node was already registered before, refreshing with new auth key")

node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
pakID := uint(pak.ID)
if pakID != 0 {
node.AuthKeyID = &pakID
}

err := h.db.NodeSetExpiry(node, registerRequest.Expiry)
if err != nil {
log.Error().
Expand Down Expand Up @@ -364,10 +368,13 @@ func (h *Headscale) handleAuthKey(
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.Proto().AclTags,
}

pakID := uint(pak.ID)
if pakID != 0 {
nodeToRegister.AuthKeyID = &pakID
}
node, err = h.db.RegisterNode(
nodeToRegister,
)
Expand Down
9 changes: 6 additions & 3 deletions hscontrol/db/addresses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -41,7 +42,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
IPAddresses: ips,
}
db.db.Save(&node)
Expand Down Expand Up @@ -81,6 +82,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := types.Node{
ID: uint64(index),
MachineKey: "foo",
Expand All @@ -89,7 +91,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
IPAddresses: ips,
}
db.db.Save(&node)
Expand Down Expand Up @@ -171,6 +173,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -179,7 +182,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)

Expand Down
63 changes: 42 additions & 21 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func (s *Suite) TestGetNode(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -33,9 +34,10 @@ func (s *Suite) TestGetNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(node)
trx := db.db.Save(node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNode("test", "testnode")
c.Assert(err, check.IsNil)
Expand All @@ -51,6 +53,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -59,9 +62,10 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil)
Expand All @@ -80,6 +84,7 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
Expand All @@ -88,9 +93,10 @@ func (s *Suite) TestGetNodeByNodeKey(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNodeByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
Expand All @@ -111,6 +117,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {

machineKey := key.NewMachine()

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
Expand All @@ -119,9 +126,10 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
Expand All @@ -138,9 +146,9 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
Hostname: "testnode3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

err = db.DeleteNode(&node)
c.Assert(err, check.IsNil)
Expand All @@ -159,6 +167,7 @@ func (s *Suite) TestListPeers(c *check.C) {
_, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
for index := 0; index <= 10; index++ {
node := types.Node{
ID: uint64(index),
Expand All @@ -168,9 +177,10 @@ func (s *Suite) TestListPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)
}

node0ByID, err := db.GetNodeByID(0)
Expand Down Expand Up @@ -205,6 +215,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.NotNil)

for index := 0; index <= 10; index++ {
pakID := uint(stor[index%2].key.ID)
node := types.Node{
ID: uint64(index),
MachineKey: "foo" + strconv.Itoa(index),
Expand All @@ -216,9 +227,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)
}

aclPolicy := &policy.ACLPolicy{
Expand Down Expand Up @@ -288,6 +300,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -296,7 +309,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
Expiry: &time.Time{},
}
db.db.Save(node)
Expand Down Expand Up @@ -345,6 +358,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
_, err = db.GetNode("user-1", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: "node-key-1",
Expand All @@ -354,9 +368,11 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(node)

trx := db.db.Save(node)
c.Assert(trx.Error, check.IsNil)

givenName, err := db.GenerateGivenName("node-key-2", "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
Expand Down Expand Up @@ -389,6 +405,7 @@ func (s *Suite) TestSetTags(c *check.C) {
_, err = db.GetNode("test", "testnode")
c.Assert(err, check.NotNil)

pakID := uint(pak.ID)
node := &types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -397,9 +414,11 @@ func (s *Suite) TestSetTags(c *check.C) {
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(node)

trx := db.db.Save(node)
c.Assert(trx.Error, check.IsNil)

// assign simple tags
sTags := []string{"tag:test", "tag:foo"}
Expand Down Expand Up @@ -572,6 +591,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
// Check if a subprefix of an autoapproved route is approved
route2 := netip.MustParsePrefix("10.11.0.0/24")

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -580,15 +600,16 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
HostInfo: types.HostInfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
}

db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

err = db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
Expand Down
3 changes: 2 additions & 1 deletion hscontrol/db/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error)
}

nodes := types.Nodes{}
pakID := uint(pak.ID)
if err := hsdb.db.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: uint(pak.ID)}).
Where(&types.Node{AuthKeyID: &pakID}).
Find(&nodes).Error; err != nil {
return nil, err
}
Expand Down
18 changes: 12 additions & 6 deletions hscontrol/db/preauth_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)

pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -83,9 +84,10 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
Expand All @@ -99,6 +101,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil)

pakID := uint(pak.ID)
node := types.Node{
ID: 1,
MachineKey: "foo",
Expand All @@ -107,9 +110,10 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
Expand All @@ -136,6 +140,7 @@ func (*Suite) TestEphemeralKey(c *check.C) {
c.Assert(err, check.IsNil)

now := time.Now().Add(-time.Second * 30)
pakID := uint(pak.ID)
node := types.Node{
ID: 0,
MachineKey: "foo",
Expand All @@ -145,9 +150,10 @@ func (*Suite) TestEphemeralKey(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
AuthKeyID: &pakID,
}
db.db.Save(&node)
trx := db.db.Save(&node)
c.Assert(trx.Error, check.IsNil)

_, err = db.ValidatePreAuthKey(pak.Key)
// Ephemeral keys are by definition reusable
Expand Down
Loading

0 comments on commit eefd9ee

Please sign in to comment.