Skip to content

Commit

Permalink
fix postgres constraints, add postgres testing
Browse files Browse the repository at this point in the history
This commit fixes the constraint syntax so it is both valid for
sqlite and postgres.

To validate this, I've added a new postgres testing library and a
helper that will spin up local postgres, setup a db and use it in
the constraints tests. This should also help testing db stuff in
the future.

postgres has been added to the nix dev shell and is now required
for running the unit tests.

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby authored and juanfont committed Nov 23, 2024
1 parent 7d9b430 commit f6276ab
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 21 deletions.
3 changes: 2 additions & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files.
vendorHash = "sha256-Qoqu2k4vvnbRFLmT/v8lI+HCEWqJsHFs8uZRfNmwQpo=";
vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg=";

subPackages = ["cmd/headscale"];

Expand Down Expand Up @@ -102,6 +102,7 @@
ko
yq-go
ripgrep
postgresql

# 'dot' is needed for pprof graphs
# go tool pprof -http=: <source>
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ require (
gorm.io/gorm v1.25.11
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7
zgo.at/zcache/v2 v2.1.0
zombiezen.com/go/postgrestest v1.0.1
)

require (
Expand Down Expand Up @@ -134,6 +135,7 @@ require (
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/lithammer/fuzzysearch v1.1.8 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4=
Expand Down Expand Up @@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs=
zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ=
10 changes: 5 additions & 5 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ func NewHeadscaleDatabase(
// - A provider_identifier is unique
// - A user name is unique if there is no provider_identifier is not set
for _, idx := range []string{
"DROP INDEX IF EXISTS `idx_provider_identifier`",
"DROP INDEX IF EXISTS `idx_name_provider_identifier`",
"CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;",
"CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);",
"CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;",
"DROP INDEX IF EXISTS idx_provider_identifier",
"DROP INDEX IF EXISTS idx_name_provider_identifier",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;",
} {
err = tx.Exec(idx).Error
if err != nil {
Expand Down
29 changes: 20 additions & 9 deletions hscontrol/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"slices"
"sort"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -259,6 +260,16 @@ func emptyCache() *zcache.Cache[string, types.Node] {
return zcache.New[string, types.Node](time.Minute, time.Hour)
}

// requireConstraintFailed checks if the error is a constraint failure with
// either SQLite and PostgreSQL error messages.
func requireConstraintFailed(t *testing.T, err error) {
t.Helper()
require.Error(t, err)
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
}
}

func TestConstraints(t *testing.T) {
tests := []struct {
name string
Expand All @@ -270,9 +281,7 @@ func TestConstraints(t *testing.T) {
_, err := CreateUser(db, "user1")
require.NoError(t, err)
_, err = CreateUser(db, "user1")
require.Error(t, err)
assert.Contains(t, err.Error(), "UNIQUE constraint failed:")
// require.Contains(t, err.Error(), "user already exists")
requireConstraintFailed(t, err)
},
},
{
Expand All @@ -294,8 +303,7 @@ func TestConstraints(t *testing.T) {
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}

err = db.Save(&user).Error
require.Error(t, err)
require.Contains(t, err.Error(), "UNIQUE constraint failed:")
requireConstraintFailed(t, err)
},
},
{
Expand All @@ -317,8 +325,7 @@ func TestConstraints(t *testing.T) {
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}

err = db.Save(&user).Error
require.Error(t, err)
require.Contains(t, err.Error(), "UNIQUE constraint failed:")
requireConstraintFailed(t, err)
},
},
{
Expand Down Expand Up @@ -354,8 +361,12 @@ func TestConstraints(t *testing.T) {
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db, err := newTestDB()
t.Run(tt.name+"-postgres", func(t *testing.T) {
db := newPostgresTestDB(t)
tt.run(t, db.DB.Debug())
})
t.Run(tt.name+"-sqlite", func(t *testing.T) {
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating database: %s", err)
}
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ func TestAutoApproveRoutes(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
adb, err := newTestDB()
adb, err := newSQLiteTestDB()
require.NoError(t, err)
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))

Expand Down Expand Up @@ -692,7 +692,7 @@ func generateRandomNumber(t *testing.T, max int64) int64 {
}

func TestListEphemeralNodes(t *testing.T) {
db, err := newTestDB()
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
Expand Down Expand Up @@ -748,7 +748,7 @@ func TestListEphemeralNodes(t *testing.T) {
}

func TestRenameNode(t *testing.T) {
db, err := newTestDB()
db, err := newSQLiteTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
Expand Down
63 changes: 60 additions & 3 deletions hscontrol/db/suite_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package db

import (
"context"
"log"
"net/url"
"os"
"strconv"
"strings"
"testing"

"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1"
"zombiezen.com/go/postgrestest"
)

func Test(t *testing.T) {
Expand Down Expand Up @@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) {
// }

var err error
db, err = newTestDB()
db, err = newSQLiteTestDB()
if err != nil {
c.Fatal(err)
}
}

func newTestDB() (*HSDatabase, error) {
// TODO(kradalby): make this a t.Helper when we dont depend
// on check test framework.
func newSQLiteTestDB() (*HSDatabase, error) {
var err error
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
Expand All @@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) {

db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Type: types.DatabaseSqlite,
Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
Expand All @@ -67,3 +74,53 @@ func newTestDB() (*HSDatabase, error) {

return db, nil
}

func newPostgresTestDB(t *testing.T) *HSDatabase {
t.Helper()

var err error
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
if err != nil {
t.Fatal(err)
}

log.Printf("database path: %s", tmpDir+"/headscale_test.db")

ctx := context.Background()
srv, err := postgrestest.Start(ctx)
if err != nil {
t.Fatal(err)
}
t.Cleanup(srv.Cleanup)

u, err := srv.CreateDatabase(ctx)
if err != nil {
t.Fatal(err)
}
t.Logf("created local postgres: %s", u)
pu, _ := url.Parse(u)

pass, _ := pu.User.Password()
port, _ := strconv.Atoi(pu.Port())

db, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: types.DatabasePostgres,
Postgres: types.PostgresConfig{
Host: pu.Hostname(),
User: pu.User.Username(),
Name: strings.TrimLeft(pu.Path, "/"),
Pass: pass,
Port: port,
Ssl: "disable",
},
},
"",
emptyCache(),
)
if err != nil {
t.Fatal(err)
}

return db
}

0 comments on commit f6276ab

Please sign in to comment.