From b1beaddb289e3baa532b781a7d3da4efed78e12f Mon Sep 17 00:00:00 2001 From: Stefan Majer Date: Mon, 8 Apr 2024 08:02:12 +0200 Subject: [PATCH 1/7] Fix typos --- CHANGELOG.md | 14 +- README.md | 2 +- config-example.yaml | 4 +- docs/exit-node.md | 2 +- docs/faq.md | 2 +- docs/proposals/001-acls.md | 8 +- docs/remote-cli.md | 6 +- docs/reverse-proxy.md | 2 +- docs/running-headscale-openbsd.md | 4 +- docs/web-ui.md.orig | 23 + flake.nix | 2 +- flake.nix.orig | 178 ++++++ hscontrol/app.go | 2 +- hscontrol/db/node.go | 7 +- hscontrol/db/node.go.orig | 772 +++++++++++++++++++++++++ hscontrol/db/node_test.go | 3 +- hscontrol/db/node_test.go.orig | 625 ++++++++++++++++++++ hscontrol/db/preauth_keys.go | 2 +- hscontrol/derp/server/derp_server.go | 2 +- hscontrol/policy/acls_test.go | 18 +- hscontrol/poll.go | 1 + hscontrol/poll.go.orig | 818 +++++++++++++++++++++++++++ integration/general_test.go | 6 +- integration/scenario.go | 2 +- integration/utils.go | 2 +- 25 files changed, 2463 insertions(+), 44 deletions(-) create mode 100644 docs/web-ui.md.orig create mode 100644 flake.nix.orig create mode 100644 hscontrol/db/node.go.orig create mode 100644 hscontrol/db/node_test.go.orig create mode 100644 hscontrol/poll.go.orig diff --git a/CHANGELOG.md b/CHANGELOG.md index a8e15c0cc0..03516fd684 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) - Change the structure of database configuration, see [config-example.yaml](./config-example.yaml) for the new structure. [#1700](https://github.com/juanfont/headscale/pull/1700) - Old structure has been remove and the configuration _must_ be converted. - - Adds additional configuration for PostgreSQL for setting max open, idle conection and idle connection lifetime. + - Adds additional configuration for PostgreSQL for setting max open, idle connection and idle connection lifetime. - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - The latest supported client is 1.38 @@ -70,7 +70,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ ### Changes - Add environment flags to enable pprof (profiling) [#1382](https://github.com/juanfont/headscale/pull/1382) - - Profiles are continously generated in our integration tests. + - Profiles are continuously generated in our integration tests. - Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391) - Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379) - Replace node filter logic, ensuring nodes with access can see eachother [#1381](https://github.com/juanfont/headscale/pull/1381) @@ -161,7 +161,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - SSH ACLs status: - Support `accept` and `check` (SSH can be enabled and used for connecting and authentication) - Rejecting connections **are not supported**, meaning that if you enable SSH, then assume that _all_ `ssh` connections **will be allowed**. - - If you decied to try this feature, please carefully managed permissions by blocking port `22` with regular ACLs or do _not_ set `--ssh` on your clients. + - If you decided to try this feature, please carefully managed permissions by blocking port `22` with regular ACLs or do _not_ set `--ssh` on your clients. - We are currently improving our testing of the SSH ACLs, help us get an overview by testing and giving feedback. - This feature should be considered dangerous and it is disabled by default. Enable by setting `HEADSCALE_EXPERIMENTAL_FEATURE_SSH=1`. @@ -211,7 +211,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ ### Changes - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) -- Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) +- Fix missing group expansion in function `excludeCorrectlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) - Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725) - Add ability to connect to PostgreSQL via unix socket [#734](https://github.com/juanfont/headscale/pull/734) @@ -231,7 +231,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542) - Fixed spurious calls to setLastStateChangeToNow from ephemeral nodes [#566](https://github.com/juanfont/headscale/pull/566) - Add command for moving nodes between namespaces [#362](https://github.com/juanfont/headscale/issues/362) -- Added more configuration parameters for OpenID Connect (scopes, free-form paramters, domain and user allowlist) +- Added more configuration parameters for OpenID Connect (scopes, free-form parameters, domain and user allowlist) - Add command to set tags on a node [#525](https://github.com/juanfont/headscale/issues/525) - Add command to view tags of nodes [#356](https://github.com/juanfont/headscale/issues/356) - Add --all (-a) flag to enable routes command [#360](https://github.com/juanfont/headscale/issues/360) @@ -279,10 +279,10 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346) - Simplify the code behind registration of machines [#366](https://github.com/juanfont/headscale/pull/366) - - Nodes are now only written to database if they are registrated successfully + - Nodes are now only written to database if they are registered successfully - Fix a limitation in the ACLs that prevented users to write rules with `*` as source [#374](https://github.com/juanfont/headscale/issues/374) - Reduce the overhead of marshal/unmarshal for Hostinfo, routes and endpoints by using specific types in Machine [#371](https://github.com/juanfont/headscale/pull/371) -- Apply normalization function to FQDN on hostnames when hosts registers and retrieve informations [#363](https://github.com/juanfont/headscale/issues/363) +- Apply normalization function to FQDN on hostnames when hosts registers and retrieve information [#363](https://github.com/juanfont/headscale/issues/363) - Fix a bug that prevented the use of `tailscale logout` with OIDC [#508](https://github.com/juanfont/headscale/issues/508) - Added Tailscale repo HEAD and unstable releases channel to the integration tests targets [#513](https://github.com/juanfont/headscale/pull/513) diff --git a/README.md b/README.md index 3087429639..2ee8f4ebdf 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ Please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file. ### Requirements -To contribute to headscale you would need the lastest version of [Go](https://golang.org) +To contribute to headscale you would need the latest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator). We recommend using [Nix](https://nixos.org/) to setup a development environment. This can diff --git a/config-example.yaml b/config-example.yaml index 0f1c2412b5..867f890330 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -105,7 +105,7 @@ derp: automatically_add_embedded_derp_region: true # For better connection stability (especially when using an Exit-Node and DNS is not working), - # it is possible to optionall add the public IPv4 and IPv6 address to the Derp-Map using: + # it is possible to optionally add the public IPv4 and IPv6 address to the Derp-Map using: ipv4: 1.2.3.4 ipv6: 2001:db8::1 @@ -199,7 +199,7 @@ log: format: text level: info -# Path to a file containg ACL policies. +# Path to a file containing ACL policies. # ACLs can be defined as YAML or HUJSON. # https://tailscale.com/kb/1018/acls/ acl_policy_path: "" diff --git a/docs/exit-node.md b/docs/exit-node.md index 898b7811d5..831652b394 100644 --- a/docs/exit-node.md +++ b/docs/exit-node.md @@ -14,7 +14,7 @@ If the node is already registered, it can advertise exit capabilities like this: $ sudo tailscale set --advertise-exit-node ``` -To use a node as an exit node, IP forwarding must be enabled on the node. Check the official [Tailscale documentation](https://tailscale.com/kb/1019/subnets/?tab=linux#enable-ip-forwarding) for how to enable IP fowarding. +To use a node as an exit node, IP forwarding must be enabled on the node. Check the official [Tailscale documentation](https://tailscale.com/kb/1019/subnets/?tab=linux#enable-ip-forwarding) for how to enable IP forwarding. ## On the control server diff --git a/docs/faq.md b/docs/faq.md index fff9613244..ba30911b1d 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -36,7 +36,7 @@ We don't know. We might be working on it. If you want to help, please send us a Please be aware that there are a number of reasons why we might not accept specific contributions: - It is not possible to implement the feature in a way that makes sense in a self-hosted environment. -- Given that we are reverse-engineering Tailscale to satify our own curiosity, we might be interested in implementing the feature ourselves. +- Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves. - You are not sending unit and integration tests with it. ## Do you support Y method of deploying Headscale? diff --git a/docs/proposals/001-acls.md b/docs/proposals/001-acls.md index 8a02e83658..74bcd13e83 100644 --- a/docs/proposals/001-acls.md +++ b/docs/proposals/001-acls.md @@ -58,12 +58,12 @@ A solution could be to consider a headscale server (in it's entirety) as a tailnet. For personal users the default behavior could either allow all communications -between all namespaces (like tailscale) or dissallow all communications between +between all namespaces (like tailscale) or disallow all communications between namespaces (current behavior). For businesses and organisations, viewing a headscale instance a single tailnet would allow users (namespace) to talk to each other with the ACLs. As described -in tailscale's documentation [[1]], a server should be tagged and personnal +in tailscale's documentation [[1]], a server should be tagged and personal devices should be tied to a user. Translated in headscale's terms each user can have multiple devices and all those devices should be in the same namespace. The servers should be tagged and used as such. @@ -88,7 +88,7 @@ the ability to rules in either format (HuJSON or YAML). Let's build an example use case for a small business (It may be the place where ACL's are the most useful). -We have a small company with a boss, an admin, two developper and an intern. +We have a small company with a boss, an admin, two developer and an intern. The boss should have access to all servers but not to the users hosts. Admin should also have access to all hosts except that their permissions should be @@ -173,7 +173,7 @@ need to add the following ACLs "ports": ["prod:*", "dev:*", "internal:*"] }, - // admin have access to adminstration port (lets only consider port 22 here) + // admin have access to administration port (lets only consider port 22 here) { "action": "accept", "users": ["group:admin"], diff --git a/docs/remote-cli.md b/docs/remote-cli.md index 96a6333a73..3d44eabc25 100644 --- a/docs/remote-cli.md +++ b/docs/remote-cli.md @@ -1,13 +1,13 @@ # Controlling `headscale` with remote CLI -## Prerequisit +## Prerequisite - A workstation to run `headscale` (could be Linux, macOS, other supported platforms) - A `headscale` server (version `0.13.0` or newer) - Access to create API keys (local access to the `headscale` server) - `headscale` _must_ be served over TLS/HTTPS - Remote access does _not_ support unencrypted traffic. -- Port `50443` must be open in the firewall (or port overriden by `grpc_listen_addr` option) +- Port `50443` must be open in the firewall (or port overridden by `grpc_listen_addr` option) ## Goal @@ -97,4 +97,4 @@ Checklist: - Make sure you use version `0.13.0` or newer. - Verify that your TLS certificate is valid and trusted - If you do not have access to a trusted certificate (e.g. from Let's Encrypt), add your self signed certificate to the trust store of your OS or - - Set `HEADSCALE_CLI_INSECURE` to 0 in your environement + - Set `HEADSCALE_CLI_INSECURE` to 0 in your environment diff --git a/docs/reverse-proxy.md b/docs/reverse-proxy.md index 1f417c9bd9..c6fd4b1635 100644 --- a/docs/reverse-proxy.md +++ b/docs/reverse-proxy.md @@ -115,7 +115,7 @@ The following Caddyfile is all that is necessary to use Caddy as a reverse proxy } ``` -Caddy v2 will [automatically](https://caddyserver.com/docs/automatic-https) provision a certficate for your domain/subdomain, force HTTPS, and proxy websockets - no further configuration is necessary. +Caddy v2 will [automatically](https://caddyserver.com/docs/automatic-https) provision a certificate for your domain/subdomain, force HTTPS, and proxy websockets - no further configuration is necessary. For a slightly more complex configuration which utilizes Docker containers to manage Caddy, Headscale, and Headscale-UI, [Guru Computing's guide](https://blog.gurucomputing.com.au/smart-vpns-with-headscale/) is an excellent reference. diff --git a/docs/running-headscale-openbsd.md b/docs/running-headscale-openbsd.md index a490439aaf..e1d8d83fbd 100644 --- a/docs/running-headscale-openbsd.md +++ b/docs/running-headscale-openbsd.md @@ -30,7 +30,7 @@ describing how to make `headscale` run properly in a server environment. cd headscale # optionally checkout a release - # option a. you can find offical relase at https://github.com/juanfont/headscale/releases/latest + # option a. you can find official release at https://github.com/juanfont/headscale/releases/latest # option b. get latest tag, this may be a beta release latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) @@ -57,7 +57,7 @@ describing how to make `headscale` run properly in a server environment. cd headscale # optionally checkout a release - # option a. you can find offical relase at https://github.com/juanfont/headscale/releases/latest + # option a. you can find official release at https://github.com/juanfont/headscale/releases/latest # option b. get latest tag, this may be a beta release latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) diff --git a/docs/web-ui.md.orig b/docs/web-ui.md.orig new file mode 100644 index 0000000000..3175057c1f --- /dev/null +++ b/docs/web-ui.md.orig @@ -0,0 +1,23 @@ +# Headscale web interface + +!!! warning "Community contributions" + + This page contains community contributions. The projects listed here are not + maintained by the Headscale authors and are written by community members. + +<<<<<<< HEAD +| Name | Repository Link | Description | Status | +| --------------- | ------------------------------------------------------- | --------------------------------------------------------------------------- | ------ | +| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | +| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | +| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend enviroment required | Alpha | +| headscale-admin | [Github](https://github.com/GoodiesHQ/headscale-admin) | Headscale-Admin is meant to be a simple, modern web interface for Headscale | Beta | +======= +| Name | Repository Link | Description | Status | +| --------------- | ------------------------------------------------------- | ------------------------------------------------------------------------- | ------ | +| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | +| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | +| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend environment required | Alpha | +>>>>>>> cde0b83 (Fix typos) + +You can ask for support on our dedicated [Discord channel](https://discord.com/channels/896711691637780480/1105842846386356294). diff --git a/flake.nix b/flake.nix index f2046dae0e..94ec6150e7 100644 --- a/flake.nix +++ b/flake.nix @@ -30,7 +30,7 @@ checkFlags = ["-short"]; # When updating go.mod or go.sum, a new sha will need to be calculated, - # update this if you have a mismatch after doing a change to thos files. + # update this if you have a mismatch after doing a change to those files. vendorHash = "sha256-wXfKeiJaGe6ahOsONrQhvbuMN8flQ13b0ZjxdbFs1e8="; subPackages = ["cmd/headscale"]; diff --git a/flake.nix.orig b/flake.nix.orig new file mode 100644 index 0000000000..ab17ebed55 --- /dev/null +++ b/flake.nix.orig @@ -0,0 +1,178 @@ +{ + description = "headscale - Open Source Tailscale Control server"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { + self, + nixpkgs, + flake-utils, + ... + }: let + headscaleVersion = + if (self ? shortRev) + then self.shortRev + else "dev"; + in + { + overlay = _: prev: let + pkgs = nixpkgs.legacyPackages.${prev.system}; + in rec { + headscale = pkgs.buildGo122Module rec { + pname = "headscale"; + version = headscaleVersion; + src = pkgs.lib.cleanSource self; + + # Only run unit tests when testing a build + checkFlags = ["-short"]; + + # When updating go.mod or go.sum, a new sha will need to be calculated, +<<<<<<< HEAD + # update this if you have a mismatch after doing a change to thos files. + vendorHash = "sha256-wXfKeiJaGe6ahOsONrQhvbuMN8flQ13b0ZjxdbFs1e8="; +======= + # update this if you have a mismatch after doing a change to those files. + vendorHash = "sha256-Yb5WaN0abPLZ4mPnuJGZoj6EMfoZjaZZ0f344KWva3o="; +>>>>>>> cde0b83 (Fix typos) + + subPackages = ["cmd/headscale"]; + + ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; + }; + + protoc-gen-grpc-gateway = pkgs.buildGoModule rec { + pname = "grpc-gateway"; + version = "2.19.1"; + + src = pkgs.fetchFromGitHub { + owner = "grpc-ecosystem"; + repo = "grpc-gateway"; + rev = "v${version}"; + sha256 = "sha256-CdGQpQfOSimeio8v1lZ7xzE/oAS2qFyu+uN+H9i7vpo="; + }; + + vendorHash = "sha256-no7kZGpf/VOuceC3J+izGFQp5aMS3b+Rn+x4BFZ2zgs="; + + nativeBuildInputs = [pkgs.installShellFiles]; + + subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"]; + }; + }; + } + // flake-utils.lib.eachDefaultSystem + (system: let + pkgs = import nixpkgs { + overlays = [self.overlay]; + inherit system; + }; + buildDeps = with pkgs; [git go_1_22 gnumake]; + devDeps = with pkgs; + buildDeps + ++ [ + golangci-lint + golines + nodePackages.prettier + goreleaser + nfpm + gotestsum + gotests + ksh + ko + yq-go + ripgrep + + # 'dot' is needed for pprof graphs + # go tool pprof -http=: + graphviz + + # Protobuf dependencies + protobuf + protoc-gen-go + protoc-gen-go-grpc + protoc-gen-grpc-gateway + buf + clang-tools # clang-format + ]; + + # Add entry to build a docker image with headscale + # caveat: only works on Linux + # + # Usage: + # nix build .#headscale-docker + # docker load < result + headscale-docker = pkgs.dockerTools.buildLayeredImage { + name = "headscale"; + tag = headscaleVersion; + contents = [pkgs.headscale]; + config.Entrypoint = [(pkgs.headscale + "/bin/headscale")]; + }; + in rec { + # `nix develop` + devShell = pkgs.mkShell { + buildInputs = + devDeps + ++ [ + (pkgs.writeShellScriptBin + "nix-vendor-sri" + '' + set -eu + + OUT=$(mktemp -d -t nar-hash-XXXXXX) + rm -rf "$OUT" + + go mod vendor -o "$OUT" + go run tailscale.com/cmd/nardump --sri "$OUT" + rm -rf "$OUT" + '') + + (pkgs.writeShellScriptBin + "go-mod-update-all" + '' + cat go.mod | ${pkgs.silver-searcher}/bin/ag "\t" | ${pkgs.silver-searcher}/bin/ag -v indirect | ${pkgs.gawk}/bin/awk '{print $1}' | ${pkgs.findutils}/bin/xargs go get -u + go mod tidy + '') + ]; + + shellHook = '' + export PATH="$PWD/result/bin:$PATH" + ''; + }; + + # `nix build` + packages = with pkgs; { + inherit headscale; + inherit headscale-docker; + }; + defaultPackage = pkgs.headscale; + + # `nix run` + apps.headscale = flake-utils.lib.mkApp { + drv = packages.headscale; + }; + apps.default = apps.headscale; + + checks = { + format = + pkgs.runCommand "check-format" + { + buildInputs = with pkgs; [ + gnumake + nixpkgs-fmt + golangci-lint + nodePackages.prettier + golines + clang-tools + ]; + } '' + ${pkgs.nixpkgs-fmt}/bin/nixpkgs-fmt ${./.} + ${pkgs.golangci-lint}/bin/golangci-lint run --fix --timeout 10m + ${pkgs.nodePackages.prettier}/bin/prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' + ${pkgs.golines}/bin/golines --max-len=88 --base-formatter=gofumpt -w ${./.} + ${pkgs.clang-tools}/bin/clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i ${./.} + ''; + }; + }); +} diff --git a/hscontrol/app.go b/hscontrol/app.go index b8eb6f69e2..28211db39d 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -330,7 +330,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, // Check if the request is coming from the on-server client. // This is not secure, but it is to maintain maintainability // with the "legacy" database-based client - // It is also neede for grpc-gateway to be able to connect to + // It is also needed for grpc-gateway to be able to connect to // the server client, _ := peer.FromContext(ctx) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e9a4ea0405..a5a5c9a622 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -661,7 +661,7 @@ func GenerateGivenName( } func DeleteExpiredEphemeralNodes(tx *gorm.DB, - inactivityThreshhold time.Duration, + inactivityThreshold time.Duration, ) ([]types.NodeID, []types.NodeID) { users, err := ListUsers(tx) if err != nil { @@ -679,7 +679,7 @@ func DeleteExpiredEphemeralNodes(tx *gorm.DB, for idx, node := range nodes { if node.IsEphemeral() && node.LastSeen != nil && time.Now(). - After(node.LastSeen.Add(inactivityThreshhold)) { + After(node.LastSeen.Add(inactivityThreshold)) { expired = append(expired, node.ID) log.Info(). @@ -692,7 +692,7 @@ func DeleteExpiredEphemeralNodes(tx *gorm.DB, log.Error(). Err(err). Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") + Msg("� Cannot delete ephemeral node from the database") } changedNodes = append(changedNodes, changed...) @@ -725,6 +725,7 @@ func ExpireExpiredNodes(tx *gorm.DB, NodeID: tailcfg.NodeID(node.ID), KeyExpiry: node.Expiry, }) + // and there is no point in notifying twice. } } diff --git a/hscontrol/db/node.go.orig b/hscontrol/db/node.go.orig new file mode 100644 index 0000000000..14bb3481ee --- /dev/null +++ b/hscontrol/db/node.go.orig @@ -0,0 +1,772 @@ +package db + +import ( + "errors" + "fmt" + "net/netip" + "sort" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/patrickmn/go-cache" + "github.com/puzpuzpuz/xsync/v3" + "github.com/rs/zerolog/log" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +const ( + NodeGivenNameHashLength = 8 + NodeGivenNameTrimSize = 2 +) + +var ( + ErrNodeNotFound = errors.New("node not found") + ErrNodeRouteIsNotAvailable = errors.New("route is not available on node") + ErrNodeNotFoundRegistrationCache = errors.New( + "node not found in registration cache", + ) + ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") + ErrDifferentRegisteredUser = errors.New( + "node was previously registered with a different user", + ) +) + +func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListPeers(rx, nodeID) + }) +} + +// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. +func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Where("id <> ?", + nodeID).Find(&nodes).Error; err != nil { + return types.Nodes{}, err + } + + sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + + return nodes, nil +} + +func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) +} + +func ListNodes(tx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Find(&nodes).Error; err != nil { + return nil, err + } + + return nodes, nil +} + +func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Where("given_name = ?", givenName).Find(&nodes).Error; err != nil { + return nil, err + } + + return nodes, nil +} + +func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return getNode(rx, user, name) + }) +} + +// getNode finds a Node by name and user and returns the Node struct. +func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, user) + if err != nil { + return nil, err + } + + for _, m := range nodes { + if m.Hostname == name { + return m, nil + } + } + + return nil, ErrNodeNotFound +} + +func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByID(rx, id) + }) +} + +// GetNodeByID finds a Node by ID and returns the Node struct. +func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) { + mach := types.Node{} + if result := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Find(&types.Node{ID: id}).First(&mach); result.Error != nil { + return nil, result.Error + } + + return &mach, nil +} + +func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByMachineKey(rx, machineKey) + }) +} + +// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. +func GetNodeByMachineKey( + tx *gorm.DB, + machineKey key.MachinePublic, +) (*types.Node, error) { + mach := types.Node{} + if result := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil { + return nil, result.Error + } + + return &mach, nil +} + +func (hsdb *HSDatabase) GetNodeByAnyKey( + machineKey key.MachinePublic, + nodeKey key.NodePublic, + oldNodeKey key.NodePublic, +) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) + }) +} + +// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. +// TODO(kradalby): see if we can remove this. +func GetNodeByAnyKey( + tx *gorm.DB, + machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, +) (*types.Node, error) { + node := types.Node{} + if result := tx. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", + machineKey.String(), + nodeKey.String(), + oldNodeKey.String()); result.Error != nil { + return nil, result.Error + } + + return &node, nil +} + +func (hsdb *HSDatabase) SetTags( + nodeID types.NodeID, + tags []string, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetTags(tx, nodeID, tags) + }) +} + +// SetTags takes a Node struct pointer and update the forced tags. +func SetTags( + tx *gorm.DB, + nodeID types.NodeID, + tags []string, +) error { + if len(tags) == 0 { + // if no tags are provided, we remove all forced tags + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil { + return fmt.Errorf("failed to remove tags for node in the database: %w", err) + } + + return nil + } + + newTags := types.StringList{} + for _, tag := range tags { + if !util.StringOrPrefixListContains(newTags, tag) { + newTags = append(newTags, tag) + } + } + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { + return fmt.Errorf("failed to update tags for node in the database: %w", err) + } + + return nil +} + +// RenameNode takes a Node struct and a new GivenName for the nodes +// and renames it. +func RenameNode(tx *gorm.DB, + nodeID uint64, newName string, +) error { + err := util.CheckForFQDNRules( + newName, + ) + if err != nil { + return fmt.Errorf("renaming node: %w", err) + } + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { + return fmt.Errorf("failed to rename node in the database: %w", err) + } + + return nil +} + +func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetExpiry(tx, nodeID, expiry) + }) +} + +// NodeSetExpiry takes a Node struct and a new expiry time. +func NodeSetExpiry(tx *gorm.DB, + nodeID types.NodeID, expiry time.Time, +) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error +} + +func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) { + return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { + return DeleteNode(tx, node, isLikelyConnected) + }) +} + +// DeleteNode deletes a Node from the database. +// Caller is responsible for notifying all of change. +func DeleteNode(tx *gorm.DB, + node *types.Node, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], +) ([]types.NodeID, error) { + changed, err := deleteNodeRoutes(tx, node, isLikelyConnected) + if err != nil { + return changed, err + } + + // Unscoped causes the node to be fully removed from the database. + if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { + return changed, err + } + + return changed, nil +} + +// SetLastSeen sets a node's last seen field indicating that we +// have recently communicating with this node. +func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error +} + +func RegisterNodeFromAuthCallback( + tx *gorm.DB, + cache *cache.Cache, + mkey key.MachinePublic, + userName string, + nodeExpiry *time.Time, + registrationMethod string, + ipv4 *netip.Addr, + ipv6 *netip.Addr, +) (*types.Node, error) { + log.Debug(). + Str("machine_key", mkey.ShortString()). + Str("userName", userName). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). + Msg("Registering node from API/CLI or auth callback") + + if nodeInterface, ok := cache.Get(mkey.String()); ok { + if registrationNode, ok := nodeInterface.(types.Node); ok { + user, err := GetUser(tx, userName) + if err != nil { + return nil, fmt.Errorf( + "failed to find user in register node from auth callback, %w", + err, + ) + } + + // Registration of expired node with different user + if registrationNode.ID != 0 && + registrationNode.UserID != user.ID { + return nil, ErrDifferentRegisteredUser + } + + registrationNode.UserID = user.ID + registrationNode.User = *user + registrationNode.RegisterMethod = registrationMethod + + if nodeExpiry != nil { + registrationNode.Expiry = nodeExpiry + } + + node, err := RegisterNode( + tx, + registrationNode, + ipv4, ipv6, + ) + + if err == nil { + cache.Delete(mkey.String()) + } + + return node, err + } else { + return nil, ErrCouldNotConvertNodeInterface + } + } + + return nil, ErrNodeNotFoundRegistrationCache +} + +func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { + return RegisterNode(tx, node, ipv4, ipv6) + }) +} + +// RegisterNode is executed from the CLI to register a new Node using its MachineKey. +func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { + log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()). + Str("user", node.User.Name). + Msg("Registering node") + + // If the node exists and it already has IP(s), we just save it + // so we store the node.Expire and node.Nodekey that has been set when + // adding it to the registrationCache + if node.IPv4 != nil || node.IPv6 != nil { + if err := tx.Save(&node).Error; err != nil { + return nil, fmt.Errorf("failed register existing node in the database: %w", err) + } + + log.Trace(). + Caller(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()). + Str("user", node.User.Name). + Msg("Node authorized again") + + return &node, nil + } + + node.IPv4 = ipv4 + node.IPv6 = ipv6 + + if err := tx.Save(&node).Error; err != nil { + return nil, fmt.Errorf("failed register(save) node in the database: %w", err) + } + + log.Trace(). + Caller(). + Str("node", node.Hostname). + Msg("Node registered with the database") + + return &node, nil +} + +// NodeSetNodeKey sets the node key of a node and saves it to the database. +func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { + return tx.Model(node).Updates(types.Node{ + NodeKey: nodeKey, + }).Error +} + +func (hsdb *HSDatabase) NodeSetMachineKey( + node *types.Node, + machineKey key.MachinePublic, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetMachineKey(tx, node, machineKey) + }) +} + +// NodeSetMachineKey sets the node key of a node and saves it to the database. +func NodeSetMachineKey( + tx *gorm.DB, + node *types.Node, + machineKey key.MachinePublic, +) error { + return tx.Model(node).Updates(types.Node{ + MachineKey: machineKey, + }).Error +} + +// NodeSave saves a node object to the database, prefer to use a specific save method rather +// than this. It is intended to be used when we are changing or. +// TODO(kradalby): Remove this func, just use Save. +func NodeSave(tx *gorm.DB, node *types.Node) error { + return tx.Save(node).Error +} + +func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetAdvertisedRoutes(rx, node) + }) +} + +// GetAdvertisedRoutes returns the routes that are be advertised by the given node. +func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { + routes := types.Routes{} + + err := tx. + Preload("Node"). + Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("getting advertised routes for node(%d): %w", node.ID, err) + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil +} + +func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetEnabledRoutes(rx, node) + }) +} + +// GetEnabledRoutes returns the routes that are enabled for the node. +func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { + routes := types.Routes{} + + err := tx. + Preload("Node"). + Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). + Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("getting enabled routes for node(%d): %w", node.ID, err) + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil +} + +func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool { + route, err := netip.ParsePrefix(routeStr) + if err != nil { + return false + } + + enabledRoutes, err := GetEnabledRoutes(tx, node) + if err != nil { + return false + } + + for _, enabledRoute := range enabledRoutes { + if route == enabledRoute { + return true + } + } + + return false +} + +func (hsdb *HSDatabase) enableRoutes( + node *types.Node, + routeStrs ...string, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return enableRoutes(tx, node, routeStrs...) + }) +} + +// enableRoutes enables new routes based on a list of new routes. +func enableRoutes(tx *gorm.DB, + node *types.Node, routeStrs ...string, +) (*types.StateUpdate, error) { + newRoutes := make([]netip.Prefix, len(routeStrs)) + for index, routeStr := range routeStrs { + route, err := netip.ParsePrefix(routeStr) + if err != nil { + return nil, err + } + + newRoutes[index] = route + } + + advertisedRoutes, err := GetAdvertisedRoutes(tx, node) + if err != nil { + return nil, err + } + + for _, newRoute := range newRoutes { + if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { + return nil, fmt.Errorf( + "route (%s) is not available on node %s: %w", + node.Hostname, + newRoute, ErrNodeRouteIsNotAvailable, + ) + } + } + + // Separate loop so we don't leave things in a half-updated state + for _, prefix := range newRoutes { + route := types.Route{} + err := tx.Preload("Node"). + Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). + First(&route).Error + if err == nil { + route.Enabled = true + + // Mark already as primary if there is only this node offering this subnet + // (and is not an exit route) + if !route.IsExitRoute() { + route.IsPrimary = isUniquePrefix(tx, route) + } + + err = tx.Save(&route).Error + if err != nil { + return nil, fmt.Errorf("failed to enable route: %w", err) + } + } else { + return nil, fmt.Errorf("failed to find route: %w", err) + } + } + + // Ensure the node has the latest routes when notifying the other + // nodes + nRoutes, err := GetNodeRoutes(tx, node) + if err != nil { + return nil, fmt.Errorf("failed to read back routes: %w", err) + } + + node.Routes = nRoutes + + log.Trace(). + Caller(). + Str("node", node.Hostname). + Strs("routes", routeStrs). + Msg("enabling routes") + + return &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{node.ID}, + Message: "created in db.enableRoutes", + }, nil +} + +func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { + normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( + suppliedName, + ) + if err != nil { + return "", err + } + + if randomSuffix { + // Trim if a hostname will be longer than 63 chars after adding the hash. + trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize + if len(normalizedHostname) > trimmedHostnameLength { + normalizedHostname = normalizedHostname[:trimmedHostnameLength] + } + + suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength) + if err != nil { + return "", err + } + + normalizedHostname += "-" + suffix + } + + return normalizedHostname, nil +} + +func (hsdb *HSDatabase) GenerateGivenName( + mkey key.MachinePublic, + suppliedName string, +) (string, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { + return GenerateGivenName(rx, mkey, suppliedName) + }) +} + +func GenerateGivenName( + tx *gorm.DB, + mkey key.MachinePublic, + suppliedName string, +) (string, error) { + givenName, err := generateGivenName(suppliedName, false) + if err != nil { + return "", err + } + + // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ + nodes, err := listNodesByGivenName(tx, givenName) + if err != nil { + return "", err + } + + var nodeFound *types.Node + for idx, node := range nodes { + if node.GivenName == givenName { + nodeFound = nodes[idx] + } + } + + if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() { + postfixedName, err := generateGivenName(suppliedName, true) + if err != nil { + return "", err + } + + givenName = postfixedName + } + + return givenName, nil +} + +<<<<<<< HEAD +func DeleteExpiredEphemeralNodes(tx *gorm.DB, + inactivityThreshhold time.Duration, +) ([]types.NodeID, []types.NodeID) { +======= +func ExpireEphemeralNodes(tx *gorm.DB, + inactivityThreshold time.Duration, +) (types.StateUpdate, bool) { +>>>>>>> cde0b83 (Fix typos) + users, err := ListUsers(tx) + if err != nil { + return nil, nil + } + + var expired []types.NodeID + var changedNodes []types.NodeID + for _, user := range users { + nodes, err := ListNodesByUser(tx, user.Name) + if err != nil { + return nil, nil + } + + for idx, node := range nodes { + if node.IsEphemeral() && node.LastSeen != nil && + time.Now(). +<<<<<<< HEAD + After(node.LastSeen.Add(inactivityThreshhold)) { + expired = append(expired, node.ID) +======= + After(node.LastSeen.Add(inactivityThreshold)) { + expired = append(expired, tailcfg.NodeID(node.ID)) +>>>>>>> cde0b83 (Fix typos) + + log.Info(). + Str("node", node.Hostname). + Msg("Ephemeral client removed from database") + + // empty isConnected map as ephemeral nodes are not routes + changed, err := DeleteNode(tx, nodes[idx], nil) + if err != nil { + log.Error(). + Err(err). + Str("node", node.Hostname). + Msg("🤮 Cannot delete ephemeral node from the database") + } + + changedNodes = append(changedNodes, changed...) + } + } + + // TODO(kradalby): needs to be moved out of transaction + } + + return expired, changedNodes +} + +func ExpireExpiredNodes(tx *gorm.DB, + lastCheck time.Time, +) (time.Time, types.StateUpdate, bool) { + // use the time of the start of the function to ensure we + // dont miss some nodes by returning it _after_ we have + // checked everything. + started := time.Now() + + expired := make([]*tailcfg.PeerChange, 0) + + nodes, err := ListNodes(tx) + if err != nil { + return time.Unix(0, 0), types.StateUpdate{}, false + } + for _, node := range nodes { + if node.IsExpired() && node.Expiry.After(lastCheck) { + expired = append(expired, &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: node.Expiry, + }) +<<<<<<< HEAD +======= + + now := time.Now() + // Do not use setNodeExpiry as that has a notifier hook, which + // can cause a deadlock, we are updating all changed nodes later + // and there is no point in notifying twice. + if err := tx.Model(&nodes[index]).Updates(types.Node{ + Expiry: &now, + }).Error; err != nil { + log.Error(). + Err(err). + Str("node", node.Hostname). + Str("name", node.GivenName). + Msg("🤮 Cannot expire node") + } else { + log.Info(). + Str("node", node.Hostname). + Str("name", node.GivenName). + Msg("Node successfully expired") + } +>>>>>>> cde0b83 (Fix typos) + } + } + + if len(expired) > 0 { + return started, types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: expired, + }, true + } + + return started, types.StateUpdate{}, false +} diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index fa18765345..190e7a57b2 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -312,6 +312,7 @@ func (s *Suite) TestExpireNode(c *check.C) { c.Assert(nodeFromDB.IsExpired(), check.Equals, true) } +func (s *Suite) TestSerdeAddressStringSlice(c *check.C) { func (s *Suite) TestGenerateGivenName(c *check.C) { user1, err := db.CreateUser("user-1") c.Assert(err, check.IsNil) @@ -393,7 +394,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) - // assign duplicat tags, expect no errors but no doubles in DB + // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) diff --git a/hscontrol/db/node_test.go.orig b/hscontrol/db/node_test.go.orig new file mode 100644 index 0000000000..d64ee0429d --- /dev/null +++ b/hscontrol/db/node_test.go.orig @@ -0,0 +1,625 @@ +package db + +import ( + "fmt" + "net/netip" + "regexp" + "strconv" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/puzpuzpuz/xsync/v3" + "gopkg.in/check.v1" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func (s *Suite) TestGetNode(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode("test", "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + pakID := uint(pak.ID) + + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) + + _, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) +} + +func (s *Suite) TestGetNodeByID(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + pakID := uint(pak.ID) + node := types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.IsNil) +} + +func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + oldNodeKey := key.NewNode() + + machineKey := key.NewMachine() + + pakID := uint(pak.ID) + node := types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + + _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) + c.Assert(err, check.IsNil) +} + +func (s *Suite) TestHardDeleteNode(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + node := types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode3", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + + _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) + c.Assert(err, check.IsNil) + + _, err = db.getNode(user.Name, "testnode3") + c.Assert(err, check.NotNil) +} + +func (s *Suite) TestListPeers(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + pakID := uint(pak.ID) + for index := 0; index <= 10; index++ { + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + node := types.Node{ + ID: types.NodeID(index), + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode" + strconv.Itoa(index), + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + } + + node0ByID, err := db.GetNodeByID(0) + c.Assert(err, check.IsNil) + + peersOfNode0, err := db.ListPeers(node0ByID.ID) + c.Assert(err, check.IsNil) + + c.Assert(len(peersOfNode0), check.Equals, 9) + c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") + c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") + c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") +} + +func (s *Suite) TestGetACLFilteredPeers(c *check.C) { + type base struct { + user *types.User + key *types.PreAuthKey + } + + stor := make([]base, 0) + + for _, name := range []string{"test", "admin"} { + user, err := db.CreateUser(name) + c.Assert(err, check.IsNil) + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + stor = append(stor, base{user, pak}) + } + + _, err := db.GetNodeByID(0) + c.Assert(err, check.NotNil) + + for index := 0; index <= 10; index++ { + nodeKey := key.NewNode() + machineKey := key.NewMachine() + pakID := uint(stor[index%2].key.ID) + + v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) + node := types.Node{ + ID: types.NodeID(index), + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + IPv4: &v4, + Hostname: "testnode" + strconv.Itoa(index), + UserID: stor[index%2].user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + } + + aclPolicy := &policy.ACLPolicy{ + Groups: map[string][]string{ + "group:test": {"admin"}, + }, + Hosts: map[string]netip.Prefix{}, + TagOwners: map[string][]string{}, + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"admin"}, + Destinations: []string{"*:*"}, + }, + { + Action: "accept", + Sources: []string{"test"}, + Destinations: []string{"test:*"}, + }, + }, + Tests: []policy.ACLTest{}, + } + + adminNode, err := db.GetNodeByID(1) + c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User) + c.Assert(err, check.IsNil) + + testNode, err := db.GetNodeByID(2) + c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) + c.Assert(err, check.IsNil) + + adminPeers, err := db.ListPeers(adminNode.ID) + c.Assert(err, check.IsNil) + + testPeers, err := db.ListPeers(testNode.ID) + c.Assert(err, check.IsNil) + + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + c.Assert(err, check.IsNil) + + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + c.Assert(err, check.IsNil) + + peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) + peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules) + + c.Log(peersOfTestNode) + c.Assert(len(peersOfTestNode), check.Equals, 9) + c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1") + c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3") + c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5") + + c.Log(peersOfAdminNode) + c.Assert(len(peersOfAdminNode), check.Equals, 9) + c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2") + c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4") + c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") +} + +func (s *Suite) TestExpireNode(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode("test", "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + pakID := uint(pak.ID) + + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + Expiry: &time.Time{}, + } + db.DB.Save(node) + + nodeFromDB, err := db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + c.Assert(nodeFromDB, check.NotNil) + + c.Assert(nodeFromDB.IsExpired(), check.Equals, false) + + now := time.Now() + err = db.NodeSetExpiry(nodeFromDB.ID, now) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + + c.Assert(nodeFromDB.IsExpired(), check.Equals, true) +} + +<<<<<<< HEAD +======= +func (s *Suite) TestSerdeAddressStringSlice(c *check.C) { + input := types.NodeAddresses([]netip.Addr{ + netip.MustParseAddr("192.0.2.1"), + netip.MustParseAddr("2001:db8::1"), + }) + serialized, err := input.Value() + c.Assert(err, check.IsNil) + if serial, ok := serialized.(string); ok { + c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") + } + + var deserialized types.NodeAddresses + err = deserialized.Scan(serialized) + c.Assert(err, check.IsNil) + + c.Assert(len(deserialized), check.Equals, len(input)) + for i := range deserialized { + c.Assert(deserialized[i], check.Equals, input[i]) + } +} + +>>>>>>> cde0b83 (Fix typos) +func (s *Suite) TestGenerateGivenName(c *check.C) { + user1, err := db.CreateUser("user-1") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode("user-1", "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + machineKey2 := key.NewMachine() + + pakID := uint(pak.ID) + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "hostname-1", + GivenName: "hostname-1", + UserID: user1.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) + + givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") + comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-2", comment) + + givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1") + comment = check.Commentf("Same user, same node, same hostname, no conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Equals, "hostname-1", comment) + + givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1") + comment = check.Commentf("Same user, unique nodes, same hostname, conflict") + c.Assert(err, check.IsNil, comment) + c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment) +} + +func (s *Suite) TestSetTags(c *check.C) { + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = db.getNode("test", "testnode") + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + pakID := uint(pak.ID) + node := &types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "testnode", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + } + + trx := db.DB.Save(node) + c.Assert(trx.Error, check.IsNil) + + // assign simple tags + sTags := []string{"tag:test", "tag:foo"} + err = db.SetTags(node.ID, sTags) + c.Assert(err, check.IsNil) + node, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) + + // assign duplicate tags, expect no errors but no doubles in DB + eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} + err = db.SetTags(node.ID, eTags) + c.Assert(err, check.IsNil) + node, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + c.Assert( + node.ForcedTags, + check.DeepEquals, + types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + ) + + // test removing tags + err = db.SetTags(node.ID, []string{}) + c.Assert(err, check.IsNil) + node, err = db.getNode("test", "testnode") + c.Assert(err, check.IsNil) + c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{})) +} + +func TestHeadscale_generateGivenName(t *testing.T) { + type args struct { + suppliedName string + randomSuffix bool + } + tests := []struct { + name string + args args + want *regexp.Regexp + wantErr bool + }{ + { + name: "simple node name generation", + args: args{ + suppliedName: "testnode", + randomSuffix: false, + }, + want: regexp.MustCompile("^testnode$"), + wantErr: false, + }, + { + name: "node name with 53 chars", + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + randomSuffix: false, + }, + want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), + wantErr: false, + }, + { + name: "node name with 63 chars", + args: args{ + suppliedName: "nodeeeeeee12345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: regexp.MustCompile("^nodeeeeeee12345678901234567890123456789012345678901234567890123$"), + wantErr: false, + }, + { + name: "node name with 64 chars", + args: args{ + suppliedName: "nodeeeeeee123456789012345678901234567890123456789012345678901234", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "node name with 73 chars", + args: args{ + suppliedName: "nodeeeeeee123456789012345678901234567890123456789012345678901234567890123", + randomSuffix: false, + }, + want: nil, + wantErr: true, + }, + { + name: "node name with random suffix", + args: args{ + suppliedName: "test", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", NodeGivenNameHashLength)), + wantErr: false, + }, + { + name: "node name with 63 chars with random suffix", + args: args{ + suppliedName: "nodeeee12345678901234567890123456789012345678901234567890123", + randomSuffix: true, + }, + want: regexp.MustCompile(fmt.Sprintf("^nodeeee1234567890123456789012345678901234567890123456-[a-z0-9]{%d}$", NodeGivenNameHashLength)), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + if (err != nil) != tt.wantErr { + t.Errorf( + "Headscale.GenerateGivenName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + + return + } + + if tt.want != nil && !tt.want.MatchString(got) { + t.Errorf( + "Headscale.GenerateGivenName() = %v, does not match %v", + tt.want, + got, + ) + } + + if len(got) > util.LabelHostnameLength { + t.Errorf( + "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", + got, + util.LabelHostnameLength, + ) + } + }) + } +} + +func (s *Suite) TestAutoApproveRoutes(c *check.C) { + acl := []byte(` +{ + "tagOwners": { + "tag:exit": ["test"], + }, + + "groups": { + "group:test": ["test"] + }, + + "acls": [ + {"action": "accept", "users": ["*"], "ports": ["*:*"]}, + ], + + "autoApprovers": { + "exitNode": ["tag:exit"], + "routes": { + "10.10.0.0/16": ["group:test"], + "10.11.0.0/16": ["test"], + } + } +} + `) + + pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") + c.Assert(err, check.IsNil) + c.Assert(pol, check.NotNil) + + user, err := db.CreateUser("test") + c.Assert(err, check.IsNil) + + pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + nodeKey := key.NewNode() + machineKey := key.NewMachine() + + defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") + defaultRouteV6 := netip.MustParsePrefix("::/0") + route1 := netip.MustParsePrefix("10.10.0.0/16") + // Check if a subprefix of an autoapproved route is approved + route2 := netip.MustParsePrefix("10.11.0.0/24") + + v4 := netip.MustParseAddr("100.64.0.1") + pakID := uint(pak.ID) + node := types.Node{ + ID: 0, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "test", + UserID: user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: &pakID, + Hostinfo: &tailcfg.Hostinfo{ + RequestTags: []string{"tag:exit"}, + RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, + }, + IPv4: &v4, + } + + trx := db.DB.Save(&node) + c.Assert(trx.Error, check.IsNil) + + sendUpdate, err := db.SaveNodeRoutes(&node) + c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) + + node0ByID, err := db.GetNodeByID(0) + c.Assert(err, check.IsNil) + + // TODO(kradalby): Check state update + err = db.EnableAutoApprovedRoutes(pol, node0ByID) + c.Assert(err, check.IsNil) + + enabledRoutes, err := db.GetEnabledRoutes(node0ByID) + c.Assert(err, check.IsNil) + c.Assert(enabledRoutes, check.HasLen, 4) +} diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 16a8689f7b..adfd289a49 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -83,7 +83,7 @@ func CreatePreAuthKey( if !seenTags[tag] { if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { return nil, fmt.Errorf( - "failed to ceate key tag in the database: %w", + "failed to create key tag in the database: %w", err, ) } diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 52a63e9fd3..0b0c9b16ca 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -204,7 +204,7 @@ func DERPProbeHandler( } } -// DERPBootstrapDNSHandler implements the /bootsrap-dns endpoint +// DERPBootstrapDNSHandler implements the /bootstrap-dns endpoint // Described in https://github.com/tailscale/tailscale/issues/1405, // this endpoint provides a way to help a client when it fails to start up // because its DNS are broken. diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index dd4d95bb36..b0cafe105b 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -532,7 +532,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { "example-host-2:80" ], "deny": [ - "exapmle-host-2:100" + "example-host-2:100" ], }, { @@ -635,7 +635,7 @@ func Test_expandGroup(t *testing.T) { wantErr: false, }, { - name: "InexistantGroup", + name: "InexistentGroup", field: field{ pol: ACLPolicy{ Groups: Groups{ @@ -2604,7 +2604,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "all hosts can talk to each other", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2651,7 +2651,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "One host can talk to another, but not all hosts", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2693,7 +2693,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "host cannot directly talk to destination, but return path is authorized", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2735,7 +2735,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "rules allows all hosts to reach one destination", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2777,7 +2777,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "rules allows all hosts to reach one destination, destination can reach all hosts", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2824,7 +2824,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "rule allows all hosts to reach all destinations", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), @@ -2871,7 +2871,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { { name: "without rule all communications are forbidden", args: args{ - nodes: types.Nodes{ // list of all nodess in the database + nodes: types.Nodes{ // list of all nodes in the database &types.Node{ ID: 1, IPv4: iap("100.64.0.1"), diff --git a/hscontrol/poll.go b/hscontrol/poll.go index e3137cc6ad..3e66a36832 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -192,6 +192,7 @@ func (m *mapSession) serve() { // start-up before their first real endpoint update. if m.isReadOnlyUpdate() { m.handleReadOnlyRequest() + // update ACLRules with peer information (to update server tags if necessary) return } diff --git a/hscontrol/poll.go.orig b/hscontrol/poll.go.orig new file mode 100644 index 0000000000..c4e279563b --- /dev/null +++ b/hscontrol/poll.go.orig @@ -0,0 +1,818 @@ +package hscontrol + +import ( + "cmp" + "context" + "fmt" + "math/rand/v2" + "net/http" + "net/netip" + "sort" + "strings" + "sync" + "time" + + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + xslices "golang.org/x/exp/slices" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +const ( + keepAliveInterval = 50 * time.Second +) + +type contextKey string + +const nodeNameContextKey = contextKey("nodeName") + +type sessionManager struct { + mu sync.RWMutex + sess map[types.NodeID]*mapSession +} + +type mapSession struct { + h *Headscale + req tailcfg.MapRequest + ctx context.Context + capVer tailcfg.CapabilityVersion + mapper *mapper.Mapper + + serving bool + servingMu sync.Mutex + + ch chan types.StateUpdate + cancelCh chan struct{} + + keepAliveTicker *time.Ticker + + node *types.Node + w http.ResponseWriter + + warnf func(string, ...any) + infof func(string, ...any) + tracef func(string, ...any) + errf func(error, string, ...any) +} + +func (h *Headscale) newMapSession( + ctx context.Context, + req tailcfg.MapRequest, + w http.ResponseWriter, + node *types.Node, +) *mapSession { + warnf, infof, tracef, errf := logPollFunc(req, node) + + var updateChan chan types.StateUpdate + if req.Stream { + // Use a buffered channel in case a node is not fully ready + // to receive a message to make sure we dont block the entire + // notifier. + updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) + updateChan <- types.StateUpdate{ + Type: types.StateFullUpdate, + } + } + + return &mapSession{ + h: h, + ctx: ctx, + req: req, + w: w, + node: node, + capVer: req.Version, + mapper: h.mapper, + + // serving indicates if a client is being served. + serving: false, + + ch: updateChan, + cancelCh: make(chan struct{}), + + keepAliveTicker: time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)), + + // Loggers + warnf: warnf, + infof: infof, + tracef: tracef, + errf: errf, + } +} + +func (m *mapSession) close() { + m.servingMu.Lock() + defer m.servingMu.Unlock() + if !m.serving { + return + } + + m.tracef("mapSession (%p) sending message on cancel chan") + m.cancelCh <- struct{}{} + m.tracef("mapSession (%p) sent message on cancel chan") +} + +func (m *mapSession) isStreaming() bool { + return m.req.Stream && !m.req.ReadOnly +} + +func (m *mapSession) isEndpointUpdate() bool { + return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers +} + +func (m *mapSession) isReadOnlyUpdate() bool { + return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly +} + +// handlePoll ensures the node gets the appropriate updates from either +// polling or immediate responses. +// +//nolint:gocyclo +func (m *mapSession) serve() { + // Register with the notifier if this is a streaming + // session + if m.isStreaming() { + // defers are called in reverse order, + // so top one is executed last. + + // Failover the node's routes if any. + defer m.infof("node has disconnected, mapSession: %p", m) + defer m.pollFailoverRoutes("node closing connection", m.node) + + defer m.h.updateNodeOnlineStatus(false, m.node) + defer m.h.nodeNotifier.RemoveNode(m.node.ID) + + defer func() { + m.servingMu.Lock() + defer m.servingMu.Unlock() + + m.serving = false + close(m.cancelCh) + }() + + m.serving = true + + m.h.nodeNotifier.AddNode(m.node.ID, m.ch) + m.h.updateNodeOnlineStatus(true, m.node) + + m.infof("node has connected, mapSession: %p", m) + } + + // TODO(kradalby): A set todos to harden: + // - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true + + // This is the mechanism where the node gives us information about its + // current configuration. + // + // If OmitPeers is true, Stream is false, and ReadOnly is false, + // then te server will let clients update their endpoints without + // breaking existing long-polling (Stream == true) connections. + // In this case, the server can omit the entire response; the client + // only checks the HTTP response status code. + // + // This is what Tailscale calls a Lite update, the client ignores + // the response and just wants a 200. + // !req.stream && !req.ReadOnly && req.OmitPeers + // + // TODO(kradalby): remove ReadOnly when we only support capVer 68+ + if m.isEndpointUpdate() { + m.handleEndpointUpdate() + + return + } + + // ReadOnly is whether the client just wants to fetch the + // MapResponse, without updating their Endpoints. The + // Endpoints field will be ignored and LastSeen will not be + // updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at + // start-up before their first real endpoint update. + if m.isReadOnlyUpdate() { + m.handleReadOnlyRequest() + + return + } + + // From version 68, all streaming requests can be treated as read only. + if m.capVer < 68 { + // Error has been handled/written to client in the func + // return + err := m.handleSaveNode() + if err != nil { + mapResponseWriteUpdatesInStream.WithLabelValues("error").Inc() + return + } + mapResponseWriteUpdatesInStream.WithLabelValues("ok").Inc() + } + + // Set up the client stream + m.h.pollNetMapStreamWG.Add(1) + defer m.h.pollNetMapStreamWG.Done() + + m.pollFailoverRoutes("node connected", m.node) + + // Upgrade the writer to a ResponseController + rc := http.NewResponseController(m.w) + + // Longpolling will break if there is a write timeout, + // so it needs to be disabled. + rc.SetWriteDeadline(time.Time{}) + +<<<<<<< HEAD + ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) +======= + return + } + + isConnected := h.nodeNotifier.ConnectedMap() + for _, peer := range peers { + online := isConnected[peer.MachineKey] + peer.IsOnline = &online + } + + mapp := mapper.NewMapper( + node, + peers, + h.DERPMap, + h.cfg.BaseDomain, + h.cfg.DNSConfig, + h.cfg.LogTail.Enabled, + h.cfg.RandomizeClientPort, + ) + + // update ACLRules with peer information (to update server tags if necessary) + if h.ACLPolicy != nil { + // update routes with peer information + // This state update is ignored as it will be sent + // as part of the whole node + // TODO(kradalby): figure out if that is actually correct + _, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + if err != nil { + logErr(err, "Error running auto approved routes") + } + } + + logTrace("Sending initial map") + + mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) + if err != nil { + logErr(err, "Failed to create MapResponse") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + + // Send the client an update to make sure we send an initial mapresponse + _, err = writer.Write(mapResp) + if err != nil { + logErr(err, "Could not write the map response") + + return + } + + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } else { + return + } + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from handlePoll -> new node added", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname) + h.nodeNotifier.NotifyWithIgnore( + ctx, + stateUpdate, + node.MachineKey.String()) + } + + if len(node.Routes) > 0 { + go h.pollFailoverRoutes(logErr, "new node", node) + } + + keepAliveTicker := time.NewTicker(keepAliveInterval) + + ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname)) +>>>>>>> cde0b83 (Fix typos) + defer cancel() + + // Loop through updates and continuously send them to the + // client. + for { + // consume channels with update, keep alives or "batch" blocking signals + select { + case <-m.cancelCh: + m.tracef("poll cancelled received") + return + case <-ctx.Done(): + m.tracef("poll context done") + return + + // Consume all updates sent to node + case update := <-m.ch: + m.tracef("received stream update: %s %s", update.Type.String(), update.Message) + mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() + + var data []byte + var err error + var lastMessage string + + // Ensure the node object is updated, for example, there + // might have been a hostinfo update in a sidechannel + // which contains data needed to generate a map response. + m.node, err = m.h.db.GetNodeByID(m.node.ID) + if err != nil { + m.errf(err, "Could not get machine from db") + + return + } + + updateType := "full" + switch update.Type { + case types.StateFullUpdate: + m.tracef("Sending Full MapResponse") + data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + case types.StatePeerChanged: + changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) + + for _, nodeID := range update.ChangeNodes { + changed[nodeID] = true + } + + lastMessage = update.Message + m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + updateType = "change" + + case types.StatePeerChangedPatch: + m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + updateType = "patch" + case types.StatePeerRemoved: + changed := make(map[types.NodeID]bool, len(update.Removed)) + + for _, nodeID := range update.Removed { + changed[nodeID] = false + } + m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + updateType = "remove" + case types.StateSelfUpdate: + lastMessage = update.Message + m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) + // create the map so an empty (self) update is sent + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + updateType = "remove" + case types.StateDERPUpdated: + m.tracef("Sending DERPUpdate MapResponse") + data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap) + updateType = "derp" + } + + if err != nil { + m.errf(err, "Could not get the create map update") + + return + } + + // log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", m.node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response") + + // Only send update if there is change + if data != nil { + startWrite := time.Now() + _, err = m.w.Write(data) + if err != nil { + mapResponseSent.WithLabelValues("error", updateType).Inc() + m.errf(err, "Could not write the map response, for mapSession: %p", m) + return + } + + err = rc.Flush() + if err != nil { + mapResponseSent.WithLabelValues("error", updateType).Inc() + m.errf(err, "flushing the map response to client, for mapSession: %p", m) + return + } + + log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") + + mapResponseSent.WithLabelValues("ok", updateType).Inc() + m.tracef("update sent") + } + + case <-m.keepAliveTicker.C: + data, err := m.mapper.KeepAliveResponse(m.req, m.node) + if err != nil { + m.errf(err, "Error generating the keep alive msg") + mapResponseSent.WithLabelValues("error", "keepalive").Inc() + return + } + _, err = m.w.Write(data) + if err != nil { + m.errf(err, "Cannot write keep alive message") + mapResponseSent.WithLabelValues("error", "keepalive").Inc() + return + } + err = rc.Flush() + if err != nil { + m.errf(err, "flushing keep alive to client, for mapSession: %p", m) + mapResponseSent.WithLabelValues("error", "keepalive").Inc() + return + } + + mapResponseSent.WithLabelValues("ok", "keepalive").Inc() + } + } +} + +func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { + update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node) + }) + if err != nil { + m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) + + return + } + + if update != nil && !update.Empty() { + ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) + m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID) + } +} + +// updateNodeOnlineStatus records the last seen status of a node and notifies peers +// about change in their online/offline status. +// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. +func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { + change := &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + Online: &online, + } + + if !online { + now := time.Now() + + // lastSeen is only relevant if the node is disconnected. + node.LastSeen = &now + change.LastSeen = &now + + err := h.db.Write(func(tx *gorm.DB) error { + return db.SetLastSeen(tx, node.ID, *node.LastSeen) + }) + if err != nil { + log.Error().Err(err).Msg("Cannot update node LastSeen") + + return + } + } + + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + change, + }, + }, node.ID) +} + +func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("node", node). + Str("channel", "Done"). + Msg(fmt.Sprintf("Closing %s channel", name)) + + close(channel) +} + +func (m *mapSession) handleEndpointUpdate() { + m.tracef("received endpoint update") + + change := m.node.PeerChangeFromMapRequest(m.req) + + online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID) + change.Online = &online + + m.node.ApplyPeerChange(&change) + + sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) + m.node.Hostinfo = m.req.Hostinfo + + logTracePeerChange(m.node.Hostname, sendUpdate, &change) + + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(change) && !sendUpdate { + mapResponseEndpointUpdates.WithLabelValues("noop").Inc() + return + } + + // Check if the Hostinfo of the node has changed. + // If it has changed, check if there has been a change to + // the routable IPs of the host and update update them in + // the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the route change. + // If the hostinfo has changed, but not the routes, just update + // hostinfo and let the function continue. + if routesChanged { + var err error + _, err = m.h.db.SaveNodeRoutes(m.node) + if err != nil { + m.errf(err, "Error processing node routes") + http.Error(m.w, "", http.StatusInternalServerError) + mapResponseEndpointUpdates.WithLabelValues("error").Inc() + + return + } + + if m.h.ACLPolicy != nil { + // update routes with peer information + err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + if err != nil { + m.errf(err, "Error running auto approved routes") + mapResponseEndpointUpdates.WithLabelValues("error").Inc() + } + } + + // Send an update to the node itself with to ensure it + // has an updated packetfilter allowing the new route + // if it is defined in the ACL. + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) + m.h.nodeNotifier.NotifyByNodeID( + ctx, + types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: []types.NodeID{m.node.ID}, + }, + m.node.ID) + } + + if err := m.h.db.DB.Save(m.node).Error; err != nil { + m.errf(err, "Failed to persist/update node in the database") + http.Error(m.w, "", http.StatusInternalServerError) + mapResponseEndpointUpdates.WithLabelValues("error").Inc() + + return + } + + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname) + m.h.nodeNotifier.NotifyWithIgnore( + ctx, + types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{m.node.ID}, + Message: "called from handlePoll -> update", + }, + m.node.ID) + + m.w.WriteHeader(http.StatusOK) + mapResponseEndpointUpdates.WithLabelValues("ok").Inc() + + return +} + +// handleSaveNode saves node updates in the maprequest _streaming_ +// path and is mostly the same code as in handleEndpointUpdate. +// It is not attempted to be deduplicated since it will go away +// when we stop supporting older than 68 which removes updates +// when the node is streaming. +func (m *mapSession) handleSaveNode() error { + m.tracef("saving node update from stream session") + + change := m.node.PeerChangeFromMapRequest(m.req) + + // A stream is being set up, the node is Online + online := true + change.Online = &online + + m.node.ApplyPeerChange(&change) + + sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) + m.node.Hostinfo = m.req.Hostinfo + + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(change) || !sendUpdate { + return nil + } + + // Check if the Hostinfo of the node has changed. + // If it has changed, check if there has been a change to + // the routable IPs of the host and update update them in + // the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the route change. + // If the hostinfo has changed, but not the routes, just update + // hostinfo and let the function continue. + if routesChanged { + var err error + _, err = m.h.db.SaveNodeRoutes(m.node) + if err != nil { + return err + } + + if m.h.ACLPolicy != nil { + // update routes with peer information + err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + if err != nil { + return err + } + } + } + + if err := m.h.db.DB.Save(m.node).Error; err != nil { + return err + } + + ctx := types.NotifyCtx(context.Background(), "pre-68-update-while-stream", m.node.Hostname) + m.h.nodeNotifier.NotifyWithIgnore( + ctx, + types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{m.node.ID}, + Message: "called from handlePoll -> pre-68-update-while-stream", + }, + m.node.ID) + + return nil +} + +func (m *mapSession) handleReadOnlyRequest() { + m.tracef("Client asked for a lite update, responding without peers") + + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + if err != nil { + m.errf(err, "Failed to create MapResponse") + http.Error(m.w, "", http.StatusInternalServerError) + mapResponseReadOnly.WithLabelValues("error").Inc() + return + } + + m.w.Header().Set("Content-Type", "application/json; charset=utf-8") + m.w.WriteHeader(http.StatusOK) + _, err = m.w.Write(mapResp) + if err != nil { + m.errf(err, "Failed to write response") + mapResponseReadOnly.WithLabelValues("error").Inc() + return + } + + m.w.WriteHeader(http.StatusOK) + mapResponseReadOnly.WithLabelValues("ok").Inc() + + return +} + +func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { + trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname) + + if change.Key != nil { + trace = trace.Str("node_key", change.Key.ShortString()) + } + + if change.DiscoKey != nil { + trace = trace.Str("disco_key", change.DiscoKey.ShortString()) + } + + if change.Online != nil { + trace = trace.Bool("online", *change.Online) + } + + if change.Endpoints != nil { + eps := make([]string, len(change.Endpoints)) + for idx, ep := range change.Endpoints { + eps[idx] = ep.String() + } + + trace = trace.Strs("endpoints", eps) + } + + if hostinfoChange { + trace = trace.Bool("hostinfo_changed", hostinfoChange) + } + + if change.DERPRegion != 0 { + trace = trace.Int("derp_region", change.DERPRegion) + } + + trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received") +} + +func peerChangeEmpty(chng tailcfg.PeerChange) bool { + return chng.Key == nil && + chng.DiscoKey == nil && + chng.Online == nil && + chng.Endpoints == nil && + chng.DERPRegion == 0 && + chng.LastSeen == nil && + chng.KeyExpiry == nil +} + +func logPollFunc( + mapRequest tailcfg.MapRequest, + node *types.Node, +) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) { + return func(msg string, a ...any) { + log.Warn(). + Caller(). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node", node.Hostname). + Msgf(msg, a...) + }, + func(msg string, a ...any) { + log.Info(). + Caller(). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node", node.Hostname). + Msgf(msg, a...) + }, + func(msg string, a ...any) { + log.Trace(). + Caller(). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node", node.Hostname). + Msgf(msg, a...) + }, + func(err error, msg string, a ...any) { + log.Error(). + Caller(). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node", node.Hostname). + Err(err). + Msgf(msg, a...) + } +} + +// hostInfoChanged reports if hostInfo has changed in two ways, +// - first bool reports if an update needs to be sent to nodes +// - second reports if there has been changes to routes +// the caller can then use this info to save and update nodes +// and routes as needed. +func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { + if old.Equal(new) { + return false, false + } + + // Routes + oldRoutes := old.RoutableIPs + newRoutes := new.RoutableIPs + + sort.Slice(oldRoutes, func(i, j int) bool { + return comparePrefix(oldRoutes[i], oldRoutes[j]) > 0 + }) + sort.Slice(newRoutes, func(i, j int) bool { + return comparePrefix(newRoutes[i], newRoutes[j]) > 0 + }) + + if !xslices.Equal(oldRoutes, newRoutes) { + return true, true + } + + // Services is mostly useful for discovery and not critical, + // except for peerapi, which is how nodes talk to eachother. + // If peerapi was not part of the initial mapresponse, we + // need to make sure its sent out later as it is needed for + // Taildrop. + // TODO(kradalby): Length comparison is a bit naive, replace. + if len(old.Services) != len(new.Services) { + return true, false + } + + return false, false +} + +// TODO(kradalby): Remove after go 1.23, will be in stdlib. +// Compare returns an integer comparing two prefixes. +// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. +// Prefixes sort first by validity (invalid before valid), then +// address family (IPv4 before IPv6), then prefix length, then +// address. +func comparePrefix(p, p2 netip.Prefix) int { + if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { + return c + } + if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { + return c + } + return p.Addr().Compare(p2.Addr()) +} diff --git a/integration/general_test.go b/integration/general_test.go index 89e0d34238..db9bf83b7c 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -335,14 +335,14 @@ func TestTaildrop(t *testing.T) { IntegrationSkip(t) t.Parallel() - retry := func(times int, sleepInverval time.Duration, doWork func() error) error { + retry := func(times int, sleepInterval time.Duration, doWork func() error) error { var err error for attempts := 0; attempts < times; attempts++ { err = doWork() if err == nil { return nil } - time.Sleep(sleepInverval) + time.Sleep(sleepInterval) } return err @@ -793,7 +793,7 @@ func TestNodeOnlineStatus(t *testing.T) { continue } - // All peers of this nodess are reporting to be + // All peers of this nodes are reporting to be // connected to the control server assert.Truef( t, diff --git a/integration/scenario.go b/integration/scenario.go index 9444d88286..3f0eb7d277 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -450,7 +450,7 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error { return nil } -// CreateHeadscaleEnv is a conventient method returning a complete Headcale +// CreateHeadscaleEnv is a convenient method returning a complete Headcale // test environment with nodes of all versions, joined to the server with X // users. func (s *Scenario) CreateHeadscaleEnv( diff --git a/integration/utils.go b/integration/utils.go index 1e2cfd2cd3..840dbc4c9b 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -331,7 +331,7 @@ func dockertestMaxWait() time.Duration { // return timeout // } -// pingAllNegativeHelper is intended to have 1 or more nodes timeing out from the ping, +// pingAllNegativeHelper is intended to have 1 or more nodes timing out from the ping, // it counts failures instead of successes. // func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { // t.Helper() From ab551cdd6f5803a9970cc941fde71b8fb4d8baee Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sun, 12 May 2024 18:29:55 +0200 Subject: [PATCH 2/7] trigger GitHub actions From 76242411df5f19c05142aab7a425180a1f9e3a3f Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sat, 18 May 2024 12:13:22 +0200 Subject: [PATCH 3/7] remove kdiff3 orig files --- docs/web-ui.md.orig | 23 - flake.nix.orig | 178 ------- hscontrol/db/node.go.orig | 772 ------------------------------- hscontrol/db/node_test.go.orig | 625 ------------------------- hscontrol/poll.go.orig | 818 --------------------------------- 5 files changed, 2416 deletions(-) delete mode 100644 docs/web-ui.md.orig delete mode 100644 flake.nix.orig delete mode 100644 hscontrol/db/node.go.orig delete mode 100644 hscontrol/db/node_test.go.orig delete mode 100644 hscontrol/poll.go.orig diff --git a/docs/web-ui.md.orig b/docs/web-ui.md.orig deleted file mode 100644 index 3175057c1f..0000000000 --- a/docs/web-ui.md.orig +++ /dev/null @@ -1,23 +0,0 @@ -# Headscale web interface - -!!! warning "Community contributions" - - This page contains community contributions. The projects listed here are not - maintained by the Headscale authors and are written by community members. - -<<<<<<< HEAD -| Name | Repository Link | Description | Status | -| --------------- | ------------------------------------------------------- | --------------------------------------------------------------------------- | ------ | -| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | -| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | -| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend enviroment required | Alpha | -| headscale-admin | [Github](https://github.com/GoodiesHQ/headscale-admin) | Headscale-Admin is meant to be a simple, modern web interface for Headscale | Beta | -======= -| Name | Repository Link | Description | Status | -| --------------- | ------------------------------------------------------- | ------------------------------------------------------------------------- | ------ | -| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | -| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | -| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend environment required | Alpha | ->>>>>>> cde0b83 (Fix typos) - -You can ask for support on our dedicated [Discord channel](https://discord.com/channels/896711691637780480/1105842846386356294). diff --git a/flake.nix.orig b/flake.nix.orig deleted file mode 100644 index ab17ebed55..0000000000 --- a/flake.nix.orig +++ /dev/null @@ -1,178 +0,0 @@ -{ - description = "headscale - Open Source Tailscale Control server"; - - inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - flake-utils.url = "github:numtide/flake-utils"; - }; - - outputs = { - self, - nixpkgs, - flake-utils, - ... - }: let - headscaleVersion = - if (self ? shortRev) - then self.shortRev - else "dev"; - in - { - overlay = _: prev: let - pkgs = nixpkgs.legacyPackages.${prev.system}; - in rec { - headscale = pkgs.buildGo122Module rec { - pname = "headscale"; - version = headscaleVersion; - src = pkgs.lib.cleanSource self; - - # Only run unit tests when testing a build - checkFlags = ["-short"]; - - # When updating go.mod or go.sum, a new sha will need to be calculated, -<<<<<<< HEAD - # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-wXfKeiJaGe6ahOsONrQhvbuMN8flQ13b0ZjxdbFs1e8="; -======= - # update this if you have a mismatch after doing a change to those files. - vendorHash = "sha256-Yb5WaN0abPLZ4mPnuJGZoj6EMfoZjaZZ0f344KWva3o="; ->>>>>>> cde0b83 (Fix typos) - - subPackages = ["cmd/headscale"]; - - ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; - }; - - protoc-gen-grpc-gateway = pkgs.buildGoModule rec { - pname = "grpc-gateway"; - version = "2.19.1"; - - src = pkgs.fetchFromGitHub { - owner = "grpc-ecosystem"; - repo = "grpc-gateway"; - rev = "v${version}"; - sha256 = "sha256-CdGQpQfOSimeio8v1lZ7xzE/oAS2qFyu+uN+H9i7vpo="; - }; - - vendorHash = "sha256-no7kZGpf/VOuceC3J+izGFQp5aMS3b+Rn+x4BFZ2zgs="; - - nativeBuildInputs = [pkgs.installShellFiles]; - - subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"]; - }; - }; - } - // flake-utils.lib.eachDefaultSystem - (system: let - pkgs = import nixpkgs { - overlays = [self.overlay]; - inherit system; - }; - buildDeps = with pkgs; [git go_1_22 gnumake]; - devDeps = with pkgs; - buildDeps - ++ [ - golangci-lint - golines - nodePackages.prettier - goreleaser - nfpm - gotestsum - gotests - ksh - ko - yq-go - ripgrep - - # 'dot' is needed for pprof graphs - # go tool pprof -http=: - graphviz - - # Protobuf dependencies - protobuf - protoc-gen-go - protoc-gen-go-grpc - protoc-gen-grpc-gateway - buf - clang-tools # clang-format - ]; - - # Add entry to build a docker image with headscale - # caveat: only works on Linux - # - # Usage: - # nix build .#headscale-docker - # docker load < result - headscale-docker = pkgs.dockerTools.buildLayeredImage { - name = "headscale"; - tag = headscaleVersion; - contents = [pkgs.headscale]; - config.Entrypoint = [(pkgs.headscale + "/bin/headscale")]; - }; - in rec { - # `nix develop` - devShell = pkgs.mkShell { - buildInputs = - devDeps - ++ [ - (pkgs.writeShellScriptBin - "nix-vendor-sri" - '' - set -eu - - OUT=$(mktemp -d -t nar-hash-XXXXXX) - rm -rf "$OUT" - - go mod vendor -o "$OUT" - go run tailscale.com/cmd/nardump --sri "$OUT" - rm -rf "$OUT" - '') - - (pkgs.writeShellScriptBin - "go-mod-update-all" - '' - cat go.mod | ${pkgs.silver-searcher}/bin/ag "\t" | ${pkgs.silver-searcher}/bin/ag -v indirect | ${pkgs.gawk}/bin/awk '{print $1}' | ${pkgs.findutils}/bin/xargs go get -u - go mod tidy - '') - ]; - - shellHook = '' - export PATH="$PWD/result/bin:$PATH" - ''; - }; - - # `nix build` - packages = with pkgs; { - inherit headscale; - inherit headscale-docker; - }; - defaultPackage = pkgs.headscale; - - # `nix run` - apps.headscale = flake-utils.lib.mkApp { - drv = packages.headscale; - }; - apps.default = apps.headscale; - - checks = { - format = - pkgs.runCommand "check-format" - { - buildInputs = with pkgs; [ - gnumake - nixpkgs-fmt - golangci-lint - nodePackages.prettier - golines - clang-tools - ]; - } '' - ${pkgs.nixpkgs-fmt}/bin/nixpkgs-fmt ${./.} - ${pkgs.golangci-lint}/bin/golangci-lint run --fix --timeout 10m - ${pkgs.nodePackages.prettier}/bin/prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' - ${pkgs.golines}/bin/golines --max-len=88 --base-formatter=gofumpt -w ${./.} - ${pkgs.clang-tools}/bin/clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i ${./.} - ''; - }; - }); -} diff --git a/hscontrol/db/node.go.orig b/hscontrol/db/node.go.orig deleted file mode 100644 index 14bb3481ee..0000000000 --- a/hscontrol/db/node.go.orig +++ /dev/null @@ -1,772 +0,0 @@ -package db - -import ( - "errors" - "fmt" - "net/netip" - "sort" - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/patrickmn/go-cache" - "github.com/puzpuzpuz/xsync/v3" - "github.com/rs/zerolog/log" - "gorm.io/gorm" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -const ( - NodeGivenNameHashLength = 8 - NodeGivenNameTrimSize = 2 -) - -var ( - ErrNodeNotFound = errors.New("node not found") - ErrNodeRouteIsNotAvailable = errors.New("route is not available on node") - ErrNodeNotFoundRegistrationCache = errors.New( - "node not found in registration cache", - ) - ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") - ErrDifferentRegisteredUser = errors.New( - "node was previously registered with a different user", - ) -) - -func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListPeers(rx, nodeID) - }) -} - -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. -func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("id <> ?", - nodeID).Find(&nodes).Error; err != nil { - return types.Nodes{}, err - } - - sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) - - return nodes, nil -} - -func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { - return ListNodes(rx) - }) -} - -func ListNodes(tx *gorm.DB) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - -func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - -func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return getNode(rx, user, name) - }) -} - -// getNode finds a Node by name and user and returns the Node struct. -func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { - nodes, err := ListNodesByUser(tx, user) - if err != nil { - return nil, err - } - - for _, m := range nodes { - if m.Hostname == name { - return m, nil - } - } - - return nil, ErrNodeNotFound -} - -func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByID(rx, id) - }) -} - -// GetNodeByID finds a Node by ID and returns the Node struct. -func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) { - mach := types.Node{} - if result := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Find(&types.Node{ID: id}).First(&mach); result.Error != nil { - return nil, result.Error - } - - return &mach, nil -} - -func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByMachineKey(rx, machineKey) - }) -} - -// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. -func GetNodeByMachineKey( - tx *gorm.DB, - machineKey key.MachinePublic, -) (*types.Node, error) { - mach := types.Node{} - if result := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil { - return nil, result.Error - } - - return &mach, nil -} - -func (hsdb *HSDatabase) GetNodeByAnyKey( - machineKey key.MachinePublic, - nodeKey key.NodePublic, - oldNodeKey key.NodePublic, -) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) - }) -} - -// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -// TODO(kradalby): see if we can remove this. -func GetNodeByAnyKey( - tx *gorm.DB, - machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, -) (*types.Node, error) { - node := types.Node{} - if result := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", - machineKey.String(), - nodeKey.String(), - oldNodeKey.String()); result.Error != nil { - return nil, result.Error - } - - return &node, nil -} - -func (hsdb *HSDatabase) SetTags( - nodeID types.NodeID, - tags []string, -) error { - return hsdb.Write(func(tx *gorm.DB) error { - return SetTags(tx, nodeID, tags) - }) -} - -// SetTags takes a Node struct pointer and update the forced tags. -func SetTags( - tx *gorm.DB, - nodeID types.NodeID, - tags []string, -) error { - if len(tags) == 0 { - // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil { - return fmt.Errorf("failed to remove tags for node in the database: %w", err) - } - - return nil - } - - newTags := types.StringList{} - for _, tag := range tags { - if !util.StringOrPrefixListContains(newTags, tag) { - newTags = append(newTags, tag) - } - } - - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { - return fmt.Errorf("failed to update tags for node in the database: %w", err) - } - - return nil -} - -// RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. -func RenameNode(tx *gorm.DB, - nodeID uint64, newName string, -) error { - err := util.CheckForFQDNRules( - newName, - ) - if err != nil { - return fmt.Errorf("renaming node: %w", err) - } - - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { - return fmt.Errorf("failed to rename node in the database: %w", err) - } - - return nil -} - -func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { - return hsdb.Write(func(tx *gorm.DB) error { - return NodeSetExpiry(tx, nodeID, expiry) - }) -} - -// NodeSetExpiry takes a Node struct and a new expiry time. -func NodeSetExpiry(tx *gorm.DB, - nodeID types.NodeID, expiry time.Time, -) error { - return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error -} - -func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) { - return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { - return DeleteNode(tx, node, isLikelyConnected) - }) -} - -// DeleteNode deletes a Node from the database. -// Caller is responsible for notifying all of change. -func DeleteNode(tx *gorm.DB, - node *types.Node, - isLikelyConnected *xsync.MapOf[types.NodeID, bool], -) ([]types.NodeID, error) { - changed, err := deleteNodeRoutes(tx, node, isLikelyConnected) - if err != nil { - return changed, err - } - - // Unscoped causes the node to be fully removed from the database. - if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { - return changed, err - } - - return changed, nil -} - -// SetLastSeen sets a node's last seen field indicating that we -// have recently communicating with this node. -func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { - return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error -} - -func RegisterNodeFromAuthCallback( - tx *gorm.DB, - cache *cache.Cache, - mkey key.MachinePublic, - userName string, - nodeExpiry *time.Time, - registrationMethod string, - ipv4 *netip.Addr, - ipv6 *netip.Addr, -) (*types.Node, error) { - log.Debug(). - Str("machine_key", mkey.ShortString()). - Str("userName", userName). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") - - if nodeInterface, ok := cache.Get(mkey.String()); ok { - if registrationNode, ok := nodeInterface.(types.Node); ok { - user, err := GetUser(tx, userName) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, - ) - } - - // Registration of expired node with different user - if registrationNode.ID != 0 && - registrationNode.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } - - registrationNode.UserID = user.ID - registrationNode.User = *user - registrationNode.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - registrationNode.Expiry = nodeExpiry - } - - node, err := RegisterNode( - tx, - registrationNode, - ipv4, ipv6, - ) - - if err == nil { - cache.Delete(mkey.String()) - } - - return node, err - } else { - return nil, ErrCouldNotConvertNodeInterface - } - } - - return nil, ErrNodeNotFoundRegistrationCache -} - -func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - return RegisterNode(tx, node, ipv4, ipv6) - }) -} - -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. -func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { - log.Debug(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Name). - Msg("Registering node") - - // If the node exists and it already has IP(s), we just save it - // so we store the node.Expire and node.Nodekey that has been set when - // adding it to the registrationCache - if node.IPv4 != nil || node.IPv6 != nil { - if err := tx.Save(&node).Error; err != nil { - return nil, fmt.Errorf("failed register existing node in the database: %w", err) - } - - log.Trace(). - Caller(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Name). - Msg("Node authorized again") - - return &node, nil - } - - node.IPv4 = ipv4 - node.IPv6 = ipv6 - - if err := tx.Save(&node).Error; err != nil { - return nil, fmt.Errorf("failed register(save) node in the database: %w", err) - } - - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("Node registered with the database") - - return &node, nil -} - -// NodeSetNodeKey sets the node key of a node and saves it to the database. -func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { - return tx.Model(node).Updates(types.Node{ - NodeKey: nodeKey, - }).Error -} - -func (hsdb *HSDatabase) NodeSetMachineKey( - node *types.Node, - machineKey key.MachinePublic, -) error { - return hsdb.Write(func(tx *gorm.DB) error { - return NodeSetMachineKey(tx, node, machineKey) - }) -} - -// NodeSetMachineKey sets the node key of a node and saves it to the database. -func NodeSetMachineKey( - tx *gorm.DB, - node *types.Node, - machineKey key.MachinePublic, -) error { - return tx.Model(node).Updates(types.Node{ - MachineKey: machineKey, - }).Error -} - -// NodeSave saves a node object to the database, prefer to use a specific save method rather -// than this. It is intended to be used when we are changing or. -// TODO(kradalby): Remove this func, just use Save. -func NodeSave(tx *gorm.DB, node *types.Node) error { - return tx.Save(node).Error -} - -func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { - return GetAdvertisedRoutes(rx, node) - }) -} - -// GetAdvertisedRoutes returns the routes that are be advertised by the given node. -func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { - routes := types.Routes{} - - err := tx. - Preload("Node"). - Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("getting advertised routes for node(%d): %w", node.ID, err) - } - - prefixes := []netip.Prefix{} - for _, route := range routes { - prefixes = append(prefixes, netip.Prefix(route.Prefix)) - } - - return prefixes, nil -} - -func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { - return GetEnabledRoutes(rx, node) - }) -} - -// GetEnabledRoutes returns the routes that are enabled for the node. -func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { - routes := types.Routes{} - - err := tx. - Preload("Node"). - Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("getting enabled routes for node(%d): %w", node.ID, err) - } - - prefixes := []netip.Prefix{} - for _, route := range routes { - prefixes = append(prefixes, netip.Prefix(route.Prefix)) - } - - return prefixes, nil -} - -func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool { - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return false - } - - enabledRoutes, err := GetEnabledRoutes(tx, node) - if err != nil { - return false - } - - for _, enabledRoute := range enabledRoutes { - if route == enabledRoute { - return true - } - } - - return false -} - -func (hsdb *HSDatabase) enableRoutes( - node *types.Node, - routeStrs ...string, -) (*types.StateUpdate, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return enableRoutes(tx, node, routeStrs...) - }) -} - -// enableRoutes enables new routes based on a list of new routes. -func enableRoutes(tx *gorm.DB, - node *types.Node, routeStrs ...string, -) (*types.StateUpdate, error) { - newRoutes := make([]netip.Prefix, len(routeStrs)) - for index, routeStr := range routeStrs { - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return nil, err - } - - newRoutes[index] = route - } - - advertisedRoutes, err := GetAdvertisedRoutes(tx, node) - if err != nil { - return nil, err - } - - for _, newRoute := range newRoutes { - if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { - return nil, fmt.Errorf( - "route (%s) is not available on node %s: %w", - node.Hostname, - newRoute, ErrNodeRouteIsNotAvailable, - ) - } - } - - // Separate loop so we don't leave things in a half-updated state - for _, prefix := range newRoutes { - route := types.Route{} - err := tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). - First(&route).Error - if err == nil { - route.Enabled = true - - // Mark already as primary if there is only this node offering this subnet - // (and is not an exit route) - if !route.IsExitRoute() { - route.IsPrimary = isUniquePrefix(tx, route) - } - - err = tx.Save(&route).Error - if err != nil { - return nil, fmt.Errorf("failed to enable route: %w", err) - } - } else { - return nil, fmt.Errorf("failed to find route: %w", err) - } - } - - // Ensure the node has the latest routes when notifying the other - // nodes - nRoutes, err := GetNodeRoutes(tx, node) - if err != nil { - return nil, fmt.Errorf("failed to read back routes: %w", err) - } - - node.Routes = nRoutes - - log.Trace(). - Caller(). - Str("node", node.Hostname). - Strs("routes", routeStrs). - Msg("enabling routes") - - return &types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{node.ID}, - Message: "created in db.enableRoutes", - }, nil -} - -func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { - normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( - suppliedName, - ) - if err != nil { - return "", err - } - - if randomSuffix { - // Trim if a hostname will be longer than 63 chars after adding the hash. - trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize - if len(normalizedHostname) > trimmedHostnameLength { - normalizedHostname = normalizedHostname[:trimmedHostnameLength] - } - - suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength) - if err != nil { - return "", err - } - - normalizedHostname += "-" + suffix - } - - return normalizedHostname, nil -} - -func (hsdb *HSDatabase) GenerateGivenName( - mkey key.MachinePublic, - suppliedName string, -) (string, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { - return GenerateGivenName(rx, mkey, suppliedName) - }) -} - -func GenerateGivenName( - tx *gorm.DB, - mkey key.MachinePublic, - suppliedName string, -) (string, error) { - givenName, err := generateGivenName(suppliedName, false) - if err != nil { - return "", err - } - - // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := listNodesByGivenName(tx, givenName) - if err != nil { - return "", err - } - - var nodeFound *types.Node - for idx, node := range nodes { - if node.GivenName == givenName { - nodeFound = nodes[idx] - } - } - - if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() { - postfixedName, err := generateGivenName(suppliedName, true) - if err != nil { - return "", err - } - - givenName = postfixedName - } - - return givenName, nil -} - -<<<<<<< HEAD -func DeleteExpiredEphemeralNodes(tx *gorm.DB, - inactivityThreshhold time.Duration, -) ([]types.NodeID, []types.NodeID) { -======= -func ExpireEphemeralNodes(tx *gorm.DB, - inactivityThreshold time.Duration, -) (types.StateUpdate, bool) { ->>>>>>> cde0b83 (Fix typos) - users, err := ListUsers(tx) - if err != nil { - return nil, nil - } - - var expired []types.NodeID - var changedNodes []types.NodeID - for _, user := range users { - nodes, err := ListNodesByUser(tx, user.Name) - if err != nil { - return nil, nil - } - - for idx, node := range nodes { - if node.IsEphemeral() && node.LastSeen != nil && - time.Now(). -<<<<<<< HEAD - After(node.LastSeen.Add(inactivityThreshhold)) { - expired = append(expired, node.ID) -======= - After(node.LastSeen.Add(inactivityThreshold)) { - expired = append(expired, tailcfg.NodeID(node.ID)) ->>>>>>> cde0b83 (Fix typos) - - log.Info(). - Str("node", node.Hostname). - Msg("Ephemeral client removed from database") - - // empty isConnected map as ephemeral nodes are not routes - changed, err := DeleteNode(tx, nodes[idx], nil) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") - } - - changedNodes = append(changedNodes, changed...) - } - } - - // TODO(kradalby): needs to be moved out of transaction - } - - return expired, changedNodes -} - -func ExpireExpiredNodes(tx *gorm.DB, - lastCheck time.Time, -) (time.Time, types.StateUpdate, bool) { - // use the time of the start of the function to ensure we - // dont miss some nodes by returning it _after_ we have - // checked everything. - started := time.Now() - - expired := make([]*tailcfg.PeerChange, 0) - - nodes, err := ListNodes(tx) - if err != nil { - return time.Unix(0, 0), types.StateUpdate{}, false - } - for _, node := range nodes { - if node.IsExpired() && node.Expiry.After(lastCheck) { - expired = append(expired, &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: node.Expiry, - }) -<<<<<<< HEAD -======= - - now := time.Now() - // Do not use setNodeExpiry as that has a notifier hook, which - // can cause a deadlock, we are updating all changed nodes later - // and there is no point in notifying twice. - if err := tx.Model(&nodes[index]).Updates(types.Node{ - Expiry: &now, - }).Error; err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("🤮 Cannot expire node") - } else { - log.Info(). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("Node successfully expired") - } ->>>>>>> cde0b83 (Fix typos) - } - } - - if len(expired) > 0 { - return started, types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - }, true - } - - return started, types.StateUpdate{}, false -} diff --git a/hscontrol/db/node_test.go.orig b/hscontrol/db/node_test.go.orig deleted file mode 100644 index d64ee0429d..0000000000 --- a/hscontrol/db/node_test.go.orig +++ /dev/null @@ -1,625 +0,0 @@ -package db - -import ( - "fmt" - "net/netip" - "regexp" - "strconv" - "testing" - "time" - - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/puzpuzpuz/xsync/v3" - "gopkg.in/check.v1" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func (s *Suite) TestGetNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - pakID := uint(pak.ID) - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.getNode("test", "testnode") - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetNodeByID(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - oldNodeKey := key.NewNode() - - machineKey := key.NewMachine() - - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestHardDeleteNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode3", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) - c.Assert(err, check.IsNil) - - _, err = db.getNode(user.Name, "testnode3") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestListPeers(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - pakID := uint(pak.ID) - for index := 0; index <= 10; index++ { - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: types.NodeID(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode" + strconv.Itoa(index), - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - } - - node0ByID, err := db.GetNodeByID(0) - c.Assert(err, check.IsNil) - - peersOfNode0, err := db.ListPeers(node0ByID.ID) - c.Assert(err, check.IsNil) - - c.Assert(len(peersOfNode0), check.Equals, 9) - c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") - c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") -} - -func (s *Suite) TestGetACLFilteredPeers(c *check.C) { - type base struct { - user *types.User - key *types.PreAuthKey - } - - stor := make([]base, 0) - - for _, name := range []string{"test", "admin"} { - user, err := db.CreateUser(name) - c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - stor = append(stor, base{user, pak}) - } - - _, err := db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - nodeKey := key.NewNode() - machineKey := key.NewMachine() - pakID := uint(stor[index%2].key.ID) - - v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) - node := types.Node{ - ID: types.NodeID(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &v4, - Hostname: "testnode" + strconv.Itoa(index), - UserID: stor[index%2].user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - } - - aclPolicy := &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:test": {"admin"}, - }, - Hosts: map[string]netip.Prefix{}, - TagOwners: map[string][]string{}, - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"admin"}, - Destinations: []string{"*:*"}, - }, - { - Action: "accept", - Sources: []string{"test"}, - Destinations: []string{"test:*"}, - }, - }, - Tests: []policy.ACLTest{}, - } - - adminNode, err := db.GetNodeByID(1) - c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User) - c.Assert(err, check.IsNil) - - testNode, err := db.GetNodeByID(2) - c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) - c.Assert(err, check.IsNil) - - adminPeers, err := db.ListPeers(adminNode.ID) - c.Assert(err, check.IsNil) - - testPeers, err := db.ListPeers(testNode.ID) - c.Assert(err, check.IsNil) - - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) - c.Assert(err, check.IsNil) - - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) - c.Assert(err, check.IsNil) - - peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) - peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules) - - c.Log(peersOfTestNode) - c.Assert(len(peersOfTestNode), check.Equals, 9) - c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1") - c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3") - c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5") - - c.Log(peersOfAdminNode) - c.Assert(len(peersOfAdminNode), check.Equals, 9) - c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4") - c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") -} - -func (s *Suite) TestExpireNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - pakID := uint(pak.ID) - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - Expiry: &time.Time{}, - } - db.DB.Save(node) - - nodeFromDB, err := db.getNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert(nodeFromDB, check.NotNil) - - c.Assert(nodeFromDB.IsExpired(), check.Equals, false) - - now := time.Now() - err = db.NodeSetExpiry(nodeFromDB.ID, now) - c.Assert(err, check.IsNil) - - nodeFromDB, err = db.getNode("test", "testnode") - c.Assert(err, check.IsNil) - - c.Assert(nodeFromDB.IsExpired(), check.Equals, true) -} - -<<<<<<< HEAD -======= -func (s *Suite) TestSerdeAddressStringSlice(c *check.C) { - input := types.NodeAddresses([]netip.Addr{ - netip.MustParseAddr("192.0.2.1"), - netip.MustParseAddr("2001:db8::1"), - }) - serialized, err := input.Value() - c.Assert(err, check.IsNil) - if serial, ok := serialized.(string); ok { - c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") - } - - var deserialized types.NodeAddresses - err = deserialized.Scan(serialized) - c.Assert(err, check.IsNil) - - c.Assert(len(deserialized), check.Equals, len(input)) - for i := range deserialized { - c.Assert(deserialized[i], check.Equals, input[i]) - } -} - ->>>>>>> cde0b83 (Fix typos) -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("user-1", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machineKey2 := key.NewMachine() - - pakID := uint(pak.ID) - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) - - givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") - comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1") - comment = check.Commentf("Same user, same node, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1") - comment = check.Commentf("Same user, unique nodes, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment) -} - -func (s *Suite) TestSetTags(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.getNode("test", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - pakID := uint(pak.ID) - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - } - - trx := db.DB.Save(node) - c.Assert(trx.Error, check.IsNil) - - // assign simple tags - sTags := []string{"tag:test", "tag:foo"} - err = db.SetTags(node.ID, sTags) - c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) - - // assign duplicate tags, expect no errors but no doubles in DB - eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = db.SetTags(node.ID, eTags) - c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert( - node.ForcedTags, - check.DeepEquals, - types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), - ) - - // test removing tags - err = db.SetTags(node.ID, []string{}) - c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{})) -} - -func TestHeadscale_generateGivenName(t *testing.T) { - type args struct { - suppliedName string - randomSuffix bool - } - tests := []struct { - name string - args args - want *regexp.Regexp - wantErr bool - }{ - { - name: "simple node name generation", - args: args{ - suppliedName: "testnode", - randomSuffix: false, - }, - want: regexp.MustCompile("^testnode$"), - wantErr: false, - }, - { - name: "node name with 53 chars", - args: args{ - suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", - randomSuffix: false, - }, - want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"), - wantErr: false, - }, - { - name: "node name with 63 chars", - args: args{ - suppliedName: "nodeeeeeee12345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: regexp.MustCompile("^nodeeeeeee12345678901234567890123456789012345678901234567890123$"), - wantErr: false, - }, - { - name: "node name with 64 chars", - args: args{ - suppliedName: "nodeeeeeee123456789012345678901234567890123456789012345678901234", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "node name with 73 chars", - args: args{ - suppliedName: "nodeeeeeee123456789012345678901234567890123456789012345678901234567890123", - randomSuffix: false, - }, - want: nil, - wantErr: true, - }, - { - name: "node name with random suffix", - args: args{ - suppliedName: "test", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", NodeGivenNameHashLength)), - wantErr: false, - }, - { - name: "node name with 63 chars with random suffix", - args: args{ - suppliedName: "nodeeee12345678901234567890123456789012345678901234567890123", - randomSuffix: true, - }, - want: regexp.MustCompile(fmt.Sprintf("^nodeeee1234567890123456789012345678901234567890123456-[a-z0-9]{%d}$", NodeGivenNameHashLength)), - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) - if (err != nil) != tt.wantErr { - t.Errorf( - "Headscale.GenerateGivenName() error = %v, wantErr %v", - err, - tt.wantErr, - ) - - return - } - - if tt.want != nil && !tt.want.MatchString(got) { - t.Errorf( - "Headscale.GenerateGivenName() = %v, does not match %v", - tt.want, - got, - ) - } - - if len(got) > util.LabelHostnameLength { - t.Errorf( - "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", - got, - util.LabelHostnameLength, - ) - } - }) - } -} - -func (s *Suite) TestAutoApproveRoutes(c *check.C) { - acl := []byte(` -{ - "tagOwners": { - "tag:exit": ["test"], - }, - - "groups": { - "group:test": ["test"] - }, - - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], - - "autoApprovers": { - "exitNode": ["tag:exit"], - "routes": { - "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], - } - } -} - `) - - pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - c.Assert(pol, check.NotNil) - - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") - defaultRouteV6 := netip.MustParsePrefix("::/0") - route1 := netip.MustParsePrefix("10.10.0.0/16") - // Check if a subprefix of an autoapproved route is approved - route2 := netip.MustParsePrefix("10.11.0.0/24") - - v4 := netip.MustParseAddr("100.64.0.1") - pakID := uint(pak.ID) - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "test", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: &pakID, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:exit"}, - RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, - }, - IPv4: &v4, - } - - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - sendUpdate, err := db.SaveNodeRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) - - node0ByID, err := db.GetNodeByID(0) - c.Assert(err, check.IsNil) - - // TODO(kradalby): Check state update - err = db.EnableAutoApprovedRoutes(pol, node0ByID) - c.Assert(err, check.IsNil) - - enabledRoutes, err := db.GetEnabledRoutes(node0ByID) - c.Assert(err, check.IsNil) - c.Assert(enabledRoutes, check.HasLen, 4) -} diff --git a/hscontrol/poll.go.orig b/hscontrol/poll.go.orig deleted file mode 100644 index c4e279563b..0000000000 --- a/hscontrol/poll.go.orig +++ /dev/null @@ -1,818 +0,0 @@ -package hscontrol - -import ( - "cmp" - "context" - "fmt" - "math/rand/v2" - "net/http" - "net/netip" - "sort" - "strings" - "sync" - "time" - - "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/mapper" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/rs/zerolog/log" - xslices "golang.org/x/exp/slices" - "gorm.io/gorm" - "tailscale.com/tailcfg" -) - -const ( - keepAliveInterval = 50 * time.Second -) - -type contextKey string - -const nodeNameContextKey = contextKey("nodeName") - -type sessionManager struct { - mu sync.RWMutex - sess map[types.NodeID]*mapSession -} - -type mapSession struct { - h *Headscale - req tailcfg.MapRequest - ctx context.Context - capVer tailcfg.CapabilityVersion - mapper *mapper.Mapper - - serving bool - servingMu sync.Mutex - - ch chan types.StateUpdate - cancelCh chan struct{} - - keepAliveTicker *time.Ticker - - node *types.Node - w http.ResponseWriter - - warnf func(string, ...any) - infof func(string, ...any) - tracef func(string, ...any) - errf func(error, string, ...any) -} - -func (h *Headscale) newMapSession( - ctx context.Context, - req tailcfg.MapRequest, - w http.ResponseWriter, - node *types.Node, -) *mapSession { - warnf, infof, tracef, errf := logPollFunc(req, node) - - var updateChan chan types.StateUpdate - if req.Stream { - // Use a buffered channel in case a node is not fully ready - // to receive a message to make sure we dont block the entire - // notifier. - updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize) - updateChan <- types.StateUpdate{ - Type: types.StateFullUpdate, - } - } - - return &mapSession{ - h: h, - ctx: ctx, - req: req, - w: w, - node: node, - capVer: req.Version, - mapper: h.mapper, - - // serving indicates if a client is being served. - serving: false, - - ch: updateChan, - cancelCh: make(chan struct{}), - - keepAliveTicker: time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)), - - // Loggers - warnf: warnf, - infof: infof, - tracef: tracef, - errf: errf, - } -} - -func (m *mapSession) close() { - m.servingMu.Lock() - defer m.servingMu.Unlock() - if !m.serving { - return - } - - m.tracef("mapSession (%p) sending message on cancel chan") - m.cancelCh <- struct{}{} - m.tracef("mapSession (%p) sent message on cancel chan") -} - -func (m *mapSession) isStreaming() bool { - return m.req.Stream && !m.req.ReadOnly -} - -func (m *mapSession) isEndpointUpdate() bool { - return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers -} - -func (m *mapSession) isReadOnlyUpdate() bool { - return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly -} - -// handlePoll ensures the node gets the appropriate updates from either -// polling or immediate responses. -// -//nolint:gocyclo -func (m *mapSession) serve() { - // Register with the notifier if this is a streaming - // session - if m.isStreaming() { - // defers are called in reverse order, - // so top one is executed last. - - // Failover the node's routes if any. - defer m.infof("node has disconnected, mapSession: %p", m) - defer m.pollFailoverRoutes("node closing connection", m.node) - - defer m.h.updateNodeOnlineStatus(false, m.node) - defer m.h.nodeNotifier.RemoveNode(m.node.ID) - - defer func() { - m.servingMu.Lock() - defer m.servingMu.Unlock() - - m.serving = false - close(m.cancelCh) - }() - - m.serving = true - - m.h.nodeNotifier.AddNode(m.node.ID, m.ch) - m.h.updateNodeOnlineStatus(true, m.node) - - m.infof("node has connected, mapSession: %p", m) - } - - // TODO(kradalby): A set todos to harden: - // - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true - - // This is the mechanism where the node gives us information about its - // current configuration. - // - // If OmitPeers is true, Stream is false, and ReadOnly is false, - // then te server will let clients update their endpoints without - // breaking existing long-polling (Stream == true) connections. - // In this case, the server can omit the entire response; the client - // only checks the HTTP response status code. - // - // This is what Tailscale calls a Lite update, the client ignores - // the response and just wants a 200. - // !req.stream && !req.ReadOnly && req.OmitPeers - // - // TODO(kradalby): remove ReadOnly when we only support capVer 68+ - if m.isEndpointUpdate() { - m.handleEndpointUpdate() - - return - } - - // ReadOnly is whether the client just wants to fetch the - // MapResponse, without updating their Endpoints. The - // Endpoints field will be ignored and LastSeen will not be - // updated and peers will not be notified of changes. - // - // The intended use is for clients to discover the DERP map at - // start-up before their first real endpoint update. - if m.isReadOnlyUpdate() { - m.handleReadOnlyRequest() - - return - } - - // From version 68, all streaming requests can be treated as read only. - if m.capVer < 68 { - // Error has been handled/written to client in the func - // return - err := m.handleSaveNode() - if err != nil { - mapResponseWriteUpdatesInStream.WithLabelValues("error").Inc() - return - } - mapResponseWriteUpdatesInStream.WithLabelValues("ok").Inc() - } - - // Set up the client stream - m.h.pollNetMapStreamWG.Add(1) - defer m.h.pollNetMapStreamWG.Done() - - m.pollFailoverRoutes("node connected", m.node) - - // Upgrade the writer to a ResponseController - rc := http.NewResponseController(m.w) - - // Longpolling will break if there is a write timeout, - // so it needs to be disabled. - rc.SetWriteDeadline(time.Time{}) - -<<<<<<< HEAD - ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) -======= - return - } - - isConnected := h.nodeNotifier.ConnectedMap() - for _, peer := range peers { - online := isConnected[peer.MachineKey] - peer.IsOnline = &online - } - - mapp := mapper.NewMapper( - node, - peers, - h.DERPMap, - h.cfg.BaseDomain, - h.cfg.DNSConfig, - h.cfg.LogTail.Enabled, - h.cfg.RandomizeClientPort, - ) - - // update ACLRules with peer information (to update server tags if necessary) - if h.ACLPolicy != nil { - // update routes with peer information - // This state update is ignored as it will be sent - // as part of the whole node - // TODO(kradalby): figure out if that is actually correct - _, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) - if err != nil { - logErr(err, "Error running auto approved routes") - } - } - - logTrace("Sending initial map") - - mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) - if err != nil { - logErr(err, "Failed to create MapResponse") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - // Send the client an update to make sure we send an initial mapresponse - _, err = writer.Write(mapResp) - if err != nil { - logErr(err, "Could not write the map response") - - return - } - - if flusher, ok := writer.(http.Flusher); ok { - flusher.Flush() - } else { - return - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from handlePoll -> new node added", - } - if stateUpdate.Valid() { - ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname) - h.nodeNotifier.NotifyWithIgnore( - ctx, - stateUpdate, - node.MachineKey.String()) - } - - if len(node.Routes) > 0 { - go h.pollFailoverRoutes(logErr, "new node", node) - } - - keepAliveTicker := time.NewTicker(keepAliveInterval) - - ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname)) ->>>>>>> cde0b83 (Fix typos) - defer cancel() - - // Loop through updates and continuously send them to the - // client. - for { - // consume channels with update, keep alives or "batch" blocking signals - select { - case <-m.cancelCh: - m.tracef("poll cancelled received") - return - case <-ctx.Done(): - m.tracef("poll context done") - return - - // Consume all updates sent to node - case update := <-m.ch: - m.tracef("received stream update: %s %s", update.Type.String(), update.Message) - mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() - - var data []byte - var err error - var lastMessage string - - // Ensure the node object is updated, for example, there - // might have been a hostinfo update in a sidechannel - // which contains data needed to generate a map response. - m.node, err = m.h.db.GetNodeByID(m.node.ID) - if err != nil { - m.errf(err, "Could not get machine from db") - - return - } - - updateType := "full" - switch update.Type { - case types.StateFullUpdate: - m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) - case types.StatePeerChanged: - changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) - - for _, nodeID := range update.ChangeNodes { - changed[nodeID] = true - } - - lastMessage = update.Message - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) - updateType = "change" - - case types.StatePeerChangedPatch: - m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) - updateType = "patch" - case types.StatePeerRemoved: - changed := make(map[types.NodeID]bool, len(update.Removed)) - - for _, nodeID := range update.Removed { - changed[nodeID] = false - } - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) - updateType = "remove" - case types.StateSelfUpdate: - lastMessage = update.Message - m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) - updateType = "remove" - case types.StateDERPUpdated: - m.tracef("Sending DERPUpdate MapResponse") - data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap) - updateType = "derp" - } - - if err != nil { - m.errf(err, "Could not get the create map update") - - return - } - - // log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", m.node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response") - - // Only send update if there is change - if data != nil { - startWrite := time.Now() - _, err = m.w.Write(data) - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "Could not write the map response, for mapSession: %p", m) - return - } - - err = rc.Flush() - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "flushing the map response to client, for mapSession: %p", m) - return - } - - log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") - - mapResponseSent.WithLabelValues("ok", updateType).Inc() - m.tracef("update sent") - } - - case <-m.keepAliveTicker.C: - data, err := m.mapper.KeepAliveResponse(m.req, m.node) - if err != nil { - m.errf(err, "Error generating the keep alive msg") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - _, err = m.w.Write(data) - if err != nil { - m.errf(err, "Cannot write keep alive message") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - err = rc.Flush() - if err != nil { - m.errf(err, "flushing keep alive to client, for mapSession: %p", m) - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - - mapResponseSent.WithLabelValues("ok", "keepalive").Inc() - } - } -} - -func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { - update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node) - }) - if err != nil { - m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) - - return - } - - if update != nil && !update.Empty() { - ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.ID) - } -} - -// updateNodeOnlineStatus records the last seen status of a node and notifies peers -// about change in their online/offline status. -// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. -func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { - change := &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - Online: &online, - } - - if !online { - now := time.Now() - - // lastSeen is only relevant if the node is disconnected. - node.LastSeen = &now - change.LastSeen = &now - - err := h.db.Write(func(tx *gorm.DB) error { - return db.SetLastSeen(tx, node.ID, *node.LastSeen) - }) - if err != nil { - log.Error().Err(err).Msg("Cannot update node LastSeen") - - return - } - } - - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) - h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - change, - }, - }, node.ID) -} - -func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) { - log.Trace(). - Str("handler", "PollNetMap"). - Str("node", node). - Str("channel", "Done"). - Msg(fmt.Sprintf("Closing %s channel", name)) - - close(channel) -} - -func (m *mapSession) handleEndpointUpdate() { - m.tracef("received endpoint update") - - change := m.node.PeerChangeFromMapRequest(m.req) - - online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID) - change.Online = &online - - m.node.ApplyPeerChange(&change) - - sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) - m.node.Hostinfo = m.req.Hostinfo - - logTracePeerChange(m.node.Hostname, sendUpdate, &change) - - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(change) && !sendUpdate { - mapResponseEndpointUpdates.WithLabelValues("noop").Inc() - return - } - - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change to - // the routable IPs of the host and update update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if routesChanged { - var err error - _, err = m.h.db.SaveNodeRoutes(m.node) - if err != nil { - m.errf(err, "Error processing node routes") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - - return - } - - if m.h.ACLPolicy != nil { - // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) - if err != nil { - m.errf(err, "Error running auto approved routes") - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - } - } - - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) - m.h.nodeNotifier.NotifyByNodeID( - ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{m.node.ID}, - }, - m.node.ID) - } - - if err := m.h.db.DB.Save(m.node).Error; err != nil { - m.errf(err, "Failed to persist/update node in the database") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseEndpointUpdates.WithLabelValues("error").Inc() - - return - } - - ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore( - ctx, - types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{m.node.ID}, - Message: "called from handlePoll -> update", - }, - m.node.ID) - - m.w.WriteHeader(http.StatusOK) - mapResponseEndpointUpdates.WithLabelValues("ok").Inc() - - return -} - -// handleSaveNode saves node updates in the maprequest _streaming_ -// path and is mostly the same code as in handleEndpointUpdate. -// It is not attempted to be deduplicated since it will go away -// when we stop supporting older than 68 which removes updates -// when the node is streaming. -func (m *mapSession) handleSaveNode() error { - m.tracef("saving node update from stream session") - - change := m.node.PeerChangeFromMapRequest(m.req) - - // A stream is being set up, the node is Online - online := true - change.Online = &online - - m.node.ApplyPeerChange(&change) - - sendUpdate, routesChanged := hostInfoChanged(m.node.Hostinfo, m.req.Hostinfo) - m.node.Hostinfo = m.req.Hostinfo - - // If there is no changes and nothing to save, - // return early. - if peerChangeEmpty(change) || !sendUpdate { - return nil - } - - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change to - // the routable IPs of the host and update update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if routesChanged { - var err error - _, err = m.h.db.SaveNodeRoutes(m.node) - if err != nil { - return err - } - - if m.h.ACLPolicy != nil { - // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) - if err != nil { - return err - } - } - } - - if err := m.h.db.DB.Save(m.node).Error; err != nil { - return err - } - - ctx := types.NotifyCtx(context.Background(), "pre-68-update-while-stream", m.node.Hostname) - m.h.nodeNotifier.NotifyWithIgnore( - ctx, - types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: []types.NodeID{m.node.ID}, - Message: "called from handlePoll -> pre-68-update-while-stream", - }, - m.node.ID) - - return nil -} - -func (m *mapSession) handleReadOnlyRequest() { - m.tracef("Client asked for a lite update, responding without peers") - - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) - if err != nil { - m.errf(err, "Failed to create MapResponse") - http.Error(m.w, "", http.StatusInternalServerError) - mapResponseReadOnly.WithLabelValues("error").Inc() - return - } - - m.w.Header().Set("Content-Type", "application/json; charset=utf-8") - m.w.WriteHeader(http.StatusOK) - _, err = m.w.Write(mapResp) - if err != nil { - m.errf(err, "Failed to write response") - mapResponseReadOnly.WithLabelValues("error").Inc() - return - } - - m.w.WriteHeader(http.StatusOK) - mapResponseReadOnly.WithLabelValues("ok").Inc() - - return -} - -func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { - trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname) - - if change.Key != nil { - trace = trace.Str("node_key", change.Key.ShortString()) - } - - if change.DiscoKey != nil { - trace = trace.Str("disco_key", change.DiscoKey.ShortString()) - } - - if change.Online != nil { - trace = trace.Bool("online", *change.Online) - } - - if change.Endpoints != nil { - eps := make([]string, len(change.Endpoints)) - for idx, ep := range change.Endpoints { - eps[idx] = ep.String() - } - - trace = trace.Strs("endpoints", eps) - } - - if hostinfoChange { - trace = trace.Bool("hostinfo_changed", hostinfoChange) - } - - if change.DERPRegion != 0 { - trace = trace.Int("derp_region", change.DERPRegion) - } - - trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received") -} - -func peerChangeEmpty(chng tailcfg.PeerChange) bool { - return chng.Key == nil && - chng.DiscoKey == nil && - chng.Online == nil && - chng.Endpoints == nil && - chng.DERPRegion == 0 && - chng.LastSeen == nil && - chng.KeyExpiry == nil -} - -func logPollFunc( - mapRequest tailcfg.MapRequest, - node *types.Node, -) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) { - return func(msg string, a ...any) { - log.Warn(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). - Msgf(msg, a...) - }, - func(msg string, a ...any) { - log.Trace(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). - Msgf(msg, a...) - }, - func(err error, msg string, a ...any) { - log.Error(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Uint64("node.id", node.ID.Uint64()). - Str("node", node.Hostname). - Err(err). - Msgf(msg, a...) - } -} - -// hostInfoChanged reports if hostInfo has changed in two ways, -// - first bool reports if an update needs to be sent to nodes -// - second reports if there has been changes to routes -// the caller can then use this info to save and update nodes -// and routes as needed. -func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) { - if old.Equal(new) { - return false, false - } - - // Routes - oldRoutes := old.RoutableIPs - newRoutes := new.RoutableIPs - - sort.Slice(oldRoutes, func(i, j int) bool { - return comparePrefix(oldRoutes[i], oldRoutes[j]) > 0 - }) - sort.Slice(newRoutes, func(i, j int) bool { - return comparePrefix(newRoutes[i], newRoutes[j]) > 0 - }) - - if !xslices.Equal(oldRoutes, newRoutes) { - return true, true - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to eachother. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if len(old.Services) != len(new.Services) { - return true, false - } - - return false, false -} - -// TODO(kradalby): Remove after go 1.23, will be in stdlib. -// Compare returns an integer comparing two prefixes. -// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2. -// Prefixes sort first by validity (invalid before valid), then -// address family (IPv4 before IPv6), then prefix length, then -// address. -func comparePrefix(p, p2 netip.Prefix) int { - if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 { - return c - } - if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 { - return c - } - return p.Addr().Compare(p2.Addr()) -} From 973da20704b8a3831f90e6d918439da7051cd6cf Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sat, 18 May 2024 12:22:39 +0200 Subject: [PATCH 4/7] fix unicode --- hscontrol/db/node.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index a5a5c9a622..259a3e13fd 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -692,7 +692,7 @@ func DeleteExpiredEphemeralNodes(tx *gorm.DB, log.Error(). Err(err). Str("node", node.Hostname). - Msg("� Cannot delete ephemeral node from the database") + Msg("🤮 Cannot delete ephemeral node from the database") } changedNodes = append(changedNodes, changed...) From 02a21e4e36d1791e7b4f181d241ff935ce61fb02 Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sat, 18 May 2024 12:23:25 +0200 Subject: [PATCH 5/7] remove unnecessary function call --- hscontrol/db/node_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 190e7a57b2..e95ee4ae33 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -312,7 +312,6 @@ func (s *Suite) TestExpireNode(c *check.C) { c.Assert(nodeFromDB.IsExpired(), check.Equals, true) } -func (s *Suite) TestSerdeAddressStringSlice(c *check.C) { func (s *Suite) TestGenerateGivenName(c *check.C) { user1, err := db.CreateUser("user-1") c.Assert(err, check.IsNil) From f248e0e3d00f7bb16aac857a3bb626a2927b25f7 Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sat, 18 May 2024 12:26:10 +0200 Subject: [PATCH 6/7] remove unnecessary comment --- hscontrol/db/node.go | 1 - 1 file changed, 1 deletion(-) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 259a3e13fd..c675dc7c3c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -725,7 +725,6 @@ func ExpireExpiredNodes(tx *gorm.DB, NodeID: tailcfg.NodeID(node.ID), KeyExpiry: node.Expiry, }) - // and there is no point in notifying twice. } } From 415b3b29cbd03cde9aff1668566226b31d5253d2 Mon Sep 17 00:00:00 2001 From: ohdearaugustin Date: Sat, 18 May 2024 12:34:06 +0200 Subject: [PATCH 7/7] remove unnecessary comment --- hscontrol/poll.go | 1 - 1 file changed, 1 deletion(-) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 3e66a36832..e3137cc6ad 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -192,7 +192,6 @@ func (m *mapSession) serve() { // start-up before their first real endpoint update. if m.isReadOnlyUpdate() { m.handleReadOnlyRequest() - // update ACLRules with peer information (to update server tags if necessary) return }