diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index 399bc95..e0d41ae 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b5b8ced..31987db 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Build @@ -24,7 +24,7 @@ jobs: mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: linux-latest path: release @@ -37,7 +37,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Build @@ -55,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: windows-latest path: build @@ -64,18 +64,18 @@ jobs: name: Build Universal Darwin env: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} - runs-on: macos-11 + runs-on: macos-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v2 + uses: Apple-Actions/import-codesign-certs@v3 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} @@ -104,11 +104,57 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: darwin-latest path: ./release/* + build-docker: + name: Create and Upload Docker Images + # Technically we only need build-linux to succeed, but if any platforms fail we'll + # want to investigate and restart the build + needs: [build-linux, build-darwin, build-windows] + runs-on: ubuntu-latest + env: + HAS_DOCKER_CREDS: ${{ vars.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} + # XXX It's not possible to write a conditional here, so instead we do it on every step + #if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + steps: + # Be sure to checkout the code before downloading artifacts, or they will + # be overwritten + - name: Checkout code + if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + uses: actions/checkout@v4 + + - name: Download artifacts + if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + uses: actions/download-artifact@v4 + with: + name: linux-latest + path: artifacts + + - name: Login to Docker Hub + if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + uses: docker/login-action@v3 + with: + username: ${{ vars.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Set up Docker Buildx + if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + uses: docker/setup-buildx-action@v3 + + - name: Build and push images + if: ${{ env.HAS_DOCKER_CREDS == 'true' }} + env: + DOCKER_IMAGE_REPO: ${{ vars.DOCKER_IMAGE_REPO || 'nebulaoss/nebula' }} + DOCKER_IMAGE_TAG: ${{ vars.DOCKER_IMAGE_TAG || 'latest' }} + run: | + mkdir -p build/linux-{amd64,arm64} + tar -zxvf artifacts/nebula-linux-amd64.tar.gz -C build/linux-amd64/ + tar -zxvf artifacts/nebula-linux-arm64.tar.gz -C build/linux-arm64/ + docker buildx build . --push -f docker/Dockerfile --platform linux/amd64,linux/arm64 --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:${GITHUB_REF#refs/tags/v}" + release: name: Create and Upload Release needs: [build-linux, build-darwin, build-windows] @@ -117,7 +163,7 @@ jobs: - uses: actions/checkout@v4 - name: Download artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: path: artifacts diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml new file mode 100644 index 0000000..2b5e6e9 --- /dev/null +++ b/.github/workflows/smoke-extra.yml @@ -0,0 +1,48 @@ +name: smoke-extra +on: + push: + branches: + - master + pull_request: + types: [opened, synchronize, labeled, reopened] + paths: + - '.github/workflows/smoke**' + - '**Makefile' + - '**.go' + - '**.proto' + - 'go.mod' + - 'go.sum' +jobs: + + smoke-extra: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: Run extra smoke tests + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + check-latest: true + + - name: install vagrant + run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox + + - name: freebsd-amd64 + run: make smoke-vagrant/freebsd-amd64 + + - name: openbsd-amd64 + run: make smoke-vagrant/openbsd-amd64 + + - name: netbsd-amd64 + run: make smoke-vagrant/netbsd-amd64 + + - name: linux-386 + run: make smoke-vagrant/linux-386 + + - name: linux-amd64-ipv6disable + run: make smoke-vagrant/linux-amd64-ipv6disable + + timeout-minutes: 30 diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index f02b1ba..54833bd 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: build diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh index 9cbb200..c546653 100755 --- a/.github/workflows/smoke/build.sh +++ b/.github/workflows/smoke/build.sh @@ -11,6 +11,11 @@ mkdir ./build cp ../../../../build/linux-amd64/nebula . cp ../../../../build/linux-amd64/nebula-cert . + if [ "$1" ] + then + cp "../../../../build/$1/nebula" "$1-nebula" + fi + HOST="lighthouse1" \ AM_LIGHTHOUSE=true \ ../genconfig.sh >lighthouse1.yml diff --git a/.github/workflows/smoke/genconfig.sh b/.github/workflows/smoke/genconfig.sh index 373ea5f..16e768e 100755 --- a/.github/workflows/smoke/genconfig.sh +++ b/.github/workflows/smoke/genconfig.sh @@ -47,7 +47,7 @@ listen: port: ${LISTEN_PORT:-4242} tun: - dev: ${TUN_DEV:-nebula1} + dev: ${TUN_DEV:-tun0} firewall: inbound_action: reject diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh index 8926091..9c113e1 100755 --- a/.github/workflows/smoke/smoke-relay.sh +++ b/.github/workflows/smoke/smoke-relay.sh @@ -76,7 +76,7 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 1 +sleep 5 if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh new file mode 100755 index 0000000..76cf72f --- /dev/null +++ b/.github/workflows/smoke/smoke-vagrant.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +set -e -x + +set -o pipefail + +export VAGRANT_CWD="$PWD/vagrant-$1" + +mkdir -p logs + +cleanup() { + echo + echo " *** cleanup" + echo + + set +e + if [ "$(jobs -r)" ] + then + docker kill lighthouse1 host2 + fi + vagrant destroy -f +} + +trap cleanup EXIT + +CONTAINER="nebula:${NAME:-smoke}" + +docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test +docker run --name host2 --rm "$CONTAINER" -config host2.yml -test + +vagrant up +vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" + +docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & +sleep 1 +docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & +sleep 1 +vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" & +sleep 15 + +# grab tcpdump pcaps for debugging +docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & +docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & +docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & +docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & +# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & +# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & + +docker exec host2 ncat -nklv 0.0.0.0 2000 & +vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & +#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & +#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & + +set +x +echo +echo " *** Testing ping from lighthouse1" +echo +set -x +docker exec lighthouse1 ping -c1 192.168.100.2 +docker exec lighthouse1 ping -c1 192.168.100.3 + +set +x +echo +echo " *** Testing ping from host2" +echo +set -x +docker exec host2 ping -c1 192.168.100.1 +# Should fail because not allowed by host3 inbound firewall +! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 + +set +x +echo +echo " *** Testing ncat from host2" +echo +set -x +# Should fail because not allowed by host3 inbound firewall +#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 +#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 + +set +x +echo +echo " *** Testing ping from host3" +echo +set -x +vagrant ssh -c "ping -c1 192.168.100.1" +vagrant ssh -c "ping -c1 192.168.100.2" + +set +x +echo +echo " *** Testing ncat from host3" +echo +set -x +#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" +#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 + +vagrant ssh -c "sudo xargs kill &2 + exit 1 +fi diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh index 3177255..6d04027 100755 --- a/.github/workflows/smoke/smoke.sh +++ b/.github/workflows/smoke/smoke.sh @@ -129,7 +129,7 @@ docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' -sleep 1 +sleep 5 if [ "$(jobs -r)" ] then diff --git a/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile new file mode 100644 index 0000000..c8a4c64 --- /dev/null +++ b/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile @@ -0,0 +1,7 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +Vagrant.configure("2") do |config| + config.vm.box = "generic/freebsd14" + + config.vm.synced_folder "../build", "/nebula", type: "rsync" +end diff --git a/.github/workflows/smoke/vagrant-linux-386/Vagrantfile b/.github/workflows/smoke/vagrant-linux-386/Vagrantfile new file mode 100644 index 0000000..4b1d0bd --- /dev/null +++ b/.github/workflows/smoke/vagrant-linux-386/Vagrantfile @@ -0,0 +1,7 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +Vagrant.configure("2") do |config| + config.vm.box = "ubuntu/xenial32" + + config.vm.synced_folder "../build", "/nebula" +end diff --git a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile new file mode 100644 index 0000000..89f9477 --- /dev/null +++ b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile @@ -0,0 +1,16 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +Vagrant.configure("2") do |config| + config.vm.box = "ubuntu/jammy64" + + config.vm.synced_folder "../build", "/nebula" + + config.vm.provision :shell do |shell| + shell.inline = <<-EOF + sed -i 's/GRUB_CMDLINE_LINUX=""/GRUB_CMDLINE_LINUX="ipv6.disable=1"/' /etc/default/grub + update-grub + EOF + shell.privileged = true + shell.reboot = true + end +end diff --git a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile new file mode 100644 index 0000000..14ba2ce --- /dev/null +++ b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile @@ -0,0 +1,7 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +Vagrant.configure("2") do |config| + config.vm.box = "generic/netbsd9" + + config.vm.synced_folder "../build", "/nebula", type: "rsync" +end diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile new file mode 100644 index 0000000..e4f4104 --- /dev/null +++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile @@ -0,0 +1,7 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +Vagrant.configure("2") do |config| + config.vm.box = "generic/openbsd7" + + config.vm.synced_folder "../build", "/nebula", type: "rsync" +end diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 34fe5f3..65a6e3e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Build @@ -40,10 +40,10 @@ jobs: - name: Build test mobile run: make build-test-mobile - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: e2e packet flow - path: e2e/mermaid/ + name: e2e packet flow linux-latest + path: e2e/mermaid/linux-latest if-no-files-found: warn test-linux-boringcrypto: @@ -55,7 +55,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Build @@ -72,14 +72,14 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, macos-11] + os: [windows-latest, macos-latest] steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version-file: 'go.mod' + go-version: '1.22' check-latest: true - name: Build nebula @@ -97,8 +97,8 @@ jobs: - name: End 2 end run: make e2evv - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: - name: e2e packet flow - path: e2e/mermaid/ + name: e2e packet flow ${{ matrix.os }} + path: e2e/mermaid/${{ matrix.os }} if-no-files-found: warn diff --git a/CHANGELOG.md b/CHANGELOG.md index 71c3ed4..f763b69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,92 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.9.3] - 2024-06-06 + +### Fixed + +- Initialize messageCounter to 2 instead of verifying later. (#1156) + +## [1.9.2] - 2024-06-03 + +### Fixed + +- Ensure messageCounter is set before handshake is complete. (#1154) + +## [1.9.1] - 2024-05-29 + +### Fixed + +- Fixed a potential deadlock in GetOrHandshake. (#1151) + +## [1.9.0] - 2024-05-07 + +### Deprecated + +- This release adds a new setting `default_local_cidr_any` that defaults to + true to match previous behavior, but will default to false in the next + release (1.10). When set to false, `local_cidr` is matched correctly for + firewall rules on hosts acting as unsafe routers, and should be set for any + firewall rules you want to allow unsafe route hosts to access. See the issue + and example config for more details. (#1071, #1099) + +### Added + +- Nebula now has an official Docker image `nebulaoss/nebula` that is + distroless and contains just the `nebula` and `nebula-cert` binaries. You + can find it here: https://hub.docker.com/r/nebulaoss/nebula (#1037) + +- Experimental binaries for `loong64` are now provided. (#1003) + +- Added example service script for OpenRC. (#711) + +- The SSH daemon now supports inlined host keys. (#1054) + +- The SSH daemon now supports certificates with `sshd.trusted_cas`. (#1098) + +### Changed + +- Config setting `tun.unsafe_routes` is now reloadable. (#1083) + +- Small documentation and internal improvements. (#1065, #1067, #1069, #1108, + #1109, #1111, #1135) + +- Various dependency updates. (#1139, #1138, #1134, #1133, #1126, #1123, #1110, + #1094, #1092, #1087, #1086, #1085, #1072, #1063, #1059, #1055, #1053, #1047, + #1046, #1034, #1022) + +### Removed + +- Support for the deprecated `local_range` option has been removed. Please + change to `preferred_ranges` (which is also now reloadable). (#1043) + +- We are now building with go1.22, which means that for Windows you need at + least Windows 10 or Windows Server 2016. This is because support for earlier + versions was removed in Go 1.21. See https://go.dev/doc/go1.21#windows (#981) + +- Removed vagrant example, as it was unmaintained. (#1129) + +- Removed Fedora and Arch nebula.service files, as they are maintained in the + upstream repos. (#1128, #1132) + +- Remove the TCP round trip tracking metrics, as they never had correct data + and were an experiment to begin with. (#1114) + +### Fixed + +- Fixed a potential deadlock introduced in 1.8.1. (#1112) + +- Fixed support for Linux when IPv6 has been disabled at the OS level. (#787) + +- DNS will return NXDOMAIN now when there are no results. (#845) + +- Allow `::` in `lighthouse.dns.host`. (#1115) + +- Capitalization of `NotAfter` fixed in DNS TXT response. (#1127) + +- Don't log invalid certificates. It is untrusted data and can cause a large + volume of logs. (#1116) + ## [1.8.2] - 2024-01-08 ### Fixed @@ -558,7 +644,11 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.2...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.3...HEAD +[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 +[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 +[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 +[1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 [1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 [1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0 diff --git a/LOGGING.md b/LOGGING.md index bd2fdef..e2508c8 100644 --- a/LOGGING.md +++ b/LOGGING.md @@ -33,6 +33,5 @@ l.WithError(err). WithField("vpnIp", IntIp(hostinfo.hostId)). WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix"}). - WithField("cert", remoteCert). Info("Invalid certificate from host") ``` \ No newline at end of file diff --git a/Makefile b/Makefile index a02a6ec..0d0943f 100644 --- a/Makefile +++ b/Makefile @@ -1,22 +1,14 @@ -GOMINVERSION = 1.20 NEBULA_CMD_PATH = "./cmd/nebula" -GO111MODULE = on -export GO111MODULE CGO_ENABLED = 0 export CGO_ENABLED # Set up OS specific bits ifeq ($(OS),Windows_NT) - #TODO: we should be able to ditch awk as well - GOVERSION := $(shell go version | awk "{print substr($$3, 3)}") - GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1) NEBULA_CMD_SUFFIX = .exe NULL_FILE = nul # RIO on windows does pointer stuff that makes go vet angry VET_FLAGS = -unsafeptr=false else - GOVERSION := $(shell go version | awk '{print substr($$3, 3)}') - GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)") NEBULA_CMD_SUFFIX = NULL_FILE = /dev/null endif @@ -30,6 +22,9 @@ ifndef BUILD_NUMBER endif endif +DOCKER_IMAGE_REPO ?= nebulaoss/nebula +DOCKER_IMAGE_TAG ?= latest + LDFLAGS = -X main.Build=$(BUILD_NUMBER) ALL_LINUX = linux-amd64 \ @@ -44,7 +39,8 @@ ALL_LINUX = linux-amd64 \ linux-mips64 \ linux-mips64le \ linux-mips-softfloat \ - linux-riscv64 + linux-riscv64 \ + linux-loong64 ALL_FREEBSD = freebsd-amd64 \ freebsd-arm64 @@ -82,8 +78,12 @@ e2evvvv: e2ev e2e-bench: TEST_FLAGS = -bench=. -benchmem -run=^$ e2e-bench: e2e +DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert + all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) +docker: docker/linux-$(shell go env GOARCH) + release: $(ALL:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) @@ -156,6 +156,9 @@ build/nebula-%.tar.gz: build/%/nebula build/%/nebula-cert build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe +docker/%: build/%/nebula build/%/nebula-cert + docker build . $(DOCKER_BUILD_ARGS) -f docker/Dockerfile --platform "$(subst -,/,$*)" --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:$(BUILD_NUMBER)" + vet: go vet $(VET_FLAGS) -v ./... @@ -219,6 +222,10 @@ smoke-docker-race: BUILD_ARGS = -race smoke-docker-race: CGO_ENABLED = 1 smoke-docker-race: smoke-docker +smoke-vagrant/%: bin-docker build/%/nebula + cd .github/workflows/smoke/ && ./build.sh $* + cd .github/workflows/smoke/ && ./smoke-vagrant.sh $* + .FORCE: -.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html +.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% .DEFAULT_GOAL := bin diff --git a/README.md b/README.md index 51e913d..65ea91f 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,11 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for $ brew install nebula ``` +- [Docker](https://hub.docker.com/r/nebulaoss/nebula) + ``` + $ docker pull nebulaoss/nebula + ``` + #### Mobile - [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200) diff --git a/allow_list.go b/allow_list.go index 9186b2f..90e0de2 100644 --- a/allow_list.go +++ b/allow_list.go @@ -2,17 +2,16 @@ package nebula import ( "fmt" - "net" + "net/netip" "regexp" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6[bool] + cidrTree *bart.Table[bool] } type RemoteAllowList struct { @@ -20,7 +19,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6[*AllowList] + insideAllowLists *bart.Table[*AllowList] } type LocalAllowList struct { @@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6[bool]() + tree := new(bart.Table[bool]) // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } + ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) + // TODO: should we error on duplicate CIDRs in the config? - tree.AddCIDR(ipNet, value) + tree.Insert(ipNet, value) - maskBits, maskSize := ipNet.Mask.Size() + maskBits := ipNet.Bits() var rules *allowListRules - if maskSize == 32 { + if ipNet.Addr().Is4() { rules = &rules4 } else { rules = &rules6 @@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules4.defaultSet { if rules4.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") - tree.AddCIDR(zeroCIDR, !rules4.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) } @@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules6.defaultSet { if rules6.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("::/0") - tree.AddCIDR(zeroCIDR, !rules6.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) } @@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { +func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6[*AllowList]() + remoteAllowRanges := new(bart.Table[*AllowList]) rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error return nil, err } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } - remoteAllowRanges.AddCIDR(ipNet, allowList) + remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList) } return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip net.IP) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContains(ip) - return result -} - -func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContainsIpV4(ip) - return result -} - -func (al *AllowList) AllowIpV6(hi, lo uint64) bool { +func (al *AllowList) Allow(ip netip.Addr) bool { if al == nil { return true } - _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + result, _ := al.cidrTree.Lookup(ip) return result } -func (al *LocalAllowList) Allow(ip net.IP) bool { +func (al *LocalAllowList) Allow(ip netip.Addr) bool { if al == nil { return true } @@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { +func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { +func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { if !al.getInsideAllowList(vpnIp).Allow(ip) { return false } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { - return false - } - return al.AllowList.AllowIpV4(ip) -} - -func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { - return false - } - return al.AllowList.AllowIpV6(hi, lo) -} - -func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { +func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { if al.insideAllowLists != nil { - ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnIp) if ok { return inside } diff --git a/allow_list_test.go b/allow_list_test.go index 334cb60..c8b3d08 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -1,11 +1,11 @@ package nebula import ( - "net" + "net/netip" "regexp" "testing" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - - tree := cidr.NewTree6[bool]() - tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) - tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) - tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) - tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) - tree.AddCIDR(cidr.Parse("::1/128"), true) - tree.AddCIDR(cidr.Parse("::2/128"), false) + assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) + + tree := new(bart.Table[bool]) + tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) + tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false) + tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true) + tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false) + tree.Insert(netip.MustParsePrefix("::1/128"), true) + tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1"))) - assert.Equal(t, true, al.Allow(net.ParseIP("::1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { diff --git a/calculated_remote.go b/calculated_remote.go index 38f5bea..ae2ed50 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -1,41 +1,36 @@ package nebula import ( + "encoding/binary" "fmt" "math" "net" + "net/netip" "strconv" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This allows us to "guess" what the remote might be for a host while we wait // for the lighthouse response. See "lighthouse.calculated_remotes" in the // example config file. type calculatedRemote struct { - ipNet net.IPNet - maskIP iputil.VpnIp - mask iputil.VpnIp - port uint32 + ipNet netip.Prefix + mask netip.Prefix + port uint32 } -func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { - // Ensure this is an IPv4 mask that we expect - ones, bits := ipNet.Mask.Size() - if ones == 0 || bits != 32 { - return nil, fmt.Errorf("invalid mask: %v", ipNet) - } +func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ - ipNet: *ipNet, - maskIP: iputil.Ip2VpnIp(ipNet.IP), - mask: iputil.Ip2VpnIp(ipNet.Mask), - port: uint32(port), + ipNet: maskCidr, + mask: masked, + port: uint32(port), }, nil } @@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { +func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes // of the overlay IP - masked := (c.maskIP & c.mask) | (ip & ^c.mask) + if c.ipNet.Addr().Is4() { + return c.apply4(ip) + } + return c.apply6(ip) +} + +func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK this can be less crappy + maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + mask := binary.BigEndian.Uint32(maskb[:]) + + b := c.mask.Addr().As4() + maskIp := binary.BigEndian.Uint32(b[:]) + + b = ip.As4() + intIp := binary.BigEndian.Uint32(b[:]) + + return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} +} - return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK + panic("Can not calculate ipv6 remote addresses") } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[any]any) if !ok { @@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } + //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } - calculatedRemotes.AddCIDR(ipNet, entry) + calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil @@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawMask) + maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } @@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(ipNet, port) + return newCalculatedRemote(maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 2ddebca..6ff1cb0 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -1,27 +1,25 @@ package nebula import ( - "net" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + ipNet, err := netip.ParsePrefix("192.168.1.0/24") require.NoError(t, err) c, err := newCalculatedRemote(ipNet, 4242) require.NoError(t, err) - input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + input, err := netip.ParseAddr("10.0.10.182") + assert.NoError(t, err) - expected := &Ip4AndPort{ - Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), - Port: 4242, - } + expected, err := netip.ParseAddr("192.168.1.182") + assert.NoError(t, err) - assert.Equal(t, expected, c.Apply(input)) + assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) } diff --git a/cert/cert.go b/cert/cert.go index 4f1b776..a0164f7 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -324,7 +324,7 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) { return k.Bytes, r, nil } -// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its +// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its // protobuf-generated struct. func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { if len(b) == 0 { diff --git a/cidr/parse.go b/cidr/parse.go deleted file mode 100644 index 74367f6..0000000 --- a/cidr/parse.go +++ /dev/null @@ -1,10 +0,0 @@ -package cidr - -import "net" - -// Parse is a convenience function that returns only the IPNet -// This function ignores errors since it is primarily a test helper, the result could be nil -func Parse(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -} diff --git a/cidr/tree4.go b/cidr/tree4.go deleted file mode 100644 index c5ebe54..0000000 --- a/cidr/tree4.go +++ /dev/null @@ -1,203 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -type Node[T any] struct { - left *Node[T] - right *Node[T] - parent *Node[T] - hasValue bool - value T -} - -type entry[T any] struct { - CIDR *net.IPNet - Value T -} - -type Tree4[T any] struct { - root *Node[T] - list []entry[T] -} - -const ( - startbit = iputil.VpnIp(0x80000000) -) - -func NewTree4[T any]() *Tree4[T] { - tree := new(Tree4[T]) - tree.root = &Node[T]{} - tree.list = []entry[T]{} - return tree -} - -func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { - bit := startbit - node := tree.root - next := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // We already have this range so update the value - if next != nil { - addCIDR := cidr.String() - for i, v := range tree.list { - if addCIDR == v.CIDR.String() { - tree.list = append(tree.list[:i], tree.list[i+1:]...) - break - } - } - - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) - node.value = val - node.hasValue = true - return - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) -} - -// Contains finds the first match, which may be the least specific -func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - return true, node.value - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - - } - - return false, value -} - -// MostSpecificContains finds the most specific match -func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -type eachFunc[T any] func(T) bool - -// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete -// The final return value will be true if the provided function returned true -func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - // If the each func returns true then we can exit the loop - if each(node.value) { - return true - } - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return false -} - -// GetCIDR returns the entry added by the most recent matching AddCIDR call -func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { - bit := startbit - node := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for node != nil && bit&mask != 0 { - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit = bit >> 1 - } - - if bit&mask == 0 && node != nil { - value = node.value - ok = node.hasValue - } - - return ok, value -} - -// List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4[T]) List() []entry[T] { - return tree.list -} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go deleted file mode 100644 index cd17be4..0000000 --- a/cidr/tree4_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package cidr - -import ( - "net" - "testing" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_List(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/16"), "1") - tree.AddCIDR(Parse("1.0.0.0/8"), "2") - tree.AddCIDR(Parse("1.0.0.0/16"), "3") - tree.AddCIDR(Parse("1.0.0.0/16"), "4") - list := tree.List() - assert.Len(t, list, 2) - assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", list[0].Value) - assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", list[1].Value) -} - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4a", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.0/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestTree4_GetCIDR(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IPNet *net.IPNet - }{ - {true, "1", Parse("1.0.0.0/8")}, - {true, "2", Parse("2.1.0.0/16")}, - {true, "3", Parse("3.1.1.0/24")}, - {true, "4a", Parse("4.1.1.0/24")}, - {true, "4b", Parse("4.1.1.1/32")}, - {true, "4c", Parse("4.1.2.1/32")}, - {true, "5", Parse("254.0.0.0/4")}, - {false, "", Parse("2.0.0.0/8")}, - } - - for _, tt := range tests { - ok, r := tree.GetCIDR(tt.IPNet) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -} diff --git a/cidr/tree6.go b/cidr/tree6.go deleted file mode 100644 index 3f2cd2a..0000000 --- a/cidr/tree6.go +++ /dev/null @@ -1,189 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -const startbit6 = uint64(1 << 63) - -type Tree6[T any] struct { - root4 *Node[T] - root6 *Node[T] -} - -func NewTree6[T any]() *Tree6[T] { - tree := new(Tree6[T]) - tree.root4 = &Node[T]{} - tree.root6 = &Node[T]{} - return tree -} - -func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { - var node, next *Node[T] - - cidrIP, ipv4 := isIPV4(cidr.IP) - if ipv4 { - node = tree.root4 - next = tree.root4 - - } else { - node = tree.root6 - next = tree.root6 - } - - for i := 0; i < len(cidrIP); i += 4 { - ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) - mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) - bit := startbit - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true -} - -// Finds the most specific match -func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { - var node *Node[T] - - wholeIP, ipv4 := isIPV4(ip) - if ipv4 { - node = tree.root4 - } else { - node = tree.root6 - } - - for i := 0; i < len(wholeIP); i += 4 { - ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) - bit := startbit - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root4 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { - ip := hi - node := tree.root6 - - for i := 0; i < 2; i++ { - bit := startbit6 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - ip = lo - } - - return ok, value -} - -func isIPV4(ip net.IP) (net.IP, bool) { - if len(ip) == net.IPv4len { - return ip, true - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16], true - } - - return ip, false -} - -func isZeros(p net.IP) bool { - for i := 0; i < len(p); i++ { - if p[i] != 0 { - return false - } - } - return true -} diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go deleted file mode 100644 index eb159ec..0000000 --- a/cidr/tree6_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package cidr - -import ( - "encoding/binary" - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.1/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree6[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - tree.AddCIDR(Parse("::/0"), "cool6") - ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("::")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) -} - -func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - } - - for _, tt := range tests { - ip := net.ParseIP(tt.IP) - hi := binary.BigEndian.Uint64(ip[:8]) - lo := binary.BigEndian.Uint64(ip[8:]) - - ok, r := tree.MostSpecificContainsIpV6(hi, lo) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 69df4ab..4e5d51d 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -180,9 +180,15 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error if err != nil { return fmt.Errorf("error while generating ecdsa keys: %s", err) } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L60 - rawPriv = key.D.FillBytes(make([]byte, 32)) - pub = elliptic.Marshal(elliptic.P256(), key.X, key.Y) + + // ecdh.PrivateKey lets us get at the encoded bytes, even though + // we aren't using ECDH here. + eKey, err := key.ECDH() + if err != nil { + return fmt.Errorf("error while converting ecdsa key: %s", err) + } + rawPriv = eKey.Bytes() + pub = eKey.PublicKey().Bytes() } nc := cert.NebulaCertificate{ diff --git a/connection_manager.go b/connection_manager.go index f5dd594..d2e8616 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -3,6 +3,8 @@ package nebula import ( "bytes" "context" + "encoding/binary" + "net/netip" "sync" "time" @@ -10,8 +12,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) type trafficDecision int @@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) var index uint32 - var relayFrom iputil.VpnIp - var relayTo iputil.VpnIp + var relayFrom netip.Addr + var relayTo netip.Addr switch { case ok && existing.State == Established: // This relay already exists in newhostinfo, then do nothing. @@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = existing.PeerIp case ForwardingType: relayFrom = existing.PeerIp @@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = r.PeerIp case ForwardingType: relayFrom = r.PeerIp @@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } + //TODO: IPV6-WORK + relayFromB := relayFrom.As4() + relayToB := relayTo.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(relayFrom), - RelayToIp: uint32(relayTo), + RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), + RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } msg, err := req.Marshal() if err != nil { @@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + "relayFrom": req.RelayFromIp, + "relayTo": req.RelayToIp, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": newhostinfo.vpnIp}). @@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp < n.intf.myVpnIp { + if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // The remotes vpn ip is lower than mine. I will not flip. @@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } diff --git a/connection_manager_test.go b/connection_manager_test.go index a2607a2..5f97cad 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -5,28 +5,26 @@ import ( "crypto/ed25519" "crypto/rand" "net" + "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) -var vpnIp iputil.VpnIp - func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, - queryChan: make(chan iputil.VpnIp, 10), + addrMap: map[netip.Addr]*RemoteList{}, + queryChan: make(chan netip.Addr, 10), } - lighthouses := map[iputil.VpnIp]struct{}{} - staticList := map[iputil.VpnIp]struct{}{} + lighthouses := map[netip.Addr]struct{}{} + staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) @@ -37,13 +35,15 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, @@ -118,12 +118,15 @@ func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, vpncidr, preferredRanges) + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) + cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, @@ -207,10 +210,12 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, } - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} - hostMap := NewHostMap(l, vpncidr, preferredRanges) + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} + hostMap := newHostMap(l, vpncidr) + hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) diff --git a/connection_state.go b/connection_state.go index 8ef8b3a..1dd3c8c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -72,6 +72,8 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i window: b, myCert: certState.Certificate, } + // always start the counter from 2, as packet 1 and packet 2 are handshake packets. + ci.messageCounter.Add(2) return ci } diff --git a/control.go b/control.go index 1e27b0f..54d528d 100644 --- a/control.go +++ b/control.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "net" + "net/netip" "os" "os/signal" "syscall" @@ -10,9 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -21,10 +19,10 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + QueryVpnIp(vpnIp netip.Addr) *HostInfo ForEachIndex(each controlEach) ForEachVpnIp(each controlEach) - GetPreferredRanges() []*net.IPNet + GetPreferredRanges() []netip.Prefix } type Control struct { @@ -39,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp net.IP `json:"vpnIp"` + VpnIp netip.Addr `json:"vpnIp"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udp.Addr `json:"remoteAddrs"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udp.Addr `json:"currentRemote"` - CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -131,8 +129,46 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } } +// GetCertByVpnIp returns a single tunnels hostInfo, or nil if not found +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { + if c.f.myVpnNet.Addr() == vpnIp { + return c.f.pki.GetCertState().Certificate + } + hi := c.f.hostMap.QueryVpnIp(vpnIp) + if hi == nil { + return nil + } + return hi.GetCert() +} + +// CreateTunnel creates a new tunnel to the given vpn ip. +func (c *Control) CreateTunnel(vpnIp netip.Addr) { + c.f.handshakeManager.StartHandshake(vpnIp, nil) +} + +// PrintTunnel creates a new tunnel to the given vpn ip. +func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { + hi := c.f.hostMap.QueryVpnIp(vpnIp) + if hi == nil { + return nil + } + chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges()) + return &chi +} + +// QueryLighthouse queries the lighthouse. +func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { + hi := c.f.lightHouse.Query(vpnIp) + if hi == nil { + return nil + } + return hi.CopyCache() +} + // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -145,24 +181,26 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH return nil } - ch := copyHostInfo(h, c.f.hostMap.preferredRanges) + ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges()) return &ch } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return nil } - hostInfo.SetRemote(addr.Copy()) - ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges) + hostInfo.SetRemote(addr) + ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return false @@ -205,7 +243,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { } // Learn which hosts are being used as relays, so we can shut them down last. - relayingHosts := map[iputil.VpnIp]*HostInfo{} + relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { @@ -236,15 +274,16 @@ func (c *Control) Device() overlay.Device { return c.f.inside } -func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { +func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ - VpnIp: h.vpnIp.ToIP(), + VpnIp: h.vpnIp, LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), + CurrentRemote: h.remote, } if h.ConnectionState != nil { @@ -255,10 +294,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi.Cert = c.Copy() } - if h.remote != nil { - chi.CurrentRemote = h.remote.Copy() - } - return chi } diff --git a/control_test.go b/control_test.go index 847332b..fbf29c0 100644 --- a/control_test.go +++ b/control_test.go @@ -2,15 +2,14 @@ package nebula import ( "net" + "net/netip" "reflect" "testing" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) @@ -18,16 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0)) - remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) - remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + remote1 := netip.MustParseAddrPort("0.0.0.100:4444") + remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") + ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), + IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ - IP: net.ParseIP("1:2:3:4:5:6:7:8"), + IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } @@ -48,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) - remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + + vpnIp, ok := netip.AddrFromSlice(ipNet.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -58,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: vpnIp, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) + vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -74,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet2.IP), + vpnIp: vpnIp2, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -89,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) + thi := c.GetHostInfoByVpnIp(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: net.IPv4(1, 2, 3, 4).To4(), + VpnIp: vpnIp, LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udp.Addr{remote2, remote1}, + RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), - CurrentRelaysToMe: []iputil.VpnIp{}, - CurrentRelaysThroughMe: []iputil.VpnIp{}, + CurrentRemote: remote1, + CurrentRelaysToMe: []netip.Addr{}, + CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - test.AssertDeepCopyEqual(t, &expectedInfo, thi) + assert.EqualValues(t, &expectedInfo, thi) + //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here + //test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) + thi = c.GetHostInfoByVpnIp(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index b786ba3..d46540f 100644 --- a/control_tester.go +++ b/control_tester.go @@ -4,14 +4,13 @@ package nebula import ( - "net" + "net/netip" "github.com/slackhq/nebula/cert" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse -func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { +func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - if v4 := toAddr.IP.To4(); v4 != nil { - remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + if toAddr.Addr().Is4() { + remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder -func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { +func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - uVpnIp := []uint32{} - for _, rVPnIp := range relayVpnIps { - uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp))) - } - - remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp) + remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { +func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { + //TODO: IPV6-WORK ip := layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().IP, - DstIP: toIp, + SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), + DstIP: toIp.Unmap().AsSlice(), } udp := layers.UDP{ @@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() iputil.VpnIp { - return c.f.myVpnIp +func (c *Control) GetVpnIp() netip.Addr { + return c.f.myVpnNet.Addr() } -func (c *Control) GetUDPAddr() string { - return c.f.outside.(*udp.TesterConn).Addr.String() +func (c *Control) GetUDPAddr() netip.AddrPort { + return c.f.outside.(*udp.TesterConn).Addr } -func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) +func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { + hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) if hostinfo == nil { return false } @@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate { return c.f.pki.GetCertState().Certificate } -func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { +func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dist/arch/nebula.service b/dist/arch/nebula.service deleted file mode 100644 index 831c71a..0000000 --- a/dist/arch/nebula.service +++ /dev/null @@ -1,15 +0,0 @@ -[Unit] -Description=Nebula overlay networking tool -Wants=basic.target network-online.target nss-lookup.target time-sync.target -After=basic.target network.target network-online.target - -[Service] -Type=notify -NotifyAccess=main -SyslogIdentifier=nebula -ExecReload=/bin/kill -HUP $MAINPID -ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml -Restart=always - -[Install] -WantedBy=multi-user.target diff --git a/dist/fedora/nebula.service b/dist/fedora/nebula.service deleted file mode 100644 index 0f947ea..0000000 --- a/dist/fedora/nebula.service +++ /dev/null @@ -1,16 +0,0 @@ -[Unit] -Description=Nebula overlay networking tool -Wants=basic.target network-online.target nss-lookup.target time-sync.target -After=basic.target network.target network-online.target -Before=sshd.service - -[Service] -Type=notify -NotifyAccess=main -SyslogIdentifier=nebula -ExecReload=/bin/kill -HUP $MAINPID -ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml -Restart=always - -[Install] -WantedBy=multi-user.target diff --git a/dns_server.go b/dns_server.go index 3109b4c..5fea65c 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,6 +3,7 @@ package nebula import ( "fmt" "net" + "net/netip" "strconv" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -42,21 +42,23 @@ func (d *dnsRecords) Query(data string) string { } func (d *dnsRecords) QueryCert(data string) string { - ip := net.ParseIP(data[:len(data)-1]) - if ip == nil { + ip, err := netip.ParseAddr(data[:len(data)-1]) + if err != nil { return "" } - iip := iputil.Ip2VpnIp(ip) - hostinfo := d.hostMap.QueryVpnIp(iip) + + hostinfo := d.hostMap.QueryVpnIp(ip) if hostinfo == nil { return "" } + q := hostinfo.GetCert() if q == nil { return "" } + cert := q.Details - c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) + c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) return c } @@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b := net.ParseIP(a) + b, err := netip.ParseAddr(a) + if err != nil { + return + } + // We don't answer these queries from non nebula nodes or localhost //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { @@ -96,6 +102,10 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } } + + if len(m.Answer) == 0 { + m.Rcode = dns.RcodeNameError + } } func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { @@ -129,7 +139,12 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { } func getDnsServerAddr(c *config.C) string { - return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) + dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) + // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. + if dnsHost == "[::]" { + dnsHost = "::" + } + return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } func startDns(l *logrus.Logger, c *config.C) { diff --git a/dns_server_test.go b/dns_server_test.go index 830dc8a..69f6ae8 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/miekg/dns" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { @@ -17,3 +19,40 @@ func TestParsequery(t *testing.T) { //parseQuery(m) } + +func Test_getDnsServerAddr(t *testing.T) { + c := config.NewC(nil) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "0.0.0.0", + "port": "1", + }, + } + assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "::", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "[::]", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) + + // Make sure whitespace doesn't mess us up + c.Settings["lighthouse"] = map[interface{}]interface{}{ + "dns": map[interface{}]interface{}{ + "host": "[::] ", + "port": "1", + }, + } + assert.Equal(t, "[::]:1", getDnsServerAddr(c)) +} diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..400e275 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,11 @@ +FROM gcr.io/distroless/static:latest + +ARG TARGETOS TARGETARCH +COPY build/$TARGETOS-$TARGETARCH/nebula /nebula +COPY build/$TARGETOS-$TARGETARCH/nebula-cert /nebula-cert + +VOLUME ["/config"] + +ENTRYPOINT ["/nebula"] +# Allow users to override the args passed to nebula +CMD ["-config", "/config/config.yml"] diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000..129744f --- /dev/null +++ b/docker/README.md @@ -0,0 +1,24 @@ +# NebulaOSS/nebula Docker Image + +## Building + +From the root of the repository, run `make docker`. + +## Running + +To run the built image, use the following command: + +``` +docker run \ + --name nebula \ + --network host \ + --cap-add NET_ADMIN \ + --volume ./config:/config \ + --rm \ + nebulaoss/nebula +``` + +A few notes: + +- The `NET_ADMIN` capability is necessary to create the tun adapter on the host (this is unnecessary if the tun device is disabled.) +- `--volume ./config:/config` should point to a directory that contains your `config.yml` and any other necessary files. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 59f1d0e..3d42a56 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -5,7 +5,7 @@ package e2e import ( "fmt" - "net" + "net/netip" "testing" "time" @@ -13,19 +13,18 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) { panic(err) } - if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } @@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() - myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() - theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,31 +360,31 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) { // ////TODO: assert hostmaps } + func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now break @@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) assert.Contains(t, c.Cert.Details.Groups, "new group") // We should only have a single tunnel now on both sides @@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) fmt.Println(tt1.LocalIndex, tt2.LocalIndex) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] if theirNewGroup { @@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") // We should only have a single tunnel now on both sides @@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() diff --git a/e2e/helpers.go b/e2e/helpers.go index 13146ab..71df805 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "io" "net" + "net/netip" "time" "github.com/slackhq/nebula/cert" @@ -12,7 +13,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] } if len(ips) > 0 { - nc.Details.Ips = ips + nc.Details.Ips = make([]*net.IPNet, len(ips)) + for i, ip := range ips { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.Details.Subnets = make([]*net.IPNet, len(subnets)) + for i, ip := range subnets { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(groups) > 0 { @@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { issuer, err := ca.Sha256Sum() if err != nil { panic(err) @@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - + ipb := ip.Addr().AsSlice() nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, + Name: name, + Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, + //Subnets: subnets, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b05c84a..527f55b 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -6,7 +6,7 @@ package e2e import ( "fmt" "io" - "net" + "net/netip" "os" "testing" "time" @@ -19,7 +19,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -27,15 +26,23 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - vpnIpNet.IP[1] += 128 - udpAddr := net.UDPAddr{ - IP: udpIp, - Port: 4242, + vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) + if err != nil { + panic(err) + } + + var udpAddr netip.AddrPort + if vpnIpNet.Addr().Is4() { + budpIp := vpnIpNet.Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnIpNet.Addr().As16() + budpIp[13] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) @@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u // "try_interval": "1s", //}, "listen": m{ - "host": udpAddr.IP.String(), - "port": udpAddr.Port, + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), @@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr, c + return control, vpnIpNet, udpAddr, c } type doneCb func() @@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) @@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) + hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") - hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) + hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") // Check that both vpn and real addr are correct assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - - assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") - assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") + assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") + assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") @@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB //checkIndexes("hmB", hmB, hAinB) } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { +func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") - assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") - assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") + assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 120be69..c14ab2e 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,11 +5,11 @@ package router import ( "fmt" + "net/netip" "sort" "strings" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { return r, globalLines } -func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { - keys := make([]iputil.VpnIp, 0, len(hosts)) +func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { + keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { - return keys[i] > keys[j] + return keys[i].Compare(keys[j]) > 0 }) return keys diff --git a/e2e/router/router.go b/e2e/router/router.go index 730853a..0890570 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -6,12 +6,11 @@ package router import ( "context" "fmt" - "net" + "net/netip" "os" "path/filepath" "reflect" "sort" - "strconv" "strings" "sync" "testing" @@ -21,7 +20,6 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) @@ -29,18 +27,18 @@ import ( type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? - controls map[string]*nebula.Control + controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address - inNat map[string]*nebula.Control + inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]net.UDPAddr + outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to - vpnControls map[iputil.VpnIp]*nebula.Control + vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry @@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { } r := &R{ - controls: make(map[string]*nebula.Control), - vpnControls: make(map[iputil.VpnIp]*nebula.Control), - inNat: make(map[string]*nebula.Control), - outNat: make(map[string]net.UDPAddr), + controls: make(map[netip.AddrPort]*nebula.Control), + vpnControls: make(map[netip.Addr]*nebula.Control), + inNat: make(map[netip.AddrPort]*nebula.Control), + outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { - panic("Duplicate listen address: " + addr) + panic("Duplicate listen address: " + addr.String()) } r.vpnControls[c.GetVpnIp()] = c @@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin -func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { +func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() - inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { - panic("Duplicate listen address inNat: " + inAddr) + panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } @@ -198,7 +196,7 @@ func (r *R) renderFlow() { panic(err) } - var participants = map[string]struct{}{} + var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") @@ -215,7 +213,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "-", 1) + sanAddr := strings.Replace(addr.String(), ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -252,9 +250,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp() > c[j].GetVpnIp() + return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 }) s := renderHostmaps(c...) @@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: - outAddr := sender.GetUDPAddr() r.Lock() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(sender.GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(cm[x].GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { panic(err) } - outAddr := sender.GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { - if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + if p.To == toAddr { return finish } @@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { r.Lock() p := rx.Interface().(*udp.Packet) - - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -697,12 +685,10 @@ func (r *R) FlushAll() { p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't FlushAll for host: " + p.To.String()) } r.Unlock() } @@ -710,28 +696,14 @@ func (r *R) FlushAll() { // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { - p.FromIp = newAddr.IP - p.FromPort = uint16(newAddr.Port) +func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - sHost, sPort, err := net.SplitHostPort(toAddr) - if err != nil { - panic(err) - } - - port, err := strconv.Atoi(sPort) - if err != nil { - panic(err) - } - - r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ - IP: net.ParseIP(sHost), - Port: port, - } + r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } @@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string { } from := "unknown" - if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { - from = c.GetUDPAddr() + srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) + if c, ok := r.vpnControls[srcAddr]; ok { + from = c.GetUDPAddr().String() } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) @@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string { return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()), diff --git a/examples/config.yml b/examples/config.yml index ff5b403..c74ffc6 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -167,8 +167,7 @@ punchy: # Preferred ranges is used to define a hint about the local network ranges, which speeds up discovering the fastest # path to a network adjacent nebula node. -# NOTE: the previous option "local_range" only allowed definition of a single range -# and has been deprecated for "preferred_ranges" +# This setting is reloadable. #preferred_ranges: ["172.16.0.0/24"] # sshd can expose informational and administrative functions via ssh. This can expose informational and administrative @@ -181,12 +180,15 @@ punchy: # A file containing the ssh host private key to use # A decent way to generate one: ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N "" < /dev/null #host_key: ./ssh_host_ed25519_key - # A file containing a list of authorized public keys + # Authorized users and their public keys #authorized_users: #- user: steeeeve # keys can be an array of strings or single string #keys: #- "ssh public key string" + # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access. + #trusted_cas: + #- "ssh public key string" # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: @@ -230,6 +232,7 @@ tun: # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. + # This setting is reloadable. unsafe_routes: #- route: 172.16.1.0/24 # via: 192.168.100.99 @@ -244,7 +247,10 @@ tun: # TODO # Configure logging level logging: - # panic, fatal, error, warning, info, or debug. Default is info + # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. + #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some + # scenarios. Debug logging is also CPU intensive and will decrease performance overall. + # Only enable debug logging while actively investigating an issue. level: info # json or text formats currently available. Default is text format: text diff --git a/examples/quickstart-vagrant/README.md b/examples/quickstart-vagrant/README.md deleted file mode 100644 index 108de9e..0000000 --- a/examples/quickstart-vagrant/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Quickstart Guide - -This guide is intended to bring up a vagrant environment with 1 lighthouse and 2 generic hosts running nebula. - -## Creating the virtualenv for ansible - -Within the `quickstart/` directory, do the following - -``` -# make a virtual environment -virtualenv venv - -# get into the virtualenv -source venv/bin/activate - -# install ansible -pip install -r requirements.yml -``` - -## Bringing up the vagrant environment - -A plugin that is used for the Vagrant environment is `vagrant-hostmanager` - -To install, run - -``` -vagrant plugin install vagrant-hostmanager -``` - -All hosts within the Vagrantfile are brought up with - -`vagrant up` - -Once the boxes are up, go into the `ansible/` directory and deploy the playbook by running - -`ansible-playbook playbook.yml -i inventory -u vagrant` - -## Testing within the vagrant env - -Once the ansible run is done, hop onto a vagrant box - -`vagrant ssh generic1.vagrant` - -or specifically - -`ssh vagrant@` (password for the vagrant user on the boxes is `vagrant`) - -See `/etc/nebula/config.yml` on a box for firewall rules. - -To see full handshakes and hostmaps, change the logging config of `/etc/nebula/config.yml` on the vagrant boxes from -info to debug. - -You can watch nebula logs by running - -``` -sudo journalctl -fu nebula -``` - -Refer to the nebula src code directory's README for further instructions on configuring nebula. - -## Troubleshooting - -### Is nebula up and running? - -Run and verify that - -``` -ifconfig -``` - -shows you an interface with the name `nebula1` being up. - -``` -vagrant@generic1:~$ ifconfig nebula1 -nebula1: flags=4305 mtu 1300 - inet 10.168.91.210 netmask 255.128.0.0 destination 10.168.91.210 - inet6 fe80::aeaf:b105:e6dc:936c prefixlen 64 scopeid 0x20 - unspec 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 txqueuelen 500 (UNSPEC) - RX packets 2 bytes 168 (168.0 B) - RX errors 0 dropped 0 overruns 0 frame 0 - TX packets 11 bytes 600 (600.0 B) - TX errors 0 dropped 0 overruns 0 carrier 0 collisions 0 -``` - -### Connectivity - -Are you able to ping other boxes on the private nebula network? - -The following are the private nebula ip addresses of the vagrant env - -``` -generic1.vagrant [nebula_ip] 10.168.91.210 -generic2.vagrant [nebula_ip] 10.168.91.220 -lighthouse1.vagrant [nebula_ip] 10.168.91.230 -``` - -Try pinging generic1.vagrant to and from any other box using its nebula ip above. - -Double check the nebula firewall rules under /etc/nebula/config.yml to make sure that connectivity is allowed for your use-case if on a specific port. - -``` -vagrant@lighthouse1:~$ grep -A21 firewall /etc/nebula/config.yml -firewall: - conntrack: - tcp_timeout: 12m - udp_timeout: 3m - default_timeout: 10m - - inbound: - - proto: icmp - port: any - host: any - - proto: any - port: 22 - host: any - - proto: any - port: 53 - host: any - - outbound: - - proto: any - port: any - host: any -``` diff --git a/examples/quickstart-vagrant/Vagrantfile b/examples/quickstart-vagrant/Vagrantfile deleted file mode 100644 index ab9408f..0000000 --- a/examples/quickstart-vagrant/Vagrantfile +++ /dev/null @@ -1,40 +0,0 @@ -Vagrant.require_version ">= 2.2.6" - -nodes = [ - { :hostname => 'generic1.vagrant', :ip => '172.11.91.210', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, - { :hostname => 'generic2.vagrant', :ip => '172.11.91.220', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, - { :hostname => 'lighthouse1.vagrant', :ip => '172.11.91.230', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1}, -] - -Vagrant.configure("2") do |config| - - config.ssh.insert_key = false - - if Vagrant.has_plugin?('vagrant-cachier') - config.cache.enable :apt - else - printf("** Install vagrant-cachier plugin to speedup deploy: `vagrant plugin install vagrant-cachier`.**\n") - end - - if Vagrant.has_plugin?('vagrant-hostmanager') - config.hostmanager.enabled = true - config.hostmanager.manage_host = true - config.hostmanager.include_offline = true - else - config.vagrant.plugins = "vagrant-hostmanager" - end - - nodes.each do |node| - config.vm.define node[:hostname] do |node_config| - node_config.vm.box = node[:box] - node_config.vm.hostname = node[:hostname] - node_config.vm.network :private_network, ip: node[:ip] - node_config.vm.provider :virtualbox do |vb| - vb.memory = node[:ram] - vb.cpus = node[:cpus] - vb.customize ["modifyvm", :id, "--natdnshostresolver1", "on"] - vb.customize ['guestproperty', 'set', :id, '/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold', 10000] - end - end - end -end diff --git a/examples/quickstart-vagrant/ansible/ansible.cfg b/examples/quickstart-vagrant/ansible/ansible.cfg deleted file mode 100644 index 518a4f1..0000000 --- a/examples/quickstart-vagrant/ansible/ansible.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[defaults] -host_key_checking = False -private_key_file = ~/.vagrant.d/insecure_private_key -become = yes diff --git a/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py b/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py deleted file mode 100644 index a21e82d..0000000 --- a/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/python - - -class FilterModule(object): - def filters(self): - return { - 'to_nebula_ip': self.to_nebula_ip, - 'map_to_nebula_ips': self.map_to_nebula_ips, - } - - def to_nebula_ip(self, ip_str): - ip_list = list(map(int, ip_str.split("."))) - ip_list[0] = 10 - ip_list[1] = 168 - ip = '.'.join(map(str, ip_list)) - return ip - - def map_to_nebula_ips(self, ip_strs): - ip_list = [ self.to_nebula_ip(ip_str) for ip_str in ip_strs ] - ips = ', '.join(ip_list) - return ips diff --git a/examples/quickstart-vagrant/ansible/inventory b/examples/quickstart-vagrant/ansible/inventory deleted file mode 100644 index 0bae407..0000000 --- a/examples/quickstart-vagrant/ansible/inventory +++ /dev/null @@ -1,11 +0,0 @@ -[all] -generic1.vagrant -generic2.vagrant -lighthouse1.vagrant - -[generic] -generic1.vagrant -generic2.vagrant - -[lighthouse] -lighthouse1.vagrant diff --git a/examples/quickstart-vagrant/ansible/playbook.yml b/examples/quickstart-vagrant/ansible/playbook.yml deleted file mode 100644 index c3c7d9f..0000000 --- a/examples/quickstart-vagrant/ansible/playbook.yml +++ /dev/null @@ -1,23 +0,0 @@ ---- -- name: test connection to vagrant boxes - hosts: all - tasks: - - debug: msg=ok - -- name: build nebula binaries locally - connection: local - hosts: localhost - tasks: - - command: chdir=../../../ make build/linux-amd64/"{{ item }}" - with_items: - - nebula - - nebula-cert - tags: - - build-nebula - -- name: install nebula on all vagrant hosts - hosts: all - become: yes - gather_facts: yes - roles: - - nebula diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml deleted file mode 100644 index f8e7a99..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml +++ /dev/null @@ -1,3 +0,0 @@ ---- -# defaults file for nebula -nebula_config_directory: "/etc/nebula/" diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service b/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service deleted file mode 100644 index fd7a067..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service +++ /dev/null @@ -1,14 +0,0 @@ -[Unit] -Description=Nebula overlay networking tool -Wants=basic.target network-online.target nss-lookup.target time-sync.target -After=basic.target network.target network-online.target -Before=sshd.service - -[Service] -SyslogIdentifier=nebula -ExecReload=/bin/kill -HUP $MAINPID -ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml -Restart=always - -[Install] -WantedBy=multi-user.target diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt deleted file mode 100644 index 6148687..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt +++ /dev/null @@ -1,5 +0,0 @@ ------BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSB0ZXN0IENBKNXC1NYFMNXIhO0GOiCmVYeZ9tkB4WEnawmkrca+ -hsAg9otUFhpAowZeJ33KVEABEkAORybHQUUyVFbKYzw0JHfVzAQOHA4kwB1yP9IV -KpiTw9+ADz+wA+R5tn9B+L8+7+Apc+9dem4BQULjA5mRaoYN ------END NEBULA CERTIFICATE----- diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key deleted file mode 100644 index 394043c..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key +++ /dev/null @@ -1,4 +0,0 @@ ------BEGIN NEBULA ED25519 PRIVATE KEY----- -FEXZKMSmg8CgIODR0ymUeNT3nbnVpMi7nD79UgkCRHWmVYeZ9tkB4WEnawmkrca+ -hsAg9otUFhpAowZeJ33KVA== ------END NEBULA ED25519 PRIVATE KEY----- diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml deleted file mode 100644 index 0e09599..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml +++ /dev/null @@ -1,5 +0,0 @@ ---- -# handlers file for nebula - -- name: restart nebula - service: name=nebula state=restarted diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml deleted file mode 100644 index ffc89d5..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml +++ /dev/null @@ -1,62 +0,0 @@ ---- -# tasks file for nebula - -- name: get the vagrant network interface and set fact - set_fact: - vagrant_ifce: "ansible_{{ ansible_interfaces | difference(['lo',ansible_default_ipv4.alias]) | sort | first }}" - tags: - - nebula-conf - -- name: install built nebula binary - copy: src="../../../../../build/linux-amd64/{{ item }}" dest="/usr/local/bin" mode=0755 - with_items: - - nebula - - nebula-cert - -- name: create nebula config directory - file: path="{{ nebula_config_directory }}" state=directory mode=0755 - -- name: temporarily copy over root.crt and root.key to sign - copy: src={{ item }} dest=/opt/{{ item }} - with_items: - - vagrant-test-ca.key - - vagrant-test-ca.crt - -- name: remove previously signed host certificate - file: dest=/etc/nebula/{{ item }} state=absent - with_items: - - host.crt - - host.key - -- name: sign using the root key - command: nebula-cert sign -ca-crt /opt/vagrant-test-ca.crt -ca-key /opt/vagrant-test-ca.key -duration 4320h -groups vagrant -ip {{ hostvars[inventory_hostname][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}/9 -name {{ ansible_hostname }}.nebula -out-crt /etc/nebula/host.crt -out-key /etc/nebula/host.key - -- name: remove root.key used to sign - file: dest=/opt/{{ item }} state=absent - with_items: - - vagrant-test-ca.key - -- name: write the content of the trusted ca certificate - copy: src="vagrant-test-ca.crt" dest="/etc/nebula/vagrant-test-ca.crt" - notify: restart nebula - -- name: Create config directory - file: path="{{ nebula_config_directory }}" owner=root group=root mode=0755 state=directory - -- name: nebula config - template: src=config.yml.j2 dest="/etc/nebula/config.yml" mode=0644 owner=root group=root - notify: restart nebula - tags: - - nebula-conf - -- name: nebula systemd - copy: src=systemd.nebula.service dest="/etc/systemd/system/nebula.service" mode=0644 owner=root group=root - register: addconf - notify: restart nebula - -- name: maybe reload systemd - shell: systemctl daemon-reload - when: addconf.changed - -- name: nebula running - service: name="nebula" state=started enabled=yes diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 b/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 deleted file mode 100644 index a05b1e3..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 +++ /dev/null @@ -1,85 +0,0 @@ -pki: - ca: /etc/nebula/vagrant-test-ca.crt - cert: /etc/nebula/host.crt - key: /etc/nebula/host.key - -# Port Nebula will be listening on -listen: - host: 0.0.0.0 - port: 4242 - -# sshd can expose informational and administrative functions via ssh -sshd: - # Toggles the feature - enabled: true - # Host and port to listen on - listen: 127.0.0.1:2222 - # A file containing the ssh host private key to use - host_key: /etc/ssh/ssh_host_ed25519_key - # A file containing a list of authorized public keys - authorized_users: -{% for user in nebula_users %} - - user: {{ user.name }} - keys: -{% for key in user.ssh_auth_keys %} - - "{{ key }}" -{% endfor %} -{% endfor %} - -local_range: 10.168.0.0/16 - -static_host_map: -# lighthouse - {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}: ["{{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address']}}:4242"] - -default_route: "0.0.0.0" - -lighthouse: -{% if 'lighthouse' in group_names %} - am_lighthouse: true - serve_dns: true -{% else %} - am_lighthouse: false -{% endif %} - interval: 60 -{% if 'generic' in group_names %} - hosts: - - {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }} -{% endif %} - -# Configure the private interface -tun: - dev: nebula1 - # Sets MTU of the tun dev. - # MTU of the tun must be smaller than the MTU of the eth0 interface - mtu: 1300 - -# TODO -# Configure logging level -logging: - level: info - format: json - -firewall: - conntrack: - tcp_timeout: 12m - udp_timeout: 3m - default_timeout: 10m - - inbound: - - proto: icmp - port: any - host: any - - proto: any - port: 22 - host: any -{% if "lighthouse" in groups %} - - proto: any - port: 53 - host: any -{% endif %} - - outbound: - - proto: any - port: any - host: any diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml deleted file mode 100644 index 7a3ae5d..0000000 --- a/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml +++ /dev/null @@ -1,7 +0,0 @@ ---- -# vars file for nebula - -nebula_users: - - name: user1 - ssh_auth_keys: - - "ed25519 place-your-ssh-public-key-here" diff --git a/examples/quickstart-vagrant/requirements.yml b/examples/quickstart-vagrant/requirements.yml deleted file mode 100644 index 90d4055..0000000 --- a/examples/quickstart-vagrant/requirements.yml +++ /dev/null @@ -1 +0,0 @@ -ansible diff --git a/examples/service_scripts/nebula.open-rc b/examples/service_scripts/nebula.open-rc new file mode 100644 index 0000000..2beca66 --- /dev/null +++ b/examples/service_scripts/nebula.open-rc @@ -0,0 +1,35 @@ +#!/sbin/openrc-run +# +# nebula service for open-rc systems + +extra_commands="checkconfig" + +: ${NEBULA_CONFDIR:=${RC_PREFIX%/}/etc/nebula} +: ${NEBULA_CONFIG:=${NEBULA_CONFDIR}/config.yml} +: ${NEBULA_BINARY:=${NEBULA_BINARY}${RC_PREFIX%/}/usr/local/sbin/nebula} + +command="${NEBULA_BINARY}" +command_args="${NEBULA_OPTS} -config ${NEBULA_CONFIG}" + +supervisor="supervise-daemon" + +description="A scalable overlay networking tool with a focus on performance, simplicity and security" + +required_dirs="${NEBULA_CONFDIR}" +required_files="${NEBULA_CONFIG}" + +checkconfig() { + "${command}" -test ${command_args} || return 1 +} + +start_pre() { + if [ "${RC_CMD}" != "restart" ] ; then + checkconfig || return $? + fi +} + +stop_pre() { + if [ "${RC_CMD}" = "restart" ] ; then + checkconfig || return $? + fi +} diff --git a/firewall.go b/firewall.go index cf2bc52..8a409d2 100644 --- a/firewall.go +++ b/firewall.go @@ -2,37 +2,31 @@ package nebula import ( "crypto/sha256" - "encoding/binary" "encoding/hex" "errors" "fmt" "hash/fnv" - "net" + "net/netip" "reflect" "strconv" "strings" "sync" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) -const tcpACK = 0x10 -const tcpFIN = 0x01 - type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error } type conn struct { Expires time.Time // Time when this conntrack entry will expire - Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set - Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack // record why the original connection passed the firewall, so we can re-validate // after ruleset changes. Note, rulesVersion is a uint16 so that these two @@ -58,16 +52,14 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] - assignedCIDR *net.IPNet + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix hasSubnets bool rules string rulesVersion uint16 defaultLocalCIDRAny bool - trackTCPRTT bool - metricTCPRTT metrics.Histogram incomingMetrics firewallMetrics outgoingMetrics firewallMetrics @@ -116,7 +108,7 @@ type FirewallRule struct { Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups - CIDR *cidr.Tree4[*firewallLocalCIDR] + CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { @@ -130,7 +122,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *cidr.Tree4[struct{}] + LocalCIDR *bart.Table[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -152,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4[struct{}]() - var assignedCIDR *net.IPNet + localIps := new(bart.Table[struct{}]) + var assignedCIDR netip.Prefix + var assignedSet bool for _, ip := range c.Details.Ips { - ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} - localIps.AddCIDR(ipNet, struct{}{}) + //TODO: IPV6-WORK the unmap is a bit unfortunate + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + nprefix := netip.PrefixFrom(nip, nip.BitLen()) + localIps.Insert(nprefix, struct{}{}) - if assignedCIDR == nil { + if !assignedSet { // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = ipNet + assignedCIDR = nprefix + assignedSet = true } } for _, n := range c.Details.Subnets { - localIps.AddCIDR(n, struct{}{}) + nip, _ := netip.AddrFromSlice(n.IP) + ones, _ := n.Mask.Size() + nip = nip.Unmap() + localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) } return &Firewall{ @@ -183,7 +183,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasSubnets: len(c.Details.Subnets) > 0, l: l, - metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), incomingMetrics: firewallMetrics{ droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), @@ -246,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" - if ip != nil { + if ip.IsValid() { sIp = ip.String() } lIp := "" - if localIp != nil { + if localIp.IsValid() { lIp = localIp.String() } @@ -391,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr *net.IPNet + var cidr netip.Prefix if r.Cidr != "" { - _, cidr, err = net.ParseCIDR(r.Cidr) + cidr, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr *net.IPNet + var localCidr netip.Prefix if r.LocalCidr != "" { - _, localCidr, err = net.ParseCIDR(r.LocalCidr) + localCidr, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } @@ -422,15 +421,16 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(packet, fp, incoming, h, caPool, localCache) { + if f.inConns(fp, h, caPool, localCache) { return nil } // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - ok, _ := remoteCidr.Contains(fp.RemoteIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := remoteCidr.Lookup(fp.RemoteIP) if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP @@ -444,7 +444,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // Make sure we are supposed to be handling this local ip address - ok, _ := f.localIps.Contains(fp.LocalIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := f.localIps.Lookup(fp.LocalIP) if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP @@ -462,7 +463,7 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos } // We always want to conntrack since it is a faster operation - f.addConn(packet, fp, incoming) + f.addConn(fp, incoming) return nil } @@ -491,7 +492,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true @@ -551,11 +552,6 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h * switch fp.Protocol { case firewall.ProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) - if incoming { - f.checkTCPRTT(c, packet) - } else { - setTCPRTTTracking(c, packet) - } case firewall.ProtoUDP: c.Expires = time.Now().Add(f.UDPTimeout) default: @@ -571,16 +567,13 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h * return true } -func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { var timeout time.Duration c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: timeout = f.TCPTimeout - if !incoming { - setTCPRTTTracking(c, packet) - } case firewall.ProtoUDP: timeout = f.UDPTimeout default: @@ -606,7 +599,6 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { - //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -650,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -694,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), - CIDR: cidr.NewTree4[*firewallLocalCIDR](), + CIDR: new(bart.Table[*firewallLocalCIDR]), } } @@ -757,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[struct{}](), + LocalCIDR: new(bart.Table[struct{}]), } } @@ -797,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n fr.Hosts[host] = nlc } - if ip != nil { - _, nlc := fr.CIDR.GetCIDR(ip) + if ip.IsValid() { + nlc, _ := fr.CIDR.Get(ip) if nlc == nil { nlc = flc() } @@ -806,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n if err != nil { return err } - fr.CIDR.AddCIDR(ip, nlc) + fr.CIDR.Insert(ip, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { + if len(groups) == 0 && host == "" && !ip.IsValid() { return true } @@ -827,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } - if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + if ip.IsValid() && ip.Bits() == 0 { return true } @@ -870,22 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { - return flc.match(p, c) + matched := false + prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { + if prefix.Contains(p.RemoteIP) && val.match(p, c) { + matched = true + return false + } + return true }) + return matched } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { - if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { + if !localIp.IsValid() { if !f.hasSubnets || f.defaultLocalCIDRAny { flc.Any = true return nil } localIp = f.assignedCIDR + } else if localIp.Bits() == 0 { + flc.Any = true } - flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } @@ -898,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalIP) return ok } @@ -1015,42 +1016,3 @@ func parsePort(s string) (startPort, endPort int32, err error) { return } - -// TODO: write tests for these -func setTCPRTTTracking(c *conn, p []byte) { - if c.Seq != 0 { - return - } - - ihl := int(p[0]&0x0f) << 2 - - // Don't track FIN packets - if p[ihl+13]&tcpFIN != 0 { - return - } - - c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8]) - c.Sent = time.Now() -} - -func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool { - if c.Seq == 0 { - return false - } - - ihl := int(p[0]&0x0f) << 2 - if p[ihl+13]&tcpACK == 0 { - return false - } - - // Deal with wrap around, signed int cuts the ack window in half - // 0 is a bad ack, no data acknowledged - // positive number is a bad ack, ack is over half the window away - if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 { - return false - } - - f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds()) - c.Seq = 0 - return true -} diff --git a/firewall/packet.go b/firewall/packet.go index 1c4affd..8954f4c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -3,8 +3,7 @@ package firewall import ( "encoding/json" "fmt" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type m map[string]interface{} @@ -20,8 +19,8 @@ const ( ) type Packet struct { - LocalIP iputil.VpnIp - RemoteIP iputil.VpnIp + LocalIP netip.Addr + RemoteIP netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 diff --git a/firewall_test.go b/firewall_test.go index 7d65cb5..4d47e78 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -2,18 +2,16 @@ package nebula import ( "bytes" - "encoding/binary" "errors" "math" "net" + "net/netip" "testing" "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -67,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - _, ti, _ := net.ParseCIDR("1.2.3.4/32") + ti, err := netip.ParsePrefix("1.2.3.4/32") + assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + anyIp, err := netip.ParsePrefix("0.0.0.0/0") + assert.NoError(t, err) + + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -154,53 +155,53 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr("1.2.3.4"), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteIP - p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP) + p.RemoteIP = netip.MustParseAddr("1.2.3.10") + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -209,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { TCP: firewallPort{}, } - _, n, _ := net.ParseCIDR("172.1.1.1/32") - goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + pfix := netip.MustParsePrefix("172.1.1.1/32") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -233,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { c := &cert.NebulaCertificate{} - ip, _, _ := net.ParseCIDR("9.254.254.254/32") - lip := iputil.Ip2VpnIp(ip) + ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) } }) @@ -264,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -288,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -365,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -389,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) @@ -408,14 +407,14 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule) + assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -424,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -455,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h1.CreateRemoteCIDR(&c1) @@ -470,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h2.CreateRemoteCIDR(&c2) @@ -485,23 +484,23 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -510,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -536,39 +535,39 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) + assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } func BenchmarkLookup(b *testing.B) { @@ -727,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) @@ -749,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr - cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -830,97 +829,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } -func TestTCPRTTTracking(t *testing.T) { - b := make([]byte, 200) - - // Max ip IHL (60 bytes) and tcp IHL (60 bytes) - b[0] = 15 - b[60+12] = 15 << 4 - f := Firewall{ - metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)), - } - - // Set SEQ to 1 - binary.BigEndian.PutUint32(b[60+4:60+8], 1) - - c := &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, uint32(1), c.Seq) - - // Bad ack - no ack flag - binary.BigEndian.PutUint32(b[60+8:60+12], 80) - assert.False(t, f.checkTCPRTT(c, b)) - - // Bad ack, number is too low - binary.BigEndian.PutUint32(b[60+8:60+12], 0) - b[60+13] = uint8(0x10) - assert.False(t, f.checkTCPRTT(c, b)) - - // Good ack - binary.BigEndian.PutUint32(b[60+8:60+12], 80) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to 1 - binary.BigEndian.PutUint32(b[60+4:60+8], 1) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, uint32(1), c.Seq) - - // Good acks - binary.BigEndian.PutUint32(b[60+8:60+12], 81) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 - 20 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0)-20, c.Seq) - - // Good acks - binary.BigEndian.PutUint32(b[60+8:60+12], 81) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 / 2 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Below - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Halfway below - binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0)) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0)/2, c.Seq) - - // Halfway above is ok - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) - - // Set SEQ to max uint32 - binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)) - c = &conn{} - setTCPRTTTracking(c, b) - assert.Equal(t, ^uint32(0), c.Seq) - - // Halfway + 1 above - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1) - assert.False(t, f.checkTCPRTT(c, b)) - assert.Equal(t, ^uint32(0), c.Seq) - - // Halfway above - binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2) - assert.True(t, f.checkTCPRTT(c, b)) - assert.Equal(t, uint32(0), c.Seq) -} - func TestFirewall_convertRule(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} @@ -964,8 +872,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip *net.IPNet - localIp *net.IPNet + ip netip.Prefix + localIp netip.Prefix caName string caSha string } @@ -975,7 +883,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, diff --git a/go.mod b/go.mod index abc1134..1da2056 100644 --- a/go.mod +++ b/go.mod @@ -1,52 +1,55 @@ module github.com/slackhq/nebula -go 1.20 +go 1.22.0 + +toolchain go1.22.2 require ( dario.cat/mergo v1.0.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 - github.com/flynn/noise v1.0.1 + github.com/flynn/noise v1.1.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.58 + github.com/miekg/dns v1.1.61 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.18.0 + github.com/prometheus/client_golang v1.19.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.18.0 - golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 - golang.org/x/net v0.20.0 - golang.org/x/sync v0.6.0 - golang.org/x/sys v0.16.0 - golang.org/x/term v0.16.0 + golang.org/x/crypto v0.24.0 + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 + golang.org/x/net v0.26.0 + golang.org/x/sync v0.7.0 + golang.org/x/sys v0.21.0 + golang.org/x/term v0.21.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.32.0 + google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 - gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f + gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/btree v1.0.1 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect + github.com/gaissmai/bart v0.11.1 // indirect + github.com/google/btree v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/common v0.45.0 // indirect + github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect - golang.org/x/tools v0.17.0 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/time v0.5.0 // indirect + golang.org/x/tools v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6226e15..6db0c4a 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -22,8 +24,12 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go. github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y= -github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= +github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI= +github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= +github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -44,14 +50,15 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -71,14 +78,13 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= -github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= -github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= +github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= +github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -96,8 +102,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= -github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -106,8 +112,8 @@ github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= @@ -117,6 +123,7 @@ github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3c github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -132,8 +139,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= @@ -146,16 +153,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o= -golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -166,8 +173,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -175,8 +182,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -194,23 +201,23 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= +golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -229,8 +236,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= -google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -246,5 +253,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E= -gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q= +gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g= +gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= diff --git a/handshake_ix.go b/handshake_ix.go index 1905c00..8cf5341 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,13 +1,12 @@ package nebula import ( + "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes @@ -46,7 +45,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) - ci.messageCounter.Add(1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { @@ -64,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed @@ -90,17 +88,36 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Invalid certificate from host") + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid certificate from host") + return + } + + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") return } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer - if vpnIp == f.myVpnIp { + if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -109,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -134,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -214,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). @@ -280,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -316,13 +333,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by } f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) + hostinfo.remotes.ResetBlockedRemotes() return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -332,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha defer hh.Unlock() hostinfo := hh.hostinfo - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } @@ -372,15 +389,33 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). - WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Error("Invalid certificate from host") + e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Error("Invalid certificate from host") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return true + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer @@ -406,7 +441,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). - WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)). + WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient @@ -444,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr != nil { + if addr.IsValid() { hostinfo.SetRemote(addr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) @@ -457,8 +492,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) - hostinfo.ConnectionState.messageCounter.Store(2) - if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) } diff --git a/handshake_manager.go b/handshake_manager.go index b568cc8..7960435 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,15 +6,15 @@ import ( "crypto/rand" "encoding/binary" "errors" - "net" + "net/netip" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/slices" ) const ( @@ -46,14 +46,14 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HandshakeHostInfo + vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig - OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] + OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter @@ -61,17 +61,17 @@ type HandshakeManager struct { l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp - trigger chan iputil.VpnIp + trigger chan netip.Addr } type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int // How many attempts have we made so far - lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + trigger: make(chan netip.Addr, config.triggerBuffer), + OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr != nil { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + if addr.IsValid() { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { +func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return @@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { - hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)). + hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). WithField("initiatorIndex", hh.hostinfo.localIndexId). WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). @@ -211,8 +211,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) + remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + var sentTo []netip.AddrPort + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { + if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) - if relayHostInfo == nil || relayHostInfo.remote == nil { + relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(*relay) + hm.f.Handshake(relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } default: @@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } } else { // No relays exist or requested yet. - if relayHostInfo.remote != nil { + if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": idx, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } } @@ -355,32 +364,32 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { - // Check the main hostmap and maintain a read lock if our host is not there +func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() - if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok { - hm.mainHostMap.RUnlock() + h, ok := hm.mainHostMap.Hosts[vpnIp] + hm.mainHostMap.RUnlock() + + if ok { // Do not attempt promotion if you are a lighthouse if !hm.lightHouse.amLighthouse { - h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f) + h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) } return h, true } - defer hm.mainHostMap.RUnlock() return hm.StartHandshake(vpnIp, cacheCb), false } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - defer hm.Unlock() if hh, ok := hm.vpnIps[vpnIp]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) } + hm.Unlock() return hh.hostinfo } @@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -421,6 +430,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } } + hm.Unlock() hm.lightHouse.QueryServer(vpnIp) return hostinfo } @@ -554,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} + c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) @@ -569,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -578,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { } -func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { +func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] @@ -598,8 +608,8 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { - return c.mainHostMap.preferredRanges +func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return c.mainHostMap.GetPreferredRanges() } func (c *HandshakeManager) ForEachVpnIp(f controlEach) { diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 303aa50..a78b45f 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -1,13 +1,12 @@ package nebula import ( - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -15,11 +14,14 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} - mainHM := NewHostMap(l, vpncidr, preferredRanges) + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + ip := netip.MustParseAddr("172.1.1.2") + + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l, vpncidr) + mainHM.preferredRanges.Store(&preferredRanges) + lh := newTestLighthouse() cs := &CertState{ @@ -64,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { +func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { @@ -78,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { return } @@ -90,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M return } -func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} +func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} diff --git a/hostmap.go b/hostmap.go index a5adeb9..fb97b76 100644 --- a/hostmap.go +++ b/hostmap.go @@ -3,17 +3,17 @@ package nebula import ( "errors" "net" + "net/netip" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 @@ -48,7 +48,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp iputil.VpnIp + PeerIp netip.Addr } type HostMap struct { @@ -56,10 +56,9 @@ type HostMap struct { Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo - Hosts map[iputil.VpnIp]*HostInfo - preferredRanges []*net.IPNet - vpnCIDR *net.IPNet - metricsEnabled bool + Hosts map[netip.Addr]*HostInfo + preferredRanges atomic.Pointer[[]netip.Prefix] + vpnCIDR netip.Prefix l *logrus.Logger } @@ -69,12 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer + relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } -func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { +func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() delete(rs.relays, ip) @@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[ip] return r, ok } -func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() rs.relays[ip] = struct{}{} } -func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - ret := make([]iputil.VpnIp, 0, len(rs.relays)) + ret := make([]netip.Addr, 0, len(rs.relays)) for ip := range rs.relays { ret = append(ret, ip) } return ret } -func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) + currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) for relayIp := range rs.relayForByIp { currentRelays = append(currentRelays, relayIp) } @@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { return ret } -func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return iputil.VpnIp(0), false - } - delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, r.PeerIp) - return r.PeerIp, true -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByIp[vpnIp] @@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re return &newRelay, true } -func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[vpnIp] @@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { return r, ok } -func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() rs.relayForByIp[ip] = r @@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - remote *udp.Addr + remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp iputil.VpnIp + vpnIp netip.Addr recvError atomic.Uint32 - remoteCidr *cidr.Tree4[struct{}] + remoteCidr *bart.Table[struct{}] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -227,7 +214,7 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udp.Addr + lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. @@ -254,21 +241,53 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { - h := map[iputil.VpnIp]*HostInfo{} - i := map[uint32]*HostInfo{} - r := map[uint32]*HostInfo{} - relays := map[uint32]*HostInfo{} - m := HostMap{ - Indexes: i, - Relays: relays, - RemoteIndexes: r, - Hosts: h, - preferredRanges: preferredRanges, - vpnCIDR: vpnCIDR, - l: l, +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { + hm := newHostMap(l, vpnCIDR) + + hm.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + hm.reload(c, false) + }) + + l.WithField("network", hm.vpnCIDR.String()). + WithField("preferredRanges", hm.GetPreferredRanges()). + Info("Main HostMap created") + + return hm +} + +func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { + return &HostMap{ + Indexes: map[uint32]*HostInfo{}, + Relays: map[uint32]*HostInfo{}, + RemoteIndexes: map[uint32]*HostInfo{}, + Hosts: map[netip.Addr]*HostInfo{}, + vpnCIDR: vpnCIDR, + l: l, + } +} + +func (hm *HostMap) reload(c *config.C, initial bool) { + if initial || c.HasChanged("preferred_ranges") { + var preferredRanges []netip.Prefix + rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) + + for _, rawPreferredRange := range rawPreferredRanges { + preferredRange, err := netip.ParsePrefix(rawPreferredRange) + + if err != nil { + hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") + continue + } + + preferredRanges = append(preferredRanges, preferredRange) + } + + oldRanges := hm.preferredRanges.Swap(&preferredRanges) + if !initial { + hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") + } } - return &m } // EmitStats reports host, index, and relay counts to the stats collection system @@ -346,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it delete(hm.Hosts, hostinfo.vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { @@ -429,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -451,13 +470,13 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() // Do not attempt promotion if you are a lighthouse if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { - h.TryPromoteBest(hm.preferredRanges, promoteIfce) + h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce) } return h @@ -503,8 +522,9 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) GetPreferredRanges() []*net.IPNet { - return hm.preferredRanges +func (hm *HostMap) GetPreferredRanges() []netip.Prefix { + //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer + return *hm.preferredRanges.Load() } func (hm *HostMap) ForEachVpnIp(f controlEach) { @@ -527,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote - if remote != nil { - rIP := remote.IP + if remote.IsValid() { + rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return @@ -542,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { - if remote != nil && (addr == nil || !preferred) { + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { return } @@ -572,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) SetRemote(remote *udp.Addr) { +func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object - if !i.remote.Equals(remote) { - i.remote = remote.Copy() - i.remotes.LearnRemote(i.vpnIp, remote.Copy()) + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnIp, remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { - if newRemote == nil { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { // relays have nil udp Addrs return false } currentRemote := i.remote - if currentRemote == nil { + if !currentRemote.IsValid() { i.SetRemote(newRemote) return true } @@ -596,13 +616,13 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { // NOTE: We do this loop here instead of calling `isPreferred` in // remote_list.go so that we only have to loop over preferredRanges once. newIsPreferred := false - for _, l := range hm.preferredRanges { + for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote - if l.Contains(currentRemote.IP) { + if l.Contains(currentRemote.Addr()) { return false } - if l.Contains(newRemote.IP) { + if l.Contains(newRemote.Addr()) { newIsPreferred = true } } @@ -610,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote.Copy() + i.lastRoamRemote = currentRemote i.SetRemote(newRemote) @@ -633,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4[struct{}]() + remoteCidr := new(bart.Table[struct{}]) for _, ip := range c.Details.Ips { - remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + bits, _ := ip.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } for _, n := range c.Details.Subnets { - remoteCidr.AddCIDR(n, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(n.IP) + nip = nip.Unmap() + bits, _ := n.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } i.remoteCidr = remoteCidr } @@ -664,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []net.IP + var ips []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -688,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { ip = v.IP } + nip, ok := netip.AddrFromSlice(ip) + if !ok { + if l.Level >= logrus.DebugLevel { + l.WithField("localIp", ip).Debug("ip was invalid for netip") + } + continue + } + nip = nip.Unmap() + //TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Would be nice to filter out SLAAC MAC based ips as well - if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { - allow := allowList.Allow(ip) + if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { + allow := allowList.Allow(nip) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") } if !allow { continue } - ips = append(ips, ip) + ips = append(ips, nip) } } } - return &ips + return ips } diff --git a/hostmap_test.go b/hostmap_test.go index c1c0dce..7e2feb8 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1,30 +1,27 @@ package nebula import ( - "net" + "net/netip" "testing" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := NewHostMap( + hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - []*net.IPNet{}, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -91,23 +88,19 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := NewHostMap( + hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - []*net.IPNet{}, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} - h5 := &HostInfo{vpnIp: 1, localIndexId: 5} - h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} + h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -123,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -142,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -160,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -176,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -190,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -202,6 +195,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } + +func TestHostMap_reload(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + + hm := NewHostMapFromConfig( + l, + netip.MustParsePrefix("10.0.0.1/24"), + c, + ) + + toS := func(ipn []netip.Prefix) []string { + var s []string + for _, n := range ipn { + s = append(s, n.String()) + } + return s + } + + assert.Empty(t, hm.GetPreferredRanges()) + + c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") + assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) + + c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") + assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) +} diff --git a/hostmap_tester.go b/hostmap_tester.go index 0d5d41b..b2d1d1b 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -5,9 +5,11 @@ package nebula // This file contains functions used to export information to the e2e testing framework -import "github.com/slackhq/nebula/iputil" +import ( + "net/netip" +) -func (i *HostInfo) GetVpnIp() iputil.VpnIp { +func (i *HostInfo) GetVpnIp() netip.Addr { return i.vpnIp } diff --git a/inside.go b/inside.go index 6230962..0ccd179 100644 --- a/inside.go +++ b/inside.go @@ -1,12 +1,13 @@ package nebula import ( + "net/netip" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/udp" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { + if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { return } - if fwPacket.RemoteIP == f.myVpnIp { + if fwPacket.RemoteIP == f.myVpnNet.Addr() { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula IP to the Nebula IP through the Nebula @@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - // Ignore broadcast packets - if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { return } @@ -62,9 +63,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { +func (f *Interface) Handshake(vpnIp netip.Addr) { f.getOrHandshake(vpnIp, nil) } // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { +func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + if !f.myVpnNet.Contains(vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return nil, false } } @@ -142,7 +143,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). @@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) @@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } @@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning return } - useRelay := remote == nil && hostinfo.remote == nil + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { @@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType return } - if remote != nil { + if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). @@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } - -func isMulticast(ip iputil.VpnIp) bool { - // Class D multicast - return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 -} diff --git a/interface.go b/interface.go index d16348a..f251907 100644 --- a/interface.go +++ b/interface.go @@ -2,10 +2,11 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" "io" - "net" + "net/netip" "os" "runtime" "sync/atomic" @@ -16,7 +17,6 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -63,8 +63,8 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - localBroadcast iputil.VpnIp - myVpnIp iputil.VpnIp + myBroadcastAddr netip.Addr + myVpnNet netip.Prefix dropLocalBroadcast bool dropMulticast bool routines int @@ -102,9 +102,9 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) + Handshake(vpnIp netip.Addr) } type sendRecvErrorConfig uint8 @@ -115,10 +115,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.IsPrivate() + return ip.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } certificate := c.pki.GetCertState().Certificate - myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) + + myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) + } + + myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) + if !ok { + return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) + } + + myVpnAddr = myVpnAddr.Unmap() + myVpnMask = myVpnMask.Unmap() + + if myVpnAddr.BitLen() != myVpnMask.BitLen() { + return nil, fmt.Errorf("ip address and mask are different lengths in certificate") + } + + ones, _ := certificate.Details.Ips[0].Mask.Size() + myVpnNet := netip.PrefixFrom(myVpnAddr, ones) + ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnIp: myVpnIp, + myVpnNet: myVpnNet, relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + if myVpnAddr.Is4() { + addr := myVpnNet.Masked().Addr().As4() + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) diff --git a/iputil/packet.go b/iputil/packet.go index b18e524..719e034 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,6 +6,8 @@ import ( "golang.org/x/net/ipv4" ) +//TODO: IPV6-WORK can probably delete this + const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header diff --git a/iputil/util.go b/iputil/util.go deleted file mode 100644 index 65f7677..0000000 --- a/iputil/util.go +++ /dev/null @@ -1,93 +0,0 @@ -package iputil - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" -) - -type VpnIp uint32 - -const maxIPv4StringLen = len("255.255.255.255") - -func (ip VpnIp) String() string { - b := make([]byte, maxIPv4StringLen) - - n := ubtoa(b, 0, byte(ip>>24)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>16&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>8&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip&255)) - return string(b[:n]) -} - -func (ip VpnIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil -} - -func (ip VpnIp) ToIP() net.IP { - nip := make(net.IP, 4) - binary.BigEndian.PutUint32(nip, uint32(ip)) - return nip -} - -func (ip VpnIp) ToNetIpAddr() netip.Addr { - var nip [4]byte - binary.BigEndian.PutUint32(nip[:], uint32(ip)) - return netip.AddrFrom4(nip) -} - -func Ip2VpnIp(ip []byte) VpnIp { - if len(ip) == 16 { - return VpnIp(binary.BigEndian.Uint32(ip[12:16])) - } - return VpnIp(binary.BigEndian.Uint32(ip)) -} - -func ToNetIpAddr(ip net.IP) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip) - } - return addr, nil -} - -func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) { - addr, err := ToNetIpAddr(ipNet.IP) - if err != nil { - return netip.Prefix{}, err - } - ones, bits := ipNet.Mask.Size() - if ones == 0 && bits == 0 { - return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet) - } - return netip.PrefixFrom(addr, ones), nil -} - -// ubtoa encodes the string form of the integer v to dst[start:] and -// returns the number of bytes written to dst. The caller must ensure -// that dst has sufficient length. -func ubtoa(dst []byte, start int, v byte) int { - if v < 10 { - dst[start] = v + '0' - return 1 - } else if v < 100 { - dst[start+1] = v%10 + '0' - dst[start] = v/10 + '0' - return 2 - } - - dst[start+2] = v%10 + '0' - dst[start+1] = (v/10)%10 + '0' - dst[start] = v/100 + '0' - return 3 -} diff --git a/iputil/util_test.go b/iputil/util_test.go deleted file mode 100644 index 712d426..0000000 --- a/iputil/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package iputil - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestVpnIp_String(t *testing.T) { - assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) - assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) - assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) - assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) - assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) - assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) -} diff --git a/lighthouse.go b/lighthouse.go index aa54c4b..62f4065 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,16 +7,16 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -26,25 +26,18 @@ import ( var ErrHostNotKnown = errors.New("host not known") -type netIpAndPort struct { - ip net.IP - port uint16 -} - type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnIp iputil.VpnIp - myVpnZeros iputil.VpnIp - myVpnNet *net.IPNet + myVpnNet netip.Prefix punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers - addrMap map[iputil.VpnIp]*RemoteList + addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -57,26 +50,26 @@ type LightHouse struct { localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- iputil.VpnIp + handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList atomic.Pointer[map[iputil.VpnIp]struct{}] - lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] + staticList atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[map[netip.Addr]struct{}] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - advertiseAddrs atomic.Pointer[[]netIpAndPort] + advertiseAddrs atomic.Pointer[[]netip.AddrPort] // IP's of relays that can be used by peers to access me - relaysForMe atomic.Pointer[[]iputil.VpnIp] + relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan iputil.VpnIp + queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -85,7 +78,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - nebulaPort = uint32(uPort.Port) + nebulaPort = uint32(uPort.Port()) } - ones, _ := myVpnNet.Mask.Size() h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), + addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, punchConn: pc, punchy: p, - queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[iputil.VpnIp]struct{}) + lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) - staticList := make(map[iputil.VpnIp]struct{}) + staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { @@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, return &h, nil } -func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { return *lh.lighthouses.Load() } @@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } -func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { +func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } -func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { +func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { +func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) - advAddrs := make([]netIpAndPort, 0) + advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { - fIp, fPort, err := udp.ParseIPAndPort(rawAddr) + host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if fPort == 0 { - fPort = uint16(lh.nebulaPort) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + if err != nil { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + if len(ips) == 0 { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) + } + + port, err := strconv.Atoi(sport) + if err != nil { + return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + + if port == 0 { + port = int(lh.nebulaPort) } - if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { + //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used + ip := ips[0].Unmap() + if lh.myVpnNet.Contains(ip) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) + advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.RUnlock() } // Build a new list based on current config. - staticList := make(map[iputil.VpnIp]struct{}) - err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + staticList := make(map[netip.Addr]struct{}) + err := lh.loadStaticMap(c, staticList) if err != nil { return err } @@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[iputil.VpnIp]struct{}) - err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + lhMap := make(map[netip.Addr]struct{}) + err := lh.parseLighthouses(c, lhMap) if err != nil { return err } @@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { lh.l.WithField("relay", v).Info("Read relay from config") - configRIP := net.ParseIP(v) - if configRIP != nil { - relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) + configRIP, err := netip.ParseAddr(v) + //TODO: We could print the error here + if err == nil { + relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) @@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } for i, host := range lhs { - ip := net.ParseIP(host) - if ip == nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + ip, err := netip.ParseAddr(host) + if err != nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !tunCidr.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + if !lh.myVpnNet.Contains(ip) { + return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) } - lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + lhMap[ip] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) { return network, nil } -func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err @@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return err } - lookup_timeout, err := getStaticMapLookupTimeout(c) + lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } @@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList i := 0 for k, v := range shm { - rip := net.ParseIP(fmt.Sprintf("%v", k)) - if rip == nil { - return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + if err != nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !tunCidr.Contains(rip) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + if !lh.myVpnNet.Contains(vpnIp) { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) } - vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) if !ok { vals = []interface{}{v} @@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) if err != nil { return err } @@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip) } @@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { +func (lh *LightHouse) QueryServer(ip netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses if lh.amLighthouse || lh.IsLighthouseIP(ip) { return @@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { lh.queryChan <- ip } -func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() @@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnIp]; ok { @@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { // First we check the static mapping // and do nothing if it is there if _, ok := lh.GetStaticHostList()[vpnIp]; ok { @@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - + if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + continue + } switch { case addrPort.Addr().Is4(): - to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV4(vpnIp, to) { - continue - } - am.unlockedPrependV4(lh.myVpnIp, to) + am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV6(vpnIp, to) { - continue - } - am.unlockedPrependV6(lh.myVpnIp, to) + am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) } } @@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnIp) if !ok { return false } @@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } // unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { +func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) @@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { return am } -func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { - switch { - case to.Is4(): - ipBytes := to.As4() - ip := iputil.Ip2VpnIp(ipBytes[:]) - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { - return false - } - case to.Is6(): - ipBytes := to.As16() - - hi := binary.BigEndian.Uint64(ipBytes[:8]) - lo := binary.BigEndian.Uint64(ipBytes[8:]) - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { - return false - } +func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnIp, to) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || lh.myVpnNet.Contains(to) { + return false } + return true } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { + ip := AddrPortFromIp4AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { + ip := AddrPortFromIp6AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { if _, ok := lh.GetLighthouses()[vpnIp]; ok { return true } return false } -func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { +func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { + if vpnIp.Is6() { + //TODO: need to support ipv6 + panic("ipv6 is not yet supported") + } + + b := vpnIp.As4() return &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(VpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), }, } } -func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { - ipp := Ip4AndPort{Port: port} - ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) - return &ipp +func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ip.Ip) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) +} + +func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ip.Hi) + binary.BigEndian.PutUint64(b[8:], ip.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) } func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { @@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { } } -func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip[:8]), - Lo: binary.BigEndian.Uint64(ip[8:]), - Port: port, - } -} - +// TODO: IPV6-WORK we can delete some more of these func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { ip6Addr := ip.As16() return &Ip6AndPort{ @@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { Port: uint32(port), } } -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { - ip := ipp.Ip - return udp.NewAddr( - net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), - uint16(ipp.Port), - ) -} - -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { - return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) -} func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { @@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() { }() } -func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { +func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { if lh.IsLighthouseIP(ip) { return } @@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() { var v6 []*Ip6AndPort for _, e := range lh.GetAdvertiseAddrs() { - if ip := e.ip.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) + if e.Addr().Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) + v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range *localIps(lh.l, lal) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { + for _, e := range localIps(lh.l, lal) { + if lh.myVpnNet.Contains(e) { continue } // Only add IPs that aren't my VPN/tun IP - if ip := e.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) + if e.Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) + v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) } } var relays []uint32 for _, r := range lh.GetRelaysForMe() { - relays = append(relays, (uint32)(r)) + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) } + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := lh.myVpnNet.Addr().As4() + m := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(lh.myVpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), Ip4AndPorts: v4, Ip6AndPorts: v6, RelayVpnIp: relays, @@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { } func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { lhh.HandleRequest(rAddr, vpnIp, p, f) } } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, //TODO: we can DRY this further reqVpnIp := n.Details.VpnIp + + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) + //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply n.Details.VpnIp = reqVpnIp @@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = uint32(vpnIp) - + //TODO: IPV6-WORK + b = vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) @@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) + + //TODO: IPV6-WORK + binary.BigEndian.PutUint32(b[:], reqVpnIp) + sendTo := netip.AddrFrom4(b) + w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { @@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { } if c.relay != nil { - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) + //TODO: IPV6-WORK + relays := make([]uint32, len(c.relay.relay)) + b := [4]byte{} + for i, _ := range relays { + b = c.relay.relay[i].As4() + relays[i] = binary.BigEndian.Uint32(b[:]) + } + n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + certVpnIp := netip.AddrFrom4(b) + am := lhh.lh.unlockedGetRemoteList(certVpnIp) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): + case lhh.lh.handshakeTrigger <- certVpnIp: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } //Simple check that the host sent this not someone else - if n.Details.VpnIp != uint32(vpnIp) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + detailsVpnIp := netip.AddrFrom4(b) + if detailsVpnIp != vpnIp { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } @@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - n.Details.VpnIp = uint32(vpnIp) + + //TODO: IPV6-WORK + vpnIpB := vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) ln, err := n.MarshalTo(lhh.pb) if err != nil { @@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } empty := []byte{0} - punch := func(vpnPeer *udp.Addr) { - if vpnPeer == nil { + punch := func(vpnPeer netip.AddrPort) { + if !vpnPeer.IsValid() { return } @@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i if lhh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK, make this debug line not suck + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) } } for _, a := range n.Details.Ip4AndPorts { - punch(NewUDPAddrFromLH4(a)) + punch(AddrPortFromIp4AndPort(a)) } for _, a := range n.Details.Ip6AndPorts { - punch(NewUDPAddrFromLH6(a)) + punch(AddrPortFromIp6AndPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - queryVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { @@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i }() } } - -// ipMaskContains checks if testIp is contained by ip after applying a cidr -// zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { - return (testIp^ip)>>zeros == 0 -} diff --git a/lighthouse_test.go b/lighthouse_test.go index 66427e3..2599f5f 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -2,15 +2,14 @@ package nebula import ( "context" + "encoding/binary" "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) { var m Ip4AndPort err := m.Unmarshal(b) assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) + ip := netip.MustParseAddr("10.1.1.1") + bp := ip.As4() + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) } func TestNewLhQuery(t *testing.T) { - myIp := net.ParseIP("192.1.1.1") - myIpint := iputil.Ip2VpnIp(myIp) + myIp, err := netip.ParseAddr("192.1.1.1") + assert.NoError(t, err) // Generating a new lh query should work - a := NewLhQueryByInt(myIpint) + a := NewLhQueryByInt(myIp) // The result should be a nebulameta protobuf assert.IsType(t, &NebulaMeta{}, a) @@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) { lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - c.ReloadConfigString("lighthouse:\n interval: 5") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - c.ReloadConfigString("lighthouse:\n interval: 10") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - c.ReloadConfigString("lighthouse:\n interval: 11") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") + myVpnNet := netip.MustParsePrefix("10.128.0.1/0") c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) @@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { b.Fatal() } - hAddr := udp.NewAddrFromString("4.5.6.7:12345") - hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList(nil) - lh.addrMap[3].unlockedSetV4( - 3, - 3, + hAddr := netip.MustParseAddrPort("4.5.6.7:12345") + hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") + + vpnIp3 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), + NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) - rAddr := udp.NewAddrFromString("1.2.2.3:12345") - rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList(nil) - lh.addrMap[2].unlockedSetV4( - 3, - 3, + rAddr := netip.MustParseAddrPort("1.2.2.3:12345") + rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") + vpnIp2 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), + NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} @@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) } @@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() - myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) - - theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) + myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") + myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") + myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") + myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") + myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") + myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") + myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") + myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") + myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") + myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") + myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") + myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") + myVpnIp := netip.MustParseAddr("10.128.0.2") + + theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") + theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") + theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") + theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") + theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") + theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) @@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) { newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udp.Addr{ + []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) { ) // Make sure we won't add ips in our vpn network - bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) + bad1 := netip.MustParseAddrPort("10.128.0.99:4242") + bad2 := netip.MustParseAddrPort("10.128.0.100:4242") + good := netip.MustParseAddrPort("1.128.0.99:4242") + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } @@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) { assert.NoError(t, err) } -func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { + //TODO: IPV6-WORK + bip := queryVpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(queryVpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), }, } @@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh return w.lastReply } -func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { + //TODO: IPV6-WORK + bip := vpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(vpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} + req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, // ) //} -func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) -} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType - vpnIp iputil.VpnIp + vpnIp netip.Addr msg *NebulaMeta } @@ -414,7 +416,7 @@ type testEncWriter struct { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } -func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { +func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { @@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { - if !assert.Len(t, have, len(want)) { - return - } - - for k, w := range want { - if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) - } - } -} - -// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { +func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) + //TODO: IPV6-WORK + h := AddrPortFromIp4AndPort(have[k]) + if !(h == w) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } - -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { - addrs := make([]*udp.Addr, len(ips)) - for k, v := range ips { - addrs[k] = NewUDPAddrFromLH4(v) - } - return addrs -} diff --git a/main.go b/main.go index 8c94e80..248f329 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - // TODO: make sure mask is 4 bytes - tunCidr := certificate.Details.Ips[0] + ones, _ := certificate.Details.Ips[0].Mask.Size() + addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + err = util.NewContextualError( + "Invalid ip address in certificate", + m{"vpnIp": certificate.Details.Ips[0].IP}, + nil, + ) + return nil, err + } + tunCidr := netip.PrefixFrom(addr, ones) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { @@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") - var listenHost *net.IPAddr + var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. - listenHost = &net.IPAddr{IP: net.IPv6zero} + listenHost = netip.IPv6Unspecified() } else { - listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } + if len(ips) == 0 { + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) + } + listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { - l.Infof("listening %q %d", listenHost.IP, port) - udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -178,57 +192,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - port = int(uPort.Port) - } - } - } - - // Set up my internal host map - var preferredRanges []*net.IPNet - rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) - // First, check if 'preferred_ranges' is set and fallback to 'local_range' - if len(rawPreferredRanges) > 0 { - for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err) - } - preferredRanges = append(preferredRanges, preferredRange) - } - } - - // local_range was superseded by preferred_ranges. If it is still present, - // merge the local_range setting into preferred_ranges. We will probably - // deprecate local_range and remove in the future. - rawLocalRange := c.GetString("local_range", "") - if rawLocalRange != "" { - _, localRange, err := net.ParseCIDR(rawLocalRange) - if err != nil { - return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err) - } - - // Check if the entry for local_range was already specified in - // preferred_ranges. Don't put it into the slice twice if so. - var found bool - for _, r := range preferredRanges { - if r.String() == localRange.String() { - found = true - break + port = int(uPort.Port()) } } - if !found { - preferredRanges = append(preferredRanges, localRange) - } } - hostMap := NewHostMap(l, tunCidr, preferredRanges) - hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) - - l. - WithField("network", hostMap.vpnCIDR.String()). - WithField("preferredRanges", hostMap.preferredRanges). - Info("Main HostMap created") - + hostMap := NewHostMapFromConfig(l, tunCidr, c) punchy := NewPunchyFromConfig(l, c) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) if err != nil { diff --git a/noise.go b/noise.go index 91ad2c0..57990a7 100644 --- a/noise.go +++ b/noise.go @@ -28,11 +28,11 @@ func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { // EncryptDanger encrypts and authenticates a given payload. // // out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. +// - ad is additional data, which will be authenticated and appended to out, but not encrypted. +// - plaintext is encrypted, authenticated and appended to out. +// - n is a nonce value which must never be re-used with this key. +// - nb is a buffer used for temporary storage in the implementation of this call, which should +// be re-used by callers to minimize garbage collection. func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { // TODO: Is this okay now that we have made messageCounter atomic? diff --git a/noiseutil/nist.go b/noiseutil/nist.go index 90e77ab..976a274 100644 --- a/noiseutil/nist.go +++ b/noiseutil/nist.go @@ -48,7 +48,7 @@ func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) { } ecdhPrivKey, err := c.curve.NewPrivateKey(privkey) if err != nil { - return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) + return nil, fmt.Errorf("unable to unmarshal private key: %w", err) } return ecdhPrivKey.ECDH(ecdhPubKey) diff --git a/outside.go b/outside.go index 2918911..be60294 100644 --- a/outside.go +++ b/outside.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "net/netip" "time" "github.com/flynn/noise" @@ -11,7 +12,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" "google.golang.org/protobuf/proto" @@ -21,9 +21,10 @@ const ( minFwPacketLen = 4 ) +// TODO: IPV6-WORK this can likely be removed now func readOutsidePackets(f *Interface) udp.EncReader { return func( - addr *udp.Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader { } } -func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if addr != nil { - if ip4 := addr.IP.To4(); ip4 != nil { - if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") - } - return + if ip.IsValid() { + if f.myVpnNet.Contains(ip.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } + return } } @@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } @@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") @@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt return } - lhf(addr, hostinfo.vpnIp, d) + lhf(ip, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") @@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(addr, via, packet, h) + f.handshakeManager.HandleIncoming(ip, via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(addr, h) + f.handleRecvError(ip, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", ip). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo.localIndexId) } @@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { - if addr != nil && !hostinfo.remote.Equals(addr) { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { + if ip.IsValid() && hostinfo.remote != ip { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(addr) + hostinfo.SetRemote(ip) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr != nil { + if addr.IsValid() { f.maybeSendRecvError(addr, h.RemoteIndex) return false } else { @@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) - fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) - fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -404,7 +405,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in @@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return true } -func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { +func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { + if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index) } } -func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { +func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) //TODO: this should be a signed message so we can trust that we should drop the index @@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { +func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). @@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { + if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } diff --git a/outside_test.go b/outside_test.go index 682107b..f9d4bfa 100644 --- a/outside_test.go +++ b/outside_test.go @@ -2,10 +2,10 @@ package nebula import ( "net" + "net/netip" "testing" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) @@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } diff --git a/overlay/device.go b/overlay/device.go index 3f3f2eb..50ad6ad 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -2,16 +2,14 @@ package overlay import ( "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type Device interface { io.ReadWriteCloser Activate() error - Cidr() *net.IPNet + Cidr() netip.Prefix Name() string - RouteFor(iputil.VpnIp) iputil.VpnIp + RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 793c8fd..8ccc994 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -4,38 +4,64 @@ import ( "fmt" "math" "net" + "net/netip" "runtime" "strconv" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type Route struct { MTU int Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + Cidr netip.Prefix + Via netip.Addr Install bool } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { - routeTree := cidr.NewTree4[iputil.VpnIp]() +// Equal determines if a route that could be installed in the system route table is equal to another +// Via is ignored since that is only consumed within nebula itself +func (r Route) Equal(t Route) bool { + if r.Cidr != t.Cidr { + return false + } + if r.Metric != t.Metric { + return false + } + if r.MTU != t.MTU { + return false + } + if r.Install != t.Install { + return false + } + return true +} + +func (r Route) String() string { + s := r.Cidr.String() + if r.Metric != 0 { + s += fmt.Sprintf(" metric: %v", r.Metric) + } + return s +} + +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { + routeTree := new(bart.Table[netip.Addr]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via != nil { - routeTree.AddCIDR(r.Cidr, *r.Via) + if r.Via.IsValid() { + routeTree.Insert(r.Cidr, r.Via) } } return routeTree, nil } -func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -86,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { MTU: mtu, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !ipWithin(network, r.Cidr) { + if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { return nil, fmt.Errorf( "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", i+1, @@ -106,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -172,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) } - nVia := net.ParseIP(via) - if nVia == nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + viaVpnIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) } rRoute, ok := m["route"] @@ -182,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } - viaVpnIp := iputil.Ip2VpnIp(nVia) - install := true rInstall, ok := m["install"] if ok { @@ -194,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - Via: &viaVpnIp, + Via: viaVpnIp, MTU: mtu, Metric: metric, Install: install, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if ipWithin(network, r.Cidr) { + if network.Contains(r.Cidr.Addr()) { return nil, fmt.Errorf( "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", i+1, diff --git a/overlay/route_test.go b/overlay/route_test.go index 46fb87c..d791389 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -2,11 +2,10 @@ package overlay import ( "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -14,7 +13,8 @@ import ( func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseRoutes(c, n) @@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} @@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, n) @@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") + assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} @@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} @@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, @@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) { routeTree, err := makeRouteTree(l, routes, true) assert.NoError(t, err) - ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - ok, r := routeTree.MostSpecificContains(ip) + ip, err := netip.ParseAddr("1.0.0.2") + assert.NoError(t, err) + r, ok := routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err := netip.ParseAddr("192.168.0.1") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.0.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err = netip.ParseAddr("192.168.0.2") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.1.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/overlay/tun.go b/overlay/tun.go index ca1a64a..12460da 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,7 +1,7 @@ package overlay import ( - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -10,60 +10,63 @@ import ( const DefaultMTU = 1300 -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) - -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { - routes, err := parseRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.routes", nil, err) - } - - unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) - } - routes = append(routes, unsafeRoutes...) +// TODO: We may be able to remove routines +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun( - l, - c.GetString("tun.dev", ""), - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - routines > 1, - c.GetBool("tun.use_system_route_table", false), - ) + return newTun(c, l, tunCidr, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { - routes, err := parseRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.routes", nil, err) - } + return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, tunCidr) + } +} + +func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { + if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { + return false, nil, nil + } - unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr) - if err != nil { - return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + routes, err := parseRoutes(c, cidr) + if err != nil { + return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) + } + + unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + if err != nil { + return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) + } + + routes = append(routes, unsafeRoutes...) + return true, routes, nil +} + +// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table. +// Via is not used to evaluate since it does not affect the system route table. +func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route { + var removed []Route + has := func(entry Route) bool { + for _, check := range newRoutes { + if check.Equal(entry) { + return true + } } - routes = append(routes, unsafeRoutes...) - return newTunFromFd( - l, - *fd, - tunCidr, - c.GetInt("tun.mtu", DefaultMTU), - routes, - c.GetInt("tun.tx_queue", 500), - c.GetBool("tun.use_system_route_table", false), - ) + return false + } + for _, oldEntry := range oldRoutes { + if !has(oldEntry) { + removed = append(removed, oldEntry) + } } + + return removed } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c5c52db..98ad9b4 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,47 +6,58 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" + "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int - cidr *net.IPNet - routeTree *cidr.Tree4[iputil.VpnIp] + cidr netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - return &tun{ + t := &tun{ ReadWriteCloser: file, fd: deviceFd, cidr: cidr, l: l, - routeTree: routeTree, - }, nil + } + + err := t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -54,7 +65,28 @@ func (t tun) Activate() error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes + t.Routes.Store(&routes) + t.routeTree.Store(routeTree) + return nil +} + +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index caec580..0b573e6 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -8,13 +8,16 @@ import ( "fmt" "io" "net" + "net/netip" "os" + "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) @@ -22,10 +25,11 @@ import ( type tun struct { io.ReadWriteCloser Device string - cidr *net.IPNet + cidr netip.Prefix DefaultMTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + linkAddr *netroute.LinkAddr l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata @@ -69,12 +73,8 @@ type ifreqMTU struct { pad [8]byte } -func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { + name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { _, err := fmt.Sscanf(name, "utun%d", &ifIndex) @@ -142,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout file := os.NewFile(uintptr(fd), "") - tun := &tun{ + t := &tun{ ReadWriteCloser: file, Device: name, cidr: cidr, - DefaultMTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } - return tun, nil + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) deviceBytes() (o [16]byte) { @@ -162,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -178,8 +188,13 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + if !t.cidr.Addr().Is4() { + //TODO: IPV6-WORK + panic("need ipv6") + } + + addr = t.cidr.Addr().As4() + copy(mask[:], prefixToMask(t.cidr)) s, err := unix.Socket( unix.AF_INET, @@ -260,6 +275,7 @@ func (t *tun) Activate() error { if linkAddr == nil { return fmt.Errorf("unable to discover link_addr for tun interface") } + t.linkAddr = linkAddr copy(routeAddr.IP[:], addr[:]) copy(maskAddr.IP[:], mask[:]) @@ -278,38 +294,52 @@ func (t *tun) Activate() error { } // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } + return t.addRoutes(false) +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if !initial && !change { + return nil + } - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { - if errors.Is(err, unix.EEXIST) { - t.l.WithField("route", r.Cidr). - Warnf("unable to add unsafe_route, identical route already exists") - } else { - return err - } + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } - // TODO how to set metric + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } } return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - ok, r := t.routeTree.MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - - return 0 + return netip.Addr{} } // Get the LinkAddr for the interface of the given name @@ -340,6 +370,99 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { return nil, nil } +func (t *tun) addRoutes(logErrors bool) error { + routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + + defer func() { + unix.Shutdown(routeSock, unix.SHUT_RDWR) + err := unix.Close(routeSock) + if err != nil { + t.l.WithError(err).Error("failed to close AF_ROUTE socket") + } + }() + + routeAddr := &netroute.Inet4Addr{} + maskAddr := &netroute.Inet4Addr{} + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + if !r.Cidr.Addr().Is4() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + //TODO: we could avoid the copy + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) + + err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + if err != nil { + if errors.Is(err, unix.EEXIST) { + t.l.WithField("route", r.Cidr). + Warnf("unable to add unsafe_route, identical route already exists") + } else { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } else { + t.l.WithField("route", r).Info("Added route") + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + + defer func() { + unix.Shutdown(routeSock, unix.SHUT_RDWR) + err := unix.Close(routeSock) + if err != nil { + t.l.WithError(err).Error("failed to close AF_ROUTE socket") + } + }() + + routeAddr := &netroute.Inet4Addr{} + maskAddr := &netroute.Inet4Addr{} + + for _, r := range routes { + if !r.Install { + continue + } + + if r.Cidr.Addr().Is6() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) + + err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { r := netroute.RouteMessage{ Version: unix.RTM_VERSION, @@ -365,6 +488,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } +func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { + r := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + Addrs: []netroute.Addr{ + unix.RTAX_DST: addr, + unix.RTAX_GATEWAY: link, + unix.RTAX_NETMASK: mask, + }, + } + + data, err := r.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + func (t *tun) Read(to []byte) (int, error) { buf := make([]byte, len(to)+4) @@ -404,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -415,3 +562,11 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +func prefixToMask(prefix netip.Prefix) []byte { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + return net.CIDRMask(prefix.Bits(), pLen) +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index e1e4ede..130f8f9 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -3,7 +3,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "strings" "github.com/rcrowley/go-metrics" @@ -13,7 +13,7 @@ import ( type disabledTun struct { read chan []byte - cidr *net.IPNet + cidr netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,7 +21,7 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ cidr: cidr, read: make(chan []byte, queueLen), @@ -43,11 +43,11 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } -func (t *disabledTun) Cidr() *net.IPNet { +func (t *disabledTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 338b8f6..bdfeb58 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,16 +9,18 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "os" "os/exec" "strconv" + "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) const ( @@ -45,10 +47,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -76,14 +78,15 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error + deviceName := c.GetString("tun.dev", "") if deviceName != "" { file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) } @@ -144,59 +147,97 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - - return &tun{ + t := &tun{ ReadWriteCloser: file, Device: deviceName, cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, - }, nil + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) - if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil { + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) - if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil { + + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) } - t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) - if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil { + + cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) + t.l.Debug("command: ", cmd.String()) + if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } + // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue + return t.addRoutes(false) +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } - t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) - if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil { - return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -208,6 +249,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index ce65b33..20981f0 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,46 +7,80 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "sync" + "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - cidr *net.IPNet - routeTree *cidr.Tree4[iputil.VpnIp] + cidr netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, false) +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { + file := os.NewFile(uintptr(deviceFd), "/dev/tun") + t := &tun{ + cidr: cidr, + ReadWriteCloser: &tunReadCloser{f: file}, + l: l, + } + + err := t.reload(c, true) if err != nil { return nil, err } - file := os.NewFile(uintptr(deviceFd), "/dev/tun") - return &tun{ - cidr: cidr, - ReadWriteCloser: &tunReadCloser{f: file}, - routeTree: routeTree, - }, nil + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) Activate() error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes + t.Routes.Store(&routes) + t.routeTree.Store(routeTree) + return nil +} + +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -108,7 +142,7 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index a576bf3..0e7e20d 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,33 +4,36 @@ package overlay import ( - "bytes" "fmt" "io" "net" + "net/netip" "os" "strings" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser - fd int - Device string - cidr *net.IPNet - MaxMTU int - DefaultMTU int - TXQueueLen int - - Routes []Route - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + fd int + Device string + cidr netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr + + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] routeChan chan struct{} useSystemRoutes bool @@ -61,33 +64,40 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) { - routeTree, err := makeRouteTree(l, routes, true) +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { + file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + + t, err := newTunGeneric(c, l, file, cidr) if err != nil { return nil, err } - file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + t.Device = "tun0" - t := &tun{ - ReadWriteCloser: file, - fd: int(file.Fd()), - Device: "tun0", - cidr: cidr, - DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - Routes: routes, - useSystemRoutes: useSystemRoutes, - l: l, - } - t.routeTree.Store(routeTree) return t, nil } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { - return nil, err + // If /dev/net/tun doesn't exist, try to create it (will happen in docker) + if os.IsNotExist(err) { + err = os.MkdirAll("/dev/net", 0755) + if err != nil { + return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) + } + err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) + if err != nil { + return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) + } + + fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) + if err != nil { + return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) + } + } else { + return nil, err + } } var req ifReq @@ -95,46 +105,113 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } - copy(req.Name[:], deviceName) + copy(req.Name[:], c.GetString("tun.dev", "")) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - - maxMTU := defaultMTU - for _, r := range routes { - if r.MTU == 0 { - r.MTU = defaultMTU - } - - if r.MTU > maxMTU { - maxMTU = r.MTU - } - } - - routeTree, err := makeRouteTree(l, routes, true) + t, err := newTunGeneric(c, l, file, cidr) if err != nil { return nil, err } + t.Device = name + + return t, nil +} + +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - Device: name, cidr: cidr, - MaxMTU: maxMTU, - DefaultMTU: defaultMTU, - TXQueueLen: txQueueLen, - Routes: routes, - useSystemRoutes: useSystemRoutes, + TXQueueLen: c.GetInt("tun.tx_queue", 500), + useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, } - t.routeTree.Store(routeTree) + + err := t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + return t, nil } +func (t *tun) reload(c *config.C, initial bool) error { + routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !routeChange && !c.HasChanged("tun.mtu") { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, true) + if err != nil { + return err + } + + oldDefaultMTU := t.DefaultMTU + oldMaxMTU := t.MaxMTU + newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU) + newMaxMTU := newDefaultMTU + for i, r := range routes { + if r.MTU == 0 { + routes[i].MTU = newDefaultMTU + } + + if r.MTU > t.MaxMTU { + newMaxMTU = r.MTU + } + } + + t.MaxMTU = newMaxMTU + t.DefaultMTU = newDefaultMTU + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + if oldMaxMTU != newMaxMTU { + t.setMTU() + t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) + } + + if oldDefaultMTU != newDefaultMTU { + err := t.setDefaultRoute() + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } + } + + // Remove first, if the system removes a wanted route hopefully it will be re-added next + t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // This should never be called since addRoutes should log its own errors in a reload condition + util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l) + } + } + + return nil +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { @@ -153,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -197,8 +274,10 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + //TODO: IPV6-WORK + addr = t.cidr.Addr().As4() + tmask := net.CIDRMask(t.cidr.Bits(), 32) + copy(mask[:], tmask) s, err := unix.Socket( unix.AF_INET, @@ -208,7 +287,7 @@ func (t *tun) Activate() error { if err != nil { return err } - fd := uintptr(s) + t.ioctlFd = uintptr(s) ifra := ifreqAddr{ Name: devName, @@ -219,75 +298,114 @@ func (t *tun) Activate() error { } // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { return fmt.Errorf("failed to set tun address: %s", err) } // Set the device network ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { return fmt.Errorf("failed to set tun netmask: %s", err) } // Set the device name ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } - // Set the MTU on the device - ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)} - if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { - // This is currently a non fatal condition because the route table must have the MTU set appropriately as well - t.l.WithError(err).Error("Failed to set tun mtu") - } + // Setup our default MTU + t.setMTU() // Set the transmit queue length ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss t.l.WithError(err).Error("Failed to set tun tx queue length") } // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - // Set the routes link, err := netlink.LinkByName(t.Device) if err != nil { return fmt.Errorf("failed to get tun device link: %s", err) } + t.deviceIndex = link.Attrs().Index + if err = t.setDefaultRoute(); err != nil { + return err + } + + // Set the routes + if err = t.addRoutes(false); err != nil { + return err + } + + // Run the interface + ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to run tun device: %s", err) + } + + return nil +} + +func (t *tun) setMTU() { + // Set the MTU on the device + ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} + if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { + // This is currently a non fatal condition because the route table must have the MTU set appropriately as well + t.l.WithError(err).Error("Failed to set tun mtu") + } +} + +func (t *tun) setDefaultRoute() error { // Default route - dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} + + dr := &net.IPNet{ + IP: t.cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + } + nr := netlink.Route{ - LinkIndex: link.Attrs().Index, + LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: t.cidr.IP, + Src: net.IP(t.cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } - err = netlink.RouteReplace(&nr) + err := netlink.RouteReplace(&nr) if err != nil { return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) } + return nil +} + +func (t *tun) addRoutes(logErrors bool) error { // Path routes - for _, r := range t.Routes { + routes := *t.Routes.Load() + for _, r := range routes { if !r.Install { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: r.Cidr, + LinkIndex: t.deviceIndex, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -297,22 +415,55 @@ func (t *tun) Activate() error { nr.Priority = r.Metric } - err = netlink.RouteAdd(&nr) + err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err) + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") } } - // Run the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to run tun device: %s", err) - } - return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) removeRoutes(routes []Route) { + for _, r := range routes { + if !r.Install { + continue + } + + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: dr, + MTU: r.MTU, + AdvMSS: t.advMSS(r), + Scope: unix.RT_SCOPE_LINK, + } + + if r.Metric > 0 { + nr.Priority = r.Metric + } + + err := netlink.RouteDel(&nr) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } +} + +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -364,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - if !t.cidr.Contains(r.Gw) { + //TODO: IPV6-WORK what if not ok? + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + return + } + + gwAddr = gwAddr.Unmap() + if !t.cidr.Contains(gwAddr) { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, not in our network") return @@ -376,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4[iputil.VpnIp]() - if r.Type == unix.RTM_NEWROUTE { - for _, oldR := range t.routeTree.Load().List() { - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + return + } + ones, _ := r.Dst.Mask.Size() + dst := netip.PrefixFrom(dstAddr, ones) + + newTree := t.routeTree.Load().Clone() + + if r.Type == unix.RTM_NEWROUTE { t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + newTree.Insert(dst, gwAddr) } else { - gw := iputil.Ip2VpnIp(r.Gw) - for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { - // This is the record to delete - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") - continue - } - - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + newTree.Delete(dst) + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } - t.routeTree.Store(newTree) } @@ -410,5 +566,9 @@ func (t *tun) Close() error { t.ReadWriteCloser.Close() } + if t.ioctlFd > 0 { + os.NewFile(t.ioctlFd, "ioctlFd").Close() + } + return nil } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index b1135fe..24ab24f 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,17 +6,19 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" "strconv" + "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) type ifreqDestroy struct { @@ -26,10 +28,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -56,56 +58,63 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error + deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } if !deviceNameRE.MatchString(deviceName) { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } - file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) if err != nil { return nil, err } - routeTree, err := makeRouteTree(l, routes, false) + t := &tun{ + ReadWriteCloser: file, + Device: deviceName, + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + err = t.reload(c, true) if err != nil { return nil, err } - return &tun{ - ReadWriteCloser: file, - Device: deviceName, - cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, - l: l, - }, nil + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -116,29 +125,54 @@ func (t *tun) Activate() error { if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } + // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue + return t.addRoutes(false) +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -150,6 +184,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } + } + + return nil +} + +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 45c06dc..6463ccb 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,24 +6,26 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" "strconv" + "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" ) type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -40,13 +42,14 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { + deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") } @@ -60,26 +63,70 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int return nil, err } - routeTree, err := makeRouteTree(l, routes, false) - if err != nil { - return nil, err - } - - return &tun{ + t := &tun{ ReadWriteCloser: file, Device: deviceName, cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, - }, nil + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} + +func (t *tun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil } func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -91,35 +138,62 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) } // Unsafe path routes - for _, r := range t.Routes { - if r.Via == nil || !r.Install { + return t.addRoutes(false) +} + +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) + return r +} + +func (t *tun) addRoutes(logErrors bool) error { + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err) + if err := cmd.Run(); err != nil { + retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } } } return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) - return r +func (t *tun) removeRoutes(routes []Route) error { + for _, r := range routes { + if !r.Install { + continue + } + + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + t.l.Debug("command: ", cmd.String()) + if err := cmd.Run(); err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 964315a..ba15723 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -6,20 +6,20 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" ) type TestTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + routeTree *bart.Table[netip.Addr] l *logrus.Logger closed atomic.Bool @@ -27,14 +27,18 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, cidr, true) + if err != nil { + return nil, err + } routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err } return &TestTun{ - Device: deviceName, + Device: c.GetString("tun.dev", ""), cidr: cidr, Routes: routes, routeTree: routeTree, @@ -44,7 +48,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes }, nil } -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -82,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Lookup(ip) return r } @@ -91,7 +95,7 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() *net.IPNet { +func (t *TestTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index e27cff2..d78f564 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -4,38 +4,50 @@ import ( "fmt" "io" "net" + "net/netip" "os/exec" "strconv" + "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/songgao/water" ) type waterTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] - + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + f *net.Interface *water.Interface } -func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) { - routeTree, err := makeRouteTree(l, routes, false) +func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { + // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() + t := &waterTun{ + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err := t.reload(c, true) if err != nil { return nil, err } - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - return &waterTun{ - cidr: cidr, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, - }, nil + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil } func (t *waterTun) Activate() error { @@ -58,8 +70,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.IP), - fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), + fmt.Sprintf("addr=%s", t.cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -74,34 +86,108 @@ func (t *waterTun) Activate() error { return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) } - iface, err := net.InterfaceByName(t.Device) + t.f, err = net.InterfaceByName(t.Device) if err != nil { return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) } - for _, r := range t.Routes { - if r.Via == nil || !r.Install { + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *waterTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + if err != nil { + return err + } + + if !initial && !change { + return nil + } + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to set routes", err, t.l) + } else { + for _, r := range findRemovedRoutes(routes, *oldRoutes) { + t.l.WithField("route", r).Info("Removed route") + } + } + } + + return nil +} + +func (t *waterTun) addRoutes(logErrors bool) error { + // Path routes + routes := *t.Routes.Load() + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - err = exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric), + err := exec.Command( + "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), ).Run() + if err != nil { - return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err) + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") } } return nil } -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *waterTun) removeRoutes(routes []Route) { + for _, r := range routes { + if !r.Install { + continue + } + + err := exec.Command( + "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), + ).Run() + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } +} + +func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *waterTun) Cidr() *net.IPNet { +func (t *waterTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 57d90cb..3d88309 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -5,20 +5,21 @@ package overlay import ( "fmt" - "net" + "net/netip" "os" "path/filepath" "runtime" "syscall" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") @@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int } if useWintun { - device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes) + device, err := newWinTun(c, l, cidr, multiqueue) if err != nil { return nil, fmt.Errorf("create Wintun interface failed, %w", err) } return device, nil } - device, err := newWaterTun(l, cidr, defaultMTU, routes) + device, err := newWaterTun(c, l, cidr, multiqueue) if err != nil { return nil, fmt.Errorf("create wintap driver failed, %w", err) } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 9647024..d010387 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -4,13 +4,14 @@ import ( "crypto" "fmt" "io" - "net" "net/netip" + "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" - "github.com/slackhq/nebula/iputil" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -20,11 +21,11 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { Device string - cidr *net.IPNet - prefix netip.Prefix + cidr netip.Prefix MTU int - Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger tun *wintun.NativeTun } @@ -48,83 +49,131 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { + deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } + t := &winTun{ + Device: deviceName, + cidr: cidr, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err + } + var tunDevice wintun.Device - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. l.WithError(err).Debug("Failed to create wintun device, retrying") - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU) + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { return nil, fmt.Errorf("create TUN device failed: %w", err) } } + t.tun = tunDevice.(*wintun.NativeTun) + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} - routeTree, err := makeRouteTree(l, routes, false) +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) if err != nil { - return nil, err + return err + } + + if !initial && !change { + return nil } - prefix, err := iputil.ToNetIpPrefix(*cidr) + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { - return nil, err + return err } - return &winTun{ - Device: deviceName, - cidr: cidr, - prefix: prefix, - MTU: defaultMTU, - Routes: routes, - routeTree: routeTree, + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } - tun: tunDevice.(*wintun.NativeTun), - }, nil + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil } func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil { + err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) + if err != nil { return fmt.Errorf("failed to set address: %w", err) } + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() foundDefault4 := false - routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1) - for _, r := range t.Routes { - if r.Via == nil || !r.Install { + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { - foundDefault4 = true + // Add our unsafe route + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr } + } else { + t.l.WithField("route", r).Info("Added route") } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - return err + if !foundDefault4 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { + foundDefault4 = true + } } - - // Add our unsafe route - routes = append(routes, &winipcfg.RouteData{ - Destination: prefix, - NextHop: r.Via.ToNetIpAddr(), - Metric: uint32(r.Metric), - }) - } - - if err := luid.AddRoutes(routes); err != nil { - return fmt.Errorf("failed to add routes: %w", err) } ipif, err := luid.IPInterface(windows.AF_INET) @@ -141,16 +190,33 @@ func (t *winTun) Activate() error { if err := ipif.Set(); err != nil { return fmt.Errorf("failed to set ip interface: %w", err) } + return nil +} + +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + for _, r := range routes { + if !r.Install { + continue + } + + err := luid.DeleteRoute(r.Cidr, r.Via) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } return nil } -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *winTun) Cidr() *net.IPNet { +func (t *winTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/user.go b/overlay/user.go index 9d819ae..1bb4ef5 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,18 +2,17 @@ package overlay import ( "io" - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return NewUserDevice(tunCidr) } -func NewUserDevice(tunCidr *net.IPNet) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() @@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) { } type UserDevice struct { - tunCidr *net.IPNet + tunCidr netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -39,9 +38,9 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/pki.go b/pki.go index 91478ce..ab95a04 100644 --- a/pki.go +++ b/pki.go @@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { } if !initial { + //TODO: include check for mask equality as well + // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate oldIPs := currentCert.Details.Ips diff --git a/relay_manager.go b/relay_manager.go index 7aa06cc..1a3a4d4 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -2,14 +2,15 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" ) type relayManager struct { @@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for i := 0; i < 32; i++ { @@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTo": iputil.VpnIp(m.RelayToIp), + "relayFrom": m.RelayFromIp, + "relayTo": m.RelayToIp, "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnIp": h.vpnIp}). Info("handleCreateRelayResponse") - target := iputil.VpnIp(m.RelayToIp) + target := m.RelayToIp + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + targetAddr := netip.AddrFrom4(b) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } - peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) + peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } if peerRelay.State == PeerRequested { + //TODO: IPV6-WORK + b = peerHostInfo.vpnIp.As4() peerRelay.State = Established resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayFromIp: binary.BigEndian.Uint32(b[:]), RelayToIp: uint32(target), } msg, err := resp.Marshal() @@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": resp.RelayFromIp, + "relayTo": resp.RelayToIp, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": peerHostInfo.vpnIp}). @@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayFromIp) + from := netip.AddrFrom4(b) - from := iputil.VpnIp(m.RelayFromIp) - target := iputil.VpnIp(m.RelayToIp) + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + target := netip.AddrFrom4(b) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, @@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnIp { - logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + if from == f.myVpnNet.Addr() { + logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - if target == f.myVpnIp { + if target == f.myVpnNet.Addr() { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N return } + //TODO: IPV6-WORK + fromB := from.As4() + targetB := target.As4() + resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(from), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now + "relayFrom": from, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.Handshake(target) return } - if peer.remote == nil { + if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } @@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N sendCreateRequest = true } if sendCreateRequest { + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := req.Marshal() if err != nil { @@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + //TODO: IPV6-WORK another lazy used to use the req object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": target}). @@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK more lazy, used to use resp object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } } } - -func (rm *relayManager) RemoveRelay(localIdx uint32) { - rm.hostmap.RemoveRelay(localIdx) -} diff --git a/remote_list.go b/remote_list.go index 60a1afd..fa14f42 100644 --- a/remote_list.go +++ b/remote_list.go @@ -1,7 +1,6 @@ package nebula import ( - "bytes" "context" "net" "net/netip" @@ -12,16 +11,14 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udp.Addr, preferred bool) +type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -30,9 +27,9 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udp.Addr `json:"learned,omitempty"` - Reported []*udp.Addr `json:"reported,omitempty"` - Relay []*net.IP `json:"relay"` + Learned []netip.AddrPort `json:"learned,omitempty"` + Reported []netip.AddrPort `json:"reported,omitempty"` + Relay []netip.Addr `json:"relay"` } //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion @@ -46,7 +43,7 @@ type cache struct { } type cacheRelay struct { - relay []uint32 + relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache @@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, continue } for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() @@ -193,22 +190,22 @@ type RemoteList struct { sync.RWMutex // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udp.Addr + addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. - relays []*iputil.VpnIp + relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[iputil.VpnIp]*cache + cache map[netip.Addr]*cache hr *hostnamesResults shouldAdd func(netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udp.Addr + badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool @@ -217,9 +214,9 @@ type RemoteList struct { // NewRemoteList creates a new empty RemoteList func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]netip.AddrPort, 0), + relays: make([]netip.Addr, 0), + cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } } @@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { +func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() @@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { +func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { - forEach(v, isPreferred(v.IP, preferredRanges)) + forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { +func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } @@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.addrs)) + c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { - c[i] = v.Copy() + c[i] = v } return c } @@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available // TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { +func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() - if v4 := addr.IP.To4(); v4 != nil { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + if remote.Addr().Is4() { + r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) } } @@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udp.Addr, 0), - Reported: make([]*udp.Addr, 0), - Relay: make([]*net.IP, 0), + Learned: make([]netip.AddrPort, 0), + Reported: make([]netip.AddrPort, 0), + Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } @@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { - nip := iputil.VpnIp(a).ToIP() - c.Relay = append(c.Relay, &nip) + c.Relay = append(c.Relay, a) } } } @@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udp.Addr) { - if bad == nil { +func (r *RemoteList) BlockRemote(bad netip.AddrPort) { + if !bad.IsValid() { // relays can have nil udp Addrs return } @@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) { } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad.Copy()) + r.badRemotes = append(r.badRemotes, bad) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { +func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.badRemotes)) + c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { - c[i] = v.Copy() + c[i] = v } return c } @@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() { // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so -func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { +func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() @@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { +func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { for _, v := range r.badRemotes { - if v.Equals(remote) { + if v == remote { return true } } @@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) } } -func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { +func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := NewUDPAddrFromLH4(c.v4.learned) + u := AddrPortFromIp4AndPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := NewUDPAddrFromLH4(v) + u := AddrPortFromIp4AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := NewUDPAddrFromLH6(c.v6.learned) + u := AddrPortFromIp6AndPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := NewUDPAddrFromLH6(v) + u := AddrPortFromIp6AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() { if c.relay != nil { for _, v := range c.relay.relay { - ip := iputil.VpnIp(v) - relays = append(relays, &ip) + relays = append(relays, v) } } } @@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) + addrs = append(addrs, addr) } } @@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() { } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list -func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { +func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { n := len(r.addrs) if n < 2 { return @@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { b := r.addrs[j] // Preferred addresses first - aPref := isPreferred(a.IP, preferredRanges) - bPref := isPreferred(b.IP, preferredRanges) + aPref := isPreferred(a.Addr(), preferredRanges) + bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j @@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // ipv6 addresses 2nd - a4 := a.IP.To4() - b4 := b.IP.To4() + a4 := a.Addr().Is4() + b4 := b.Addr().Is4() switch { - case a4 == nil && b4 != nil: + case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true - case a4 != nil && b4 == nil: + case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false - case a4 != nil && b4 != nil: - // Special case for ipv4, a4 and b4 are not nil - aPrivate := isPrivateIP(a4) - bPrivate := isPrivateIP(b4) + case a4 == true && b4 == true: + // i and j are both ipv4 + aPrivate := a.Addr().IsPrivate() + bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j @@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // lexical order of ips 3rd - c := bytes.Compare(a.IP, b.IP) + c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th - return a.Port < b.Port + return a.Port() < b.Port() } // Ip wasn't the same @@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { // Deduplicate a, b := 0, 1 for b < n { - if !r.addrs[a].Equals(r.addrs[b]) { + if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] @@ -693,7 +684,7 @@ func minInt(a, b int) int { } // isPreferred returns true of the ip is contained in the preferredRanges list -func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { +func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { @@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { } return false } - -var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") -var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") -var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") - -// isPrivateIP returns true if the ip is contained by a rfc 1918 private range -func isPrivateIP(ip net.IP) bool { - //TODO: another great cidrtree option - //TODO: Private for ipv6 or just let it ride? - return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) -} diff --git a/remote_list_test.go b/remote_list_test.go index 49aa171..62a892b 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -1,47 +1,47 @@ package nebula import ( - "net" + "encoding/binary" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe + newIp4AndPortFromString("70.199.182.92:1475"), // this is duped + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 1, - 1, + netip.MustParseAddr("0.0.0.1"), + netip.MustParseAddr("0.0.0.1"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped - NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe - NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), // this is duped + newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe + newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within @@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up - _, ipNet, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them @@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) // Ensure we can hoist a specific ipv4 range over anything else - _, ipNet, err = net.ParseCIDR("172.17.0.0/16") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first @@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - assert.NoError(b, err) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - assert.NoError(b, err) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(b, err) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } @@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - rl.Rebuild([]*net.IPNet{ipNet}) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") + rl.Rebuild([]netip.Prefix{ipNet1}) - assert.NoError(b, err) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) - assert.NoError(b, err) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) - assert.NoError(b, err) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } + +func newIp4AndPortFromString(s string) *Ip4AndPort { + a := netip.MustParseAddrPort(s) + v4Addr := a.Addr().As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(a.Port()), + } +} + +func newIp6AndPortFromString(s string) *Ip6AndPort { + a := netip.MustParseAddrPort(s) + v6Addr := a.Addr().As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(a.Port()), + } +} diff --git a/service/service.go b/service/service.go index 66ce864..50c1d4a 100644 --- a/service/service.go +++ b/service/service.go @@ -17,7 +17,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "golang.org/x/sync/errgroup" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -81,7 +81,7 @@ func New(config *config.C) (*Service, error) { if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) } - ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4))) + ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4))) s.ipstack.SetRouteTable([]tcpip.Route{ { Destination: ipv4Subnet, @@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) { ipNet := device.Cidr() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ @@ -124,7 +124,7 @@ func New(config *config.C) (*Service, error) { return err } packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])), + Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), }) linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) @@ -136,7 +136,7 @@ func New(config *config.C) (*Service, error) { eg.Go(func() error { for { packet := linkEP.ReadContext(ctx) - if packet.IsNil() { + if packet == nil { if err := ctx.Err(); err != nil { return err } @@ -166,7 +166,7 @@ func (s *Service) DialContext(ctx context.Context, network, address string) (net fullAddr := tcpip.FullAddress{ NIC: nicID, - Addr: tcpip.Address(addr.IP), + Addr: tcpip.AddrFromSlice(addr.IP), Port: uint16(addr.Port), } diff --git a/service/service_test.go b/service/service_test.go index d1909cd..3176209 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "errors" - "net" + "net/netip" "testing" "time" @@ -18,12 +18,8 @@ import ( type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { - - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) @@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -94,7 +90,7 @@ func TestService(t *testing.T) { "port": 4243, }, }) - b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, diff --git a/ssh.go b/ssh.go index e99205c..2ff0954 100644 --- a/ssh.go +++ b/ssh.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "net" + "net/netip" "os" "reflect" "runtime" @@ -18,9 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" - "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -51,6 +50,11 @@ type sshCreateTunnelFlags struct { Address string } +type sshDeviceInfoFlags struct { + Json bool + Pretty bool +} + func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { @@ -110,6 +114,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return nil, fmt.Errorf("error while adding sshd.host_key: %s", err) } + // Clear existing trusted CAs and authorized keys + ssh.ClearTrustedCAs() + ssh.ClearAuthorizedKeys() + + rawCAs := c.GetStringSlice("sshd.trusted_cas", []string{}) + for _, caAuthorizedKey := range rawCAs { + err := ssh.AddTrustedCA(caAuthorizedKey) + if err != nil { + l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") + continue + } + } + rawKeys := c.Get("sshd.authorized_users") keys, ok := rawKeys.([]interface{}) if ok { @@ -231,7 +248,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", - ShortDescription: "Starts a cpu profile and write output to the provided file", + ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", Callback: sshStartCpuProfile, }) @@ -246,7 +263,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", - ShortDescription: "Saves a heap profile to the provided path", + ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", Callback: sshGetHeapProfile, }) @@ -258,7 +275,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "save-mutex-profile", - ShortDescription: "Saves a mutex profile to the provided path", + ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", Callback: sshGetMutexProfile, }) @@ -286,6 +303,21 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter }, }) + ssh.RegisterCommand(&sshd.Command{ + Name: "device-info", + ShortDescription: "Prints information about the network device.", + Flags: func() (*flag.FlagSet, interface{}) { + fl := flag.NewFlagSet("", flag.ContinueOnError) + s := sshDeviceInfoFlags{} + fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") + fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") + return fl, &s + }, + Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { + return sshDeviceInfo(f, fs, w) + }, + }) + ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", @@ -398,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 + return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -512,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -541,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -583,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -603,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udp.Addr + var addr netip.AddrPort if flags.Address != "" { - addr = udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err = netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } } hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) - if addr != nil { + if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -634,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No address was provided") } - addr := udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err := netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -759,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -829,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp iputil.VpnIp + PeerIp netip.Addr LocalIndex uint32 RemoteIndex uint32 - RelayedThrough []iputil.VpnIp + RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp iputil.VpnIp + NebulaIp netip.Addr RelayForIps []RelayFor } @@ -919,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -939,7 +965,34 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr enc.SetIndent("", " ") } - return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges)) + return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) +} + +func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { + + data := struct { + Name string `json:"name"` + Cidr string `json:"cidr"` + }{ + Name: ifce.inside.Name(), + Cidr: ifce.inside.Cidr().String(), + } + + flags, ok := fs.(*sshDeviceInfoFlags) + if !ok { + return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) + } + + if flags.Json || flags.Pretty { + js := json.NewEncoder(w.GetWriter()) + if flags.Pretty { + js.SetIndent("", " ") + } + + return js.Encode(data) + } else { + return w.WriteLine(fmt.Sprintf("name=%v cidr=%v", data.Name, data.Cidr)) + } } func sshReload(c *config.C, w sshd.StringWriter) error { diff --git a/sshd/server.go b/sshd/server.go index 4a78fdf..9e8c721 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -1,6 +1,7 @@ package sshd import ( + "bytes" "errors" "fmt" "net" @@ -15,8 +16,11 @@ type SSHServer struct { config *ssh.ServerConfig l *logrus.Entry + certChecker *ssh.CertChecker + // Map of user -> authorized keys trustedKeys map[string]map[string]bool + trustedCAs []ssh.PublicKey // List of available commands helpCommand *Command @@ -31,6 +35,7 @@ type SSHServer struct { // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { + s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, @@ -38,8 +43,43 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { conns: make(map[int]*session), } + cc := ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + for _, ca := range s.trustedCAs { + if bytes.Equal(ca.Marshal(), auth.Marshal()) { + return true + } + } + + return false + }, + UserKeyFallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + pk := string(pubKey.Marshal()) + fp := ssh.FingerprintSHA256(pubKey) + + tk, ok := s.trustedKeys[c.User()] + if !ok { + return nil, fmt.Errorf("unknown user %s", c.User()) + } + + _, ok = tk[pk] + if !ok { + return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp) + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "fp": fp, + "user": c.User(), + }, + }, nil + + }, + } + s.config = &ssh.ServerConfig{ - PublicKeyCallback: s.matchPubKey, + PublicKeyCallback: cc.Authenticate, //TODO: AuthLogCallback: s.authAttempt, //TODO: version string ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), @@ -66,10 +106,26 @@ func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error { return nil } +func (s *SSHServer) ClearTrustedCAs() { + s.trustedCAs = []ssh.PublicKey{} +} + func (s *SSHServer) ClearAuthorizedKeys() { s.trustedKeys = make(map[string]map[string]bool) } +// AddTrustedCA adds a trusted CA for user certificates +func (s *SSHServer) AddTrustedCA(pubKey string) error { + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) + if err != nil { + return err + } + + s.trustedCAs = append(s.trustedCAs, pk) + s.l.WithField("sshKey", pubKey).Info("Trusted CA key") + return nil +} + // AddAuthorizedKey adds an ssh public key for a user func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) @@ -178,26 +234,3 @@ func (s *SSHServer) closeSessions() { } s.connsLock.Unlock() } - -func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { - pk := string(pubKey.Marshal()) - fp := ssh.FingerprintSHA256(pubKey) - - tk, ok := s.trustedKeys[c.User()] - if !ok { - return nil, fmt.Errorf("unknown user %s", c.User()) - } - - _, ok = tk[pk] - if !ok { - return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp) - } - - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "fp": fp, - "user": c.User(), - }, - }, nil -} diff --git a/test/tun.go b/test/tun.go index 86656c9..fbf5829 100644 --- a/test/tun.go +++ b/test/tun.go @@ -3,23 +3,21 @@ package test import ( "errors" "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type NoopTun struct{} -func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() *net.IPNet { - return nil +func (NoopTun) Cidr() netip.Prefix { + return netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 3f81ff4..4c6364e 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -1,6 +1,7 @@ package nebula import ( + "net/netip" "testing" "time" @@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: 1}, - {LocalIP: 2}, - {LocalIP: 3}, - {LocalIP: 4}, + {LocalIP: netip.MustParseAddr("0.0.0.1")}, + {LocalIP: netip.MustParseAddr("0.0.0.2")}, + {LocalIP: netip.MustParseAddr("0.0.0.3")}, + {LocalIP: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index a2c24a1..fa4e443 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,6 +1,8 @@ package udp import ( + "net/netip" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -9,7 +11,7 @@ import ( const MTU = 9001 type EncReader func( - addr *Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -22,9 +24,9 @@ type EncReader func( type Conn interface { Rebind() error - LocalAddr() (*Addr, error) + LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) - WriteTo(b []byte, addr *Addr) error + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } @@ -34,13 +36,13 @@ type NoopConn struct{} func (NoopConn) Rebind() error { return nil } -func (NoopConn) LocalAddr() (*Addr, error) { - return nil, nil +func (NoopConn) LocalAddr() (netip.AddrPort, error) { + return netip.AddrPort{}, nil } func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { return } -func (NoopConn) WriteTo(_ []byte, _ *Addr) error { +func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { diff --git a/udp/temp.go b/udp/temp.go index 2efe31d..b281906 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,9 +1,10 @@ package udp import ( - "github.com/slackhq/nebula/iputil" + "net/netip" ) //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) +// TODO: IPV6-WORK this can likely be removed now +type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_all.go b/udp/udp_all.go deleted file mode 100644 index 093bf69..0000000 --- a/udp/udp_all.go +++ /dev/null @@ -1,100 +0,0 @@ -package udp - -import ( - "encoding/json" - "fmt" - "net" - "strconv" -) - -type m map[string]interface{} - -type Addr struct { - IP net.IP - Port uint16 -} - -func NewAddr(ip net.IP, port uint16) *Addr { - addr := Addr{IP: make([]byte, net.IPv6len), Port: port} - copy(addr.IP, ip.To16()) - return &addr -} - -func NewAddrFromString(s string) *Addr { - ip, port, err := ParseIPAndPort(s) - //TODO: handle err - _ = err - return &Addr{IP: ip.To16(), Port: port} -} - -func (ua *Addr) Equals(t *Addr) bool { - if t == nil || ua == nil { - return t == nil && ua == nil - } - return ua.IP.Equal(t.IP) && ua.Port == t.Port -} - -func (ua *Addr) String() string { - if ua == nil { - return "" - } - - return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) -} - -func (ua *Addr) MarshalJSON() ([]byte, error) { - if ua == nil { - return nil, nil - } - - return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) -} - -func (ua *Addr) Copy() *Addr { - if ua == nil { - return nil - } - - nu := Addr{ - Port: ua.Port, - IP: make(net.IP, len(ua.IP)), - } - - copy(nu.IP, ua.IP) - return &nu -} - -type AddrSlice []*Addr - -func (a AddrSlice) Equal(b AddrSlice) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equals(b[i]) { - return false - } - } - - return true -} - -func ParseIPAndPort(s string) (net.IP, uint16, error) { - rIp, sPort, err := net.SplitHostPort(s) - if err != nil { - return nil, 0, err - } - - addr, err := net.ResolveIPAddr("ip", rIp) - if err != nil { - return nil, 0, err - } - - iPort, err := strconv.Atoi(sPort) - if err != nil { - return nil, 0, err - } - - return addr.IP, uint16(iPort), nil -} diff --git a/udp/udp_android.go b/udp/udp_android.go index 8d69074..bb19195 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -6,13 +6,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 785aa6a..65ef31a 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -9,13 +9,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 08e1b6a..183ac7a 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 1dd6d1d..2d84536 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -25,7 +26,7 @@ type GenericConn struct { var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { - _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { + _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } -func (u *GenericConn) LocalAddr() (*Addr, error) { +func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &Addr{IP: make([]byte, len(v.IP))} - copy(addr.IP, v.IP) - addr.Port = uint16(v.Port) - return addr, nil + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return nil, fmt.Errorf("LocalAddr returned: %#v", a) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } @@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { // Just read one packet at a time - n, rua, err := u.ReadFromUDP(buffer) + n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - udpAddr.IP = rua.IP - udpAddr.Port = uint16(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 1151c89..ef07243 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "syscall" "unsafe" @@ -27,25 +28,6 @@ type StdConn struct { batch int } -var x int - -// From linux/sock_diag.h -const ( - _SK_MEMINFO_RMEM_ALLOC = iota - _SK_MEMINFO_RCVBUF - _SK_MEMINFO_WMEM_ALLOC - _SK_MEMINFO_SNDBUF - _SK_MEMINFO_FWD_ALLOC - _SK_MEMINFO_WMEM_QUEUED - _SK_MEMINFO_OPTMEM - _SK_MEMINFO_BACKLOG - _SK_MEMINFO_DROPS - - _SK_MEMINFO_VARS -) - -type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 - func maybeIPV4(ip net.IP) (net.IP, bool) { ip4 := ip.To4() if ip4 != nil { @@ -54,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { - ipV4, isV4 := maybeIPV4(ip) +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 - if isV4 { + if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() @@ -80,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr - if isV4 { + if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} - copy(sa4.Addr[:], ipV4) + sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} - copy(sa6.Addr[:], ip.To16()) + sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { @@ -98,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -121,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *StdConn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { - return nil, err + return netip.AddrPort{}, err } - addr := &Addr{} switch sa := sa.(type) { case *unix.SockaddrInet4: - addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() - addr.Port = uint16(sa.Port) + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil + case *unix.SockaddrInet6: - addr.IP = sa.Addr[0:] - addr.Port = uint16(sa.Port) - } + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil - return addr, nil + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{} + var ip netip.Addr nb := make([]byte, 12, 12) //TODO: should we track this? @@ -165,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { - udpAddr.IP = names[i][4:8] + ip, _ = netip.AddrFromSlice(names[i][4:8]) + //TODO: IPV6-WORK what is not ok? } else { - udpAddr.IP = names[i][8:24] + ip, _ = netip.AddrFromSlice(names[i][8:24]) + //TODO: IPV6-WORK what is not ok? } - udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), + plaintext[:0], + buffers[i][:msgs[i].Len], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } } @@ -216,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - return u.writeTo4(b, addr) + return u.writeTo4(b, ip) } - return u.writeTo6(b, addr) + return u.writeTo6(b, ip) } -func (u *StdConn) writeTo6(b []byte, addr *Addr) error { +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 + rsa.Addr = ip.Addr().As16() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addr.IP.To16()) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( @@ -251,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error { } } -func (u *StdConn) writeTo4(b []byte, addr *Addr) error { - addrV4, isAddrV4 := maybeIPV4(addr.IP) - if !isAddrV4 { +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET + rsa.Addr = ip.Addr().As4() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addrV4) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( @@ -316,8 +308,8 @@ func (u *StdConn) ReloadConfig(c *config.C) { } } -func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { - var vallen uint32 = 4 * _SK_MEMINFO_VARS +func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { + var vallen uint32 = 4 * unix.SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { return err @@ -332,12 +324,12 @@ func (u *StdConn) Close() error { func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges - var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge - var meminfo _SK_MEMINFO + var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge + var meminfo [unix.SK_MEMINFO_VARS]uint32 if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { - udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) + udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { - udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ + udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{ metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), @@ -354,7 +346,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() { for i, gauges := range udpGauges { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { - for j := 0; j < _SK_MEMINFO_VARS; j++ { + for j := 0; j < unix.SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index a54f1df..87a0de7 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -1,6 +1,6 @@ -//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64) && !android && !e2e_testing +//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing // +build linux -// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 +// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64 // +build !android // +build !e2e_testing diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3c14fac..3b69159 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 31c1a55..ee7e1e0 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "syscall" @@ -61,16 +62,14 @@ type RIOConn struct { results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { +func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} - addr := [16]byte{} - copy(addr[:], ip.To16()) - err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - udpAddr.IP = rua.Addr[:] - p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) - p[0] = byte(rua.Port >> 8) - p[1] = byte(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } @@ -231,7 +235,7 @@ retry: return n, ep, nil } -func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { +func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } @@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(packet.addr.Addr[:], addr.IP.To16()) + packet.addr.Addr = ip.Addr().As16() + port := ip.Port() + packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ @@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) LocalAddr() (*Addr, error) { +func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { - return nil, err + return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) - return &Addr{ - IP: v6.Addr[:], - Port: uint16(v6.Port), - }, nil + return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil + } func (u *RIOConn) Rebind() error { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55985f4..f03a353 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,9 +4,8 @@ package udp import ( - "fmt" "io" - "net" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" @@ -16,30 +15,24 @@ import ( ) type Packet struct { - ToIp net.IP - ToPort uint16 - FromIp net.IP - FromPort uint16 - Data []byte + To netip.AddrPort + From netip.AddrPort + Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ - ToIp: make(net.IP, len(u.ToIp)), - ToPort: u.ToPort, - FromIp: make(net.IP, len(u.FromIp)), - FromPort: u.FromPort, - Data: make([]byte, len(u.Data)), + To: u.To, + From: u.From, + Data: make([]byte, len(u.Data)), } - copy(n.ToIp, u.ToIp) - copy(n.FromIp, u.FromIp) copy(n.Data, u.Data) return n } type TesterConn struct { - Addr *Addr + Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula @@ -48,9 +41,9 @@ type TesterConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ - Addr: &Addr{ip, uint16(port)}, + Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) { } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). - WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } @@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ - Data: make([]byte, len(b), len(b)), - FromIp: make([]byte, 16), - FromPort: u.Addr.Port, - ToIp: make([]byte, 16), - ToPort: addr.Port, + Data: make([]byte, len(b), len(b)), + From: u.Addr, + To: addr, } copy(p.Data, b) - copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.Addr.IP.To16()) - u.TxPackets <- p return nil } @@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - ua := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi if !ok { return } - ua.Port = p.FromPort - copy(ua.IP, p.FromIp.To16()) - r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } @@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() { return func() {} } -func (u *TesterConn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index ebcace6..1b777c3 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -6,12 +6,13 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between