diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml
index 399bc95..e0d41ae 100644
--- a/.github/workflows/gofmt.yml
+++ b/.github/workflows/gofmt.yml
@@ -18,7 +18,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Install goimports
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b5b8ced..31987db 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -14,7 +14,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Build
@@ -24,7 +24,7 @@ jobs:
mv build/*.tar.gz release
- name: Upload artifacts
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: linux-latest
path: release
@@ -37,7 +37,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Build
@@ -55,7 +55,7 @@ jobs:
mv dist\windows\wintun build\dist\windows\
- name: Upload artifacts
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: windows-latest
path: build
@@ -64,18 +64,18 @@ jobs:
name: Build Universal Darwin
env:
HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
- runs-on: macos-11
+ runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Import certificates
if: env.HAS_SIGNING_CREDS == 'true'
- uses: Apple-Actions/import-codesign-certs@v2
+ uses: Apple-Actions/import-codesign-certs@v3
with:
p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}
@@ -104,11 +104,57 @@ jobs:
fi
- name: Upload artifacts
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: darwin-latest
path: ./release/*
+ build-docker:
+ name: Create and Upload Docker Images
+ # Technically we only need build-linux to succeed, but if any platforms fail we'll
+ # want to investigate and restart the build
+ needs: [build-linux, build-darwin, build-windows]
+ runs-on: ubuntu-latest
+ env:
+ HAS_DOCKER_CREDS: ${{ vars.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
+ # XXX It's not possible to write a conditional here, so instead we do it on every step
+ #if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ steps:
+ # Be sure to checkout the code before downloading artifacts, or they will
+ # be overwritten
+ - name: Checkout code
+ if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ uses: actions/checkout@v4
+
+ - name: Download artifacts
+ if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ uses: actions/download-artifact@v4
+ with:
+ name: linux-latest
+ path: artifacts
+
+ - name: Login to Docker Hub
+ if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ uses: docker/login-action@v3
+ with:
+ username: ${{ vars.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Set up Docker Buildx
+ if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ uses: docker/setup-buildx-action@v3
+
+ - name: Build and push images
+ if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
+ env:
+ DOCKER_IMAGE_REPO: ${{ vars.DOCKER_IMAGE_REPO || 'nebulaoss/nebula' }}
+ DOCKER_IMAGE_TAG: ${{ vars.DOCKER_IMAGE_TAG || 'latest' }}
+ run: |
+ mkdir -p build/linux-{amd64,arm64}
+ tar -zxvf artifacts/nebula-linux-amd64.tar.gz -C build/linux-amd64/
+ tar -zxvf artifacts/nebula-linux-arm64.tar.gz -C build/linux-arm64/
+ docker buildx build . --push -f docker/Dockerfile --platform linux/amd64,linux/arm64 --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:${GITHUB_REF#refs/tags/v}"
+
release:
name: Create and Upload Release
needs: [build-linux, build-darwin, build-windows]
@@ -117,7 +163,7 @@ jobs:
- uses: actions/checkout@v4
- name: Download artifacts
- uses: actions/download-artifact@v3
+ uses: actions/download-artifact@v4
with:
path: artifacts
diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml
new file mode 100644
index 0000000..2b5e6e9
--- /dev/null
+++ b/.github/workflows/smoke-extra.yml
@@ -0,0 +1,48 @@
+name: smoke-extra
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [opened, synchronize, labeled, reopened]
+ paths:
+ - '.github/workflows/smoke**'
+ - '**Makefile'
+ - '**.go'
+ - '**.proto'
+ - 'go.mod'
+ - 'go.sum'
+jobs:
+
+ smoke-extra:
+ if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra')
+ name: Run extra smoke tests
+ runs-on: ubuntu-latest
+ steps:
+
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-go@v5
+ with:
+ go-version-file: 'go.mod'
+ check-latest: true
+
+ - name: install vagrant
+ run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
+
+ - name: freebsd-amd64
+ run: make smoke-vagrant/freebsd-amd64
+
+ - name: openbsd-amd64
+ run: make smoke-vagrant/openbsd-amd64
+
+ - name: netbsd-amd64
+ run: make smoke-vagrant/netbsd-amd64
+
+ - name: linux-386
+ run: make smoke-vagrant/linux-386
+
+ - name: linux-amd64-ipv6disable
+ run: make smoke-vagrant/linux-amd64-ipv6disable
+
+ timeout-minutes: 30
diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml
index f02b1ba..54833bd 100644
--- a/.github/workflows/smoke.yml
+++ b/.github/workflows/smoke.yml
@@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: build
diff --git a/.github/workflows/smoke/build.sh b/.github/workflows/smoke/build.sh
index 9cbb200..c546653 100755
--- a/.github/workflows/smoke/build.sh
+++ b/.github/workflows/smoke/build.sh
@@ -11,6 +11,11 @@ mkdir ./build
cp ../../../../build/linux-amd64/nebula .
cp ../../../../build/linux-amd64/nebula-cert .
+ if [ "$1" ]
+ then
+ cp "../../../../build/$1/nebula" "$1-nebula"
+ fi
+
HOST="lighthouse1" \
AM_LIGHTHOUSE=true \
../genconfig.sh >lighthouse1.yml
diff --git a/.github/workflows/smoke/genconfig.sh b/.github/workflows/smoke/genconfig.sh
index 373ea5f..16e768e 100755
--- a/.github/workflows/smoke/genconfig.sh
+++ b/.github/workflows/smoke/genconfig.sh
@@ -47,7 +47,7 @@ listen:
port: ${LISTEN_PORT:-4242}
tun:
- dev: ${TUN_DEV:-nebula1}
+ dev: ${TUN_DEV:-tun0}
firewall:
inbound_action: reject
diff --git a/.github/workflows/smoke/smoke-relay.sh b/.github/workflows/smoke/smoke-relay.sh
index 8926091..9c113e1 100755
--- a/.github/workflows/smoke/smoke-relay.sh
+++ b/.github/workflows/smoke/smoke-relay.sh
@@ -76,7 +76,7 @@ docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
-sleep 1
+sleep 5
if [ "$(jobs -r)" ]
then
diff --git a/.github/workflows/smoke/smoke-vagrant.sh b/.github/workflows/smoke/smoke-vagrant.sh
new file mode 100755
index 0000000..76cf72f
--- /dev/null
+++ b/.github/workflows/smoke/smoke-vagrant.sh
@@ -0,0 +1,105 @@
+#!/bin/bash
+
+set -e -x
+
+set -o pipefail
+
+export VAGRANT_CWD="$PWD/vagrant-$1"
+
+mkdir -p logs
+
+cleanup() {
+ echo
+ echo " *** cleanup"
+ echo
+
+ set +e
+ if [ "$(jobs -r)" ]
+ then
+ docker kill lighthouse1 host2
+ fi
+ vagrant destroy -f
+}
+
+trap cleanup EXIT
+
+CONTAINER="nebula:${NAME:-smoke}"
+
+docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
+docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
+
+vagrant up
+vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
+
+docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' &
+sleep 1
+docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' &
+sleep 1
+vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
+sleep 15
+
+# grab tcpdump pcaps for debugging
+docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
+docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
+docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
+docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
+# vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
+# vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
+
+docker exec host2 ncat -nklv 0.0.0.0 2000 &
+vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
+#docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
+#vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
+
+set +x
+echo
+echo " *** Testing ping from lighthouse1"
+echo
+set -x
+docker exec lighthouse1 ping -c1 192.168.100.2
+docker exec lighthouse1 ping -c1 192.168.100.3
+
+set +x
+echo
+echo " *** Testing ping from host2"
+echo
+set -x
+docker exec host2 ping -c1 192.168.100.1
+# Should fail because not allowed by host3 inbound firewall
+! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
+
+set +x
+echo
+echo " *** Testing ncat from host2"
+echo
+set -x
+# Should fail because not allowed by host3 inbound firewall
+#! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
+#! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
+
+set +x
+echo
+echo " *** Testing ping from host3"
+echo
+set -x
+vagrant ssh -c "ping -c1 192.168.100.1"
+vagrant ssh -c "ping -c1 192.168.100.2"
+
+set +x
+echo
+echo " *** Testing ncat from host3"
+echo
+set -x
+#vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
+#vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
+
+vagrant ssh -c "sudo xargs kill &2
+ exit 1
+fi
diff --git a/.github/workflows/smoke/smoke.sh b/.github/workflows/smoke/smoke.sh
index 3177255..6d04027 100755
--- a/.github/workflows/smoke/smoke.sh
+++ b/.github/workflows/smoke/smoke.sh
@@ -129,7 +129,7 @@ docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1'
docker exec host2 sh -c 'kill 1'
docker exec lighthouse1 sh -c 'kill 1'
-sleep 1
+sleep 5
if [ "$(jobs -r)" ]
then
diff --git a/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile
new file mode 100644
index 0000000..c8a4c64
--- /dev/null
+++ b/.github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile
@@ -0,0 +1,7 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+Vagrant.configure("2") do |config|
+ config.vm.box = "generic/freebsd14"
+
+ config.vm.synced_folder "../build", "/nebula", type: "rsync"
+end
diff --git a/.github/workflows/smoke/vagrant-linux-386/Vagrantfile b/.github/workflows/smoke/vagrant-linux-386/Vagrantfile
new file mode 100644
index 0000000..4b1d0bd
--- /dev/null
+++ b/.github/workflows/smoke/vagrant-linux-386/Vagrantfile
@@ -0,0 +1,7 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+Vagrant.configure("2") do |config|
+ config.vm.box = "ubuntu/xenial32"
+
+ config.vm.synced_folder "../build", "/nebula"
+end
diff --git a/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile
new file mode 100644
index 0000000..89f9477
--- /dev/null
+++ b/.github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile
@@ -0,0 +1,16 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+Vagrant.configure("2") do |config|
+ config.vm.box = "ubuntu/jammy64"
+
+ config.vm.synced_folder "../build", "/nebula"
+
+ config.vm.provision :shell do |shell|
+ shell.inline = <<-EOF
+ sed -i 's/GRUB_CMDLINE_LINUX=""/GRUB_CMDLINE_LINUX="ipv6.disable=1"/' /etc/default/grub
+ update-grub
+ EOF
+ shell.privileged = true
+ shell.reboot = true
+ end
+end
diff --git a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile
new file mode 100644
index 0000000..14ba2ce
--- /dev/null
+++ b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile
@@ -0,0 +1,7 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+Vagrant.configure("2") do |config|
+ config.vm.box = "generic/netbsd9"
+
+ config.vm.synced_folder "../build", "/nebula", type: "rsync"
+end
diff --git a/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile
new file mode 100644
index 0000000..e4f4104
--- /dev/null
+++ b/.github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile
@@ -0,0 +1,7 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+Vagrant.configure("2") do |config|
+ config.vm.box = "generic/openbsd7"
+
+ config.vm.synced_folder "../build", "/nebula", type: "rsync"
+end
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 34fe5f3..65a6e3e 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -22,7 +22,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Build
@@ -40,10 +40,10 @@ jobs:
- name: Build test mobile
run: make build-test-mobile
- - uses: actions/upload-artifact@v3
+ - uses: actions/upload-artifact@v4
with:
- name: e2e packet flow
- path: e2e/mermaid/
+ name: e2e packet flow linux-latest
+ path: e2e/mermaid/linux-latest
if-no-files-found: warn
test-linux-boringcrypto:
@@ -55,7 +55,7 @@ jobs:
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Build
@@ -72,14 +72,14 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [windows-latest, macos-11]
+ os: [windows-latest, macos-latest]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
- go-version-file: 'go.mod'
+ go-version: '1.22'
check-latest: true
- name: Build nebula
@@ -97,8 +97,8 @@ jobs:
- name: End 2 end
run: make e2evv
- - uses: actions/upload-artifact@v3
+ - uses: actions/upload-artifact@v4
with:
- name: e2e packet flow
- path: e2e/mermaid/
+ name: e2e packet flow ${{ matrix.os }}
+ path: e2e/mermaid/${{ matrix.os }}
if-no-files-found: warn
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 71c3ed4..f763b69 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,92 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [1.9.3] - 2024-06-06
+
+### Fixed
+
+- Initialize messageCounter to 2 instead of verifying later. (#1156)
+
+## [1.9.2] - 2024-06-03
+
+### Fixed
+
+- Ensure messageCounter is set before handshake is complete. (#1154)
+
+## [1.9.1] - 2024-05-29
+
+### Fixed
+
+- Fixed a potential deadlock in GetOrHandshake. (#1151)
+
+## [1.9.0] - 2024-05-07
+
+### Deprecated
+
+- This release adds a new setting `default_local_cidr_any` that defaults to
+ true to match previous behavior, but will default to false in the next
+ release (1.10). When set to false, `local_cidr` is matched correctly for
+ firewall rules on hosts acting as unsafe routers, and should be set for any
+ firewall rules you want to allow unsafe route hosts to access. See the issue
+ and example config for more details. (#1071, #1099)
+
+### Added
+
+- Nebula now has an official Docker image `nebulaoss/nebula` that is
+ distroless and contains just the `nebula` and `nebula-cert` binaries. You
+ can find it here: https://hub.docker.com/r/nebulaoss/nebula (#1037)
+
+- Experimental binaries for `loong64` are now provided. (#1003)
+
+- Added example service script for OpenRC. (#711)
+
+- The SSH daemon now supports inlined host keys. (#1054)
+
+- The SSH daemon now supports certificates with `sshd.trusted_cas`. (#1098)
+
+### Changed
+
+- Config setting `tun.unsafe_routes` is now reloadable. (#1083)
+
+- Small documentation and internal improvements. (#1065, #1067, #1069, #1108,
+ #1109, #1111, #1135)
+
+- Various dependency updates. (#1139, #1138, #1134, #1133, #1126, #1123, #1110,
+ #1094, #1092, #1087, #1086, #1085, #1072, #1063, #1059, #1055, #1053, #1047,
+ #1046, #1034, #1022)
+
+### Removed
+
+- Support for the deprecated `local_range` option has been removed. Please
+ change to `preferred_ranges` (which is also now reloadable). (#1043)
+
+- We are now building with go1.22, which means that for Windows you need at
+ least Windows 10 or Windows Server 2016. This is because support for earlier
+ versions was removed in Go 1.21. See https://go.dev/doc/go1.21#windows (#981)
+
+- Removed vagrant example, as it was unmaintained. (#1129)
+
+- Removed Fedora and Arch nebula.service files, as they are maintained in the
+ upstream repos. (#1128, #1132)
+
+- Remove the TCP round trip tracking metrics, as they never had correct data
+ and were an experiment to begin with. (#1114)
+
+### Fixed
+
+- Fixed a potential deadlock introduced in 1.8.1. (#1112)
+
+- Fixed support for Linux when IPv6 has been disabled at the OS level. (#787)
+
+- DNS will return NXDOMAIN now when there are no results. (#845)
+
+- Allow `::` in `lighthouse.dns.host`. (#1115)
+
+- Capitalization of `NotAfter` fixed in DNS TXT response. (#1127)
+
+- Don't log invalid certificates. It is untrusted data and can cause a large
+ volume of logs. (#1116)
+
## [1.8.2] - 2024-01-08
### Fixed
@@ -558,7 +644,11 @@ created.)
- Initial public release.
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.2...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.3...HEAD
+[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
+[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
+[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1
+[1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0
[1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2
[1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1
[1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0
diff --git a/LOGGING.md b/LOGGING.md
index bd2fdef..e2508c8 100644
--- a/LOGGING.md
+++ b/LOGGING.md
@@ -33,6 +33,5 @@ l.WithError(err).
WithField("vpnIp", IntIp(hostinfo.hostId)).
WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix"}).
- WithField("cert", remoteCert).
Info("Invalid certificate from host")
```
\ No newline at end of file
diff --git a/Makefile b/Makefile
index a02a6ec..0d0943f 100644
--- a/Makefile
+++ b/Makefile
@@ -1,22 +1,14 @@
-GOMINVERSION = 1.20
NEBULA_CMD_PATH = "./cmd/nebula"
-GO111MODULE = on
-export GO111MODULE
CGO_ENABLED = 0
export CGO_ENABLED
# Set up OS specific bits
ifeq ($(OS),Windows_NT)
- #TODO: we should be able to ditch awk as well
- GOVERSION := $(shell go version | awk "{print substr($$3, 3)}")
- GOISMIN := $(shell IF "$(GOVERSION)" GEQ "$(GOMINVERSION)" ECHO 1)
NEBULA_CMD_SUFFIX = .exe
NULL_FILE = nul
# RIO on windows does pointer stuff that makes go vet angry
VET_FLAGS = -unsafeptr=false
else
- GOVERSION := $(shell go version | awk '{print substr($$3, 3)}')
- GOISMIN := $(shell expr "$(GOVERSION)" ">=" "$(GOMINVERSION)")
NEBULA_CMD_SUFFIX =
NULL_FILE = /dev/null
endif
@@ -30,6 +22,9 @@ ifndef BUILD_NUMBER
endif
endif
+DOCKER_IMAGE_REPO ?= nebulaoss/nebula
+DOCKER_IMAGE_TAG ?= latest
+
LDFLAGS = -X main.Build=$(BUILD_NUMBER)
ALL_LINUX = linux-amd64 \
@@ -44,7 +39,8 @@ ALL_LINUX = linux-amd64 \
linux-mips64 \
linux-mips64le \
linux-mips-softfloat \
- linux-riscv64
+ linux-riscv64 \
+ linux-loong64
ALL_FREEBSD = freebsd-amd64 \
freebsd-arm64
@@ -82,8 +78,12 @@ e2evvvv: e2ev
e2e-bench: TEST_FLAGS = -bench=. -benchmem -run=^$
e2e-bench: e2e
+DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert
+
all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert)
+docker: docker/linux-$(shell go env GOARCH)
+
release: $(ALL:%=build/nebula-%.tar.gz)
release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz)
@@ -156,6 +156,9 @@ build/nebula-%.tar.gz: build/%/nebula build/%/nebula-cert
build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe
cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe
+docker/%: build/%/nebula build/%/nebula-cert
+ docker build . $(DOCKER_BUILD_ARGS) -f docker/Dockerfile --platform "$(subst -,/,$*)" --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:$(BUILD_NUMBER)"
+
vet:
go vet $(VET_FLAGS) -v ./...
@@ -219,6 +222,10 @@ smoke-docker-race: BUILD_ARGS = -race
smoke-docker-race: CGO_ENABLED = 1
smoke-docker-race: smoke-docker
+smoke-vagrant/%: bin-docker build/%/nebula
+ cd .github/workflows/smoke/ && ./build.sh $*
+ cd .github/workflows/smoke/ && ./smoke-vagrant.sh $*
+
.FORCE:
-.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html
+.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/%
.DEFAULT_GOAL := bin
diff --git a/README.md b/README.md
index 51e913d..65ea91f 100644
--- a/README.md
+++ b/README.md
@@ -52,6 +52,11 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
$ brew install nebula
```
+- [Docker](https://hub.docker.com/r/nebulaoss/nebula)
+ ```
+ $ docker pull nebulaoss/nebula
+ ```
+
#### Mobile
- [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200)
diff --git a/allow_list.go b/allow_list.go
index 9186b2f..90e0de2 100644
--- a/allow_list.go
+++ b/allow_list.go
@@ -2,17 +2,16 @@ package nebula
import (
"fmt"
- "net"
+ "net/netip"
"regexp"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
- cidrTree *cidr.Tree6[bool]
+ cidrTree *bart.Table[bool]
}
type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
// Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList
- insideAllowLists *cidr.Tree6[*AllowList]
+ insideAllowLists *bart.Table[*AllowList]
}
type LocalAllowList struct {
@@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}
- tree := cidr.NewTree6[bool]()
+ tree := new(bart.Table[bool])
// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
@@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil {
- return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+ return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
}
+ ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
+
// TODO: should we error on duplicate CIDRs in the config?
- tree.AddCIDR(ipNet, value)
+ tree.Insert(ipNet, value)
- maskBits, maskSize := ipNet.Mask.Size()
+ maskBits := ipNet.Bits()
var rules *allowListRules
- if maskSize == 32 {
+ if ipNet.Addr().Is4() {
rules = &rules4
} else {
rules = &rules6
@@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules4.defaultSet {
if rules4.allValuesMatch {
- _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
- tree.AddCIDR(zeroCIDR, !rules4.allValues)
+ tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
}
@@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
if !rules6.defaultSet {
if rules6.allValuesMatch {
- _, zeroCIDR, _ := net.ParseCIDR("::/0")
- tree.AddCIDR(zeroCIDR, !rules6.allValues)
+ tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
} else {
return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
}
@@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
return nameRules, nil
}
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
+func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
- remoteAllowRanges := cidr.NewTree6[*AllowList]()
+ remoteAllowRanges := new(bart.Table[*AllowList])
rawMap, ok := value.(map[interface{}]interface{})
if !ok {
@@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
return nil, err
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ ipNet, err := netip.ParsePrefix(rawCIDR)
if err != nil {
- return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+ return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
}
- remoteAllowRanges.AddCIDR(ipNet, allowList)
+ remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
}
return remoteAllowRanges, nil
}
-func (al *AllowList) Allow(ip net.IP) bool {
- if al == nil {
- return true
- }
-
- _, result := al.cidrTree.MostSpecificContains(ip)
- return result
-}
-
-func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
- if al == nil {
- return true
- }
-
- _, result := al.cidrTree.MostSpecificContainsIpV4(ip)
- return result
-}
-
-func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
+func (al *AllowList) Allow(ip netip.Addr) bool {
if al == nil {
return true
}
- _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+ result, _ := al.cidrTree.Lookup(ip)
return result
}
-func (al *LocalAllowList) Allow(ip net.IP) bool {
+func (al *LocalAllowList) Allow(ip netip.Addr) bool {
if al == nil {
return true
}
@@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
return !al.nameRules[0].Allow
}
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
+func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
if al == nil {
return true
}
return al.AllowList.Allow(ip)
}
-func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
if !al.getInsideAllowList(vpnIp).Allow(ip) {
return false
}
return al.AllowList.Allow(ip)
}
-func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
- if al == nil {
- return true
- }
- if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
- return false
- }
- return al.AllowList.AllowIpV4(ip)
-}
-
-func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
- if al == nil {
- return true
- }
- if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
- return false
- }
- return al.AllowList.AllowIpV6(hi, lo)
-}
-
-func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
if al.insideAllowLists != nil {
- ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+ inside, ok := al.insideAllowLists.Lookup(vpnIp)
if ok {
return inside
}
diff --git a/allow_list_test.go b/allow_list_test.go
index 334cb60..c8b3d08 100644
--- a/allow_list_test.go
+++ b/allow_list_test.go
@@ -1,11 +1,11 @@
package nebula
import (
- "net"
+ "net/netip"
"regexp"
"testing"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
@@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"192.168.0.0": true,
}
r, err := newAllowListFromConfig(c, "allowlist", nil)
- assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+ assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
}
func TestAllowList_Allow(t *testing.T) {
- assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
-
- tree := cidr.NewTree6[bool]()
- tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
- tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
- tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
- tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
- tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
- tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
- tree.AddCIDR(cidr.Parse("::1/128"), true)
- tree.AddCIDR(cidr.Parse("::2/128"), false)
+ assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
+
+ tree := new(bart.Table[bool])
+ tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
+ tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
+ tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
+ tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
+ tree.Insert(netip.MustParsePrefix("::1/128"), true)
+ tree.Insert(netip.MustParsePrefix("::2/128"), false)
al := &AllowList{cidrTree: tree}
- assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
- assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
- assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
- assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
- assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
- assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
- assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
+ assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
+ assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
}
func TestLocalAllowList_AllowName(t *testing.T) {
diff --git a/calculated_remote.go b/calculated_remote.go
index 38f5bea..ae2ed50 100644
--- a/calculated_remote.go
+++ b/calculated_remote.go
@@ -1,41 +1,36 @@
package nebula
import (
+ "encoding/binary"
"fmt"
"math"
"net"
+ "net/netip"
"strconv"
- "github.com/slackhq/nebula/cidr"
+ "github.com/gaissmai/bart"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
// This allows us to "guess" what the remote might be for a host while we wait
// for the lighthouse response. See "lighthouse.calculated_remotes" in the
// example config file.
type calculatedRemote struct {
- ipNet net.IPNet
- maskIP iputil.VpnIp
- mask iputil.VpnIp
- port uint32
+ ipNet netip.Prefix
+ mask netip.Prefix
+ port uint32
}
-func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
- // Ensure this is an IPv4 mask that we expect
- ones, bits := ipNet.Mask.Size()
- if ones == 0 || bits != 32 {
- return nil, fmt.Errorf("invalid mask: %v", ipNet)
- }
+func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+ masked := maskCidr.Masked()
if port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %d", port)
}
return &calculatedRemote{
- ipNet: *ipNet,
- maskIP: iputil.Ip2VpnIp(ipNet.IP),
- mask: iputil.Ip2VpnIp(ipNet.Mask),
- port: uint32(port),
+ ipNet: maskCidr,
+ mask: masked,
+ port: uint32(port),
}, nil
}
@@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
}
-func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
+func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
// Combine the masked bytes of the "mask" IP with the unmasked bytes
// of the overlay IP
- masked := (c.maskIP & c.mask) | (ip & ^c.mask)
+ if c.ipNet.Addr().Is4() {
+ return c.apply4(ip)
+ }
+ return c.apply6(ip)
+}
+
+func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
+ //TODO: IPV6-WORK this can be less crappy
+ maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+ mask := binary.BigEndian.Uint32(maskb[:])
+
+ b := c.mask.Addr().As4()
+ maskIp := binary.BigEndian.Uint32(b[:])
+
+ b = ip.As4()
+ intIp := binary.BigEndian.Uint32(b[:])
+
+ return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+}
- return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
+func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
+ //TODO: IPV6-WORK
+ panic("Can not calculate ipv6 remote addresses")
}
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}
- calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
+ calculatedRemotes := new(bart.Table[[]*calculatedRemote])
rawMap, ok := value.(map[any]any)
if !ok {
@@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
}
- _, ipNet, err := net.ParseCIDR(rawCIDR)
+ cidr, err := netip.ParsePrefix(rawCIDR)
if err != nil {
return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
}
+ //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
entry, err := newCalculatedRemotesListFromConfig(rawValue)
if err != nil {
return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
}
- calculatedRemotes.AddCIDR(ipNet, entry)
+ calculatedRemotes.Insert(cidr, entry)
}
return calculatedRemotes, nil
@@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
if !ok {
return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
}
- _, ipNet, err := net.ParseCIDR(rawMask)
+ maskCidr, err := netip.ParsePrefix(rawMask)
if err != nil {
return nil, fmt.Errorf("invalid mask: %s", rawMask)
}
@@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
}
- return newCalculatedRemote(ipNet, port)
+ return newCalculatedRemote(maskCidr, port)
}
diff --git a/calculated_remote_test.go b/calculated_remote_test.go
index 2ddebca..6ff1cb0 100644
--- a/calculated_remote_test.go
+++ b/calculated_remote_test.go
@@ -1,27 +1,25 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCalculatedRemoteApply(t *testing.T) {
- _, ipNet, err := net.ParseCIDR("192.168.1.0/24")
+ ipNet, err := netip.ParsePrefix("192.168.1.0/24")
require.NoError(t, err)
c, err := newCalculatedRemote(ipNet, 4242)
require.NoError(t, err)
- input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
+ input, err := netip.ParseAddr("10.0.10.182")
+ assert.NoError(t, err)
- expected := &Ip4AndPort{
- Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
- Port: 4242,
- }
+ expected, err := netip.ParseAddr("192.168.1.182")
+ assert.NoError(t, err)
- assert.Equal(t, expected, c.Apply(input))
+ assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
}
diff --git a/cert/cert.go b/cert/cert.go
index 4f1b776..a0164f7 100644
--- a/cert/cert.go
+++ b/cert/cert.go
@@ -324,7 +324,7 @@ func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
return k.Bytes, r, nil
}
-// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert into its
+// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
// protobuf-generated struct.
func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
if len(b) == 0 {
diff --git a/cidr/parse.go b/cidr/parse.go
deleted file mode 100644
index 74367f6..0000000
--- a/cidr/parse.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package cidr
-
-import "net"
-
-// Parse is a convenience function that returns only the IPNet
-// This function ignores errors since it is primarily a test helper, the result could be nil
-func Parse(s string) *net.IPNet {
- _, c, _ := net.ParseCIDR(s)
- return c
-}
diff --git a/cidr/tree4.go b/cidr/tree4.go
deleted file mode 100644
index c5ebe54..0000000
--- a/cidr/tree4.go
+++ /dev/null
@@ -1,203 +0,0 @@
-package cidr
-
-import (
- "net"
-
- "github.com/slackhq/nebula/iputil"
-)
-
-type Node[T any] struct {
- left *Node[T]
- right *Node[T]
- parent *Node[T]
- hasValue bool
- value T
-}
-
-type entry[T any] struct {
- CIDR *net.IPNet
- Value T
-}
-
-type Tree4[T any] struct {
- root *Node[T]
- list []entry[T]
-}
-
-const (
- startbit = iputil.VpnIp(0x80000000)
-)
-
-func NewTree4[T any]() *Tree4[T] {
- tree := new(Tree4[T])
- tree.root = &Node[T]{}
- tree.list = []entry[T]{}
- return tree
-}
-
-func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
- bit := startbit
- node := tree.root
- next := tree.root
-
- ip := iputil.Ip2VpnIp(cidr.IP)
- mask := iputil.Ip2VpnIp(cidr.Mask)
-
- // Find our last ancestor in the tree
- for bit&mask != 0 {
- if ip&bit != 0 {
- next = node.right
- } else {
- next = node.left
- }
-
- if next == nil {
- break
- }
-
- bit = bit >> 1
- node = next
- }
-
- // We already have this range so update the value
- if next != nil {
- addCIDR := cidr.String()
- for i, v := range tree.list {
- if addCIDR == v.CIDR.String() {
- tree.list = append(tree.list[:i], tree.list[i+1:]...)
- break
- }
- }
-
- tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
- node.value = val
- node.hasValue = true
- return
- }
-
- // Build up the rest of the tree we don't already have
- for bit&mask != 0 {
- next = &Node[T]{}
- next.parent = node
-
- if ip&bit != 0 {
- node.right = next
- } else {
- node.left = next
- }
-
- bit >>= 1
- node = next
- }
-
- // Final node marks our cidr, set the value
- node.value = val
- node.hasValue = true
- tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-}
-
-// Contains finds the first match, which may be the least specific
-func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- return true, node.value
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
-
- }
-
- return false, value
-}
-
-// MostSpecificContains finds the most specific match
-func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return ok, value
-}
-
-type eachFunc[T any] func(T) bool
-
-// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
-// The final return value will be true if the provided function returned true
-func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
- bit := startbit
- node := tree.root
-
- for node != nil {
- if node.hasValue {
- // If the each func returns true then we can exit the loop
- if each(node.value) {
- return true
- }
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return false
-}
-
-// GetCIDR returns the entry added by the most recent matching AddCIDR call
-func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
- bit := startbit
- node := tree.root
-
- ip := iputil.Ip2VpnIp(cidr.IP)
- mask := iputil.Ip2VpnIp(cidr.Mask)
-
- // Find our last ancestor in the tree
- for node != nil && bit&mask != 0 {
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit = bit >> 1
- }
-
- if bit&mask == 0 && node != nil {
- value = node.value
- ok = node.hasValue
- }
-
- return ok, value
-}
-
-// List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4[T]) List() []entry[T] {
- return tree.list
-}
diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go
deleted file mode 100644
index cd17be4..0000000
--- a/cidr/tree4_test.go
+++ /dev/null
@@ -1,170 +0,0 @@
-package cidr
-
-import (
- "net"
- "testing"
-
- "github.com/slackhq/nebula/iputil"
- "github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_List(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/16"), "1")
- tree.AddCIDR(Parse("1.0.0.0/8"), "2")
- tree.AddCIDR(Parse("1.0.0.0/16"), "3")
- tree.AddCIDR(Parse("1.0.0.0/16"), "4")
- list := tree.List()
- assert.Len(t, list, 2)
- assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
- assert.Equal(t, "2", list[0].Value)
- assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
- assert.Equal(t, "4", list[1].Value)
-}
-
-func TestCIDRTree_Contains(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
- tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4a", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree4[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4b", "4.1.1.2"},
- {true, "4c", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree4[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-}
-
-func TestTree4_GetCIDR(t *testing.T) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
- tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
- tests := []struct {
- Found bool
- Result interface{}
- IPNet *net.IPNet
- }{
- {true, "1", Parse("1.0.0.0/8")},
- {true, "2", Parse("2.1.0.0/16")},
- {true, "3", Parse("3.1.1.0/24")},
- {true, "4a", Parse("4.1.1.0/24")},
- {true, "4b", Parse("4.1.1.1/32")},
- {true, "4c", Parse("4.1.2.1/32")},
- {true, "5", Parse("254.0.0.0/4")},
- {false, "", Parse("2.0.0.0/8")},
- }
-
- for _, tt := range tests {
- ok, r := tree.GetCIDR(tt.IPNet)
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
- tree := NewTree4[string]()
- tree.AddCIDR(Parse("1.1.0.0/16"), "1")
- tree.AddCIDR(Parse("1.2.1.1/32"), "1")
- tree.AddCIDR(Parse("192.2.1.1/32"), "1")
- tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
- ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
- b.Run("found", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- tree.Contains(ip)
- }
- })
-
- ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
- b.Run("not found", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- tree.Contains(ip)
- }
- })
-}
diff --git a/cidr/tree6.go b/cidr/tree6.go
deleted file mode 100644
index 3f2cd2a..0000000
--- a/cidr/tree6.go
+++ /dev/null
@@ -1,189 +0,0 @@
-package cidr
-
-import (
- "net"
-
- "github.com/slackhq/nebula/iputil"
-)
-
-const startbit6 = uint64(1 << 63)
-
-type Tree6[T any] struct {
- root4 *Node[T]
- root6 *Node[T]
-}
-
-func NewTree6[T any]() *Tree6[T] {
- tree := new(Tree6[T])
- tree.root4 = &Node[T]{}
- tree.root6 = &Node[T]{}
- return tree
-}
-
-func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
- var node, next *Node[T]
-
- cidrIP, ipv4 := isIPV4(cidr.IP)
- if ipv4 {
- node = tree.root4
- next = tree.root4
-
- } else {
- node = tree.root6
- next = tree.root6
- }
-
- for i := 0; i < len(cidrIP); i += 4 {
- ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
- mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
- bit := startbit
-
- // Find our last ancestor in the tree
- for bit&mask != 0 {
- if ip&bit != 0 {
- next = node.right
- } else {
- next = node.left
- }
-
- if next == nil {
- break
- }
-
- bit = bit >> 1
- node = next
- }
-
- // Build up the rest of the tree we don't already have
- for bit&mask != 0 {
- next = &Node[T]{}
- next.parent = node
-
- if ip&bit != 0 {
- node.right = next
- } else {
- node.left = next
- }
-
- bit >>= 1
- node = next
- }
- }
-
- // Final node marks our cidr, set the value
- node.value = val
- node.hasValue = true
-}
-
-// Finds the most specific match
-func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
- var node *Node[T]
-
- wholeIP, ipv4 := isIPV4(ip)
- if ipv4 {
- node = tree.root4
- } else {
- node = tree.root6
- }
-
- for i := 0; i < len(wholeIP); i += 4 {
- ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
- bit := startbit
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if bit == 0 {
- break
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
- }
-
- return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
- bit := startbit
- node := tree.root4
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
- ip := hi
- node := tree.root6
-
- for i := 0; i < 2; i++ {
- bit := startbit6
-
- for node != nil {
- if node.hasValue {
- value = node.value
- ok = true
- }
-
- if bit == 0 {
- break
- }
-
- if ip&bit != 0 {
- node = node.right
- } else {
- node = node.left
- }
-
- bit >>= 1
- }
-
- ip = lo
- }
-
- return ok, value
-}
-
-func isIPV4(ip net.IP) (net.IP, bool) {
- if len(ip) == net.IPv4len {
- return ip, true
- }
-
- if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
- return ip[12:16], true
- }
-
- return ip, false
-}
-
-func isZeros(p net.IP) bool {
- for i := 0; i < len(p); i++ {
- if p[i] != 0 {
- return false
- }
- }
- return true
-}
diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go
deleted file mode 100644
index eb159ec..0000000
--- a/cidr/tree6_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package cidr
-
-import (
- "encoding/binary"
- "net"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
- tree := NewTree6[string]()
- tree.AddCIDR(Parse("1.0.0.0/8"), "1")
- tree.AddCIDR(Parse("2.1.0.0/16"), "2")
- tree.AddCIDR(Parse("3.1.1.0/24"), "3")
- tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
- tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
- tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
- tree.AddCIDR(Parse("254.0.0.0/4"), "5")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "1", "1.0.0.0"},
- {true, "1", "1.255.255.255"},
- {true, "2", "2.1.0.0"},
- {true, "2", "2.1.255.255"},
- {true, "3", "3.1.1.0"},
- {true, "3", "3.1.1.255"},
- {true, "4a", "4.1.1.255"},
- {true, "4b", "4.1.1.2"},
- {true, "4c", "4.1.1.1"},
- {true, "5", "240.0.0.0"},
- {true, "5", "255.255.255.255"},
- {true, "6a", "1:2:0:4:1:1:1:1"},
- {true, "6b", "1:2:0:4:5:1:1:1"},
- {true, "6c", "1:2:0:4:5:0:0:0"},
- {false, "", "239.0.0.0"},
- {false, "", "4.1.2.2"},
- }
-
- for _, tt := range tests {
- ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-
- tree = NewTree6[string]()
- tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
- tree.AddCIDR(Parse("::/0"), "cool6")
- ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
- assert.True(t, ok)
- assert.Equal(t, "cool", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("::"))
- assert.True(t, ok)
- assert.Equal(t, "cool6", r)
-
- ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
- assert.True(t, ok)
- assert.Equal(t, "cool6", r)
-}
-
-func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
- tree := NewTree6[string]()
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
- tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
- tests := []struct {
- Found bool
- Result interface{}
- IP string
- }{
- {true, "6a", "1:2:0:4:1:1:1:1"},
- {true, "6b", "1:2:0:4:5:1:1:1"},
- {true, "6c", "1:2:0:4:5:0:0:0"},
- }
-
- for _, tt := range tests {
- ip := net.ParseIP(tt.IP)
- hi := binary.BigEndian.Uint64(ip[:8])
- lo := binary.BigEndian.Uint64(ip[8:])
-
- ok, r := tree.MostSpecificContainsIpV6(hi, lo)
- assert.Equal(t, tt.Found, ok)
- assert.Equal(t, tt.Result, r)
- }
-}
diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go
index 69df4ab..4e5d51d 100644
--- a/cmd/nebula-cert/ca.go
+++ b/cmd/nebula-cert/ca.go
@@ -180,9 +180,15 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
if err != nil {
return fmt.Errorf("error while generating ecdsa keys: %s", err)
}
- // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L60
- rawPriv = key.D.FillBytes(make([]byte, 32))
- pub = elliptic.Marshal(elliptic.P256(), key.X, key.Y)
+
+ // ecdh.PrivateKey lets us get at the encoded bytes, even though
+ // we aren't using ECDH here.
+ eKey, err := key.ECDH()
+ if err != nil {
+ return fmt.Errorf("error while converting ecdsa key: %s", err)
+ }
+ rawPriv = eKey.Bytes()
+ pub = eKey.PublicKey().Bytes()
}
nc := cert.NebulaCertificate{
diff --git a/connection_manager.go b/connection_manager.go
index f5dd594..d2e8616 100644
--- a/connection_manager.go
+++ b/connection_manager.go
@@ -3,6 +3,8 @@ package nebula
import (
"bytes"
"context"
+ "encoding/binary"
+ "net/netip"
"sync"
"time"
@@ -10,8 +12,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
type trafficDecision int
@@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
var index uint32
- var relayFrom iputil.VpnIp
- var relayTo iputil.VpnIp
+ var relayFrom netip.Addr
+ var relayTo netip.Addr
switch {
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
@@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
index = existing.LocalIndex
switch r.Type {
case TerminalType:
- relayFrom = n.intf.myVpnIp
+ relayFrom = n.intf.myVpnNet.Addr()
relayTo = existing.PeerIp
case ForwardingType:
relayFrom = existing.PeerIp
@@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
switch r.Type {
case TerminalType:
- relayFrom = n.intf.myVpnIp
+ relayFrom = n.intf.myVpnNet.Addr()
relayTo = r.PeerIp
case ForwardingType:
relayFrom = r.PeerIp
@@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
}
+ //TODO: IPV6-WORK
+ relayFromB := relayFrom.As4()
+ relayToB := relayTo.As4()
+
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
- RelayFromIp: uint32(relayFrom),
- RelayToIp: uint32(relayTo),
+ RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(relayToB[:]),
}
msg, err := req.Marshal()
if err != nil {
@@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(req.RelayFromIp),
- "relayTo": iputil.VpnIp(req.RelayToIp),
+ "relayFrom": req.RelayFromIp,
+ "relayTo": req.RelayToIp,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
@@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
- if current.vpnIp < n.intf.myVpnIp {
+ if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
@@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}
if n.punchy.GetTargetEverything() {
- hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
+ hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
- } else if hostinfo.remote != nil {
+ } else if hostinfo.remote.IsValid() {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
diff --git a/connection_manager_test.go b/connection_manager_test.go
index a2607a2..5f97cad 100644
--- a/connection_manager_test.go
+++ b/connection_manager_test.go
@@ -5,28 +5,26 @@ import (
"crypto/ed25519"
"crypto/rand"
"net"
+ "net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
-var vpnIp iputil.VpnIp
-
func newTestLighthouse() *LightHouse {
lh := &LightHouse{
l: test.NewLogger(),
- addrMap: map[iputil.VpnIp]*RemoteList{},
- queryChan: make(chan iputil.VpnIp, 10),
+ addrMap: map[netip.Addr]*RemoteList{},
+ queryChan: make(chan netip.Addr, 10),
}
- lighthouses := map[iputil.VpnIp]struct{}{}
- staticList := map[iputil.VpnIp]struct{}{}
+ lighthouses := map[netip.Addr]struct{}{}
+ staticList := map[netip.Addr]struct{}{}
lh.lighthouses.Store(&lighthouses)
lh.staticList.Store(&staticList)
@@ -37,13 +35,15 @@ func newTestLighthouse() *LightHouse {
func Test_NewConnectionManagerTest(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
- hostMap := NewHostMap(l, vpncidr, preferredRanges)
+ hostMap := newHostMap(l, vpncidr)
+ hostMap.preferredRanges.Store(&preferredRanges)
+
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
@@ -118,12 +118,15 @@ func Test_NewConnectionManagerTest(t *testing.T) {
func Test_NewConnectionManagerTest2(t *testing.T) {
l := test.NewLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- preferredRanges := []*net.IPNet{localrange}
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
// Very incomplete mock objects
- hostMap := NewHostMap(l, vpncidr, preferredRanges)
+ hostMap := newHostMap(l, vpncidr)
+ hostMap.preferredRanges.Store(&preferredRanges)
+
cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
@@ -207,10 +210,12 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
IP: net.IPv4(172, 1, 1, 2),
Mask: net.IPMask{255, 255, 255, 0},
}
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- preferredRanges := []*net.IPNet{localrange}
- hostMap := NewHostMap(l, vpncidr, preferredRanges)
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ vpnIp := netip.MustParseAddr("172.1.1.2")
+ preferredRanges := []netip.Prefix{localrange}
+ hostMap := newHostMap(l, vpncidr)
+ hostMap.preferredRanges.Store(&preferredRanges)
// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
diff --git a/connection_state.go b/connection_state.go
index 8ef8b3a..1dd3c8c 100644
--- a/connection_state.go
+++ b/connection_state.go
@@ -72,6 +72,8 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
window: b,
myCert: certState.Certificate,
}
+ // always start the counter from 2, as packet 1 and packet 2 are handshake packets.
+ ci.messageCounter.Add(2)
return ci
}
diff --git a/control.go b/control.go
index 1e27b0f..54d528d 100644
--- a/control.go
+++ b/control.go
@@ -2,7 +2,7 @@ package nebula
import (
"context"
- "net"
+ "net/netip"
"os"
"os/signal"
"syscall"
@@ -10,9 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
- "github.com/slackhq/nebula/udp"
)
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
type controlEach func(h *HostInfo)
type controlHostLister interface {
- QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+ QueryVpnIp(vpnIp netip.Addr) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
- GetPreferredRanges() []*net.IPNet
+ GetPreferredRanges() []netip.Prefix
}
type Control struct {
@@ -39,15 +37,15 @@ type Control struct {
}
type ControlHostInfo struct {
- VpnIp net.IP `json:"vpnIp"`
+ VpnIp netip.Addr `json:"vpnIp"`
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
- RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
+ RemoteAddrs []netip.AddrPort `json:"remoteAddrs"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
- CurrentRemote *udp.Addr `json:"currentRemote"`
- CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"`
- CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"`
+ CurrentRemote netip.AddrPort `json:"currentRemote"`
+ CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"`
+ CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"`
}
// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -131,8 +129,46 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
}
}
+// GetCertByVpnIp returns a single tunnels hostInfo, or nil if not found
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
+ if c.f.myVpnNet.Addr() == vpnIp {
+ return c.f.pki.GetCertState().Certificate
+ }
+ hi := c.f.hostMap.QueryVpnIp(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ return hi.GetCert()
+}
+
+// CreateTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) CreateTunnel(vpnIp netip.Addr) {
+ c.f.handshakeManager.StartHandshake(vpnIp, nil)
+}
+
+// PrintTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
+ hi := c.f.hostMap.QueryVpnIp(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges())
+ return &chi
+}
+
+// QueryLighthouse queries the lighthouse.
+func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
+ hi := c.f.lightHouse.Query(vpnIp)
+ if hi == nil {
+ return nil
+ }
+ return hi.CopyCache()
+}
+
// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
var hl controlHostLister
if pending {
hl = c.f.handshakeManager
@@ -145,24 +181,26 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
return nil
}
- ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
+ ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
return &ch
}
// SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return nil
}
- hostInfo.SetRemote(addr.Copy())
- ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
+ hostInfo.SetRemote(addr)
+ ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch
}
// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return false
@@ -205,7 +243,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
}
// Learn which hosts are being used as relays, so we can shut them down last.
- relayingHosts := map[iputil.VpnIp]*HostInfo{}
+ relayingHosts := map[netip.Addr]*HostInfo{}
// Grab the hostMap lock to access the Relays map
c.f.hostMap.Lock()
for _, relayingHost := range c.f.hostMap.Relays {
@@ -236,15 +274,16 @@ func (c *Control) Device() overlay.Device {
return c.f.inside
}
-func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
chi := ControlHostInfo{
- VpnIp: h.vpnIp.ToIP(),
+ VpnIp: h.vpnIp,
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CurrentRelaysToMe: h.relayState.CopyRelayIps(),
CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
+ CurrentRemote: h.remote,
}
if h.ConnectionState != nil {
@@ -255,10 +294,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi.Cert = c.Copy()
}
- if h.remote != nil {
- chi.CurrentRemote = h.remote.Copy()
- }
-
return chi
}
diff --git a/control_test.go b/control_test.go
index 847332b..fbf29c0 100644
--- a/control_test.go
+++ b/control_test.go
@@ -2,15 +2,14 @@ package nebula
import (
"net"
+ "net/netip"
"reflect"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
- "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
@@ -18,16 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
- hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
- remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
- remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
+ hm := newHostMap(l, netip.Prefix{})
+ hm.preferredRanges.Store(&[]netip.Prefix{})
+
+ remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
+ remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
+
ipNet := net.IPNet{
- IP: net.IPv4(1, 2, 3, 4),
+ IP: remote1.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
ipNet2 := net.IPNet{
- IP: net.ParseIP("1:2:3:4:5:6:7:8"),
+ IP: remote2.Addr().AsSlice(),
Mask: net.IPMask{255, 255, 255, 0},
}
@@ -48,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}
remotes := NewRemoteList(nil)
- remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
- remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
+ remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
+ remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+
+ vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
+ assert.True(t, ok)
+
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
@@ -58,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: vpnIp,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
+ vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
+ assert.True(t, ok)
+
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
@@ -74,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
},
remoteIndexId: 200,
localIndexId: 201,
- vpnIp: iputil.Ip2VpnIp(ipNet2.IP),
+ vpnIp: vpnIp2,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}, &Interface{})
@@ -89,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l: logrus.New(),
}
- thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
+ thi := c.GetHostInfoByVpnIp(vpnIp, false)
expectedInfo := ControlHostInfo{
- VpnIp: net.IPv4(1, 2, 3, 4).To4(),
+ VpnIp: vpnIp,
LocalIndex: 201,
RemoteIndex: 200,
- RemoteAddrs: []*udp.Addr{remote2, remote1},
+ RemoteAddrs: []netip.AddrPort{remote2, remote1},
Cert: crt.Copy(),
MessageCounter: 0,
- CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
- CurrentRelaysToMe: []iputil.VpnIp{},
- CurrentRelaysThroughMe: []iputil.VpnIp{},
+ CurrentRemote: remote1,
+ CurrentRelaysToMe: []netip.Addr{},
+ CurrentRelaysThroughMe: []netip.Addr{},
}
// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
- test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+ assert.EqualValues(t, &expectedInfo, thi)
+ //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
+ //test.AssertDeepCopyEqual(t, &expectedInfo, thi)
// Make sure we don't panic if the host info doesn't have a cert yet
assert.NotPanics(t, func() {
- thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
+ thi = c.GetHostInfoByVpnIp(vpnIp2, false)
})
}
diff --git a/control_tester.go b/control_tester.go
index b786ba3..d46540f 100644
--- a/control_tester.go
+++ b/control_tester.go
@@ -4,14 +4,13 @@
package nebula
import (
- "net"
+ "net/netip"
"github.com/slackhq/nebula/cert"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
@@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse
-func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
+func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
c.f.lightHouse.Lock()
- remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+ remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
- iVpnIp := iputil.Ip2VpnIp(vpnIp)
- if v4 := toAddr.IP.To4(); v4 != nil {
- remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+ if toAddr.Addr().Is4() {
+ remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
} else {
- remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+ remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
}
}
// InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
// This is necessary to inform an initiator of possible relays for communicating with a responder
-func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
+func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
c.f.lightHouse.Lock()
- remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+ remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()
- iVpnIp := iputil.Ip2VpnIp(vpnIp)
- uVpnIp := []uint32{}
- for _, rVPnIp := range relayVpnIps {
- uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
- }
-
- remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
+ remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
}
// GetFromTun will pull a packet off the tun side of nebula
@@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
}
// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
+func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
+ //TODO: IPV6-WORK
ip := layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
- SrcIP: c.f.inside.Cidr().IP,
- DstIP: toIp,
+ SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(),
+ DstIP: toIp.Unmap().AsSlice(),
}
udp := layers.UDP{
@@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
}
-func (c *Control) GetVpnIp() iputil.VpnIp {
- return c.f.myVpnIp
+func (c *Control) GetVpnIp() netip.Addr {
+ return c.f.myVpnNet.Addr()
}
-func (c *Control) GetUDPAddr() string {
- return c.f.outside.(*udp.TesterConn).Addr.String()
+func (c *Control) GetUDPAddr() netip.AddrPort {
+ return c.f.outside.(*udp.TesterConn).Addr
}
-func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
- hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
+ hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
if hostinfo == nil {
return false
}
@@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.pki.GetCertState().Certificate
}
-func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+func (c *Control) ReHandshake(vpnIp netip.Addr) {
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}
diff --git a/dist/arch/nebula.service b/dist/arch/nebula.service
deleted file mode 100644
index 831c71a..0000000
--- a/dist/arch/nebula.service
+++ /dev/null
@@ -1,15 +0,0 @@
-[Unit]
-Description=Nebula overlay networking tool
-Wants=basic.target network-online.target nss-lookup.target time-sync.target
-After=basic.target network.target network-online.target
-
-[Service]
-Type=notify
-NotifyAccess=main
-SyslogIdentifier=nebula
-ExecReload=/bin/kill -HUP $MAINPID
-ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
-Restart=always
-
-[Install]
-WantedBy=multi-user.target
diff --git a/dist/fedora/nebula.service b/dist/fedora/nebula.service
deleted file mode 100644
index 0f947ea..0000000
--- a/dist/fedora/nebula.service
+++ /dev/null
@@ -1,16 +0,0 @@
-[Unit]
-Description=Nebula overlay networking tool
-Wants=basic.target network-online.target nss-lookup.target time-sync.target
-After=basic.target network.target network-online.target
-Before=sshd.service
-
-[Service]
-Type=notify
-NotifyAccess=main
-SyslogIdentifier=nebula
-ExecReload=/bin/kill -HUP $MAINPID
-ExecStart=/usr/bin/nebula -config /etc/nebula/config.yml
-Restart=always
-
-[Install]
-WantedBy=multi-user.target
diff --git a/dns_server.go b/dns_server.go
index 3109b4c..5fea65c 100644
--- a/dns_server.go
+++ b/dns_server.go
@@ -3,6 +3,7 @@ package nebula
import (
"fmt"
"net"
+ "net/netip"
"strconv"
"strings"
"sync"
@@ -10,7 +11,6 @@ import (
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
// This whole thing should be rewritten to use context
@@ -42,21 +42,23 @@ func (d *dnsRecords) Query(data string) string {
}
func (d *dnsRecords) QueryCert(data string) string {
- ip := net.ParseIP(data[:len(data)-1])
- if ip == nil {
+ ip, err := netip.ParseAddr(data[:len(data)-1])
+ if err != nil {
return ""
}
- iip := iputil.Ip2VpnIp(ip)
- hostinfo := d.hostMap.QueryVpnIp(iip)
+
+ hostinfo := d.hostMap.QueryVpnIp(ip)
if hostinfo == nil {
return ""
}
+
q := hostinfo.GetCert()
if q == nil {
return ""
}
+
cert := q.Details
- c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
+ c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c
}
@@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
- b := net.ParseIP(a)
+ b, err := netip.ParseAddr(a)
+ if err != nil {
+ return
+ }
+
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
@@ -96,6 +102,10 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
}
}
}
+
+ if len(m.Answer) == 0 {
+ m.Rcode = dns.RcodeNameError
+ }
}
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
@@ -129,7 +139,12 @@ func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
}
func getDnsServerAddr(c *config.C) string {
- return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
+ dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
+ // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
+ if dnsHost == "[::]" {
+ dnsHost = "::"
+ }
+ return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}
func startDns(l *logrus.Logger, c *config.C) {
diff --git a/dns_server_test.go b/dns_server_test.go
index 830dc8a..69f6ae8 100644
--- a/dns_server_test.go
+++ b/dns_server_test.go
@@ -4,6 +4,8 @@ import (
"testing"
"github.com/miekg/dns"
+ "github.com/slackhq/nebula/config"
+ "github.com/stretchr/testify/assert"
)
func TestParsequery(t *testing.T) {
@@ -17,3 +19,40 @@ func TestParsequery(t *testing.T) {
//parseQuery(m)
}
+
+func Test_getDnsServerAddr(t *testing.T) {
+ c := config.NewC(nil)
+
+ c.Settings["lighthouse"] = map[interface{}]interface{}{
+ "dns": map[interface{}]interface{}{
+ "host": "0.0.0.0",
+ "port": "1",
+ },
+ }
+ assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
+
+ c.Settings["lighthouse"] = map[interface{}]interface{}{
+ "dns": map[interface{}]interface{}{
+ "host": "::",
+ "port": "1",
+ },
+ }
+ assert.Equal(t, "[::]:1", getDnsServerAddr(c))
+
+ c.Settings["lighthouse"] = map[interface{}]interface{}{
+ "dns": map[interface{}]interface{}{
+ "host": "[::]",
+ "port": "1",
+ },
+ }
+ assert.Equal(t, "[::]:1", getDnsServerAddr(c))
+
+ // Make sure whitespace doesn't mess us up
+ c.Settings["lighthouse"] = map[interface{}]interface{}{
+ "dns": map[interface{}]interface{}{
+ "host": "[::] ",
+ "port": "1",
+ },
+ }
+ assert.Equal(t, "[::]:1", getDnsServerAddr(c))
+}
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..400e275
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,11 @@
+FROM gcr.io/distroless/static:latest
+
+ARG TARGETOS TARGETARCH
+COPY build/$TARGETOS-$TARGETARCH/nebula /nebula
+COPY build/$TARGETOS-$TARGETARCH/nebula-cert /nebula-cert
+
+VOLUME ["/config"]
+
+ENTRYPOINT ["/nebula"]
+# Allow users to override the args passed to nebula
+CMD ["-config", "/config/config.yml"]
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000..129744f
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,24 @@
+# NebulaOSS/nebula Docker Image
+
+## Building
+
+From the root of the repository, run `make docker`.
+
+## Running
+
+To run the built image, use the following command:
+
+```
+docker run \
+ --name nebula \
+ --network host \
+ --cap-add NET_ADMIN \
+ --volume ./config:/config \
+ --rm \
+ nebulaoss/nebula
+```
+
+A few notes:
+
+- The `NET_ADMIN` capability is necessary to create the tun adapter on the host (this is unnecessary if the tun device is disabled.)
+- `--volume ./config:/config` should point to a directory that contains your `config.yml` and any other necessary files.
diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go
index 59f1d0e..3d42a56 100644
--- a/e2e/handshakes_test.go
+++ b/e2e/handshakes_test.go
@@ -5,7 +5,7 @@ package e2e
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"time"
@@ -13,19 +13,18 @@ import (
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/e2e/router"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
func BenchmarkHotPath(b *testing.B) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
@@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) {
r.CancelFlowLogs()
for n := 0; n < b.N; n++ {
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
@@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) {
}
func TestGoodHandshake(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
t.Log("Have them consume my stage 0 packet. They have a tunnel now")
theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) {
myControl.WaitForType(1, 0, theirControl)
t.Log("Make sure our host infos are correct")
- assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+ assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
t.Log("Get that cached packet and make sure it looks right")
myCachedPacket := theirControl.GetFromTun(true)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Do a bidirectional tunnel test")
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
@@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) {
}
func TestWrongResponderHandshake(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
// The IPs here are chosen on purpose:
// The current remote handling will sort by preference, public, and then lexically.
// So we need them to have a higher address than evil (we could apply a preference though)
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
- evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
+ evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
// Add their real udp addr, which should be tried after evil.
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl, evilControl)
@@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) {
evilControl.Start()
t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
h := &header.H{}
err := h.Parse(p.Data)
@@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) {
panic(err)
}
- if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+ if p.To == theirUdpAddr && h.Type == 1 {
return router.RouteAndExit
}
@@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) {
t.Log("My cached packet should be received by them")
myCachedPacket := theirControl.GetFromTun(true)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
t.Log("Test the tunnel with them")
- assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Flush all packets from all controllers")
r.FlushAll()
t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
- assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
- assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
+ assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
+ assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
//TODO: assert hostmaps for everyone
@@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) {
// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
// But will eventually collapse down to a single tunnel
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake to start on both me and them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1 handshake packets")
myHsForThem := myControl.GetFromUDP(true)
@@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) {
r.Log("Route until they receive a message packet")
myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Their cached packet should be received by me")
theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
- assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.Log("Do a bidirectional tunnel test")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myHostmapHosts := myControl.ListHostmapHosts(false)
myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) {
r.Log("Spin until connection manager tears down a tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
@@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) {
}
func TestUncleanShutdownRaceLoser(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Nuke my hostmap")
myHostmap := myControl.GetHostmap()
- myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+ myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
p = r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.Log("Assert the tunnel works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(theirControl.GetHostmap().Indexes)
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(theirControl.GetHostmap().Indexes) < start {
break
}
@@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
}
func TestUncleanShutdownRaceWinner(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
theirControl.Start()
r.Log("Trigger a handshake from me to them")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
r.Log("Nuke my hostmap")
theirHostmap := theirControl.GetHostmap()
- theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+ theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
p = r.RouteForAllUntilTxTun(myControl)
- assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
r.Log("Assert the tunnel works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.Log("Wait for the dead index to go away")
start := len(myControl.GetHostmap().Indexes)
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
if len(myControl.GetHostmap().Indexes) < start {
break
}
@@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
}
func TestRelays(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -361,31 +360,31 @@ func TestRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
}
func TestStage1RaceRelays(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) {
theirControl.Start()
r.Log("Get a tunnel between me and relay")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
r.Log("Wait for a packet from them to me")
p := r.RouteForAllUntilTxTun(myControl)
@@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) {
func TestStage1RaceRelays2(t *testing.T) {
//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
l := NewTestLogger()
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Get a tunnel between me and relay")
l.Info("Get a tunnel between me and relay")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
r.Log("Get a tunnel between them and relay")
l.Info("Get a tunnel between them and relay")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
r.Log("Trigger a handshake from both them and me via relay to them and me")
l.Info("Trigger a handshake from both them and me via relay to them and me")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
t.Log("Wait until we remove extra tunnels")
l.Info("Wait until we remove extra tunnels")
@@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) {
"theirControl": len(theirControl.GetHostmap().Indexes),
"relayControl": len(relayControl.GetHostmap().Indexes),
}).Info("Waiting for hostinfos to be removed...")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
retries--
@@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) {
r.Log("Assert the tunnel works")
l.Info("Assert the tunnel works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
myControl.Stop()
theirControl.Stop()
@@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) {
//
////TODO: assert hostmaps
}
+
func TestRehandshakingRelays(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
- c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+ c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+ c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) {
func TestRehandshakingRelaysPrimary(t *testing.T) {
// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
- relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
+ relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
- myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
- myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
- relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+ myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+ myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+ relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, relayControl, theirControl)
@@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
theirControl.Start()
t.Log("Trigger a handshake from me to them via the relay")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
p := r.RouteForAllUntilTxTun(theirControl)
r.Log("Assert the tunnel works")
- assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+ assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
- assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
- c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+ c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
@@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
- assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+ assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+ c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
@@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
// We should have two hostinfos on all sides
for len(myControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(theirControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
for len(relayControl.GetHostmap().Indexes) != 2 {
t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
r.Log("Assert the relay tunnel still works")
- assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+ assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
r.Log("yupitdoes")
time.Sleep(time.Second)
}
@@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
}
func TestRehandshaking(t *testing.T) {
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) {
myConfig.ReloadConfigString(string(rc))
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+ c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
// We have a new certificate now
break
@@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
- c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
assert.Contains(t, c.Cert.Details.Groups, "new group")
// We should only have a single tunnel now on both sides
@@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) {
func TestRehandshakingLoser(t *testing.T) {
// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
// Should be the one with the new certificate
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
// Put their info in our lighthouse and vice versa
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(t, myControl, theirControl)
@@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) {
theirControl.Start()
t.Log("Stand up a tunnel between me and them")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
- tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
- tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+ tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+ tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) {
theirConfig.ReloadConfigString(string(rc))
for {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
- theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+ theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
if theirNewGroup {
@@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) {
r.Log("Spin until there is only 1 tunnel")
for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
t.Log("Connection manager hasn't ticked yet")
time.Sleep(time.Second)
}
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myFinalHostmapHosts := myControl.ListHostmapHosts(false)
myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
// Make sure the correct tunnel won
- theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+ theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
// We should only have a single tunnel now on both sides
@@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) {
// This test forces stage 1, stage 2, stage 1 to be received by me from them
// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
// caused a cross-linked hostinfo
- ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
- theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+ ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+ theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
- myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
- theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+ myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+ theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) {
//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
t.Log("Start both handshakes")
- myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
- theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+ myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+ theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
t.Log("Get both stage 1")
myStage1ForThem := myControl.GetFromUDP(true)
@@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
t.Log("Make sure the tunnel still works")
- assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+ assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
myControl.Stop()
theirControl.Stop()
diff --git a/e2e/helpers.go b/e2e/helpers.go
index 13146ab..71df805 100644
--- a/e2e/helpers.go
+++ b/e2e/helpers.go
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"io"
"net"
+ "net/netip"
"time"
"github.com/slackhq/nebula/cert"
@@ -12,7 +13,7 @@ import (
)
// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
}
if len(ips) > 0 {
- nc.Details.Ips = ips
+ nc.Details.Ips = make([]*net.IPNet, len(ips))
+ for i, ip := range ips {
+ nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+ }
}
if len(subnets) > 0 {
- nc.Details.Subnets = subnets
+ nc.Details.Subnets = make([]*net.IPNet, len(subnets))
+ for i, ip := range subnets {
+ nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+ }
}
if len(groups) > 0 {
@@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum()
if err != nil {
panic(err)
@@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
}
pub, rawPriv := x25519Keypair()
-
+ ipb := ip.Addr().AsSlice()
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
- Name: name,
- Ips: []*net.IPNet{ip},
- Subnets: subnets,
+ Name: name,
+ Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
+ //Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go
index b05c84a..527f55b 100644
--- a/e2e/helpers_test.go
+++ b/e2e/helpers_test.go
@@ -6,7 +6,7 @@ package e2e
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"testing"
"time"
@@ -19,7 +19,6 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/e2e/router"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
@@ -27,15 +26,23 @@ import (
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
- vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
- copy(vpnIpNet.IP, udpIp)
- vpnIpNet.IP[1] += 128
- udpAddr := net.UDPAddr{
- IP: udpIp,
- Port: 4242,
+ vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
+ if err != nil {
+ panic(err)
+ }
+
+ var udpAddr netip.AddrPort
+ if vpnIpNet.Addr().Is4() {
+ budpIp := vpnIpNet.Addr().As4()
+ budpIp[1] -= 128
+ udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+ } else {
+ budpIp := vpnIpNet.Addr().As16()
+ budpIp[13] -= 128
+ udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
@@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
// "try_interval": "1s",
//},
"listen": m{
- "host": udpAddr.IP.String(),
- "port": udpAddr.Port,
+ "host": udpAddr.Addr().String(),
+ "port": udpAddr.Port(),
},
"logging": m{
"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
panic(err)
}
- return control, vpnIpNet, &udpAddr, c
+ return control, vpnIpNet, udpAddr, c
}
type doneCb func()
@@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
}
}
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
}
-func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
// Get both host infos
- hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+ hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
- hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+ hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
// Check that both vpn and real addr are correct
assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
- assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
- assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
-
- assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
- assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
+ assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
+ assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
// Check that our indexes match
assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
@@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
//checkIndexes("hmB", hmB, hAinB)
}
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
+func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")
- assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
- assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
+ assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
+ assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
assert.NotNil(t, udp, "No udp data found")
diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go
index 120be69..c14ab2e 100644
--- a/e2e/router/hostmap.go
+++ b/e2e/router/hostmap.go
@@ -5,11 +5,11 @@ package router
import (
"fmt"
+ "net/netip"
"sort"
"strings"
"github.com/slackhq/nebula"
- "github.com/slackhq/nebula/iputil"
)
type edge struct {
@@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
return r, globalLines
}
-func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
- keys := make([]iputil.VpnIp, 0, len(hosts))
+func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
+ keys := make([]netip.Addr, 0, len(hosts))
for key := range hosts {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
- return keys[i] > keys[j]
+ return keys[i].Compare(keys[j]) > 0
})
return keys
diff --git a/e2e/router/router.go b/e2e/router/router.go
index 730853a..0890570 100644
--- a/e2e/router/router.go
+++ b/e2e/router/router.go
@@ -6,12 +6,11 @@ package router
import (
"context"
"fmt"
- "net"
+ "net/netip"
"os"
"path/filepath"
"reflect"
"sort"
- "strconv"
"strings"
"sync"
"testing"
@@ -21,7 +20,6 @@ import (
"github.com/google/gopacket/layers"
"github.com/slackhq/nebula"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/exp/maps"
)
@@ -29,18 +27,18 @@ import (
type R struct {
// Simple map of the ip:port registered on a control to the control
// Basically a router, right?
- controls map[string]*nebula.Control
+ controls map[netip.AddrPort]*nebula.Control
// A map for inbound packets for a control that doesn't know about this address
- inNat map[string]*nebula.Control
+ inNat map[netip.AddrPort]*nebula.Control
// A last used map, if an inbound packet hit the inNat map then
// all return packets should use the same last used inbound address for the outbound sender
// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
- outNat map[string]net.UDPAddr
+ outNat map[string]netip.AddrPort
// A map of vpn ip to the nebula control it belongs to
- vpnControls map[iputil.VpnIp]*nebula.Control
+ vpnControls map[netip.Addr]*nebula.Control
ignoreFlows []ignoreFlow
flow []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
}
r := &R{
- controls: make(map[string]*nebula.Control),
- vpnControls: make(map[iputil.VpnIp]*nebula.Control),
- inNat: make(map[string]*nebula.Control),
- outNat: make(map[string]net.UDPAddr),
+ controls: make(map[netip.AddrPort]*nebula.Control),
+ vpnControls: make(map[netip.Addr]*nebula.Control),
+ inNat: make(map[netip.AddrPort]*nebula.Control),
+ outNat: make(map[string]netip.AddrPort),
flow: []flowEntry{},
ignoreFlows: []ignoreFlow{},
fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
for _, c := range controls {
addr := c.GetUDPAddr()
if _, ok := r.controls[addr]; ok {
- panic("Duplicate listen address: " + addr)
+ panic("Duplicate listen address: " + addr.String())
}
r.vpnControls[c.GetVpnIp()] = c
@@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
// It does not look at the addr attached to the instance.
// If a route is used, this will behave like a NAT for the return path.
// Rewriting the source ip:port to what was last sent to from the origin
-func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
+func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
r.Lock()
defer r.Unlock()
- inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
+ inAddr := netip.AddrPortFrom(ip, port)
if _, ok := r.inNat[inAddr]; ok {
- panic("Duplicate listen address inNat: " + inAddr)
+ panic("Duplicate listen address inNat: " + inAddr.String())
}
r.inNat[inAddr] = c
}
@@ -198,7 +196,7 @@ func (r *R) renderFlow() {
panic(err)
}
- var participants = map[string]struct{}{}
+ var participants = map[netip.AddrPort]struct{}{}
var participantsVals []string
fmt.Fprintln(f, "```mermaid")
@@ -215,7 +213,7 @@ func (r *R) renderFlow() {
continue
}
participants[addr] = struct{}{}
- sanAddr := strings.Replace(addr, ":", "-", 1)
+ sanAddr := strings.Replace(addr.String(), ":", "-", 1)
participantsVals = append(participantsVals, sanAddr)
fmt.Fprintf(
f, " participant %s as Nebula: %s
UDP: %s\n",
@@ -252,9 +250,9 @@ func (r *R) renderFlow() {
fmt.Fprintf(f,
" %s%s%s: %s(%s), index %v, counter: %v\n",
- strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
line,
- strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
)
}
@@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
func (r *R) renderHostmaps(title string) {
c := maps.Values(r.controls)
sort.SliceStable(c, func(i, j int) bool {
- return c[i].GetVpnIp() > c[j].GetVpnIp()
+ return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
})
s := renderHostmaps(c...)
@@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
// Nope, lets push the sender along
case p := <-udpTx:
- outAddr := sender.GetUDPAddr()
r.Lock()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- c := r.getControl(outAddr, inAddr, p)
+ c := r.getControl(sender.GetUDPAddr(), p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
@@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
} else {
// we are a udp tx, route and continue
p := rx.Interface().(*udp.Packet)
- outAddr := cm[x].GetUDPAddr()
-
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- c := r.getControl(outAddr, inAddr, p)
+ c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if c == nil {
r.Unlock()
panic("No control for udp tx")
@@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
panic(err)
}
- outAddr := sender.GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't RouteExitFunc for host: " + p.To.String())
}
e := whatDo(p, receiver)
@@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
// RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
// finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
// If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
if finish == KeepRouting {
finish = RouteAndExit
}
r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
- if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
+ if p.To == toAddr {
return finish
}
@@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
r.Lock()
p := rx.Interface().(*udp.Packet)
-
- outAddr := cm[x].GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't RouteForAllExitFunc for host: " + p.To.String())
}
e := whatDo(p, receiver)
@@ -697,12 +685,10 @@ func (r *R) FlushAll() {
p := rx.Interface().(*udp.Packet)
- outAddr := cm[x].GetUDPAddr()
- inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
- receiver := r.getControl(outAddr, inAddr, p)
+ receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
if receiver == nil {
r.Unlock()
- panic("Can't route for host: " + inAddr)
+ panic("Can't FlushAll for host: " + p.To.String())
}
r.Unlock()
}
@@ -710,28 +696,14 @@ func (r *R) FlushAll() {
// getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
// This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
- if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
- p.FromIp = newAddr.IP
- p.FromPort = uint16(newAddr.Port)
+func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
+ if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
+ p.From = newAddr
}
c, ok := r.inNat[toAddr]
if ok {
- sHost, sPort, err := net.SplitHostPort(toAddr)
- if err != nil {
- panic(err)
- }
-
- port, err := strconv.Atoi(sPort)
- if err != nil {
- panic(err)
- }
-
- r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
- IP: net.ParseIP(sHost),
- Port: port,
- }
+ r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
return c
}
@@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string {
}
from := "unknown"
- if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
- from = c.GetUDPAddr()
+ srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
+ if c, ok := r.vpnControls[srcAddr]; ok {
+ from = c.GetUDPAddr().String()
}
udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
@@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string {
return fmt.Sprintf(
" %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n",
strings.Replace(from, ":", "-", 1),
- strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+ strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
udp.SrcPort,
udp.DstPort,
string(data.Payload()),
diff --git a/examples/config.yml b/examples/config.yml
index ff5b403..c74ffc6 100644
--- a/examples/config.yml
+++ b/examples/config.yml
@@ -167,8 +167,7 @@ punchy:
# Preferred ranges is used to define a hint about the local network ranges, which speeds up discovering the fastest
# path to a network adjacent nebula node.
-# NOTE: the previous option "local_range" only allowed definition of a single range
-# and has been deprecated for "preferred_ranges"
+# This setting is reloadable.
#preferred_ranges: ["172.16.0.0/24"]
# sshd can expose informational and administrative functions via ssh. This can expose informational and administrative
@@ -181,12 +180,15 @@ punchy:
# A file containing the ssh host private key to use
# A decent way to generate one: ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N "" < /dev/null
#host_key: ./ssh_host_ed25519_key
- # A file containing a list of authorized public keys
+ # Authorized users and their public keys
#authorized_users:
#- user: steeeeve
# keys can be an array of strings or single string
#keys:
#- "ssh public key string"
+ # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access.
+ #trusted_cas:
+ #- "ssh public key string"
# EXPERIMENTAL: relay support for networks that can't establish direct connections.
relay:
@@ -230,6 +232,7 @@ tun:
# `mtu`: will default to tun mtu if this option is not specified
# `metric`: will default to 0 if this option is not specified
# `install`: will default to true, controls whether this route is installed in the systems routing table.
+ # This setting is reloadable.
unsafe_routes:
#- route: 172.16.1.0/24
# via: 192.168.100.99
@@ -244,7 +247,10 @@ tun:
# TODO
# Configure logging level
logging:
- # panic, fatal, error, warning, info, or debug. Default is info
+ # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
+ #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some
+ # scenarios. Debug logging is also CPU intensive and will decrease performance overall.
+ # Only enable debug logging while actively investigating an issue.
level: info
# json or text formats currently available. Default is text
format: text
diff --git a/examples/quickstart-vagrant/README.md b/examples/quickstart-vagrant/README.md
deleted file mode 100644
index 108de9e..0000000
--- a/examples/quickstart-vagrant/README.md
+++ /dev/null
@@ -1,138 +0,0 @@
-# Quickstart Guide
-
-This guide is intended to bring up a vagrant environment with 1 lighthouse and 2 generic hosts running nebula.
-
-## Creating the virtualenv for ansible
-
-Within the `quickstart/` directory, do the following
-
-```
-# make a virtual environment
-virtualenv venv
-
-# get into the virtualenv
-source venv/bin/activate
-
-# install ansible
-pip install -r requirements.yml
-```
-
-## Bringing up the vagrant environment
-
-A plugin that is used for the Vagrant environment is `vagrant-hostmanager`
-
-To install, run
-
-```
-vagrant plugin install vagrant-hostmanager
-```
-
-All hosts within the Vagrantfile are brought up with
-
-`vagrant up`
-
-Once the boxes are up, go into the `ansible/` directory and deploy the playbook by running
-
-`ansible-playbook playbook.yml -i inventory -u vagrant`
-
-## Testing within the vagrant env
-
-Once the ansible run is done, hop onto a vagrant box
-
-`vagrant ssh generic1.vagrant`
-
-or specifically
-
-`ssh vagrant@` (password for the vagrant user on the boxes is `vagrant`)
-
-See `/etc/nebula/config.yml` on a box for firewall rules.
-
-To see full handshakes and hostmaps, change the logging config of `/etc/nebula/config.yml` on the vagrant boxes from
-info to debug.
-
-You can watch nebula logs by running
-
-```
-sudo journalctl -fu nebula
-```
-
-Refer to the nebula src code directory's README for further instructions on configuring nebula.
-
-## Troubleshooting
-
-### Is nebula up and running?
-
-Run and verify that
-
-```
-ifconfig
-```
-
-shows you an interface with the name `nebula1` being up.
-
-```
-vagrant@generic1:~$ ifconfig nebula1
-nebula1: flags=4305 mtu 1300
- inet 10.168.91.210 netmask 255.128.0.0 destination 10.168.91.210
- inet6 fe80::aeaf:b105:e6dc:936c prefixlen 64 scopeid 0x20
- unspec 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 txqueuelen 500 (UNSPEC)
- RX packets 2 bytes 168 (168.0 B)
- RX errors 0 dropped 0 overruns 0 frame 0
- TX packets 11 bytes 600 (600.0 B)
- TX errors 0 dropped 0 overruns 0 carrier 0 collisions 0
-```
-
-### Connectivity
-
-Are you able to ping other boxes on the private nebula network?
-
-The following are the private nebula ip addresses of the vagrant env
-
-```
-generic1.vagrant [nebula_ip] 10.168.91.210
-generic2.vagrant [nebula_ip] 10.168.91.220
-lighthouse1.vagrant [nebula_ip] 10.168.91.230
-```
-
-Try pinging generic1.vagrant to and from any other box using its nebula ip above.
-
-Double check the nebula firewall rules under /etc/nebula/config.yml to make sure that connectivity is allowed for your use-case if on a specific port.
-
-```
-vagrant@lighthouse1:~$ grep -A21 firewall /etc/nebula/config.yml
-firewall:
- conntrack:
- tcp_timeout: 12m
- udp_timeout: 3m
- default_timeout: 10m
-
- inbound:
- - proto: icmp
- port: any
- host: any
- - proto: any
- port: 22
- host: any
- - proto: any
- port: 53
- host: any
-
- outbound:
- - proto: any
- port: any
- host: any
-```
diff --git a/examples/quickstart-vagrant/Vagrantfile b/examples/quickstart-vagrant/Vagrantfile
deleted file mode 100644
index ab9408f..0000000
--- a/examples/quickstart-vagrant/Vagrantfile
+++ /dev/null
@@ -1,40 +0,0 @@
-Vagrant.require_version ">= 2.2.6"
-
-nodes = [
- { :hostname => 'generic1.vagrant', :ip => '172.11.91.210', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
- { :hostname => 'generic2.vagrant', :ip => '172.11.91.220', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
- { :hostname => 'lighthouse1.vagrant', :ip => '172.11.91.230', :box => 'bento/ubuntu-18.04', :ram => '512', :cpus => 1},
-]
-
-Vagrant.configure("2") do |config|
-
- config.ssh.insert_key = false
-
- if Vagrant.has_plugin?('vagrant-cachier')
- config.cache.enable :apt
- else
- printf("** Install vagrant-cachier plugin to speedup deploy: `vagrant plugin install vagrant-cachier`.**\n")
- end
-
- if Vagrant.has_plugin?('vagrant-hostmanager')
- config.hostmanager.enabled = true
- config.hostmanager.manage_host = true
- config.hostmanager.include_offline = true
- else
- config.vagrant.plugins = "vagrant-hostmanager"
- end
-
- nodes.each do |node|
- config.vm.define node[:hostname] do |node_config|
- node_config.vm.box = node[:box]
- node_config.vm.hostname = node[:hostname]
- node_config.vm.network :private_network, ip: node[:ip]
- node_config.vm.provider :virtualbox do |vb|
- vb.memory = node[:ram]
- vb.cpus = node[:cpus]
- vb.customize ["modifyvm", :id, "--natdnshostresolver1", "on"]
- vb.customize ['guestproperty', 'set', :id, '/VirtualBox/GuestAdd/VBoxService/--timesync-set-threshold', 10000]
- end
- end
- end
-end
diff --git a/examples/quickstart-vagrant/ansible/ansible.cfg b/examples/quickstart-vagrant/ansible/ansible.cfg
deleted file mode 100644
index 518a4f1..0000000
--- a/examples/quickstart-vagrant/ansible/ansible.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[defaults]
-host_key_checking = False
-private_key_file = ~/.vagrant.d/insecure_private_key
-become = yes
diff --git a/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py b/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py
deleted file mode 100644
index a21e82d..0000000
--- a/examples/quickstart-vagrant/ansible/filter_plugins/to_nebula_ip.py
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/python
-
-
-class FilterModule(object):
- def filters(self):
- return {
- 'to_nebula_ip': self.to_nebula_ip,
- 'map_to_nebula_ips': self.map_to_nebula_ips,
- }
-
- def to_nebula_ip(self, ip_str):
- ip_list = list(map(int, ip_str.split(".")))
- ip_list[0] = 10
- ip_list[1] = 168
- ip = '.'.join(map(str, ip_list))
- return ip
-
- def map_to_nebula_ips(self, ip_strs):
- ip_list = [ self.to_nebula_ip(ip_str) for ip_str in ip_strs ]
- ips = ', '.join(ip_list)
- return ips
diff --git a/examples/quickstart-vagrant/ansible/inventory b/examples/quickstart-vagrant/ansible/inventory
deleted file mode 100644
index 0bae407..0000000
--- a/examples/quickstart-vagrant/ansible/inventory
+++ /dev/null
@@ -1,11 +0,0 @@
-[all]
-generic1.vagrant
-generic2.vagrant
-lighthouse1.vagrant
-
-[generic]
-generic1.vagrant
-generic2.vagrant
-
-[lighthouse]
-lighthouse1.vagrant
diff --git a/examples/quickstart-vagrant/ansible/playbook.yml b/examples/quickstart-vagrant/ansible/playbook.yml
deleted file mode 100644
index c3c7d9f..0000000
--- a/examples/quickstart-vagrant/ansible/playbook.yml
+++ /dev/null
@@ -1,23 +0,0 @@
----
-- name: test connection to vagrant boxes
- hosts: all
- tasks:
- - debug: msg=ok
-
-- name: build nebula binaries locally
- connection: local
- hosts: localhost
- tasks:
- - command: chdir=../../../ make build/linux-amd64/"{{ item }}"
- with_items:
- - nebula
- - nebula-cert
- tags:
- - build-nebula
-
-- name: install nebula on all vagrant hosts
- hosts: all
- become: yes
- gather_facts: yes
- roles:
- - nebula
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml
deleted file mode 100644
index f8e7a99..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/defaults/main.yml
+++ /dev/null
@@ -1,3 +0,0 @@
----
-# defaults file for nebula
-nebula_config_directory: "/etc/nebula/"
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service b/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service
deleted file mode 100644
index fd7a067..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/files/systemd.nebula.service
+++ /dev/null
@@ -1,14 +0,0 @@
-[Unit]
-Description=Nebula overlay networking tool
-Wants=basic.target network-online.target nss-lookup.target time-sync.target
-After=basic.target network.target network-online.target
-Before=sshd.service
-
-[Service]
-SyslogIdentifier=nebula
-ExecReload=/bin/kill -HUP $MAINPID
-ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml
-Restart=always
-
-[Install]
-WantedBy=multi-user.target
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt
deleted file mode 100644
index 6148687..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt
+++ /dev/null
@@ -1,5 +0,0 @@
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSB0ZXN0IENBKNXC1NYFMNXIhO0GOiCmVYeZ9tkB4WEnawmkrca+
-hsAg9otUFhpAowZeJ33KVEABEkAORybHQUUyVFbKYzw0JHfVzAQOHA4kwB1yP9IV
-KpiTw9+ADz+wA+R5tn9B+L8+7+Apc+9dem4BQULjA5mRaoYN
------END NEBULA CERTIFICATE-----
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key b/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
deleted file mode 100644
index 394043c..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
+++ /dev/null
@@ -1,4 +0,0 @@
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-FEXZKMSmg8CgIODR0ymUeNT3nbnVpMi7nD79UgkCRHWmVYeZ9tkB4WEnawmkrca+
-hsAg9otUFhpAowZeJ33KVA==
------END NEBULA ED25519 PRIVATE KEY-----
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml
deleted file mode 100644
index 0e09599..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/handlers/main.yml
+++ /dev/null
@@ -1,5 +0,0 @@
----
-# handlers file for nebula
-
-- name: restart nebula
- service: name=nebula state=restarted
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml
deleted file mode 100644
index ffc89d5..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/tasks/main.yml
+++ /dev/null
@@ -1,62 +0,0 @@
----
-# tasks file for nebula
-
-- name: get the vagrant network interface and set fact
- set_fact:
- vagrant_ifce: "ansible_{{ ansible_interfaces | difference(['lo',ansible_default_ipv4.alias]) | sort | first }}"
- tags:
- - nebula-conf
-
-- name: install built nebula binary
- copy: src="../../../../../build/linux-amd64/{{ item }}" dest="/usr/local/bin" mode=0755
- with_items:
- - nebula
- - nebula-cert
-
-- name: create nebula config directory
- file: path="{{ nebula_config_directory }}" state=directory mode=0755
-
-- name: temporarily copy over root.crt and root.key to sign
- copy: src={{ item }} dest=/opt/{{ item }}
- with_items:
- - vagrant-test-ca.key
- - vagrant-test-ca.crt
-
-- name: remove previously signed host certificate
- file: dest=/etc/nebula/{{ item }} state=absent
- with_items:
- - host.crt
- - host.key
-
-- name: sign using the root key
- command: nebula-cert sign -ca-crt /opt/vagrant-test-ca.crt -ca-key /opt/vagrant-test-ca.key -duration 4320h -groups vagrant -ip {{ hostvars[inventory_hostname][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}/9 -name {{ ansible_hostname }}.nebula -out-crt /etc/nebula/host.crt -out-key /etc/nebula/host.key
-
-- name: remove root.key used to sign
- file: dest=/opt/{{ item }} state=absent
- with_items:
- - vagrant-test-ca.key
-
-- name: write the content of the trusted ca certificate
- copy: src="vagrant-test-ca.crt" dest="/etc/nebula/vagrant-test-ca.crt"
- notify: restart nebula
-
-- name: Create config directory
- file: path="{{ nebula_config_directory }}" owner=root group=root mode=0755 state=directory
-
-- name: nebula config
- template: src=config.yml.j2 dest="/etc/nebula/config.yml" mode=0644 owner=root group=root
- notify: restart nebula
- tags:
- - nebula-conf
-
-- name: nebula systemd
- copy: src=systemd.nebula.service dest="/etc/systemd/system/nebula.service" mode=0644 owner=root group=root
- register: addconf
- notify: restart nebula
-
-- name: maybe reload systemd
- shell: systemctl daemon-reload
- when: addconf.changed
-
-- name: nebula running
- service: name="nebula" state=started enabled=yes
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2 b/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2
deleted file mode 100644
index a05b1e3..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/templates/config.yml.j2
+++ /dev/null
@@ -1,85 +0,0 @@
-pki:
- ca: /etc/nebula/vagrant-test-ca.crt
- cert: /etc/nebula/host.crt
- key: /etc/nebula/host.key
-
-# Port Nebula will be listening on
-listen:
- host: 0.0.0.0
- port: 4242
-
-# sshd can expose informational and administrative functions via ssh
-sshd:
- # Toggles the feature
- enabled: true
- # Host and port to listen on
- listen: 127.0.0.1:2222
- # A file containing the ssh host private key to use
- host_key: /etc/ssh/ssh_host_ed25519_key
- # A file containing a list of authorized public keys
- authorized_users:
-{% for user in nebula_users %}
- - user: {{ user.name }}
- keys:
-{% for key in user.ssh_auth_keys %}
- - "{{ key }}"
-{% endfor %}
-{% endfor %}
-
-local_range: 10.168.0.0/16
-
-static_host_map:
-# lighthouse
- {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}: ["{{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address']}}:4242"]
-
-default_route: "0.0.0.0"
-
-lighthouse:
-{% if 'lighthouse' in group_names %}
- am_lighthouse: true
- serve_dns: true
-{% else %}
- am_lighthouse: false
-{% endif %}
- interval: 60
-{% if 'generic' in group_names %}
- hosts:
- - {{ hostvars[groups['lighthouse'][0]][vagrant_ifce]['ipv4']['address'] | to_nebula_ip }}
-{% endif %}
-
-# Configure the private interface
-tun:
- dev: nebula1
- # Sets MTU of the tun dev.
- # MTU of the tun must be smaller than the MTU of the eth0 interface
- mtu: 1300
-
-# TODO
-# Configure logging level
-logging:
- level: info
- format: json
-
-firewall:
- conntrack:
- tcp_timeout: 12m
- udp_timeout: 3m
- default_timeout: 10m
-
- inbound:
- - proto: icmp
- port: any
- host: any
- - proto: any
- port: 22
- host: any
-{% if "lighthouse" in groups %}
- - proto: any
- port: 53
- host: any
-{% endif %}
-
- outbound:
- - proto: any
- port: any
- host: any
diff --git a/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml b/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml
deleted file mode 100644
index 7a3ae5d..0000000
--- a/examples/quickstart-vagrant/ansible/roles/nebula/vars/main.yml
+++ /dev/null
@@ -1,7 +0,0 @@
----
-# vars file for nebula
-
-nebula_users:
- - name: user1
- ssh_auth_keys:
- - "ed25519 place-your-ssh-public-key-here"
diff --git a/examples/quickstart-vagrant/requirements.yml b/examples/quickstart-vagrant/requirements.yml
deleted file mode 100644
index 90d4055..0000000
--- a/examples/quickstart-vagrant/requirements.yml
+++ /dev/null
@@ -1 +0,0 @@
-ansible
diff --git a/examples/service_scripts/nebula.open-rc b/examples/service_scripts/nebula.open-rc
new file mode 100644
index 0000000..2beca66
--- /dev/null
+++ b/examples/service_scripts/nebula.open-rc
@@ -0,0 +1,35 @@
+#!/sbin/openrc-run
+#
+# nebula service for open-rc systems
+
+extra_commands="checkconfig"
+
+: ${NEBULA_CONFDIR:=${RC_PREFIX%/}/etc/nebula}
+: ${NEBULA_CONFIG:=${NEBULA_CONFDIR}/config.yml}
+: ${NEBULA_BINARY:=${NEBULA_BINARY}${RC_PREFIX%/}/usr/local/sbin/nebula}
+
+command="${NEBULA_BINARY}"
+command_args="${NEBULA_OPTS} -config ${NEBULA_CONFIG}"
+
+supervisor="supervise-daemon"
+
+description="A scalable overlay networking tool with a focus on performance, simplicity and security"
+
+required_dirs="${NEBULA_CONFDIR}"
+required_files="${NEBULA_CONFIG}"
+
+checkconfig() {
+ "${command}" -test ${command_args} || return 1
+}
+
+start_pre() {
+ if [ "${RC_CMD}" != "restart" ] ; then
+ checkconfig || return $?
+ fi
+}
+
+stop_pre() {
+ if [ "${RC_CMD}" = "restart" ] ; then
+ checkconfig || return $?
+ fi
+}
diff --git a/firewall.go b/firewall.go
index cf2bc52..8a409d2 100644
--- a/firewall.go
+++ b/firewall.go
@@ -2,37 +2,31 @@ package nebula
import (
"crypto/sha256"
- "encoding/binary"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
- "net"
+ "net/netip"
"reflect"
"strconv"
"strings"
"sync"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
)
-const tcpACK = 0x10
-const tcpFIN = 0x01
-
type FirewallInterface interface {
- AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
+ AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
- Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
- Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
// record why the original connection passed the firewall, so we can re-validate
// after ruleset changes. Note, rulesVersion is a uint16 so that these two
@@ -58,16 +52,14 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
- localIps *cidr.Tree4[struct{}]
- assignedCIDR *net.IPNet
+ localIps *bart.Table[struct{}]
+ assignedCIDR netip.Prefix
hasSubnets bool
rules string
rulesVersion uint16
defaultLocalCIDRAny bool
- trackTCPRTT bool
- metricTCPRTT metrics.Histogram
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
@@ -116,7 +108,7 @@ type FirewallRule struct {
Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups
- CIDR *cidr.Tree4[*firewallLocalCIDR]
+ CIDR *bart.Table[*firewallLocalCIDR]
}
type firewallGroups struct {
@@ -130,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
- LocalCIDR *cidr.Tree4[struct{}]
+ LocalCIDR *bart.Table[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -152,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
- localIps := cidr.NewTree4[struct{}]()
- var assignedCIDR *net.IPNet
+ localIps := new(bart.Table[struct{}])
+ var assignedCIDR netip.Prefix
+ var assignedSet bool
for _, ip := range c.Details.Ips {
- ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
- localIps.AddCIDR(ipNet, struct{}{})
+ //TODO: IPV6-WORK the unmap is a bit unfortunate
+ nip, _ := netip.AddrFromSlice(ip.IP)
+ nip = nip.Unmap()
+ nprefix := netip.PrefixFrom(nip, nip.BitLen())
+ localIps.Insert(nprefix, struct{}{})
- if assignedCIDR == nil {
+ if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
- assignedCIDR = ipNet
+ assignedCIDR = nprefix
+ assignedSet = true
}
}
for _, n := range c.Details.Subnets {
- localIps.AddCIDR(n, struct{}{})
+ nip, _ := netip.AddrFromSlice(n.IP)
+ ones, _ := n.Mask.Size()
+ nip = nip.Unmap()
+ localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
}
return &Firewall{
@@ -183,7 +183,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasSubnets: len(c.Details.Subnets) > 0,
l: l,
- metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
incomingMetrics: firewallMetrics{
droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
@@ -246,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
}
// AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
- if ip != nil {
+ if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
- if localIp != nil {
+ if localIp.IsValid() {
lIp = localIp.String()
}
@@ -391,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
- var cidr *net.IPNet
+ var cidr netip.Prefix
if r.Cidr != "" {
- _, cidr, err = net.ParseCIDR(r.Cidr)
+ cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
- var localCidr *net.IPNet
+ var localCidr netip.Prefix
if r.LocalCidr != "" {
- _, localCidr, err = net.ParseCIDR(r.LocalCidr)
+ localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
@@ -422,15 +421,16 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
// Check if we spoke to this tuple, if we did then allow this packet
- if f.inConns(packet, fp, incoming, h, caPool, localCache) {
+ if f.inConns(fp, h, caPool, localCache) {
return nil
}
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
- ok, _ := remoteCidr.Contains(fp.RemoteIP)
+ //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+ _, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
@@ -444,7 +444,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
}
// Make sure we are supposed to be handling this local ip address
- ok, _ := f.localIps.Contains(fp.LocalIP)
+ //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+ _, ok := f.localIps.Lookup(fp.LocalIP)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP
@@ -462,7 +463,7 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
}
// We always want to conntrack since it is a faster operation
- f.addConn(packet, fp, incoming)
+ f.addConn(fp, incoming)
return nil
}
@@ -491,7 +492,7 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
-func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
@@ -551,11 +552,6 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
switch fp.Protocol {
case firewall.ProtoTCP:
c.Expires = time.Now().Add(f.TCPTimeout)
- if incoming {
- f.checkTCPRTT(c, packet)
- } else {
- setTCPRTTTracking(c, packet)
- }
case firewall.ProtoUDP:
c.Expires = time.Now().Add(f.UDPTimeout)
default:
@@ -571,16 +567,13 @@ func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *
return true
}
-func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
+func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
timeout = f.TCPTimeout
- if !incoming {
- setTCPRTTTracking(c, packet)
- }
case firewall.ProtoUDP:
timeout = f.UDPTimeout
default:
@@ -606,7 +599,6 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) {
- //TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
@@ -650,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@@ -694,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0),
- CIDR: cidr.NewTree4[*firewallLocalCIDR](),
+ CIDR: new(bart.Table[*firewallLocalCIDR]),
}
}
@@ -757,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
- LocalCIDR: cidr.NewTree4[struct{}](),
+ LocalCIDR: new(bart.Table[struct{}]),
}
}
@@ -797,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
fr.Hosts[host] = nlc
}
- if ip != nil {
- _, nlc := fr.CIDR.GetCIDR(ip)
+ if ip.IsValid() {
+ nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
@@ -806,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
if err != nil {
return err
}
- fr.CIDR.AddCIDR(ip, nlc)
+ fr.CIDR.Insert(ip, nlc)
}
return nil
}
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
- if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
+ if len(groups) == 0 && host == "" && !ip.IsValid() {
return true
}
@@ -827,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true
}
- if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
+ if ip.IsValid() && ip.Bits() == 0 {
return true
}
@@ -870,22 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
}
- return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
- return flc.match(p, c)
+ matched := false
+ prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+ fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
+ if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+ matched = true
+ return false
+ }
+ return true
})
+ return matched
}
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
- if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
+ if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
+ } else if localIp.Bits() == 0 {
+ flc.Any = true
}
- flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+ flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
@@ -898,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
return true
}
- ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+ _, ok := flc.LocalCIDR.Lookup(p.LocalIP)
return ok
}
@@ -1015,42 +1016,3 @@ func parsePort(s string) (startPort, endPort int32, err error) {
return
}
-
-// TODO: write tests for these
-func setTCPRTTTracking(c *conn, p []byte) {
- if c.Seq != 0 {
- return
- }
-
- ihl := int(p[0]&0x0f) << 2
-
- // Don't track FIN packets
- if p[ihl+13]&tcpFIN != 0 {
- return
- }
-
- c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
- c.Sent = time.Now()
-}
-
-func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
- if c.Seq == 0 {
- return false
- }
-
- ihl := int(p[0]&0x0f) << 2
- if p[ihl+13]&tcpACK == 0 {
- return false
- }
-
- // Deal with wrap around, signed int cuts the ack window in half
- // 0 is a bad ack, no data acknowledged
- // positive number is a bad ack, ack is over half the window away
- if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
- return false
- }
-
- f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
- c.Seq = 0
- return true
-}
diff --git a/firewall/packet.go b/firewall/packet.go
index 1c4affd..8954f4c 100644
--- a/firewall/packet.go
+++ b/firewall/packet.go
@@ -3,8 +3,7 @@ package firewall
import (
"encoding/json"
"fmt"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type m map[string]interface{}
@@ -20,8 +19,8 @@ const (
)
type Packet struct {
- LocalIP iputil.VpnIp
- RemoteIP iputil.VpnIp
+ LocalIP netip.Addr
+ RemoteIP netip.Addr
LocalPort uint16
RemotePort uint16
Protocol uint8
diff --git a/firewall_test.go b/firewall_test.go
index 7d65cb5..4d47e78 100644
--- a/firewall_test.go
+++ b/firewall_test.go
@@ -2,18 +2,16 @@ package nebula
import (
"bytes"
- "encoding/binary"
"errors"
"math"
"net"
+ "net/netip"
"testing"
"time"
- "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
@@ -67,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
- _, ti, _ := net.ParseCIDR("1.2.3.4/32")
+ ti, err := netip.ParsePrefix("1.2.3.4/32")
+ assert.NoError(t, err)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
- ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
+ _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
- ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
+ _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
- assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
+ anyIp, err := netip.ParsePrefix("0.0.0.0/0")
+ assert.NoError(t, err)
+
+ assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
- assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
- assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
+ assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+ assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -154,53 +155,53 @@ func TestFirewall_Drop(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr("1.2.3.4"),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
- assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
+ assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
- assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteIP
- p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
- assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
+ p.RemoteIP = netip.MustParseAddr("1.2.3.10")
+ assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
- assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+ assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+ assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
- assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+ assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+ assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -209,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
TCP: firewallPort{},
}
- _, n, _ := net.ParseCIDR("172.1.1.1/32")
- goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
- _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
- _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
+ pfix := netip.MustParsePrefix("172.1.1.1/32")
+ _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
+ _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@@ -233,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
c := &cert.NebulaCertificate{}
- ip, _, _ := net.ParseCIDR("9.254.254.254/32")
- lip := iputil.Ip2VpnIp(ip)
+ ip := netip.MustParsePrefix("9.254.254.254/32")
for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
+ assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
}
})
@@ -264,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
- assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+ assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
}
})
@@ -288,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
},
}
for n := 0; n < b.N; n++ {
- assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+ assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
}
})
@@ -365,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -389,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h.CreateRemoteCIDR(&c)
@@ -408,14 +407,14 @@ func TestFirewall_Drop2(t *testing.T) {
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
- assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
+ assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@@ -424,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 1,
RemotePort: 1,
Protocol: firewall.ProtoUDP,
@@ -455,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c1,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h1.CreateRemoteCIDR(&c1)
@@ -470,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c2,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h2.CreateRemoteCIDR(&c2)
@@ -485,23 +484,23 @@ func TestFirewall_Drop3(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c3,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
+ assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
+ assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
- assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
+ assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -510,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
l.SetOutput(ob)
p := firewall.Packet{
- LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
- RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+ LocalIP: netip.MustParseAddr("1.2.3.4"),
+ RemoteIP: netip.MustParseAddr("1.2.3.4"),
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
@@ -536,39 +535,39 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
ConnectionState: &ConnectionState{
peerCert: &c,
},
- vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+ vpnIp: netip.MustParseAddr(ipNet.IP.String()),
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool()
// Drop outbound
- assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
+ assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
- assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack
- assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
- assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
+ assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
- assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
+ assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
- assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
+ assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
}
func BenchmarkLookup(b *testing.B) {
@@ -727,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
- assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
+ assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, c, conf)
- assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+ assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
@@ -749,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
- assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr
- cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
+ cidr := netip.MustParsePrefix("10.0.0.0/8")
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
- assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
+ assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
@@ -830,97 +829,6 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
-func TestTCPRTTTracking(t *testing.T) {
- b := make([]byte, 200)
-
- // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
- b[0] = 15
- b[60+12] = 15 << 4
- f := Firewall{
- metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
- }
-
- // Set SEQ to 1
- binary.BigEndian.PutUint32(b[60+4:60+8], 1)
-
- c := &conn{}
- setTCPRTTTracking(c, b)
- assert.Equal(t, uint32(1), c.Seq)
-
- // Bad ack - no ack flag
- binary.BigEndian.PutUint32(b[60+8:60+12], 80)
- assert.False(t, f.checkTCPRTT(c, b))
-
- // Bad ack, number is too low
- binary.BigEndian.PutUint32(b[60+8:60+12], 0)
- b[60+13] = uint8(0x10)
- assert.False(t, f.checkTCPRTT(c, b))
-
- // Good ack
- binary.BigEndian.PutUint32(b[60+8:60+12], 80)
- assert.True(t, f.checkTCPRTT(c, b))
- assert.Equal(t, uint32(0), c.Seq)
-
- // Set SEQ to 1
- binary.BigEndian.PutUint32(b[60+4:60+8], 1)
- c = &conn{}
- setTCPRTTTracking(c, b)
- assert.Equal(t, uint32(1), c.Seq)
-
- // Good acks
- binary.BigEndian.PutUint32(b[60+8:60+12], 81)
- assert.True(t, f.checkTCPRTT(c, b))
- assert.Equal(t, uint32(0), c.Seq)
-
- // Set SEQ to max uint32 - 20
- binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
- c = &conn{}
- setTCPRTTTracking(c, b)
- assert.Equal(t, ^uint32(0)-20, c.Seq)
-
- // Good acks
- binary.BigEndian.PutUint32(b[60+8:60+12], 81)
- assert.True(t, f.checkTCPRTT(c, b))
- assert.Equal(t, uint32(0), c.Seq)
-
- // Set SEQ to max uint32 / 2
- binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
- c = &conn{}
- setTCPRTTTracking(c, b)
- assert.Equal(t, ^uint32(0)/2, c.Seq)
-
- // Below
- binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
- assert.False(t, f.checkTCPRTT(c, b))
- assert.Equal(t, ^uint32(0)/2, c.Seq)
-
- // Halfway below
- binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
- assert.False(t, f.checkTCPRTT(c, b))
- assert.Equal(t, ^uint32(0)/2, c.Seq)
-
- // Halfway above is ok
- binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
- assert.True(t, f.checkTCPRTT(c, b))
- assert.Equal(t, uint32(0), c.Seq)
-
- // Set SEQ to max uint32
- binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
- c = &conn{}
- setTCPRTTTracking(c, b)
- assert.Equal(t, ^uint32(0), c.Seq)
-
- // Halfway + 1 above
- binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
- assert.False(t, f.checkTCPRTT(c, b))
- assert.Equal(t, ^uint32(0), c.Seq)
-
- // Halfway above
- binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
- assert.True(t, f.checkTCPRTT(c, b))
- assert.Equal(t, uint32(0), c.Seq)
-}
-
func TestFirewall_convertRule(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
@@ -964,8 +872,8 @@ type addRuleCall struct {
endPort int32
groups []string
host string
- ip *net.IPNet
- localIp *net.IPNet
+ ip netip.Prefix
+ localIp netip.Prefix
caName string
caSha string
}
@@ -975,7 +883,7 @@ type mockFirewall struct {
nextCallReturn error
}
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,
diff --git a/go.mod b/go.mod
index abc1134..1da2056 100644
--- a/go.mod
+++ b/go.mod
@@ -1,52 +1,55 @@
module github.com/slackhq/nebula
-go 1.20
+go 1.22.0
+
+toolchain go1.22.2
require (
dario.cat/mergo v1.0.0
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
github.com/armon/go-radix v1.0.0
github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
- github.com/flynn/noise v1.0.1
+ github.com/flynn/noise v1.1.0
github.com/gogo/protobuf v1.3.2
github.com/google/gopacket v1.1.19
github.com/kardianos/service v1.2.2
- github.com/miekg/dns v1.1.58
+ github.com/miekg/dns v1.1.61
github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
- github.com/prometheus/client_golang v1.18.0
+ github.com/prometheus/client_golang v1.19.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
- github.com/stretchr/testify v1.8.4
+ github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2
- golang.org/x/crypto v0.18.0
- golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
- golang.org/x/net v0.20.0
- golang.org/x/sync v0.6.0
- golang.org/x/sys v0.16.0
- golang.org/x/term v0.16.0
+ golang.org/x/crypto v0.24.0
+ golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
+ golang.org/x/net v0.26.0
+ golang.org/x/sync v0.7.0
+ golang.org/x/sys v0.21.0
+ golang.org/x/term v0.21.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
golang.zx2c4.com/wireguard/windows v0.5.3
- google.golang.org/protobuf v1.32.0
+ google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v2 v2.4.0
- gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
+ gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
)
require (
github.com/beorn7/perks v1.0.1 // indirect
+ github.com/bits-and-blooms/bitset v1.13.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/google/btree v1.0.1 // indirect
- github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
+ github.com/gaissmai/bart v0.11.1 // indirect
+ github.com/google/btree v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
- github.com/prometheus/common v0.45.0 // indirect
+ github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
- golang.org/x/mod v0.14.0 // indirect
- golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
- golang.org/x/tools v0.17.0 // indirect
+ golang.org/x/mod v0.18.0 // indirect
+ golang.org/x/time v0.5.0 // indirect
+ golang.org/x/tools v0.22.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 6226e15..6db0c4a 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
+github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -22,8 +24,12 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y=
-github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
+github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI=
+github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
+github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
+github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -44,14 +50,15 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
-github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
-github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
+github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
+github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -71,14 +78,13 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
-github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
-github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
-github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
+github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
+github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
@@ -96,8 +102,8 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk=
-github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA=
+github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
+github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@@ -106,8 +112,8 @@ github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk
github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM=
-github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY=
+github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
+github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
@@ -117,6 +123,7 @@ github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3c
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -132,8 +139,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
@@ -146,16 +153,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
-golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
-golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
+golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
+golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
+golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
-golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
+golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -166,8 +173,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
-golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
+golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
+golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -175,8 +182,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
-golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
+golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -194,23 +201,23 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
-golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
+golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.16.0 h1:m+B6fahuftsE9qjo0VWp2FW0mB3MTJvR0BaMQrq0pmE=
-golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
+golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
+golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
-golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
+golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc=
-golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
+golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
+golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -229,8 +236,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
-google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
+google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -246,5 +253,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E=
-gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q=
+gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g=
+gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU=
diff --git a/handshake_ix.go b/handshake_ix.go
index 1905c00..8cf5341 100644
--- a/handshake_ix.go
+++ b/handshake_ix.go
@@ -1,13 +1,12 @@
package nebula
import (
+ "net/netip"
"time"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
// NOISE IX Handshakes
@@ -46,7 +45,6 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
}
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
- ci.messageCounter.Add(1)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
@@ -64,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true
}
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
@@ -90,17 +88,36 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
- f.l.WithError(err).WithField("udpAddr", addr).
- WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
- Info("Invalid certificate from host")
+ e := f.l.WithError(err).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Info("Invalid certificate from host")
+ return
+ }
+
+ vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+ if !ok {
+ e := f.l.WithError(err).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Info("Invalid vpn ip from host")
return
}
- vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+
+ vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
- if vpnIp == f.myVpnIp {
+ if vpnIp == f.myVpnNet.Addr() {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
@@ -109,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
return
}
- if addr != nil {
- if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
+ if addr.IsValid() {
+ if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
@@ -134,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@@ -214,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
- if addr != nil {
+ if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
@@ -280,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
- if addr != nil {
+ if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
if err != nil {
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
@@ -316,13 +333,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
}
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
- hostinfo.ConnectionState.messageCounter.Store(2)
+
hostinfo.remotes.ResetBlockedRemotes()
return
}
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@@ -332,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
defer hh.Unlock()
hostinfo := hh.hostinfo
- if addr != nil {
- if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
+ if addr.IsValid() {
+ if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
@@ -372,15 +389,33 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
if err != nil {
- f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
- WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
- Error("Invalid certificate from host")
+ e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Error("Invalid certificate from host")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
- vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+ vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+ if !ok {
+ e := f.l.WithError(err).WithField("udpAddr", addr).
+ WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+
+ if f.l.Level > logrus.DebugLevel {
+ e = e.WithField("cert", remoteCert)
+ }
+
+ e.Info("Invalid vpn ip from host")
+ return true
+ }
+
+ vpnIp = vpnIp.Unmap()
certName := remoteCert.Details.Name
fingerprint, _ := remoteCert.Sha256Sum()
issuer := remoteCert.Details.Issuer
@@ -406,7 +441,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
- WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+ WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")
// Swap the packet store to benefit the original intended recipient
@@ -444,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
- if addr != nil {
+ if addr.IsValid() {
hostinfo.SetRemote(addr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
@@ -457,8 +492,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
- hostinfo.ConnectionState.messageCounter.Store(2)
-
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
}
diff --git a/handshake_manager.go b/handshake_manager.go
index b568cc8..7960435 100644
--- a/handshake_manager.go
+++ b/handshake_manager.go
@@ -6,15 +6,15 @@ import (
"crypto/rand"
"encoding/binary"
"errors"
- "net"
+ "net/netip"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
+ "golang.org/x/exp/slices"
)
const (
@@ -46,14 +46,14 @@ type HandshakeManager struct {
// Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex
- vpnIps map[iputil.VpnIp]*HandshakeHostInfo
+ vpnIps map[netip.Addr]*HandshakeHostInfo
indexes map[uint32]*HandshakeHostInfo
mainHostMap *HostMap
lightHouse *LightHouse
outside udp.Conn
config HandshakeConfig
- OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
+ OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
@@ -61,17 +61,17 @@ type HandshakeManager struct {
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIp
- trigger chan iputil.VpnIp
+ trigger chan netip.Addr
}
type HandshakeHostInfo struct {
sync.Mutex
- startTime time.Time // Time that we first started trying with this handshake
- ready bool // Is the handshake ready
- counter int // How many attempts have we made so far
- lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt
- packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+ startTime time.Time // Time that we first started trying with this handshake
+ ready bool // Is the handshake ready
+ counter int // How many attempts have we made so far
+ lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
+ packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
}
@@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
- vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{},
+ vpnIps: map[netip.Addr]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{},
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
- trigger: make(chan iputil.VpnIp, config.triggerBuffer),
- OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+ trigger: make(chan netip.Addr, config.triggerBuffer),
+ OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
}
}
-func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
- if addr != nil {
- if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+ if addr.IsValid() {
+ if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
@@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
}
}
-func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
hh := hm.queryVpnIp(vpnIp)
if hh == nil {
return
@@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
- hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
+ hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -211,8 +211,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
}
- remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
- remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
+ remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
+ remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
// This is a very specific optimization for a fast lighthouse reply.
@@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
}
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
- var sentTo []*udp.Addr
- hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+ var sentTo []netip.AddrPort
+ hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
@@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to
- if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
+ if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
continue
}
- relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
- if relayHostInfo == nil || relayHostInfo.remote == nil {
+ relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+ if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
- hm.f.Handshake(*relay)
+ hm.f.Handshake(relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through it
@@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+
+ //TODO: IPV6-WORK
+ myVpnIpB := hm.f.myVpnNet.Addr().As4()
+ theirVpnIpB := vpnIp.As4()
+
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
- RelayFromIp: uint32(hm.lightHouse.myVpnIp),
- RelayToIp: uint32(vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
+ RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
@@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
- "relayFrom": hm.lightHouse.myVpnIp,
+ "relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
- "relay": *relay}).
+ "relay": relay}).
Info("send CreateRelayRequest")
}
default:
@@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
}
} else {
// No relays exist or requested yet.
- if relayHostInfo.remote != nil {
+ if relayHostInfo.remote.IsValid() {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
+ //TODO: IPV6-WORK
+ myVpnIpB := hm.f.myVpnNet.Addr().As4()
+ theirVpnIpB := vpnIp.As4()
+
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
- RelayFromIp: uint32(hm.lightHouse.myVpnIp),
- RelayToIp: uint32(vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]),
+ RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]),
}
msg, err := m.Marshal()
if err != nil {
@@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
- "relayFrom": hm.lightHouse.myVpnIp,
+ "relayFrom": hm.f.myVpnNet.Addr(),
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
- "relay": *relay}).
+ "relay": relay}).
Info("send CreateRelayRequest")
}
}
@@ -355,32 +364,32 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
// The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
- // Check the main hostmap and maintain a read lock if our host is not there
+func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
hm.mainHostMap.RLock()
- if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
- hm.mainHostMap.RUnlock()
+ h, ok := hm.mainHostMap.Hosts[vpnIp]
+ hm.mainHostMap.RUnlock()
+
+ if ok {
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
- h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
+ h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
}
return h, true
}
- defer hm.mainHostMap.RUnlock()
return hm.StartHandshake(vpnIp, cacheCb), false
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
- defer hm.Unlock()
if hh, ok := hm.vpnIps[vpnIp]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hh)
}
+ hm.Unlock()
return hh.hostinfo
}
@@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
- relays: map[iputil.VpnIp]struct{}{},
- relayForByIp: map[iputil.VpnIp]*Relay{},
+ relays: map[netip.Addr]struct{}{},
+ relayForByIp: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
@@ -421,6 +430,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
}
}
+ hm.Unlock()
hm.lightHouse.QueryServer(vpnIp)
return hostinfo
}
@@ -554,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
- c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
+ c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
}
delete(c.indexes, hostinfo.localIndexId)
@@ -569,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
}
}
-func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
hh := hm.queryVpnIp(vpnIp)
if hh != nil {
return hh.hostinfo
@@ -578,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
}
-func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
hm.RLock()
defer hm.RUnlock()
return hm.vpnIps[vpnIp]
@@ -598,8 +608,8 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
return hm.indexes[index]
}
-func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
- return c.mainHostMap.preferredRanges
+func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+ return c.mainHostMap.GetPreferredRanges()
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
diff --git a/handshake_manager_test.go b/handshake_manager_test.go
index 303aa50..a78b45f 100644
--- a/handshake_manager_test.go
+++ b/handshake_manager_test.go
@@ -1,13 +1,12 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
@@ -15,11 +14,14 @@ import (
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
- _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
- _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
- ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
- preferredRanges := []*net.IPNet{localrange}
- mainHM := NewHostMap(l, vpncidr, preferredRanges)
+ vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+ localrange := netip.MustParsePrefix("10.1.1.1/24")
+ ip := netip.MustParseAddr("172.1.1.2")
+
+ preferredRanges := []netip.Prefix{localrange}
+ mainHM := newHostMap(l, vpncidr)
+ mainHM.preferredRanges.Store(&preferredRanges)
+
lh := newTestLighthouse()
cs := &CertState{
@@ -64,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip)
}
-func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
@@ -78,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
type mockEncWriter struct {
}
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
return
}
@@ -90,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
return
}
-func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
+func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
diff --git a/hostmap.go b/hostmap.go
index a5adeb9..fb97b76 100644
--- a/hostmap.go
+++ b/hostmap.go
@@ -3,17 +3,17 @@ package nebula
import (
"errors"
"net"
+ "net/netip"
"sync"
"sync/atomic"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
- "github.com/slackhq/nebula/cidr"
+ "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
// const ProbeLen = 100
@@ -48,7 +48,7 @@ type Relay struct {
State int
LocalIndex uint32
RemoteIndex uint32
- PeerIp iputil.VpnIp
+ PeerIp netip.Addr
}
type HostMap struct {
@@ -56,10 +56,9 @@ type HostMap struct {
Indexes map[uint32]*HostInfo
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
- Hosts map[iputil.VpnIp]*HostInfo
- preferredRanges []*net.IPNet
- vpnCIDR *net.IPNet
- metricsEnabled bool
+ Hosts map[netip.Addr]*HostInfo
+ preferredRanges atomic.Pointer[[]netip.Prefix]
+ vpnCIDR netip.Prefix
l *logrus.Logger
}
@@ -69,12 +68,12 @@ type HostMap struct {
type RelayState struct {
sync.RWMutex
- relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
- relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
- relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
+ relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
+ relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
+ relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info
}
-func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
+func (rs *RelayState) DeleteRelay(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
delete(rs.relays, ip)
@@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
return ret
}
-func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[ip]
return r, ok
}
-func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
+func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
rs.Lock()
defer rs.Unlock()
rs.relays[ip] = struct{}{}
}
-func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
- ret := make([]iputil.VpnIp, 0, len(rs.relays))
+ ret := make([]netip.Addr, 0, len(rs.relays))
for ip := range rs.relays {
ret = append(ret, ip)
}
return ret
}
-func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayForIps() []netip.Addr {
rs.RLock()
defer rs.RUnlock()
- currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
+ currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
for relayIp := range rs.relayForByIp {
currentRelays = append(currentRelays, relayIp)
}
@@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
return ret
}
-func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
- rs.Lock()
- defer rs.Unlock()
- r, ok := rs.relayForByIdx[localIdx]
- if !ok {
- return iputil.VpnIp(0), false
- }
- delete(rs.relayForByIdx, localIdx)
- delete(rs.relayForByIp, r.PeerIp)
- return r.PeerIp, true
-}
-
-func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
+func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
rs.Lock()
defer rs.Unlock()
r, ok := rs.relayForByIp[vpnIp]
@@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
return &newRelay, true
}
-func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
rs.RLock()
defer rs.RUnlock()
r, ok := rs.relayForByIp[vpnIp]
@@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
return r, ok
}
-func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
+func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
rs.Lock()
defer rs.Unlock()
rs.relayForByIp[ip] = r
@@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
}
type HostInfo struct {
- remote *udp.Addr
+ remote netip.AddrPort
remotes *RemoteList
promoteCounter atomic.Uint32
ConnectionState *ConnectionState
remoteIndexId uint32
localIndexId uint32
- vpnIp iputil.VpnIp
+ vpnIp netip.Addr
recvError atomic.Uint32
- remoteCidr *cidr.Tree4[struct{}]
+ remoteCidr *bart.Table[struct{}]
relayState RelayState
// HandshakePacket records the packets used to create this hostinfo
@@ -227,7 +214,7 @@ type HostInfo struct {
lastHandshakeTime uint64
lastRoam time.Time
- lastRoamRemote *udp.Addr
+ lastRoamRemote netip.AddrPort
// Used to track other hostinfos for this vpn ip since only 1 can be primary
// Synchronised via hostmap lock and not the hostinfo lock.
@@ -254,21 +241,53 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}
-func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
- h := map[iputil.VpnIp]*HostInfo{}
- i := map[uint32]*HostInfo{}
- r := map[uint32]*HostInfo{}
- relays := map[uint32]*HostInfo{}
- m := HostMap{
- Indexes: i,
- Relays: relays,
- RemoteIndexes: r,
- Hosts: h,
- preferredRanges: preferredRanges,
- vpnCIDR: vpnCIDR,
- l: l,
+func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
+ hm := newHostMap(l, vpnCIDR)
+
+ hm.reload(c, true)
+ c.RegisterReloadCallback(func(c *config.C) {
+ hm.reload(c, false)
+ })
+
+ l.WithField("network", hm.vpnCIDR.String()).
+ WithField("preferredRanges", hm.GetPreferredRanges()).
+ Info("Main HostMap created")
+
+ return hm
+}
+
+func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
+ return &HostMap{
+ Indexes: map[uint32]*HostInfo{},
+ Relays: map[uint32]*HostInfo{},
+ RemoteIndexes: map[uint32]*HostInfo{},
+ Hosts: map[netip.Addr]*HostInfo{},
+ vpnCIDR: vpnCIDR,
+ l: l,
+ }
+}
+
+func (hm *HostMap) reload(c *config.C, initial bool) {
+ if initial || c.HasChanged("preferred_ranges") {
+ var preferredRanges []netip.Prefix
+ rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
+
+ for _, rawPreferredRange := range rawPreferredRanges {
+ preferredRange, err := netip.ParsePrefix(rawPreferredRange)
+
+ if err != nil {
+ hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
+ continue
+ }
+
+ preferredRanges = append(preferredRanges, preferredRange)
+ }
+
+ oldRanges := hm.preferredRanges.Swap(&preferredRanges)
+ if !initial {
+ hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
+ }
}
- return &m
}
// EmitStats reports host, index, and relay counts to the stats collection system
@@ -346,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnIp)
if len(hm.Hosts) == 0 {
- hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+ hm.Hosts = map[netip.Addr]*HostInfo{}
}
if hostinfo.next != nil {
@@ -429,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
}
}
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
return hm.queryVpnIp(vpnIp, nil)
}
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
hm.RLock()
defer hm.RUnlock()
@@ -451,13 +470,13 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
return nil, nil, errors.New("unable to find host with relay")
}
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
hm.RUnlock()
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
- h.TryPromoteBest(hm.preferredRanges, promoteIfce)
+ h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
}
return h
@@ -503,8 +522,9 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}
}
-func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
- return hm.preferredRanges
+func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
+ //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
+ return *hm.preferredRanges.Load()
}
func (hm *HostMap) ForEachVpnIp(f controlEach) {
@@ -527,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
// TryPromoteBest handles re-querying lighthouses and probing for better paths
// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
-func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
+func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
c := i.promoteCounter.Add(1)
if c%ifce.tryPromoteEvery.Load() == 0 {
remote := i.remote
// return early if we are already on a preferred remote
- if remote != nil {
- rIP := remote.IP
+ if remote.IsValid() {
+ rIP := remote.Addr()
for _, l := range preferredRanges {
if l.Contains(rIP) {
return
@@ -542,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}
}
- i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
- if remote != nil && (addr == nil || !preferred) {
+ i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
+ if remote.IsValid() && (!addr.IsValid() || !preferred) {
return
}
@@ -572,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
return nil
}
-func (i *HostInfo) SetRemote(remote *udp.Addr) {
+func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
- if !i.remote.Equals(remote) {
- i.remote = remote.Copy()
- i.remotes.LearnRemote(i.vpnIp, remote.Copy())
+ if i.remote != remote {
+ i.remote = remote
+ i.remotes.LearnRemote(i.vpnIp, remote)
}
}
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
- if newRemote == nil {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
+ if !newRemote.IsValid() {
// relays have nil udp Addrs
return false
}
currentRemote := i.remote
- if currentRemote == nil {
+ if !currentRemote.IsValid() {
i.SetRemote(newRemote)
return true
}
@@ -596,13 +616,13 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
// NOTE: We do this loop here instead of calling `isPreferred` in
// remote_list.go so that we only have to loop over preferredRanges once.
newIsPreferred := false
- for _, l := range hm.preferredRanges {
+ for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote
- if l.Contains(currentRemote.IP) {
+ if l.Contains(currentRemote.Addr()) {
return false
}
- if l.Contains(newRemote.IP) {
+ if l.Contains(newRemote.Addr()) {
newIsPreferred = true
}
}
@@ -610,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
if newIsPreferred {
// Consider this a roaming event
i.lastRoam = time.Now()
- i.lastRoamRemote = currentRemote.Copy()
+ i.lastRoamRemote = currentRemote
i.SetRemote(newRemote)
@@ -633,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
return
}
- remoteCidr := cidr.NewTree4[struct{}]()
+ remoteCidr := new(bart.Table[struct{}])
for _, ip := range c.Details.Ips {
- remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+ //TODO: IPV6-WORK what to do when ip is invalid?
+ nip, _ := netip.AddrFromSlice(ip.IP)
+ nip = nip.Unmap()
+ bits, _ := ip.Mask.Size()
+ remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
}
for _, n := range c.Details.Subnets {
- remoteCidr.AddCIDR(n, struct{}{})
+ //TODO: IPV6-WORK what to do when ip is invalid?
+ nip, _ := netip.AddrFromSlice(n.IP)
+ nip = nip.Unmap()
+ bits, _ := n.Mask.Size()
+ remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
}
i.remoteCidr = remoteCidr
}
@@ -664,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
// Utility functions
-func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
+func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
//FIXME: This function is pretty garbage
- var ips []net.IP
+ var ips []netip.Addr
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
@@ -688,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
ip = v.IP
}
+ nip, ok := netip.AddrFromSlice(ip)
+ if !ok {
+ if l.Level >= logrus.DebugLevel {
+ l.WithField("localIp", ip).Debug("ip was invalid for netip")
+ }
+ continue
+ }
+ nip = nip.Unmap()
+
//TODO: Filtering out link local for now, this is probably the most correct thing
//TODO: Would be nice to filter out SLAAC MAC based ips as well
- if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
- allow := allowList.Allow(ip)
+ if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
+ allow := allowList.Allow(nip)
if l.Level >= logrus.TraceLevel {
- l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
+ l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
}
if !allow {
continue
}
- ips = append(ips, ip)
+ ips = append(ips, nip)
}
}
}
- return &ips
+ return ips
}
diff --git a/hostmap_test.go b/hostmap_test.go
index c1c0dce..7e2feb8 100644
--- a/hostmap_test.go
+++ b/hostmap_test.go
@@ -1,30 +1,27 @@
package nebula
import (
- "net"
+ "net/netip"
"testing"
+ "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
- hm := NewHostMap(
+ hm := newHostMap(
l,
- &net.IPNet{
- IP: net.IP{10, 0, 0, 1},
- Mask: net.IPMask{255, 255, 255, 0},
- },
- []*net.IPNet{},
+ netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{}
- h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
- h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
- h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
- h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+ h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+ h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+ h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+ h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
hm.unlockedAddHostInfo(h4, f)
hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.unlockedAddHostInfo(h1, f)
// Make sure we go h1 -> h2 -> h3 -> h4
- prim := hm.QueryVpnIp(1)
+ prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h3)
// Make sure we go h3 -> h1 -> h2 -> h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h3.localIndexId, prim.localIndexId)
assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
hm.MakePrimary(h4)
// Make sure we go h4 -> h3 -> h1 -> h2
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -91,23 +88,19 @@ func TestHostMap_MakePrimary(t *testing.T) {
func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
- hm := NewHostMap(
+ hm := newHostMap(
l,
- &net.IPNet{
- IP: net.IP{10, 0, 0, 1},
- Mask: net.IPMask{255, 255, 255, 0},
- },
- []*net.IPNet{},
+ netip.MustParsePrefix("10.0.0.1/24"),
)
f := &Interface{}
- h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
- h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
- h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
- h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
- h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
- h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+ h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+ h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+ h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+ h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+ h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
+ h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
hm.unlockedAddHostInfo(h6, f)
hm.unlockedAddHostInfo(h5, f)
@@ -123,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h)
// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
- prim := hm.QueryVpnIp(1)
+ prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h1.localIndexId, prim.localIndexId)
assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -142,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h1.next)
// Make sure we go h2 -> h3 -> h4 -> h5
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -160,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h3.next)
// Make sure we go h2 -> h4 -> h5
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -176,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h5.next)
// Make sure we go h2 -> h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h2.localIndexId, prim.localIndexId)
assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
assert.Nil(t, prim.prev)
@@ -190,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h2.next)
// Make sure we only have h4
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Equal(t, h4.localIndexId, prim.localIndexId)
assert.Nil(t, prim.prev)
assert.Nil(t, prim.next)
@@ -202,6 +195,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
assert.Nil(t, h4.next)
// Make sure we have nil
- prim = hm.QueryVpnIp(1)
+ prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
assert.Nil(t, prim)
}
+
+func TestHostMap_reload(t *testing.T) {
+ l := test.NewLogger()
+ c := config.NewC(l)
+
+ hm := NewHostMapFromConfig(
+ l,
+ netip.MustParsePrefix("10.0.0.1/24"),
+ c,
+ )
+
+ toS := func(ipn []netip.Prefix) []string {
+ var s []string
+ for _, n := range ipn {
+ s = append(s, n.String())
+ }
+ return s
+ }
+
+ assert.Empty(t, hm.GetPreferredRanges())
+
+ c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
+ assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
+
+ c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
+ assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
+}
diff --git a/hostmap_tester.go b/hostmap_tester.go
index 0d5d41b..b2d1d1b 100644
--- a/hostmap_tester.go
+++ b/hostmap_tester.go
@@ -5,9 +5,11 @@ package nebula
// This file contains functions used to export information to the e2e testing framework
-import "github.com/slackhq/nebula/iputil"
+import (
+ "net/netip"
+)
-func (i *HostInfo) GetVpnIp() iputil.VpnIp {
+func (i *HostInfo) GetVpnIp() netip.Addr {
return i.vpnIp
}
diff --git a/inside.go b/inside.go
index 6230962..0ccd179 100644
--- a/inside.go
+++ b/inside.go
@@ -1,12 +1,13 @@
package nebula
import (
+ "net/netip"
+
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
- "github.com/slackhq/nebula/udp"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
// Ignore local broadcast packets
- if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
+ if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
return
}
- if fwPacket.RemoteIP == f.myVpnIp {
+ if fwPacket.RemoteIP == f.myVpnNet.Addr() {
// Immediately forward packets from self to self.
// This should only happen on Darwin-based and FreeBSD hosts, which
// routes packets from the Nebula IP to the Nebula IP through the Nebula
@@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
- // Ignore broadcast packets
- if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
+ // Ignore multicast packets
+ if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
return
}
@@ -62,9 +63,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
- dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
+ dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
- f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
+ f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out, q)
@@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
return
}
- f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
+ f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
-func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
+func (f *Interface) Handshake(vpnIp netip.Addr) {
f.getOrHandshake(vpnIp, nil)
}
// getOrHandshake returns nil if the vpnIp is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
- if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
+func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+ if !f.myVpnNet.Contains(vpnIp) {
vpnIp = f.inside.RouteFor(vpnIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return nil, false
}
}
@@ -142,7 +143,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
- dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
+ dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
@@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
return
}
- f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
+ f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
}
// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})
@@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
- f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0)
+ f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
}
-func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
+func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
f.messageMetrics.Tx(t, st, 1)
f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
}
@@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo,
f.connectionManager.RelayUsed(relay.LocalIndex)
}
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
if ci.eKey == nil {
//TODO: log warning
return
}
- useRelay := remote == nil && hostinfo.remote == nil
+ useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
fullOut := out
if useRelay {
@@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
return
}
- if remote != nil {
+ if remote.IsValid() {
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
- } else if hostinfo.remote != nil {
+ } else if hostinfo.remote.IsValid() {
err = f.writers[q].WriteTo(out, hostinfo.remote)
if err != nil {
hostinfo.logger(f.l).WithError(err).
@@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
}
}
}
-
-func isMulticast(ip iputil.VpnIp) bool {
- // Class D multicast
- return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
-}
diff --git a/interface.go b/interface.go
index d16348a..f251907 100644
--- a/interface.go
+++ b/interface.go
@@ -2,10 +2,11 @@ package nebula
import (
"context"
+ "encoding/binary"
"errors"
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"runtime"
"sync/atomic"
@@ -16,7 +17,6 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay"
"github.com/slackhq/nebula/udp"
)
@@ -63,8 +63,8 @@ type Interface struct {
serveDns bool
createTime time.Time
lightHouse *LightHouse
- localBroadcast iputil.VpnIp
- myVpnIp iputil.VpnIp
+ myBroadcastAddr netip.Addr
+ myVpnNet netip.Prefix
dropLocalBroadcast bool
dropMulticast bool
routines int
@@ -102,9 +102,9 @@ type EncWriter interface {
out []byte,
nocopy bool,
)
- SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+ SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
- Handshake(vpnIp iputil.VpnIp)
+ Handshake(vpnIp netip.Addr)
}
type sendRecvErrorConfig uint8
@@ -115,10 +115,10 @@ const (
sendRecvErrorPrivate
)
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
switch s {
case sendRecvErrorPrivate:
- return ip.IsPrivate()
+ return ip.Addr().IsPrivate()
case sendRecvErrorAlways:
return true
case sendRecvErrorNever:
@@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
}
certificate := c.pki.GetCertState().Certificate
- myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
+
+ myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+ if !ok {
+ return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
+ }
+
+ myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
+ if !ok {
+ return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
+ }
+
+ myVpnAddr = myVpnAddr.Unmap()
+ myVpnMask = myVpnMask.Unmap()
+
+ if myVpnAddr.BitLen() != myVpnMask.BitLen() {
+ return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
+ }
+
+ ones, _ := certificate.Details.Ips[0].Mask.Size()
+ myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
+
ifce := &Interface{
pki: c.pki,
hostMap: c.HostMap,
@@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
handshakeManager: c.HandshakeManager,
createTime: time.Now(),
lightHouse: c.lightHouse,
- localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
dropLocalBroadcast: c.DropLocalBroadcast,
dropMulticast: c.DropMulticast,
routines: c.routines,
version: c.version,
writers: make([]udp.Conn, c.routines),
readers: make([]io.ReadWriteCloser, c.routines),
- myVpnIp: myVpnIp,
+ myVpnNet: myVpnNet,
relayManager: c.relayManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout,
@@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
l: c.l,
}
+ if myVpnAddr.Is4() {
+ addr := myVpnNet.Masked().Addr().As4()
+ binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
+ ifce.myBroadcastAddr = netip.AddrFrom4(addr)
+ }
+
ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
ifce.reQueryEvery.Store(c.reQueryEvery)
ifce.reQueryWait.Store(int64(c.reQueryWait))
diff --git a/iputil/packet.go b/iputil/packet.go
index b18e524..719e034 100644
--- a/iputil/packet.go
+++ b/iputil/packet.go
@@ -6,6 +6,8 @@ import (
"golang.org/x/net/ipv4"
)
+//TODO: IPV6-WORK can probably delete this
+
const (
// Need 96 bytes for the largest reject packet:
// - 20 byte ipv4 header
diff --git a/iputil/util.go b/iputil/util.go
deleted file mode 100644
index 65f7677..0000000
--- a/iputil/util.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package iputil
-
-import (
- "encoding/binary"
- "fmt"
- "net"
- "net/netip"
-)
-
-type VpnIp uint32
-
-const maxIPv4StringLen = len("255.255.255.255")
-
-func (ip VpnIp) String() string {
- b := make([]byte, maxIPv4StringLen)
-
- n := ubtoa(b, 0, byte(ip>>24))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip>>16&255))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip>>8&255))
- b[n] = '.'
- n++
-
- n += ubtoa(b, n, byte(ip&255))
- return string(b[:n])
-}
-
-func (ip VpnIp) MarshalJSON() ([]byte, error) {
- return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
-}
-
-func (ip VpnIp) ToIP() net.IP {
- nip := make(net.IP, 4)
- binary.BigEndian.PutUint32(nip, uint32(ip))
- return nip
-}
-
-func (ip VpnIp) ToNetIpAddr() netip.Addr {
- var nip [4]byte
- binary.BigEndian.PutUint32(nip[:], uint32(ip))
- return netip.AddrFrom4(nip)
-}
-
-func Ip2VpnIp(ip []byte) VpnIp {
- if len(ip) == 16 {
- return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
- }
- return VpnIp(binary.BigEndian.Uint32(ip))
-}
-
-func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
- addr, ok := netip.AddrFromSlice(ip)
- if !ok {
- return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
- }
- return addr, nil
-}
-
-func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
- addr, err := ToNetIpAddr(ipNet.IP)
- if err != nil {
- return netip.Prefix{}, err
- }
- ones, bits := ipNet.Mask.Size()
- if ones == 0 && bits == 0 {
- return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
- }
- return netip.PrefixFrom(addr, ones), nil
-}
-
-// ubtoa encodes the string form of the integer v to dst[start:] and
-// returns the number of bytes written to dst. The caller must ensure
-// that dst has sufficient length.
-func ubtoa(dst []byte, start int, v byte) int {
- if v < 10 {
- dst[start] = v + '0'
- return 1
- } else if v < 100 {
- dst[start+1] = v%10 + '0'
- dst[start] = v/10 + '0'
- return 2
- }
-
- dst[start+2] = v%10 + '0'
- dst[start+1] = (v/10)%10 + '0'
- dst[start] = v/100 + '0'
- return 3
-}
diff --git a/iputil/util_test.go b/iputil/util_test.go
deleted file mode 100644
index 712d426..0000000
--- a/iputil/util_test.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package iputil
-
-import (
- "net"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestVpnIp_String(t *testing.T) {
- assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
- assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
- assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
- assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
- assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
- assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
-}
diff --git a/lighthouse.go b/lighthouse.go
index aa54c4b..62f4065 100644
--- a/lighthouse.go
+++ b/lighthouse.go
@@ -7,16 +7,16 @@ import (
"fmt"
"net"
"net/netip"
+ "strconv"
"sync"
"sync/atomic"
"time"
+ "github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"github.com/slackhq/nebula/util"
)
@@ -26,25 +26,18 @@ import (
var ErrHostNotKnown = errors.New("host not known")
-type netIpAndPort struct {
- ip net.IP
- port uint16
-}
-
type LightHouse struct {
//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
sync.RWMutex //Because we concurrently read and write to our maps
ctx context.Context
amLighthouse bool
- myVpnIp iputil.VpnIp
- myVpnZeros iputil.VpnIp
- myVpnNet *net.IPNet
+ myVpnNet netip.Prefix
punchConn udp.Conn
punchy *Punchy
// Local cache of answers from light houses
// map of vpn Ip to answers
- addrMap map[iputil.VpnIp]*RemoteList
+ addrMap map[netip.Addr]*RemoteList
// filters remote addresses allowed for each host
// - When we are a lighthouse, this filters what addresses we store and
@@ -57,26 +50,26 @@ type LightHouse struct {
localAllowList atomic.Pointer[LocalAllowList]
// used to trigger the HandshakeManager when we receive HostQueryReply
- handshakeTrigger chan<- iputil.VpnIp
+ handshakeTrigger chan<- netip.Addr
// staticList exists to avoid having a bool in each addrMap entry
// since static should be rare
- staticList atomic.Pointer[map[iputil.VpnIp]struct{}]
- lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
+ staticList atomic.Pointer[map[netip.Addr]struct{}]
+ lighthouses atomic.Pointer[map[netip.Addr]struct{}]
interval atomic.Int64
updateCancel context.CancelFunc
ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
- advertiseAddrs atomic.Pointer[[]netIpAndPort]
+ advertiseAddrs atomic.Pointer[[]netip.AddrPort]
// IP's of relays that can be used by peers to access me
- relaysForMe atomic.Pointer[[]iputil.VpnIp]
+ relaysForMe atomic.Pointer[[]netip.Addr]
- queryChan chan iputil.VpnIp
+ queryChan chan netip.Addr
- calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
+ calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
@@ -85,7 +78,7 @@ type LightHouse struct {
// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
// addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
nebulaPort := uint32(c.GetInt("listen.port", 0))
if amLighthouse && nebulaPort == 0 {
@@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err)
}
- nebulaPort = uint32(uPort.Port)
+ nebulaPort = uint32(uPort.Port())
}
- ones, _ := myVpnNet.Mask.Size()
h := LightHouse{
ctx: ctx,
amLighthouse: amLighthouse,
- myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP),
- myVpnZeros: iputil.VpnIp(32 - ones),
myVpnNet: myVpnNet,
- addrMap: make(map[iputil.VpnIp]*RemoteList),
+ addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
- queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
+ queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
- lighthouses := make(map[iputil.VpnIp]struct{})
+ lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
- staticList := make(map[iputil.VpnIp]struct{})
+ staticList := make(map[netip.Addr]struct{})
h.staticList.Store(&staticList)
if c.GetBool("stats.lighthouse_metrics", false) {
@@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
return &h, nil
}
-func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
return *lh.staticList.Load()
}
-func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
return *lh.lighthouses.Load()
}
@@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
return lh.localAllowList.Load()
}
-func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
+func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
return *lh.advertiseAddrs.Load()
}
-func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
+func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
return *lh.relaysForMe.Load()
}
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
+func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
return lh.calculatedRemotes.Load()
}
@@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
func (lh *LightHouse) reload(c *config.C, initial bool) error {
if initial || c.HasChanged("lighthouse.advertise_addrs") {
rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
- advAddrs := make([]netIpAndPort, 0)
+ advAddrs := make([]netip.AddrPort, 0)
for i, rawAddr := range rawAdvAddrs {
- fIp, fPort, err := udp.ParseIPAndPort(rawAddr)
+ host, sport, err := net.SplitHostPort(rawAddr)
if err != nil {
return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
}
- if fPort == 0 {
- fPort = uint16(lh.nebulaPort)
+ ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
+ if err != nil {
+ return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+ }
+ if len(ips) == 0 {
+ return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
+ }
+
+ port, err := strconv.Atoi(sport)
+ if err != nil {
+ return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+ }
+
+ if port == 0 {
+ port = int(lh.nebulaPort)
}
- if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) {
+ //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
+ ip := ips[0].Unmap()
+ if lh.myVpnNet.Contains(ip) {
lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
continue
}
- advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
+ advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
}
lh.advertiseAddrs.Store(&advAddrs)
@@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
lh.RUnlock()
}
// Build a new list based on current config.
- staticList := make(map[iputil.VpnIp]struct{})
- err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
+ staticList := make(map[netip.Addr]struct{})
+ err := lh.loadStaticMap(c, staticList)
if err != nil {
return err
}
@@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
if initial || c.HasChanged("lighthouse.hosts") {
- lhMap := make(map[iputil.VpnIp]struct{})
- err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
+ lhMap := make(map[netip.Addr]struct{})
+ err := lh.parseLighthouses(c, lhMap)
if err != nil {
return err
}
@@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
if len(c.GetStringSlice("relay.relays", nil)) > 0 {
lh.l.Info("Ignoring relays from config because am_relay is true")
}
- relaysForMe := []iputil.VpnIp{}
+ relaysForMe := []netip.Addr{}
lh.relaysForMe.Store(&relaysForMe)
case false:
- relaysForMe := []iputil.VpnIp{}
+ relaysForMe := []netip.Addr{}
for _, v := range c.GetStringSlice("relay.relays", nil) {
lh.l.WithField("relay", v).Info("Read relay from config")
- configRIP := net.ParseIP(v)
- if configRIP != nil {
- relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
+ configRIP, err := netip.ParseAddr(v)
+ //TODO: We could print the error here
+ if err == nil {
+ relaysForMe = append(relaysForMe, configRIP)
}
}
lh.relaysForMe.Store(&relaysForMe)
@@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
return nil
}
-func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
lhs := c.GetStringSlice("lighthouse.hosts", []string{})
if lh.amLighthouse && len(lhs) != 0 {
lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
}
for i, host := range lhs {
- ip := net.ParseIP(host)
- if ip == nil {
- return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+ ip, err := netip.ParseAddr(host)
+ if err != nil {
+ return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
}
- if !tunCidr.Contains(ip) {
- return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+ if !lh.myVpnNet.Contains(ip) {
+ return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
}
- lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
+ lhMap[ip] = struct{}{}
}
if !lh.amLighthouse && len(lhMap) == 0 {
@@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
return network, nil
}
-func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
d, err := getStaticMapCadence(c)
if err != nil {
return err
@@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return err
}
- lookup_timeout, err := getStaticMapLookupTimeout(c)
+ lookupTimeout, err := getStaticMapLookupTimeout(c)
if err != nil {
return err
}
@@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
i := 0
for k, v := range shm {
- rip := net.ParseIP(fmt.Sprintf("%v", k))
- if rip == nil {
- return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
+ vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
+ if err != nil {
+ return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
}
- if !tunCidr.Contains(rip) {
- return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
+ if !lh.myVpnNet.Contains(vpnIp) {
+ return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
}
- vpnIp := iputil.Ip2VpnIp(rip)
vals, ok := v.([]interface{})
if !ok {
vals = []interface{}{v}
@@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
}
- err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+ err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
if err != nil {
return err
}
@@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil
}
-func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip)
}
@@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
}
// QueryServer is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+func (lh *LightHouse) QueryServer(ip netip.Addr) {
// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
if lh.amLighthouse || lh.IsLighthouseIP(ip) {
return
@@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
lh.queryChan <- ip
}
-func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
lh.RUnlock()
@@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
lh.RLock()
// Do we have an entry in the main cache?
if v, ok := lh.addrMap[vpnIp]; ok {
@@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
return false, 0, nil
}
-func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
// First we check the static mapping
// and do nothing if it is there
if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
@@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
// NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
lh.Lock()
am := lh.unlockedGetRemoteList(vpnIp)
am.Lock()
@@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
am.unlockedSetHostnamesResults(hr)
for _, addrPort := range hr.GetIPs() {
-
+ if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
+ continue
+ }
switch {
case addrPort.Addr().Is4():
- to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
- if !lh.unlockedShouldAddV4(vpnIp, to) {
- continue
- }
- am.unlockedPrependV4(lh.myVpnIp, to)
+ am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
case addrPort.Addr().Is6():
- to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
- if !lh.unlockedShouldAddV6(vpnIp, to) {
- continue
- }
- am.unlockedPrependV6(lh.myVpnIp, to)
+ am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
}
}
@@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
// addCalculatedRemotes adds any calculated remotes based on the
// lighthouse.calculated_remotes configuration. It returns true if any
// calculated remotes were added
-func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
tree := lh.getCalculatedRemotes()
if tree == nil {
return false
}
- ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+ calculatedRemotes, ok := tree.Lookup(vpnIp)
if !ok {
return false
}
@@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
defer am.Unlock()
lh.Unlock()
- am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
+ am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
return len(calculated) > 0
}
// unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
am, ok := lh.addrMap[vpnIp]
if !ok {
am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
@@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
return am
}
-func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
- switch {
- case to.Is4():
- ipBytes := to.As4()
- ip := iputil.Ip2VpnIp(ipBytes[:])
- allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
- if lh.l.Level >= logrus.TraceLevel {
- lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
- }
- if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
- return false
- }
- case to.Is6():
- ipBytes := to.As16()
-
- hi := binary.BigEndian.Uint64(ipBytes[:8])
- lo := binary.BigEndian.Uint64(ipBytes[8:])
- allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
- if lh.l.Level >= logrus.TraceLevel {
- lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
- }
-
- // We don't check our vpn network here because nebula does not support ipv6 on the inside
- if !allow {
- return false
- }
+func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
+ if lh.l.Level >= logrus.TraceLevel {
+ lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+ }
+ if !allow || lh.myVpnNet.Contains(to) {
+ return false
}
+
return true
}
// unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
- allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
+ ip := AddrPortFromIp4AndPort(to)
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
- if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
+ if !allow || lh.myVpnNet.Contains(ip.Addr()) {
return false
}
@@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
}
// unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
- allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
+ ip := AddrPortFromIp6AndPort(to)
+ allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
if lh.l.Level >= logrus.TraceLevel {
lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
}
- // We don't check our vpn network here because nebula does not support ipv6 on the inside
- if !allow {
+ if !allow || lh.myVpnNet.Contains(ip.Addr()) {
return false
}
@@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
return ip
}
-func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
if _, ok := lh.GetLighthouses()[vpnIp]; ok {
return true
}
return false
}
-func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
+func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
+ if vpnIp.Is6() {
+ //TODO: need to support ipv6
+ panic("ipv6 is not yet supported")
+ }
+
+ b := vpnIp.As4()
return &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
- VpnIp: uint32(VpnIp),
+ VpnIp: binary.BigEndian.Uint32(b[:]),
},
}
}
-func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
- ipp := Ip4AndPort{Port: port}
- ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
- return &ipp
+func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], ip.Ip)
+ return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
+}
+
+func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
+ b := [16]byte{}
+ binary.BigEndian.PutUint64(b[:8], ip.Hi)
+ binary.BigEndian.PutUint64(b[8:], ip.Lo)
+ return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
}
func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
@@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
}
}
-func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
- return &Ip6AndPort{
- Hi: binary.BigEndian.Uint64(ip[:8]),
- Lo: binary.BigEndian.Uint64(ip[8:]),
- Port: port,
- }
-}
-
+// TODO: IPV6-WORK we can delete some more of these
func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
ip6Addr := ip.As16()
return &Ip6AndPort{
@@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
Port: uint32(port),
}
}
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
- ip := ipp.Ip
- return udp.NewAddr(
- net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
- uint16(ipp.Port),
- )
-}
-
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
- return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
-}
func (lh *LightHouse) startQueryWorker() {
if lh.amLighthouse {
@@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
}()
}
-func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
if lh.IsLighthouseIP(ip) {
return
}
@@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
var v6 []*Ip6AndPort
for _, e := range lh.GetAdvertiseAddrs() {
- if ip := e.ip.To4(); ip != nil {
- v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port)))
+ if e.Addr().Is4() {
+ v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
} else {
- v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port)))
+ v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
}
}
lal := lh.GetLocalAllowList()
- for _, e := range *localIps(lh.l, lal) {
- if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
+ for _, e := range localIps(lh.l, lal) {
+ if lh.myVpnNet.Contains(e) {
continue
}
// Only add IPs that aren't my VPN/tun IP
- if ip := e.To4(); ip != nil {
- v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
+ if e.Is4() {
+ v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
} else {
- v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
+ v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
}
}
var relays []uint32
for _, r := range lh.GetRelaysForMe() {
- relays = append(relays, (uint32)(r))
+ //TODO: IPV6-WORK both relays and vpnip need ipv6 support
+ b := r.As4()
+ relays = append(relays, binary.BigEndian.Uint32(b[:]))
}
+ //TODO: IPV6-WORK both relays and vpnip need ipv6 support
+ b := lh.myVpnNet.Addr().As4()
+
m := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
- VpnIp: uint32(lh.myVpnIp),
+ VpnIp: binary.BigEndian.Uint32(b[:]),
Ip4AndPorts: v4,
Ip6AndPorts: v6,
RelayVpnIp: relays,
@@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
}
func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
- return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+ return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
lhh.HandleRequest(rAddr, vpnIp, p, f)
}
}
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
@@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
}
}
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
@@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
//TODO: we can DRY this further
reqVpnIp := n.Details.VpnIp
+
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ queryVpnIp := netip.AddrFrom4(b)
+
//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
- found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
+ found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
n.Details.VpnIp = reqVpnIp
@@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
- n.Details.VpnIp = uint32(vpnIp)
-
+ //TODO: IPV6-WORK
+ b = vpnIp.As4()
+ n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(c, n)
return n.MarshalTo(lhh.pb)
@@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
}
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
- w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
+
+ //TODO: IPV6-WORK
+ binary.BigEndian.PutUint32(b[:], reqVpnIp)
+ sendTo := netip.AddrFrom4(b)
+ w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
}
if c.relay != nil {
- n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...)
+ //TODO: IPV6-WORK
+ relays := make([]uint32, len(c.relay.relay))
+ b := [4]byte{}
+ for i, _ := range relays {
+ b = c.relay.relay[i].As4()
+ relays[i] = binary.BigEndian.Uint32(b[:])
+ }
+ n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
}
}
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
lhh.lh.Lock()
- am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ certVpnIp := netip.AddrFrom4(b)
+ am := lhh.lh.unlockedGetRemoteList(certVpnIp)
am.Lock()
lhh.lh.Unlock()
- certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+ //TODO: IPV6-WORK
am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
- am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+
+ //TODO: IPV6-WORK
+ relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+ for i, _ := range n.Details.RelayVpnIp {
+ binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+ relays[i] = netip.AddrFrom4(b)
+ }
+ am.unlockedSetRelay(vpnIp, certVpnIp, relays)
am.Unlock()
// Non-blocking attempt to trigger, skip if it would block
select {
- case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
+ case lhh.lh.handshakeTrigger <- certVpnIp:
default:
}
}
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
}
//Simple check that the host sent this not someone else
- if n.Details.VpnIp != uint32(vpnIp) {
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ detailsVpnIp := netip.AddrFrom4(b)
+ if detailsVpnIp != vpnIp {
if lhh.l.Level >= logrus.DebugLevel {
- lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+ lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
}
return
}
@@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
am.Lock()
lhh.lh.Unlock()
- certVpnIp := iputil.VpnIp(n.Details.VpnIp)
- am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
- am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
- am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+ am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+ am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+
+ //TODO: IPV6-WORK
+ relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+ for i, _ := range n.Details.RelayVpnIp {
+ binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+ relays[i] = netip.AddrFrom4(b)
+ }
+ am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
am.Unlock()
n = lhh.resetMeta()
n.Type = NebulaMeta_HostUpdateNotificationAck
- n.Details.VpnIp = uint32(vpnIp)
+
+ //TODO: IPV6-WORK
+ vpnIpB := vpnIp.As4()
+ n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
ln, err := n.MarshalTo(lhh.pb)
if err != nil {
@@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
if !lhh.lh.IsLighthouseIP(vpnIp) {
return
}
empty := []byte{0}
- punch := func(vpnPeer *udp.Addr) {
- if vpnPeer == nil {
+ punch := func(vpnPeer netip.AddrPort) {
+ if !vpnPeer.IsValid() {
return
}
@@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
if lhh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
- lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
+ //TODO: IPV6-WORK, make this debug line not suck
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
}
}
for _, a := range n.Details.Ip4AndPorts {
- punch(NewUDPAddrFromLH4(a))
+ punch(AddrPortFromIp4AndPort(a))
}
for _, a := range n.Details.Ip6AndPorts {
- punch(NewUDPAddrFromLH6(a))
+ punch(AddrPortFromIp6AndPort(a))
}
// This sends a nebula test packet to the host trying to contact us. In the case
// of a double nat or other difficult scenario, this may help establish
// a tunnel.
if lhh.lh.punchy.GetRespond() {
- queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+ queryVpnIp := netip.AddrFrom4(b)
go func() {
time.Sleep(lhh.lh.punchy.GetRespondDelay())
if lhh.l.Level >= logrus.DebugLevel {
@@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
}()
}
}
-
-// ipMaskContains checks if testIp is contained by ip after applying a cidr
-// zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
- return (testIp^ip)>>zeros == 0
-}
diff --git a/lighthouse_test.go b/lighthouse_test.go
index 66427e3..2599f5f 100644
--- a/lighthouse_test.go
+++ b/lighthouse_test.go
@@ -2,15 +2,14 @@ package nebula
import (
"context"
+ "encoding/binary"
"fmt"
- "net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
- "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
@@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
var m Ip4AndPort
err := m.Unmarshal(b)
assert.NoError(t, err)
- assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
+ ip := netip.MustParseAddr("10.1.1.1")
+ bp := ip.As4()
+ assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
}
func TestNewLhQuery(t *testing.T) {
- myIp := net.ParseIP("192.1.1.1")
- myIpint := iputil.Ip2VpnIp(myIp)
+ myIp, err := netip.ParseAddr("192.1.1.1")
+ assert.NoError(t, err)
// Generating a new lh query should work
- a := NewLhQueryByInt(myIpint)
+ a := NewLhQueryByInt(myIp)
// The result should be a nebulameta protobuf
assert.IsType(t, &NebulaMeta{}, a)
@@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
func Test_lhStaticMapping(t *testing.T) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2"
c := config.NewC(l)
@@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
func TestReloadLighthouseInterval(t *testing.T) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
lh1 := "10.128.0.2"
c := config.NewC(l)
@@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies
- c.ReloadConfigString("lighthouse:\n interval: 5")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function
- c.ReloadConfigString("lighthouse:\n interval: 10")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine
- c.ReloadConfigString("lighthouse:\n interval: 11")
+ assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load())
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := test.NewLogger()
- _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
+ myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
@@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
b.Fatal()
}
- hAddr := udp.NewAddrFromString("4.5.6.7:12345")
- hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
- lh.addrMap[3] = NewRemoteList(nil)
- lh.addrMap[3].unlockedSetV4(
- 3,
- 3,
+ hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
+ hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
+
+ vpnIp3 := netip.MustParseAddr("0.0.0.3")
+ lh.addrMap[vpnIp3] = NewRemoteList(nil)
+ lh.addrMap[vpnIp3].unlockedSetV4(
+ vpnIp3,
+ vpnIp3,
[]*Ip4AndPort{
- NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
- NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+ NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
+ NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
- rAddr := udp.NewAddrFromString("1.2.2.3:12345")
- rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
- lh.addrMap[2] = NewRemoteList(nil)
- lh.addrMap[2].unlockedSetV4(
- 3,
- 3,
+ rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
+ rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
+ vpnIp2 := netip.MustParseAddr("0.0.0.3")
+ lh.addrMap[vpnIp2] = NewRemoteList(nil)
+ lh.addrMap[vpnIp2].unlockedSetV4(
+ vpnIp3,
+ vpnIp3,
[]*Ip4AndPort{
- NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
- NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+ NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
+ NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
mw := &mockEncWriter{}
@@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
p, err := req.Marshal()
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
- lhh.HandleRequest(rAddr, 2, p, mw)
+ lhh.HandleRequest(rAddr, vpnIp2, p, mw)
}
})
b.Run("found", func(b *testing.B) {
@@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
assert.NoError(b, err)
for n := 0; n < b.N; n++ {
- lhh.HandleRequest(rAddr, 2, p, mw)
+ lhh.HandleRequest(rAddr, vpnIp2, p, mw)
}
})
}
@@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
func TestLighthouse_Memory(t *testing.T) {
l := test.NewLogger()
- myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
- myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
- myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
- myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
- myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
- myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
- myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
- myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
- myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
- myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
- myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
- myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
- myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
-
- theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
- theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
- theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
- theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
- theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
- theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+ myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
+ myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
+ myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
+ myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
+ myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
+ myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
+ myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
+ myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
+ myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
+ myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
+ myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
+ myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
+ myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+ theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
+ theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
+ theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
+ theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
+ theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
+ theirVpnIp := netip.MustParseAddr("10.128.0.3")
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
- lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+ lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
lhh := lh.NewRequestHandler()
// Test that my first update responds with just that
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
// Ensure we don't accumulate addresses
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
// Grow it back to 2
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
// Update a different host and ask about it
- newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+ newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
@@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
newLHHostUpdate(
myUdpAddr0,
myVpnIp,
- []*udp.Addr{
+ []netip.AddrPort{
myUdpAddr1,
myUdpAddr2,
myUdpAddr3,
@@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
)
// Make sure we won't add ips in our vpn network
- bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
- bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
- good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
- newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
+ bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
+ bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
+ good := netip.MustParseAddrPort("1.128.0.99:4242")
+ newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
}
@@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
c := config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
- lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+ lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
assert.NoError(t, err)
nc := map[interface{}]interface{}{
@@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
assert.NoError(t, err)
}
-func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
+ //TODO: IPV6-WORK
+ bip := queryVpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{
- VpnIp: uint32(queryVpnIp),
+ VpnIp: binary.BigEndian.Uint32(bip[:]),
},
}
@@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
return w.lastReply
}
-func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
+ //TODO: IPV6-WORK
+ bip := vpnIp.As4()
req := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
- VpnIp: uint32(vpnIp),
+ VpnIp: binary.BigEndian.Uint32(bip[:]),
Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
},
}
for k, v := range addrs {
- req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
+ req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
}
b, err := req.Marshal()
@@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
// )
//}
-func Test_ipMaskContains(t *testing.T) {
- assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
- assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
- assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-}
-
type testLhReply struct {
nebType header.MessageType
nebSubType header.MessageSubType
- vpnIp iputil.VpnIp
+ vpnIp netip.Addr
msg *NebulaMeta
}
@@ -414,7 +416,7 @@ type testEncWriter struct {
func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
}
-func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
+func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
}
func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
}
}
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
msg := &NebulaMeta{}
err := msg.Unmarshal(p)
if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
}
// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
- if !assert.Len(t, have, len(want)) {
- return
- }
-
- for k, w := range want {
- if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
- assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
- }
- }
-}
-
-// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) {
return
}
for k, w := range want {
- if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
- assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+ //TODO: IPV6-WORK
+ h := AddrPortFromIp4AndPort(have[k])
+ if !(h == w) {
+ assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
}
}
}
-
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
- addrs := make([]*udp.Addr, len(ips))
- for k, v := range ips {
- addrs[k] = NewUDPAddrFromLH4(v)
- }
- return addrs
-}
diff --git a/main.go b/main.go
index 8c94e80..248f329 100644
--- a/main.go
+++ b/main.go
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"net"
+ "net/netip"
"time"
"github.com/sirupsen/logrus"
@@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
- // TODO: make sure mask is 4 bytes
- tunCidr := certificate.Details.Ips[0]
+ ones, _ := certificate.Details.Ips[0].Mask.Size()
+ addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+ if !ok {
+ err = util.NewContextualError(
+ "Invalid ip address in certificate",
+ m{"vpnIp": certificate.Details.Ips[0].IP},
+ nil,
+ )
+ return nil, err
+ }
+ tunCidr := netip.PrefixFrom(addr, ones)
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
if err != nil {
@@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if !configTest {
rawListenHost := c.GetString("listen.host", "0.0.0.0")
- var listenHost *net.IPAddr
+ var listenHost netip.Addr
if rawListenHost == "[::]" {
// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
- listenHost = &net.IPAddr{IP: net.IPv6zero}
+ listenHost = netip.IPv6Unspecified()
} else {
- listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+ ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
}
+ if len(ips) == 0 {
+ return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
+ }
+ listenHost = ips[0].Unmap()
}
for i := 0; i < routines; i++ {
- l.Infof("listening %q %d", listenHost.IP, port)
- udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
+ l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+ udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
if err != nil {
return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
@@ -178,57 +192,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
if err != nil {
return nil, util.NewContextualError("Failed to get listening port", nil, err)
}
- port = int(uPort.Port)
- }
- }
- }
-
- // Set up my internal host map
- var preferredRanges []*net.IPNet
- rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
- // First, check if 'preferred_ranges' is set and fallback to 'local_range'
- if len(rawPreferredRanges) > 0 {
- for _, rawPreferredRange := range rawPreferredRanges {
- _, preferredRange, err := net.ParseCIDR(rawPreferredRange)
- if err != nil {
- return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
- }
- preferredRanges = append(preferredRanges, preferredRange)
- }
- }
-
- // local_range was superseded by preferred_ranges. If it is still present,
- // merge the local_range setting into preferred_ranges. We will probably
- // deprecate local_range and remove in the future.
- rawLocalRange := c.GetString("local_range", "")
- if rawLocalRange != "" {
- _, localRange, err := net.ParseCIDR(rawLocalRange)
- if err != nil {
- return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
- }
-
- // Check if the entry for local_range was already specified in
- // preferred_ranges. Don't put it into the slice twice if so.
- var found bool
- for _, r := range preferredRanges {
- if r.String() == localRange.String() {
- found = true
- break
+ port = int(uPort.Port())
}
}
- if !found {
- preferredRanges = append(preferredRanges, localRange)
- }
}
- hostMap := NewHostMap(l, tunCidr, preferredRanges)
- hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
-
- l.
- WithField("network", hostMap.vpnCIDR.String()).
- WithField("preferredRanges", hostMap.preferredRanges).
- Info("Main HostMap created")
-
+ hostMap := NewHostMapFromConfig(l, tunCidr, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil {
diff --git a/noise.go b/noise.go
index 91ad2c0..57990a7 100644
--- a/noise.go
+++ b/noise.go
@@ -28,11 +28,11 @@ func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
// EncryptDanger encrypts and authenticates a given payload.
//
// out is a destination slice to hold the output of the EncryptDanger operation.
-// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
-// - plaintext is encrypted, authenticated and appended to out.
-// - n is a nonce value which must never be re-used with this key.
-// - nb is a buffer used for temporary storage in the implementation of this call, which should
-// be re-used by callers to minimize garbage collection.
+// - ad is additional data, which will be authenticated and appended to out, but not encrypted.
+// - plaintext is encrypted, authenticated and appended to out.
+// - n is a nonce value which must never be re-used with this key.
+// - nb is a buffer used for temporary storage in the implementation of this call, which should
+// be re-used by callers to minimize garbage collection.
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
if s != nil {
// TODO: Is this okay now that we have made messageCounter atomic?
diff --git a/noiseutil/nist.go b/noiseutil/nist.go
index 90e77ab..976a274 100644
--- a/noiseutil/nist.go
+++ b/noiseutil/nist.go
@@ -48,7 +48,7 @@ func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) {
}
ecdhPrivKey, err := c.curve.NewPrivateKey(privkey)
if err != nil {
- return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+ return nil, fmt.Errorf("unable to unmarshal private key: %w", err)
}
return ecdhPrivKey.ECDH(ecdhPubKey)
diff --git a/outside.go b/outside.go
index 2918911..be60294 100644
--- a/outside.go
+++ b/outside.go
@@ -4,6 +4,7 @@ import (
"encoding/binary"
"errors"
"fmt"
+ "net/netip"
"time"
"github.com/flynn/noise"
@@ -11,7 +12,6 @@ import (
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
"google.golang.org/protobuf/proto"
@@ -21,9 +21,10 @@ const (
minFwPacketLen = 4
)
+// TODO: IPV6-WORK this can likely be removed now
func readOutsidePackets(f *Interface) udp.EncReader {
return func(
- addr *udp.Addr,
+ addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
@@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader {
}
}
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// TODO: best if we return this and let caller log
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
- f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
+ f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
- if addr != nil {
- if ip4 := addr.IP.To4(); ip4 != nil {
- if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
- if f.l.Level >= logrus.DebugLevel {
- f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
- }
- return
+ if ip.IsValid() {
+ if f.myVpnNet.Contains(ip.Addr()) {
+ if f.l.Level >= logrus.DebugLevel {
+ f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
}
+ return
}
}
@@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
@@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo.localIndexId)
f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
- f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+ f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
@@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
return
}
- lhf(addr, hostinfo.vpnIp, d)
+ lhf(ip, hostinfo.vpnIp, d)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt test packet")
@@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
@@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- f.handshakeManager.HandleIncoming(addr, via, packet, h)
+ f.handshakeManager.HandleIncoming(ip, via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- f.handleRecvError(addr, h)
+ f.handleRecvError(ip, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
- hostinfo.logger(f.l).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", ip).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
- if !f.handleEncrypted(ci, addr, h) {
+ if !f.handleEncrypted(ci, ip, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
- hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+ hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
@@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
- hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
+ hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
return
}
- f.handleHostRoaming(hostinfo, addr)
+ f.handleHostRoaming(hostinfo, ip)
f.connectionManager.In(hostinfo.localIndexId)
}
@@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
- if addr != nil && !hostinfo.remote.Equals(addr) {
- if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
- hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
+ if ip.IsValid() && hostinfo.remote != ip {
+ if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
+ hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
return
}
- if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+ if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
- hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
- hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+ hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
- hostinfo.SetRemote(addr)
+ hostinfo.SetRemote(ip)
}
}
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
- if addr != nil {
+ if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
return false
} else {
@@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Firewall packets are locally oriented
if incoming {
- fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
- fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+ //TODO: IPV6-WORK
+ fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
+ fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
- fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
- fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+ //TODO: IPV6-WORK
+ fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
+ fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
@@ -404,7 +405,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
- dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
+ dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in
@@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return true
}
-func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
- if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
+func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+ if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
f.sendRecvError(endpoint, index)
}
}
-func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
f.messageMetrics.Tx(header.RecvError, 0, 1)
//TODO: this should be a signed message so we can trust that we should drop the index
@@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
}
}
-func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
+func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
@@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
return
}
- if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
+ if hostinfo.remote.IsValid() && hostinfo.remote != addr {
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}
diff --git a/outside_test.go b/outside_test.go
index 682107b..f9d4bfa 100644
--- a/outside_test.go
+++ b/outside_test.go
@@ -2,10 +2,10 @@ package nebula
import (
"net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/firewall"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
"golang.org/x/net/ipv4"
)
@@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
- assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
- assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+ assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
+ assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
assert.Equal(t, p.RemotePort, uint16(3))
assert.Equal(t, p.LocalPort, uint16(4))
@@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, p.Protocol, uint8(2))
- assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
- assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+ assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
+ assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
assert.Equal(t, p.RemotePort, uint16(6))
assert.Equal(t, p.LocalPort, uint16(5))
}
diff --git a/overlay/device.go b/overlay/device.go
index 3f3f2eb..50ad6ad 100644
--- a/overlay/device.go
+++ b/overlay/device.go
@@ -2,16 +2,14 @@ package overlay
import (
"io"
- "net"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type Device interface {
io.ReadWriteCloser
Activate() error
- Cidr() *net.IPNet
+ Cidr() netip.Prefix
Name() string
- RouteFor(iputil.VpnIp) iputil.VpnIp
+ RouteFor(netip.Addr) netip.Addr
NewMultiQueueReader() (io.ReadWriteCloser, error)
}
diff --git a/overlay/route.go b/overlay/route.go
index 793c8fd..8ccc994 100644
--- a/overlay/route.go
+++ b/overlay/route.go
@@ -4,38 +4,64 @@ import (
"fmt"
"math"
"net"
+ "net/netip"
"runtime"
"strconv"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
type Route struct {
MTU int
Metric int
- Cidr *net.IPNet
- Via *iputil.VpnIp
+ Cidr netip.Prefix
+ Via netip.Addr
Install bool
}
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
- routeTree := cidr.NewTree4[iputil.VpnIp]()
+// Equal determines if a route that could be installed in the system route table is equal to another
+// Via is ignored since that is only consumed within nebula itself
+func (r Route) Equal(t Route) bool {
+ if r.Cidr != t.Cidr {
+ return false
+ }
+ if r.Metric != t.Metric {
+ return false
+ }
+ if r.MTU != t.MTU {
+ return false
+ }
+ if r.Install != t.Install {
+ return false
+ }
+ return true
+}
+
+func (r Route) String() string {
+ s := r.Cidr.String()
+ if r.Metric != 0 {
+ s += fmt.Sprintf(" metric: %v", r.Metric)
+ }
+ return s
+}
+
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
+ routeTree := new(bart.Table[netip.Addr])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
- if r.Via != nil {
- routeTree.AddCIDR(r.Cidr, *r.Via)
+ if r.Via.IsValid() {
+ routeTree.Insert(r.Cidr, r.Via)
}
}
return routeTree, nil
}
-func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.routes")
@@ -86,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
MTU: mtu,
}
- _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+ r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
}
- if !ipWithin(network, r.Cidr) {
+ if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
return nil, fmt.Errorf(
"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
i+1,
@@ -106,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return routes, nil
}
-func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
var err error
r := c.Get("tun.unsafe_routes")
@@ -172,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
- nVia := net.ParseIP(via)
- if nVia == nil {
- return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
+ viaVpnIp, err := netip.ParseAddr(via)
+ if err != nil {
+ return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
rRoute, ok := m["route"]
@@ -182,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
}
- viaVpnIp := iputil.Ip2VpnIp(nVia)
-
install := true
rInstall, ok := m["install"]
if ok {
@@ -194,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
}
r := Route{
- Via: &viaVpnIp,
+ Via: viaVpnIp,
MTU: mtu,
Metric: metric,
Install: install,
}
- _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+ r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
if err != nil {
return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
}
- if ipWithin(network, r.Cidr) {
+ if network.Contains(r.Cidr.Addr()) {
return nil, fmt.Errorf(
"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
i+1,
diff --git a/overlay/route_test.go b/overlay/route_test.go
index 46fb87c..d791389 100644
--- a/overlay/route_test.go
+++ b/overlay/route_test.go
@@ -2,11 +2,10 @@ package overlay
import (
"fmt"
- "net"
+ "net/netip"
"testing"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)
@@ -14,7 +13,8 @@ import (
func Test_parseRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
// test no routes config
routes, err := parseRoutes(c, n)
@@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
+ assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
@@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
// test no routes config
routes, err := parseUnsafeRoutes(c, n)
@@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
+ assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
@@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, n)
assert.Nil(t, routes)
- assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
+ assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
@@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
- _, n, _ := net.ParseCIDR("10.0.0.0/24")
+ n, err := netip.ParsePrefix("10.0.0.0/24")
+ assert.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
@@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err)
- ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
- ok, r := routeTree.MostSpecificContains(ip)
+ ip, err := netip.ParseAddr("1.0.0.2")
+ assert.NoError(t, err)
+ r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
- assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
- ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
- ok, r = routeTree.MostSpecificContains(ip)
+ nip, err := netip.ParseAddr("192.168.0.1")
+ assert.NoError(t, err)
+ assert.Equal(t, nip, r)
+
+ ip, err = netip.ParseAddr("1.0.0.1")
+ assert.NoError(t, err)
+ r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
- assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
- ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
- ok, r = routeTree.MostSpecificContains(ip)
+ nip, err = netip.ParseAddr("192.168.0.2")
+ assert.NoError(t, err)
+ assert.Equal(t, nip, r)
+
+ ip, err = netip.ParseAddr("1.1.0.1")
+ assert.NoError(t, err)
+ r, ok = routeTree.Lookup(ip)
assert.False(t, ok)
}
diff --git a/overlay/tun.go b/overlay/tun.go
index ca1a64a..12460da 100644
--- a/overlay/tun.go
+++ b/overlay/tun.go
@@ -1,7 +1,7 @@
package overlay
import (
- "net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@@ -10,60 +10,63 @@ import (
const DefaultMTU = 1300
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
-
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
- routes, err := parseRoutes(c, tunCidr)
- if err != nil {
- return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
- }
-
- unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
- if err != nil {
- return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
- }
- routes = append(routes, unsafeRoutes...)
+// TODO: We may be able to remove routines
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
- return newTun(
- l,
- c.GetString("tun.dev", ""),
- tunCidr,
- c.GetInt("tun.mtu", DefaultMTU),
- routes,
- c.GetInt("tun.tx_queue", 500),
- routines > 1,
- c.GetBool("tun.use_system_route_table", false),
- )
+ return newTun(c, l, tunCidr, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
- return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
- routes, err := parseRoutes(c, tunCidr)
- if err != nil {
- return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
- }
+ return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
+ return newTunFromFd(c, l, *fd, tunCidr)
+ }
+}
+
+func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
+ if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
+ return false, nil, nil
+ }
- unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
- if err != nil {
- return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+ routes, err := parseRoutes(c, cidr)
+ if err != nil {
+ return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
+ }
+
+ unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
+ if err != nil {
+ return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+ }
+
+ routes = append(routes, unsafeRoutes...)
+ return true, routes, nil
+}
+
+// findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table.
+// Via is not used to evaluate since it does not affect the system route table.
+func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
+ var removed []Route
+ has := func(entry Route) bool {
+ for _, check := range newRoutes {
+ if check.Equal(entry) {
+ return true
+ }
}
- routes = append(routes, unsafeRoutes...)
- return newTunFromFd(
- l,
- *fd,
- tunCidr,
- c.GetInt("tun.mtu", DefaultMTU),
- routes,
- c.GetInt("tun.tx_queue", 500),
- c.GetBool("tun.use_system_route_table", false),
- )
+ return false
+ }
+ for _, oldEntry := range oldRoutes {
+ if !has(oldEntry) {
+ removed = append(removed, oldEntry)
+ }
}
+
+ return removed
}
diff --git a/overlay/tun_android.go b/overlay/tun_android.go
index c5c52db..98ad9b4 100644
--- a/overlay/tun_android.go
+++ b/overlay/tun_android.go
@@ -6,47 +6,58 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
+ "sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
fd int
- cidr *net.IPNet
- routeTree *cidr.Tree4[iputil.VpnIp]
+ cidr netip.Prefix
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
}
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
- routeTree, err := makeRouteTree(l, routes, false)
- if err != nil {
- return nil, err
- }
-
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
- return &tun{
+ t := &tun{
ReadWriteCloser: file,
fd: deviceFd,
cidr: cidr,
l: l,
- routeTree: routeTree,
- }, nil
+ }
+
+ err := t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -54,7 +65,28 @@ func (t tun) Activate() error {
return nil
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes
+ t.Routes.Store(&routes)
+ t.routeTree.Store(routeTree)
+ return nil
+}
+
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go
index caec580..0b573e6 100644
--- a/overlay/tun_darwin.go
+++ b/overlay/tun_darwin.go
@@ -8,13 +8,16 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os"
+ "sync/atomic"
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
)
@@ -22,10 +25,11 @@ import (
type tun struct {
io.ReadWriteCloser
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
DefaultMTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
+ linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -69,12 +73,8 @@ type ifreqMTU struct {
pad [8]byte
}
-func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
- routeTree, err := makeRouteTree(l, routes, false)
- if err != nil {
- return nil, err
- }
-
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+ name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
@@ -142,17 +142,27 @@ func newTun(l *logrus.Logger, name string, cidr *net.IPNet, defaultMTU int, rout
file := os.NewFile(uintptr(fd), "")
- tun := &tun{
+ t := &tun{
ReadWriteCloser: file,
Device: name,
cidr: cidr,
- DefaultMTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
+ DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
- return tun, nil
+ err = t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
func (t *tun) deviceBytes() (o [16]byte) {
@@ -162,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@@ -178,8 +188,13 @@ func (t *tun) Activate() error {
var addr, mask [4]byte
- copy(addr[:], t.cidr.IP.To4())
- copy(mask[:], t.cidr.Mask)
+ if !t.cidr.Addr().Is4() {
+ //TODO: IPV6-WORK
+ panic("need ipv6")
+ }
+
+ addr = t.cidr.Addr().As4()
+ copy(mask[:], prefixToMask(t.cidr))
s, err := unix.Socket(
unix.AF_INET,
@@ -260,6 +275,7 @@ func (t *tun) Activate() error {
if linkAddr == nil {
return fmt.Errorf("unable to discover link_addr for tun interface")
}
+ t.linkAddr = linkAddr
copy(routeAddr.IP[:], addr[:])
copy(maskAddr.IP[:], mask[:])
@@ -278,38 +294,52 @@ func (t *tun) Activate() error {
}
// Unsafe path routes
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
- // We don't allow route MTUs so only install routes with a via
- continue
- }
+ return t.addRoutes(false)
+}
+
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
- copy(routeAddr.IP[:], r.Cidr.IP.To4())
- copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+ if !initial && !change {
+ return nil
+ }
- err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
if err != nil {
- if errors.Is(err, unix.EEXIST) {
- t.l.WithField("route", r.Cidr).
- Warnf("unable to add unsafe_route, identical route already exists")
- } else {
- return err
- }
+ util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
- // TODO how to set metric
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
+ }
}
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- ok, r := t.routeTree.MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
-
- return 0
+ return netip.Addr{}
}
// Get the LinkAddr for the interface of the given name
@@ -340,6 +370,99 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
return nil, nil
}
+func (t *tun) addRoutes(logErrors bool) error {
+ routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ if err != nil {
+ return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+ }
+
+ defer func() {
+ unix.Shutdown(routeSock, unix.SHUT_RDWR)
+ err := unix.Close(routeSock)
+ if err != nil {
+ t.l.WithError(err).Error("failed to close AF_ROUTE socket")
+ }
+ }()
+
+ routeAddr := &netroute.Inet4Addr{}
+ maskAddr := &netroute.Inet4Addr{}
+ routes := *t.Routes.Load()
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
+ // We don't allow route MTUs so only install routes with a via
+ continue
+ }
+
+ if !r.Cidr.Addr().Is4() {
+ //TODO: implement ipv6
+ panic("Cant handle ipv6 routes yet")
+ }
+
+ routeAddr.IP = r.Cidr.Addr().As4()
+ //TODO: we could avoid the copy
+ copy(maskAddr.IP[:], prefixToMask(r.Cidr))
+
+ err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+ if err != nil {
+ if errors.Is(err, unix.EEXIST) {
+ t.l.WithField("route", r.Cidr).
+ Warnf("unable to add unsafe_route, identical route already exists")
+ } else {
+ retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
+ }
+ } else {
+ t.l.WithField("route", r).Info("Added route")
+ }
+ }
+
+ return nil
+}
+
+func (t *tun) removeRoutes(routes []Route) error {
+ routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ if err != nil {
+ return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+ }
+
+ defer func() {
+ unix.Shutdown(routeSock, unix.SHUT_RDWR)
+ err := unix.Close(routeSock)
+ if err != nil {
+ t.l.WithError(err).Error("failed to close AF_ROUTE socket")
+ }
+ }()
+
+ routeAddr := &netroute.Inet4Addr{}
+ maskAddr := &netroute.Inet4Addr{}
+
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ if r.Cidr.Addr().Is6() {
+ //TODO: implement ipv6
+ panic("Cant handle ipv6 routes yet")
+ }
+
+ routeAddr.IP = r.Cidr.Addr().As4()
+ copy(maskAddr.IP[:], prefixToMask(r.Cidr))
+
+ err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+ if err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+ return nil
+}
+
func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
r := netroute.RouteMessage{
Version: unix.RTM_VERSION,
@@ -365,6 +488,30 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
return nil
}
+func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
+ r := netroute.RouteMessage{
+ Version: unix.RTM_VERSION,
+ Type: unix.RTM_DELETE,
+ Seq: 1,
+ Addrs: []netroute.Addr{
+ unix.RTAX_DST: addr,
+ unix.RTAX_GATEWAY: link,
+ unix.RTAX_NETMASK: mask,
+ },
+ }
+
+ data, err := r.Marshal()
+ if err != nil {
+ return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+ }
+ _, err = unix.Write(sock, data[:])
+ if err != nil {
+ return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+ }
+
+ return nil
+}
+
func (t *tun) Read(to []byte) (int, error) {
buf := make([]byte, len(to)+4)
@@ -404,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
return n - 4, err
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -415,3 +562,11 @@ func (t *tun) Name() string {
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}
+
+func prefixToMask(prefix netip.Prefix) []byte {
+ pLen := 128
+ if prefix.Addr().Is4() {
+ pLen = 32
+ }
+ return net.CIDRMask(prefix.Bits(), pLen)
+}
diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go
index e1e4ede..130f8f9 100644
--- a/overlay/tun_disabled.go
+++ b/overlay/tun_disabled.go
@@ -3,7 +3,7 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"strings"
"github.com/rcrowley/go-metrics"
@@ -13,7 +13,7 @@ import (
type disabledTun struct {
read chan []byte
- cidr *net.IPNet
+ cidr netip.Prefix
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
@@ -21,7 +21,7 @@ type disabledTun struct {
l *logrus.Logger
}
-func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
read: make(chan []byte, queueLen),
@@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
return nil
}
-func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
- return 0
+func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
+ return netip.Addr{}
}
-func (t *disabledTun) Cidr() *net.IPNet {
+func (t *disabledTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go
index 338b8f6..bdfeb58 100644
--- a/overlay/tun_freebsd.go
+++ b/overlay/tun_freebsd.go
@@ -9,16 +9,18 @@ import (
"fmt"
"io"
"io/fs"
- "net"
+ "net/netip"
"os"
"os/exec"
"strconv"
+ "sync/atomic"
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
)
const (
@@ -45,10 +47,10 @@ type ifreqDestroy struct {
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -76,14 +78,15 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var file *os.File
var err error
+ deviceName := c.GetString("tun.dev", "")
if deviceName != "" {
file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
}
@@ -144,59 +147,97 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr)))
}
- routeTree, err := makeRouteTree(l, routes, false)
- if err != nil {
- return nil, err
- }
-
- return &tun{
+ t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
- MTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
+ MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
- }, nil
+ }
+
+ err = t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- t.l.Debug("command: ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
- if err = exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()).Run(); err != nil {
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+ t.l.Debug("command: ", cmd.String())
+ if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
- t.l.Debug("command: route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
- if err = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device).Run(); err != nil {
+
+ cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
+ t.l.Debug("command: ", cmd.String())
+ if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
- t.l.Debug("command: ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
- if err = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)).Run(); err != nil {
+
+ cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
+ t.l.Debug("command: ", cmd.String())
+ if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
+
// Unsafe path routes
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
- // We don't allow route MTUs so only install routes with a via
- continue
+ return t.addRoutes(false)
+}
+
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+ if err != nil {
+ util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
- t.l.Debug("command: route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
- if err = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device).Run(); err != nil {
- return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -208,6 +249,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}
+func (t *tun) addRoutes(logErrors bool) error {
+ routes := *t.Routes.Load()
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
+ // We don't allow route MTUs so only install routes with a via
+ continue
+ }
+
+ cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
+ t.l.Debug("command: ", cmd.String())
+ if err := cmd.Run(); err != nil {
+ retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
+ }
+ }
+
+ return nil
+}
+
+func (t *tun) removeRoutes(routes []Route) error {
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
+ t.l.Debug("command: ", cmd.String())
+ if err := cmd.Run(); err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+ return nil
+}
+
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go
index ce65b33..20981f0 100644
--- a/overlay/tun_ios.go
+++ b/overlay/tun_ios.go
@@ -7,46 +7,80 @@ import (
"errors"
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"sync"
+ "sync/atomic"
"syscall"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
)
type tun struct {
io.ReadWriteCloser
- cidr *net.IPNet
- routeTree *cidr.Tree4[iputil.VpnIp]
+ cidr netip.Prefix
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
+ l *logrus.Logger
}
-func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool) (*tun, error) {
- routeTree, err := makeRouteTree(l, routes, false)
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+ file := os.NewFile(uintptr(deviceFd), "/dev/tun")
+ t := &tun{
+ cidr: cidr,
+ ReadWriteCloser: &tunReadCloser{f: file},
+ l: l,
+ }
+
+ err := t.reload(c, true)
if err != nil {
return nil, err
}
- file := os.NewFile(uintptr(deviceFd), "/dev/tun")
- return &tun{
- cidr: cidr,
- ReadWriteCloser: &tunReadCloser{f: file},
- routeTree: routeTree,
- }, nil
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
func (t *tun) Activate() error {
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes
+ t.Routes.Store(&routes)
+ t.routeTree.Store(routeTree)
+ return nil
+}
+
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -108,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
return tr.f.Close()
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go
index a576bf3..0e7e20d 100644
--- a/overlay/tun_linux.go
+++ b/overlay/tun_linux.go
@@ -4,33 +4,36 @@
package overlay
import (
- "bytes"
"fmt"
"io"
"net"
+ "net/netip"
"os"
"strings"
"sync/atomic"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type tun struct {
io.ReadWriteCloser
- fd int
- Device string
- cidr *net.IPNet
- MaxMTU int
- DefaultMTU int
- TXQueueLen int
-
- Routes []Route
- routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+ fd int
+ Device string
+ cidr netip.Prefix
+ MaxMTU int
+ DefaultMTU int
+ TXQueueLen int
+ deviceIndex int
+ ioctlFd uintptr
+
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeChan chan struct{}
useSystemRoutes bool
@@ -61,33 +64,40 @@ type ifreqQLEN struct {
pad [8]byte
}
-func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, useSystemRoutes bool) (*tun, error) {
- routeTree, err := makeRouteTree(l, routes, true)
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+ file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
+
+ t, err := newTunGeneric(c, l, file, cidr)
if err != nil {
return nil, err
}
- file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
+ t.Device = "tun0"
- t := &tun{
- ReadWriteCloser: file,
- fd: int(file.Fd()),
- Device: "tun0",
- cidr: cidr,
- DefaultMTU: defaultMTU,
- TXQueueLen: txQueueLen,
- Routes: routes,
- useSystemRoutes: useSystemRoutes,
- l: l,
- }
- t.routeTree.Store(routeTree)
return t, nil
}
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, txQueueLen int, multiqueue bool, useSystemRoutes bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
- return nil, err
+ // If /dev/net/tun doesn't exist, try to create it (will happen in docker)
+ if os.IsNotExist(err) {
+ err = os.MkdirAll("/dev/net", 0755)
+ if err != nil {
+ return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
+ }
+ err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
+ }
+
+ fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
+ if err != nil {
+ return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
+ }
+ } else {
+ return nil, err
+ }
}
var req ifReq
@@ -95,46 +105,113 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
if multiqueue {
req.Flags |= unix.IFF_MULTI_QUEUE
}
- copy(req.Name[:], deviceName)
+ copy(req.Name[:], c.GetString("tun.dev", ""))
if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
return nil, err
}
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
-
- maxMTU := defaultMTU
- for _, r := range routes {
- if r.MTU == 0 {
- r.MTU = defaultMTU
- }
-
- if r.MTU > maxMTU {
- maxMTU = r.MTU
- }
- }
-
- routeTree, err := makeRouteTree(l, routes, true)
+ t, err := newTunGeneric(c, l, file, cidr)
if err != nil {
return nil, err
}
+ t.Device = name
+
+ return t, nil
+}
+
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
- Device: name,
cidr: cidr,
- MaxMTU: maxMTU,
- DefaultMTU: defaultMTU,
- TXQueueLen: txQueueLen,
- Routes: routes,
- useSystemRoutes: useSystemRoutes,
+ TXQueueLen: c.GetInt("tun.tx_queue", 500),
+ useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
l: l,
}
- t.routeTree.Store(routeTree)
+
+ err := t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
return t, nil
}
+func (t *tun) reload(c *config.C, initial bool) error {
+ routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !routeChange && !c.HasChanged("tun.mtu") {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, true)
+ if err != nil {
+ return err
+ }
+
+ oldDefaultMTU := t.DefaultMTU
+ oldMaxMTU := t.MaxMTU
+ newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
+ newMaxMTU := newDefaultMTU
+ for i, r := range routes {
+ if r.MTU == 0 {
+ routes[i].MTU = newDefaultMTU
+ }
+
+ if r.MTU > t.MaxMTU {
+ newMaxMTU = r.MTU
+ }
+ }
+
+ t.MaxMTU = newMaxMTU
+ t.DefaultMTU = newDefaultMTU
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ if oldMaxMTU != newMaxMTU {
+ t.setMTU()
+ t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
+ }
+
+ if oldDefaultMTU != newDefaultMTU {
+ err := t.setDefaultRoute()
+ if err != nil {
+ t.l.Warn(err)
+ } else {
+ t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+ }
+ }
+
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // This should never be called since addRoutes should log its own errors in a reload condition
+ util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
+ }
+ }
+
+ return nil
+}
+
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@@ -153,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@@ -197,8 +274,10 @@ func (t *tun) Activate() error {
var addr, mask [4]byte
- copy(addr[:], t.cidr.IP.To4())
- copy(mask[:], t.cidr.Mask)
+ //TODO: IPV6-WORK
+ addr = t.cidr.Addr().As4()
+ tmask := net.CIDRMask(t.cidr.Bits(), 32)
+ copy(mask[:], tmask)
s, err := unix.Socket(
unix.AF_INET,
@@ -208,7 +287,7 @@ func (t *tun) Activate() error {
if err != nil {
return err
}
- fd := uintptr(s)
+ t.ioctlFd = uintptr(s)
ifra := ifreqAddr{
Name: devName,
@@ -219,75 +298,114 @@ func (t *tun) Activate() error {
}
// Set the device ip address
- if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
+ if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun address: %s", err)
}
// Set the device network
ifra.Addr.Addr = mask
- if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
+ if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set tun netmask: %s", err)
}
// Set the device name
ifrf := ifReq{Name: devName}
- if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+ if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to set tun device name: %s", err)
}
- // Set the MTU on the device
- ifm := ifreqMTU{Name: devName, MTU: int32(t.MaxMTU)}
- if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
- // This is currently a non fatal condition because the route table must have the MTU set appropriately as well
- t.l.WithError(err).Error("Failed to set tun mtu")
- }
+ // Setup our default MTU
+ t.setMTU()
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
- if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
+ if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
t.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Bring up the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP
- if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+ if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
return fmt.Errorf("failed to bring the tun device up: %s", err)
}
- // Set the routes
link, err := netlink.LinkByName(t.Device)
if err != nil {
return fmt.Errorf("failed to get tun device link: %s", err)
}
+ t.deviceIndex = link.Attrs().Index
+ if err = t.setDefaultRoute(); err != nil {
+ return err
+ }
+
+ // Set the routes
+ if err = t.addRoutes(false); err != nil {
+ return err
+ }
+
+ // Run the interface
+ ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
+ if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+ return fmt.Errorf("failed to run tun device: %s", err)
+ }
+
+ return nil
+}
+
+func (t *tun) setMTU() {
+ // Set the MTU on the device
+ ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
+ if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
+ // This is currently a non fatal condition because the route table must have the MTU set appropriately as well
+ t.l.WithError(err).Error("Failed to set tun mtu")
+ }
+}
+
+func (t *tun) setDefaultRoute() error {
// Default route
- dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
+
+ dr := &net.IPNet{
+ IP: t.cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+ }
+
nr := netlink.Route{
- LinkIndex: link.Attrs().Index,
+ LinkIndex: t.deviceIndex,
Dst: dr,
MTU: t.DefaultMTU,
AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
- Src: t.cidr.IP,
+ Src: net.IP(t.cidr.Addr().AsSlice()),
Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
- err = netlink.RouteReplace(&nr)
+ err := netlink.RouteReplace(&nr)
if err != nil {
return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
}
+ return nil
+}
+
+func (t *tun) addRoutes(logErrors bool) error {
// Path routes
- for _, r := range t.Routes {
+ routes := *t.Routes.Load()
+ for _, r := range routes {
if !r.Install {
continue
}
+ dr := &net.IPNet{
+ IP: r.Cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+ }
+
nr := netlink.Route{
- LinkIndex: link.Attrs().Index,
- Dst: r.Cidr,
+ LinkIndex: t.deviceIndex,
+ Dst: dr,
MTU: r.MTU,
AdvMSS: t.advMSS(r),
Scope: unix.RT_SCOPE_LINK,
@@ -297,22 +415,55 @@ func (t *tun) Activate() error {
nr.Priority = r.Metric
}
- err = netlink.RouteAdd(&nr)
+ err := netlink.RouteReplace(&nr)
if err != nil {
- return fmt.Errorf("failed to set mtu %v on route %v; %v", r.MTU, r.Cidr, err)
+ retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
+ } else {
+ t.l.WithField("route", r).Info("Added route")
}
}
- // Run the interface
- ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
- if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
- return fmt.Errorf("failed to run tun device: %s", err)
- }
-
return nil
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) removeRoutes(routes []Route) {
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ dr := &net.IPNet{
+ IP: r.Cidr.Masked().Addr().AsSlice(),
+ Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+ }
+
+ nr := netlink.Route{
+ LinkIndex: t.deviceIndex,
+ Dst: dr,
+ MTU: r.MTU,
+ AdvMSS: t.advMSS(r),
+ Scope: unix.RT_SCOPE_LINK,
+ }
+
+ if r.Metric > 0 {
+ nr.Priority = r.Metric
+ }
+
+ err := netlink.RouteDel(&nr)
+ if err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+}
+
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -364,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
- if !t.cidr.Contains(r.Gw) {
+ //TODO: IPV6-WORK what if not ok?
+ gwAddr, ok := netip.AddrFromSlice(r.Gw)
+ if !ok {
+ t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+ return
+ }
+
+ gwAddr = gwAddr.Unmap()
+ if !t.cidr.Contains(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
return
@@ -376,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
return
}
- newTree := cidr.NewTree4[iputil.VpnIp]()
- if r.Type == unix.RTM_NEWROUTE {
- for _, oldR := range t.routeTree.Load().List() {
- newTree.AddCIDR(oldR.CIDR, oldR.Value)
- }
+ dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
+ if !ok {
+ t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
+ return
+ }
+ ones, _ := r.Dst.Mask.Size()
+ dst := netip.PrefixFrom(dstAddr, ones)
+
+ newTree := t.routeTree.Load().Clone()
+
+ if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
- newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+ newTree.Insert(dst, gwAddr)
} else {
- gw := iputil.Ip2VpnIp(r.Gw)
- for _, oldR := range t.routeTree.Load().List() {
- if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
- // This is the record to delete
- t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
- continue
- }
-
- newTree.AddCIDR(oldR.CIDR, oldR.Value)
- }
+ newTree.Delete(dst)
+ t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
}
-
t.routeTree.Store(newTree)
}
@@ -410,5 +566,9 @@ func (t *tun) Close() error {
t.ReadWriteCloser.Close()
}
+ if t.ioctlFd > 0 {
+ os.NewFile(t.ioctlFd, "ioctlFd").Close()
+ }
+
return nil
}
diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go
index b1135fe..24ab24f 100644
--- a/overlay/tun_netbsd.go
+++ b/overlay/tun_netbsd.go
@@ -6,17 +6,19 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"os/exec"
"regexp"
"strconv"
+ "sync/atomic"
"syscall"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
)
type ifreqDestroy struct {
@@ -26,10 +28,10 @@ type ifreqDestroy struct {
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -56,56 +58,63 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var file *os.File
var err error
+ deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
if !deviceNameRE.MatchString(deviceName) {
return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
}
- file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+ file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
if err != nil {
return nil, err
}
- routeTree, err := makeRouteTree(l, routes, false)
+ t := &tun{
+ ReadWriteCloser: file,
+ Device: deviceName,
+ cidr: cidr,
+ MTU: c.GetInt("tun.mtu", DefaultMTU),
+ l: l,
+ }
+ err = t.reload(c, true)
if err != nil {
return nil, err
}
- return &tun{
- ReadWriteCloser: file,
- Device: deviceName,
- cidr: cidr,
- MTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
- l: l,
- }, nil
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
+ cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
@@ -116,29 +125,54 @@ func (t *tun) Activate() error {
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
+
// Unsafe path routes
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
- // We don't allow route MTUs so only install routes with a via
- continue
+ return t.addRoutes(false)
+}
+
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+ if err != nil {
+ util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
- t.l.Debug("command: ", cmd.String())
- if err = cmd.Run(); err != nil {
- return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
}
}
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
@@ -150,6 +184,46 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}
+func (t *tun) addRoutes(logErrors bool) error {
+ routes := *t.Routes.Load()
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
+ // We don't allow route MTUs so only install routes with a via
+ continue
+ }
+
+ cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
+ t.l.Debug("command: ", cmd.String())
+ if err := cmd.Run(); err != nil {
+ retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
+ }
+ }
+
+ return nil
+}
+
+func (t *tun) removeRoutes(routes []Route) error {
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
+ t.l.Debug("command: ", cmd.String())
+ if err := cmd.Run(); err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+ return nil
+}
+
func (t *tun) deviceBytes() (o [16]byte) {
for i, c := range t.Device {
o[i] = byte(c)
diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go
index 45c06dc..6463ccb 100644
--- a/overlay/tun_openbsd.go
+++ b/overlay/tun_openbsd.go
@@ -6,24 +6,26 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"os/exec"
"regexp"
"strconv"
+ "sync/atomic"
"syscall"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
)
type tun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
l *logrus.Logger
io.ReadWriteCloser
@@ -40,13 +42,14 @@ func (t *tun) Close() error {
return nil
}
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+ deviceName := c.GetString("tun.dev", "")
if deviceName == "" {
return nil, fmt.Errorf("a device name in the format of tunN must be specified")
}
@@ -60,26 +63,70 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
return nil, err
}
- routeTree, err := makeRouteTree(l, routes, false)
- if err != nil {
- return nil, err
- }
-
- return &tun{
+ t := &tun{
ReadWriteCloser: file,
Device: deviceName,
cidr: cidr,
- MTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
+ MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
- }, nil
+ }
+
+ err = t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
+}
+
+func (t *tun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+ if err != nil {
+ util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
+ }
+
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
+ }
+ }
+
+ return nil
}
func (t *tun) Activate() error {
var err error
// TODO use syscalls instead of exec.Command
- cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -91,35 +138,62 @@ func (t *tun) Activate() error {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
+ cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
if err = cmd.Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
// Unsafe path routes
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
+ return t.addRoutes(false)
+}
+
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
+ return r
+}
+
+func (t *tun) addRoutes(logErrors bool) error {
+ routes := *t.Routes.Load()
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- cmd = exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
+ cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
t.l.Debug("command: ", cmd.String())
- if err = cmd.Run(); err != nil {
- return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.Cidr.String(), err)
+ if err := cmd.Run(); err != nil {
+ retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
}
}
return nil
}
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
- return r
+func (t *tun) removeRoutes(routes []Route) error {
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+ t.l.Debug("command: ", cmd.String())
+ if err := cmd.Run(); err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+ return nil
}
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go
index 964315a..ba15723 100644
--- a/overlay/tun_tester.go
+++ b/overlay/tun_tester.go
@@ -6,20 +6,20 @@ package overlay
import (
"fmt"
"io"
- "net"
+ "net/netip"
"os"
"sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
)
type TestTun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ routeTree *bart.Table[netip.Addr]
l *logrus.Logger
closed atomic.Bool
@@ -27,14 +27,18 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes []Route, _ int, _ bool, _ bool) (*TestTun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
+ _, routes, err := getAllRoutesFromConfig(c, cidr, true)
+ if err != nil {
+ return nil, err
+ }
routeTree, err := makeRouteTree(l, routes, false)
if err != nil {
return nil, err
}
return &TestTun{
- Device: deviceName,
+ Device: c.GetString("tun.dev", ""),
cidr: cidr,
Routes: routes,
routeTree: routeTree,
@@ -44,7 +48,7 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, _ int, routes
}, nil
}
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@@ -82,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
-func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Lookup(ip)
return r
}
@@ -91,7 +95,7 @@ func (t *TestTun) Activate() error {
return nil
}
-func (t *TestTun) Cidr() *net.IPNet {
+func (t *TestTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go
index e27cff2..d78f564 100644
--- a/overlay/tun_water_windows.go
+++ b/overlay/tun_water_windows.go
@@ -4,38 +4,50 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"os/exec"
"strconv"
+ "sync/atomic"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
"github.com/songgao/water"
)
type waterTun struct {
Device string
- cidr *net.IPNet
+ cidr netip.Prefix
MTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
-
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
+ l *logrus.Logger
+ f *net.Interface
*water.Interface
}
-func newWaterTun(l *logrus.Logger, cidr *net.IPNet, defaultMTU int, routes []Route) (*waterTun, error) {
- routeTree, err := makeRouteTree(l, routes, false)
+func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
+ // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
+ t := &waterTun{
+ cidr: cidr,
+ MTU: c.GetInt("tun.mtu", DefaultMTU),
+ l: l,
+ }
+
+ err := t.reload(c, true)
if err != nil {
return nil, err
}
- // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
- return &waterTun{
- cidr: cidr,
- MTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
- }, nil
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
}
func (t *waterTun) Activate() error {
@@ -58,8 +70,8 @@ func (t *waterTun) Activate() error {
`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
fmt.Sprintf("name=%s", t.Device),
"source=static",
- fmt.Sprintf("addr=%s", t.cidr.IP),
- fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
+ fmt.Sprintf("addr=%s", t.cidr.Addr()),
+ fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
"gateway=none",
).Run()
if err != nil {
@@ -74,34 +86,108 @@ func (t *waterTun) Activate() error {
return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
}
- iface, err := net.InterfaceByName(t.Device)
+ t.f, err = net.InterfaceByName(t.Device)
if err != nil {
return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
}
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
+ err = t.addRoutes(false)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (t *waterTun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+ if err != nil {
+ return err
+ }
+
+ if !initial && !change {
+ return nil
+ }
+
+ routeTree, err := makeRouteTree(t.l, routes, false)
+ if err != nil {
+ return err
+ }
+
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
+ } else {
+ for _, r := range findRemovedRoutes(routes, *oldRoutes) {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+ }
+
+ return nil
+}
+
+func (t *waterTun) addRoutes(logErrors bool) error {
+ // Path routes
+ routes := *t.Routes.Load()
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- err = exec.Command(
- "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(iface.Index), "METRIC", strconv.Itoa(r.Metric),
+ err := exec.Command(
+ "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
).Run()
+
if err != nil {
- return fmt.Errorf("failed to add the unsafe_route %s: %v", r.Cidr.String(), err)
+ retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ } else {
+ return retErr
+ }
+ } else {
+ t.l.WithField("route", r).Info("Added route")
}
}
return nil
}
-func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *waterTun) removeRoutes(routes []Route) {
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ err := exec.Command(
+ "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
+ ).Run()
+ if err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
+}
+
+func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *waterTun) Cidr() *net.IPNet {
+func (t *waterTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go
index 57d90cb..3d88309 100644
--- a/overlay/tun_windows.go
+++ b/overlay/tun_windows.go
@@ -5,20 +5,21 @@ package overlay
import (
"fmt"
- "net"
+ "net/netip"
"os"
"path/filepath"
"runtime"
"syscall"
"github.com/sirupsen/logrus"
+ "github.com/slackhq/nebula/config"
)
-func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int, _ bool) (Device, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
-func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route, _ int, _ bool, _ bool) (Device, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
useWintun := true
if err := checkWinTunExists(); err != nil {
l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
@@ -26,14 +27,14 @@ func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int
}
if useWintun {
- device, err := newWinTun(l, deviceName, cidr, defaultMTU, routes)
+ device, err := newWinTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create Wintun interface failed, %w", err)
}
return device, nil
}
- device, err := newWaterTun(l, cidr, defaultMTU, routes)
+ device, err := newWaterTun(c, l, cidr, multiqueue)
if err != nil {
return nil, fmt.Errorf("create wintap driver failed, %w", err)
}
diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go
index 9647024..d010387 100644
--- a/overlay/tun_wintun_windows.go
+++ b/overlay/tun_wintun_windows.go
@@ -4,13 +4,14 @@ import (
"crypto"
"fmt"
"io"
- "net"
"net/netip"
+ "sync/atomic"
"unsafe"
+ "github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/cidr"
- "github.com/slackhq/nebula/iputil"
+ "github.com/slackhq/nebula/config"
+ "github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -20,11 +21,11 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
- cidr *net.IPNet
- prefix netip.Prefix
+ cidr netip.Prefix
MTU int
- Routes []Route
- routeTree *cidr.Tree4[iputil.VpnIp]
+ Routes atomic.Pointer[[]Route]
+ routeTree atomic.Pointer[bart.Table[netip.Addr]]
+ l *logrus.Logger
tun *wintun.NativeTun
}
@@ -48,83 +49,131 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
}
-func newWinTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []Route) (*winTun, error) {
+func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
+ deviceName := c.GetString("tun.dev", "")
guid, err := generateGUIDByDeviceName(deviceName)
if err != nil {
return nil, fmt.Errorf("generate GUID failed: %w", err)
}
+ t := &winTun{
+ Device: deviceName,
+ cidr: cidr,
+ MTU: c.GetInt("tun.mtu", DefaultMTU),
+ l: l,
+ }
+
+ err = t.reload(c, true)
+ if err != nil {
+ return nil, err
+ }
+
var tunDevice wintun.Device
- tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+ tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
// Trying a second time resolves the issue.
l.WithError(err).Debug("Failed to create wintun device, retrying")
- tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, defaultMTU)
+ tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
if err != nil {
return nil, fmt.Errorf("create TUN device failed: %w", err)
}
}
+ t.tun = tunDevice.(*wintun.NativeTun)
+
+ c.RegisterReloadCallback(func(c *config.C) {
+ err := t.reload(c, false)
+ if err != nil {
+ util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+ }
+ })
+
+ return t, nil
+}
- routeTree, err := makeRouteTree(l, routes, false)
+func (t *winTun) reload(c *config.C, initial bool) error {
+ change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
if err != nil {
- return nil, err
+ return err
+ }
+
+ if !initial && !change {
+ return nil
}
- prefix, err := iputil.ToNetIpPrefix(*cidr)
+ routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
- return nil, err
+ return err
}
- return &winTun{
- Device: deviceName,
- cidr: cidr,
- prefix: prefix,
- MTU: defaultMTU,
- Routes: routes,
- routeTree: routeTree,
+ // Teach nebula how to handle the routes before establishing them in the system table
+ oldRoutes := t.Routes.Swap(&routes)
+ t.routeTree.Store(routeTree)
+
+ if !initial {
+ // Remove first, if the system removes a wanted route hopefully it will be re-added next
+ err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+ if err != nil {
+ util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
+ }
- tun: tunDevice.(*wintun.NativeTun),
- }, nil
+ // Ensure any routes we actually want are installed
+ err = t.addRoutes(true)
+ if err != nil {
+ // Catch any stray logs
+ util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
+ }
+ }
+
+ return nil
}
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
- if err := luid.SetIPAddresses([]netip.Prefix{t.prefix}); err != nil {
+ err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
+ if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
+ err = t.addRoutes(false)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (t *winTun) addRoutes(logErrors bool) error {
+ luid := winipcfg.LUID(t.tun.LUID())
+ routes := *t.Routes.Load()
foundDefault4 := false
- routes := make([]*winipcfg.RouteData, 0, len(t.Routes)+1)
- for _, r := range t.Routes {
- if r.Via == nil || !r.Install {
+ for _, r := range routes {
+ if !r.Via.IsValid() || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
- if !foundDefault4 {
- if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
- foundDefault4 = true
+ // Add our unsafe route
+ err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
+ if err != nil {
+ retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+ if logErrors {
+ retErr.Log(t.l)
+ continue
+ } else {
+ return retErr
}
+ } else {
+ t.l.WithField("route", r).Info("Added route")
}
- prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
- if err != nil {
- return err
+ if !foundDefault4 {
+ if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
+ foundDefault4 = true
+ }
}
-
- // Add our unsafe route
- routes = append(routes, &winipcfg.RouteData{
- Destination: prefix,
- NextHop: r.Via.ToNetIpAddr(),
- Metric: uint32(r.Metric),
- })
- }
-
- if err := luid.AddRoutes(routes); err != nil {
- return fmt.Errorf("failed to add routes: %w", err)
}
ipif, err := luid.IPInterface(windows.AF_INET)
@@ -141,16 +190,33 @@ func (t *winTun) Activate() error {
if err := ipif.Set(); err != nil {
return fmt.Errorf("failed to set ip interface: %w", err)
}
+ return nil
+}
+
+func (t *winTun) removeRoutes(routes []Route) error {
+ luid := winipcfg.LUID(t.tun.LUID())
+ for _, r := range routes {
+ if !r.Install {
+ continue
+ }
+
+ err := luid.DeleteRoute(r.Cidr, r.Via)
+ if err != nil {
+ t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+ } else {
+ t.l.WithField("route", r).Info("Removed route")
+ }
+ }
return nil
}
-func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
- _, r := t.routeTree.MostSpecificContains(ip)
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+ r, _ := t.routeTree.Load().Lookup(ip)
return r
}
-func (t *winTun) Cidr() *net.IPNet {
+func (t *winTun) Cidr() netip.Prefix {
return t.cidr
}
diff --git a/overlay/user.go b/overlay/user.go
index 9d819ae..1bb4ef5 100644
--- a/overlay/user.go
+++ b/overlay/user.go
@@ -2,18 +2,17 @@ package overlay
import (
"io"
- "net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
- "github.com/slackhq/nebula/iputil"
)
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
return NewUserDevice(tunCidr)
}
-func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
// these pipes guarantee each write/read will match 1:1
or, ow := io.Pipe()
ir, iw := io.Pipe()
@@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
}
type UserDevice struct {
- tunCidr *net.IPNet
+ tunCidr netip.Prefix
outboundReader *io.PipeReader
outboundWriter *io.PipeWriter
@@ -39,9 +38,9 @@ type UserDevice struct {
func (d *UserDevice) Activate() error {
return nil
}
-func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr }
-func (d *UserDevice) Name() string { return "faketun0" }
-func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr }
+func (d *UserDevice) Name() string { return "faketun0" }
+func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}
diff --git a/pki.go b/pki.go
index 91478ce..ab95a04 100644
--- a/pki.go
+++ b/pki.go
@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
}
if !initial {
+ //TODO: include check for mask equality as well
+
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Details.Ips
diff --git a/relay_manager.go b/relay_manager.go
index 7aa06cc..1a3a4d4 100644
--- a/relay_manager.go
+++ b/relay_manager.go
@@ -2,14 +2,15 @@ package nebula
import (
"context"
+ "encoding/binary"
"errors"
"fmt"
+ "net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
)
type relayManager struct {
@@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
// AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
// relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
-func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) {
+func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
@@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(m.RelayFromIp),
- "relayTo": iputil.VpnIp(m.RelayToIp),
+ "relayFrom": m.RelayFromIp,
+ "relayTo": m.RelayToIp,
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
Info("handleCreateRelayResponse")
- target := iputil.VpnIp(m.RelayToIp)
+ target := m.RelayToIp
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+ targetAddr := netip.AddrFrom4(b)
relay, err := rm.EstablishRelay(h, m)
if err != nil {
@@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
return
}
- peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
+ peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
if !ok {
rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
return
}
if peerRelay.State == PeerRequested {
+ //TODO: IPV6-WORK
+ b = peerHostInfo.vpnIp.As4()
peerRelay.State = Established
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: peerRelay.LocalIndex,
InitiatorRelayIndex: peerRelay.RemoteIndex,
- RelayFromIp: uint32(peerHostInfo.vpnIp),
+ RelayFromIp: binary.BigEndian.Uint32(b[:]),
RelayToIp: uint32(target),
}
msg, err := resp.Marshal()
@@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
} else {
f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ "relayFrom": resp.RelayFromIp,
+ "relayTo": resp.RelayToIp,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": peerHostInfo.vpnIp}).
@@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
}
func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
+ //TODO: IPV6-WORK
+ b := [4]byte{}
+ binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
+ from := netip.AddrFrom4(b)
- from := iputil.VpnIp(m.RelayFromIp)
- target := iputil.VpnIp(m.RelayToIp)
+ binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+ target := netip.AddrFrom4(b)
logMsg := rm.l.WithFields(logrus.Fields{
"relayFrom": from,
@@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
logMsg.Info("handleCreateRelayRequest")
// Is the source of the relay me? This should never happen, but did happen due to
// an issue migrating relays over to newly re-handshaked host info objects.
- if from == f.myVpnIp {
- logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+ if from == f.myVpnNet.Addr() {
+ logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
return
}
// Is the target of the relay me?
- if target == f.myVpnIp {
+ if target == f.myVpnNet.Addr() {
existingRelay, ok := h.relayState.QueryRelayForByIp(from)
if ok {
switch existingRelay.State {
@@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
return
}
+ //TODO: IPV6-WORK
+ fromB := from.As4()
+ targetB := target.As4()
+
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
- RelayFromIp: uint32(from),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := resp.Marshal()
if err != nil {
@@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
+ "relayFrom": from,
+ "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
@@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
f.Handshake(target)
return
}
- if peer.remote == nil {
+ if !peer.remote.IsValid() {
// Only create relays to peers for whom I have a direct connection
return
}
@@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
sendCreateRequest = true
}
if sendCreateRequest {
+ //TODO: IPV6-WORK
+ fromB := h.vpnIp.As4()
+ targetB := target.As4()
+
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
- RelayFromIp: uint32(h.vpnIp),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := req.Marshal()
if err != nil {
@@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(req.RelayFromIp),
- "relayTo": iputil.VpnIp(req.RelayToIp),
+ //TODO: IPV6-WORK another lazy used to use the req object
+ "relayFrom": h.vpnIp,
+ "relayTo": target,
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": target}).
@@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
return
}
+ //TODO: IPV6-WORK
+ fromB := h.vpnIp.As4()
+ targetB := target.As4()
resp := NebulaControl{
Type: NebulaControl_CreateRelayResponse,
ResponderRelayIndex: relay.LocalIndex,
InitiatorRelayIndex: relay.RemoteIndex,
- RelayFromIp: uint32(h.vpnIp),
- RelayToIp: uint32(target),
+ RelayFromIp: binary.BigEndian.Uint32(fromB[:]),
+ RelayToIp: binary.BigEndian.Uint32(targetB[:]),
}
msg, err := resp.Marshal()
if err != nil {
@@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
} else {
f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
rm.l.WithFields(logrus.Fields{
- "relayFrom": iputil.VpnIp(resp.RelayFromIp),
- "relayTo": iputil.VpnIp(resp.RelayToIp),
+ //TODO: IPV6-WORK more lazy, used to use resp object
+ "relayFrom": h.vpnIp,
+ "relayTo": target,
"initiatorRelayIndex": resp.InitiatorRelayIndex,
"responderRelayIndex": resp.ResponderRelayIndex,
"vpnIp": h.vpnIp}).
@@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
}
}
}
-
-func (rm *relayManager) RemoveRelay(localIdx uint32) {
- rm.hostmap.RemoveRelay(localIdx)
-}
diff --git a/remote_list.go b/remote_list.go
index 60a1afd..fa14f42 100644
--- a/remote_list.go
+++ b/remote_list.go
@@ -1,7 +1,6 @@
package nebula
import (
- "bytes"
"context"
"net"
"net/netip"
@@ -12,16 +11,14 @@ import (
"time"
"github.com/sirupsen/logrus"
- "github.com/slackhq/nebula/iputil"
- "github.com/slackhq/nebula/udp"
)
// forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udp.Addr, preferred bool)
+type forEachFunc func(addr netip.AddrPort, preferred bool)
// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
// CacheMap is a struct that better represents the lighthouse cache for humans
// The string key is the owners vpnIp
@@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
// We don't reason about ipv4 vs ipv6 here
type Cache struct {
- Learned []*udp.Addr `json:"learned,omitempty"`
- Reported []*udp.Addr `json:"reported,omitempty"`
- Relay []*net.IP `json:"relay"`
+ Learned []netip.AddrPort `json:"learned,omitempty"`
+ Reported []netip.AddrPort `json:"reported,omitempty"`
+ Relay []netip.Addr `json:"relay"`
}
//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -46,7 +43,7 @@ type cache struct {
}
type cacheRelay struct {
- relay []uint32
+ relay []netip.Addr
}
// cacheV4 stores learned and reported ipv4 records under cache
@@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
continue
}
for _, a := range addrs {
- netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+ netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
}
}
origSet := r.ips.Load()
@@ -193,22 +190,22 @@ type RemoteList struct {
sync.RWMutex
// A deduplicated set of addresses. Any accessor should lock beforehand.
- addrs []*udp.Addr
+ addrs []netip.AddrPort
// A set of relay addresses. VpnIp addresses that the remote identified as relays.
- relays []*iputil.VpnIp
+ relays []netip.Addr
// These are maps to store v4 and v6 addresses per lighthouse
// Map key is the vpnIp of the person that told us about this the cached entries underneath.
// For learned addresses, this is the vpnIp that sent the packet
- cache map[iputil.VpnIp]*cache
+ cache map[netip.Addr]*cache
hr *hostnamesResults
shouldAdd func(netip.Addr) bool
// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
// They should not be tried again during a handshake
- badRemotes []*udp.Addr
+ badRemotes []netip.AddrPort
// A flag that the cache may have changed and addrs needs to be rebuilt
shouldRebuild bool
@@ -217,9 +214,9 @@ type RemoteList struct {
// NewRemoteList creates a new empty RemoteList
func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
return &RemoteList{
- addrs: make([]*udp.Addr, 0),
- relays: make([]*iputil.VpnIp, 0),
- cache: make(map[iputil.VpnIp]*cache),
+ addrs: make([]netip.AddrPort, 0),
+ relays: make([]netip.Addr, 0),
+ cache: make(map[netip.Addr]*cache),
shouldAdd: shouldAdd,
}
}
@@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
// Len locks and reports the size of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
r.Rebuild(preferredRanges)
r.RLock()
defer r.RUnlock()
@@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
// ForEach locks and will call the forEachFunc for every deduplicated address in the list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
r.Rebuild(preferredRanges)
r.RLock()
for _, v := range r.addrs {
- forEach(v, isPreferred(v.IP, preferredRanges))
+ forEach(v, isPreferred(v.Addr(), preferredRanges))
}
r.RUnlock()
}
// CopyAddrs locks and makes a deep copy of the deduplicated address list
// The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
+func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
if r == nil {
return nil
}
@@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
r.RLock()
defer r.RUnlock()
- c := make([]*udp.Addr, len(r.addrs))
+ c := make([]netip.AddrPort, len(r.addrs))
for i, v := range r.addrs {
- c[i] = v.Copy()
+ c[i] = v
}
return c
}
@@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
// TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
r.Lock()
defer r.Unlock()
- if v4 := addr.IP.To4(); v4 != nil {
- r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+ if remote.Addr().Is4() {
+ r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
} else {
- r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+ r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
}
}
@@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
c := cm[vpnIp]
if c == nil {
c = &Cache{
- Learned: make([]*udp.Addr, 0),
- Reported: make([]*udp.Addr, 0),
- Relay: make([]*net.IP, 0),
+ Learned: make([]netip.AddrPort, 0),
+ Reported: make([]netip.AddrPort, 0),
+ Relay: make([]netip.Addr, 0),
}
cm[vpnIp] = c
}
@@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
if mc.v4 != nil {
if mc.v4.learned != nil {
- c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+ c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
}
for _, a := range mc.v4.reported {
- c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+ c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
}
}
if mc.v6 != nil {
if mc.v6.learned != nil {
- c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+ c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
}
for _, a := range mc.v6.reported {
- c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+ c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
}
}
if mc.relay != nil {
for _, a := range mc.relay.relay {
- nip := iputil.VpnIp(a).ToIP()
- c.Relay = append(c.Relay, &nip)
+ c.Relay = append(c.Relay, a)
}
}
}
@@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udp.Addr) {
- if bad == nil {
+func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
+ if !bad.IsValid() {
// relays can have nil udp Addrs
return
}
@@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
}
// We copy here because we are taking something else's memory and we can't trust everything
- r.badRemotes = append(r.badRemotes, bad.Copy())
+ r.badRemotes = append(r.badRemotes, bad)
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true
}
// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
+func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
r.RLock()
defer r.RUnlock()
- c := make([]*udp.Addr, len(r.badRemotes))
+ c := make([]netip.AddrPort, len(r.badRemotes))
for i, v := range r.badRemotes {
- c[i] = v.Copy()
+ c[i] = v
}
return c
}
@@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
// Rebuild locks and generates the deduplicated address list only if there is work to be done
// There is generally no reason to call this directly but it is safe to do so
-func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
r.Lock()
defer r.Unlock()
@@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
}
// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
+func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
for _, v := range r.badRemotes {
- if v.Equals(remote) {
+ if v == remote {
return true
}
}
@@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
}
// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
}
}
-func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) {
+func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeRelay(ownerVpnIp)
@@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV4(ownerVpnIp)
@@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
// deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true
r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
}
// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
// and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
// This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
r.shouldRebuild = true
c := r.unlockedGetOrMakeV6(ownerVpnIp)
@@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
}
}
-func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay {
+func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
// The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
// The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
am := r.cache[ownerVpnIp]
if am == nil {
am = &cache{}
@@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
for _, c := range r.cache {
if c.v4 != nil {
if c.v4.learned != nil {
- u := NewUDPAddrFromLH4(c.v4.learned)
+ u := AddrPortFromIp4AndPort(c.v4.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v4.reported {
- u := NewUDPAddrFromLH4(v)
+ u := AddrPortFromIp4AndPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
if c.v6 != nil {
if c.v6.learned != nil {
- u := NewUDPAddrFromLH6(c.v6.learned)
+ u := AddrPortFromIp6AndPort(c.v6.learned)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
}
for _, v := range c.v6.reported {
- u := NewUDPAddrFromLH6(v)
+ u := AddrPortFromIp6AndPort(v)
if !r.unlockedIsBad(u) {
addrs = append(addrs, u)
}
@@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
if c.relay != nil {
for _, v := range c.relay.relay {
- ip := iputil.VpnIp(v)
- relays = append(relays, &ip)
+ relays = append(relays, v)
}
}
}
@@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
dnsAddrs := r.hr.GetIPs()
for _, addr := range dnsAddrs {
if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
- v6 := addr.Addr().As16()
- addrs = append(addrs, &udp.Addr{
- IP: v6[:],
- Port: addr.Port(),
- })
+ addrs = append(addrs, addr)
}
}
@@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
}
// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
-func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
n := len(r.addrs)
if n < 2 {
return
@@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
b := r.addrs[j]
// Preferred addresses first
- aPref := isPreferred(a.IP, preferredRanges)
- bPref := isPreferred(b.IP, preferredRanges)
+ aPref := isPreferred(a.Addr(), preferredRanges)
+ bPref := isPreferred(b.Addr(), preferredRanges)
switch {
case aPref && !bPref:
// If i is preferred and j is not, i is less than j
@@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
}
// ipv6 addresses 2nd
- a4 := a.IP.To4()
- b4 := b.IP.To4()
+ a4 := a.Addr().Is4()
+ b4 := b.Addr().Is4()
switch {
- case a4 == nil && b4 != nil:
+ case a4 == false && b4 == true:
// If i is v6 and j is v4, i is less than j
return true
- case a4 != nil && b4 == nil:
+ case a4 == true && b4 == false:
// If j is v6 and i is v4, i is not less than j
return false
- case a4 != nil && b4 != nil:
- // Special case for ipv4, a4 and b4 are not nil
- aPrivate := isPrivateIP(a4)
- bPrivate := isPrivateIP(b4)
+ case a4 == true && b4 == true:
+ // i and j are both ipv4
+ aPrivate := a.Addr().IsPrivate()
+ bPrivate := b.Addr().IsPrivate()
switch {
case !aPrivate && bPrivate:
// If i is a public ip (not private) and j is a private ip, i is less then j
@@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
}
// lexical order of ips 3rd
- c := bytes.Compare(a.IP, b.IP)
+ c := a.Addr().Compare(b.Addr())
if c == 0 {
// Ips are the same, Lexical order of ports 4th
- return a.Port < b.Port
+ return a.Port() < b.Port()
}
// Ip wasn't the same
@@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
// Deduplicate
a, b := 0, 1
for b < n {
- if !r.addrs[a].Equals(r.addrs[b]) {
+ if r.addrs[a] != r.addrs[b] {
a++
if a != b {
r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
@@ -693,7 +684,7 @@ func minInt(a, b int) int {
}
// isPreferred returns true of the ip is contained in the preferredRanges list
-func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
//TODO: this would be better in a CIDR6Tree
for _, p := range preferredRanges {
if p.Contains(ip) {
@@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
}
return false
}
-
-var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
-var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
-var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
-
-// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
-func isPrivateIP(ip net.IP) bool {
- //TODO: another great cidrtree option
- //TODO: Private for ipv6 or just let it ride?
- return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-}
diff --git a/remote_list_test.go b/remote_list_test.go
index 49aa171..62a892b 100644
--- a/remote_list_test.go
+++ b/remote_list_test.go
@@ -1,47 +1,47 @@
package nebula
import (
- "net"
+ "encoding/binary"
+ "net/netip"
"testing"
- "github.com/slackhq/nebula/iputil"
"github.com/stretchr/testify/assert"
)
func TestRemoteList_Rebuild(t *testing.T) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
+ newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
+ newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 1,
- 1,
+ netip.MustParseAddr("0.0.0.1"),
+ netip.MustParseAddr("0.0.0.1"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
- NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
- NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"), // this is duped
+ newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
+ newIp6AndPortFromString("[1::1]:2"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv6 first, sorted lexically within
@@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
// Now ensure we can hoist ipv4 up
- _, ipNet, err := net.ParseCIDR("0.0.0.0/0")
- assert.NoError(t, err)
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// ipv4 first, public then private, lexically within them
@@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
// Ensure we can hoist a specific ipv4 range over anything else
- _, ipNet, err = net.ParseCIDR("172.17.0.0/16")
- assert.NoError(t, err)
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
// Preferred ipv4 first
@@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
func BenchmarkFullRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"),
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"),
+ newIp4AndPortFromString("172.18.0.1:10101"),
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"),
+ newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
}
})
- _, ipNet, err := net.ParseCIDR("172.17.0.0/16")
- assert.NoError(b, err)
+ ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{ipNet1})
}
})
- _, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
- assert.NoError(b, err)
+ ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ rl.Rebuild([]netip.Prefix{ipNet2})
}
})
- _, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
- assert.NoError(b, err)
+ ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
}
})
}
@@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
func BenchmarkSortRebuild(b *testing.B) {
rl := NewRemoteList(nil)
rl.unlockedSetV4(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip4AndPort{
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe
- {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+ newIp4AndPortFromString("70.199.182.92:1475"),
+ newIp4AndPortFromString("172.17.0.182:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"),
+ newIp4AndPortFromString("172.18.0.1:10101"),
+ newIp4AndPortFromString("172.19.0.1:10101"),
+ newIp4AndPortFromString("172.31.0.1:10101"),
+ newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe
+ newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
},
- func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+ func(netip.Addr, *Ip4AndPort) bool { return true },
)
rl.unlockedSetV6(
- 0,
- 0,
+ netip.MustParseAddr("0.0.0.0"),
+ netip.MustParseAddr("0.0.0.0"),
[]*Ip6AndPort{
- NewIp6AndPort(net.ParseIP("1::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
- NewIp6AndPort(net.ParseIP("1:100::1"), 1),
- NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+ newIp6AndPortFromString("[1::1]:1"),
+ newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+ newIp6AndPortFromString("[1:100::1]:1"),
+ newIp6AndPortFromString("[1::1]:1"), // this is a dupe
},
- func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+ func(netip.Addr, *Ip6AndPort) bool { return true },
)
b.Run("no preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
rl.shouldRebuild = true
- rl.Rebuild([]*net.IPNet{})
+ rl.Rebuild([]netip.Prefix{})
}
})
- _, ipNet, err := net.ParseCIDR("172.17.0.0/16")
- rl.Rebuild([]*net.IPNet{ipNet})
+ ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
+ rl.Rebuild([]netip.Prefix{ipNet1})
- assert.NoError(b, err)
b.Run("1 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet})
+ rl.Rebuild([]netip.Prefix{ipNet1})
}
})
- _, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
- assert.NoError(b, err)
b.Run("2 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
}
})
- _, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
- assert.NoError(b, err)
b.Run("3 preferred", func(b *testing.B) {
for i := 0; i < b.N; i++ {
- rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+ rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
}
})
}
+
+func newIp4AndPortFromString(s string) *Ip4AndPort {
+ a := netip.MustParseAddrPort(s)
+ v4Addr := a.Addr().As4()
+ return &Ip4AndPort{
+ Ip: binary.BigEndian.Uint32(v4Addr[:]),
+ Port: uint32(a.Port()),
+ }
+}
+
+func newIp6AndPortFromString(s string) *Ip6AndPort {
+ a := netip.MustParseAddrPort(s)
+ v6Addr := a.Addr().As16()
+ return &Ip6AndPort{
+ Hi: binary.BigEndian.Uint64(v6Addr[:8]),
+ Lo: binary.BigEndian.Uint64(v6Addr[8:]),
+ Port: uint32(a.Port()),
+ }
+}
diff --git a/service/service.go b/service/service.go
index 66ce864..50c1d4a 100644
--- a/service/service.go
+++ b/service/service.go
@@ -17,7 +17,7 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
"golang.org/x/sync/errgroup"
- "gvisor.dev/gvisor/pkg/bufferv2"
+ "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -81,7 +81,7 @@ func New(config *config.C) (*Service, error) {
if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
}
- ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
+ ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
s.ipstack.SetRouteTable([]tcpip.Route{
{
Destination: ipv4Subnet,
@@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) {
ipNet := device.Cidr()
pa := tcpip.ProtocolAddress{
- AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(),
+ AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
Protocol: ipv4.ProtocolNumber,
}
if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
@@ -124,7 +124,7 @@ func New(config *config.C) (*Service, error) {
return err
}
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])),
+ Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
})
linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
@@ -136,7 +136,7 @@ func New(config *config.C) (*Service, error) {
eg.Go(func() error {
for {
packet := linkEP.ReadContext(ctx)
- if packet.IsNil() {
+ if packet == nil {
if err := ctx.Err(); err != nil {
return err
}
@@ -166,7 +166,7 @@ func (s *Service) DialContext(ctx context.Context, network, address string) (net
fullAddr := tcpip.FullAddress{
NIC: nicID,
- Addr: tcpip.Address(addr.IP),
+ Addr: tcpip.AddrFromSlice(addr.IP),
Port: uint16(addr.Port),
}
diff --git a/service/service_test.go b/service/service_test.go
index d1909cd..3176209 100644
--- a/service/service_test.go
+++ b/service/service_test.go
@@ -4,7 +4,7 @@ import (
"bytes"
"context"
"errors"
- "net"
+ "net/netip"
"testing"
"time"
@@ -18,12 +18,8 @@ import (
type m map[string]interface{}
-func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
-
- vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
- copy(vpnIpNet.IP, udpIp)
-
- _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
+ _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
caB, err := caCrt.MarshalToPEM()
if err != nil {
panic(err)
@@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
}
func TestService(t *testing.T) {
- ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
- a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+ ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+ a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
"static_host_map": m{},
"lighthouse": m{
"am_lighthouse": true,
@@ -94,7 +90,7 @@ func TestService(t *testing.T) {
"port": 4243,
},
})
- b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+ b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
"static_host_map": m{
"10.0.0.1": []string{"localhost:4243"},
},
diff --git a/ssh.go b/ssh.go
index e99205c..2ff0954 100644
--- a/ssh.go
+++ b/ssh.go
@@ -7,6 +7,7 @@ import (
"flag"
"fmt"
"net"
+ "net/netip"
"os"
"reflect"
"runtime"
@@ -18,9 +19,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
- "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/sshd"
- "github.com/slackhq/nebula/udp"
)
type sshListHostMapFlags struct {
@@ -51,6 +50,11 @@ type sshCreateTunnelFlags struct {
Address string
}
+type sshDeviceInfoFlags struct {
+ Json bool
+ Pretty bool
+}
+
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
c.RegisterReloadCallback(func(c *config.C) {
if c.GetBool("sshd.enabled", false) {
@@ -110,6 +114,19 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
}
+ // Clear existing trusted CAs and authorized keys
+ ssh.ClearTrustedCAs()
+ ssh.ClearAuthorizedKeys()
+
+ rawCAs := c.GetStringSlice("sshd.trusted_cas", []string{})
+ for _, caAuthorizedKey := range rawCAs {
+ err := ssh.AddTrustedCA(caAuthorizedKey)
+ if err != nil {
+ l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
+ continue
+ }
+ }
+
rawKeys := c.Get("sshd.authorized_users")
keys, ok := rawKeys.([]interface{})
if ok {
@@ -231,7 +248,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "start-cpu-profile",
- ShortDescription: "Starts a cpu profile and write output to the provided file",
+ ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`",
Callback: sshStartCpuProfile,
})
@@ -246,7 +263,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "save-heap-profile",
- ShortDescription: "Saves a heap profile to the provided path",
+ ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`",
Callback: sshGetHeapProfile,
})
@@ -258,7 +275,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
ssh.RegisterCommand(&sshd.Command{
Name: "save-mutex-profile",
- ShortDescription: "Saves a mutex profile to the provided path",
+ ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`",
Callback: sshGetMutexProfile,
})
@@ -286,6 +303,21 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
},
})
+ ssh.RegisterCommand(&sshd.Command{
+ Name: "device-info",
+ ShortDescription: "Prints information about the network device.",
+ Flags: func() (*flag.FlagSet, interface{}) {
+ fl := flag.NewFlagSet("", flag.ContinueOnError)
+ s := sshDeviceInfoFlags{}
+ fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
+ fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
+ return fl, &s
+ },
+ Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+ return sshDeviceInfo(f, fs, w)
+ },
+ })
+
ssh.RegisterCommand(&sshd.Command{
Name: "print-cert",
ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
@@ -398,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
}
sort.Slice(hm, func(i, j int) bool {
- return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
+ return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
})
if fs.Json || fs.Pretty {
@@ -512,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -541,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -583,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -603,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
}
- var addr *udp.Addr
+ var addr netip.AddrPort
if flags.Address != "" {
- addr = udp.NewAddrFromString(flags.Address)
- if addr == nil {
+ addr, err = netip.ParseAddrPort(flags.Address)
+ if err != nil {
return w.WriteLine("Address could not be parsed")
}
}
hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
- if addr != nil {
+ if addr.IsValid() {
hostInfo.SetRemote(addr)
}
@@ -634,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
return w.WriteLine("No address was provided")
}
- addr := udp.NewAddrFromString(flags.Address)
- if addr == nil {
+ addr, err := netip.ParseAddrPort(flags.Address)
+ if err != nil {
return w.WriteLine("Address could not be parsed")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -759,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
cert := ifce.pki.GetCertState().Certificate
if len(a) > 0 {
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -829,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
Error error
Type string
State string
- PeerIp iputil.VpnIp
+ PeerIp netip.Addr
LocalIndex uint32
RemoteIndex uint32
- RelayedThrough []iputil.VpnIp
+ RelayedThrough []netip.Addr
}
type RelayOutput struct {
- NebulaIp iputil.VpnIp
+ NebulaIp netip.Addr
RelayForIps []RelayFor
}
@@ -919,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
return w.WriteLine("No vpn ip was provided")
}
- parsedIp := net.ParseIP(a[0])
- if parsedIp == nil {
+ vpnIp, err := netip.ParseAddr(a[0])
+ if err != nil {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
- vpnIp := iputil.Ip2VpnIp(parsedIp)
- if vpnIp == 0 {
+ if !vpnIp.IsValid() {
return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
}
@@ -939,7 +965,34 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}
- return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
+ return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
+}
+
+func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
+
+ data := struct {
+ Name string `json:"name"`
+ Cidr string `json:"cidr"`
+ }{
+ Name: ifce.inside.Name(),
+ Cidr: ifce.inside.Cidr().String(),
+ }
+
+ flags, ok := fs.(*sshDeviceInfoFlags)
+ if !ok {
+ return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)
+ }
+
+ if flags.Json || flags.Pretty {
+ js := json.NewEncoder(w.GetWriter())
+ if flags.Pretty {
+ js.SetIndent("", " ")
+ }
+
+ return js.Encode(data)
+ } else {
+ return w.WriteLine(fmt.Sprintf("name=%v cidr=%v", data.Name, data.Cidr))
+ }
}
func sshReload(c *config.C, w sshd.StringWriter) error {
diff --git a/sshd/server.go b/sshd/server.go
index 4a78fdf..9e8c721 100644
--- a/sshd/server.go
+++ b/sshd/server.go
@@ -1,6 +1,7 @@
package sshd
import (
+ "bytes"
"errors"
"fmt"
"net"
@@ -15,8 +16,11 @@ type SSHServer struct {
config *ssh.ServerConfig
l *logrus.Entry
+ certChecker *ssh.CertChecker
+
// Map of user -> authorized keys
trustedKeys map[string]map[string]bool
+ trustedCAs []ssh.PublicKey
// List of available commands
helpCommand *Command
@@ -31,6 +35,7 @@ type SSHServer struct {
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
+
s := &SSHServer{
trustedKeys: make(map[string]map[string]bool),
l: l,
@@ -38,8 +43,43 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
conns: make(map[int]*session),
}
+ cc := ssh.CertChecker{
+ IsUserAuthority: func(auth ssh.PublicKey) bool {
+ for _, ca := range s.trustedCAs {
+ if bytes.Equal(ca.Marshal(), auth.Marshal()) {
+ return true
+ }
+ }
+
+ return false
+ },
+ UserKeyFallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
+ pk := string(pubKey.Marshal())
+ fp := ssh.FingerprintSHA256(pubKey)
+
+ tk, ok := s.trustedKeys[c.User()]
+ if !ok {
+ return nil, fmt.Errorf("unknown user %s", c.User())
+ }
+
+ _, ok = tk[pk]
+ if !ok {
+ return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "fp": fp,
+ "user": c.User(),
+ },
+ }, nil
+
+ },
+ }
+
s.config = &ssh.ServerConfig{
- PublicKeyCallback: s.matchPubKey,
+ PublicKeyCallback: cc.Authenticate,
//TODO: AuthLogCallback: s.authAttempt,
//TODO: version string
ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
@@ -66,10 +106,26 @@ func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error {
return nil
}
+func (s *SSHServer) ClearTrustedCAs() {
+ s.trustedCAs = []ssh.PublicKey{}
+}
+
func (s *SSHServer) ClearAuthorizedKeys() {
s.trustedKeys = make(map[string]map[string]bool)
}
+// AddTrustedCA adds a trusted CA for user certificates
+func (s *SSHServer) AddTrustedCA(pubKey string) error {
+ pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
+ if err != nil {
+ return err
+ }
+
+ s.trustedCAs = append(s.trustedCAs, pk)
+ s.l.WithField("sshKey", pubKey).Info("Trusted CA key")
+ return nil
+}
+
// AddAuthorizedKey adds an ssh public key for a user
func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error {
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey))
@@ -178,26 +234,3 @@ func (s *SSHServer) closeSessions() {
}
s.connsLock.Unlock()
}
-
-func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
- pk := string(pubKey.Marshal())
- fp := ssh.FingerprintSHA256(pubKey)
-
- tk, ok := s.trustedKeys[c.User()]
- if !ok {
- return nil, fmt.Errorf("unknown user %s", c.User())
- }
-
- _, ok = tk[pk]
- if !ok {
- return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp)
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "fp": fp,
- "user": c.User(),
- },
- }, nil
-}
diff --git a/test/tun.go b/test/tun.go
index 86656c9..fbf5829 100644
--- a/test/tun.go
+++ b/test/tun.go
@@ -3,23 +3,21 @@ package test
import (
"errors"
"io"
- "net"
-
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
type NoopTun struct{}
-func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
- return 0
+func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
+ return netip.Addr{}
}
func (NoopTun) Activate() error {
return nil
}
-func (NoopTun) Cidr() *net.IPNet {
- return nil
+func (NoopTun) Cidr() netip.Prefix {
+ return netip.Prefix{}
}
func (NoopTun) Name() string {
diff --git a/timeout_test.go b/timeout_test.go
index 3f81ff4..4c6364e 100644
--- a/timeout_test.go
+++ b/timeout_test.go
@@ -1,6 +1,7 @@
package nebula
import (
+ "net/netip"
"testing"
"time"
@@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
assert.Equal(t, 0, tw.current)
fps := []firewall.Packet{
- {LocalIP: 1},
- {LocalIP: 2},
- {LocalIP: 3},
- {LocalIP: 4},
+ {LocalIP: netip.MustParseAddr("0.0.0.1")},
+ {LocalIP: netip.MustParseAddr("0.0.0.2")},
+ {LocalIP: netip.MustParseAddr("0.0.0.3")},
+ {LocalIP: netip.MustParseAddr("0.0.0.4")},
}
tw.Add(fps[0], time.Second*1)
diff --git a/udp/conn.go b/udp/conn.go
index a2c24a1..fa4e443 100644
--- a/udp/conn.go
+++ b/udp/conn.go
@@ -1,6 +1,8 @@
package udp
import (
+ "net/netip"
+
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
@@ -9,7 +11,7 @@ import (
const MTU = 9001
type EncReader func(
- addr *Addr,
+ addr netip.AddrPort,
out []byte,
packet []byte,
header *header.H,
@@ -22,9 +24,9 @@ type EncReader func(
type Conn interface {
Rebind() error
- LocalAddr() (*Addr, error)
+ LocalAddr() (netip.AddrPort, error)
ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
- WriteTo(b []byte, addr *Addr) error
+ WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
Close() error
}
@@ -34,13 +36,13 @@ type NoopConn struct{}
func (NoopConn) Rebind() error {
return nil
}
-func (NoopConn) LocalAddr() (*Addr, error) {
- return nil, nil
+func (NoopConn) LocalAddr() (netip.AddrPort, error) {
+ return netip.AddrPort{}, nil
}
func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
return
}
-func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
+func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}
func (NoopConn) ReloadConfig(_ *config.C) {
diff --git a/udp/temp.go b/udp/temp.go
index 2efe31d..b281906 100644
--- a/udp/temp.go
+++ b/udp/temp.go
@@ -1,9 +1,10 @@
package udp
import (
- "github.com/slackhq/nebula/iputil"
+ "net/netip"
)
//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
+// TODO: IPV6-WORK this can likely be removed now
+type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)
diff --git a/udp/udp_all.go b/udp/udp_all.go
deleted file mode 100644
index 093bf69..0000000
--- a/udp/udp_all.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package udp
-
-import (
- "encoding/json"
- "fmt"
- "net"
- "strconv"
-)
-
-type m map[string]interface{}
-
-type Addr struct {
- IP net.IP
- Port uint16
-}
-
-func NewAddr(ip net.IP, port uint16) *Addr {
- addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
- copy(addr.IP, ip.To16())
- return &addr
-}
-
-func NewAddrFromString(s string) *Addr {
- ip, port, err := ParseIPAndPort(s)
- //TODO: handle err
- _ = err
- return &Addr{IP: ip.To16(), Port: port}
-}
-
-func (ua *Addr) Equals(t *Addr) bool {
- if t == nil || ua == nil {
- return t == nil && ua == nil
- }
- return ua.IP.Equal(t.IP) && ua.Port == t.Port
-}
-
-func (ua *Addr) String() string {
- if ua == nil {
- return ""
- }
-
- return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
-}
-
-func (ua *Addr) MarshalJSON() ([]byte, error) {
- if ua == nil {
- return nil, nil
- }
-
- return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
-}
-
-func (ua *Addr) Copy() *Addr {
- if ua == nil {
- return nil
- }
-
- nu := Addr{
- Port: ua.Port,
- IP: make(net.IP, len(ua.IP)),
- }
-
- copy(nu.IP, ua.IP)
- return &nu
-}
-
-type AddrSlice []*Addr
-
-func (a AddrSlice) Equal(b AddrSlice) bool {
- if len(a) != len(b) {
- return false
- }
-
- for i := range a {
- if !a[i].Equals(b[i]) {
- return false
- }
- }
-
- return true
-}
-
-func ParseIPAndPort(s string) (net.IP, uint16, error) {
- rIp, sPort, err := net.SplitHostPort(s)
- if err != nil {
- return nil, 0, err
- }
-
- addr, err := net.ResolveIPAddr("ip", rIp)
- if err != nil {
- return nil, 0, err
- }
-
- iPort, err := strconv.Atoi(sPort)
- if err != nil {
- return nil, 0, err
- }
-
- return addr.IP, uint16(iPort), nil
-}
diff --git a/udp/udp_android.go b/udp/udp_android.go
index 8d69074..bb19195 100644
--- a/udp/udp_android.go
+++ b/udp/udp_android.go
@@ -6,13 +6,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go
index 785aa6a..65ef31a 100644
--- a/udp/udp_bsd.go
+++ b/udp/udp_bsd.go
@@ -9,13 +9,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go
index 08e1b6a..183ac7a 100644
--- a/udp/udp_darwin.go
+++ b/udp/udp_darwin.go
@@ -8,13 +8,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_generic.go b/udp/udp_generic.go
index 1dd6d1d..2d84536 100644
--- a/udp/udp_generic.go
+++ b/udp/udp_generic.go
@@ -11,6 +11,7 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
@@ -25,7 +26,7 @@ type GenericConn struct {
var _ Conn = &GenericConn{}
-func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
if err != nil {
@@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
-func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
- _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
+func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
+ _, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
return err
}
-func (u *GenericConn) LocalAddr() (*Addr, error) {
+func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
a := u.UDPConn.LocalAddr()
switch v := a.(type) {
case *net.UDPAddr:
- addr := &Addr{IP: make([]byte, len(v.IP))}
- copy(addr.IP, v.IP)
- addr.Port = uint16(v.Port)
- return addr, nil
+ addr, ok := netip.AddrFromSlice(v.IP)
+ if !ok {
+ return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+ }
+ return netip.AddrPortFrom(addr, uint16(v.Port)), nil
default:
- return nil, fmt.Errorf("LocalAddr returned: %#v", a)
+ return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
}
}
@@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
// Just read one packet at a time
- n, rua, err := u.ReadFromUDP(buffer)
+ n, rua, err := u.ReadFromUDPAddrPort(buffer)
if err != nil {
u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
return
}
- udpAddr.IP = rua.IP
- udpAddr.Port = uint16(rua.Port)
- r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
+ plaintext[:0],
+ buffer[:n],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
diff --git a/udp/udp_linux.go b/udp/udp_linux.go
index 1151c89..ef07243 100644
--- a/udp/udp_linux.go
+++ b/udp/udp_linux.go
@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
+ "net/netip"
"syscall"
"unsafe"
@@ -27,25 +28,6 @@ type StdConn struct {
batch int
}
-var x int
-
-// From linux/sock_diag.h
-const (
- _SK_MEMINFO_RMEM_ALLOC = iota
- _SK_MEMINFO_RCVBUF
- _SK_MEMINFO_WMEM_ALLOC
- _SK_MEMINFO_SNDBUF
- _SK_MEMINFO_FWD_ALLOC
- _SK_MEMINFO_WMEM_QUEUED
- _SK_MEMINFO_OPTMEM
- _SK_MEMINFO_BACKLOG
- _SK_MEMINFO_DROPS
-
- _SK_MEMINFO_VARS
-)
-
-type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
-
func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
@@ -54,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
return ip, false
}
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
- ipV4, isV4 := maybeIPV4(ip)
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
af := unix.AF_INET6
- if isV4 {
+ if ip.Is4() {
af = unix.AF_INET
}
syscall.ForkLock.RLock()
@@ -80,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//TODO: support multiple listening IPs (for limiting ipv6)
var sa unix.Sockaddr
- if isV4 {
+ if ip.Is4() {
sa4 := &unix.SockaddrInet4{Port: port}
- copy(sa4.Addr[:], ipV4)
+ sa4.Addr = ip.As4()
sa = sa4
} else {
sa6 := &unix.SockaddrInet6{Port: port}
- copy(sa6.Addr[:], ip.To16())
+ sa6.Addr = ip.As16()
sa = sa6
}
if err = unix.Bind(fd, sa); err != nil {
@@ -98,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
- return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
+ return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) Rebind() error {
@@ -121,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}
-func (u *StdConn) LocalAddr() (*Addr, error) {
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
sa, err := unix.Getsockname(u.sysFd)
if err != nil {
- return nil, err
+ return netip.AddrPort{}, err
}
- addr := &Addr{}
switch sa := sa.(type) {
case *unix.SockaddrInet4:
- addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
- addr.Port = uint16(sa.Port)
+ return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
+
case *unix.SockaddrInet6:
- addr.IP = sa.Addr[0:]
- addr.Port = uint16(sa.Port)
- }
+ return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
- return addr, nil
+ default:
+ return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
+ }
}
func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{}
+ var ip netip.Addr
nb := make([]byte, 12, 12)
//TODO: should we track this?
@@ -165,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
//metric.Update(int64(n))
for i := 0; i < n; i++ {
if u.isV4 {
- udpAddr.IP = names[i][4:8]
+ ip, _ = netip.AddrFromSlice(names[i][4:8])
+ //TODO: IPV6-WORK what is not ok?
} else {
- udpAddr.IP = names[i][8:24]
+ ip, _ = netip.AddrFromSlice(names[i][8:24])
+ //TODO: IPV6-WORK what is not ok?
}
- udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
- r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
+ plaintext[:0],
+ buffers[i][:msgs[i].Len],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
}
@@ -216,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
}
}
-func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
+func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
if u.isV4 {
- return u.writeTo4(b, addr)
+ return u.writeTo4(b, ip)
}
- return u.writeTo6(b, addr)
+ return u.writeTo6(b, ip)
}
-func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
+func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
+ rsa.Addr = ip.Addr().As16()
+ port := ip.Port()
// Little Endian -> Network Endian
- rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
- copy(rsa.Addr[:], addr.IP.To16())
+ rsa.Port = (port >> 8) | ((port & 0xff) << 8)
for {
_, _, err := unix.Syscall6(
@@ -251,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
}
}
-func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
- addrV4, isAddrV4 := maybeIPV4(addr.IP)
- if !isAddrV4 {
+func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
+ if !ip.Addr().Is4() {
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
+ rsa.Addr = ip.Addr().As4()
+ port := ip.Port()
// Little Endian -> Network Endian
- rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
- copy(rsa.Addr[:], addrV4)
+ rsa.Port = (port >> 8) | ((port & 0xff) << 8)
for {
_, _, err := unix.Syscall6(
@@ -316,8 +308,8 @@ func (u *StdConn) ReloadConfig(c *config.C) {
}
}
-func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
- var vallen uint32 = 4 * _SK_MEMINFO_VARS
+func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error {
+ var vallen uint32 = 4 * unix.SK_MEMINFO_VARS
_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
if err != 0 {
return err
@@ -332,12 +324,12 @@ func (u *StdConn) Close() error {
func NewUDPStatsEmitter(udpConns []Conn) func() {
// Check if our kernel supports SO_MEMINFO before registering the gauges
- var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
- var meminfo _SK_MEMINFO
+ var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge
+ var meminfo [unix.SK_MEMINFO_VARS]uint32
if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
- udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
+ udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
for i := range udpConns {
- udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
+ udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
@@ -354,7 +346,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() {
return func() {
for i, gauges := range udpGauges {
if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
- for j := 0; j < _SK_MEMINFO_VARS; j++ {
+ for j := 0; j < unix.SK_MEMINFO_VARS; j++ {
gauges[j].Update(int64(meminfo[j]))
}
}
diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go
index a54f1df..87a0de7 100644
--- a/udp/udp_linux_64.go
+++ b/udp/udp_linux_64.go
@@ -1,6 +1,6 @@
-//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64) && !android && !e2e_testing
+//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing
// +build linux
-// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64
+// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64
// +build !android
// +build !e2e_testing
diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go
index 3c14fac..3b69159 100644
--- a/udp/udp_netbsd.go
+++ b/udp/udp_netbsd.go
@@ -8,13 +8,14 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
return NewGenericListener(l, ip, port, multi, batch)
}
diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go
index 31c1a55..ee7e1e0 100644
--- a/udp/udp_rio_windows.go
+++ b/udp/udp_rio_windows.go
@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"sync"
"sync/atomic"
"syscall"
@@ -61,16 +62,14 @@ type RIOConn struct {
results [packetsPerRing]winrio.Result
}
-func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
if !winrio.Initialize() {
return nil, errors.New("could not initialize winrio")
}
u := &RIOConn{l: l}
- addr := [16]byte{}
- copy(addr[:], ip.To16())
- err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+ err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
if err != nil {
return nil, fmt.Errorf("bind: %w", err)
}
@@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
buffer := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- udpAddr := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
@@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
return
}
- udpAddr.IP = rua.Addr[:]
- p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
- p[0] = byte(rua.Port >> 8)
- p[1] = byte(rua.Port)
- r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(
+ netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
+ plaintext[:0],
+ buffer[:n],
+ h,
+ fwPacket,
+ lhf,
+ nb,
+ q,
+ cache.Get(u.l),
+ )
}
}
@@ -231,7 +235,7 @@ retry:
return n, ep, nil
}
-func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
if !u.isOpen.Load() {
return net.ErrClosed
}
@@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
packet := u.tx.Push()
packet.addr.Family = windows.AF_INET6
- p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
- p[0] = byte(addr.Port >> 8)
- p[1] = byte(addr.Port)
- copy(packet.addr.Addr[:], addr.IP.To16())
+ packet.addr.Addr = ip.Addr().As16()
+ port := ip.Port()
+ packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
@@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (u *RIOConn) LocalAddr() (*Addr, error) {
+func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
sa, err := windows.Getsockname(u.sock)
if err != nil {
- return nil, err
+ return netip.AddrPort{}, err
}
v6 := sa.(*windows.SockaddrInet6)
- return &Addr{
- IP: v6.Addr[:],
- Port: uint16(v6.Port),
- }, nil
+ return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
+
}
func (u *RIOConn) Rebind() error {
diff --git a/udp/udp_tester.go b/udp/udp_tester.go
index 55985f4..f03a353 100644
--- a/udp/udp_tester.go
+++ b/udp/udp_tester.go
@@ -4,9 +4,8 @@
package udp
import (
- "fmt"
"io"
- "net"
+ "net/netip"
"sync/atomic"
"github.com/sirupsen/logrus"
@@ -16,30 +15,24 @@ import (
)
type Packet struct {
- ToIp net.IP
- ToPort uint16
- FromIp net.IP
- FromPort uint16
- Data []byte
+ To netip.AddrPort
+ From netip.AddrPort
+ Data []byte
}
func (u *Packet) Copy() *Packet {
n := &Packet{
- ToIp: make(net.IP, len(u.ToIp)),
- ToPort: u.ToPort,
- FromIp: make(net.IP, len(u.FromIp)),
- FromPort: u.FromPort,
- Data: make([]byte, len(u.Data)),
+ To: u.To,
+ From: u.From,
+ Data: make([]byte, len(u.Data)),
}
- copy(n.ToIp, u.ToIp)
- copy(n.FromIp, u.FromIp)
copy(n.Data, u.Data)
return n
}
type TesterConn struct {
- Addr *Addr
+ Addr netip.AddrPort
RxPackets chan *Packet // Packets to receive into nebula
TxPackets chan *Packet // Packets transmitted outside by nebula
@@ -48,9 +41,9 @@ type TesterConn struct {
l *logrus.Logger
}
-func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
return &TesterConn{
- Addr: &Addr{ip, uint16(port)},
+ Addr: netip.AddrPortFrom(ip, uint16(port)),
RxPackets: make(chan *Packet, 10),
TxPackets: make(chan *Packet, 10),
l: l,
@@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
}
if u.l.Level >= logrus.DebugLevel {
u.l.WithField("header", h).
- WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+ WithField("udpAddr", packet.From).
WithField("dataLen", len(packet.Data)).
Debug("UDP receiving injected packet")
}
@@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
-func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
if u.closed.Load() {
return io.ErrClosedPipe
}
p := &Packet{
- Data: make([]byte, len(b), len(b)),
- FromIp: make([]byte, 16),
- FromPort: u.Addr.Port,
- ToIp: make([]byte, 16),
- ToPort: addr.Port,
+ Data: make([]byte, len(b), len(b)),
+ From: u.Addr,
+ To: addr,
}
copy(p.Data, b)
- copy(p.ToIp, addr.IP.To16())
- copy(p.FromIp, u.Addr.IP.To16())
-
u.TxPackets <- p
return nil
}
@@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
- ua := &Addr{IP: make([]byte, 16)}
nb := make([]byte, 12, 12)
for {
@@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
if !ok {
return
}
- ua.Port = p.FromPort
- copy(ua.IP, p.FromIp.To16())
- r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+ r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
}
@@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
return func() {}
}
-func (u *TesterConn) LocalAddr() (*Addr, error) {
+func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
diff --git a/udp/udp_windows.go b/udp/udp_windows.go
index ebcace6..1b777c3 100644
--- a/udp/udp_windows.go
+++ b/udp/udp_windows.go
@@ -6,12 +6,13 @@ package udp
import (
"fmt"
"net"
+ "net/netip"
"syscall"
"github.com/sirupsen/logrus"
)
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
if multi {
//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
// The udp stack would need to be reworked to hide away the implementation differences between