diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a118476..a2c2a87 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,9 +36,6 @@ jobs: matrix: # a list of all the targets include: - - TARGET: i686-unknown-linux-musl # test in an alpine container on a mac - OS: ubuntu-latest - FEATURES: normal,web - TARGET: x86_64-unknown-linux-musl # test in an alpine container on a mac OS: ubuntu-latest FEATURES: ring-cipher,web diff --git a/Cargo.lock b/Cargo.lock index 09488a2..276ec1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,7 +53,7 @@ dependencies = [ "actix-service", "actix-utils", "ahash", - "base64", + "base64 0.21.7", "bitflags 2.5.0", "brotli", "bytes", @@ -63,7 +63,7 @@ dependencies = [ "flate2", "futures-core", "h2", - "http", + "http 0.2.12", "httparse", "httpdate", "itoa", @@ -98,7 +98,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d22475596539443685426b6bdadb926ad0ecaefdfc5fb05e5e3441f15463c511" dependencies = [ "bytestring", - "http", + "http 0.2.12", "regex", "serde", "tracing", @@ -368,12 +368,24 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "base64ct" version = "1.6.0" @@ -392,6 +404,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -401,6 +422,28 @@ dependencies = [ "generic-array", ] +[[package]] +name = "boringtun" +version = "0.6.0" +dependencies = [ + "aead", + "base64 0.13.1", + "blake2", + "chacha20poly1305", + "hex", + "hmac", + "ip_network", + "ip_network_table", + "libc", + "nix", + "parking_lot", + "rand_core", + "ring", + "tracing", + "untrusted", + "x25519-dalek", +] + [[package]] name = "brotli" version = "3.5.0" @@ -466,6 +509,30 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "chacha20poly1305" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + [[package]] name = "change-detection" version = "1.2.0" @@ -498,6 +565,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ "crypto-common", "inout", + "zeroize", ] [[package]] @@ -670,24 +738,57 @@ dependencies = [ "cipher", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "dashmap" -version = "5.5.3" +version = "6.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" dependencies = [ "cfg-if", + "crossbeam-utils", "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" -version = "0.6.1" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1a467a65c5e759bce6e65eaf91cc29f466cdc57cb65777bd646872a8a1fd4de" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", "pem-rfc7468", @@ -742,6 +843,7 @@ dependencies = [ "block-buffer", "const-oid", "crypto-common", + "subtle", ] [[package]] @@ -802,6 +904,12 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "flate2" version = "1.0.29" @@ -870,6 +978,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", @@ -930,7 +1039,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", "indexmap 2.2.6", "slab", "tokio", @@ -966,6 +1075,21 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -986,6 +1110,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-range" version = "0.1.5" @@ -1072,6 +1207,37 @@ dependencies = [ "generic-array", ] +[[package]] +name = "ip_network" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2f047c0a98b2f299aa5d6d7088443570faae494e9ae1305e48be000c9e0eb1" + +[[package]] +name = "ip_network_table" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4099b7cfc5c5e2fe8c5edf3f6f7adf7a714c9cc697534f63a5a5da30397cb2c0" +dependencies = [ + "ip_network", + "ip_network_table-deps-treebitmap", +] + +[[package]] +name = "ip_network_table-deps-treebitmap" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e537132deb99c0eb4b752f0346b6a836200eaaa3516dd7e5514b63930a09e5d" + +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "is-terminal" version = "0.4.12" @@ -1282,6 +1448,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "nix" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f346ff70e7dbfd675fe90590b92d59ef2de15a8779ae305ebcbfd3f0caf59be4" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "libc", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -1440,9 +1618,9 @@ checksum = "498a099351efa4becc6a19c72aa9270598e8fd274ca47052e37455241c88b696" [[package]] name = "pem-rfc7468" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d159833a9105500e0398934e205e0773f0b27529557134ecfc51c27646adac" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" dependencies = [ "base64ct", ] @@ -1467,21 +1645,20 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs1" -version = "0.4.1" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eff33bdbdfc54cc98a2eca766ebdec3e1b8fb7387523d5c9c9a2891da856f719" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" dependencies = [ "der", "pkcs8", "spki", - "zeroize", ] [[package]] name = "pkcs8" -version = "0.9.0" +version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ "der", "spki", @@ -1493,6 +1670,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "poly1305" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "polyval" version = "0.6.2" @@ -1756,21 +1944,20 @@ dependencies = [ [[package]] name = "rsa" -version = "0.7.2" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "094052d5470cbcef561cb848a7209968c9f12dfa6d668f4bca048ac5de51099c" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" dependencies = [ - "byteorder", + "const-oid", "digest", "num-bigint-dig", "num-integer", - "num-iter", "num-traits", "pkcs1", "pkcs8", "rand_core", "signature", - "smallvec", + "spki", "subtle", "zeroize", ] @@ -1920,9 +2107,9 @@ dependencies = [ [[package]] name = "signature" -version = "1.6.4" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest", "rand_core", @@ -1967,9 +2154,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "spki" -version = "0.6.0" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", "der", @@ -2154,6 +2341,18 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "tokio-tungstenite" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -2180,9 +2379,21 @@ checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "tracing-core" version = "0.1.32" @@ -2198,6 +2409,24 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" +[[package]] +name = "tungstenite" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typemap-ors" version = "1.0.0" @@ -2285,6 +2514,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "uuid" version = "1.8.0" @@ -2308,13 +2543,16 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "vnts" -version = "1.2.9" +version = "1.2.12" dependencies = [ "actix-files", "actix-web", "actix-web-static-files", "aes-gcm", + "anyhow", "async-trait", + "base64 0.22.1", + "boringtun", "chrono", "clap", "colored", @@ -2323,6 +2561,7 @@ dependencies = [ "dashmap", "dirs", "futures-util", + "ipnetwork", "lazy_static", "log", "log4rs", @@ -2342,6 +2581,7 @@ dependencies = [ "static-files", "thiserror", "tokio", + "tokio-tungstenite", "tokio-util", "uuid", ] @@ -2597,6 +2837,18 @@ version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core", + "serde", + "zeroize", +] + [[package]] name = "zerocopy" version = "0.7.32" @@ -2622,6 +2874,20 @@ name = "zeroize" version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] [[package]] name = "zstd" diff --git a/Cargo.toml b/Cargo.toml index e94a28e..27eca9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vnts" -version = "1.2.9" +version = "1.2.12" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -13,15 +13,16 @@ log4rs = "1.3" dirs = "5" crossbeam = "0.8" parking_lot = "0.12" -dashmap = "5.5" +dashmap = "6.0.1" -rsa = { version = "0.7.2", features = [] } -spki = { version = "0.6.0", features = ["fingerprint", "alloc"] } +rsa = { version = "0.9.6", features = [] } +spki = { version = "0.7.3", features = ["fingerprint", "alloc", "base64"] } aes-gcm = { version = "0.10.2", optional = true } ring = { version = "0.17", optional = true } rand = "0.8" sha2 = { version = "0.10", features = ["oid"] } colored = "2.1" +anyhow = "1.0.82" thiserror = "1" chrono = "0.4" @@ -36,6 +37,11 @@ socket2 = { version = "0.5", features = ["all"] } actix-web = { version = "4.5", optional = true } actix-files = { version = "0.6", optional = true } actix-web-static-files = { version = "4.0.1", optional = true } +tokio-tungstenite = "0.23.1" + +boringtun = { path = "lib/boringtun", features = [] } +ipnetwork = "0.20.0" +base64 = "0.22.1" serde = { version = "1", features = ["derive"] } crossbeam-utils = "0.8" diff --git a/README.md b/README.md index 4c9b450..d77ca5f 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,18 @@ # vnts -[vnt](https://github.com/lbl8603/vnt)的服务端 +[vnt](https://github.com/vnt-dev/vnt)的服务端 查看参数 ``` Options: - --port 指定端口,默认29872 - --white-token token白名单,例如 --white-token 1234 --white-token 123 - --gateway 网关,例如 --gateway 10.10.0.1 - --netmask 子网掩码,例如 --netmask 255.255.255.0 - --finger 开启指纹校验,开启后只会转发指纹正确的客户端数据包,增强安全性,这会损失一部分性能 - --log-path log路径,默认为当前程序路径,为/dev/null时表示不输出log - --web-port web后台端口,默认29870,如果设置为0则表示不启动web后台 - --username web后台用户名,默认为admin - --password web后台用户密码,默认为admin + -p, --port 指定端口,默认29872 + -w, --white-token token白名单,例如 --white-token 1234 --white-token 123 + -g, --gateway 网关,例如 --gateway 10.10.0.1 + -m, --netmask 子网掩码,例如 --netmask 255.255.255.0 + -f, --finger 开启指纹校验,开启后只会转发指纹正确的客户端数据包,增强安全性,这会损失一部分性能 + -l, --log-path log路径,默认为当前程序路径,为/dev/null时表示不输出log + --wg wg私钥,使用base64编码 -h, --help Print help information -V, --version Print version information ``` diff --git a/lib/boringtun/Cargo.toml b/lib/boringtun/Cargo.toml new file mode 100644 index 0000000..1c528e7 --- /dev/null +++ b/lib/boringtun/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "boringtun" +description = "an implementation of the WireGuard® protocol designed for portability and speed" +version = "0.6.0" +authors = [ + "Noah Kennedy ", + "Andy Grover ", + "Jeff Hiner ", +] +license = "BSD-3-Clause" +repository = "https://github.com/cloudflare/boringtun" +documentation = "https://docs.rs/boringtun/0.5.2/boringtun/" +edition = "2018" + +[features] +default = [] +device = ["socket2", "thiserror"] +jni-bindings = ["ffi-bindings", "jni"] +ffi-bindings = ["tracing-subscriber"] +# mocks std::time::Instant with mock_instant +mock-instant = ["mock_instant"] + +[dependencies] +base64 = "0.13" +hex = "0.4" +untrusted = "0.9.0" +libc = "0.2" +parking_lot = "0.12" +tracing = "0.1.40" +tracing-subscriber = { version = "0.3", features = ["fmt"], optional = true } +ip_network = "0.4.1" +ip_network_table = "0.2.0" +ring = "0.17" +x25519-dalek = { version = "2.0.0", features = [ + "reusable_secrets", + "static_secrets", +] } +rand_core = { version = "0.6.4", features = ["getrandom"] } +chacha20poly1305 = "0.10.0-pre.1" +aead = "0.5.0-pre.2" +blake2 = "0.10" +hmac = "0.12" +jni = { version = "0.19.0", optional = true } +mock_instant = { version = "0.3", optional = true } +socket2 = { version = "0.4.7", features = ["all"], optional = true } +thiserror = { version = "1", optional = true } + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.25", default-features = false, features = [ + "time", + "user", +] } + +[dev-dependencies] +etherparse = "0.13" +tracing-subscriber = "0.3" +criterion = { version = "0.3.5", features = ["html_reports"] } + +[lib] +crate-type = ["staticlib", "cdylib", "rlib"] + +[[bench]] +name = "crypto_benches" +harness = false diff --git a/lib/boringtun/benches/crypto_benches/blake2s_benching.rs b/lib/boringtun/benches/crypto_benches/blake2s_benching.rs new file mode 100644 index 0000000..3698172 --- /dev/null +++ b/lib/boringtun/benches/crypto_benches/blake2s_benching.rs @@ -0,0 +1,90 @@ +use blake2::digest::{FixedOutput, KeyInit}; +use blake2::{Blake2s256, Blake2sMac, Digest}; +use criterion::{BenchmarkId, Criterion, Throughput}; +use ring::rand::{SecureRandom, SystemRandom}; + +pub fn bench_blake2s_hash(c: &mut Criterion) { + let mut group = c.benchmark_group("blake2s_hash"); + + group.sample_size(1000); + + for size in [32, 64, 128] { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| { + let buf_in = vec![0u8; size]; + + b.iter(|| { + let mut hasher = Blake2s256::new(); + hasher.update(&buf_in); + hasher.finalize(); + }); + }); + } + + group.finish(); +} + +pub fn bench_blake2s_hmac(c: &mut Criterion) { + let mut group = c.benchmark_group("blake2s_hmac"); + + group.sample_size(1000); + + for size in [16, 32] { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| { + let buf_in = vec![0u8; size]; + let rng = SystemRandom::new(); + + b.iter_batched( + || { + let mut key = [0u8; 32]; + rng.fill(&mut key).unwrap(); + key + }, + |key| { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(&key).unwrap(); + hmac.update(&buf_in); + hmac.finalize_fixed(); + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +pub fn bench_blake2s_keyed(c: &mut Criterion) { + let mut group = c.benchmark_group("blake2s_keyed_mac"); + + group.sample_size(1000); + + for size in [128, 1024] { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input(BenchmarkId::new("blake2s_crate", size), &size, |b, _| { + let buf_in = vec![0u8; size]; + let rng = SystemRandom::new(); + + b.iter_batched( + || { + let mut key = [0u8; 16]; + rng.fill(&mut key).unwrap(); + key + }, + |key| -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(&key).unwrap(); + blake2::digest::Update::update(&mut hmac, &buf_in); + hmac.finalize_fixed().into() + }, + criterion::BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} diff --git a/lib/boringtun/benches/crypto_benches/chacha20poly1305_benching.rs b/lib/boringtun/benches/crypto_benches/chacha20poly1305_benching.rs new file mode 100644 index 0000000..ed857a7 --- /dev/null +++ b/lib/boringtun/benches/crypto_benches/chacha20poly1305_benching.rs @@ -0,0 +1,79 @@ +use aead::{AeadInPlace, KeyInit}; +use criterion::{BenchmarkId, Criterion, Throughput}; +use rand_core::{OsRng, RngCore}; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; + +fn chacha20poly1305_ring(key_bytes: &[u8], buf: &mut [u8]) { + let len = buf.len(); + let n = len - 16; + + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key_bytes).unwrap()); + + let tag = key + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key([0u8; 12]), + Aad::from(&[]), + &mut buf[..n], + ) + .unwrap(); + + buf[n..].copy_from_slice(tag.as_ref()) +} + +fn chacha20poly1305_non_ring(key_bytes: &[u8], buf: &mut [u8]) { + let len = buf.len(); + let n = len - 16; + + let aead = chacha20poly1305::ChaCha20Poly1305::new_from_slice(key_bytes).unwrap(); + let nonce = chacha20poly1305::Nonce::default(); + + let tag = aead + .encrypt_in_place_detached(&nonce, &[], &mut buf[..n]) + .unwrap(); + + buf[n..].copy_from_slice(tag.as_ref()); +} + +pub fn bench_chacha20poly1305(c: &mut Criterion) { + let mut group = c.benchmark_group("chacha20poly1305"); + + group.sample_size(1000); + + for size in [128, 192, 1400, 8192] { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_with_input( + BenchmarkId::new("chacha20poly1305_ring", size), + &size, + |b, i| { + let mut key = [0; 32]; + let mut buf = vec![0; i + 16]; + + let mut rng = OsRng::default(); + + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut buf); + + b.iter(|| chacha20poly1305_ring(&key, &mut buf)); + }, + ); + + group.bench_with_input( + BenchmarkId::new("chacha20poly1305_non_ring", size), + &size, + |b, i| { + let mut key = [0; 32]; + let mut buf = vec![0; i + 16]; + + let mut rng = OsRng::default(); + + rng.fill_bytes(&mut key); + rng.fill_bytes(&mut buf); + + b.iter(|| chacha20poly1305_non_ring(&key, &mut buf)); + }, + ); + } + + group.finish(); +} diff --git a/lib/boringtun/benches/crypto_benches/main.rs b/lib/boringtun/benches/crypto_benches/main.rs new file mode 100644 index 0000000..13349a1 --- /dev/null +++ b/lib/boringtun/benches/crypto_benches/main.rs @@ -0,0 +1,20 @@ +use blake2s_benching::{bench_blake2s_hash, bench_blake2s_hmac, bench_blake2s_keyed}; +use chacha20poly1305_benching::bench_chacha20poly1305; +use x25519_public_key_benching::bench_x25519_public_key; +use x25519_shared_key_benching::bench_x25519_shared_key; + +mod blake2s_benching; +mod chacha20poly1305_benching; +mod x25519_public_key_benching; +mod x25519_shared_key_benching; + +criterion::criterion_group!( + crypto_benches, + bench_chacha20poly1305, + bench_blake2s_hash, + bench_blake2s_hmac, + bench_blake2s_keyed, + bench_x25519_shared_key, + bench_x25519_public_key +); +criterion::criterion_main!(crypto_benches); diff --git a/lib/boringtun/benches/crypto_benches/x25519_public_key_benching.rs b/lib/boringtun/benches/crypto_benches/x25519_public_key_benching.rs new file mode 100644 index 0000000..7e25759 --- /dev/null +++ b/lib/boringtun/benches/crypto_benches/x25519_public_key_benching.rs @@ -0,0 +1,30 @@ +use criterion::Criterion; +use rand_core::OsRng; + +pub fn bench_x25519_public_key(c: &mut Criterion) { + let mut group = c.benchmark_group("x25519_public_key"); + + group.sample_size(1000); + + group.bench_function("x25519_public_key_dalek", |b| { + b.iter(|| { + let secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let public_key = x25519_dalek::PublicKey::from(&secret_key); + + (secret_key, public_key) + }); + }); + + group.bench_function("x25519_public_key_ring", |b| { + let rng = ring::rand::SystemRandom::new(); + + b.iter(|| { + let my_private_key = + ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng) + .unwrap(); + my_private_key.compute_public_key().unwrap() + }); + }); + + group.finish(); +} diff --git a/lib/boringtun/benches/crypto_benches/x25519_shared_key_benching.rs b/lib/boringtun/benches/crypto_benches/x25519_shared_key_benching.rs new file mode 100644 index 0000000..a3c1145 --- /dev/null +++ b/lib/boringtun/benches/crypto_benches/x25519_shared_key_benching.rs @@ -0,0 +1,48 @@ +use criterion::{BatchSize, Criterion}; +use rand_core::OsRng; + +pub fn bench_x25519_shared_key(c: &mut Criterion) { + let mut group = c.benchmark_group("x25519_shared_key"); + + group.sample_size(1000); + + group.bench_function("x25519_shared_key_dalek", |b| { + let public_key = + x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random_from_rng(OsRng)); + + b.iter_batched( + || x25519_dalek::StaticSecret::random_from_rng(OsRng), + |secret_key| secret_key.diffie_hellman(&public_key), + BatchSize::SmallInput, + ); + }); + + group.bench_function("x25519_shared_key_ring", |b| { + let rng = ring::rand::SystemRandom::new(); + + let peer_public_key = { + let peer_private_key = + ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng) + .unwrap(); + peer_private_key.compute_public_key().unwrap() + }; + let peer_public_key_alg = &ring::agreement::X25519; + + let my_public_key = + ring::agreement::UnparsedPublicKey::new(peer_public_key_alg, &peer_public_key); + + b.iter_batched( + || { + ring::agreement::EphemeralPrivateKey::generate(&ring::agreement::X25519, &rng) + .unwrap() + }, + |my_private_key| { + ring::agreement::agree_ephemeral(my_private_key, &my_public_key, |_key_material| ()) + .unwrap() + }, + BatchSize::SmallInput, + ); + }); + + group.finish(); +} diff --git a/lib/boringtun/src/device/allowed_ips.rs b/lib/boringtun/src/device/allowed_ips.rs new file mode 100644 index 0000000..65600b2 --- /dev/null +++ b/lib/boringtun/src/device/allowed_ips.rs @@ -0,0 +1,389 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::device::peer::AllowedIP; + +use ip_network::IpNetwork; +use ip_network_table::IpNetworkTable; + +use std::collections::VecDeque; +use std::iter::FromIterator; +use std::net::IpAddr; + +/// A trie of IP/cidr addresses +#[derive(Default)] +pub struct AllowedIps { + ips: IpNetworkTable, +} + +impl<'a, D> FromIterator<(&'a AllowedIP, D)> for AllowedIps { + fn from_iter>(iter: I) -> Self { + let mut allowed_ips = AllowedIps::new(); + + for (ip, data) in iter { + allowed_ips.insert(ip.addr, ip.cidr as u32, data); + } + + allowed_ips + } +} + +impl AllowedIps { + pub fn new() -> Self { + Self { + ips: IpNetworkTable::new(), + } + } + + pub fn clear(&mut self) { + self.ips = IpNetworkTable::new(); + } + + pub fn insert(&mut self, key: IpAddr, cidr: u32, data: D) -> Option { + // These are networks, it doesn't make sense for host bits to be set, so + // use new_truncate(). + self.ips.insert( + IpNetwork::new_truncate(key, cidr as u8).expect("cidr is valid length"), + data, + ) + } + + pub fn find(&self, key: IpAddr) -> Option<&D> { + self.ips.longest_match(key).map(|(_net, data)| data) + } + + pub fn remove(&mut self, predicate: &dyn Fn(&D) -> bool) { + self.ips.retain(|_, v| !predicate(v)); + } + + pub fn iter(&self) -> Iter { + Iter( + self.ips + .iter() + .map(|(ipa, d)| (d, ipa.network_address(), ipa.netmask())) + .collect(), + ) + } +} + +pub struct Iter<'a, D: 'a>(VecDeque<(&'a D, IpAddr, u8)>); + +impl<'a, D> Iterator for Iter<'a, D> { + type Item = (&'a D, IpAddr, u8); + fn next(&mut self) -> Option { + self.0.pop_front() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn build_allowed_ips() -> AllowedIps { + let mut map: AllowedIps = Default::default(); + map.insert(IpAddr::from([127, 0, 0, 1]), 32, '1'); + map.insert(IpAddr::from([45, 25, 15, 1]), 30, '6'); + map.insert(IpAddr::from([127, 0, 15, 1]), 16, '2'); + map.insert(IpAddr::from([127, 1, 15, 1]), 24, '3'); + map.insert(IpAddr::from([255, 1, 15, 1]), 24, '4'); + map.insert(IpAddr::from([60, 25, 15, 1]), 32, '5'); + map.insert(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128, '7'); + map + } + + #[test] + fn test_allowed_ips_insert_find() { + let map = build_allowed_ips(); + assert_eq!(map.find(IpAddr::from([127, 0, 0, 1])), Some(&'1')); + assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2')); + assert_eq!(map.find(IpAddr::from([127, 1, 255, 255])), None); + assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2')); + assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3')); + assert_eq!(map.find(IpAddr::from([127, 0, 255, 255])), Some(&'2')); + assert_eq!(map.find(IpAddr::from([127, 1, 15, 255])), Some(&'3')); + assert_eq!(map.find(IpAddr::from([255, 1, 15, 2])), Some(&'4')); + assert_eq!(map.find(IpAddr::from([60, 25, 15, 1])), Some(&'5')); + assert_eq!(map.find(IpAddr::from([20, 0, 0, 100])), None); + assert_eq!( + map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0])), + Some(&'7') + ); + assert_eq!(map.find(IpAddr::from([553, 0, 0, 1, 0, 0, 0, 1])), None); + assert_eq!(map.find(IpAddr::from([45, 25, 15, 1])), Some(&'6')); + } + + #[test] + fn test_allowed_ips_remove() { + let mut map = build_allowed_ips(); + map.remove(&|c| *c == '5' || *c == '1' || *c == '7'); + + let mut map_iter = map.iter(); + assert_eq!( + map_iter.next(), + Some((&'6', IpAddr::from([45, 25, 15, 0]), 30)) + ); + assert_eq!( + map_iter.next(), + Some((&'2', IpAddr::from([127, 0, 0, 0]), 16)) + ); + assert_eq!( + map_iter.next(), + Some((&'3', IpAddr::from([127, 1, 15, 0]), 24)) + ); + assert_eq!( + map_iter.next(), + Some((&'4', IpAddr::from([255, 1, 15, 0]), 24)) + ); + assert_eq!(map_iter.next(), None); + } + + #[test] + fn test_allowed_ips_iter() { + let map = build_allowed_ips(); + let mut map_iter = map.iter(); + assert_eq!( + map_iter.next(), + Some((&'6', IpAddr::from([45, 25, 15, 0]), 30)) + ); + assert_eq!( + map_iter.next(), + Some((&'5', IpAddr::from([60, 25, 15, 1]), 32)) + ); + assert_eq!( + map_iter.next(), + Some((&'2', IpAddr::from([127, 0, 0, 0]), 16)) + ); + assert_eq!( + map_iter.next(), + Some((&'1', IpAddr::from([127, 0, 0, 1]), 32)) + ); + assert_eq!( + map_iter.next(), + Some((&'3', IpAddr::from([127, 1, 15, 0]), 24)) + ); + assert_eq!( + map_iter.next(), + Some((&'4', IpAddr::from([255, 1, 15, 0]), 24)) + ); + assert_eq!( + map_iter.next(), + Some((&'7', IpAddr::from([553, 0, 0, 1, 0, 0, 0, 0]), 128)) + ); + assert_eq!(map_iter.next(), None); + } + + #[test] + fn test_allowed_ips_v4_kernel_compatibility() { + // Test case from wireguard-go + let mut map: AllowedIps = Default::default(); + + map.insert(IpAddr::from([192, 168, 4, 0]), 24, 'a'); + map.insert(IpAddr::from([192, 168, 4, 4]), 32, 'b'); + map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'c'); + map.insert(IpAddr::from([192, 95, 5, 64]), 27, 'd'); + map.insert(IpAddr::from([192, 95, 5, 65]), 27, 'c'); + map.insert(IpAddr::from([0, 0, 0, 0]), 0, 'e'); + map.insert(IpAddr::from([64, 15, 112, 0]), 20, 'g'); + map.insert(IpAddr::from([64, 15, 123, 211]), 25, 'h'); + map.insert(IpAddr::from([10, 0, 0, 0]), 25, 'a'); + map.insert(IpAddr::from([10, 0, 0, 128]), 25, 'b'); + map.insert(IpAddr::from([10, 1, 0, 0]), 30, 'a'); + map.insert(IpAddr::from([10, 1, 0, 4]), 30, 'b'); + map.insert(IpAddr::from([10, 1, 0, 8]), 29, 'c'); + map.insert(IpAddr::from([10, 1, 0, 16]), 29, 'd'); + + assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 20]))); + assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 168, 4, 0]))); + assert_eq!(Some(&'b'), map.find(IpAddr::from([192, 168, 4, 4]))); + assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 168, 200, 182]))); + assert_eq!(Some(&'c'), map.find(IpAddr::from([192, 95, 5, 68]))); + assert_eq!(Some(&'e'), map.find(IpAddr::from([192, 95, 5, 96]))); + assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 116, 26]))); + assert_eq!(Some(&'g'), map.find(IpAddr::from([64, 15, 127, 3]))); + + map.insert(IpAddr::from([1, 0, 0, 0]), 32, 'a'); + map.insert(IpAddr::from([64, 0, 0, 0]), 32, 'a'); + map.insert(IpAddr::from([128, 0, 0, 0]), 32, 'a'); + map.insert(IpAddr::from([192, 0, 0, 0]), 32, 'a'); + map.insert(IpAddr::from([255, 0, 0, 0]), 32, 'a'); + + assert_eq!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0]))); + assert_eq!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0]))); + assert_eq!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0]))); + assert_eq!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0]))); + assert_eq!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0]))); + + map.remove(&|c| *c == 'a'); + + assert_ne!(Some(&'a'), map.find(IpAddr::from([1, 0, 0, 0]))); + assert_ne!(Some(&'a'), map.find(IpAddr::from([64, 0, 0, 0]))); + assert_ne!(Some(&'a'), map.find(IpAddr::from([128, 0, 0, 0]))); + assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 0, 0, 0]))); + assert_ne!(Some(&'a'), map.find(IpAddr::from([255, 0, 0, 0]))); + + map.clear(); + + map.insert(IpAddr::from([192, 168, 0, 0]), 16, 'a'); + map.insert(IpAddr::from([192, 168, 0, 0]), 24, 'a'); + + map.remove(&|c| *c == 'a'); + + assert_ne!(Some(&'a'), map.find(IpAddr::from([192, 168, 0, 1]))); + } + + #[test] + fn test_allowed_ips_v6_kernel_compatibility() { + // Test case from wireguard-go + let mut map: AllowedIps = Default::default(); + + map.insert( + IpAddr::from([ + 0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543, + ]), + 128, + 'd', + ); + map.insert( + IpAddr::from([ + 0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0x0000, 0x0000, + ]), + 64, + 'c', + ); + map.insert( + IpAddr::from([ + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + ]), + 0, + 'e', + ); + map.insert( + IpAddr::from([ + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + ]), + 0, + 'f', + ); + map.insert( + IpAddr::from([ + 0x2404, 0x6800, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + ]), + 32, + 'g', + ); + map.insert( + IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef, + ]), + 64, + 'h', + ); + map.insert( + IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef, + ]), + 128, + 'a', + ); + map.insert( + IpAddr::from([ + 0x2444, 0x6800, 0x40e4, 0x0800, 0xdeae, 0xbeef, 0x0def, 0xbeef, + ]), + 128, + 'c', + ); + map.insert( + IpAddr::from([ + 0x2444, 0x6800, 0xf0e4, 0x0800, 0xeeae, 0xbeef, 0x0000, 0x0000, + ]), + 98, + 'b', + ); + + assert_eq!( + Some(&'d'), + map.find(IpAddr::from([ + 0x2607, 0x5300, 0x6000, 0x6b00, 0x0000, 0x0000, 0xc05f, 0x0543 + ])) + ); + assert_eq!( + Some(&'c'), + map.find(IpAddr::from([ + 0x2607, 0x5300, 0x6000, 0x6b00, 0, 0, 0xc02e, 0x01ee + ])) + ); + assert_eq!( + Some(&'f'), + map.find(IpAddr::from([0x2607, 0x5300, 0x6000, 0x6b01, 0, 0, 0, 0])) + ); + assert_eq!( + Some(&'g'), + map.find(IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0806, 0, 0, 0, 0x1006 + ])) + ); + assert_eq!( + Some(&'g'), + map.find(IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678 + ])) + ); + assert_eq!( + Some(&'f'), + map.find(IpAddr::from([ + 0x2404, 0x67ff, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678 + ])) + ); + assert_eq!( + Some(&'f'), + map.find(IpAddr::from([ + 0x2404, 0x6801, 0x4004, 0x0806, 0, 0x1234, 0, 0x5678 + ])) + ); + assert_eq!( + Some(&'h'), + map.find(IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0800, 0, 0x1234, 0, 0x5678 + ])) + ); + assert_eq!( + Some(&'h'), + map.find(IpAddr::from([0x2404, 0x6800, 0x4004, 0x0800, 0, 0, 0, 0])) + ); + assert_eq!( + Some(&'h'), + map.find(IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0800, 0x1010, 0x1010, 0x1010, 0x1010 + ])) + ); + assert_eq!( + Some(&'a'), + map.find(IpAddr::from([ + 0x2404, 0x6800, 0x4004, 0x0800, 0xdead, 0xbeef, 0xdead, 0xbeef + ])) + ); + } + + #[test] + fn test_allowed_ips_iter_zero_leaf_bits() { + let mut map: AllowedIps = Default::default(); + map.insert(IpAddr::from([10, 111, 0, 1]), 32, '1'); + map.insert(IpAddr::from([10, 111, 0, 2]), 32, '2'); + map.insert(IpAddr::from([10, 111, 0, 3]), 32, '3'); + + let mut map_iter = map.iter(); + assert_eq!( + map_iter.next(), + Some((&'1', IpAddr::from([10, 111, 0, 1]), 32)) + ); + assert_eq!( + map_iter.next(), + Some((&'2', IpAddr::from([10, 111, 0, 2]), 32)) + ); + assert_eq!( + map_iter.next(), + Some((&'3', IpAddr::from([10, 111, 0, 3]), 32)) + ); + assert_eq!(map_iter.next(), None); + } +} diff --git a/lib/boringtun/src/device/api.rs b/lib/boringtun/src/device/api.rs new file mode 100644 index 0000000..0486de6 --- /dev/null +++ b/lib/boringtun/src/device/api.rs @@ -0,0 +1,368 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::dev_lock::LockReadGuard; +use super::drop_privileges::get_saved_ids; +use super::{AllowedIP, Device, Error, SocketAddr}; +use crate::device::Action; +use crate::serialization::KeyBytes; +use crate::x25519; +use hex::encode as encode_hex; +use libc::*; +use std::fs::{create_dir, remove_file}; +use std::io::{BufRead, BufReader, BufWriter, Write}; +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::sync::atomic::Ordering; + +const SOCK_DIR: &str = "/var/run/wireguard/"; + +fn create_sock_dir() { + let _ = create_dir(SOCK_DIR); // Create the directory if it does not exist + + if let Ok((saved_uid, saved_gid)) = get_saved_ids() { + unsafe { + let c_path = std::ffi::CString::new(SOCK_DIR).unwrap(); + // The directory is under the root user, but we want to be able to + // delete the files there when we exit, so we need to change the owner + chown( + c_path.as_bytes_with_nul().as_ptr() as _, + saved_uid, + saved_gid, + ); + } + } +} + +impl Device { + /// Register the api handler for this Device. The api handler receives stream connections on a Unix socket + /// with a known path: /var/run/wireguard/{tun_name}.sock. + pub fn register_api_handler(&mut self) -> Result<(), Error> { + let path = format!("{}/{}.sock", SOCK_DIR, self.iface.name()?); + + create_sock_dir(); + + let _ = remove_file(&path); // Attempt to remove the socket if already exists + + let api_listener = UnixListener::bind(&path).map_err(Error::ApiSocket)?; // Bind a new socket to the path + + self.cleanup_paths.push(path.clone()); + + self.queue.new_event( + api_listener.as_raw_fd(), + Box::new(move |d, _| { + // This is the closure that listens on the api unix socket + let (api_conn, _) = match api_listener.accept() { + Ok(conn) => conn, + _ => return Action::Continue, + }; + + let mut reader = BufReader::new(&api_conn); + let mut writer = BufWriter::new(&api_conn); + let mut cmd = String::new(); + if reader.read_line(&mut cmd).is_ok() { + cmd.pop(); // pop the new line character + let status = match cmd.as_ref() { + // Only two commands are legal according to the protocol, get=1 and set=1. + "get=1" => api_get(&mut writer, d), + "set=1" => api_set(&mut reader, d), + _ => EIO, + }; + // The protocol requires to return an error code as the response, or zero on success + writeln!(writer, "errno={}\n", status).ok(); + } + Action::Continue // Indicates the worker thread should continue as normal + }), + )?; + + self.register_monitor(path)?; + self.register_api_signal_handlers() + } + + pub fn register_api_fd(&mut self, fd: i32) -> Result<(), Error> { + let io_file = unsafe { UnixStream::from_raw_fd(fd) }; + + self.queue.new_event( + io_file.as_raw_fd(), + Box::new(move |d, _| { + // This is the closure that listens on the api file descriptor + + let mut reader = BufReader::new(&io_file); + let mut writer = BufWriter::new(&io_file); + let mut cmd = String::new(); + if reader.read_line(&mut cmd).is_ok() { + cmd.pop(); // pop the new line character + let status = match cmd.as_ref() { + // Only two commands are legal according to the protocol, get=1 and set=1. + "get=1" => api_get(&mut writer, d), + "set=1" => api_set(&mut reader, d), + _ => EIO, + }; + // The protocol requires to return an error code as the response, or zero on success + writeln!(writer, "errno={}\n", status).ok(); + } else { + // The remote side is likely closed; we should trigger an exit. + d.trigger_exit(); + return Action::Exit; + } + + Action::Continue // Indicates the worker thread should continue as normal + }), + )?; + + Ok(()) + } + + fn register_monitor(&self, path: String) -> Result<(), Error> { + self.queue.new_periodic_event( + Box::new(move |d, _| { + // This is not a very nice hack to detect if the control socket was removed + // and exiting nicely as a result. We check every 3 seconds in a loop if the + // file was deleted by stating it. + // The problem is that on linux inotify can be used quite beautifully to detect + // deletion, and kqueue EVFILT_VNODE can be used for the same purpose, but that + // will require introducing new events, for no measurable benefit. + // TODO: Could this be an issue if we restart the service too quickly? + let path = std::path::Path::new(&path); + if !path.exists() { + d.trigger_exit(); + return Action::Exit; + } + + // Periodically read the mtu of the interface in case it changes + if let Ok(mtu) = d.iface.mtu() { + d.mtu.store(mtu, Ordering::Relaxed); + } + + Action::Continue + }), + std::time::Duration::from_millis(1000), + )?; + + Ok(()) + } + + fn register_api_signal_handlers(&self) -> Result<(), Error> { + self.queue + .new_signal_event(SIGINT, Box::new(move |_, _| Action::Exit))?; + + self.queue + .new_signal_event(SIGTERM, Box::new(move |_, _| Action::Exit))?; + + Ok(()) + } +} + +#[allow(unused_must_use)] +fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 { + // get command requires an empty line, but there is no reason to be religious about it + if let Some(ref k) = d.key_pair { + writeln!(writer, "own_public_key={}", encode_hex(k.1.as_bytes())); + } + + if d.listen_port != 0 { + writeln!(writer, "listen_port={}", d.listen_port); + } + + if let Some(fwmark) = d.fwmark { + writeln!(writer, "fwmark={}", fwmark); + } + + for (k, p) in d.peers.iter() { + let p = p.lock(); + writeln!(writer, "public_key={}", encode_hex(k.as_bytes())); + + if let Some(ref key) = p.preshared_key() { + writeln!(writer, "preshared_key={}", encode_hex(key)); + } + + if let Some(keepalive) = p.persistent_keepalive() { + writeln!(writer, "persistent_keepalive_interval={}", keepalive); + } + + if let Some(ref addr) = p.endpoint().addr { + writeln!(writer, "endpoint={}", addr); + } + + for (ip, cidr) in p.allowed_ips() { + writeln!(writer, "allowed_ip={}/{}", ip, cidr); + } + + if let Some(time) = p.time_since_last_handshake() { + writeln!(writer, "last_handshake_time_sec={}", time.as_secs()); + writeln!(writer, "last_handshake_time_nsec={}", time.subsec_nanos()); + } + + let (_, tx_bytes, rx_bytes, ..) = p.tunnel.stats(); + + writeln!(writer, "rx_bytes={}", rx_bytes); + writeln!(writer, "tx_bytes={}", tx_bytes); + } + 0 +} + +fn api_set(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard) -> i32 { + d.try_writeable( + |device| device.trigger_yield(), + |device| { + device.cancel_yield(); + + let mut cmd = String::new(); + + while reader.read_line(&mut cmd).is_ok() { + cmd.pop(); // remove newline if any + if cmd.is_empty() { + return 0; // Done + } + { + let parsed_cmd: Vec<&str> = cmd.split('=').collect(); + if parsed_cmd.len() != 2 { + return EPROTO; + } + + let (key, val) = (parsed_cmd[0], parsed_cmd[1]); + + match key { + "private_key" => match val.parse::() { + Ok(key_bytes) => { + device.set_key(x25519::StaticSecret::from(key_bytes.0)) + } + Err(_) => return EINVAL, + }, + "listen_port" => match val.parse::() { + Ok(port) => match device.open_listen_socket(port) { + Ok(()) => {} + Err(_) => return EADDRINUSE, + }, + Err(_) => return EINVAL, + }, + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "linux" + ))] + "fwmark" => match val.parse::() { + Ok(mark) => match device.set_fwmark(mark) { + Ok(()) => {} + Err(_) => return EADDRINUSE, + }, + Err(_) => return EINVAL, + }, + "replace_peers" => match val.parse::() { + Ok(true) => device.clear_peers(), + Ok(false) => {} + Err(_) => return EINVAL, + }, + "public_key" => match val.parse::() { + // Indicates a new peer section + Ok(key_bytes) => { + return api_set_peer( + reader, + device, + x25519::PublicKey::from(key_bytes.0), + ) + } + Err(_) => return EINVAL, + }, + _ => return EINVAL, + } + } + cmd.clear(); + } + + 0 + }, + ) + .unwrap_or(EIO) +} + +fn api_set_peer( + reader: &mut BufReader<&UnixStream>, + d: &mut Device, + pub_key: x25519::PublicKey, +) -> i32 { + let mut cmd = String::new(); + + let mut remove = false; + let mut replace_ips = false; + let mut endpoint = None; + let mut keepalive = None; + let mut public_key = pub_key; + let mut preshared_key = None; + let mut allowed_ips: Vec = vec![]; + while reader.read_line(&mut cmd).is_ok() { + cmd.pop(); // remove newline if any + if cmd.is_empty() { + d.update_peer( + public_key, + remove, + replace_ips, + endpoint, + allowed_ips.as_slice(), + keepalive, + preshared_key, + ); + allowed_ips.clear(); //clear the vector content after update + return 0; // Done + } + { + let parsed_cmd: Vec<&str> = cmd.splitn(2, '=').collect(); + if parsed_cmd.len() != 2 { + return EPROTO; + } + let (key, val) = (parsed_cmd[0], parsed_cmd[1]); + match key { + "remove" => match val.parse::() { + Ok(true) => remove = true, + Ok(false) => remove = false, + Err(_) => return EINVAL, + }, + "preshared_key" => match val.parse::() { + Ok(key_bytes) => preshared_key = Some(key_bytes.0), + Err(_) => return EINVAL, + }, + "endpoint" => match val.parse::() { + Ok(addr) => endpoint = Some(addr), + Err(_) => return EINVAL, + }, + "persistent_keepalive_interval" => match val.parse::() { + Ok(interval) => keepalive = Some(interval), + Err(_) => return EINVAL, + }, + "replace_allowed_ips" => match val.parse::() { + Ok(true) => replace_ips = true, + Ok(false) => replace_ips = false, + Err(_) => return EINVAL, + }, + "allowed_ip" => match val.parse::() { + Ok(ip) => allowed_ips.push(ip), + Err(_) => return EINVAL, + }, + "public_key" => { + // Indicates a new peer section. Commit changes for current peer, and continue to next peer + d.update_peer( + public_key, + remove, + replace_ips, + endpoint, + allowed_ips.as_slice(), + keepalive, + preshared_key, + ); + allowed_ips.clear(); //clear the vector content after update + match val.parse::() { + Ok(key_bytes) => public_key = key_bytes.0.into(), + Err(_) => return EINVAL, + } + } + "protocol_version" => match val.parse::() { + Ok(1) => {} // Only version 1 is legal + _ => return EINVAL, + }, + _ => return EINVAL, + } + } + cmd.clear(); + } + 0 +} diff --git a/lib/boringtun/src/device/dev_lock.rs b/lib/boringtun/src/device/dev_lock.rs new file mode 100644 index 0000000..1a700fa --- /dev/null +++ b/lib/boringtun/src/device/dev_lock.rs @@ -0,0 +1,108 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard}; +use std::ops::Deref; + +/// A special type of read/write lock, that makes the following assumptions: +/// a) Read access is frequent, and has to be very fast, so we want to hold it indefinitely +/// b) Write access is very rare (think less than once per second) and can be a bit slower +/// c) A thread that holds a read lock, can ask for an upgrade to a write lock, cooperatively asking other threads to yield their locks +pub struct Lock { + wants_write: (Mutex, Condvar), + inner: RwLock, // Although parking lot lock is upgradable, it does not allow a two staged mark + lock upgrade +} + +impl Lock { + /// New lock + pub fn new(user_data: T) -> Lock { + Lock { + wants_write: (Mutex::new(false), Condvar::new()), + inner: RwLock::new(user_data), + } + } +} + +impl Lock { + /// Acquire a read lock + pub fn read(&self) -> LockReadGuard { + let (ref lock, ref cvar) = &self.wants_write; + let mut wants_write = lock.lock(); + while *wants_write { + // We have a writer and we want to wait for it to go away + cvar.wait(&mut wants_write); + } + + LockReadGuard { + wants_write: &self.wants_write, + inner: self.inner.read(), + } + } +} + +pub struct LockReadGuard<'a, T: 'a + ?Sized> { + wants_write: &'a (Mutex, Condvar), + inner: RwLockReadGuard<'a, T>, +} + +impl<'a, T: ?Sized> LockReadGuard<'a, T> { + /// Perform a closure on a mutable reference of the inner locked value. + /// + /// # Parameters + /// + /// `prep_func` - A closure that will run once, after the lock marks its intention to write, + /// this can be used to tell other threads to yield their read locks temporarily. It will be passed + /// an immutable reference to the inner value. + /// + /// `mut_func` - A closure that will run once write access is gained. It iwll be passed a mutable reference + /// to the inner value. + /// + pub fn try_writeable U>( + &mut self, + prep_func: P, + mut_func: F, + ) -> Option { + // First tell everyone that we want to write now, this will prevent any new reader from starting until we are done. + { + let &(ref lock, cvar) = &self.wants_write; + let mut wants_write = lock.lock(); + + RwLockReadGuard::unlocked(&mut self.inner, move || { + while *wants_write { + // We have a writer and we want to wait for it to go away + cvar.wait(&mut wants_write); + } + + *wants_write = true; + }); + } + + // Second stage is to run the prep function + prep_func(&*self.inner); + + let lock = RwLockReadGuard::rwlock(&self.inner); + + // Third stage is to perform our op under a write lock + let ret = Some(RwLockReadGuard::unlocked(&mut self.inner, move || { + // There is no race here because wants_write blocks other threads + let mut write_access = lock.write(); + mut_func(&mut *write_access) + })); + + // Finally signal other threads + let (ref lock, ref cvar) = &self.wants_write; + let mut wants_write = lock.lock(); + *wants_write = false; + cvar.notify_all(); + + ret + } +} + +impl<'a, T: ?Sized> Deref for LockReadGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { + &self.inner + } +} diff --git a/lib/boringtun/src/device/drop_privileges.rs b/lib/boringtun/src/device/drop_privileges.rs new file mode 100644 index 0000000..621c4d4 --- /dev/null +++ b/lib/boringtun/src/device/drop_privileges.rs @@ -0,0 +1,75 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::device::Error; +use libc::{gid_t, setgid, setuid, uid_t}; +use std::io; + +#[cfg(target_os = "macos")] +use nix::unistd::User; + +pub fn get_saved_ids() -> Result<(uid_t, gid_t), Error> { + // Get the user name of the sudoer + #[cfg(target_os = "macos")] + match std::env::var("USER") { + Ok(uname) => match User::from_name(&uname) { + Ok(Some(user)) => Ok((uid_t::from(user.uid), gid_t::from(user.gid))), + Err(e) => Err(Error::DropPrivileges(format!( + "Failed parse user; err: {:?}", + e + ))), + Ok(None) => Err(Error::DropPrivileges("Failed to find user".to_owned())), + }, + Err(e) => Err(Error::DropPrivileges(format!( + "Could not get environment variable for user; err: {:?}", + e + ))), + } + #[cfg(not(target_os = "macos"))] + { + use libc::{getlogin, getpwnam}; + + let uname = unsafe { getlogin() }; + if uname.is_null() { + return Err(Error::DropPrivileges("NULL from getlogin".to_owned())); + } + let userinfo = unsafe { getpwnam(uname) }; + if userinfo.is_null() { + return Err(Error::DropPrivileges("NULL from getpwnam".to_owned())); + } + + // Saved group ID + let saved_gid = unsafe { (*userinfo).pw_gid }; + // Saved user ID + let saved_uid = unsafe { (*userinfo).pw_uid }; + + Ok((saved_uid, saved_gid)) + } +} + +pub fn drop_privileges() -> Result<(), Error> { + let (saved_uid, saved_gid) = get_saved_ids()?; + + if -1 == unsafe { setgid(saved_gid) } { + // Set real and effective group ID + return Err(Error::DropPrivileges( + io::Error::last_os_error().to_string(), + )); + } + + if -1 == unsafe { setuid(saved_uid) } { + // Set real and effective user ID + return Err(Error::DropPrivileges( + io::Error::last_os_error().to_string(), + )); + } + + // Validated we can't get sudo back again + if unsafe { (setgid(0) != -1) || (setuid(0) != -1) } { + Err(Error::DropPrivileges( + "Failed to permanently drop privileges".to_owned(), + )) + } else { + Ok(()) + } +} diff --git a/lib/boringtun/src/device/epoll.rs b/lib/boringtun/src/device/epoll.rs new file mode 100644 index 0000000..b6ecaf0 --- /dev/null +++ b/lib/boringtun/src/device/epoll.rs @@ -0,0 +1,416 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::Error; +use libc::*; +use parking_lot::Mutex; +use std::io; +use std::ops::Deref; +use std::os::unix::io::RawFd; +use std::ptr::null_mut; +use std::time::Duration; + +/// A return type for the EventPoll::wait() function +pub enum WaitResult<'a, H> { + /// Event triggered normally + Ok(EventGuard<'a, H>), + /// Event triggered due to End of File conditions + EoF(EventGuard<'a, H>), + /// There was an error + Error(String), +} + +/// Implements a registry of pollable events +pub struct EventPoll { + events: Mutex>>>>, + epoll: RawFd, // The OS epoll +} + +/// A type that hold a reference to a triggered Event +/// While an EventGuard exists for a given Event, it will not be triggered by any other thread +/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled +pub struct EventGuard<'a, H> { + epoll: RawFd, + event: &'a mut Event, + poll: &'a EventPoll, +} + +/// A reference to a single event in an EventPoll +pub struct EventRef { + trigger: RawFd, +} + +struct Event { + event: epoll_event, // The epoll event description + fd: RawFd, // The associated fd + handler: H, // The associated data + notifier: bool, // Is a notification event + needs_read: bool, // This event needs to be read to be cleared +} + +impl Drop for EventPoll { + fn drop(&mut self) { + unsafe { close(self.epoll) }; + } +} + +impl EventPoll { + /// Create a new event registry + pub fn new() -> Result, Error> { + let epoll = match unsafe { epoll_create(1) } { + -1 => return Err(Error::EventQueue(io::Error::last_os_error())), + epoll => epoll, + }; + + Ok(EventPoll { + events: Mutex::new(vec![]), + epoll, + }) + } + + /// Add and enable a new event with the factory. + /// The event is triggered when a Read operation on the provided trigger becomes available + /// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be + /// automatically released. + /// The safe way to delete an event, is using the cancel method of an EventGuard. + /// If the same trigger is used with multiple events in the same EventPoll, the last added + /// event overrides all previous events. In case the same trigger is used with multiple polls, + /// each event will be triggered independently. + /// The event will keep triggering until a Read operation is no longer possible on the trigger. + /// When triggered, one of the threads waiting on the poll will receive the handler via an + /// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to + /// the handler at any given time. + pub fn new_event(&self, trigger: RawFd, handler: H) -> Result { + // Create an event descriptor + let flags = EPOLLIN | EPOLLONESHOT; + let ev = Event { + event: epoll_event { + events: flags as _, + u64: 0, + }, + fd: trigger, + handler, + notifier: false, + needs_read: false, + }; + + self.register_event(ev) + } + + /// Add and enable a new write event with the factory. + /// The event is triggered when a Write operation on the provided trigger becomes possible + /// For TCP sockets it means that the socket was succesfully connected + #[allow(dead_code)] + pub fn new_write_event(&self, trigger: RawFd, handler: H) -> Result { + // Create an event descriptor + let flags = EPOLLOUT | EPOLLET | EPOLLONESHOT; + let ev = Event { + event: epoll_event { + events: flags as _, + u64: 0, + }, + fd: trigger, + handler, + notifier: false, + needs_read: false, + }; + + self.register_event(ev) + } + + /// Add and enable a new timed event with the factory. + /// The even will be triggered for the first time after period time, and henceforth triggered + /// every period time. Period is counted from the moment the appropriate EventGuard is released. + pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result { + // The periodic event on Linux uses the timerfd + let tfd = match unsafe { timerfd_create(CLOCK_BOOTTIME, TFD_NONBLOCK) } { + -1 => match unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK) } { + // A fallback for kernels < 3.15 + -1 => return Err(Error::Timer(io::Error::last_os_error())), + efd => efd, + }, + efd => efd, + }; + + let ts = timespec { + tv_sec: period.as_secs() as _, + tv_nsec: i64::from(period.subsec_nanos()) as _, + }; + + let spec = itimerspec { + it_value: ts, + it_interval: ts, + }; + + if unsafe { timerfd_settime(tfd, 0, &spec, std::ptr::null_mut()) } == -1 { + unsafe { close(tfd) }; + return Err(Error::Timer(io::Error::last_os_error())); + } + + let ev = Event { + event: epoll_event { + events: (EPOLLIN | EPOLLONESHOT) as _, + u64: 0, + }, + fd: tfd, + handler, + notifier: false, + needs_read: true, + }; + + self.register_event(ev) + } + + /// Add and enable a new notification event with the factory. + /// The event can only be triggered manually, using the trigger_notification method. + /// The event will remain in a triggered state until the stop_notification method is + /// called. Both methods should only be called with the producing EventPoll. + pub fn new_notifier(&self, handler: H) -> Result { + // The notifier on Linux uses the eventfd for notifications. + // The way it works is when a non zero value is written into the eventfd it will trigger + // the EPOLLIN event. Since we don't enable ONESHOT it will keep triggering until + // canceled. + // When we want to stop the event, we read something once from the file descriptor. + let efd = match unsafe { eventfd(0, EFD_NONBLOCK) } { + -1 => return Err(Error::EventQueue(io::Error::last_os_error())), + efd => efd, + }; + + let ev = Event { + event: epoll_event { + events: (EPOLLIN) as _, + u64: 0, + }, + fd: efd, + handler, + notifier: true, + needs_read: false, + }; + + self.register_event(ev) + } + + /// Add and enable a new signal handler + pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result { + let sfd = match unsafe { + let mut sigset = std::mem::zeroed(); + sigemptyset(&mut sigset); + sigaddset(&mut sigset, signal); + sigprocmask(SIG_BLOCK, &sigset, null_mut()); + signalfd(-1, &sigset, SFD_NONBLOCK) + } { + -1 => return Err(Error::EventQueue(io::Error::last_os_error())), + sfd => sfd, + }; + + let ev = Event { + event: epoll_event { + events: (EPOLLIN | EPOLLONESHOT) as _, + u64: 0, + }, + fd: sfd, + handler, + notifier: false, + needs_read: true, + }; + + self.register_event(ev) + } + + /// Wait until one of the registered events becomes triggered. Once an event + /// is triggered, a single caller thread gets the handler for that event. + /// In case a notifier is triggered, all waiting threads will receive the same + /// handler. + pub fn wait(&self) -> WaitResult<'_, H> { + let mut event = epoll_event { events: 0, u64: 0 }; + match unsafe { epoll_wait(self.epoll, &mut event, 1, -1) } { + -1 => return WaitResult::Error(io::Error::last_os_error().to_string()), + 1 => {} + _ => return WaitResult::Error("unexpected number of events returned".to_string()), + } + + let event_data = unsafe { (event.u64 as *mut Event).as_mut().unwrap() }; + + let guard = EventGuard { + epoll: self.epoll, + event: event_data, + poll: self, + }; + + if event.events & EPOLLHUP as u32 != 0 { + // End of file flag + WaitResult::EoF(guard) + } else { + WaitResult::Ok(guard) + } + } + + // Register an event with this poll. + fn register_event(&self, ev: Event) -> Result { + // To register an event we + // * Create a reference to self in the inner event + // * Store the Event in the events vector + // * Dispose of a previous Event under same fd if any + // * Add the Event to epoll + let trigger = ev.fd; + let mut ev = Box::new(ev); + // The inner event points back to the wrapper + ev.event.u64 = ev.as_mut() as *mut Event as _; + let mut event_desc = ev.event; + // Now add the pointer to the events vector, this is a place from which we can drop the event + self.insert_at(trigger as _, ev); + // Add the event to epoll + if unsafe { epoll_ctl(self.epoll, EPOLL_CTL_ADD, trigger, &mut event_desc) } == -1 { + return Err(Error::EventQueue(io::Error::last_os_error())); + } + + Ok(EventRef { trigger }) + } + + // Insert an event into the events vector + fn insert_at(&self, index: usize, data: Box>) { + let mut events = self.events.lock(); + while events.len() <= index { + // Resize the vector to be able to fit the new index + // We trust the OS to allocate file descriptors in a sane order + events.push(None); // resize doesn't work because Clone is not satisfied + } + + if events[index].take().is_some() { + // Properly remove the previous event first + unsafe { + epoll_ctl(self.epoll, EPOLL_CTL_DEL, index as _, null_mut()); + }; + } + + events[index] = Some(data); + } + + /// Trigger a notification + pub fn trigger_notification(&self, notification_event: &EventRef) { + let events = self.events.lock(); + + let event_ref = &(*events)[notification_event.trigger as usize]; + let event_data = event_ref.as_ref().expect("Expected an event"); + + if !event_data.notifier { + panic!("Can only trigger a notification event"); + } + + // Write some data to the eventfd to trigger an EPOLLIN event + unsafe { + write( + notification_event.trigger, + &(std::u64::MAX - 1).to_ne_bytes()[0] as *const u8 as _, + 8, + ) + }; + } + + /// Stop a notification + pub fn stop_notification(&self, notification_event: &EventRef) { + let events = self.events.lock(); + + let event_ref = &(*events)[notification_event.trigger as usize]; + let event_data = event_ref.as_ref().expect("Expected an event"); + + if !event_data.notifier { + panic!("Can only trigger a notification event"); + } + + let mut buf = [0u8; 8]; + unsafe { + read( + notification_event.trigger, + buf.as_mut_ptr() as _, + buf.len() as _, + ) + }; + } +} + +impl EventPoll { + /// Disable and remove the event and associated handler, using the fd that + /// was used to register it. + /// + /// # Safety + /// + /// This function is only safe to call when the event loop is not running, + /// otherwise the memory of the handler may get freed while in use. + pub unsafe fn clear_event_by_fd(&self, index: RawFd) { + let mut events = self.events.lock(); + assert!(index >= 0); + if events[index as usize].take().is_some() { + epoll_ctl(self.epoll, EPOLL_CTL_DEL, index, null_mut()); + } + } +} + +impl<'a, H> Deref for EventGuard<'a, H> { + type Target = H; + fn deref(&self) -> &H { + &self.event.handler + } +} + +impl<'a, H> Drop for EventGuard<'a, H> { + fn drop(&mut self) { + if self.event.needs_read { + // Must read from the event to reset it before we enable it + let mut buf: [std::mem::MaybeUninit; 256] = + unsafe { std::mem::MaybeUninit::uninit().assume_init() }; + while unsafe { read(self.event.fd, buf.as_mut_ptr() as _, buf.len() as _) } != -1 {} + } + + unsafe { + epoll_ctl( + self.epoll, + EPOLL_CTL_MOD, + self.event.fd, + &mut self.event.event, + ); + } + } +} + +impl<'a, H> EventGuard<'a, H> { + /// Get a mutable reference to the stored value + #[allow(dead_code)] + pub fn get_mut(&mut self) -> &mut H { + &mut self.event.handler + } + + /// Cancel and remove the event referenced by this guard + pub fn cancel(self) { + unsafe { self.poll.clear_event_by_fd(self.event.fd) }; + std::mem::forget(self); // Don't call the regular drop that would enable the event + } + + pub fn fd(&self) -> i32 { + self.event.fd + } + + /// Change the event flags to enable or disable notifying when the fd is writable + pub fn notify_writable(&mut self, enabled: bool) { + let flags = if enabled { + EPOLLOUT | EPOLLIN | EPOLLET | EPOLLONESHOT + } else { + EPOLLIN | EPOLLONESHOT + }; + self.event.event.events = flags as _; + } +} + +pub fn block_signal(signal: c_int) -> Result { + unsafe { + let mut sigset = std::mem::zeroed(); + sigemptyset(&mut sigset); + if sigaddset(&mut sigset, signal) == -1 { + return Err(io::Error::last_os_error().to_string()); + } + if sigprocmask(SIG_BLOCK, &sigset, null_mut()) == -1 { + return Err(io::Error::last_os_error().to_string()); + } + Ok(sigset) + } +} diff --git a/lib/boringtun/src/device/integration_tests/mod.rs b/lib/boringtun/src/device/integration_tests/mod.rs new file mode 100644 index 0000000..b4e360c --- /dev/null +++ b/lib/boringtun/src/device/integration_tests/mod.rs @@ -0,0 +1,849 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +// This module contains some integration tests for boringtun +// Those tests require docker and sudo privileges to run +#[cfg(all(test, not(target_os = "macos")))] +mod tests { + use crate::device::{DeviceConfig, DeviceHandle}; + use crate::x25519::{PublicKey, StaticSecret}; + use base64::encode as base64encode; + use hex::encode; + use rand_core::OsRng; + use ring::rand::{SecureRandom, SystemRandom}; + use std::fmt::Write as _; + use std::io::{BufRead, BufReader, Read, Write}; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::os::unix::net::UnixStream; + use std::process::Command; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::thread; + + static NEXT_IFACE_IDX: AtomicUsize = AtomicUsize::new(100); // utun 100+ should be vacant during testing on CI + static NEXT_PORT: AtomicUsize = AtomicUsize::new(61111); // Use ports starting with 61111, hoping we don't run into a taken port 🤷 + static NEXT_IP: AtomicUsize = AtomicUsize::new(0xc0000200); // Use 192.0.2.0/24 for those tests, we might use more than 256 addresses though, usize must be >=32 bits on all supported platforms + static NEXT_IP_V6: AtomicUsize = AtomicUsize::new(0); // Use the 2001:db8:: address space, append this atomic counter for bottom 32 bits + + fn next_ip() -> IpAddr { + IpAddr::V4(Ipv4Addr::from( + NEXT_IP.fetch_add(1, Ordering::Relaxed) as u32 + )) + } + + fn next_ip_v6() -> IpAddr { + let addr = 0x2001_0db8_0000_0000_0000_0000_0000_0000_u128 + + u128::from(NEXT_IP_V6.fetch_add(1, Ordering::Relaxed) as u32); + + IpAddr::V6(Ipv6Addr::from(addr)) + } + + fn next_port() -> u16 { + NEXT_PORT.fetch_add(1, Ordering::Relaxed) as u16 + } + + /// Represents an allowed IP and cidr for a peer + struct AllowedIp { + ip: IpAddr, + cidr: u8, + } + + /// Represents a single peer running in a container + struct Peer { + key: StaticSecret, + endpoint: SocketAddr, + allowed_ips: Vec, + container_name: Option, + } + + /// Represents a single WireGuard interface on local machine + struct WGHandle { + _device: DeviceHandle, + name: String, + addr_v4: IpAddr, + addr_v6: IpAddr, + started: bool, + peers: Vec>, + } + + impl Drop for Peer { + fn drop(&mut self) { + if let Some(name) = &self.container_name { + Command::new("docker") + .args([ + "stop", // Run docker + &name[5..], + ]) + .status() + .ok(); + + std::fs::remove_file(name).ok(); + std::fs::remove_file(format!("{}.ngx", name)).ok(); + } + } + } + + impl Peer { + /// Create a new peer with a given endpoint and a list of allowed IPs + fn new(endpoint: SocketAddr, allowed_ips: Vec) -> Peer { + Peer { + key: StaticSecret::random_from_rng(OsRng), + endpoint, + allowed_ips, + container_name: None, + } + } + + /// Creates a new configuration file that can be used by wg-quick + fn gen_wg_conf( + &self, + local_key: &PublicKey, + local_addr: &IpAddr, + local_port: u16, + ) -> String { + let mut conf = String::from("[Interface]\n"); + // Each allowed ip, becomes a possible address in the config + for ip in &self.allowed_ips { + let _ = writeln!(conf, "Address = {}/{}", ip.ip, ip.cidr); + } + + // The local endpoint port is the remote listen port + let _ = writeln!(conf, "ListenPort = {}", self.endpoint.port()); + // HACK: this should consume the key so it can't be reused instead of cloning and serializing + let _ = writeln!(conf, "PrivateKey = {}", base64encode(self.key.to_bytes())); + + // We are the peer + let _ = writeln!(conf, "[Peer]"); + let _ = writeln!(conf, "PublicKey = {}", base64encode(local_key.as_bytes())); + let _ = writeln!(conf, "AllowedIPs = {}", local_addr); + let _ = write!(conf, "Endpoint = 127.0.0.1:{}", local_port); + + conf + } + + /// Create a simple nginx config, that will respond with the peer public key + fn gen_nginx_conf(&self) -> String { + format!( + "server {{\n\ + listen 80;\n\ + listen [::]:80;\n\ + location / {{\n\ + return 200 '{}';\n\ + }}\n\ + }}", + encode(PublicKey::from(&self.key).as_bytes()) + ) + } + + fn start_in_container( + &mut self, + local_key: &PublicKey, + local_addr: &IpAddr, + local_port: u16, + ) { + let peer_config = self.gen_wg_conf(local_key, local_addr, local_port); + let peer_config_file = temp_path(); + std::fs::write(&peer_config_file, peer_config).unwrap(); + let nginx_config = self.gen_nginx_conf(); + let nginx_config_file = format!("{}.ngx", peer_config_file); + std::fs::write(&nginx_config_file, nginx_config).unwrap(); + + Command::new("docker") + .args([ + "run", // Run docker + "-d", // In detached mode + "--cap-add=NET_ADMIN", // Grant permissions to open a tunnel + "--device=/dev/net/tun", + "--sysctl", // Enable ipv6 + "net.ipv6.conf.all.disable_ipv6=0", + "--sysctl", + "net.ipv6.conf.default.disable_ipv6=0", + "-p", // Open port for the endpoint + &format!("{0}:{0}/udp", self.endpoint.port()), + "-v", // Map the generated WireGuard config file + &format!("{}:/wireguard/wg.conf", peer_config_file), + "-v", // Map the nginx config file + &format!("{}:/etc/nginx/conf.d/default.conf", nginx_config_file), + "--rm", // Cleanup + "--name", + &peer_config_file[5..], + "vkrasnov/wireguard-test", + ]) + .status() + .expect("Failed to run docker"); + + self.container_name = Some(peer_config_file); + } + + fn connect(&self) -> std::net::TcpStream { + let http_addr = SocketAddr::new(self.allowed_ips[0].ip, 80); + for _i in 0..5 { + let res = std::net::TcpStream::connect(http_addr); + if let Err(err) = res { + println!("failed to connect: {:?}", err); + std::thread::sleep(std::time::Duration::from_millis(100)); + continue; + } + + return res.unwrap(); + } + + panic!("failed to connect"); + } + + fn get_request(&self) -> String { + let mut tcp_conn = self.connect(); + + write!( + tcp_conn, + "GET / HTTP/1.1\nHost: localhost\nAccept: */*\nConnection: close\n\n" + ) + .unwrap(); + + tcp_conn + .set_read_timeout(Some(std::time::Duration::from_secs(60))) + .ok(); + + let mut reader = BufReader::new(tcp_conn); + let mut line = String::new(); + let mut response = String::new(); + let mut len = 0usize; + + // Read response code + if reader.read_line(&mut line).is_ok() && !line.starts_with("HTTP/1.1 200") { + return response; + } + line.clear(); + + // Read headers + while reader.read_line(&mut line).is_ok() { + if line.trim() == "" { + break; + } + + { + let parsed_line: Vec<&str> = line.split(':').collect(); + if parsed_line.len() < 2 { + return response; + } + + let (key, val) = (parsed_line[0], parsed_line[1]); + if key.to_lowercase() == "content-length" { + len = match val.trim().parse() { + Err(_) => return response, + Ok(len) => len, + }; + } + } + line.clear(); + } + + // Read body + let mut buf = [0u8; 256]; + while len > 0 { + let to_read = len.min(buf.len()); + if reader.read_exact(&mut buf[..to_read]).is_err() { + return response; + } + response.push_str(&String::from_utf8_lossy(&buf[..to_read])); + len -= to_read; + } + + response + } + } + + impl WGHandle { + /// Create a new interface for the tunnel with the given address + fn init(addr_v4: IpAddr, addr_v6: IpAddr) -> WGHandle { + WGHandle::init_with_config( + addr_v4, + addr_v6, + DeviceConfig { + n_threads: 2, + use_connected_socket: true, + #[cfg(target_os = "linux")] + use_multi_queue: true, + #[cfg(target_os = "linux")] + uapi_fd: -1, + }, + ) + } + + /// Create a new interface for the tunnel with the given address + fn init_with_config(addr_v4: IpAddr, addr_v6: IpAddr, config: DeviceConfig) -> WGHandle { + // Generate a new name, utun100+ should work on macOS and Linux + let name = format!("utun{}", NEXT_IFACE_IDX.fetch_add(1, Ordering::Relaxed)); + let _device = DeviceHandle::new(&name, config).unwrap(); + WGHandle { + _device, + name, + addr_v4, + addr_v6, + started: false, + peers: vec![], + } + } + + #[cfg(target_os = "macos")] + /// Starts the tunnel + fn start(&mut self) { + // Assign the ipv4 address to the interface + Command::new("ifconfig") + .args(&[ + &self.name, + &self.addr_v4.to_string(), + &self.addr_v4.to_string(), + "alias", + ]) + .status() + .expect("failed to assign ip to tunnel"); + + // Assign the ipv6 address to the interface + Command::new("ifconfig") + .args(&[ + &self.name, + "inet6", + &self.addr_v6.to_string(), + "prefixlen", + "128", + "alias", + ]) + .status() + .expect("failed to assign ipv6 to tunnel"); + + // Start the tunnel + Command::new("ifconfig") + .args(&[&self.name, "up"]) + .status() + .expect("failed to start the tunnel"); + + self.started = true; + + // Add each peer to the routing table + for p in &self.peers { + for r in &p.allowed_ips { + let inet_flag = match r.ip { + IpAddr::V4(_) => "-inet", + IpAddr::V6(_) => "-inet6", + }; + + Command::new("route") + .args(&[ + "-q", + "-n", + "add", + inet_flag, + &format!("{}/{}", r.ip, r.cidr), + "-interface", + &self.name, + ]) + .status() + .expect("failed to add route"); + } + } + } + + #[cfg(target_os = "linux")] + /// Starts the tunnel + fn start(&mut self) { + Command::new("ip") + .args([ + "address", + "add", + &self.addr_v4.to_string(), + "dev", + &self.name, + ]) + .status() + .expect("failed to assign ip to tunnel"); + + Command::new("ip") + .args([ + "address", + "add", + &self.addr_v6.to_string(), + "dev", + &self.name, + ]) + .status() + .expect("failed to assign ipv6 to tunnel"); + + // Start the tunnel + Command::new("ip") + .args(["link", "set", "mtu", "1400", "up", "dev", &self.name]) + .status() + .expect("failed to start the tunnel"); + + self.started = true; + + // Add each peer to the routing table + for p in &self.peers { + for r in &p.allowed_ips { + Command::new("ip") + .args([ + "route", + "add", + &format!("{}/{}", r.ip, r.cidr), + "dev", + &self.name, + ]) + .status() + .expect("failed to add route"); + } + } + } + + /// Issue a get command on the interface + fn wg_get(&self) -> String { + let path = format!("/var/run/wireguard/{}.sock", self.name); + + let mut socket = UnixStream::connect(path).unwrap(); + write!(socket, "get=1\n\n").unwrap(); + + let mut ret = String::new(); + socket.read_to_string(&mut ret).unwrap(); + ret + } + + /// Issue a set command on the interface + fn wg_set(&self, setting: &str) -> String { + let path = format!("/var/run/wireguard/{}.sock", self.name); + let mut socket = UnixStream::connect(path).unwrap(); + write!(socket, "set=1\n{}\n\n", setting).unwrap(); + + let mut ret = String::new(); + socket.read_to_string(&mut ret).unwrap(); + ret + } + + /// Assign a listen_port to the interface + fn wg_set_port(&self, port: u16) -> String { + self.wg_set(&format!("listen_port={}", port)) + } + + /// Assign a private_key to the interface + fn wg_set_key(&self, key: StaticSecret) -> String { + self.wg_set(&format!("private_key={}", encode(key.to_bytes()))) + } + + /// Assign a peer to the interface (with public_key, endpoint and a series of nallowed_ip) + fn wg_set_peer( + &self, + key: &PublicKey, + ep: &SocketAddr, + allowed_ips: &[AllowedIp], + ) -> String { + let mut req = format!("public_key={}\nendpoint={}", encode(key.as_bytes()), ep); + for AllowedIp { ip, cidr } in allowed_ips { + let _ = write!(req, "\nallowed_ip={}/{}", ip, cidr); + } + + self.wg_set(&req) + } + + /// Add a new known peer + fn add_peer(&mut self, peer: Arc) { + self.wg_set_peer( + &PublicKey::from(&peer.key), + &peer.endpoint, + &peer.allowed_ips, + ); + self.peers.push(peer); + } + } + + /// Create a new filename in the /tmp dir + fn temp_path() -> String { + let mut path = String::from("/tmp/"); + let mut buf = [0u8; 32]; + SystemRandom::new().fill(&mut buf[..]).unwrap(); + path.push_str(&encode(buf)); + path + } + + #[test] + #[ignore] + /// Test if wireguard starts and creates a unix socket that we can read from + fn test_wireguard_get() { + let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap()); + let response = wg.wg_get(); + assert!(response.ends_with("errno=0\n\n")); + } + + #[test] + #[ignore] + /// Test if wireguard starts and creates a unix socket that we can use to set settings + fn test_wireguard_set() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let own_public_key = PublicKey::from(&private_key); + + let wg = WGHandle::init("192.0.2.0".parse().unwrap(), "::2".parse().unwrap()); + assert!(wg.wg_get().ends_with("errno=0\n\n")); + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + // Check that the response matches what we expect + assert_eq!( + wg.wg_get(), + format!( + "own_public_key={}\nlisten_port={}\nerrno=0\n\n", + encode(own_public_key.as_bytes()), + port + ) + ); + + let peer_key = StaticSecret::random_from_rng(OsRng); + let peer_pub_key = PublicKey::from(&peer_key); + let endpoint = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 0, 0, 1)), 50001); + let allowed_ips = [ + AllowedIp { + ip: IpAddr::V4(Ipv4Addr::new(172, 0, 0, 2)), + cidr: 32, + }, + AllowedIp { + ip: IpAddr::V6(Ipv6Addr::new(0xf120, 0, 0, 2, 2, 2, 0, 0)), + cidr: 100, + }, + ]; + + assert_eq!( + wg.wg_set_peer(&peer_pub_key, &endpoint, &allowed_ips), + "errno=0\n\n" + ); + + // Check that the response matches what we expect + assert_eq!( + wg.wg_get(), + format!( + "own_public_key={}\n\ + listen_port={}\n\ + public_key={}\n\ + endpoint={}\n\ + allowed_ip={}/{}\n\ + allowed_ip={}/{}\n\ + rx_bytes=0\n\ + tx_bytes=0\n\ + errno=0\n\n", + encode(own_public_key.as_bytes()), + port, + encode(peer_pub_key.as_bytes()), + endpoint, + allowed_ips[0].ip, + allowed_ips[0].cidr, + allowed_ips[1].ip, + allowed_ips[1].cidr + ) + ); + } + + /// Test if wireguard can handle simple ipv4 connections, don't use a connected socket + #[test] + #[ignore] + fn test_wg_start_ipv4_non_connected() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init_with_config( + addr_v4, + addr_v6, + DeviceConfig { + n_threads: 2, + use_connected_socket: false, + #[cfg(target_os = "linux")] + use_multi_queue: true, + #[cfg(target_os = "linux")] + uapi_fd: -1, + }, + ); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + // Create a new peer whose endpoint is on this machine + let mut peer = Peer::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()), + vec![AllowedIp { + ip: next_ip(), + cidr: 32, + }], + ); + + peer.start_in_container(&public_key, &addr_v4, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + wg.start(); + + let response = peer.get_request(); + + assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes())); + } + + /// Test if wireguard can handle simple ipv4 connections + #[test] + #[ignore] + fn test_wg_start_ipv4() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init(addr_v4, addr_v6); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + // Create a new peer whose endpoint is on this machine + let mut peer = Peer::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()), + vec![AllowedIp { + ip: next_ip(), + cidr: 32, + }], + ); + + peer.start_in_container(&public_key, &addr_v4, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + wg.start(); + + let response = peer.get_request(); + + assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes())); + } + + #[test] + #[ignore] + /// Test if wireguard can handle simple ipv6 connections + fn test_wg_start_ipv6() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init(addr_v4, addr_v6); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + let mut peer = Peer::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()), + vec![AllowedIp { + ip: next_ip_v6(), + cidr: 128, + }], + ); + + peer.start_in_container(&public_key, &addr_v6, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + wg.start(); + + let response = peer.get_request(); + + assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes())); + } + + /// Test if wireguard can handle connection with an ipv6 endpoint + #[test] + #[ignore] + #[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM + fn test_wg_start_ipv6_endpoint() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init(addr_v4, addr_v6); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + let mut peer = Peer::new( + SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + next_port(), + ), + vec![AllowedIp { + ip: next_ip_v6(), + cidr: 128, + }], + ); + + peer.start_in_container(&public_key, &addr_v6, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + wg.start(); + + let response = peer.get_request(); + + assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes())); + } + + /// Test if wireguard can handle connection with an ipv6 endpoint + #[test] + #[ignore] + #[cfg(target_os = "linux")] // Can't make docker work with ipv6 on macOS ATM + fn test_wg_start_ipv6_endpoint_not_connected() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init_with_config( + addr_v4, + addr_v6, + DeviceConfig { + n_threads: 2, + use_connected_socket: false, + #[cfg(target_os = "linux")] + use_multi_queue: true, + #[cfg(target_os = "linux")] + uapi_fd: -1, + }, + ); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + let mut peer = Peer::new( + SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + next_port(), + ), + vec![AllowedIp { + ip: next_ip_v6(), + cidr: 128, + }], + ); + + peer.start_in_container(&public_key, &addr_v6, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + wg.start(); + + let response = peer.get_request(); + + assert_eq!(response, encode(PublicKey::from(&peer.key).as_bytes())); + } + + /// Test many concurrent connections + #[test] + #[ignore] + fn test_wg_concurrent() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init(addr_v4, addr_v6); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + for _ in 0..5 { + // Create a new peer whose endpoint is on this machine + let mut peer = Peer::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()), + vec![AllowedIp { + ip: next_ip(), + cidr: 32, + }], + ); + + peer.start_in_container(&public_key, &addr_v4, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + } + + wg.start(); + + let mut threads = vec![]; + + for p in wg.peers { + let pub_key = PublicKey::from(&p.key); + threads.push(thread::spawn(move || { + for _ in 0..100 { + let response = p.get_request(); + assert_eq!(response, encode(pub_key.as_bytes())); + } + })); + } + + for t in threads { + t.join().unwrap(); + } + } + + /// Test many concurrent connections + #[test] + #[ignore] + fn test_wg_concurrent_v6() { + let port = next_port(); + let private_key = StaticSecret::random_from_rng(OsRng); + let public_key = PublicKey::from(&private_key); + let addr_v4 = next_ip(); + let addr_v6 = next_ip_v6(); + + let mut wg = WGHandle::init(addr_v4, addr_v6); + + assert_eq!(wg.wg_set_port(port), "errno=0\n\n"); + assert_eq!(wg.wg_set_key(private_key), "errno=0\n\n"); + + for _ in 0..5 { + // Create a new peer whose endpoint is on this machine + let mut peer = Peer::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), next_port()), + vec![AllowedIp { + ip: next_ip_v6(), + cidr: 128, + }], + ); + + peer.start_in_container(&public_key, &addr_v6, port); + + let peer = Arc::new(peer); + + wg.add_peer(Arc::clone(&peer)); + } + + wg.start(); + + let mut threads = vec![]; + + for p in wg.peers { + let pub_key = PublicKey::from(&p.key); + threads.push(thread::spawn(move || { + for _ in 0..100 { + let response = p.get_request(); + assert_eq!(response, encode(pub_key.as_bytes())); + } + })); + } + + for t in threads { + t.join().unwrap(); + } + } +} diff --git a/lib/boringtun/src/device/kqueue.rs b/lib/boringtun/src/device/kqueue.rs new file mode 100644 index 0000000..638665e --- /dev/null +++ b/lib/boringtun/src/device/kqueue.rs @@ -0,0 +1,337 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::Error; +use libc::*; +use parking_lot::Mutex; +use std::io; +use std::ops::Deref; +use std::os::unix::io::RawFd; +use std::ptr::{null, null_mut}; +use std::time::Duration; + +/// A return type for the EventPoll::wait() function +pub enum WaitResult<'a, H> { + /// Event triggered normally + Ok(EventGuard<'a, H>), + /// Event triggered due to End of File conditions + EoF(EventGuard<'a, H>), + /// There was an error + Error(String), +} + +/// Implements a registry of pollable events +pub struct EventPoll { + events: Mutex>>>>, // Events with a file descriptor + custom: Mutex>>>>, // Other events (i.e. timers & notifiers) + signals: Mutex>>>>, // Signal handlers + kqueue: RawFd, // The OS kqueue +} + +/// A type that hold a reference to a triggered Event +/// While an EventGuard exists for a given Event, it will not be triggered by any other thread +/// Once the EventGuard goes out of scope, the underlying Event will be re-enabled +pub struct EventGuard<'a, H> { + kqueue: RawFd, + event: &'a Event, + poll: &'a EventPoll, +} + +/// A reference to a single event in an EventPoll +pub struct EventRef { + trigger: RawFd, +} + +#[derive(PartialEq)] +enum EventKind { + FD, + Notifier, + Signal, + Timer, +} + +// A single event +struct Event { + event: kevent, // The kqueue event description + handler: H, // The associated data + kind: EventKind, +} + +impl Drop for EventPoll { + fn drop(&mut self) { + unsafe { close(self.kqueue) }; + } +} + +unsafe impl Send for EventPoll {} +unsafe impl Sync for EventPoll {} + +impl EventPoll { + /// Create a new event registry + pub fn new() -> Result, Error> { + let kqueue = match unsafe { kqueue() } { + -1 => return Err(Error::EventQueue(io::Error::last_os_error())), + kqueue => kqueue, + }; + + Ok(EventPoll { + events: Mutex::new(vec![]), + custom: Mutex::new(vec![]), + signals: Mutex::new(vec![]), + kqueue, + }) + } + + /// Add and enable a new event with the factory. + /// The event is triggered when a Read operation on the provided trigger becomes available + /// If the trigger fd is closed, the event won't be triggered anymore, but it's data won't be + /// automatically released. + /// The safe way to delete an event, is using the cancel method of an EventGuard. + /// If the same trigger is used with multiple events in the same EventPoll, the last added + /// event overrides all previous events. In case the same trigger is used with multiple polls, + /// each event will be triggered independently. + /// The event will keep triggering until a Read operation is no longer possible on the trigger. + /// When triggered, one of the threads waiting on the poll will receive the handler via an + /// appropriate EventGuard. It is guaranteed that only a single thread can have a reference to + /// the handler at any given time. + pub fn new_event(&self, trigger: RawFd, handler: H) -> Result { + // Create an event descriptor + let flags = EV_ENABLE | EV_DISPATCH; + + let ev = Event { + event: kevent { + ident: trigger as _, + filter: EVFILT_READ, + flags, + fflags: 0, + data: 0, + udata: null_mut(), + }, + handler, + kind: EventKind::FD, + }; + + self.register_event(ev) + } + + pub fn new_periodic_event(&self, handler: H, period: Duration) -> Result { + // The periodic event in BSD uses EVFILT_TIMER + let ev = Event { + event: kevent { + ident: 0, + filter: EVFILT_TIMER, + flags: EV_ENABLE | EV_DISPATCH, + fflags: NOTE_NSECONDS, + data: period + .as_secs() + .checked_mul(1_000_000_000) + .unwrap() + .checked_add(u64::from(period.subsec_nanos())) + .unwrap() as _, + udata: null_mut(), + }, + handler, + kind: EventKind::Timer, + }; + + self.register_event(ev) + } + + pub fn new_notifier(&self, handler: H) -> Result { + // The notifier in BSD uses EVFILT_USER for notifications. + let ev = Event { + event: kevent { + ident: 0, + filter: EVFILT_USER, + flags: EV_ENABLE, + fflags: 0, + data: 0, + udata: null_mut(), + }, + handler, + kind: EventKind::Notifier, + }; + + self.register_event(ev) + } + + /// Add and enable a new signal handler + pub fn new_signal_event(&self, signal: c_int, handler: H) -> Result { + let ev = Event { + event: kevent { + ident: signal as _, + filter: EVFILT_SIGNAL, + flags: EV_ENABLE | EV_DISPATCH, + fflags: 0, + data: 0, + udata: null_mut(), + }, + handler, + kind: EventKind::Signal, + }; + + self.register_event(ev) + } + + /// Wait until one of the registered events becomes triggered. Once an event + /// is triggered, a single caller thread gets the handler for that event. + /// In case a notifier is triggered, all waiting threads will receive the same + /// handler. + pub fn wait(&'_ self) -> WaitResult<'_, H> { + let mut event = kevent { + ident: 0, + filter: 0, + flags: 0, + fflags: 0, + data: 0, + udata: null_mut(), + }; + + if unsafe { kevent(self.kqueue, null(), 0, &mut event, 1, null()) } == -1 { + return WaitResult::Error(io::Error::last_os_error().to_string()); + } + + let event_data = unsafe { (event.udata as *mut Event).as_ref().unwrap() }; + + let guard = EventGuard { + kqueue: self.kqueue, + event: event_data, + poll: self, + }; + + if event.flags & EV_EOF != 0 { + WaitResult::EoF(guard) + } else { + WaitResult::Ok(guard) + } + } + + // Register an event with this poll. + fn register_event(&self, ev: Event) -> Result { + let mut events = match ev.kind { + EventKind::FD => self.events.lock(), + EventKind::Timer | EventKind::Notifier => self.custom.lock(), + EventKind::Signal => self.signals.lock(), + }; + + let (trigger, index) = match ev.kind { + EventKind::FD | EventKind::Signal => (ev.event.ident as RawFd, ev.event.ident as usize), + EventKind::Timer | EventKind::Notifier => (-(events.len() as RawFd) - 1, events.len()), // Custom events get negative identifiers, hopefully we will never have more than 2^31 events of each type + }; + + // Expand events vector if needed + while events.len() <= index { + // Resize the vector to be able to fit the new index + // We trust the OS to allocate file descriptors in a sane order + events.push(None); // resize doesn't work because Clone is not satisfied + } + + let mut ev = Box::new(ev); + // The inner event points back to the wrapper + ev.event.ident = trigger as _; + ev.event.udata = ev.as_mut() as *mut Event as _; + + let mut kev = ev.event; + kev.flags |= EV_ADD; + + if unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) } == -1 { + return Err(Error::EventQueue(io::Error::last_os_error())); + } + + if let Some(mut event) = events[index].take() { + // Properly remove any previous event first + event.event.flags = EV_DELETE; + unsafe { kevent(self.kqueue, &event.event, 1, null_mut(), 0, null()) }; + } + + if ev.kind == EventKind::Signal { + // Mask the signal if successfully added to kqueue + unsafe { signal(trigger, SIG_IGN) }; + } + + events[index] = Some(ev); + + Ok(EventRef { trigger }) + } + + pub fn trigger_notification(&self, notification_event: &EventRef) { + let events = self.custom.lock(); + let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1 + + let event_ref = &(*events)[ev_index as usize]; + let event_data = event_ref.as_ref().expect("Expected an event"); + + if event_data.kind != EventKind::Notifier { + panic!("Can only trigger a notification event"); + } + + let mut kev = event_data.event; + kev.fflags = NOTE_TRIGGER; + + unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) }; + } + + pub fn stop_notification(&self, notification_event: &EventRef) { + let events = self.custom.lock(); + let ev_index = -notification_event.trigger - 1; // Custom events have negative index from -1 + + let event_ref = &(*events)[ev_index as usize]; + let event_data = event_ref.as_ref().expect("Expected an event"); + + if event_data.kind != EventKind::Notifier { + panic!("Can only stop a notification event"); + } + + let mut kev = event_data.event; + kev.flags = EV_DISABLE; + kev.fflags = 0; + + unsafe { kevent(self.kqueue, &kev, 1, null_mut(), 0, null()) }; + } +} + +impl EventPoll { + // This function is only safe to call when the event loop is not running + pub unsafe fn clear_event_by_fd(&self, index: RawFd) { + let (mut events, index) = if index >= 0 { + (self.events.lock(), index as usize) + } else { + (self.custom.lock(), (-index - 1) as usize) + }; + + if let Some(mut event) = events[index].take() { + // Properly remove any previous event first + event.event.flags = EV_DELETE; + kevent(self.kqueue, &event.event, 1, null_mut(), 0, null()); + } + } +} + +impl<'a, H> Deref for EventGuard<'a, H> { + type Target = H; + fn deref(&self) -> &H { + &self.event.handler + } +} + +impl<'a, H> Drop for EventGuard<'a, H> { + fn drop(&mut self) { + unsafe { + // Re-enable the event once EventGuard goes out of scope + kevent(self.kqueue, &self.event.event, 1, null_mut(), 0, null()); + } + } +} + +impl<'a, H> EventGuard<'a, H> { + /// Cancel and remove the event represented by this guard + pub fn cancel(self) { + unsafe { self.poll.clear_event_by_fd(self.event.event.ident as RawFd) }; + std::mem::forget(self); // Don't call the regular drop that would enable the event + } + + /// Stub: only used for Linux-specific features. + pub fn fd(&self) -> i32 { + -1 + } +} diff --git a/lib/boringtun/src/device/mod.rs b/lib/boringtun/src/device/mod.rs new file mode 100644 index 0000000..b250f5e --- /dev/null +++ b/lib/boringtun/src/device/mod.rs @@ -0,0 +1,884 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub mod allowed_ips; +pub mod api; +mod dev_lock; +pub mod drop_privileges; +#[cfg(test)] +mod integration_tests; +pub mod peer; + +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +#[path = "kqueue.rs"] +pub mod poll; + +#[cfg(target_os = "linux")] +#[path = "epoll.rs"] +pub mod poll; + +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +#[path = "tun_darwin.rs"] +pub mod tun; + +#[cfg(target_os = "linux")] +#[path = "tun_linux.rs"] +pub mod tun; + +use std::collections::HashMap; +use std::io::{self, Write as _}; +use std::mem::MaybeUninit; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::os::unix::io::AsRawFd; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use std::thread::JoinHandle; + +use crate::noise::errors::WireGuardError; +use crate::noise::handshake::parse_handshake_anon; +use crate::noise::rate_limiter::RateLimiter; +use crate::noise::{Packet, Tunn, TunnResult}; +use crate::x25519; +use allowed_ips::AllowedIps; +use parking_lot::Mutex; +use peer::{AllowedIP, Peer}; +use poll::{EventPoll, EventRef, WaitResult}; +use rand_core::{OsRng, RngCore}; +use socket2::{Domain, Protocol, Type}; +use tun::TunSocket; + +use dev_lock::{Lock, LockReadGuard}; + +const HANDSHAKE_RATE_LIMIT: u64 = 100; // The number of handshakes per second we can tolerate before using cookies + +const MAX_UDP_SIZE: usize = (1 << 16) - 1; +const MAX_ITR: usize = 100; // Number of packets to handle per handler call + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("i/o error: {0}")] + IoError(#[from] io::Error), + #[error("{0}")] + Socket(io::Error), + #[error("{0}")] + Bind(String), + #[error("{0}")] + FCntl(io::Error), + #[error("{0}")] + EventQueue(io::Error), + #[error("{0}")] + IOCtl(io::Error), + #[error("{0}")] + Connect(String), + #[error("{0}")] + SetSockOpt(String), + #[error("Invalid tunnel name")] + InvalidTunnelName, + #[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] + #[error("{0}")] + GetSockOpt(io::Error), + #[error("{0}")] + GetSockName(String), + #[cfg(target_os = "linux")] + #[error("{0}")] + Timer(io::Error), + #[error("iface read: {0}")] + IfaceRead(io::Error), + #[error("{0}")] + DropPrivileges(String), + #[error("API socket error: {0}")] + ApiSocket(io::Error), +} + +// What the event loop should do after a handler returns +enum Action { + Continue, // Continue the loop + Yield, // Yield the read lock and acquire it again + Exit, // Stop the loop +} + +// Event handler function +type Handler = Box, &mut ThreadData) -> Action + Send + Sync>; + +pub struct DeviceHandle { + device: Arc>, // The interface this handle owns + threads: Vec>, +} + +#[derive(Debug, Clone, Copy)] +pub struct DeviceConfig { + pub n_threads: usize, + pub use_connected_socket: bool, + #[cfg(target_os = "linux")] + pub use_multi_queue: bool, + #[cfg(target_os = "linux")] + pub uapi_fd: i32, +} + +impl Default for DeviceConfig { + fn default() -> Self { + DeviceConfig { + n_threads: 4, + use_connected_socket: true, + #[cfg(target_os = "linux")] + use_multi_queue: true, + #[cfg(target_os = "linux")] + uapi_fd: -1, + } + } +} + +pub struct Device { + key_pair: Option<(x25519::StaticSecret, x25519::PublicKey)>, + queue: Arc>, + + listen_port: u16, + fwmark: Option, + + iface: Arc, + udp4: Option, + udp6: Option, + + yield_notice: Option, + exit_notice: Option, + + peers: HashMap>>, + peers_by_ip: AllowedIps>>, + peers_by_idx: HashMap>>, + next_index: IndexLfsr, + + config: DeviceConfig, + + cleanup_paths: Vec, + + mtu: AtomicUsize, + + rate_limiter: Option>, + + #[cfg(target_os = "linux")] + uapi_fd: i32, +} + +struct ThreadData { + iface: Arc, + src_buf: [u8; MAX_UDP_SIZE], + dst_buf: [u8; MAX_UDP_SIZE], +} + +impl DeviceHandle { + pub fn new(name: &str, config: DeviceConfig) -> Result { + let n_threads = config.n_threads; + let mut wg_interface = Device::new(name, config)?; + wg_interface.open_listen_socket(0)?; // Start listening on a random port + + let interface_lock = Arc::new(Lock::new(wg_interface)); + + let mut threads = vec![]; + + for i in 0..n_threads { + threads.push({ + let dev = Arc::clone(&interface_lock); + thread::spawn(move || DeviceHandle::event_loop(i, &dev)) + }); + } + + Ok(DeviceHandle { + device: interface_lock, + threads, + }) + } + + pub fn wait(&mut self) { + while let Some(thread) = self.threads.pop() { + thread.join().unwrap(); + } + } + + pub fn clean(&mut self) { + for path in &self.device.read().cleanup_paths { + // attempt to remove any file we created in the work dir + let _ = std::fs::remove_file(path); + } + } + + fn event_loop(_i: usize, device: &Lock) { + #[cfg(target_os = "linux")] + let mut thread_local = ThreadData { + src_buf: [0u8; MAX_UDP_SIZE], + dst_buf: [0u8; MAX_UDP_SIZE], + iface: if _i == 0 || !device.read().config.use_multi_queue { + // For the first thread use the original iface + Arc::clone(&device.read().iface) + } else { + // For for the rest create a new iface queue + let iface_local = Arc::new( + TunSocket::new(&device.read().iface.name().unwrap()) + .unwrap() + .set_non_blocking() + .unwrap(), + ); + + device + .read() + .register_iface_handler(Arc::clone(&iface_local)) + .ok(); + + iface_local + }, + }; + + #[cfg(not(target_os = "linux"))] + let mut thread_local = ThreadData { + src_buf: [0u8; MAX_UDP_SIZE], + dst_buf: [0u8; MAX_UDP_SIZE], + iface: Arc::clone(&device.read().iface), + }; + + #[cfg(not(target_os = "linux"))] + let uapi_fd = -1; + #[cfg(target_os = "linux")] + let uapi_fd = device.read().uapi_fd; + + loop { + // The event loop keeps a read lock on the device, because we assume write access is rarely needed + let mut device_lock = device.read(); + let queue = Arc::clone(&device_lock.queue); + + loop { + match queue.wait() { + WaitResult::Ok(handler) => { + let action = (*handler)(&mut device_lock, &mut thread_local); + match action { + Action::Continue => {} + Action::Yield => break, + Action::Exit => { + device_lock.trigger_exit(); + return; + } + } + } + WaitResult::EoF(handler) => { + if uapi_fd >= 0 && uapi_fd == handler.fd() { + device_lock.trigger_exit(); + return; + } + handler.cancel(); + } + WaitResult::Error(e) => tracing::error!(message = "Poll error", error = ?e), + } + } + } + } +} + +impl Drop for DeviceHandle { + fn drop(&mut self) { + self.device.read().trigger_exit(); + self.clean(); + } +} + +impl Device { + fn next_index(&mut self) -> u32 { + self.next_index.next() + } + + fn remove_peer(&mut self, pub_key: &x25519::PublicKey) { + if let Some(peer) = self.peers.remove(pub_key) { + // Found a peer to remove, now purge all references to it: + { + let p = peer.lock(); + p.shutdown_endpoint(); // close open udp socket and free the closure + self.peers_by_idx.remove(&p.index()); + } + self.peers_by_ip + .remove(&|p: &Arc>| Arc::ptr_eq(&peer, p)); + + tracing::info!("Peer removed"); + } + } + + #[allow(clippy::too_many_arguments)] + fn update_peer( + &mut self, + pub_key: x25519::PublicKey, + remove: bool, + _replace_ips: bool, + endpoint: Option, + allowed_ips: &[AllowedIP], + keepalive: Option, + preshared_key: Option<[u8; 32]>, + ) { + if remove { + // Completely remove a peer + return self.remove_peer(&pub_key); + } + + // Update an existing peer + if self.peers.get(&pub_key).is_some() { + // We already have a peer, we need to merge the existing config into the newly created one + panic!("Modifying existing peers is not yet supported. Remove and add again instead."); + } + + let next_index = self.next_index(); + let device_key_pair = self + .key_pair + .as_ref() + .expect("Private key must be set first"); + + let tunn = Tunn::new( + device_key_pair.0.clone(), + pub_key, + preshared_key, + keepalive, + next_index, + None, + ); + + let peer = Peer::new(tunn, next_index, endpoint, allowed_ips, preshared_key); + + let peer = Arc::new(Mutex::new(peer)); + self.peers.insert(pub_key, Arc::clone(&peer)); + self.peers_by_idx.insert(next_index, Arc::clone(&peer)); + + for AllowedIP { addr, cidr } in allowed_ips { + self.peers_by_ip + .insert(*addr, *cidr as _, Arc::clone(&peer)); + } + + tracing::info!("Peer added"); + } + + pub fn new(name: &str, config: DeviceConfig) -> Result { + let poll = EventPoll::::new()?; + + // Create a tunnel device + let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?); + let mtu = iface.mtu()?; + + #[cfg(not(target_os = "linux"))] + let uapi_fd = -1; + #[cfg(target_os = "linux")] + let uapi_fd = config.uapi_fd; + + let mut device = Device { + queue: Arc::new(poll), + iface, + config, + exit_notice: Default::default(), + yield_notice: Default::default(), + fwmark: Default::default(), + key_pair: Default::default(), + listen_port: Default::default(), + next_index: Default::default(), + peers: Default::default(), + peers_by_idx: Default::default(), + peers_by_ip: AllowedIps::new(), + udp4: Default::default(), + udp6: Default::default(), + cleanup_paths: Default::default(), + mtu: AtomicUsize::new(mtu), + rate_limiter: None, + #[cfg(target_os = "linux")] + uapi_fd, + }; + + if uapi_fd >= 0 { + device.register_api_fd(uapi_fd)?; + } else { + device.register_api_handler()?; + } + device.register_iface_handler(Arc::clone(&device.iface))?; + device.register_notifiers()?; + device.register_timers()?; + + #[cfg(target_os = "macos")] + { + // Only for macOS write the actual socket name into WG_TUN_NAME_FILE + if let Ok(name_file) = std::env::var("WG_TUN_NAME_FILE") { + if name == "utun" { + std::fs::write(&name_file, device.iface.name().unwrap().as_bytes()).unwrap(); + device.cleanup_paths.push(name_file); + } + } + } + + Ok(device) + } + + fn open_listen_socket(&mut self, mut port: u16) -> Result<(), Error> { + // Binds the network facing interfaces + // First close any existing open socket, and remove them from the event loop + if let Some(s) = self.udp4.take() { + unsafe { + // This is safe because the event loop is not running yet + self.queue.clear_event_by_fd(s.as_raw_fd()) + } + }; + + if let Some(s) = self.udp6.take() { + unsafe { self.queue.clear_event_by_fd(s.as_raw_fd()) }; + } + + for peer in self.peers.values() { + peer.lock().shutdown_endpoint(); + } + + // Then open new sockets and bind to the port + let udp_sock4 = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + udp_sock4.set_reuse_address(true)?; + udp_sock4.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?; + udp_sock4.set_nonblocking(true)?; + + if port == 0 { + // Random port was assigned + port = udp_sock4.local_addr()?.as_socket().unwrap().port(); + } + + let udp_sock6 = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + udp_sock6.set_reuse_address(true)?; + udp_sock6.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into())?; + udp_sock6.set_nonblocking(true)?; + + self.register_udp_handler(udp_sock4.try_clone().unwrap())?; + self.register_udp_handler(udp_sock6.try_clone().unwrap())?; + self.udp4 = Some(udp_sock4); + self.udp6 = Some(udp_sock6); + + self.listen_port = port; + + Ok(()) + } + + fn set_key(&mut self, private_key: x25519::StaticSecret) { + let public_key = x25519::PublicKey::from(&private_key); + let key_pair = Some((private_key.clone(), public_key)); + + // x25519 (rightly) doesn't let us expose secret keys for comparison. + // If the public keys are the same, then the private keys are the same. + if Some(&public_key) == self.key_pair.as_ref().map(|p| &p.1) { + return; + } + + let rate_limiter = Arc::new(RateLimiter::new(&public_key, HANDSHAKE_RATE_LIMIT)); + + for peer in self.peers.values_mut() { + peer.lock().tunnel.set_static_private( + private_key.clone(), + public_key, + Some(Arc::clone(&rate_limiter)), + ) + } + + self.key_pair = key_pair; + self.rate_limiter = Some(rate_limiter); + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + fn set_fwmark(&mut self, mark: u32) -> Result<(), Error> { + self.fwmark = Some(mark); + + // First set fwmark on listeners + if let Some(ref sock) = self.udp4 { + sock.set_mark(mark)?; + } + + if let Some(ref sock) = self.udp6 { + sock.set_mark(mark)?; + } + + // Then on all currently connected sockets + for peer in self.peers.values() { + if let Some(ref sock) = peer.lock().endpoint().conn { + sock.set_mark(mark)? + } + } + + Ok(()) + } + + fn clear_peers(&mut self) { + self.peers.clear(); + self.peers_by_idx.clear(); + self.peers_by_ip.clear(); + } + + fn register_notifiers(&mut self) -> Result<(), Error> { + let yield_ev = self + .queue + // The notification event handler simply returns Action::Yield + .new_notifier(Box::new(|_, _| Action::Yield))?; + self.yield_notice = Some(yield_ev); + + let exit_ev = self + .queue + // The exit event handler simply returns Action::Exit + .new_notifier(Box::new(|_, _| Action::Exit))?; + self.exit_notice = Some(exit_ev); + Ok(()) + } + + fn register_timers(&self) -> Result<(), Error> { + self.queue.new_periodic_event( + // Reset the rate limiter every second give or take + Box::new(|d, _| { + if let Some(r) = d.rate_limiter.as_ref() { + r.reset_count() + } + Action::Continue + }), + std::time::Duration::from_secs(1), + )?; + + self.queue.new_periodic_event( + // Execute the timed function of every peer in the list + Box::new(|d, t| { + let peer_map = &d.peers; + + let (udp4, udp6) = match (d.udp4.as_ref(), d.udp6.as_ref()) { + (Some(udp4), Some(udp6)) => (udp4, udp6), + _ => return Action::Continue, + }; + + // Go over each peer and invoke the timer function + for peer in peer_map.values() { + let mut p = peer.lock(); + let endpoint_addr = match p.endpoint().addr { + Some(addr) => addr, + None => continue, + }; + + match p.update_timers(&mut t.dst_buf[..]) { + TunnResult::Done => {} + TunnResult::Err(WireGuardError::ConnectionExpired) => { + p.shutdown_endpoint(); // close open udp socket + } + TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e), + TunnResult::WriteToNetwork(packet) => { + match endpoint_addr { + SocketAddr::V4(_) => { + udp4.send_to(packet, &endpoint_addr.into()).ok() + } + SocketAddr::V6(_) => { + udp6.send_to(packet, &endpoint_addr.into()).ok() + } + }; + } + _ => panic!("Unexpected result from update_timers"), + }; + } + Action::Continue + }), + std::time::Duration::from_millis(250), + )?; + Ok(()) + } + + pub(crate) fn trigger_yield(&self) { + self.queue + .trigger_notification(self.yield_notice.as_ref().unwrap()) + } + + pub(crate) fn trigger_exit(&self) { + self.queue + .trigger_notification(self.exit_notice.as_ref().unwrap()) + } + + pub(crate) fn cancel_yield(&self) { + self.queue + .stop_notification(self.yield_notice.as_ref().unwrap()) + } + + fn register_udp_handler(&self, udp: socket2::Socket) -> Result<(), Error> { + self.queue.new_event( + udp.as_raw_fd(), + Box::new(move |d, t| { + // Handler that handles anonymous packets over UDP + let mut iter = MAX_ITR; + let (private_key, public_key) = d.key_pair.as_ref().expect("Key not set"); + + let rate_limiter = d.rate_limiter.as_ref().unwrap(); + + // Loop while we have packets on the anonymous connection + + // Safety: the `recv_from` implementation promises not to write uninitialised + // bytes to the buffer, so this casting is safe. + let src_buf = + unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit]) }; + while let Ok((packet_len, addr)) = udp.recv_from(src_buf) { + let packet = &t.src_buf[..packet_len]; + // The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie + let parsed_packet = match rate_limiter.verify_packet( + Some(addr.as_socket().unwrap().ip()), + packet, + &mut t.dst_buf, + ) { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + let _: Result<_, _> = udp.send_to(cookie, &addr); + continue; + } + Err(_) => continue, + }; + + let peer = match &parsed_packet { + Packet::HandshakeInit(p) => { + parse_handshake_anon(private_key, public_key, p) + .ok() + .and_then(|hh| { + d.peers.get(&x25519::PublicKey::from(hh.peer_static_public)) + }) + } + Packet::HandshakeResponse(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + Packet::PacketCookieReply(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + Packet::PacketData(p) => d.peers_by_idx.get(&(p.receiver_idx >> 8)), + }; + + let peer = match peer { + None => continue, + Some(peer) => peer, + }; + + let mut p = peer.lock(); + + // We found a peer, use it to decapsulate the message+ + let mut flush = false; // Are there packets to send from the queue? + match p + .tunnel + .handle_verified_packet(parsed_packet, &mut t.dst_buf[..]) + { + TunnResult::Done => {} + TunnResult::Err(_) => continue, + TunnResult::WriteToNetwork(packet) => { + flush = true; + let _: Result<_, _> = udp.send_to(packet, &addr); + } + TunnResult::WriteToTunnelV4(packet, addr) => { + if p.is_allowed_ip(addr) { + t.iface.write4(packet); + } + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if p.is_allowed_ip(addr) { + t.iface.write6(packet); + } + } + }; + + if flush { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..]) + { + let _: Result<_, _> = udp.send_to(packet, &addr); + } + } + + // This packet was OK, that means we want to create a connected socket for this peer + let addr = addr.as_socket().unwrap(); + let ip_addr = addr.ip(); + p.set_endpoint(addr); + if d.config.use_connected_socket { + if let Ok(sock) = p.connect_endpoint(d.listen_port, d.fwmark) { + d.register_conn_handler(Arc::clone(peer), sock, ip_addr) + .unwrap(); + } + } + + iter -= 1; + if iter == 0 { + break; + } + } + Action::Continue + }), + )?; + Ok(()) + } + + fn register_conn_handler( + &self, + peer: Arc>, + udp: socket2::Socket, + peer_addr: IpAddr, + ) -> Result<(), Error> { + self.queue.new_event( + udp.as_raw_fd(), + Box::new(move |_, t| { + // The conn_handler handles packet received from a connected UDP socket, associated + // with a known peer, this saves us the hustle of finding the right peer. If another + // peer gets the same ip, it will be ignored until the socket does not expire. + let iface = &t.iface; + let mut iter = MAX_ITR; + + // Safety: the `recv_from` implementation promises not to write uninitialised + // bytes to the buffer, so this casting is safe. + let src_buf = + unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit]) }; + + while let Ok(read_bytes) = udp.recv(src_buf) { + let mut flush = false; + let mut p = peer.lock(); + match p.tunnel.decapsulate( + Some(peer_addr), + &t.src_buf[..read_bytes], + &mut t.dst_buf[..], + ) { + TunnResult::Done => {} + TunnResult::Err(e) => eprintln!("Decapsulate error {:?}", e), + TunnResult::WriteToNetwork(packet) => { + flush = true; + let _: Result<_, _> = udp.send(packet); + } + TunnResult::WriteToTunnelV4(packet, addr) => { + if p.is_allowed_ip(addr) { + iface.write4(packet); + } + } + TunnResult::WriteToTunnelV6(packet, addr) => { + if p.is_allowed_ip(addr) { + iface.write6(packet); + } + } + }; + + if flush { + // Flush pending queue + while let TunnResult::WriteToNetwork(packet) = + p.tunnel.decapsulate(None, &[], &mut t.dst_buf[..]) + { + let _: Result<_, _> = udp.send(packet); + } + } + + iter -= 1; + if iter == 0 { + break; + } + } + Action::Continue + }), + )?; + Ok(()) + } + + fn register_iface_handler(&self, iface: Arc) -> Result<(), Error> { + self.queue.new_event( + iface.as_raw_fd(), + Box::new(move |d, t| { + // The iface_handler handles packets received from the WireGuard virtual network + // interface. The flow is as follows: + // * Read a packet + // * Determine peer based on packet destination ip + // * Encapsulate the packet for the given peer + // * Send encapsulated packet to the peer's endpoint + let mtu = d.mtu.load(Ordering::Relaxed); + + let udp4 = d.udp4.as_ref().expect("Not connected"); + let udp6 = d.udp6.as_ref().expect("Not connected"); + + let peers = &d.peers_by_ip; + for _ in 0..MAX_ITR { + let src = match iface.read(&mut t.src_buf[..mtu]) { + Ok(src) => src, + Err(Error::IfaceRead(e)) => { + let ek = e.kind(); + if ek == io::ErrorKind::Interrupted || ek == io::ErrorKind::WouldBlock { + break; + } + eprintln!("Fatal read error on tun interface: {:?}", e); + return Action::Exit; + } + Err(e) => { + eprintln!("Unexpected error on tun interface: {:?}", e); + return Action::Exit; + } + }; + + let dst_addr = match Tunn::dst_address(src) { + Some(addr) => addr, + None => continue, + }; + + let mut peer = match peers.find(dst_addr) { + Some(peer) => peer.lock(), + None => continue, + }; + + match peer.tunnel.encapsulate(src, &mut t.dst_buf[..]) { + TunnResult::Done => {} + TunnResult::Err(e) => { + tracing::error!(message = "Encapsulate error", error = ?e) + } + TunnResult::WriteToNetwork(packet) => { + let mut endpoint = peer.endpoint_mut(); + if let Some(conn) = endpoint.conn.as_mut() { + // Prefer to send using the connected socket + let _: Result<_, _> = conn.write(packet); + } else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr { + let _: Result<_, _> = udp4.send_to(packet, &addr.into()); + } else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr { + let _: Result<_, _> = udp6.send_to(packet, &addr.into()); + } else { + tracing::error!("No endpoint"); + } + } + _ => panic!("Unexpected result from encapsulate"), + }; + } + Action::Continue + }), + )?; + Ok(()) + } +} + +/// A basic linear-feedback shift register implemented as xorshift, used to +/// distribute peer indexes across the 24-bit address space reserved for peer +/// identification. +/// The purpose is to obscure the total number of peers using the system and to +/// ensure it requires a non-trivial amount of processing power and/or samples +/// to guess other peers' indices. Anything more ambitious than this is wasted +/// with only 24 bits of space. +struct IndexLfsr { + initial: u32, + lfsr: u32, + mask: u32, +} + +impl IndexLfsr { + /// Generate a random 24-bit nonzero integer + fn random_index() -> u32 { + const LFSR_MAX: u32 = 0xffffff; // 24-bit seed + loop { + let i = OsRng.next_u32() & LFSR_MAX; + if i > 0 { + // LFSR seed must be non-zero + return i; + } + } + } + + /// Generate the next value in the pseudorandom sequence + fn next(&mut self) -> u32 { + // 24-bit polynomial for randomness. This is arbitrarily chosen to + // inject bitflips into the value. + const LFSR_POLY: u32 = 0xd80000; // 24-bit polynomial + let value = self.lfsr - 1; // lfsr will never have value of 0 + self.lfsr = (self.lfsr >> 1) ^ ((0u32.wrapping_sub(self.lfsr & 1u32)) & LFSR_POLY); + assert!(self.lfsr != self.initial, "Too many peers created"); + value ^ self.mask + } +} + +impl Default for IndexLfsr { + fn default() -> Self { + let seed = Self::random_index(); + IndexLfsr { + initial: seed, + lfsr: seed, + mask: Self::random_index(), + } + } +} diff --git a/lib/boringtun/src/device/peer.rs b/lib/boringtun/src/device/peer.rs new file mode 100644 index 0000000..d7f2c22 --- /dev/null +++ b/lib/boringtun/src/device/peer.rs @@ -0,0 +1,170 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use parking_lot::RwLock; +use socket2::{Domain, Protocol, Type}; + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::str::FromStr; + +use crate::device::{AllowedIps, Error}; +use crate::noise::{Tunn, TunnResult}; + +#[derive(Default, Debug)] +pub struct Endpoint { + pub addr: Option, + pub conn: Option, +} + +pub struct Peer { + /// The associated tunnel struct + pub(crate) tunnel: Tunn, + /// The index the tunnel uses + index: u32, + endpoint: RwLock, + allowed_ips: AllowedIps<()>, + preshared_key: Option<[u8; 32]>, +} + +#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] +pub struct AllowedIP { + pub addr: IpAddr, + pub cidr: u8, +} + +impl FromStr for AllowedIP { + type Err = String; + + fn from_str(s: &str) -> Result { + let ip: Vec<&str> = s.split('/').collect(); + if ip.len() != 2 { + return Err("Invalid IP format".to_owned()); + } + + let (addr, cidr) = (ip[0].parse::(), ip[1].parse::()); + match (addr, cidr) { + (Ok(addr @ IpAddr::V4(_)), Ok(cidr)) if cidr <= 32 => Ok(AllowedIP { addr, cidr }), + (Ok(addr @ IpAddr::V6(_)), Ok(cidr)) if cidr <= 128 => Ok(AllowedIP { addr, cidr }), + _ => Err("Invalid IP format".to_owned()), + } + } +} + +impl Peer { + pub fn new( + tunnel: Tunn, + index: u32, + endpoint: Option, + allowed_ips: &[AllowedIP], + preshared_key: Option<[u8; 32]>, + ) -> Peer { + Peer { + tunnel, + index, + endpoint: RwLock::new(Endpoint { + addr: endpoint, + conn: None, + }), + allowed_ips: allowed_ips.iter().map(|ip| (ip, ())).collect(), + preshared_key, + } + } + + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + self.tunnel.update_timers(dst) + } + + pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> { + self.endpoint.read() + } + + pub(crate) fn endpoint_mut(&self) -> parking_lot::RwLockWriteGuard<'_, Endpoint> { + self.endpoint.write() + } + + pub fn shutdown_endpoint(&self) { + if let Some(conn) = self.endpoint.write().conn.take() { + tracing::info!("Disconnecting from endpoint"); + conn.shutdown(Shutdown::Both).unwrap(); + } + } + + pub fn set_endpoint(&self, addr: SocketAddr) { + let mut endpoint = self.endpoint.write(); + if endpoint.addr != Some(addr) { + // We only need to update the endpoint if it differs from the current one + if let Some(conn) = endpoint.conn.take() { + conn.shutdown(Shutdown::Both).unwrap(); + } + + endpoint.addr = Some(addr); + } + } + + pub fn connect_endpoint( + &self, + port: u16, + fwmark: Option, + ) -> Result { + let mut endpoint = self.endpoint.write(); + + if endpoint.conn.is_some() { + return Err(Error::Connect("Connected".to_owned())); + } + + let addr = endpoint + .addr + .expect("Attempt to connect to undefined endpoint"); + + let udp_conn = + socket2::Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::UDP))?; + udp_conn.set_reuse_address(true)?; + let bind_addr = if addr.is_ipv4() { + SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into() + } else { + SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0).into() + }; + udp_conn.bind(&bind_addr)?; + udp_conn.connect(&addr.into())?; + udp_conn.set_nonblocking(true)?; + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + if let Some(fwmark) = fwmark { + udp_conn.set_mark(fwmark)?; + } + + tracing::info!( + message="Connected endpoint", + port=port, + endpoint=?endpoint.addr.unwrap() + ); + + endpoint.conn = Some(udp_conn.try_clone().unwrap()); + + Ok(udp_conn) + } + + pub fn is_allowed_ip>(&self, addr: I) -> bool { + self.allowed_ips.find(addr.into()).is_some() + } + + pub fn allowed_ips(&self) -> impl Iterator + '_ { + self.allowed_ips.iter().map(|(_, ip, cidr)| (ip, cidr)) + } + + pub fn time_since_last_handshake(&self) -> Option { + self.tunnel.time_since_last_handshake() + } + + pub fn persistent_keepalive(&self) -> Option { + self.tunnel.persistent_keepalive() + } + + pub fn preshared_key(&self) -> Option<&[u8; 32]> { + self.preshared_key.as_ref() + } + + pub fn index(&self) -> u32 { + self.index + } +} diff --git a/lib/boringtun/src/device/tun_darwin.rs b/lib/boringtun/src/device/tun_darwin.rs new file mode 100644 index 0000000..0732e95 --- /dev/null +++ b/lib/boringtun/src/device/tun_darwin.rs @@ -0,0 +1,256 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::Error; +use libc::*; +use std::io; +use std::mem::size_of; +use std::mem::size_of_val; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::ptr::null_mut; + +const CTRL_NAME: &[u8] = b"com.apple.net.utun_control"; + +#[repr(C)] +pub struct ctl_info { + pub ctl_id: u32, + pub ctl_name: [c_uchar; 96], +} + +#[repr(C)] +union IfrIfru { + ifru_addr: sockaddr, + ifru_addr_v4: sockaddr_in, + ifru_addr_v6: sockaddr_in, + ifru_dstaddr: sockaddr, + ifru_broadaddr: sockaddr, + ifru_flags: c_short, + ifru_metric: c_int, + ifru_mtu: c_int, + ifru_phys: c_int, + ifru_media: c_int, + ifru_intval: c_int, + //ifru_data: caddr_t, + //ifru_devmtu: ifdevmtu, + //ifru_kpi: ifkpi, + ifru_wake_flags: u32, + ifru_route_refcnt: u32, + ifru_cap: [c_int; 2], + ifru_functional_type: u32, +} + +#[repr(C)] +pub struct ifreq { + ifr_name: [c_uchar; IF_NAMESIZE], + ifr_ifru: IfrIfru, +} + +const CTLIOCGINFO: u64 = 0x0000_0000_c064_4e03; +const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933; + +#[derive(Default, Debug)] +pub struct TunSocket { + fd: RawFd, +} + +impl Drop for TunSocket { + fn drop(&mut self) { + unsafe { close(self.fd) }; + } +} + +impl AsRawFd for TunSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +// On Darwin tunnel can only be named utunXXX +pub fn parse_utun_name(name: &str) -> Result { + if !name.starts_with("utun") { + return Err(Error::InvalidTunnelName); + } + + match name.get(4..) { + None | Some("") => { + // The name is simply "utun" + Ok(0) + } + Some(idx) => { + // Everything past utun should represent an integer index + idx.parse::() + .map_err(|_| Error::InvalidTunnelName) + .map(|x| x + 1) + } + } +} + +impl TunSocket { + fn write(&self, src: &[u8], af: u8) -> usize { + let mut hdr = [0u8, 0u8, 0u8, af as u8]; + let mut iov = [ + iovec { + iov_base: hdr.as_mut_ptr() as _, + iov_len: hdr.len(), + }, + iovec { + iov_base: src.as_ptr() as _, + iov_len: src.len(), + }, + ]; + + let msg_hdr = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iov[0], + msg_iovlen: iov.len() as _, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + match unsafe { sendmsg(self.fd, &msg_hdr, 0) } { + -1 => 0, + n => n as usize, + } + } + + pub fn new(name: &str) -> Result { + let idx = parse_utun_name(name)?; + + let fd = match unsafe { socket(PF_SYSTEM, SOCK_DGRAM, SYSPROTO_CONTROL) } { + -1 => return Err(Error::Socket(io::Error::last_os_error())), + fd => fd, + }; + + let mut info = ctl_info { + ctl_id: 0, + ctl_name: [0u8; 96], + }; + info.ctl_name[..CTRL_NAME.len()].copy_from_slice(CTRL_NAME); + + if unsafe { ioctl(fd, CTLIOCGINFO, &mut info as *mut ctl_info) } < 0 { + unsafe { close(fd) }; + return Err(Error::IOCtl(io::Error::last_os_error())); + } + + let addr = sockaddr_ctl { + sc_len: size_of::() as u8, + sc_family: AF_SYSTEM as u8, + ss_sysaddr: AF_SYS_CONTROL as u16, + sc_id: info.ctl_id, + sc_unit: idx, + sc_reserved: Default::default(), + }; + + if unsafe { + connect( + fd, + &addr as *const sockaddr_ctl as _, + size_of_val(&addr) as _, + ) + } < 0 + { + unsafe { close(fd) }; + let mut err_string = io::Error::last_os_error().to_string(); + err_string.push_str("(did you run with sudo?)"); + return Err(Error::Connect(err_string)); + } + + Ok(TunSocket { fd }) + } + + pub fn set_non_blocking(self) -> Result { + match unsafe { fcntl(self.fd, F_GETFL) } { + -1 => Err(Error::FCntl(io::Error::last_os_error())), + flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } { + -1 => Err(Error::FCntl(io::Error::last_os_error())), + _ => Ok(self), + }, + } + } + + pub fn name(&self) -> Result { + let mut tunnel_name = [0u8; 256]; + let mut tunnel_name_len: socklen_t = tunnel_name.len() as u32; + if unsafe { + getsockopt( + self.fd, + SYSPROTO_CONTROL, + UTUN_OPT_IFNAME, + tunnel_name.as_mut_ptr() as _, + &mut tunnel_name_len, + ) + } < 0 + || tunnel_name_len == 0 + { + return Err(Error::GetSockOpt(io::Error::last_os_error())); + } + + Ok(String::from_utf8_lossy(&tunnel_name[..(tunnel_name_len - 1) as usize]).to_string()) + } + + /// Get the current MTU value + pub fn mtu(&self) -> Result { + let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } { + -1 => return Err(Error::Socket(io::Error::last_os_error())), + fd => fd, + }; + + let name = self.name()?; + let iface_name: &[u8] = name.as_ref(); + let mut ifr = ifreq { + ifr_name: [0; IF_NAMESIZE], + ifr_ifru: IfrIfru { ifru_mtu: 0 }, + }; + + ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name); + + if unsafe { ioctl(fd, SIOCGIFMTU, &ifr) } < 0 { + return Err(Error::IOCtl(io::Error::last_os_error())); + } + + unsafe { close(fd) }; + + Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _) + } + + pub fn write4(&self, src: &[u8]) -> usize { + self.write(src, AF_INET as u8) + } + + pub fn write6(&self, src: &[u8]) -> usize { + self.write(src, AF_INET6 as u8) + } + + pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> { + let mut hdr = [0u8; 4]; + + let mut iov = [ + iovec { + iov_base: hdr.as_mut_ptr() as _, + iov_len: hdr.len(), + }, + iovec { + iov_base: dst.as_mut_ptr() as _, + iov_len: dst.len(), + }, + ]; + + let mut msg_hdr = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iov[0], + msg_iovlen: iov.len() as _, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + match unsafe { recvmsg(self.fd, &mut msg_hdr, 0) } { + -1 => Err(Error::IfaceRead(io::Error::last_os_error())), + 0..=4 => Ok(&mut dst[..0]), + n => Ok(&mut dst[..(n - 4) as usize]), + } + } +} diff --git a/lib/boringtun/src/device/tun_linux.rs b/lib/boringtun/src/device/tun_linux.rs new file mode 100644 index 0000000..dee2999 --- /dev/null +++ b/lib/boringtun/src/device/tun_linux.rs @@ -0,0 +1,159 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::Error; +use libc::*; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; + +const TUNSETIFF: u64 = 0x4004_54ca; + +#[repr(C)] +union IfrIfru { + ifru_addr: sockaddr, + ifru_addr_v4: sockaddr_in, + ifru_addr_v6: sockaddr_in, + ifru_dstaddr: sockaddr, + ifru_broadaddr: sockaddr, + ifru_flags: c_short, + ifru_metric: c_int, + ifru_mtu: c_int, + ifru_phys: c_int, + ifru_media: c_int, + ifru_intval: c_int, + //ifru_data: caddr_t, + //ifru_devmtu: ifdevmtu, + //ifru_kpi: ifkpi, + ifru_wake_flags: u32, + ifru_route_refcnt: u32, + ifru_cap: [c_int; 2], + ifru_functional_type: u32, +} + +#[repr(C)] +pub struct ifreq { + ifr_name: [c_uchar; IFNAMSIZ], + ifr_ifru: IfrIfru, +} + +#[derive(Default, Debug)] +pub struct TunSocket { + fd: RawFd, + name: String, +} + +impl Drop for TunSocket { + fn drop(&mut self) { + unsafe { close(self.fd) }; + } +} + +impl AsRawFd for TunSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl TunSocket { + fn write(&self, buf: &[u8]) -> usize { + match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } { + -1 => 0, + n => n as usize, + } + } + + pub fn new(name: &str) -> Result { + // If the provided name appears to be a FD, use that. + let provided_fd = name.parse::(); + if let Ok(fd) = provided_fd { + return Ok(TunSocket { + fd, + name: name.to_string(), + }); + } + + let fd = match unsafe { open(b"/dev/net/tun\0".as_ptr() as _, O_RDWR) } { + -1 => return Err(Error::Socket(io::Error::last_os_error())), + fd => fd, + }; + let iface_name = name.as_bytes(); + let mut ifr = ifreq { + ifr_name: [0; IFNAMSIZ], + ifr_ifru: IfrIfru { + ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as _, + }, + }; + + if iface_name.len() >= ifr.ifr_name.len() { + return Err(Error::InvalidTunnelName); + } + + ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name); + + if unsafe { ioctl(fd, TUNSETIFF as _, &ifr) } < 0 { + return Err(Error::IOCtl(io::Error::last_os_error())); + } + + let name = name.to_string(); + Ok(TunSocket { fd, name }) + } + + pub fn set_non_blocking(self) -> Result { + match unsafe { fcntl(self.fd, F_GETFL) } { + -1 => Err(Error::FCntl(io::Error::last_os_error())), + flags => match unsafe { fcntl(self.fd, F_SETFL, flags | O_NONBLOCK) } { + -1 => Err(Error::FCntl(io::Error::last_os_error())), + _ => Ok(self), + }, + } + } + + pub fn name(&self) -> Result { + Ok(self.name.clone()) + } + + /// Get the current MTU value + pub fn mtu(&self) -> Result { + let provided_fd = self.name.parse::(); + if provided_fd.is_ok() { + return Ok(1500); + } + + let fd = match unsafe { socket(AF_INET, SOCK_STREAM, IPPROTO_IP) } { + -1 => return Err(Error::Socket(io::Error::last_os_error())), + fd => fd, + }; + + let name = self.name()?; + let iface_name: &[u8] = name.as_ref(); + let mut ifr = ifreq { + ifr_name: [0; IF_NAMESIZE], + ifr_ifru: IfrIfru { ifru_mtu: 0 }, + }; + + ifr.ifr_name[..iface_name.len()].copy_from_slice(iface_name); + + if unsafe { ioctl(fd, SIOCGIFMTU as _, &ifr) } < 0 { + return Err(Error::IOCtl(io::Error::last_os_error())); + } + + unsafe { close(fd) }; + + Ok(unsafe { ifr.ifr_ifru.ifru_mtu } as _) + } + + pub fn write4(&self, src: &[u8]) -> usize { + self.write(src) + } + + pub fn write6(&self, src: &[u8]) -> usize { + self.write(src) + } + + pub fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error> { + match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } { + -1 => Err(Error::IfaceRead(io::Error::last_os_error())), + n => Ok(&mut dst[..n as usize]), + } + } +} diff --git a/lib/boringtun/src/ffi/mod.rs b/lib/boringtun/src/ffi/mod.rs new file mode 100644 index 0000000..1e5a2a9 --- /dev/null +++ b/lib/boringtun/src/ffi/mod.rs @@ -0,0 +1,397 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +// Requiring explicit per-fn "Safety" docs not worth it. Just pass in valid +// pointers and buffers/lengths to these, ok? +#![allow(clippy::missing_safety_doc)] + +//! C bindings for the BoringTun library +use super::noise::{Tunn, TunnResult}; +use crate::x25519::{PublicKey, StaticSecret}; +use base64::{decode, encode}; +use hex::encode as encode_hex; +use libc::{raise, SIGSEGV}; +use parking_lot::Mutex; +use rand_core::OsRng; +use tracing; +use tracing_subscriber::fmt; + +use crate::serialization::KeyBytes; +use std::ffi::{CStr, CString}; +use std::io::{Error, ErrorKind, Write}; +use std::os::raw::c_char; +use std::panic; +use std::ptr; +use std::ptr::null_mut; +use std::slice; +use std::sync::Once; + +static PANIC_HOOK: Once = Once::new(); + +#[allow(non_camel_case_types)] +#[repr(C)] +/// Indicates the operation required from the caller +pub enum result_type { + /// No operation is required. + WIREGUARD_DONE = 0, + /// Write dst buffer to network. Size indicates the number of bytes to write. + WRITE_TO_NETWORK = 1, + /// Some error occurred, no operation is required. Size indicates error code. + WIREGUARD_ERROR = 2, + /// Write dst buffer to the interface as an ipv4 packet. Size indicates the number of bytes to write. + WRITE_TO_TUNNEL_IPV4 = 4, + /// Write dst buffer to the interface as an ipv6 packet. Size indicates the number of bytes to write. + WRITE_TO_TUNNEL_IPV6 = 6, +} + +/// The return type of WireGuard functions +#[repr(C)] +pub struct wireguard_result { + /// The operation to be performed by the caller + pub op: result_type, + /// Additional information, required to perform the operation + pub size: usize, +} + +#[repr(C)] +pub struct stats { + pub time_since_last_handshake: i64, + pub tx_bytes: usize, + pub rx_bytes: usize, + pub estimated_loss: f32, + pub estimated_rtt: i32, + reserved: [u8; 56], // Make sure to add new fields in this space, keeping total size constant +} + +impl<'a> From> for wireguard_result { + fn from(res: TunnResult<'a>) -> wireguard_result { + match res { + TunnResult::Done => wireguard_result { + op: result_type::WIREGUARD_DONE, + size: 0, + }, + TunnResult::Err(e) => wireguard_result { + op: result_type::WIREGUARD_ERROR, + size: e as _, + }, + TunnResult::WriteToNetwork(b) => wireguard_result { + op: result_type::WRITE_TO_NETWORK, + size: b.len(), + }, + TunnResult::WriteToTunnelV4(b, _) => wireguard_result { + op: result_type::WRITE_TO_TUNNEL_IPV4, + size: b.len(), + }, + TunnResult::WriteToTunnelV6(b, _) => wireguard_result { + op: result_type::WRITE_TO_TUNNEL_IPV6, + size: b.len(), + }, + } + } +} + +#[repr(C)] +pub struct x25519_key { + pub key: [u8; 32], +} + +/// Generates a new x25519 secret key. +#[no_mangle] +pub extern "C" fn x25519_secret_key() -> x25519_key { + x25519_key { + key: StaticSecret::random_from_rng(OsRng).to_bytes(), + } +} + +/// Computes a public x25519 key from a secret key. +#[no_mangle] +pub extern "C" fn x25519_public_key(private_key: x25519_key) -> x25519_key { + let private = StaticSecret::from(private_key.key); + let public = PublicKey::from(&private); + x25519_key { + key: public.to_bytes(), + } +} + +/// Returns the base64 encoding of a key as a UTF8 C-string. +/// +/// The memory has to be freed by calling `x25519_key_to_str_free` +#[no_mangle] +pub extern "C" fn x25519_key_to_base64(key: x25519_key) -> *const c_char { + let encoded_key = encode(key.key); + CString::into_raw(CString::new(encoded_key).unwrap()) +} + +/// Returns the hex encoding of a key as a UTF8 C-string. +/// +/// The memory has to be freed by calling `x25519_key_to_str_free` +#[no_mangle] +pub extern "C" fn x25519_key_to_hex(key: x25519_key) -> *const c_char { + let encoded_key = encode_hex(key.key); + CString::into_raw(CString::new(encoded_key).unwrap()) +} + +/// Frees memory of the string given by `x25519_key_to_hex` or `x25519_key_to_base64` +#[no_mangle] +pub unsafe extern "C" fn x25519_key_to_str_free(stringified_key: *mut c_char) { + drop(CString::from_raw(stringified_key)); +} + +/// Check if the input C-string represents a valid base64 encoded x25519 key. +/// Return 1 if valid 0 otherwise. +#[no_mangle] +pub unsafe extern "C" fn check_base64_encoded_x25519_key(key: *const c_char) -> i32 { + let c_str = CStr::from_ptr(key); + let utf8_key = match c_str.to_str() { + Err(_) => return 0, + Ok(string) => string, + }; + + if let Ok(key) = decode(utf8_key) { + let len = key.len(); + let mut zero = 0u8; + for b in key { + zero |= b + } + if len == 32 && zero != 0 { + 1 + } else { + 0 + } + } else { + 0 + } +} + +/// Custom tracing_subscriber writer to an external function pointer +struct FFIFunctionPointerWriter { + log_func: unsafe extern "C" fn(*const c_char), +} + +/// Implements Write trait for use with tracing_subscriber +impl Write for FFIFunctionPointerWriter { + fn write(&mut self, buf: &[u8]) -> Result { + let out_str = String::from_utf8_lossy(buf).to_string(); + if let Ok(c_string) = CString::new(out_str) { + unsafe { (self.log_func)(c_string.as_ptr()) } + Ok(buf.len()) + } else { + Err(Error::new( + ErrorKind::Other, + "Failed to create CString from buffer.", + )) + } + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + // no-op + Ok(()) + } +} + +/// Sets the default tracing_subscriber to write to `log_func`. +/// +/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters. +/// Subscribes to TRACE level events. +/// +/// This function should only be called once as setting the default tracing_subscriber +/// more than once will result in an error. +/// +/// Returns false on failure. +/// +/// # Safety +/// +/// `c_char` will be freed by the library after calling `log_func`. If the value needs +/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`. +#[no_mangle] +pub unsafe extern "C" fn set_logging_function( + log_func: unsafe extern "C" fn(*const c_char), +) -> bool { + let result = std::panic::catch_unwind(|| -> bool { + let writer = FFIFunctionPointerWriter { log_func }; + let format = fmt::format() + // don't include levels in formatted output + .with_level(false) + // don't include targets + .with_target(false) + // don't 'include the thread ID of the current thread + .with_thread_ids(false) + // don't 'include the name of the current thread + .with_thread_names(false) + // use the `Compact` formatting style. + .compact() + // disable terminal escape codes + .with_ansi(false); + + fmt() + .event_format(format) + .with_writer(std::sync::Mutex::new(writer)) + .with_max_level(tracing::Level::TRACE) + .with_ansi(false) + .try_init() + .is_ok() + }); + if let Ok(value) = result { + value + } else { + false + } +} + +/// Allocate a new tunnel, return NULL on failure. +/// Keys must be valid base64 encoded 32-byte keys. +#[no_mangle] +pub unsafe extern "C" fn new_tunnel( + static_private: *const c_char, + server_static_public: *const c_char, + preshared_key: *const c_char, + keep_alive: u16, + index: u32, +) -> *mut Mutex { + let c_str = CStr::from_ptr(static_private); + let static_private = match c_str.to_str() { + Err(_) => return ptr::null_mut(), + Ok(string) => string, + }; + + let c_str = CStr::from_ptr(server_static_public); + let server_static_public = match c_str.to_str() { + Err(_) => return ptr::null_mut(), + Ok(string) => string, + }; + + let preshared_key = if preshared_key.is_null() { + None + } else { + let c_str = CStr::from_ptr(preshared_key); + + if let Ok(string) = c_str.to_str() { + if let Ok(key) = string.parse::() { + Some(key.0) + } else { + return null_mut(); + } + } else { + return null_mut(); + } + }; + + let private_key = match static_private.parse::() { + Err(_) => return ptr::null_mut(), + Ok(key) => StaticSecret::from(key.0), + }; + + let public_key = match server_static_public.parse::() { + Err(_) => return ptr::null_mut(), + Ok(key) => PublicKey::from(key.0), + }; + + let keep_alive = if keep_alive == 0 { + None + } else { + Some(keep_alive) + }; + + let tunnel = Box::new(Mutex::new(Tunn::new( + private_key, + public_key, + preshared_key, + keep_alive, + index, + None, + ))); + + PANIC_HOOK.call_once(|| { + // FFI won't properly unwind on panic, but it will if we cause a segmentation fault + panic::set_hook(Box::new(move |_| { + raise(SIGSEGV); + })); + }); + + Box::into_raw(tunnel) +} + +/// Drops the Tunn object +#[no_mangle] +pub unsafe extern "C" fn tunnel_free(tunnel: *mut Mutex) { + drop(Box::from_raw(tunnel)); +} + +/// Write an IP packet from the tunnel interface. +/// For more details check noise::tunnel_to_network functions. +#[no_mangle] +pub unsafe extern "C" fn wireguard_write( + tunnel: *const Mutex, + src: *const u8, + src_size: u32, + dst: *mut u8, + dst_size: u32, +) -> wireguard_result { + let mut tunnel = tunnel.as_ref().unwrap().lock(); + // Slices are not owned, and therefore will not be freed by Rust + let src = slice::from_raw_parts(src, src_size as usize); + let dst = slice::from_raw_parts_mut(dst, dst_size as usize); + wireguard_result::from(tunnel.encapsulate(src, dst)) +} + +/// Read a UDP packet from the server. +/// For more details check noise::network_to_tunnel functions. +#[no_mangle] +pub unsafe extern "C" fn wireguard_read( + tunnel: *const Mutex, + src: *const u8, + src_size: u32, + dst: *mut u8, + dst_size: u32, +) -> wireguard_result { + let mut tunnel = tunnel.as_ref().unwrap().lock(); + // Slices are not owned, and therefore will not be freed by Rust + let src = slice::from_raw_parts(src, src_size as usize); + let dst = slice::from_raw_parts_mut(dst, dst_size as usize); + wireguard_result::from(tunnel.decapsulate(None, src, dst)) +} + +/// This is a state keeping function, that need to be called periodically. +/// Recommended interval: 100ms. +#[no_mangle] +pub unsafe extern "C" fn wireguard_tick( + tunnel: *const Mutex, + dst: *mut u8, + dst_size: u32, +) -> wireguard_result { + let mut tunnel = tunnel.as_ref().unwrap().lock(); + // Slices are not owned, and therefore will not be freed by Rust + let dst = slice::from_raw_parts_mut(dst, dst_size as usize); + wireguard_result::from(tunnel.update_timers(dst)) +} + +/// Force the tunnel to initiate a new handshake, dst buffer must be at least 148 byte long. +#[no_mangle] +pub unsafe extern "C" fn wireguard_force_handshake( + tunnel: *const Mutex, + dst: *mut u8, + dst_size: u32, +) -> wireguard_result { + let mut tunnel = tunnel.as_ref().unwrap().lock(); + // Slices are not owned, and therefore will not be freed by Rust + let dst = slice::from_raw_parts_mut(dst, dst_size as usize); + wireguard_result::from(tunnel.format_handshake_initiation(dst, true)) +} + +/// Returns stats from the tunnel: +/// Time of last handshake in seconds (or -1 if no handshake occurred) +/// Number of data bytes encapsulated +/// Number of data bytes decapsulated +#[no_mangle] +pub unsafe extern "C" fn wireguard_stats(tunnel: *const Mutex) -> stats { + let tunnel = tunnel.as_ref().unwrap().lock(); + let (time, tx_bytes, rx_bytes, estimated_loss, estimated_rtt) = tunnel.stats(); + stats { + time_since_last_handshake: time.map(|t| t.as_secs() as i64).unwrap_or(-1), + tx_bytes, + rx_bytes, + estimated_loss, + estimated_rtt: estimated_rtt.map(|r| r as i32).unwrap_or(-1), + reserved: [0u8; 56], + } +} diff --git a/lib/boringtun/src/jni.rs b/lib/boringtun/src/jni.rs new file mode 100644 index 0000000..7bc2bdc --- /dev/null +++ b/lib/boringtun/src/jni.rs @@ -0,0 +1,271 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +// temporary, we need to do some verification around these bindings later +#![allow(clippy::missing_safety_doc)] + +/// JNI bindings for BoringTun library +use std::os::raw::c_char; +use std::ptr; + +use jni::objects::{JByteBuffer, JClass, JString}; +use jni::strings::JNIStr; +use jni::sys::{jbyteArray, jint, jlong, jshort, jstring}; +use jni::JNIEnv; +use parking_lot::Mutex; + +use crate::ffi::new_tunnel; +use crate::ffi::wireguard_read; +use crate::ffi::wireguard_result; +use crate::ffi::wireguard_tick; +use crate::ffi::wireguard_write; +use crate::ffi::x25519_key; +use crate::ffi::x25519_key_to_base64; +use crate::ffi::x25519_key_to_hex; +use crate::ffi::x25519_public_key; +use crate::ffi::x25519_secret_key; + +use crate::noise::Tunn; + +pub extern "C" fn log_print(_log_string: *const c_char) { + /* + XXX: + Define callback function in app. + */ +} + +/// Generates new x25519 secret key and converts into java byte array. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1secret_1key"] +pub extern "C" fn generate_secret_key(env: JNIEnv, _class: JClass) -> jbyteArray { + match env.byte_array_from_slice(&x25519_secret_key().key) { + Ok(v) => v, + Err(_) => ptr::null_mut(), + } +} + +/// Computes public x25519 key from secret key and converts into java byte array. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1public_1key"] +pub unsafe extern "C" fn generate_public_key1( + env: JNIEnv, + _class: JClass, + arg_secret_key: jbyteArray, +) -> jbyteArray { + let mut key_inner = [0; 32]; + + if env + .get_byte_array_region(arg_secret_key, 0, &mut key_inner) + .is_err() + { + return ptr::null_mut(); + } + + let secret_key = x25519_key { + key: std::mem::transmute::<[i8; 32], [u8; 32]>(key_inner), + }; + + match env.byte_array_from_slice(&x25519_public_key(secret_key).key) { + Ok(v) => v, + Err(_) => ptr::null_mut(), + } +} + +/// Converts x25519 key to hex string. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1hex"] +pub unsafe extern "C" fn convert_x25519_key_to_hex( + env: JNIEnv, + _class: JClass, + arg_key: jbyteArray, +) -> jstring { + let mut key = [0; 32]; + + if env.get_byte_array_region(arg_key, 0, &mut key).is_err() { + return ptr::null_mut(); + } + + let x25519_key = x25519_key { + key: std::mem::transmute::<[i8; 32], [u8; 32]>(key), + }; + + let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_hex(x25519_key)).to_owned()) { + Ok(v) => v, + Err(_) => return ptr::null_mut(), + }; + + output.into_inner() +} + +/// Converts x25519 key to base64 string. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_x25519_1key_1to_1base64"] +pub unsafe extern "C" fn convert_x25519_key_to_base64( + env: JNIEnv, + _class: JClass, + arg_key: jbyteArray, +) -> jstring { + let mut key = [0; 32]; + + if env.get_byte_array_region(arg_key, 0, &mut key).is_err() { + return ptr::null_mut(); + } + + let x25519_key = x25519_key { + key: std::mem::transmute::<[i8; 32], [u8; 32]>(key), + }; + + let output = match env.new_string(JNIStr::from_ptr(x25519_key_to_base64(x25519_key)).to_owned()) + { + Ok(v) => v, + Err(_) => return ptr::null_mut(), + }; + + output.into_inner() +} + +/// Creates new tunnel +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_new_1tunnel"] +pub unsafe extern "C" fn create_new_tunnel( + env: JNIEnv, + _class: JClass, + arg_secret_key: JString, + arg_public_key: JString, + arg_preshared_key: JString, + keep_alive: jshort, + index: jint, +) -> jlong { + let secret_key = match env.get_string_utf_chars(arg_secret_key) { + Ok(v) => v, + Err(_) => return 0, + }; + + let public_key = match env.get_string_utf_chars(arg_public_key) { + Ok(v) => v, + Err(_) => return 0, + }; + + let preshared_key = if arg_preshared_key.is_null() { + ptr::null_mut() + } else { + match env.get_string_utf_chars(arg_preshared_key) { + Ok(v) => v, + Err(_) => return 0, + } + }; + + let tunnel = new_tunnel( + secret_key, + public_key, + preshared_key, + keep_alive as u16, + index as u32, + ); + + if tunnel.is_null() { + return 0; + } + + tunnel as jlong +} + +/// Encrypts raw IP packets into WG formatted packets. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1write"] +pub unsafe extern "C" fn encrypt_raw_packet( + env: JNIEnv, + _class: JClass, + tunnel: jlong, + src: jbyteArray, + src_size: jint, + dst: JByteBuffer, + dst_size: jint, + op: JByteBuffer, +) -> jint { + let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let output: wireguard_result = wireguard_write( + tunnel as *const Mutex, + env.convert_byte_array(src).unwrap().as_mut_ptr(), + src_size as u32, + dst_ptr, + dst_size as u32, + ); + *op_ptr = output.op as u8; + + output.size as i32 +} + +/// Decrypts WG formatted packets into raw IP packets. +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1read"] +pub unsafe extern "C" fn decrypt_to_raw_packet( + env: JNIEnv, + _class: JClass, + tunnel: jlong, + src: jbyteArray, + src_size: jint, + dst: JByteBuffer, + dst_size: jint, + op: JByteBuffer, +) -> jint { + let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let output: wireguard_result = wireguard_read( + tunnel as *const Mutex, + env.convert_byte_array(src).unwrap().as_mut_ptr(), + src_size as u32, + dst_ptr, + dst_size as u32, + ); + + *op_ptr = output.op as u8; + + output.size as i32 +} + +/// Periodic function that writes WG formatted packets into destination buffer +#[no_mangle] +#[export_name = "Java_com_cloudflare_app_boringtun_BoringTunJNI_wireguard_1tick"] +pub unsafe extern "C" fn run_periodic_task( + env: JNIEnv, + _class: JClass, + tunnel: jlong, + dst: JByteBuffer, + dst_size: jint, + op: JByteBuffer, +) -> jint { + let dst_ptr: *mut u8 = match env.get_direct_buffer_address(dst) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let op_ptr: *mut u8 = match env.get_direct_buffer_address(op) { + Ok(v) => v.as_mut_ptr(), + Err(_) => return 0, + }; + + let output: wireguard_result = + wireguard_tick(tunnel as *const Mutex, dst_ptr, dst_size as u32); + + *op_ptr = output.op as u8; + + output.size as i32 +} diff --git a/lib/boringtun/src/lib.rs b/lib/boringtun/src/lib.rs new file mode 100644 index 0000000..6ab410d --- /dev/null +++ b/lib/boringtun/src/lib.rs @@ -0,0 +1,27 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Simple implementation of the client-side of the WireGuard protocol. +//! +//! git clone https://github.com/cloudflare/boringtun.git + +#[cfg(feature = "device")] +pub mod device; + +#[cfg(feature = "ffi-bindings")] +pub mod ffi; +#[cfg(feature = "jni-bindings")] +pub mod jni; +pub mod noise; + +#[cfg(not(feature = "mock-instant"))] +pub(crate) mod sleepyinstant; + +pub(crate) mod serialization; + +/// Re-export of the x25519 types +pub mod x25519 { + pub use x25519_dalek::{ + EphemeralSecret, PublicKey, ReusableSecret, SharedSecret, StaticSecret, + }; +} diff --git a/lib/boringtun/src/noise/errors.rs b/lib/boringtun/src/noise/errors.rs new file mode 100644 index 0000000..10513ae --- /dev/null +++ b/lib/boringtun/src/noise/errors.rs @@ -0,0 +1,23 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#[derive(Debug)] +pub enum WireGuardError { + DestinationBufferTooSmall, + IncorrectPacketLength, + UnexpectedPacket, + WrongPacketType, + WrongIndex, + WrongKey, + InvalidTai64nTimestamp, + WrongTai64nTimestamp, + InvalidMac, + InvalidAeadTag, + InvalidCounter, + DuplicateCounter, + InvalidPacket, + NoCurrentSession, + LockFailed, + ConnectionExpired, + UnderLoad, +} diff --git a/lib/boringtun/src/noise/handshake.rs b/lib/boringtun/src/noise/handshake.rs new file mode 100644 index 0000000..40ed803 --- /dev/null +++ b/lib/boringtun/src/noise/handshake.rs @@ -0,0 +1,940 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::{HandshakeInit, HandshakeResponse, PacketCookieReply}; +use crate::noise::errors::WireGuardError; +use crate::noise::session::Session; +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; +use crate::x25519; +use aead::{Aead, Payload}; +use blake2::digest::{FixedOutput, KeyInit}; +use blake2::{Blake2s256, Blake2sMac, Digest}; +use chacha20poly1305::XChaCha20Poly1305; +use rand_core::OsRng; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::convert::TryInto; +use std::time::{Duration, SystemTime}; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; + +pub(crate) const LABEL_MAC1: &[u8; 8] = b"mac1----"; +pub(crate) const LABEL_COOKIE: &[u8; 8] = b"cookie--"; +const KEY_LEN: usize = 32; +const TIMESTAMP_LEN: usize = 12; + +// initiator.chaining_key = HASH(CONSTRUCTION) +const INITIAL_CHAIN_KEY: [u8; KEY_LEN] = [ + 96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 66, 6, 248, + 114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54, +]; + +// initiator.chaining_hash = HASH(initiator.chaining_key || IDENTIFIER) +const INITIAL_CHAIN_HASH: [u8; KEY_LEN] = [ + 34, 17, 179, 97, 8, 26, 197, 102, 105, 18, 67, 219, 69, 138, 213, 50, 45, 156, 108, 102, 34, + 147, 232, 183, 14, 225, 156, 101, 186, 7, 158, 243, +]; + +#[inline] +pub(crate) fn b2s_hash(data1: &[u8], data2: &[u8]) -> [u8; 32] { + let mut hash = Blake2s256::new(); + hash.update(data1); + hash.update(data2); + hash.finalize().into() +} + +#[inline] +/// RFC 2401 HMAC+Blake2s, not to be confused with *keyed* Blake2s +pub(crate) fn b2s_hmac(key: &[u8], data1: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// Like b2s_hmac, but chain data1 and data2 together +pub(crate) fn b2s_hmac2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 32] { + use blake2::digest::Update; + type HmacBlake2s = hmac::SimpleHmac; + let mut hmac = HmacBlake2s::new_from_slice(key).unwrap(); + hmac.update(data1); + hmac.update(data2); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16(key: &[u8], data1: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +pub(crate) fn b2s_keyed_mac_16_2(key: &[u8], data1: &[u8], data2: &[u8]) -> [u8; 16] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + blake2::digest::Update::update(&mut hmac, data2); + hmac.finalize_fixed().into() +} + +pub(crate) fn b2s_mac_24(key: &[u8], data1: &[u8]) -> [u8; 24] { + let mut hmac = Blake2sMac::new_from_slice(key).unwrap(); + blake2::digest::Update::update(&mut hmac, data1); + hmac.finalize_fixed().into() +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_seal(ciphertext: &mut [u8], key: &[u8], counter: u64, data: &[u8], aad: &[u8]) { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..12].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_seal_inner(ciphertext, key, nonce, data, aad) +} + +#[inline] +fn aead_chacha20_seal_inner( + ciphertext: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + ciphertext[..data.len()].copy_from_slice(data); + + let tag = key + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut ciphertext[..data.len()], + ) + .unwrap(); + + ciphertext[data.len()..].copy_from_slice(tag.as_ref()); +} + +#[inline] +/// This wrapper involves an extra copy and MAY BE SLOWER +fn aead_chacha20_open( + buffer: &mut [u8], + key: &[u8], + counter: u64, + data: &[u8], + aad: &[u8], +) -> Result<(), WireGuardError> { + let mut nonce: [u8; 12] = [0; 12]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + + aead_chacha20_open_inner(buffer, key, nonce, data, aad) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + Ok(()) +} + +#[inline] +fn aead_chacha20_open_inner( + buffer: &mut [u8], + key: &[u8], + nonce: [u8; 12], + data: &[u8], + aad: &[u8], +) -> Result<(), ring::error::Unspecified> { + let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, key).unwrap()); + + let mut inner_buffer = data.to_owned(); + + let plaintext = key.open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(aad), + &mut inner_buffer, + )?; + + buffer.copy_from_slice(plaintext); + + Ok(()) +} + +#[derive(Debug)] +/// This struct represents a 12 byte [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp +struct Tai64N { + secs: u64, + nano: u32, +} + +#[derive(Debug)] +/// This struct computes a [Tai64N](https://cr.yp.to/libtai/tai64.html) timestamp from current system time +struct TimeStamper { + duration_at_start: Duration, + instant_at_start: Instant, +} + +impl TimeStamper { + /// Create a new TimeStamper + pub fn new() -> TimeStamper { + TimeStamper { + duration_at_start: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap(), + instant_at_start: Instant::now(), + } + } + + /// Take time reading and generate a 12 byte timestamp + pub fn stamp(&self) -> [u8; 12] { + const TAI64_BASE: u64 = (1u64 << 62) + 37; + let mut ext_stamp = [0u8; 12]; + let stamp = Instant::now().duration_since(self.instant_at_start) + self.duration_at_start; + ext_stamp[0..8].copy_from_slice(&(stamp.as_secs() + TAI64_BASE).to_be_bytes()); + ext_stamp[8..12].copy_from_slice(&stamp.subsec_nanos().to_be_bytes()); + ext_stamp + } +} + +impl Tai64N { + /// A zeroed out timestamp + fn zero() -> Tai64N { + Tai64N { secs: 0, nano: 0 } + } + + /// Parse a timestamp from a 12 byte u8 slice + fn parse(buf: &[u8; 12]) -> Result { + if buf.len() < 12 { + return Err(WireGuardError::InvalidTai64nTimestamp); + } + + let (sec_bytes, nano_bytes) = buf.split_at(std::mem::size_of::()); + let secs = u64::from_be_bytes(sec_bytes.try_into().unwrap()); + let nano = u32::from_be_bytes(nano_bytes.try_into().unwrap()); + + // WireGuard does not actually expect tai64n timestamp, just monotonically increasing one + //if secs < (1u64 << 62) || secs >= (1u64 << 63) { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //}; + //if nano >= 1_000_000_000 { + // return Err(WireGuardError::InvalidTai64nTimestamp); + //} + + Ok(Tai64N { secs, nano }) + } + + /// Check if this timestamp represents a time that is chronologically after the time represented + /// by the other timestamp + pub fn after(&self, other: &Tai64N) -> bool { + (self.secs > other.secs) || ((self.secs == other.secs) && (self.nano > other.nano)) + } +} + +/// Parameters used by the noise protocol +struct NoiseParams { + /// Our static public key + static_public: x25519::PublicKey, + /// Our static private key + static_private: x25519::StaticSecret, + /// Static public key of the other party + peer_static_public: x25519::PublicKey, + /// A shared key = DH(static_private, peer_static_public) + static_shared: x25519::SharedSecret, + /// A pre-computation of HASH("mac1----", peer_static_public) for this peer + sending_mac1_key: [u8; KEY_LEN], + /// An optional preshared key + preshared_key: Option<[u8; KEY_LEN]>, +} + +impl std::fmt::Debug for NoiseParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseParams") + .field("static_public", &self.static_public) + .field("static_private", &"") + .field("peer_static_public", &self.peer_static_public) + .field("static_shared", &"") + .field("sending_mac1_key", &self.sending_mac1_key) + .field("preshared_key", &self.preshared_key) + .finish() + } +} + +struct HandshakeInitSentState { + local_index: u32, + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + ephemeral_private: x25519::ReusableSecret, + time_sent: Instant, +} + +impl std::fmt::Debug for HandshakeInitSentState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HandshakeInitSentState") + .field("local_index", &self.local_index) + .field("hash", &self.hash) + .field("chaining_key", &self.chaining_key) + .field("ephemeral_private", &"") + .field("time_sent", &self.time_sent) + .finish() + } +} + +#[derive(Debug)] +enum HandshakeState { + /// No handshake in process + None, + /// We initiated the handshake + InitSent(HandshakeInitSentState), + /// Handshake initiated by peer + InitReceived { + hash: [u8; KEY_LEN], + chaining_key: [u8; KEY_LEN], + peer_ephemeral_public: x25519::PublicKey, + peer_index: u32, + }, + /// Handshake was established too long ago (implies no handshake is in progress) + Expired, +} + +pub struct Handshake { + params: NoiseParams, + /// Index of the next session + next_index: u32, + /// Allow to have two outgoing handshakes in flight, because sometimes we may receive a delayed response to a handshake with bad networks + previous: HandshakeState, + /// Current handshake state + state: HandshakeState, + cookies: Cookies, + /// The timestamp of the last handshake we received + last_handshake_timestamp: Tai64N, + // TODO: make TimeStamper a singleton + stamper: TimeStamper, + pub(super) last_rtt: Option, +} + +#[derive(Default)] +struct Cookies { + last_mac1: Option<[u8; 16]>, + index: u32, + write_cookie: Option<[u8; 16]>, +} + +#[derive(Debug)] +pub struct HalfHandshake { + pub peer_index: u32, + pub peer_static_public: [u8; 32], +} + +pub fn parse_handshake_anon( + static_private: &x25519::StaticSecret, + static_public: &x25519::PublicKey, + packet: &HandshakeInit, +) -> Result { + let peer_index = packet.sender_idx; + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, static_public.as_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = static_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + Ok(HalfHandshake { + peer_index, + peer_static_public, + }) +} + +impl NoiseParams { + /// New noise params struct from our secret key, peers public key, and optional preshared key + fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + ) -> NoiseParams { + let static_shared = static_private.diffie_hellman(&peer_static_public); + + let initial_sending_mac_key = b2s_hash(LABEL_MAC1, peer_static_public.as_bytes()); + + NoiseParams { + static_public, + static_private, + peer_static_public, + static_shared, + sending_mac1_key: initial_sending_mac_key, + preshared_key, + } + } + + /// Set a new private key + fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + ) { + // Check that the public key indeed matches the private key + let check_key = x25519::PublicKey::from(&static_private); + assert_eq!(check_key.as_bytes(), static_public.as_bytes()); + + self.static_private = static_private; + self.static_public = static_public; + + self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public); + } +} + +impl Handshake { + pub(crate) fn new( + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + peer_static_public: x25519::PublicKey, + global_idx: u32, + preshared_key: Option<[u8; 32]>, + ) -> Handshake { + let params = NoiseParams::new( + static_private, + static_public, + peer_static_public, + preshared_key, + ); + + Handshake { + params, + next_index: global_idx, + previous: HandshakeState::None, + state: HandshakeState::None, + last_handshake_timestamp: Tai64N::zero(), + stamper: TimeStamper::new(), + cookies: Default::default(), + last_rtt: None, + } + } + + pub(crate) fn is_in_progress(&self) -> bool { + !matches!(self.state, HandshakeState::None | HandshakeState::Expired) + } + + pub(crate) fn timer(&self) -> Option { + match self.state { + HandshakeState::InitSent(HandshakeInitSentState { time_sent, .. }) => Some(time_sent), + _ => None, + } + } + + pub(crate) fn set_expired(&mut self) { + self.previous = HandshakeState::Expired; + self.state = HandshakeState::Expired; + } + + pub(crate) fn is_expired(&self) -> bool { + matches!(self.state, HandshakeState::Expired) + } + + pub(crate) fn has_cookie(&self) -> bool { + self.cookies.write_cookie.is_some() + } + + pub(crate) fn clear_cookie(&mut self) { + self.cookies.write_cookie = None; + } + + // The index used is 24 bits for peer index, allowing for 16M active peers per server and 8 bits for cyclic session index + fn inc_index(&mut self) -> u32 { + let index = self.next_index; + let idx8 = index as u8; + self.next_index = (index & !0xff) | u32::from(idx8.wrapping_add(1)); + self.next_index + } + + pub(crate) fn set_static_private( + &mut self, + private_key: x25519::StaticSecret, + public_key: x25519::PublicKey, + ) { + self.params.set_static_private(private_key, public_key) + } + + pub(super) fn receive_handshake_initialization<'a>( + &mut self, + packet: HandshakeInit, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.static_public.as_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + let peer_index = packet.sender_idx; + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + let peer_ephemeral_public = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, peer_ephemeral_public.as_bytes()); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac( + &b2s_hmac(&chaining_key, peer_ephemeral_public.as_bytes()), + &[0x01], + ); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = self + .params + .static_private + .diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + + let mut peer_static_public_decrypted = [0u8; KEY_LEN]; + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_open( + &mut peer_static_public_decrypted, + &key, + 0, + packet.encrypted_static, + &hash, + )?; + + ring::constant_time::verify_slices_are_equal( + self.params.peer_static_public.as_bytes(), + &peer_static_public_decrypted, + ) + .map_err(|_| WireGuardError::WrongKey)?; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, packet.encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let mut timestamp = [0u8; TIMESTAMP_LEN]; + aead_chacha20_open(&mut timestamp, &key, 0, packet.encrypted_timestamp, &hash)?; + + let timestamp = Tai64N::parse(×tamp)?; + if !timestamp.after(&self.last_handshake_timestamp) { + // Possibly a replay + return Err(WireGuardError::WrongTai64nTimestamp); + } + self.last_handshake_timestamp = timestamp; + + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, packet.encrypted_timestamp); + + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + }, + ); + + self.format_handshake_response(dst) + } + + pub(super) fn receive_handshake_response( + &mut self, + packet: HandshakeResponse, + ) -> Result { + // Check if there is a handshake awaiting a response and return the correct one + let (state, is_previous) = match (&self.state, &self.previous) { + (HandshakeState::InitSent(s), _) if s.local_index == packet.receiver_idx => (s, false), + (_, HandshakeState::InitSent(s)) if s.local_index == packet.receiver_idx => (s, true), + _ => return Err(WireGuardError::UnexpectedPacket), + }; + + let peer_index = packet.sender_idx; + let local_index = state.local_index; + + let unencrypted_ephemeral = x25519::PublicKey::from(*packet.unencrypted_ephemeral); + // msg.unencrypted_ephemeral = DH_PUBKEY(responder.ephemeral_private) + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + let mut hash = b2s_hash(&state.hash, unencrypted_ephemeral.as_bytes()); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&state.chaining_key, unencrypted_ephemeral.as_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + let mut chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public)) + let ephemeral_shared = state + .ephemeral_private + .diffie_hellman(&unencrypted_ephemeral); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &self + .params + .static_private + .diffie_hellman(&unencrypted_ephemeral) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_open(&mut [], &key, 0, packet.encrypted_nothing, &hash)?; + + // responder.hash = HASH(responder.hash || msg.encrypted_nothing) + // hash = b2s_hash(hash, buf[ENC_NOTHING_OFF..ENC_NOTHING_OFF + ENC_NOTHING_SZ]); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let rtt_time = Instant::now().duration_since(state.time_sent); + self.last_rtt = Some(rtt_time.as_millis() as u32); + + if is_previous { + self.previous = HandshakeState::None; + } else { + self.state = HandshakeState::None; + } + Ok(Session::new(local_index, peer_index, temp3, temp2)) + } + + pub(super) fn receive_cookie_reply( + &mut self, + packet: PacketCookieReply, + ) -> Result<(), WireGuardError> { + let mac1 = match self.cookies.last_mac1 { + Some(mac) => mac, + None => { + return Err(WireGuardError::UnexpectedPacket); + } + }; + + let local_index = self.cookies.index; + if packet.receiver_idx != local_index { + return Err(WireGuardError::WrongIndex); + } + // msg.encrypted_cookie = XAEAD(HASH(LABEL_COOKIE || responder.static_public), msg.nonce, cookie, last_received_msg.mac1) + let key = b2s_hash(LABEL_COOKIE, self.params.peer_static_public.as_bytes()); // TODO: pre-compute + + let payload = Payload { + aad: &mac1[0..16], + msg: packet.encrypted_cookie, + }; + let plaintext = XChaCha20Poly1305::new_from_slice(&key) + .unwrap() + .decrypt(packet.nonce.into(), payload) + .map_err(|_| WireGuardError::InvalidAeadTag)?; + + let cookie = plaintext + .try_into() + .map_err(|_| WireGuardError::InvalidPacket)?; + self.cookies.write_cookie = Some(cookie); + Ok(()) + } + + // Compute and append mac1 and mac2 to a handshake message + fn append_mac1_and_mac2<'a>( + &mut self, + local_index: u32, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let mac1_off = dst.len() - 32; + let mac2_off = dst.len() - 16; + + // msg.mac1 = MAC(HASH(LABEL_MAC1 || responder.static_public), msg[0:offsetof(msg.mac1)]) + let msg_mac1 = b2s_keyed_mac_16(&self.params.sending_mac1_key, &dst[..mac1_off]); + + dst[mac1_off..mac2_off].copy_from_slice(&msg_mac1[..]); + + //msg.mac2 = MAC(initiator.last_received_cookie, msg[0:offsetof(msg.mac2)]) + let msg_mac2: [u8; 16] = if let Some(cookie) = self.cookies.write_cookie { + b2s_keyed_mac_16(&cookie, &dst[..mac2_off]) + } else { + [0u8; 16] + }; + + dst[mac2_off..].copy_from_slice(&msg_mac2[..]); + + self.cookies.index = local_index; + self.cookies.last_mac1 = Some(msg_mac1); + Ok(dst) + } + + pub(super) fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::HANDSHAKE_INIT_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_static, rest) = rest.split_at_mut(32 + 16); + let (encrypted_timestamp, _) = rest.split_at_mut(12 + 16); + + let local_index = self.inc_index(); + + // initiator.chaining_key = HASH(CONSTRUCTION) + let mut chaining_key = INITIAL_CHAIN_KEY; + // initiator.hash = HASH(HASH(initiator.chaining_key || IDENTIFIER) || responder.static_public) + let mut hash = INITIAL_CHAIN_HASH; + hash = b2s_hash(&hash, self.params.peer_static_public.as_bytes()); + // initiator.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + // msg.message_type = 1 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_INIT.to_le_bytes()); + // msg.sender_index = little_endian(initiator.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // initiator.hash = HASH(initiator.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(initiator.chaining_key, msg.unencrypted_ephemeral) + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&b2s_hmac(&chaining_key, unencrypted_ephemeral), &[0x01]); + // temp = HMAC(initiator.chaining_key, DH(initiator.ephemeral_private, responder.static_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&self.params.peer_static_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_static = AEAD(key, 0, initiator.static_public, initiator.hash) + aead_chacha20_seal( + encrypted_static, + &key, + 0, + self.params.static_public.as_bytes(), + &hash, + ); + // initiator.hash = HASH(initiator.hash || msg.encrypted_static) + hash = b2s_hash(&hash, encrypted_static); + // temp = HMAC(initiator.chaining_key, DH(initiator.static_private, responder.static_public)) + let temp = b2s_hmac(&chaining_key, self.params.static_shared.as_bytes()); + // initiator.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // key = HMAC(temp, initiator.chaining_key || 0x2) + let key = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // msg.encrypted_timestamp = AEAD(key, 0, TAI64N(), initiator.hash) + let timestamp = self.stamper.stamp(); + aead_chacha20_seal(encrypted_timestamp, &key, 0, ×tamp, &hash); + // initiator.hash = HASH(initiator.hash || msg.encrypted_timestamp) + hash = b2s_hash(&hash, encrypted_timestamp); + + let time_now = Instant::now(); + self.previous = std::mem::replace( + &mut self.state, + HandshakeState::InitSent(HandshakeInitSentState { + local_index, + chaining_key, + hash, + ephemeral_private, + time_sent: time_now, + }), + ); + + self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_INIT_SZ]) + } + + fn format_handshake_response<'a>( + &mut self, + dst: &'a mut [u8], + ) -> Result<(&'a mut [u8], Session), WireGuardError> { + if dst.len() < super::HANDSHAKE_RESP_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let state = std::mem::replace(&mut self.state, HandshakeState::None); + let (mut chaining_key, mut hash, peer_ephemeral_public, peer_index) = match state { + HandshakeState::InitReceived { + chaining_key, + hash, + peer_ephemeral_public, + peer_index, + } => (chaining_key, hash, peer_ephemeral_public, peer_index), + _ => { + panic!("Unexpected attempt to call send_handshake_response"); + } + }; + + let (message_type, rest) = dst.split_at_mut(4); + let (sender_index, rest) = rest.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (unencrypted_ephemeral, rest) = rest.split_at_mut(32); + let (encrypted_nothing, _) = rest.split_at_mut(16); + + // responder.ephemeral_private = DH_GENERATE() + let ephemeral_private = x25519::ReusableSecret::random_from_rng(OsRng); + let local_index = self.inc_index(); + // msg.message_type = 2 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::HANDSHAKE_RESP.to_le_bytes()); + // msg.sender_index = little_endian(responder.sender_index) + sender_index.copy_from_slice(&local_index.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&peer_index.to_le_bytes()); + // msg.unencrypted_ephemeral = DH_PUBKEY(initiator.ephemeral_private) + unencrypted_ephemeral + .copy_from_slice(x25519::PublicKey::from(&ephemeral_private).as_bytes()); + // responder.hash = HASH(responder.hash || msg.unencrypted_ephemeral) + hash = b2s_hash(&hash, unencrypted_ephemeral); + // temp = HMAC(responder.chaining_key, msg.unencrypted_ephemeral) + let temp = b2s_hmac(&chaining_key, unencrypted_ephemeral); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.ephemeral_public)) + let ephemeral_shared = ephemeral_private.diffie_hellman(&peer_ephemeral_public); + let temp = b2s_hmac(&chaining_key, &ephemeral_shared.to_bytes()); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, DH(responder.ephemeral_private, initiator.static_public)) + let temp = b2s_hmac( + &chaining_key, + &ephemeral_private + .diffie_hellman(&self.params.peer_static_public) + .to_bytes(), + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp = HMAC(responder.chaining_key, preshared_key) + let temp = b2s_hmac( + &chaining_key, + &self.params.preshared_key.unwrap_or([0u8; 32])[..], + ); + // responder.chaining_key = HMAC(temp, 0x1) + chaining_key = b2s_hmac(&temp, &[0x01]); + // temp2 = HMAC(temp, responder.chaining_key || 0x2) + let temp2 = b2s_hmac2(&temp, &chaining_key, &[0x02]); + // key = HMAC(temp, temp2 || 0x3) + let key = b2s_hmac2(&temp, &temp2, &[0x03]); + // responder.hash = HASH(responder.hash || temp2) + hash = b2s_hash(&hash, &temp2); + // msg.encrypted_nothing = AEAD(key, 0, [empty], responder.hash) + aead_chacha20_seal(encrypted_nothing, &key, 0, &[], &hash); + + // Derive keys + // temp1 = HMAC(initiator.chaining_key, [empty]) + // temp2 = HMAC(temp1, 0x1) + // temp3 = HMAC(temp1, temp2 || 0x2) + // initiator.sending_key = temp2 + // initiator.receiving_key = temp3 + // initiator.sending_key_counter = 0 + // initiator.receiving_key_counter = 0 + let temp1 = b2s_hmac(&chaining_key, &[]); + let temp2 = b2s_hmac(&temp1, &[0x01]); + let temp3 = b2s_hmac2(&temp1, &temp2, &[0x02]); + + let dst = self.append_mac1_and_mac2(local_index, &mut dst[..super::HANDSHAKE_RESP_SZ])?; + + Ok((dst, Session::new(local_index, peer_index, temp2, temp3))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chacha20_seal_rfc7530_test_vector() { + let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; + let aad: [u8; 12] = [ + 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, + ]; + let key: [u8; 32] = [ + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, + 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, + 0x9c, 0x9d, 0x9e, 0x9f, + ]; + let nonce: [u8; 12] = [ + 0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + ]; + let mut buffer = vec![0; plaintext.len() + 16]; + + aead_chacha20_seal_inner(&mut buffer, &key, nonce, plaintext, &aad); + + const EXPECTED_CIPHERTEXT: [u8; 114] = [ + 0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef, + 0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7, + 0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa, + 0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29, + 0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77, + 0x8b, 0x8c, 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, 0xfa, 0xb3, 0x24, 0xe4, + 0xfa, 0xd6, 0x75, 0x94, 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, 0x3f, 0xf4, + 0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b, + 0x61, 0x16, + ]; + const EXPECTED_TAG: [u8; 16] = [ + 0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60, + 0x06, 0x91, + ]; + + assert_eq!(buffer[..plaintext.len()], EXPECTED_CIPHERTEXT); + assert_eq!(buffer[plaintext.len()..], EXPECTED_TAG); + } + + #[test] + fn symmetric_chacha20_seal_open() { + let aad: [u8; 32] = Default::default(); + let key: [u8; 32] = Default::default(); + let counter = 0; + + let mut encrypted_nothing: [u8; 16] = Default::default(); + + aead_chacha20_seal(&mut encrypted_nothing, &key, counter, &[], &aad); + + eprintln!("encrypted_nothing: {:?}", encrypted_nothing); + + aead_chacha20_open(&mut [], &key, counter, &encrypted_nothing, &aad) + .expect("Should open what we just sealed"); + } +} diff --git a/lib/boringtun/src/noise/mod.rs b/lib/boringtun/src/noise/mod.rs new file mode 100644 index 0000000..ebc99bc --- /dev/null +++ b/lib/boringtun/src/noise/mod.rs @@ -0,0 +1,794 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +pub mod errors; +pub mod handshake; +pub mod rate_limiter; + +mod session; +mod timers; + +use crate::noise::errors::WireGuardError; +use crate::noise::handshake::Handshake; +use crate::noise::rate_limiter::RateLimiter; +use crate::noise::timers::{TimerName, Timers}; +use crate::x25519; + +use std::collections::VecDeque; +use std::convert::{TryFrom, TryInto}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; +use std::time::Duration; + +/// The default value to use for rate limiting, when no other rate limiter is defined +const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10; + +const IPV4_MIN_HEADER_SIZE: usize = 20; +const IPV4_LEN_OFF: usize = 2; +const IPV4_SRC_IP_OFF: usize = 12; +const IPV4_DST_IP_OFF: usize = 16; +const IPV4_IP_SZ: usize = 4; + +const IPV6_MIN_HEADER_SIZE: usize = 40; +const IPV6_LEN_OFF: usize = 4; +const IPV6_SRC_IP_OFF: usize = 8; +const IPV6_DST_IP_OFF: usize = 24; +const IPV6_IP_SZ: usize = 16; + +const IP_LEN_SZ: usize = 2; + +const MAX_QUEUE_DEPTH: usize = 256; +/// number of sessions in the ring, better keep a PoT +const N_SESSIONS: usize = 8; + +#[derive(Debug)] +pub enum TunnResult<'a> { + Done, + Err(WireGuardError), + WriteToNetwork(&'a mut [u8]), + WriteToTunnelV4(&'a mut [u8], Ipv4Addr), + WriteToTunnelV6(&'a mut [u8], Ipv6Addr), +} + +impl<'a> From for TunnResult<'a> { + fn from(err: WireGuardError) -> TunnResult<'a> { + TunnResult::Err(err) + } +} + +/// Tunnel represents a point-to-point WireGuard connection +pub struct Tunn { + /// The handshake currently in progress + handshake: handshake::Handshake, + /// The N_SESSIONS most recent sessions, index is session id modulo N_SESSIONS + sessions: [Option; N_SESSIONS], + /// Index of most recently used session + current: usize, + /// Queue to store blocked packets + packet_queue: VecDeque>, + /// Keeps tabs on the expiring timers + timers: timers::Timers, + tx_bytes: usize, + rx_bytes: usize, + rate_limiter: Arc, +} + +type MessageType = u32; +const HANDSHAKE_INIT: MessageType = 1; +const HANDSHAKE_RESP: MessageType = 2; +const COOKIE_REPLY: MessageType = 3; +const DATA: MessageType = 4; + +const HANDSHAKE_INIT_SZ: usize = 148; +const HANDSHAKE_RESP_SZ: usize = 92; +const COOKIE_REPLY_SZ: usize = 64; +const DATA_OVERHEAD_SZ: usize = 32; + +#[derive(Debug)] +pub struct HandshakeInit<'a> { + sender_idx: u32, + pub unencrypted_ephemeral: &'a [u8; 32], + encrypted_static: &'a [u8], + encrypted_timestamp: &'a [u8], +} + +#[derive(Debug)] +pub struct HandshakeResponse<'a> { + sender_idx: u32, + pub receiver_idx: u32, + pub unencrypted_ephemeral: &'a [u8; 32], + encrypted_nothing: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketCookieReply<'a> { + pub receiver_idx: u32, + nonce: &'a [u8], + encrypted_cookie: &'a [u8], +} + +#[derive(Debug)] +pub struct PacketData<'a> { + pub receiver_idx: u32, + counter: u64, + encrypted_encapsulated_packet: &'a [u8], +} + +/// Describes a packet from network +#[derive(Debug)] +pub enum Packet<'a> { + HandshakeInit(HandshakeInit<'a>), + HandshakeResponse(HandshakeResponse<'a>), + PacketCookieReply(PacketCookieReply<'a>), + PacketData(PacketData<'a>), +} + +impl Tunn { + #[inline(always)] + pub fn parse_incoming_packet(src: &[u8]) -> Result { + if src.len() < 4 { + return Err(WireGuardError::InvalidPacket); + } + + // Checks the type, as well as the reserved zero fields + let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap()); + + Ok(match (packet_type, src.len()) { + (HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40]) + .expect("length already checked above"), + encrypted_static: &src[40..88], + encrypted_timestamp: &src[88..116], + }), + (HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse { + sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()), + unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44]) + .expect("length already checked above"), + encrypted_nothing: &src[44..60], + }), + (COOKIE_REPLY, COOKIE_REPLY_SZ) => Packet::PacketCookieReply(PacketCookieReply { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + nonce: &src[8..32], + encrypted_cookie: &src[32..64], + }), + (DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData { + receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()), + counter: u64::from_le_bytes(src[8..16].try_into().unwrap()), + encrypted_encapsulated_packet: &src[16..], + }), + _ => return Err(WireGuardError::InvalidPacket), + }) + } + + pub fn is_expired(&self) -> bool { + self.handshake.is_expired() + } + + pub fn dst_address(packet: &[u8]) -> Option { + if packet.is_empty() { + return None; + } + + match packet[0] >> 4 { + 4 if packet.len() >= IPV4_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_DST_IP_OFF..IPV4_DST_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + 6 if packet.len() >= IPV6_MIN_HEADER_SIZE => { + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_DST_IP_OFF..IPV6_DST_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + Some(IpAddr::from(addr_bytes)) + } + _ => None, + } + } + + /// Create a new tunnel using own private key and the peer public key + pub fn new( + static_private: x25519::StaticSecret, + peer_static_public: x25519::PublicKey, + preshared_key: Option<[u8; 32]>, + persistent_keepalive: Option, + index: u32, + rate_limiter: Option>, + ) -> Self { + let static_public = x25519::PublicKey::from(&static_private); + + Tunn { + handshake: Handshake::new( + static_private, + static_public, + peer_static_public, + index << 8, + preshared_key, + ), + sessions: Default::default(), + current: Default::default(), + tx_bytes: Default::default(), + rx_bytes: Default::default(), + + packet_queue: VecDeque::new(), + timers: Timers::new(persistent_keepalive, rate_limiter.is_none()), + + rate_limiter: rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }), + } + } + + /// Update the private key and clear existing sessions + pub fn set_static_private( + &mut self, + static_private: x25519::StaticSecret, + static_public: x25519::PublicKey, + rate_limiter: Option>, + ) { + self.timers.should_reset_rr = rate_limiter.is_none(); + self.rate_limiter = rate_limiter.unwrap_or_else(|| { + Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT)) + }); + self.handshake + .set_static_private(static_private, static_public); + for s in &mut self.sessions { + *s = None; + } + } + + /// Encapsulate a single packet from the tunnel interface. + /// Returns TunnResult. + /// + /// # Panics + /// Panics if dst buffer is too small. + /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. + pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { + let current = self.current; + if let Some(ref session) = self.sessions[current % N_SESSIONS] { + // Send the packet using an established session + let packet = session.format_packet_data(src, dst); + self.timer_tick(TimerName::TimeLastPacketSent); + // Exclude Keepalive packets from timer update. + if !src.is_empty() { + self.timer_tick(TimerName::TimeLastDataPacketSent); + } + self.tx_bytes += src.len(); + return TunnResult::WriteToNetwork(packet); + } + + // If there is no session, queue the packet for future retry + self.queue_packet(src); + // Initiate a new handshake if none is in progress + self.format_handshake_initiation(dst, false) + } + + /// Receives a UDP datagram from the network and parses it. + /// Returns TunnResult. + /// + /// If the result is of type TunnResult::WriteToNetwork, should repeat the call with empty datagram, + /// until TunnResult::Done is returned. If batch processing packets, it is OK to defer until last + /// packet is processed. + pub fn decapsulate<'a>( + &mut self, + src_addr: Option, + datagram: &[u8], + dst: &'a mut [u8], + ) -> TunnResult<'a> { + if datagram.is_empty() { + // Indicates a repeated call + return self.send_queued_packet(dst); + } + + let mut cookie = [0u8; COOKIE_REPLY_SZ]; + let packet = match self + .rate_limiter + .verify_packet(src_addr, datagram, &mut cookie) + { + Ok(packet) => packet, + Err(TunnResult::WriteToNetwork(cookie)) => { + dst[..cookie.len()].copy_from_slice(cookie); + return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]); + } + Err(TunnResult::Err(e)) => return TunnResult::Err(e), + _ => unreachable!(), + }; + + self.handle_verified_packet(packet, dst) + } + + pub(crate) fn handle_verified_packet<'a>( + &mut self, + packet: Packet, + dst: &'a mut [u8], + ) -> TunnResult<'a> { + match packet { + Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst), + Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst), + Packet::PacketCookieReply(p) => self.handle_cookie_reply(p), + Packet::PacketData(p) => self.handle_data(p, dst), + } + .unwrap_or_else(TunnResult::from) + } + + fn handle_handshake_init<'a>( + &mut self, + p: HandshakeInit, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_initiation", + remote_idx = p.sender_idx + ); + + let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?; + + // Store new session in ring buffer + let index = session.local_index(); + self.sessions[index % N_SESSIONS] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeLastPacketSent); + self.timer_tick_session_established(false, index); // New session established, we are not the initiator + + tracing::debug!(message = "Sending handshake_response", local_idx = index); + + Ok(TunnResult::WriteToNetwork(packet)) + } + + fn handle_handshake_response<'a>( + &mut self, + p: HandshakeResponse, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received handshake_response", + local_idx = p.receiver_idx, + remote_idx = p.sender_idx + ); + + let session = self.handshake.receive_handshake_response(p)?; + + let keepalive_packet = session.format_packet_data(&[], dst); + // Store new session in ring buffer + let l_idx = session.local_index(); + let index = l_idx % N_SESSIONS; + self.sessions[index] = Some(session); + + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick_session_established(true, index); // New session established, we are the initiator + self.set_current_session(l_idx); + + tracing::debug!("Sending keepalive"); + + Ok(TunnResult::WriteToNetwork(keepalive_packet)) // Send a keepalive as a response + } + + fn handle_cookie_reply<'a>( + &mut self, + p: PacketCookieReply, + ) -> Result, WireGuardError> { + tracing::debug!( + message = "Received cookie_reply", + local_idx = p.receiver_idx + ); + + self.handshake.receive_cookie_reply(p)?; + self.timer_tick(TimerName::TimeLastPacketReceived); + self.timer_tick(TimerName::TimeCookieReceived); + + tracing::debug!("Did set cookie"); + + Ok(TunnResult::Done) + } + + /// Update the index of the currently used session, if needed + fn set_current_session(&mut self, new_idx: usize) { + let cur_idx = self.current; + if cur_idx == new_idx { + // There is nothing to do, already using this session, this is the common case + return; + } + if self.sessions[cur_idx % N_SESSIONS].is_none() + || self.timers.session_timers[new_idx % N_SESSIONS] + >= self.timers.session_timers[cur_idx % N_SESSIONS] + { + self.current = new_idx; + tracing::debug!(message = "New session", session = new_idx); + } + } + + /// Decrypts a data packet, and stores the decapsulated packet in dst. + fn handle_data<'a>( + &mut self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result, WireGuardError> { + let r_idx = packet.receiver_idx as usize; + let idx = r_idx % N_SESSIONS; + + // Get the (probably) right session + let decapsulated_packet = { + let session = self.sessions[idx].as_ref(); + let session = session.ok_or_else(|| { + tracing::trace!(message = "No current session available", remote_idx = r_idx); + WireGuardError::NoCurrentSession + })?; + session.receive_packet_data(packet, dst)? + }; + + self.set_current_session(r_idx); + + self.timer_tick(TimerName::TimeLastPacketReceived); + + Ok(self.validate_decapsulated_packet(decapsulated_packet)) + } + + /// Formats a new handshake initiation message and store it in dst. If force_resend is true will send + /// a new handshake, even if a handshake is already in progress (for example when a handshake times out) + pub fn format_handshake_initiation<'a>( + &mut self, + dst: &'a mut [u8], + force_resend: bool, + ) -> TunnResult<'a> { + if self.handshake.is_in_progress() && !force_resend { + return TunnResult::Done; + } + + if self.handshake.is_expired() { + self.timers.clear(); + } + + let starting_new_handshake = !self.handshake.is_in_progress(); + + match self.handshake.format_handshake_initiation(dst) { + Ok(packet) => { + tracing::debug!("Sending handshake_initiation"); + + if starting_new_handshake { + self.timer_tick(TimerName::TimeLastHandshakeStarted); + } + self.timer_tick(TimerName::TimeLastPacketSent); + TunnResult::WriteToNetwork(packet) + } + Err(e) => TunnResult::Err(e), + } + } + + /// Check if an IP packet is v4 or v6, truncate to the length indicated by the length field + /// Returns the truncated packet and the source IP as TunnResult + fn validate_decapsulated_packet<'a>(&mut self, packet: &'a mut [u8]) -> TunnResult<'a> { + let (computed_len, src_ip_address) = match packet.len() { + 0 => return TunnResult::Done, // This is keepalive, and not an error + _ if packet[0] >> 4 == 4 && packet.len() >= IPV4_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV4_LEN_OFF..IPV4_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV4_IP_SZ] = packet + [IPV4_SRC_IP_OFF..IPV4_SRC_IP_OFF + IPV4_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize, + IpAddr::from(addr_bytes), + ) + } + _ if packet[0] >> 4 == 6 && packet.len() >= IPV6_MIN_HEADER_SIZE => { + let len_bytes: [u8; IP_LEN_SZ] = packet[IPV6_LEN_OFF..IPV6_LEN_OFF + IP_LEN_SZ] + .try_into() + .unwrap(); + let addr_bytes: [u8; IPV6_IP_SZ] = packet + [IPV6_SRC_IP_OFF..IPV6_SRC_IP_OFF + IPV6_IP_SZ] + .try_into() + .unwrap(); + ( + u16::from_be_bytes(len_bytes) as usize + IPV6_MIN_HEADER_SIZE, + IpAddr::from(addr_bytes), + ) + } + _ => return TunnResult::Err(WireGuardError::InvalidPacket), + }; + + if computed_len > packet.len() { + return TunnResult::Err(WireGuardError::InvalidPacket); + } + + self.timer_tick(TimerName::TimeLastDataPacketReceived); + self.rx_bytes += computed_len; + + match src_ip_address { + IpAddr::V4(addr) => TunnResult::WriteToTunnelV4(&mut packet[..computed_len], addr), + IpAddr::V6(addr) => TunnResult::WriteToTunnelV6(&mut packet[..computed_len], addr), + } + } + + /// Get a packet from the queue, and try to encapsulate it + fn send_queued_packet<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + if let Some(packet) = self.dequeue_packet() { + match self.encapsulate(&packet, dst) { + TunnResult::Err(_) => { + // On error, return packet to the queue + self.requeue_packet(packet); + } + r => return r, + } + } + TunnResult::Done + } + + /// Push packet to the back of the queue + fn queue_packet(&mut self, packet: &[u8]) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_back(packet.to_vec()); + } + } + + /// Push packet to the front of the queue + fn requeue_packet(&mut self, packet: Vec) { + if self.packet_queue.len() < MAX_QUEUE_DEPTH { + // Drop if too many are already in queue + self.packet_queue.push_front(packet); + } + } + + fn dequeue_packet(&mut self) -> Option> { + self.packet_queue.pop_front() + } + + fn estimate_loss(&self) -> f32 { + let session_idx = self.current; + + let mut weight = 9.0; + let mut cur_avg = 0.0; + let mut total_weight = 0.0; + + for i in 0..N_SESSIONS { + if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] { + let (expected, received) = session.current_packet_cnt(); + + let loss = if expected == 0 { + 0.0 + } else { + 1.0 - received as f32 / expected as f32 + }; + + cur_avg += loss * weight; + total_weight += weight; + weight /= 3.0; + } + } + + if total_weight == 0.0 { + 0.0 + } else { + cur_avg / total_weight + } + } + + /// Return stats from the tunnel: + /// * Time since last handshake in seconds + /// * Data bytes sent + /// * Data bytes received + pub fn stats(&self) -> (Option, usize, usize, f32, Option) { + let time = self.time_since_last_handshake(); + let tx_bytes = self.tx_bytes; + let rx_bytes = self.rx_bytes; + let loss = self.estimate_loss(); + let rtt = self.handshake.last_rtt; + + (time, tx_bytes, rx_bytes, loss, rtt) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "mock-instant")] + use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT}; + + use super::*; + use rand_core::{OsRng, RngCore}; + + fn create_two_tuns() -> (Tunn, Tunn) { + let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key); + let my_idx = OsRng.next_u32(); + + let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key); + let their_idx = OsRng.next_u32(); + + let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None); + + let their_tun = Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None); + + (my_tun, their_tun) + } + + fn create_handshake_init(tun: &mut Tunn) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_init = tun.format_handshake_initiation(&mut dst, false); + assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_))); + let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init { + sent + } else { + unreachable!(); + }; + + handshake_init.into() + } + + fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst); + assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_))); + + let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp { + sent + } else { + unreachable!(); + }; + + handshake_resp.into() + } + + fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, handshake_resp, &mut dst); + assert!(matches!(keepalive, TunnResult::WriteToNetwork(_))); + + let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive { + sent + } else { + unreachable!(); + }; + + keepalive.into() + } + + fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) { + let mut dst = vec![0u8; 2048]; + let keepalive = tun.decapsulate(None, keepalive, &mut dst); + assert!(matches!(keepalive, TunnResult::Done)); + } + + fn create_two_tuns_and_handshake() -> (Tunn, Tunn) { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let keepalive = parse_handshake_resp(&mut my_tun, &resp); + parse_keepalive(&mut their_tun, &keepalive); + + (my_tun, their_tun) + } + + fn create_ipv4_udp_packet() -> Vec { + let header = + etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23); + let payload = [0, 1, 2, 3]; + let mut packet = Vec::::with_capacity(header.size(payload.len())); + header.write(&mut packet, &payload).unwrap(); + packet + } + + #[cfg(feature = "mock-instant")] + fn update_timer_results_in_handshake(tun: &mut Tunn) { + let mut dst = vec![0u8; 2048]; + let result = tun.update_timers(&mut dst); + assert!(matches!(result, TunnResult::WriteToNetwork(_))); + let packet_data = if let TunnResult::WriteToNetwork(data) = result { + data + } else { + unreachable!(); + }; + let packet = Tunn::parse_incoming_packet(packet_data).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + } + + #[test] + fn create_two_tunnels_linked_to_eachother() { + let (_my_tun, _their_tun) = create_two_tuns(); + } + + #[test] + fn handshake_init() { + let (mut my_tun, _their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let packet = Tunn::parse_incoming_packet(&init).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + } + + #[test] + fn handshake_init_and_response() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let packet = Tunn::parse_incoming_packet(&resp).unwrap(); + assert!(matches!(packet, Packet::HandshakeResponse(_))); + } + + #[test] + fn full_handshake() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, &init); + let keepalive = parse_handshake_resp(&mut my_tun, &resp); + let packet = Tunn::parse_incoming_packet(&keepalive).unwrap(); + assert!(matches!(packet, Packet::PacketData(_))); + } + + #[test] + fn full_handshake_plus_timers() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + // Time has not yet advanced so their is nothing to do + assert!(matches!(my_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + } + + #[test] + #[cfg(feature = "mock-instant")] + fn new_handshake_after_two_mins() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + let mut my_dst = [0u8; 1024]; + + // Advance time 1 second and "send" 1 packet so that we send a handshake + // after the timeout + mock_instant::MockClock::advance(Duration::from_secs(1)); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + assert!(matches!( + my_tun.update_timers(&mut my_dst), + TunnResult::Done + )); + let sent_packet_buf = create_ipv4_udp_packet(); + let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + assert!(matches!(data, TunnResult::WriteToNetwork(_))); + + //Advance to timeout + mock_instant::MockClock::advance(REKEY_AFTER_TIME); + assert!(matches!(their_tun.update_timers(&mut []), TunnResult::Done)); + update_timer_results_in_handshake(&mut my_tun); + } + + #[test] + #[cfg(feature = "mock-instant")] + fn handshake_no_resp_rekey_timeout() { + let (mut my_tun, _their_tun) = create_two_tuns(); + + let init = create_handshake_init(&mut my_tun); + let packet = Tunn::parse_incoming_packet(&init).unwrap(); + assert!(matches!(packet, Packet::HandshakeInit(_))); + + mock_instant::MockClock::advance(REKEY_TIMEOUT); + update_timer_results_in_handshake(&mut my_tun) + } + + #[test] + fn one_ip_packet() { + let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake(); + let mut my_dst = [0u8; 1024]; + let mut their_dst = [0u8; 1024]; + + let sent_packet_buf = create_ipv4_udp_packet(); + + let data = my_tun.encapsulate(&sent_packet_buf, &mut my_dst); + assert!(matches!(data, TunnResult::WriteToNetwork(_))); + let data = if let TunnResult::WriteToNetwork(sent) = data { + sent + } else { + unreachable!(); + }; + + let data = their_tun.decapsulate(None, data, &mut their_dst); + assert!(matches!(data, TunnResult::WriteToTunnelV4(..))); + let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data { + recv + } else { + unreachable!(); + }; + assert_eq!(sent_packet_buf, recv_packet_buf); + } +} diff --git a/lib/boringtun/src/noise/rate_limiter.rs b/lib/boringtun/src/noise/rate_limiter.rs new file mode 100644 index 0000000..e0c5530 --- /dev/null +++ b/lib/boringtun/src/noise/rate_limiter.rs @@ -0,0 +1,193 @@ +use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24}; +use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1}; +use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError}; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; +use std::net::IpAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; + +use aead::generic_array::GenericArray; +use aead::{AeadInPlace, KeyInit}; +use chacha20poly1305::{Key, XChaCha20Poly1305}; +use parking_lot::Mutex; +use rand_core::{OsRng, RngCore}; +use ring::constant_time::verify_slices_are_equal; + +const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division +const COOKIE_SIZE: usize = 16; +const COOKIE_NONCE_SIZE: usize = 24; + +/// How often should reset count in seconds +const RESET_PERIOD: u64 = 1; + +type Cookie = [u8; COOKIE_SIZE]; + +/// There are two places where WireGuard requires "randomness" for cookies +/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse +/// * A secret value that changes every two minutes +/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address +/// we can relax the randomness definition a bit, in order to avoid locking, because using less +/// resources is the main goal of any DoS prevention mechanism. +/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and +/// some counters. +pub struct RateLimiter { + /// The key we use to derive the nonce + nonce_key: [u8; 32], + /// The key we use to derive the cookie + secret_key: [u8; 16], + start_time: Instant, + /// A single 64 bit counter (should suffice for many years) + nonce_ctr: AtomicUsize, + mac1_key: [u8; 32], + cookie_key: Key, + limit: usize, + /// The counter since last reset + count: AtomicUsize, + /// The time last reset was performed on this rate limiter + last_reset: Mutex, +} + +impl RateLimiter { + pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self { + let mut secret_key = [0u8; 16]; + OsRng.fill_bytes(&mut secret_key); + RateLimiter { + nonce_key: Self::rand_bytes(), + secret_key, + start_time: Instant::now(), + nonce_ctr: AtomicUsize::new(0), + mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()), + cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(), + limit: limit as _, + count: AtomicUsize::new(0), + last_reset: Mutex::new(Instant::now()), + } + } + + fn rand_bytes() -> [u8; 32] { + let mut key = [0u8; 32]; + OsRng.fill_bytes(&mut key); + key + } + + /// Reset packet count (ideally should be called with a period of 1 second) + pub fn reset_count(&self) { + // The rate limiter is not very accurate, but at the scale we care about it doesn't matter much + let current_time = Instant::now(); + let mut last_reset_time = self.last_reset.lock(); + if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD { + self.count.store(0, Ordering::SeqCst); + *last_reset_time = current_time; + } + } + + /// Compute the correct cookie value based on the current secret value and the source IP + fn current_cookie(&self, addr: IpAddr) -> Cookie { + let mut addr_bytes = [0u8; 16]; + + match addr { + IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]), + IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]), + } + + // The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address) + // First we derive the secret from the current time, the value of cur_counter would change with time. + let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH; + + // Next we derive the cookie + b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes) + } + + fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] { + let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed); + + b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes()) + } + + fn is_under_load(&self) -> bool { + self.count.fetch_add(1, Ordering::SeqCst) >= self.limit + } + + pub(crate) fn format_cookie_reply<'a>( + &self, + idx: u32, + cookie: Cookie, + mac1: &[u8], + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + if dst.len() < super::COOKIE_REPLY_SZ { + return Err(WireGuardError::DestinationBufferTooSmall); + } + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (nonce, rest) = rest.split_at_mut(24); + let (encrypted_cookie, _) = rest.split_at_mut(16 + 16); + + // msg.message_type = 3 + // msg.reserved_zero = { 0, 0, 0 } + message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes()); + // msg.receiver_index = little_endian(initiator.sender_index) + receiver_index.copy_from_slice(&idx.to_le_bytes()); + nonce.copy_from_slice(&self.nonce()[..]); + + let cipher = XChaCha20Poly1305::new(&self.cookie_key); + + let iv = GenericArray::from_slice(nonce); + + encrypted_cookie[..16].copy_from_slice(&cookie); + let tag = cipher + .encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16]) + .map_err(|_| WireGuardError::DestinationBufferTooSmall)?; + + encrypted_cookie[16..].copy_from_slice(&tag); + + Ok(&mut dst[..super::COOKIE_REPLY_SZ]) + } + + /// Verify the MAC fields on the datagram, and apply rate limiting if needed + pub fn verify_packet<'a, 'b>( + &self, + src_addr: Option, + src: &'a [u8], + dst: &'b mut [u8], + ) -> Result, TunnResult<'b>> { + let packet = Tunn::parse_incoming_packet(src)?; + + // Verify and rate limit handshake messages only + if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. }) + | Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet + { + let (msg, macs) = src.split_at(src.len() - 32); + let (mac1, mac2) = macs.split_at(16); + + let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg); + verify_slices_are_equal(&computed_mac1[..16], mac1) + .map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?; + + if self.is_under_load() { + let addr = match src_addr { + None => return Err(TunnResult::Err(WireGuardError::UnderLoad)), + Some(addr) => addr, + }; + + // Only given an address can we validate mac2 + let cookie = self.current_cookie(addr); + let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1); + + if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() { + let cookie_packet = self + .format_cookie_reply(sender_idx, cookie, mac1, dst) + .map_err(TunnResult::Err)?; + return Err(TunnResult::WriteToNetwork(cookie_packet)); + } + } + } + + Ok(packet) + } +} diff --git a/lib/boringtun/src/noise/session.rs b/lib/boringtun/src/noise/session.rs new file mode 100644 index 0000000..0d05b95 --- /dev/null +++ b/lib/boringtun/src/noise/session.rs @@ -0,0 +1,329 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::PacketData; +use crate::noise::errors::WireGuardError; +use parking_lot::Mutex; +use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +pub struct Session { + pub(crate) receiving_index: u32, + sending_index: u32, + receiver: LessSafeKey, + sender: LessSafeKey, + sending_key_counter: AtomicUsize, + receiving_key_counter: Mutex, +} + +impl std::fmt::Debug for Session { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Session: {}<- ->{}", + self.receiving_index, self.sending_index + ) + } +} + +/// Where encrypted data resides in a data packet +const DATA_OFFSET: usize = 16; +/// The overhead of the AEAD +const AEAD_SIZE: usize = 16; + +// Receiving buffer constants +const WORD_SIZE: u64 = 64; +const N_WORDS: u64 = 16; // Suffice to reorder 64*16 = 1024 packets; can be increased at will +const N_BITS: u64 = WORD_SIZE * N_WORDS; + +#[derive(Debug, Clone, Default)] +struct ReceivingKeyCounterValidator { + /// In order to avoid replays while allowing for some reordering of the packets, we keep a + /// bitmap of received packets, and the value of the highest counter + next: u64, + /// Used to estimate packet loss + receive_cnt: u64, + bitmap: [u64; N_WORDS as usize], +} + +impl ReceivingKeyCounterValidator { + #[inline(always)] + fn set_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] |= 1 << bit; + } + + #[inline(always)] + fn clear_bit(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + self.bitmap[word] &= !(1u64 << bit); + } + + /// Clear the word that contains idx + #[inline(always)] + fn clear_word(&mut self, idx: u64) { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + self.bitmap[word] = 0; + } + + /// Returns true if bit is set, false otherwise + #[inline(always)] + fn check_bit(&self, idx: u64) -> bool { + let bit_idx = idx % N_BITS; + let word = (bit_idx / WORD_SIZE) as usize; + let bit = (bit_idx % WORD_SIZE) as usize; + ((self.bitmap[word] >> bit) & 1) == 1 + } + + /// Returns true if the counter was not yet received, and is not too far back + #[inline(always)] + fn will_accept(&self, counter: u64) -> Result<(), WireGuardError> { + if counter >= self.next { + // As long as the counter is growing no replay took place for sure + return Ok(()); + } + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter); + } + if !self.check_bit(counter) { + Ok(()) + } else { + Err(WireGuardError::DuplicateCounter) + } + } + + /// Marks the counter as received, and returns true if it is still good (in case during + /// decryption something changed) + #[inline(always)] + fn mark_did_receive(&mut self, counter: u64) -> Result<(), WireGuardError> { + if counter + N_BITS < self.next { + // Drop if too far back + return Err(WireGuardError::InvalidCounter); + } + if counter == self.next { + // Usually the packets arrive in order, in that case we simply mark the bit and + // increment the counter + self.set_bit(counter); + self.next += 1; + return Ok(()); + } + if counter < self.next { + // A packet arrived out of order, check if it is valid, and mark + if self.check_bit(counter) { + return Err(WireGuardError::InvalidCounter); + } + self.set_bit(counter); + return Ok(()); + } + // Packets where dropped, or maybe reordered, skip them and mark unused + if counter - self.next >= N_BITS { + // Too far ahead, clear all the bits + for c in self.bitmap.iter_mut() { + *c = 0; + } + } else { + let mut i = self.next; + while i % WORD_SIZE != 0 && i < counter { + // Clear until i aligned to word size + self.clear_bit(i); + i += 1; + } + while i + WORD_SIZE < counter { + // Clear whole word at a time + self.clear_word(i); + i = (i + WORD_SIZE) & 0u64.wrapping_sub(WORD_SIZE); + } + while i < counter { + // Clear any remaining bits + self.clear_bit(i); + i += 1; + } + } + self.set_bit(counter); + self.next = counter + 1; + Ok(()) + } +} + +impl Session { + pub(super) fn new( + local_index: u32, + peer_index: u32, + receiving_key: [u8; 32], + sending_key: [u8; 32], + ) -> Session { + Session { + receiving_index: local_index, + sending_index: peer_index, + receiver: LessSafeKey::new( + UnboundKey::new(&CHACHA20_POLY1305, &receiving_key).unwrap(), + ), + sender: LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &sending_key).unwrap()), + sending_key_counter: AtomicUsize::new(0), + receiving_key_counter: Mutex::new(Default::default()), + } + } + + pub(super) fn local_index(&self) -> usize { + self.receiving_index as usize + } + + /// Returns true if receiving counter is good to use + fn receiving_counter_quick_check(&self, counter: u64) -> Result<(), WireGuardError> { + let counter_validator = self.receiving_key_counter.lock(); + counter_validator.will_accept(counter) + } + + /// Returns true if receiving counter is good to use, and marks it as used { + fn receiving_counter_mark(&self, counter: u64) -> Result<(), WireGuardError> { + let mut counter_validator = self.receiving_key_counter.lock(); + let ret = counter_validator.mark_did_receive(counter); + if ret.is_ok() { + counter_validator.receive_cnt += 1; + } + ret + } + + /// src - an IP packet from the interface + /// dst - pre-allocated space to hold the encapsulating UDP packet to send over the network + /// returns the size of the formatted packet + pub(super) fn format_packet_data<'a>(&self, src: &[u8], dst: &'a mut [u8]) -> &'a mut [u8] { + if dst.len() < src.len() + super::DATA_OVERHEAD_SZ { + panic!("The destination buffer is too small"); + } + + let sending_key_counter = self.sending_key_counter.fetch_add(1, Ordering::Relaxed) as u64; + + let (message_type, rest) = dst.split_at_mut(4); + let (receiver_index, rest) = rest.split_at_mut(4); + let (counter, data) = rest.split_at_mut(8); + + message_type.copy_from_slice(&super::DATA.to_le_bytes()); + receiver_index.copy_from_slice(&self.sending_index.to_le_bytes()); + counter.copy_from_slice(&sending_key_counter.to_le_bytes()); + + // TODO: spec requires padding to 16 bytes, but actually works fine without it + let n = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&sending_key_counter.to_le_bytes()); + data[..src.len()].copy_from_slice(src); + self.sender + .seal_in_place_separate_tag( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut data[..src.len()], + ) + .map(|tag| { + data[src.len()..src.len() + AEAD_SIZE].copy_from_slice(tag.as_ref()); + src.len() + AEAD_SIZE + }) + .unwrap() + }; + + &mut dst[..DATA_OFFSET + n] + } + + /// packet - a data packet we received from the network + /// dst - pre-allocated space to hold the encapsulated IP packet, to send to the interface + /// dst will always take less space than src + /// return the size of the encapsulated packet on success + pub(super) fn receive_packet_data<'a>( + &self, + packet: PacketData, + dst: &'a mut [u8], + ) -> Result<&'a mut [u8], WireGuardError> { + let ct_len = packet.encrypted_encapsulated_packet.len(); + if dst.len() < ct_len { + // This is a very incorrect use of the library, therefore panic and not error + panic!("The destination buffer is too small"); + } + if packet.receiver_idx != self.receiving_index { + return Err(WireGuardError::WrongIndex); + } + // Don't reuse counters, in case this is a replay attack we want to quickly check the counter without running expensive decryption + self.receiving_counter_quick_check(packet.counter)?; + + let ret = { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&packet.counter.to_le_bytes()); + dst[..ct_len].copy_from_slice(packet.encrypted_encapsulated_packet); + self.receiver + .open_in_place( + Nonce::assume_unique_for_key(nonce), + Aad::from(&[]), + &mut dst[..ct_len], + ) + .map_err(|_| WireGuardError::InvalidAeadTag)? + }; + + // After decryption is done, check counter again, and mark as received + self.receiving_counter_mark(packet.counter)?; + Ok(ret) + } + + /// Returns the estimated downstream packet loss for this session + pub(super) fn current_packet_cnt(&self) -> (u64, u64) { + let counter_validator = self.receiving_key_counter.lock(); + (counter_validator.next, counter_validator.receive_cnt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_replay_counter() { + let mut c: ReceivingKeyCounterValidator = Default::default(); + + assert!(c.mark_did_receive(0).is_ok()); + assert!(c.mark_did_receive(0).is_err()); + assert!(c.mark_did_receive(1).is_ok()); + assert!(c.mark_did_receive(1).is_err()); + assert!(c.mark_did_receive(63).is_ok()); + assert!(c.mark_did_receive(63).is_err()); + assert!(c.mark_did_receive(15).is_ok()); + assert!(c.mark_did_receive(15).is_err()); + + for i in 64..N_BITS + 128 { + assert!(c.mark_did_receive(i).is_ok()); + assert!(c.mark_did_receive(i).is_err()); + } + + assert!(c.mark_did_receive(N_BITS * 3).is_ok()); + for i in 0..=N_BITS * 2 { + assert!(matches!( + c.will_accept(i), + Err(WireGuardError::InvalidCounter) + )); + assert!(c.mark_did_receive(i).is_err()); + } + for i in N_BITS * 2 + 1..N_BITS * 3 { + assert!(c.will_accept(i).is_ok()); + } + assert!(matches!( + c.will_accept(N_BITS * 3), + Err(WireGuardError::DuplicateCounter) + )); + + for i in (N_BITS * 2 + 1..N_BITS * 3).rev() { + assert!(c.mark_did_receive(i).is_ok()); + assert!(c.mark_did_receive(i).is_err()); + } + + assert!(c.mark_did_receive(N_BITS * 3 + 70).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 71).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 72).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 72 + 125).is_ok()); + assert!(c.mark_did_receive(N_BITS * 3 + 63).is_ok()); + + assert!(c.mark_did_receive(N_BITS * 3 + 70).is_err()); + assert!(c.mark_did_receive(N_BITS * 3 + 71).is_err()); + assert!(c.mark_did_receive(N_BITS * 3 + 72).is_err()); + } +} diff --git a/lib/boringtun/src/noise/timers.rs b/lib/boringtun/src/noise/timers.rs new file mode 100644 index 0000000..6b91d57 --- /dev/null +++ b/lib/boringtun/src/noise/timers.rs @@ -0,0 +1,335 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::errors::WireGuardError; +use crate::noise::{Tunn, TunnResult}; +use std::mem; +use std::ops::{Index, IndexMut}; + +use std::time::Duration; + +#[cfg(feature = "mock-instant")] +use mock_instant::Instant; + +#[cfg(not(feature = "mock-instant"))] +use crate::sleepyinstant::Instant; + +// Some constants, represent time in seconds +// https://www.wireguard.com/papers/wireguard.pdf#page=14 +pub(crate) const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); +const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); +const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); +pub(crate) const REKEY_TIMEOUT: Duration = Duration::from_secs(5); +const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); +const COOKIE_EXPIRATION_TIME: Duration = Duration::from_secs(120); + +#[derive(Debug)] +pub enum TimerName { + /// Current time, updated each call to `update_timers` + TimeCurrent, + /// Time when last handshake was completed + TimeSessionEstablished, + /// Time the last attempt for a new handshake began + TimeLastHandshakeStarted, + /// Time we last received and authenticated a packet + TimeLastPacketReceived, + /// Time we last send a packet + TimeLastPacketSent, + /// Time we last received and authenticated a DATA packet + TimeLastDataPacketReceived, + /// Time we last send a DATA packet + TimeLastDataPacketSent, + /// Time we last received a cookie + TimeCookieReceived, + /// Time we last sent persistent keepalive + TimePersistentKeepalive, + Top, +} + +use self::TimerName::*; + +#[derive(Debug)] +pub struct Timers { + /// Is the owner of the timer the initiator or the responder for the last handshake? + is_initiator: bool, + /// Start time of the tunnel + time_started: Instant, + timers: [Duration; TimerName::Top as usize], + pub(super) session_timers: [Duration; super::N_SESSIONS], + /// Did we receive data without sending anything back? + want_keepalive: bool, + /// Did we send data without hearing back? + want_handshake: bool, + persistent_keepalive: usize, + /// Should this timer call reset rr function (if not a shared rr instance) + pub(super) should_reset_rr: bool, +} + +impl Timers { + pub(super) fn new(persistent_keepalive: Option, reset_rr: bool) -> Timers { + Timers { + is_initiator: false, + time_started: Instant::now(), + timers: Default::default(), + session_timers: Default::default(), + want_keepalive: Default::default(), + want_handshake: Default::default(), + persistent_keepalive: usize::from(persistent_keepalive.unwrap_or(0)), + should_reset_rr: reset_rr, + } + } + + fn is_initiator(&self) -> bool { + self.is_initiator + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + pub(super) fn clear(&mut self) { + let now = Instant::now().duration_since(self.time_started); + for t in &mut self.timers[..] { + *t = now; + } + self.want_handshake = false; + self.want_keepalive = false; + } +} + +impl Index for Timers { + type Output = Duration; + fn index(&self, index: TimerName) -> &Duration { + &self.timers[index as usize] + } +} + +impl IndexMut for Timers { + fn index_mut(&mut self, index: TimerName) -> &mut Duration { + &mut self.timers[index as usize] + } +} + +impl Tunn { + pub(super) fn timer_tick(&mut self, timer_name: TimerName) { + match timer_name { + TimeLastPacketReceived => { + self.timers.want_keepalive = true; + self.timers.want_handshake = false; + } + TimeLastPacketSent => { + self.timers.want_handshake = true; + self.timers.want_keepalive = false; + } + _ => {} + } + + let time = self.timers[TimeCurrent]; + self.timers[timer_name] = time; + } + + pub(super) fn timer_tick_session_established( + &mut self, + is_initiator: bool, + session_idx: usize, + ) { + self.timer_tick(TimeSessionEstablished); + self.timers.session_timers[session_idx % crate::noise::N_SESSIONS] = + self.timers[TimeCurrent]; + self.timers.is_initiator = is_initiator; + } + + // We don't really clear the timers, but we set them to the current time to + // so the reference time frame is the same + fn clear_all(&mut self) { + for session in &mut self.sessions { + *session = None; + } + + self.packet_queue.clear(); + + self.timers.clear(); + } + + fn update_session_timers(&mut self, time_now: Duration) { + let timers = &mut self.timers; + + for (i, t) in timers.session_timers.iter_mut().enumerate() { + if time_now - *t > REJECT_AFTER_TIME { + if let Some(session) = self.sessions[i].take() { + tracing::debug!( + message = "SESSION_EXPIRED(REJECT_AFTER_TIME)", + session = session.receiving_index + ); + } + *t = time_now; + } + } + } + + pub fn update_timers<'a>(&mut self, dst: &'a mut [u8]) -> TunnResult<'a> { + let mut handshake_initiation_required = false; + let mut keepalive_required = false; + + let time = Instant::now(); + + if self.timers.should_reset_rr { + self.rate_limiter.reset_count(); + } + + // All the times are counted from tunnel initiation, for efficiency our timers are rounded + // to a second, as there is no real benefit to having highly accurate timers. + let now = time.duration_since(self.timers.time_started); + self.timers[TimeCurrent] = now; + + self.update_session_timers(now); + + // Load timers only once: + let session_established = self.timers[TimeSessionEstablished]; + let handshake_started = self.timers[TimeLastHandshakeStarted]; + let aut_packet_received = self.timers[TimeLastPacketReceived]; + let aut_packet_sent = self.timers[TimeLastPacketSent]; + let data_packet_received = self.timers[TimeLastDataPacketReceived]; + let data_packet_sent = self.timers[TimeLastDataPacketSent]; + let persistent_keepalive = self.timers.persistent_keepalive; + + { + if self.handshake.is_expired() { + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + // Clear cookie after COOKIE_EXPIRATION_TIME + if self.handshake.has_cookie() + && now - self.timers[TimeCookieReceived] >= COOKIE_EXPIRATION_TIME + { + self.handshake.clear_cookie(); + } + + // All ephemeral private keys and symmetric session keys are zeroed out after + // (REJECT_AFTER_TIME * 3) ms if no new keys have been exchanged. + if now - session_established >= REJECT_AFTER_TIME * 3 { + tracing::error!("CONNECTION_EXPIRED(REJECT_AFTER_TIME * 3)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + if let Some(time_init_sent) = self.handshake.timer() { + // Handshake Initiation Retransmission + if now - handshake_started >= REKEY_ATTEMPT_TIME { + // After REKEY_ATTEMPT_TIME ms of trying to initiate a new handshake, + // the retries give up and cease, and clear all existing packets queued + // up to be sent. If a packet is explicitly queued up to be sent, then + // this timer is reset. + tracing::error!("CONNECTION_EXPIRED(REKEY_ATTEMPT_TIME)"); + self.handshake.set_expired(); + self.clear_all(); + return TunnResult::Err(WireGuardError::ConnectionExpired); + } + + if time_init_sent.elapsed() >= REKEY_TIMEOUT { + // We avoid using `time` here, because it can be earlier than `time_init_sent`. + // Once `checked_duration_since` is stable we can use that. + // A handshake initiation is retried after REKEY_TIMEOUT + jitter ms, + // if a response has not been received, where jitter is some random + // value between 0 and 333 ms. + tracing::warn!("HANDSHAKE(REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + } else { + if self.timers.is_initiator() { + // After sending a packet, if the sender was the original initiator + // of the handshake and if the current session key is REKEY_AFTER_TIME + // ms old, we initiate a new handshake. If the sender was the original + // responder of the handshake, it does not re-initiate a new handshake + // after REKEY_AFTER_TIME ms like the original initiator does. + if session_established < data_packet_sent + && now - session_established >= REKEY_AFTER_TIME + { + tracing::debug!("HANDSHAKE(REKEY_AFTER_TIME (on send))"); + handshake_initiation_required = true; + } + + // After receiving a packet, if the receiver was the original initiator + // of the handshake and if the current session key is REJECT_AFTER_TIME + // - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT ms old, we initiate a new + // handshake. + if session_established < data_packet_received + && now - session_established + >= REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT + { + tracing::warn!( + "HANDSHAKE(REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - \ + REKEY_TIMEOUT \ + (on receive))" + ); + handshake_initiation_required = true; + } + } + + // If we have sent a packet to a given peer but have not received a + // packet after from that peer for (KEEPALIVE + REKEY_TIMEOUT) ms, + // we initiate a new handshake. + if data_packet_sent > aut_packet_received + && now - aut_packet_received >= KEEPALIVE_TIMEOUT + REKEY_TIMEOUT + && mem::replace(&mut self.timers.want_handshake, false) + { + tracing::warn!("HANDSHAKE(KEEPALIVE + REKEY_TIMEOUT)"); + handshake_initiation_required = true; + } + + if !handshake_initiation_required { + // If a packet has been received from a given peer, but we have not sent one back + // to the given peer in KEEPALIVE ms, we send an empty packet. + if data_packet_received > aut_packet_sent + && now - aut_packet_sent >= KEEPALIVE_TIMEOUT + && mem::replace(&mut self.timers.want_keepalive, false) + { + tracing::debug!("KEEPALIVE(KEEPALIVE_TIMEOUT)"); + keepalive_required = true; + } + + // Persistent KEEPALIVE + if persistent_keepalive > 0 + && (now - self.timers[TimePersistentKeepalive] + >= Duration::from_secs(persistent_keepalive as _)) + { + tracing::debug!("KEEPALIVE(PERSISTENT_KEEPALIVE)"); + self.timer_tick(TimePersistentKeepalive); + keepalive_required = true; + } + } + } + } + + if handshake_initiation_required { + return self.format_handshake_initiation(dst, true); + } + + if keepalive_required { + return self.encapsulate(&[], dst); + } + + TunnResult::Done + } + + pub fn time_since_last_handshake(&self) -> Option { + let current_session = self.current; + if self.sessions[current_session % super::N_SESSIONS].is_some() { + let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started); + let duration_since_session_established = self.timers[TimeSessionEstablished]; + + Some(duration_since_tun_start - duration_since_session_established) + } else { + None + } + } + + pub fn persistent_keepalive(&self) -> Option { + let keepalive = self.timers.persistent_keepalive; + + if keepalive > 0 { + Some(keepalive as u16) + } else { + None + } + } +} diff --git a/lib/boringtun/src/serialization.rs b/lib/boringtun/src/serialization.rs new file mode 100644 index 0000000..e6920f8 --- /dev/null +++ b/lib/boringtun/src/serialization.rs @@ -0,0 +1,33 @@ +pub(crate) struct KeyBytes(pub [u8; 32]); + +impl std::str::FromStr for KeyBytes { + type Err = &'static str; + + /// Can parse a secret key from a hex or base64 encoded string. + fn from_str(s: &str) -> Result { + let mut internal = [0u8; 32]; + + match s.len() { + 64 => { + // Try to parse as hex + for i in 0..32 { + internal[i] = u8::from_str_radix(&s[i * 2..=i * 2 + 1], 16) + .map_err(|_| "Illegal character in key")?; + } + } + 43 | 44 => { + // Try to parse as base64 + if let Ok(decoded_key) = base64::decode(s) { + if decoded_key.len() == internal.len() { + internal[..].copy_from_slice(&decoded_key); + } else { + return Err("Illegal character in key"); + } + } + } + _ => return Err("Illegal key size"), + } + + Ok(KeyBytes(internal)) + } +} diff --git a/lib/boringtun/src/sleepyinstant/mod.rs b/lib/boringtun/src/sleepyinstant/mod.rs new file mode 100644 index 0000000..542beea --- /dev/null +++ b/lib/boringtun/src/sleepyinstant/mod.rs @@ -0,0 +1,77 @@ +#![forbid(unsafe_code)] +//! Attempts to provide the same functionality as std::time::Instant, except it +//! uses a timer which accounts for time when the system is asleep +use std::time::Duration; + +#[cfg(target_os = "windows")] +mod windows; +#[cfg(target_os = "windows")] +use windows as inner; + +#[cfg(unix)] +mod unix; +#[cfg(unix)] +use unix as inner; + +/// A measurement of a monotonically nondecreasing clock. +/// Opaque and useful only with [`Duration`]. +/// +/// Instants are always guaranteed, barring [platform bugs], to be no less than any previously +/// measured instant when created, and are often useful for tasks such as measuring +/// benchmarks or timing how long an operation takes. +/// +/// Note, however, that instants are **not** guaranteed to be **steady**. In other +/// words, each tick of the underlying clock might not be the same length (e.g. +/// some seconds may be longer than others). An instant may jump forwards or +/// experience time dilation (slow down or speed up), but it will never go +/// backwards. +/// +/// Instants are opaque types that can only be compared to one another. There is +/// no method to get "the number of seconds" from an instant. Instead, it only +/// allows measuring the duration between two instants (or comparing two +/// instants). +/// +/// The size of an `Instant` struct may vary depending on the target operating +/// system. +/// +#[derive(Clone, Copy, Debug)] +pub struct Instant { + t: inner::Instant, +} + +impl Instant { + /// Returns an instant corresponding to "now". + pub fn now() -> Self { + Self { + t: inner::Instant::now(), + } + } + + /// Returns the amount of time elapsed from another instant to this one, + /// or zero duration if that instant is later than this one. + /// + /// # Panics + /// + /// panics when `earlier` was later than `self`. + pub fn duration_since(&self, earlier: Instant) -> Duration { + self.t.duration_since(earlier.t) + } + + /// Returns the amount of time elapsed since this instant was created. + pub fn elapsed(&self) -> Duration { + Self::now().duration_since(*self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn time_increments_after_sleep() { + let sleep_time = Duration::from_millis(10); + let start = Instant::now(); + std::thread::sleep(sleep_time); + assert!(start.elapsed() >= sleep_time); + } +} diff --git a/lib/boringtun/src/sleepyinstant/unix.rs b/lib/boringtun/src/sleepyinstant/unix.rs new file mode 100644 index 0000000..8488c2f --- /dev/null +++ b/lib/boringtun/src/sleepyinstant/unix.rs @@ -0,0 +1,48 @@ +use std::time::Duration; + +use nix::sys::time::TimeSpec; +use nix::time::{clock_gettime, ClockId}; + +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +const CLOCK_ID: ClockId = ClockId::CLOCK_MONOTONIC; +#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))] +const CLOCK_ID: ClockId = ClockId::CLOCK_BOOTTIME; + +#[derive(Clone, Copy, Debug)] +pub(crate) struct Instant { + t: TimeSpec, +} + +impl Instant { + pub(crate) fn now() -> Self { + // std::time::Instant unwraps as well, so feel safe doing so here + let t = clock_gettime(CLOCK_ID).unwrap(); + Self { t } + } + + fn checked_duration_since(&self, earlier: Instant) -> Option { + const NANOSECOND: nix::libc::c_long = 1_000_000_000; + let (tv_sec, tv_nsec) = if self.t.tv_nsec() < earlier.t.tv_nsec() { + ( + self.t.tv_sec() - earlier.t.tv_sec() - 1, + self.t.tv_nsec() - earlier.t.tv_nsec() + NANOSECOND, + ) + } else { + ( + self.t.tv_sec() - earlier.t.tv_sec(), + self.t.tv_nsec() - earlier.t.tv_nsec(), + ) + }; + + if tv_sec < 0 { + None + } else { + Some(Duration::new(tv_sec as _, tv_nsec as _)) + } + } + + pub(crate) fn duration_since(&self, earlier: Instant) -> Duration { + self.checked_duration_since(earlier) + .unwrap_or(Duration::ZERO) + } +} diff --git a/lib/boringtun/src/sleepyinstant/windows.rs b/lib/boringtun/src/sleepyinstant/windows.rs new file mode 100644 index 0000000..ac85229 --- /dev/null +++ b/lib/boringtun/src/sleepyinstant/windows.rs @@ -0,0 +1 @@ +pub(crate) use std::time::Instant; diff --git a/lib/boringtun/src/wireguard_ffi.h b/lib/boringtun/src/wireguard_ffi.h new file mode 100644 index 0000000..5cdd901 --- /dev/null +++ b/lib/boringtun/src/wireguard_ffi.h @@ -0,0 +1,106 @@ +// Copyright (c) 2019 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#pragma once + +#include +#include + +struct wireguard_tunnel; // This corresponds to the Rust type + +enum +{ + MAX_WIREGUARD_PACKET_SIZE = 65536 + 64, +}; + +enum result_type +{ + WIREGUARD_DONE = 0, + WRITE_TO_NETWORK = 1, + WIREGUARD_ERROR = 2, + WRITE_TO_TUNNEL_IPV4 = 4, + WRITE_TO_TUNNEL_IPV6 = 6, +}; + +struct wireguard_result +{ + enum result_type op; + size_t size; +}; + +struct stats +{ + int64_t time_since_last_handshake; + size_t tx_bytes; + size_t rx_bytes; + float estimated_loss; + int32_t estimated_rtt; // rtt estimated on time it took to complete latest initiated handshake in ms + uint8_t reserved[56]; // decrement appropriately when adding new fields +}; + +struct x25519_key +{ + uint8_t key[32]; +}; + +// Generates a fresh x25519 secret key +struct x25519_key x25519_secret_key(); +// Computes an x25519 public key from a secret key +struct x25519_key x25519_public_key(struct x25519_key private_key); +// Encodes a public or private x25519 key to base64. Must be freed with x25519_key_to_str_free. +const char *x25519_key_to_base64(struct x25519_key key); +// Encodes a public or private x25519 key to hex. Must be freed with x25519_key_to_str_free. +const char *x25519_key_to_hex(struct x25519_key key); +// Free string pointer obtained from either x25519_key_to_base64 or x25519_key_to_hex +void x25519_key_to_str_free(const char *key_str); +// Check if a null terminated string represents a valid x25519 key +// Returns 0 if not +int check_base64_encoded_x25519_key(const char *key); + +/// Sets the default tracing_subscriber to write to `log_func`. +/// +/// Uses Compact format without level, target, thread ids, thread names, or ansi control characters. +/// Subscribes to TRACE level events. +/// +/// This function should only be called once as setting the default tracing_subscriber +/// more than once will result in an error. +/// +/// Returns false on failure. +/// +/// # Safety +/// +/// `c_char` will be freed by the library after calling `log_func`. If the value needs +/// to be stored then `log_func` needs to create a copy, e.g. `strcpy`. +bool set_logging_function(void (*log_func)(const char *)); + +// Allocate a new tunnel +struct wireguard_tunnel *new_tunnel(const char *static_private, + const char *server_static_public, + const char *preshared_key, + uint16_t keep_alive, // Keep alive interval in seconds + uint32_t index); // The 24bit index prefix to be used for session indexes + +// Deallocate the tunnel +void tunnel_free(struct wireguard_tunnel *); + +struct wireguard_result wireguard_write(const struct wireguard_tunnel *tunnel, + const uint8_t *src, + uint32_t src_size, + uint8_t *dst, + uint32_t dst_size); + +struct wireguard_result wireguard_read(const struct wireguard_tunnel *tunnel, + const uint8_t *src, + uint32_t src_size, + uint8_t *dst, + uint32_t dst_size); + +struct wireguard_result wireguard_tick(const struct wireguard_tunnel *tunnel, + uint8_t *dst, + uint32_t dst_size); + +struct wireguard_result wireguard_force_handshake(const struct wireguard_tunnel *tunnel, + uint8_t *dst, + uint32_t dst_size); + +struct stats wireguard_stats(const struct wireguard_tunnel *tunnel); diff --git a/proto/message.proto b/proto/message.proto index b0073d6..197c884 100644 --- a/proto/message.proto +++ b/proto/message.proto @@ -24,6 +24,7 @@ message RegistrationRequest { fixed32 virtual_ip = 6; bool allow_ip_change = 7; bool client_secret = 8; + bytes client_secret_hash = 9; } message RegistrationResponse { @@ -41,6 +42,8 @@ message DeviceInfo { fixed32 virtual_ip = 2; uint32 device_status = 3; bool client_secret = 4; + bytes client_secret_hash = 5; + bool wireguard = 6; } message DeviceList { diff --git a/src/cipher/rsa_cipher.rs b/src/cipher/rsa_cipher.rs index 4052fce..ba7762b 100644 --- a/src/cipher/rsa_cipher.rs +++ b/src/cipher/rsa_cipher.rs @@ -89,7 +89,7 @@ impl RsaCipher { }) } pub fn finger_(public_key_der: &[u8]) -> io::Result { - match rsa::pkcs8::SubjectPublicKeyInfo::from_der(public_key_der) { + match spki::SubjectPublicKeyInfoOwned::from_der(public_key_der) { Ok(spki) => match spki.fingerprint_base64() { Ok(finger) => Ok(finger), Err(e) => Err(io::Error::new( @@ -120,7 +120,7 @@ impl RsaCipher { match self .inner .private_key - .decrypt(rsa::PaddingScheme::PKCS1v15Encrypt, net_packet.payload()) + .decrypt(rsa::Pkcs1v15Encrypt, net_packet.payload()) { Ok(rs) => { let mut nonce_raw = [0; 12]; diff --git a/src/core/entity/mod.rs b/src/core/entity/mod.rs index 8ec0ccc..742885d 100644 --- a/src/core/entity/mod.rs +++ b/src/core/entity/mod.rs @@ -1,10 +1,23 @@ -use chrono::{DateTime, Local}; use std::collections::HashMap; use std::net::{Ipv4Addr, SocketAddr}; + +use chrono::{DateTime, Local}; use tokio::sync::mpsc::Sender; +#[derive(Clone, Debug)] +pub struct WireGuardConfig { + pub vnts_endpoint: String, + pub vnts_allowed_ips: String, + pub group_id: String, + pub device_id: String, + pub ip: Ipv4Addr, + pub prefix: u8, + pub persistent_keepalive: u16, + pub secret_key: [u8; 32], + pub public_key: [u8; 32], +} /// 网段信息 -#[derive(Default)] +#[derive(Default, Debug)] pub struct NetworkInfo { // 组网编号 // pub group: String, @@ -33,6 +46,7 @@ impl NetworkInfo { } /// 客户端信息 +#[derive(Debug)] pub struct ClientInfo { // 设备ID pub device_id: String, @@ -42,6 +56,8 @@ pub struct ClientInfo { pub name: String, // 客户端间是否加密 pub client_secret: bool, + // 加密hash + pub client_secret_hash: Vec, // 和服务端是否加密 pub server_secret: bool, // 链接服务器的来源地址 @@ -52,11 +68,51 @@ pub struct ClientInfo { pub virtual_ip: u32, // 建立的tcp连接发送端 pub tcp_sender: Option>>, + // wireguard客户端公钥 + pub wireguard: Option<[u8; 32]>, + pub wg_sender: Option, Ipv4Addr)>>, pub client_status: Option, pub last_join_time: DateTime, pub timestamp: i64, } - +/// 客户端简要信息 +#[derive(Debug)] +pub struct SimpleClientInfo { + // 分配的ip + pub virtual_ip: u32, + // 版本 + pub version: String, + // 名称 + pub name: String, + // 客户端间是否加密 + pub client_secret: bool, + // 加密hash + pub client_secret_hash: Vec, + // 和服务端是否加密 + pub server_secret: bool, + // 是否在线 + pub online: bool, + // 是wg客户端 + pub wireguard: bool, +} +impl From<&ClientInfo> for SimpleClientInfo { + fn from(value: &ClientInfo) -> Self { + Self { + virtual_ip: value.virtual_ip, + version: value.version.clone(), + name: value.name.clone(), + client_secret: value.client_secret, + client_secret_hash: if value.online { + value.client_secret_hash.clone() + } else { + vec![] + }, + server_secret: value.server_secret, + online: value.online, + wireguard: value.wireguard.is_some(), + } + } +} impl Default for ClientInfo { fn default() -> Self { Self { @@ -64,18 +120,21 @@ impl Default for ClientInfo { version: "".to_string(), name: "".to_string(), client_secret: false, + client_secret_hash: vec![], server_secret: false, address: "0.0.0.0:0".parse().unwrap(), online: false, virtual_ip: 0, tcp_sender: None, + wireguard: None, + wg_sender: None, client_status: None, last_join_time: Local::now(), timestamp: 0, } } } - +#[derive(Debug)] pub struct ClientStatusInfo { pub p2p_list: Vec, pub up_stream: u64, diff --git a/src/core/server/mod.rs b/src/core/server/mod.rs index 8ef90fe..90286b5 100644 --- a/src/core/server/mod.rs +++ b/src/core/server/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio::net::{TcpListener, UdpSocket}; use crate::cipher::RsaCipher; +use crate::core::server::wire_guard::WireGuardGroup; use crate::core::service::PacketHandler; use crate::core::store::cache::AppCache; use crate::ConfigInfo; @@ -12,6 +13,8 @@ mod tcp; mod udp; #[cfg(feature = "web")] mod web; +mod websocket; +mod wire_guard; pub async fn start( udp: std::net::UdpSocket, @@ -28,8 +31,9 @@ pub async fn start( rsa_cipher.clone(), udp.clone(), ); + let wg = WireGuardGroup::new(cache.clone(), config.clone(), udp.clone()); let tcp_handle = tokio::spawn(tcp::start(TcpListener::from_std(tcp)?, handler.clone())); - let udp_handle = tokio::spawn(udp::start(udp, handler.clone())); + let udp_handle = tokio::spawn(udp::start(udp, handler.clone(), wg)); #[cfg(not(feature = "web"))] let _ = tokio::try_join!(tcp_handle, udp_handle); #[cfg(feature = "web")] diff --git a/src/core/server/tcp.rs b/src/core/server/tcp.rs index 7f8c610..f66b607 100644 --- a/src/core/server/tcp.rs +++ b/src/core/server/tcp.rs @@ -1,4 +1,5 @@ use crate::core::service::PacketHandler; +use crate::core::store::cache::VntContext; use crate::protocol::NetPacket; use std::io; use std::net::SocketAddr; @@ -7,6 +8,8 @@ use tokio::net::tcp::OwnedReadHalf; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{channel, Sender}; +const TCP_MAX_PACKET_SIZE: usize = (1 << 24) - 1; + pub async fn start(tcp: TcpListener, handler: PacketHandler) { if let Err(e) = accept(tcp, handler).await { log::error!("accept {:?}", e); @@ -17,24 +20,47 @@ async fn accept(tcp: TcpListener, handler: PacketHandler) -> io::Result<()> { loop { let (stream, addr) = tcp.accept().await?; let _ = stream.set_nodelay(true); - stream_handle(stream, addr, handler.clone()).await; + tokio::spawn(stream_handle(stream, addr, handler.clone())); } } async fn stream_handle(stream: TcpStream, addr: SocketAddr, handler: PacketHandler) { + { + let mut buf = [0u8; 1]; + match stream.peek(&mut buf).await { + Ok(len) => { + if len == 0 { + log::warn!("数据流读取失败 {}", addr); + return; + } + if buf[0] != 0 { + //可能是ws协议 + crate::core::server::websocket::handle_websocket_connection( + stream, addr, handler, + ) + .await; + return; + } + } + Err(e) => { + log::warn!("数据流读取失败 {:?} {}", e, addr); + return; + } + } + } + let (r, mut w) = stream.into_split(); let (sender, mut receiver) = channel::>(100); tokio::spawn(async move { while let Some(data) = receiver.recv().await { let len = data.len(); + if len > TCP_MAX_PACKET_SIZE { + log::warn!("超过了tcp的最大长度传输 地址{}", addr); + return; + } if let Err(e) = w - .write_all(&[ - (len >> 24) as u8, - (len >> 16) as u8, - (len >> 8) as u8, - len as u8, - ]) + .write_all(&[0, (len >> 16) as u8, (len >> 8) as u8, len as u8]) .await { log::info!("发送失败,链接终止:{:?},{:?}", addr, e); @@ -48,27 +74,36 @@ async fn stream_handle(stream: TcpStream, addr: SocketAddr, handler: PacketHandl let _ = w.shutdown().await; }); tokio::spawn(async move { - if let Err(e) = tcp_read(r, addr, sender, handler).await { + let mut context = VntContext { + link_context: None, + server_cipher: None, + link_address: addr, + }; + if let Err(e) = tcp_read(&mut context, r, addr, sender, &handler).await { log::warn!("tcp_read {:?}", e) } + handler.leave(context).await; }); } async fn tcp_read( + context: &mut VntContext, mut read: OwnedReadHalf, addr: SocketAddr, sender: Sender>, - handler: PacketHandler, + handler: &PacketHandler, ) -> io::Result<()> { let mut head = [0; 4]; let mut buf = [0; 65536]; let sender = Some(sender); + loop { read.read_exact(&mut head).await?; - let len = ((head[0] as usize) << 24) - | ((head[1] as usize) << 16) - | ((head[2] as usize) << 8) - | head[3] as usize; + if head[0] != 0 { + log::warn!("tcp数据流错误 来源地址 {}", addr); + return Ok(()); + } + let len = ((head[1] as usize) << 16) | ((head[2] as usize) << 8) | head[3] as usize; if len < 12 || len > buf.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, @@ -77,7 +112,7 @@ async fn tcp_read( } read.read_exact(&mut buf[..len]).await?; let packet = NetPacket::new0(len, &mut buf)?; - if let Some(rs) = handler.handle(packet, addr, &sender).await { + if let Some(rs) = handler.handle(context, packet, addr, &sender).await { if sender .as_ref() .unwrap() diff --git a/src/core/server/udp.rs b/src/core/server/udp.rs index 402d21b..39d935b 100644 --- a/src/core/server/udp.rs +++ b/src/core/server/udp.rs @@ -1,21 +1,91 @@ +use crate::core::server::wire_guard::WireGuardGroup; +use crate::core::service::PacketHandler; +use crate::core::store::cache::VntContext; +use crate::protocol::NetPacket; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::net::SocketAddr; use std::sync::Arc; - +use std::time::Duration; use tokio::net::UdpSocket; +use tokio::sync::mpsc::{channel, Sender}; -use crate::core::service::PacketHandler; -use crate::protocol::NetPacket; +pub async fn start(main_udp: Arc, handler: PacketHandler, mut wg: WireGuardGroup) { + let mut udp_group = UdpGroup::new(main_udp.clone(), handler); + let mut buf = [0u8; 65536]; -pub async fn start(main_udp: Arc, handler: PacketHandler) { loop { - let mut buf = vec![0u8; 65536]; match main_udp.recv_from(&mut buf).await { Ok((len, addr)) => { - let handler = handler.clone(); - let udp = main_udp.clone(); - tokio::spawn(async move { - match NetPacket::new(&mut buf[..len]) { + if len == 0 { + log::warn!("UnexpectedEof {}", addr); + continue; + } + let buf = buf[..len].to_vec(); + if WireGuardGroup::maybe_wg(&buf) { + // 可能是wg协议 + wg.handle(buf, addr); + continue; + } + if let Err(e) = udp_group.handle(buf, addr) { + log::error!("{} {:?}", addr, e); + } + } + #[cfg(windows)] + Err(ref e) if e.kind() == std::io::ErrorKind::ConnectionReset => { + // 忽略 ConnectionReset 错误 + } + Err(e) => { + log::error!("{:?}", e) + } + } + } +} + +pub struct UdpGroup { + data_channel_map: Arc>>>>, + udp: Arc, + handler: PacketHandler, +} + +impl UdpGroup { + pub fn new(udp: Arc, handler: PacketHandler) -> Self { + Self { + data_channel_map: Default::default(), + udp, + handler, + } + } + pub fn handle(&mut self, buf: Vec, addr: SocketAddr) -> anyhow::Result<()> { + if let Some(sender) = self.data_channel_map.lock().get(&addr) { + sender.try_send(buf)?; + return Ok(()); + } + let (udp_sender, mut udp_receiver) = channel(64); + udp_sender.try_send(buf)?; + let data_channel_map = self.data_channel_map.clone(); + data_channel_map.lock().insert(addr, udp_sender); + let handler = self.handler.clone(); + let udp = self.udp.clone(); + tokio::spawn(async move { + let mut context = VntContext { + link_context: None, + server_cipher: None, + link_address: addr, + }; + loop { + let data = match tokio::time::timeout(Duration::from_secs(60), udp_receiver.recv()) + .await + { + Ok(data) => data, + Err(_) => break, + }; + if let Some(data) = data { + match NetPacket::new(data) { Ok(net_packet) => { - if let Some(rs) = handler.handle(net_packet, addr, &None).await { + if let Some(rs) = + handler.handle(&mut context, net_packet, addr, &None).await + { if let Err(e) = udp.send_to(rs.buffer(), addr).await { log::error!("{:?} {}", e, addr) } @@ -25,11 +95,13 @@ pub async fn start(main_udp: Arc, handler: PacketHandler) { log::error!("{:?} {}", e, addr) } } - }); + } else { + break; + } } - Err(e) => { - log::error!("{:?}", e) - } - } + handler.leave(context).await; + data_channel_map.lock().remove(&addr); + }); + Ok(()) } } diff --git a/src/core/server/web/mod.rs b/src/core/server/web/mod.rs index e7b870e..8926bbb 100644 --- a/src/core/server/web/mod.rs +++ b/src/core/server/web/mod.rs @@ -1,6 +1,5 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::net; -use std::sync::Arc; use actix_web::dev::Service; use actix_web::web::Data; @@ -9,7 +8,9 @@ use actix_web::{middleware, post, web, App, HttpRequest, HttpResponse, HttpServe use actix_web_static_files::ResourceFiles; use crate::core::server::web::service::VntsWebService; -use crate::core::server::web::vo::{LoginData, ResponseMessage}; +use crate::core::server::web::vo::req::{CreateWGData, LoginData, RemoveClientReq}; + +use crate::core::server::web::vo::ResponseMessage; use crate::core::store::cache::AppCache; use crate::ConfigInfo; @@ -18,7 +19,7 @@ mod vo; include!(concat!(env!("OUT_DIR"), "/generated.rs")); -#[post("/login")] +#[post("/api/login")] async fn login(service: Data, data: web::Json) -> HttpResponse { match service.login(data.0).await { Ok(auth) => HttpResponse::Ok().json(ResponseMessage::success(auth)), @@ -26,13 +27,37 @@ async fn login(service: Data, data: web::Json) -> Htt } } -#[post("/group_list")] +#[post("/api/group_list")] async fn group_list(_req: HttpRequest, service: Data) -> HttpResponse { let info = service.group_list(); HttpResponse::Ok().json(ResponseMessage::success(info)) } - -#[post("/group_info")] +#[post("/api/remove_client")] +async fn remove_client( + _req: HttpRequest, + service: Data, + data: web::Json, +) -> HttpResponse { + service.remove_client(data.0); + HttpResponse::Ok().json(ResponseMessage::success("success")) +} +#[post("/api/private_key")] +async fn private_key(_req: HttpRequest, service: Data) -> HttpResponse { + let private_key = service.gen_wg_private_key(); + HttpResponse::Ok().json(ResponseMessage::success(private_key)) +} +#[post("/api/create_wg_config")] +async fn create_wg_config( + _req: HttpRequest, + service: Data, + data: web::Json, +) -> HttpResponse { + match service.create_wg_config(data.0).await { + Ok(wg_config) => HttpResponse::Ok().json(ResponseMessage::success(wg_config)), + Err(e) => HttpResponse::Ok().json(ResponseMessage::fail(e.to_string())), + } +} +#[post("/api/group_info")] async fn group_info( _req: HttpRequest, service: Data, @@ -46,36 +71,19 @@ async fn group_info( } } -#[derive(Clone)] -struct AuthApi { - api_set: Arc>, -} - -fn auth_api_set() -> AuthApi { - let mut api_set = HashSet::new(); - api_set.insert("/group_info".to_string()); - api_set.insert("/group_list".to_string()); - AuthApi { - api_set: Arc::new(api_set), - } -} - pub async fn start( lst: net::TcpListener, cache: AppCache, config: ConfigInfo, ) -> std::io::Result<()> { let web_service = VntsWebService::new(cache, config); - let auth_api = auth_api_set(); HttpServer::new(move || { let generated = generate(); App::new() .app_data(Data::new(web_service.clone())) - .app_data(Data::new(auth_api.clone())) .wrap_fn(|request, srv| { - let auth_api: &Data = request.app_data().unwrap(); let path = request.path(); - if path == "/login" || !auth_api.api_set.contains(path) { + if path == "/api/login" || !path.contains("/api/") { return srv.call(request); } let service: &Data = request.app_data().unwrap(); @@ -96,6 +104,9 @@ pub async fn start( }) .wrap(middleware::Compress::default()) .service(login) + .service(remove_client) + .service(private_key) + .service(create_wg_config) .service(group_list) .service(group_info) .service(ResourceFiles::new("/", generated)) diff --git a/src/core/server/web/service/mod.rs b/src/core/server/web/service/mod.rs index 9b00d39..0525667 100644 --- a/src/core/server/web/service/mod.rs +++ b/src/core/server/web/service/mod.rs @@ -1,11 +1,22 @@ -use crossbeam_utils::atomic::AtomicCell; -use std::net::{SocketAddr, SocketAddrV4}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}; +use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; -use crate::core::server::web::vo::{ - ClientInfo, ClientStatusInfo, GroupList, LoginData, NetworkInfo, +use anyhow::{anyhow, Context}; +use base64::engine::general_purpose; +use base64::Engine; +use crossbeam_utils::atomic::AtomicCell; +use ipnetwork::Ipv4Network; +use rsa::rand_core::RngCore; + +use crate::core::entity::WireGuardConfig; + +use crate::core::server::web::vo::req::{CreateWGData, CreateWgConfig, LoginData, RemoveClientReq}; +use crate::core::server::web::vo::res::{ + ClientInfo, ClientStatusInfo, GroupList, NetworkInfo, WGData, WgConfig, }; +use crate::core::service::server::{generate_ip, RegisterClientRequest}; use crate::core::store::cache::AppCache; use crate::ConfigInfo; @@ -60,6 +71,120 @@ impl VntsWebService { .collect(); GroupList { group_list } } + pub fn remove_client(&self, req: RemoveClientReq) { + if let Some(ip) = req.virtual_ip { + if let Some(network_info) = self.cache.virtual_network.get(&req.group_id) { + if let Some(client_info) = network_info.write().clients.remove(&ip.into()) { + if let Some(key) = client_info.wireguard { + self.cache.wg_group_map.remove(&key); + } + } + } + } else { + if let Some(network_info) = self.cache.virtual_network.remove(&req.group_id) { + for (_, client_info) in network_info.write().clients.drain() { + if let Some(key) = client_info.wireguard { + self.cache.wg_group_map.remove(&key); + } + } + } + } + } + pub fn gen_wg_private_key(&self) -> String { + let mut bytes = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut bytes); + return general_purpose::STANDARD.encode(bytes); + } + pub async fn create_wg_config(&self, wg_data: CreateWGData) -> anyhow::Result { + let device_id = wg_data.device_id.trim().to_string(); + let group_id = wg_data.group_id.trim().to_string(); + if group_id.is_empty() { + Err(anyhow!("组网id不能为空"))?; + } + if device_id.is_empty() { + Err(anyhow!("设备id不能为空"))?; + } + let cache = &self.cache; + let (secret_key, public_key) = Self::check_wg_config(&wg_data.config)?; + let gateway = self.config.gateway; + let netmask = self.config.netmask; + let network = Ipv4Network::with_netmask(gateway, netmask)?; + let network = Ipv4Network::with_netmask(network.network(), netmask)?; + let virtual_ip = if wg_data.virtual_ip.trim().is_empty() { + Ipv4Addr::UNSPECIFIED + } else { + Ipv4Addr::from_str(&wg_data.virtual_ip).context("虚拟IP错误")? + }; + let register_client_request = RegisterClientRequest { + group_id: group_id.clone(), + virtual_ip, + gateway, + netmask, + allow_ip_change: false, + device_id: device_id.clone(), + version: String::from("wg"), + name: wg_data.name.clone(), + client_secret: true, + client_secret_hash: vec![], + server_secret: true, + address: SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), + tcp_sender: None, + online: false, + wireguard: Some(public_key), + }; + let response = generate_ip(cache, register_client_request).await?; + let wireguard_config = WireGuardConfig { + vnts_endpoint: wg_data.config.vnts_endpoint.clone(), + vnts_allowed_ips: network.to_string(), + group_id: group_id.clone(), + device_id: device_id.clone(), + ip: response.virtual_ip, + prefix: network.prefix(), + persistent_keepalive: wg_data.config.persistent_keepalive, + secret_key, + public_key, + }; + cache.wg_group_map.insert(public_key, wireguard_config); + let config = WgConfig { + vnts_endpoint: wg_data.config.vnts_endpoint, + vnts_public_key: general_purpose::STANDARD.encode(&self.config.wg_public_key), + vnts_allowed_ips: network.to_string(), + public_key: general_purpose::STANDARD.encode(public_key), + private_key: general_purpose::STANDARD.encode(secret_key), + ip: response.virtual_ip, + prefix: network.prefix(), + persistent_keepalive: wg_data.config.persistent_keepalive, + }; + let wg_data = WGData { + group_id, + virtual_ip: response.virtual_ip, + device_id, + name: wg_data.name, + config, + }; + Ok(wg_data) + } + fn check_wg_config(config: &CreateWgConfig) -> anyhow::Result<([u8; 32], [u8; 32])> { + match config.vnts_endpoint.to_socket_addrs() { + Ok(mut addr) => { + if let Some(addr) = addr.next() { + if addr.ip().is_unspecified() || addr.port() == 0 { + Err(anyhow!("服务端地址错误"))? + } + } + } + Err(e) => Err(anyhow!("服务端地址解析失败:{}", e))?, + } + + let private_key = general_purpose::STANDARD + .decode(&config.private_key) + .context("私钥错误")?; + let private_key: [u8; 32] = private_key.try_into().map_err(|_| anyhow!("私钥错误"))?; + let secret_key = boringtun::x25519::StaticSecret::from(private_key); + let public_key = *boringtun::x25519::PublicKey::from(&secret_key).as_bytes(); + + Ok((private_key, public_key)) + } pub fn group_info(&self, group: String) -> Option { if let Some(info) = self.cache.virtual_network.get(&group) { let guard = info.read(); @@ -67,19 +192,20 @@ impl VntsWebService { guard.network_ip.into(), guard.mask_ip.into(), guard.gateway_ip.into(), + general_purpose::STANDARD.encode(&self.config.wg_public_key), ); - for into in guard.clients.values() { - let address = match into.address { - SocketAddr::V4(_) => into.address, + for info in guard.clients.values() { + let address = match info.address { + SocketAddr::V4(_) => info.address, SocketAddr::V6(ipv6) => { if let Some(ipv4) = ipv6.ip().to_ipv4_mapped() { SocketAddr::V4(SocketAddrV4::new(ipv4, ipv6.port())) } else { - into.address + info.address } } }; - let status_info = if let Some(client_status) = &into.client_status { + let status_info = if let Some(client_status) = &info.client_status { Some(ClientStatusInfo { p2p_list: client_status.p2p_list.clone(), up_stream: client_status.up_stream, @@ -93,18 +219,24 @@ impl VntsWebService { } else { None }; - + let mut wg_config = None; + if let Some(key) = &info.wireguard { + if let Some(v) = self.cache.wg_group_map.get(key) { + wg_config.replace(v.clone()); + } + } let client_info = ClientInfo { - device_id: into.device_id.clone(), - version: into.version.clone(), - name: into.name.clone(), - client_secret: into.client_secret, - server_secret: into.server_secret, + device_id: info.device_id.clone(), + version: info.version.clone(), + name: info.name.clone(), + client_secret: info.client_secret, + server_secret: info.server_secret, address, - online: into.online, - virtual_ip: into.virtual_ip.into(), + online: info.online, + virtual_ip: info.virtual_ip.into(), status_info, - last_join_time: into.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(), + last_join_time: info.last_join_time.format("%Y-%m-%d %H:%M:%S").to_string(), + wg_config: wg_config.map(|v| v.into()), }; network.clients.push(client_info); } @@ -116,32 +248,4 @@ impl VntsWebService { None } } - // pub fn groups_info(&self) -> GroupsInfo { - // let mut data = GroupsInfo::new(); - // for (group, info) in self.cache.virtual_network.key_values() { - // let guard = info.read(); - // let mut network = NetworkInfo::new( - // guard.network_ip.into(), - // guard.mask_ip.into(), - // guard.gateway_ip.into(), - // ); - // for (_ip, into) in guard.clients.iter() { - // let client_info = ClientInfo { - // device_id: into.device_id.clone(), - // name: into.name.clone(), - // client_secret: into.client_secret, - // server_secret: into.server_secret.is_some(), - // address: into.address, - // online: into.online, - // virtual_ip: into.virtual_ip.into(), - // }; - // network.clients.push(client_info); - // } - // network - // .clients - // .sort_by(|v1, v2| v1.virtual_ip.cmp(&v2.virtual_ip)); - // data.data.insert(group.to_string(), network); - // } - // data - // } } diff --git a/src/core/server/web/vo/mod.rs b/src/core/server/web/vo/mod.rs index 72630cb..8d9fcc4 100644 --- a/src/core/server/web/vo/mod.rs +++ b/src/core/server/web/vo/mod.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; -use std::net::{Ipv4Addr, SocketAddr}; - use serde::{Deserialize, Serialize}; +pub mod req; +pub mod res; #[derive(Debug, Serialize, Deserialize)] pub struct ResponseMessage { data: V, @@ -38,73 +37,3 @@ impl ResponseMessage> { } } } - -#[derive(Debug, Serialize, Deserialize)] -pub struct ClientInfo { - // 设备ID - pub device_id: String, - // 客户端版本 - pub version: String, - // 名称 - pub name: String, - // 客户端间是否加密 - pub client_secret: bool, - // 客户端和服务端是否加密 - pub server_secret: bool, - // 链接服务器的来源地址 - pub address: SocketAddr, - // 是否在线 - pub online: bool, - // 分配的ip - pub virtual_ip: Ipv4Addr, - pub status_info: Option, - pub last_join_time: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ClientStatusInfo { - pub p2p_list: Vec, - pub up_stream: u64, - pub down_stream: u64, - pub is_cone: bool, - pub update_time: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NetworkInfo { - // 网段 - pub network_ip: Ipv4Addr, - // 掩码 - pub mask_ip: Ipv4Addr, - // 网关 - pub gateway_ip: Ipv4Addr, - // 网段下的客户端列表 - pub clients: Vec, -} - -impl NetworkInfo { - pub fn new(network_ip: Ipv4Addr, mask_ip: Ipv4Addr, gateway_ip: Ipv4Addr) -> Self { - Self { - network_ip, - mask_ip, - gateway_ip, - clients: Default::default(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GroupList { - pub group_list: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GroupsInfo { - pub data: HashMap, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LoginData { - pub username: String, - pub password: String, -} diff --git a/src/core/server/web/vo/req.rs b/src/core/server/web/vo/req.rs new file mode 100644 index 0000000..721efff --- /dev/null +++ b/src/core/server/web/vo/req.rs @@ -0,0 +1,29 @@ +use serde::{Deserialize, Serialize}; +use std::net::Ipv4Addr; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateWGData { + pub group_id: String, + pub virtual_ip: String, + pub device_id: String, + pub name: String, + pub config: CreateWgConfig, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateWgConfig { + pub vnts_endpoint: String, + pub private_key: String, + pub persistent_keepalive: u16, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginData { + pub username: String, + pub password: String, +} +#[derive(Debug, Serialize, Deserialize)] +pub struct RemoveClientReq { + pub group_id: String, + pub virtual_ip: Option, +} diff --git a/src/core/server/web/vo/res.rs b/src/core/server/web/vo/res.rs new file mode 100644 index 0000000..4d46dab --- /dev/null +++ b/src/core/server/web/vo/res.rs @@ -0,0 +1,123 @@ +use crate::core::entity::WireGuardConfig; +use base64::engine::general_purpose; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use std::net::{Ipv4Addr, SocketAddr}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct WGData { + pub group_id: String, + pub virtual_ip: Ipv4Addr, + pub device_id: String, + pub name: String, + pub config: WgConfig, +} +#[derive(Debug, Serialize, Deserialize)] +pub struct WgConfig { + pub vnts_endpoint: String, + pub vnts_public_key: String, + pub vnts_allowed_ips: String, + + pub public_key: String, + pub private_key: String, + // 合一起是 Address = ip/prefix + pub ip: Ipv4Addr, + pub prefix: u8, + pub persistent_keepalive: u16, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClientInfo { + // 设备ID + pub device_id: String, + // 客户端版本 + pub version: String, + // 名称 + pub name: String, + // 客户端间是否加密 + pub client_secret: bool, + // 客户端和服务端是否加密 + pub server_secret: bool, + // 链接服务器的来源地址 + pub address: SocketAddr, + // 是否在线 + pub online: bool, + // 分配的ip + pub virtual_ip: Ipv4Addr, + pub status_info: Option, + pub last_join_time: String, + // wg配置 + pub wg_config: Option, +} +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct WireGuardConfigRes { + pub vnts_endpoint: String, + pub vnts_allowed_ips: String, + pub group_id: String, + pub device_id: String, + pub ip: Ipv4Addr, + pub prefix: u8, + pub persistent_keepalive: u16, + pub secret_key: String, + pub public_key: String, +} +impl From for WireGuardConfigRes { + fn from(value: WireGuardConfig) -> Self { + Self { + vnts_endpoint: value.vnts_endpoint, + vnts_allowed_ips: value.vnts_allowed_ips, + group_id: value.group_id, + device_id: value.device_id, + ip: value.ip, + prefix: value.prefix, + persistent_keepalive: value.persistent_keepalive, + secret_key: general_purpose::STANDARD.encode(&value.secret_key), + public_key: general_purpose::STANDARD.encode(&value.public_key), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClientStatusInfo { + pub p2p_list: Vec, + pub up_stream: u64, + pub down_stream: u64, + pub is_cone: bool, + pub update_time: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct NetworkInfo { + // 网段 + pub network_ip: Ipv4Addr, + // 掩码 + pub mask_ip: Ipv4Addr, + // 网关 + pub gateway_ip: Ipv4Addr, + // vnts的公钥 + pub vnts_public_key: String, + // 网段下的客户端列表 + pub clients: Vec, +} + +impl NetworkInfo { + pub fn new( + network_ip: Ipv4Addr, + mask_ip: Ipv4Addr, + gateway_ip: Ipv4Addr, + vnts_public_key: String, + ) -> Self { + Self { + network_ip, + mask_ip, + gateway_ip, + vnts_public_key, + clients: Default::default(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GroupList { + pub group_list: Vec, +} diff --git a/src/core/server/websocket/mod.rs b/src/core/server/websocket/mod.rs new file mode 100644 index 0000000..40894ba --- /dev/null +++ b/src/core/server/websocket/mod.rs @@ -0,0 +1,78 @@ +use crate::core::service::PacketHandler; +use crate::core::store::cache::VntContext; +use crate::protocol::NetPacket; +use anyhow::Context; +use futures_util::{SinkExt, StreamExt}; +use std::net::SocketAddr; +use tokio::net::TcpStream; +use tokio::sync::mpsc::channel; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; + +pub async fn handle_websocket_connection( + stream: TcpStream, + addr: SocketAddr, + handler: PacketHandler, +) { + tokio::spawn(async move { + let mut context = VntContext { + link_context: None, + server_cipher: None, + link_address: addr, + }; + if let Err(e) = handle_websocket_connection0(&mut context, stream, addr, &handler).await { + log::warn!("websocket err {:?} {}", e, addr); + } + handler.leave(context).await; + }); +} + +async fn handle_websocket_connection0( + context: &mut VntContext, + stream: TcpStream, + addr: SocketAddr, + handler: &PacketHandler, +) -> anyhow::Result<()> { + let ws_stream = accept_async(stream) + .await + .with_context(|| format!("Error during WebSocket handshake {}", addr))?; + + let (mut ws_write, mut ws_read) = ws_stream.split(); + + let (sender, mut receiver) = channel::>(100); + tokio::spawn(async move { + while let Some(data) = receiver.recv().await { + if let Err(e) = ws_write.send(Message::Binary(data)).await { + log::warn!("websocket err {:?} {}", e, addr); + break; + } + } + let _ = ws_write.close().await; + }); + let sender = Some(sender); + + while let Some(msg) = ws_read.next().await { + let msg = msg.with_context(|| format!("Error during WebSocket read {}", addr))?; + match msg { + Message::Text(txt) => log::info!("Received text message: {} {}", txt, addr), + Message::Binary(mut data) => { + let packet = NetPacket::new0(data.len(), &mut data)?; + if let Some(rs) = handler.handle(context, packet, addr, &sender).await { + if sender + .as_ref() + .unwrap() + .send(rs.buffer().to_vec()) + .await + .is_err() + { + break; + } + } + } + Message::Ping(_) | Message::Pong(_) => (), + Message::Close(_) => break, + _ => {} + } + } + return Ok(()); +} diff --git a/src/core/server/wire_guard/mod.rs b/src/core/server/wire_guard/mod.rs new file mode 100644 index 0000000..7d7b8d5 --- /dev/null +++ b/src/core/server/wire_guard/mod.rs @@ -0,0 +1,470 @@ +use crate::core::entity::{NetworkInfo, WireGuardConfig}; +use crate::core::store::cache::AppCache; +use crate::protocol::{ip_turn_packet, NetPacket, Protocol, HEAD_LEN, MAX_TTL}; +use crate::ConfigInfo; +use anyhow::{anyhow, Context}; +use boringtun::noise::errors::WireGuardError; +use boringtun::noise::{handshake, Packet, Tunn, TunnResult}; +use boringtun::x25519::StaticSecret; +use chrono::Local; +use packet::icmp::{icmp, Kind}; +use packet::ip::ipv4; +use packet::ip::ipv4::packet::IpV4Packet; +use parking_lot::{Mutex, RwLock}; +use rand::RngCore; +use std::collections::HashMap; +use std::net::{Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::Duration; +use tokio::net::UdpSocket; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +pub struct WireGuardGroup { + cache: AppCache, + config: ConfigInfo, + udp: Arc, + data_channel_map: Arc>>>>, +} + +impl WireGuardGroup { + pub fn new(cache: AppCache, config: ConfigInfo, udp: Arc) -> Self { + Self { + cache, + config, + udp, + data_channel_map: Default::default(), + } + } + pub fn handle(&mut self, buf: Vec, addr: SocketAddr) { + if let Err(e) = self.handle0(buf, addr) { + log::warn!("{},{}", addr, e); + } + } + fn handle0(&mut self, buf: Vec, addr: SocketAddr) -> anyhow::Result<()> { + if let Some(sender) = self.data_channel_map.lock().get(&addr) { + sender.try_send(buf)?; + return Ok(()); + } + let config = self.handshake(&buf)?; + let network_info = self + .cache + .virtual_network + .get(&config.group_id) + .context("wg配置已过期")?; + let (network_receiver, broadcast_ip, mask_ip, gateway_ip) = { + let mut guard = network_info.write(); + let broadcast_ip = guard.network_ip | (!guard.mask_ip); + + let client_info = guard + .clients + .get_mut(&config.ip.into()) + .context("wg配置已过期")?; + if client_info.wireguard.is_none() { + Err(anyhow!("不是wg配置"))?; + } + let (network_sender, network_receiver) = channel(64); + client_info.wg_sender = Some(network_sender); + client_info.last_join_time = Local::now(); + client_info.timestamp = client_info.last_join_time.timestamp(); + client_info.address = addr; + client_info.online = true; + guard.epoch += 1; + ( + network_receiver, + broadcast_ip, + guard.mask_ip, + guard.gateway_ip, + ) + }; + let wg = WireGuard::new( + network_info.clone(), + broadcast_ip.into(), + mask_ip.into(), + gateway_ip.into(), + self.cache.clone(), + self.config.wg_secret_key.clone(), + self.udp.clone(), + addr, + config, + self.data_channel_map.clone(), + ); + let (udp_sender, udp_receiver) = channel(64); + udp_sender.try_send(buf)?; + self.data_channel_map.lock().insert(addr, udp_sender); + tokio::spawn(wg.start(udp_receiver, network_receiver)); + Ok(()) + } + #[inline] + pub fn maybe_wg(buf: &[u8]) -> bool { + if buf.len() < 4 { + return false; + } + + // Checks the type, as well as the reserved zero fields + let packet_type = u32::from_le_bytes(buf[0..4].try_into().unwrap()); + (1..=4).contains(&packet_type) + } + pub fn handshake(&mut self, buf: &[u8]) -> anyhow::Result { + let packet = match Tunn::parse_incoming_packet(buf) { + Ok(packet) => packet, + Err(e) => Err(anyhow!("{:?}", e))?, + }; + match packet { + Packet::HandshakeInit(data) => { + let half_handshake = handshake::parse_handshake_anon( + &self.config.wg_secret_key, + &self.config.wg_public_key, + &data, + ) + .map_err(|e| anyhow!("HandshakeInit {:?}", e))?; + let config = self + .cache + .wg_group_map + .get(&half_handshake.peer_static_public) + .context("需要先在vnts配置wg信息")? + .clone(); + Ok(config) + } + _ => Err(anyhow!("非握手包")), + } + } +} + +pub struct WireGuard { + network_info: Arc>, + ip: Ipv4Addr, + broadcast_ip: Ipv4Addr, + mask_ip: Ipv4Addr, + gateway_ip: Ipv4Addr, + + group_id: String, + tunn: Tunn, + cache: AppCache, + wg_source_addr: SocketAddr, + udp: Arc, + data_channel_map: Arc>>>>, +} + +impl WireGuard { + pub fn new( + network_info: Arc>, + broadcast_ip: Ipv4Addr, + mask_ip: Ipv4Addr, + gateway_ip: Ipv4Addr, + cache: AppCache, + vnts_secret_key: StaticSecret, + udp: Arc, + wg_source_addr: SocketAddr, + config: WireGuardConfig, + data_channel_map: Arc>>>>, + ) -> Self { + let tunn = Tunn::new( + vnts_secret_key, + config.public_key.into(), + None, + Some(config.persistent_keepalive), + rand::thread_rng().next_u32(), + None, + ); + Self { + network_info, + ip: config.ip, + broadcast_ip, + mask_ip, + gateway_ip, + group_id: config.group_id, + tunn, + cache, + wg_source_addr, + udp, + data_channel_map, + } + } + pub async fn start( + mut self, + udp_receiver: Receiver>, + ipv4_receiver: Receiver<(Vec, Ipv4Addr)>, + ) { + if let Err(e) = self.start0(udp_receiver, ipv4_receiver).await { + log::warn!( + "wg连接异常断开 {:?},{:?},{:?},{:?}", + self.group_id, + self.ip, + self.wg_source_addr, + e + ); + } + self.offline(); + } + fn offline(&self) { + if let Some(v) = self.cache.virtual_network.get(&self.group_id) { + if let Some(v) = v.write().clients.get_mut(&self.ip.into()) { + if v.address == self.wg_source_addr { + v.online = false; + v.wg_sender = None; + } + } + } + self.data_channel_map.lock().remove(&self.wg_source_addr); + } + pub async fn start0( + &mut self, + mut udp_receiver: Receiver>, + mut ipv4_receiver: Receiver<(Vec, Ipv4Addr)>, + ) -> anyhow::Result<()> { + let mut interval = tokio::time::interval(Duration::from_millis(200)); + let mut dst_buf = [0; 65535]; + let mut dst_buf2 = [0; 65535]; + log::info!( + "处理wg链接 {},{}/{},{}", + self.group_id, + self.ip, + self.mask_ip, + self.wg_source_addr + ); + loop { + tokio::select! { + rs = udp_receiver.recv()=>{ + if let Some(mut data) = rs{ + self.handle_wg_data(&mut data,&mut dst_buf,&mut dst_buf2).await?; + }else{ + break; + } + } + rs = ipv4_receiver.recv()=>{ + if let Some((data,ip)) = rs{ + if let Err(e) = self.handle_ipv4_data(&data,&mut dst_buf).await{ + log::warn!("来源{},发送到wg失败,{:?}",ip,e) + } + }else{ + break; + } + } + _ = interval.tick()=>{ + self.update_timers(&mut dst_buf,&mut dst_buf2).await? + } + } + } + Ok(()) + } + pub async fn handle_ipv4_data(&mut self, buf: &[u8], dst_buf: &mut [u8]) -> anyhow::Result<()> { + let result = self.tunn.encapsulate(buf, dst_buf); + match result { + TunnResult::Done => {} + TunnResult::WriteToNetwork(data) => { + self.udp.send_to(data, self.wg_source_addr).await?; + } + e => Err(anyhow!("{:?}", e))?, + } + Ok(()) + } + + pub async fn handle_wg_data( + &mut self, + mut buf: &mut [u8], + dst_buf: &mut [u8], + dst_buf2: &mut [u8], + ) -> anyhow::Result<()> { + loop { + let mut result = self.tunn.decapsulate(None, buf, dst_buf); + if !self.handle_tunn_result(&mut result, dst_buf2).await? { + break; + } + buf = &mut []; + } + + Ok(()) + } + async fn handle_tunn_result( + &mut self, + result: &mut TunnResult<'_>, + dst_buf: &mut [u8], + ) -> anyhow::Result { + match result { + TunnResult::Done => {} + TunnResult::Err(WireGuardError::ConnectionExpired) => { + // 超时了直接断开,vnts不重连,等对端重连 + return Err(anyhow!("链接超时")); + } + TunnResult::Err(e) => { + log::warn!("WireGuard数据异常 {:?}", e); + } + TunnResult::WriteToNetwork(data) => { + self.udp.send_to(data, self.wg_source_addr).await?; + return Ok(true); + } + TunnResult::WriteToTunnelV4(data, _source_ip) => { + let mut packet = IpV4Packet::new(data)?; + let source_ip = packet.source_ip(); + let destination_ip = packet.destination_ip(); + if let Err(e) = self + .turn_data(source_ip, destination_ip, &mut packet.buffer, dst_buf) + .await + { + log::warn!("wg数据转发失败 {}->{} {:?}", source_ip, destination_ip, e); + } + } + TunnResult::WriteToTunnelV6(_packet, ip) => { + return Err(anyhow!("不支持ipv6连接 {:?}", ip)) + } + } + Ok(false) + } + /// from 'wireguard_tick': + /// This is a state keeping function, that need to be called periodically. + /// Recommended interval: 100ms. + pub async fn update_timers( + &mut self, + dst_buf: &mut [u8], + dst_buf2: &mut [u8], + ) -> anyhow::Result<()> { + let mut result = self.tunn.update_timers(dst_buf); + self.handle_tunn_result(&mut result, dst_buf2).await?; + Ok(()) + } + async fn turn_data( + &mut self, + src_ip: Ipv4Addr, + dest_ip: Ipv4Addr, + data: &mut [u8], + dst_buf: &mut [u8], + ) -> anyhow::Result<()> { + if dest_ip == self.gateway_ip { + if self.ping(data, src_ip, dest_ip).is_ok() { + if let Err(e) = self.handle_ipv4_data(&data, dst_buf).await { + log::warn!("发送ping回应到wg失败,{:?}", e) + } + } + return Ok(()); + } + if dest_ip.is_broadcast() || dest_ip == self.broadcast_ip { + // 广播 + let x: Vec<_> = self + .network_info + .read() + .clients + .values() + .filter(|v| v.online && v.virtual_ip != u32::from(self.ip)) + .map(|v| { + ( + v.address, + v.tcp_sender.clone(), + v.server_secret, + v.wg_sender.clone(), + ) + }) + .collect(); + for (peer_addr, peer_tcp_sender, server_secret, peer_wg_sender) in x { + if let Err(e) = self + .send_one( + peer_addr, + peer_tcp_sender, + peer_wg_sender, + server_secret, + src_ip, + dest_ip, + data, + dst_buf, + ) + .await + { + log::warn!("wg广播失败 {} {} {:?}", src_ip, peer_addr, e); + } + } + return Ok(()); + } + + let (server_secret, peer_addr, peer_tcp_sender, peer_wg_sender) = { + let guard = self.network_info.read(); + if let Some(dest_client_info) = guard.clients.get(&dest_ip.into()) { + if !dest_client_info.online { + Err(anyhow!("目标不在线"))? + } + if !dest_client_info.virtual_ip == u32::from(self.ip) { + Err(anyhow!("阻止回路"))? + } + let dest_link_addr = dest_client_info.address; + let server_secret = dest_client_info.server_secret; + ( + server_secret, + dest_link_addr, + dest_client_info.tcp_sender.clone(), + dest_client_info.wg_sender.clone(), + ) + } else { + Err(anyhow!("目标未注册"))? + } + }; + + self.send_one( + peer_addr, + peer_tcp_sender, + peer_wg_sender, + server_secret, + src_ip, + dest_ip, + data, + dst_buf, + ) + .await?; + Ok(()) + } + async fn send_one( + &self, + peer_addr: SocketAddr, + peer_tcp_sender: Option>>, + peer_wg_sender: Option, Ipv4Addr)>>, + server_secret: bool, + src_ip: Ipv4Addr, + dest_ip: Ipv4Addr, + data: &mut [u8], + dst_buf: &mut [u8], + ) -> anyhow::Result<()> { + if let Some(peer_wg_sender) = peer_wg_sender { + if let Err(e) = peer_wg_sender.send((data.to_vec(), self.ip)).await { + Err(anyhow!("发送到对端wg失败 {}", e))? + } + return Ok(()); + } + let mut net_packet = NetPacket::new0(HEAD_LEN + data.len(), dst_buf)?; + net_packet.set_default_version(); + // 把wg的转发当成是服务端来源的数据,因为服务端没有客户端密钥对数据进行加密 + net_packet.set_gateway_flag(true); + net_packet.set_protocol(Protocol::IpTurn); + net_packet.set_transport_protocol(ip_turn_packet::Protocol::WGIpv4.into()); + net_packet.first_set_ttl(MAX_TTL); + net_packet.set_source(src_ip); + net_packet.set_destination(dest_ip); + net_packet.set_payload(data)?; + if server_secret { + let cipher = self + .cache + .cipher_session + .get(&peer_addr) + .context("加密信息不存在")?; + cipher.encrypt_ipv4(&mut net_packet)?; + } + if let Some(tcp_sender) = peer_tcp_sender { + tcp_sender.send(net_packet.buffer().to_vec()).await?; + } else { + self.udp.send_to(net_packet.buffer(), peer_addr).await?; + } + Ok(()) + } + fn ping(&self, data: &mut [u8], src_ip: Ipv4Addr, dest_ip: Ipv4Addr) -> anyhow::Result<()> { + let mut ipv4 = IpV4Packet::new(data)?; + if let ipv4::protocol::Protocol::Icmp = ipv4.protocol() { + let mut icmp_packet = icmp::IcmpPacket::new(ipv4.payload_mut())?; + if icmp_packet.kind() == Kind::EchoRequest { + //开启ping + icmp_packet.set_kind(Kind::EchoReply); + icmp_packet.update_checksum(); + ipv4.set_source_ip(dest_ip); + ipv4.set_destination_ip(src_ip); + ipv4.update_checksum(); + return Ok(()); + } + } + Err(anyhow!("非ping Echo 不处理")) + } +} diff --git a/src/core/service/client.rs b/src/core/service/client.rs index b00994c..95e008d 100644 --- a/src/core/service/client.rs +++ b/src/core/service/client.rs @@ -3,14 +3,13 @@ use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::UdpSocket; - use crate::cipher::RsaCipher; -use crate::core::entity::ClientInfo; -use crate::core::store::cache::{AppCache, Context}; +use crate::core::store::cache::{AppCache, LinkVntContext, VntContext}; use crate::error::*; use crate::protocol::NetPacket; use crate::ConfigInfo; +use tokio::net::UdpSocket; +use tokio::sync::mpsc::Sender; #[derive(Clone)] pub struct ClientPacketHandler { @@ -37,25 +36,26 @@ impl ClientPacketHandler { } impl ClientPacketHandler { - pub fn handle + AsMut<[u8]>>( + pub async fn handle + AsMut<[u8]>>( &self, + context: &VntContext, net_packet: NetPacket, - addr: SocketAddr, + _addr: SocketAddr, ) -> Result<()> { - if let Some(context) = self.cache.get_context(&addr) { - self.handle0(net_packet, context) + if let Some(context) = &context.link_context { + self.handle0(context, net_packet).await } else { - Err(Error::Disconnect) + Err(Error::Disconnect)? } } } impl ClientPacketHandler { /// 转发到目标地址 - fn handle0 + AsMut<[u8]>>( + async fn handle0 + AsMut<[u8]>>( &self, + context: &LinkVntContext, mut net_packet: NetPacket, - context: Context, ) -> Result<()> { if net_packet.incr_ttl() > 1 { if self.config.check_finger { @@ -65,33 +65,65 @@ impl ClientPacketHandler { let destination = net_packet.destination(); if destination.is_broadcast() || self.config.broadcast == destination { //处理广播 - broadcast(&self.udp, context, net_packet); - } else if let Some(client_info) = - context.network_info.read().clients.get(&destination.into()) - { - send_one(&self.udp, client_info, &net_packet); + broadcast(context, &self.udp, net_packet).await; + } else { + let is_encrypt = net_packet.is_encrypt(); + let source_ip = u32::from(net_packet.source()); + let rs = context + .network_info + .read() + .clients + .get(&destination.into()) + .filter(|v| { + v.wireguard.is_none() + && v.online + && v.client_secret == is_encrypt + && v.virtual_ip != source_ip + }) + .map(|v| (v.address, v.tcp_sender.clone())); + if let Some((peer_addr, peer_tcp_sender)) = rs { + send_one(&self.udp, peer_addr, peer_tcp_sender, &net_packet).await; + } } } Ok(()) } } -fn broadcast>(udp_socket: &UdpSocket, context: Context, net_packet: NetPacket) { - for client_info in context.network_info.read().clients.values() { - send_one(udp_socket, client_info, &net_packet); +async fn broadcast>( + context: &LinkVntContext, + udp_socket: &UdpSocket, + net_packet: NetPacket, +) { + let is_encrypt = net_packet.is_encrypt(); + let source_ip = u32::from(net_packet.source()); + let x: Vec<_> = context + .network_info + .read() + .clients + .values() + .filter(|v| { + v.wireguard.is_none() + && v.online + && v.client_secret == is_encrypt + && v.virtual_ip != source_ip + }) + .map(|v| (v.address, v.tcp_sender.clone())) + .collect(); + for (peer_addr, peer_tcp_sender) in x { + send_one(udp_socket, peer_addr, peer_tcp_sender, &net_packet).await; } } -fn send_one>( +async fn send_one>( udp_socket: &UdpSocket, - client_info: &ClientInfo, + peer_addr: SocketAddr, + peer_tcp_sender: Option>>, net_packet: &NetPacket, ) { - if client_info.online && client_info.client_secret == net_packet.is_encrypt() { - if let Some(sender) = &client_info.tcp_sender { - let _ = sender.try_send(net_packet.buffer().to_vec()); - } else { - let _ = udp_socket.try_send_to(net_packet.buffer(), client_info.address); - } + if let Some(sender) = &peer_tcp_sender { + let _ = sender.send(net_packet.buffer().to_vec()).await; + } else { + let _ = udp_socket.send_to(net_packet.buffer(), peer_addr).await; } } diff --git a/src/core/service/mod.rs b/src/core/service/mod.rs index b73d3c7..4e68f7f 100644 --- a/src/core/service/mod.rs +++ b/src/core/service/mod.rs @@ -7,7 +7,7 @@ use tokio::sync::mpsc::Sender; use crate::cipher::RsaCipher; use crate::core::service::client::ClientPacketHandler; use crate::core::service::server::ServerPacketHandler; -use crate::core::store::cache::AppCache; +use crate::core::store::cache::{AppCache, VntContext}; use crate::error::*; use crate::protocol::NetPacket; use crate::ConfigInfo; @@ -41,13 +41,17 @@ impl PacketHandler { } impl PacketHandler { + pub async fn leave(&self, context: VntContext) { + self.server.leave(context).await; + } pub async fn handle + AsMut<[u8]>>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, ) -> Option>> { - self.handle0(net_packet, addr, tcp_sender) + self.handle0(context, net_packet, addr, tcp_sender) .await .unwrap_or_else(|e| { log::error!("addr={},{:?}", addr, e); @@ -56,14 +60,17 @@ impl PacketHandler { } async fn handle0 + AsMut<[u8]>>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, ) -> Result>>> { if net_packet.is_gateway() { - self.server.handle(net_packet, addr, tcp_sender).await + self.server + .handle(context, net_packet, addr, tcp_sender) + .await } else { - self.client.handle(net_packet, addr)?; + self.client.handle(context, net_packet, addr).await?; Ok(None) } } diff --git a/src/core/service/server.rs b/src/core/service/server.rs index 599ad00..781c020 100644 --- a/src/core/service/server.rs +++ b/src/core/service/server.rs @@ -1,20 +1,20 @@ +use anyhow::{anyhow, Context}; use chrono::Local; use packet::icmp::{icmp, Kind}; use packet::ip::ipv4; use packet::ip::ipv4::packet::IpV4Packet; +use protobuf::Message; use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; use std::{io, result}; - -use protobuf::Message; use tokio::net::UdpSocket; use tokio::sync::mpsc::Sender; use crate::cipher::{Aes256GcmCipher, Finger, RsaCipher}; -use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo}; -use crate::core::store::cache::{AppCache, Context}; +use crate::core::entity::{ClientInfo, ClientStatusInfo, NetworkInfo, SimpleClientInfo}; +use crate::core::store::cache::{AppCache, LinkVntContext, VntContext}; use crate::error::*; use crate::proto::message; use crate::proto::message::{DeviceList, RegistrationRequest, RegistrationResponse}; @@ -48,8 +48,12 @@ impl ServerPacketHandler { } impl ServerPacketHandler { + pub async fn leave(&self, context: VntContext) { + context.leave(&self.cache).await; + } pub async fn handle + AsMut<[u8]>>( &self, + context: &mut VntContext, mut net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, @@ -66,26 +70,24 @@ impl ServerPacketHandler { } service_packet::Protocol::SecretHandshakeRequest => { // 加密握手 - let rs = self.secret_handshake(net_packet, addr).await?; + let rs = self.secret_handshake(context, net_packet, addr).await?; return Ok(Some(rs)); } _ => {} } } // 解密 - let aes = if net_packet.is_encrypt() { - if let Some(aes) = self.cache.cipher_session.get(&addr) { + let server_secret = net_packet.is_encrypt(); + if server_secret { + if let Some(aes) = &context.server_cipher { aes.decrypt_ipv4(&mut net_packet)?; - Some(aes) } else { log::info!("没有密钥:{},head={:?}", addr, net_packet.head()); - return Ok(Some(self.handle_err(addr, source, Error::NoKey)?)); + return Ok(Some(self.handle_err(addr, source, &Error::NoKey)?)); } - } else { - None - }; + } let mut packet = match self - .handle0(net_packet, addr, tcp_sender, aes.is_some()) + .handle0(context, net_packet, addr, tcp_sender, server_secret) .await { Ok(rs) => { @@ -95,11 +97,13 @@ impl ServerPacketHandler { return Ok(None); } } - Err(e) => self.handle_err(addr, source, e)?, + Err(e) => self.handle_anyhow_err(addr, source, e)?, }; self.common_param(&mut packet, source); - if let Some(aes) = aes { - aes.encrypt_ipv4(&mut packet)?; + if server_secret { + if let Some(aes) = &context.server_cipher { + aes.encrypt_ipv4(&mut packet)?; + } } Ok(Some(packet)) } @@ -115,20 +119,28 @@ impl ServerPacketHandler { net_packet.first_set_ttl(MAX_TTL); net_packet.set_gateway_flag(true); } + fn handle_anyhow_err( + &self, + addr: SocketAddr, + source: Ipv4Addr, + e: anyhow::Error, + ) -> Result>> { + if let Some(e) = e.downcast_ref() { + self.handle_err(addr, source, e) + } else { + self.handle_err(addr, source, &Error::Other(format!("{}", e))) + } + } fn handle_err( &self, addr: SocketAddr, source: Ipv4Addr, - e: Error, + e: &Error, ) -> Result>> { log::warn!("addr={},source={},{:?}", addr, source, e); let rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; let mut packet = NetPacket::new_encrypt(rs)?; match e { - Error::Io(_) => {} - Error::Channel(_) => {} - Error::Protobuf(_) => {} - Error::AddressExhausted => { packet.set_transport_protocol(error_packet::Protocol::AddressExhausted.into()); } @@ -161,6 +173,7 @@ impl ServerPacketHandler { } async fn handle0 + AsMut<[u8]>>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, @@ -168,7 +181,7 @@ impl ServerPacketHandler { ) -> Result>>> { // 处理不需要连接上下文的请求 let mut net_packet = match self - .not_context(net_packet, addr, tcp_sender, server_secret) + .not_context(context, net_packet, addr, tcp_sender, server_secret) .await { Ok(rs) => { @@ -177,10 +190,10 @@ impl ServerPacketHandler { Err(net_packet) => net_packet, }; // 需要连接的上下文 - let context = if let Some(context) = self.cache.get_context(&addr) { - context + let link_context = if let Some(link_context) = &context.link_context { + link_context } else { - return Err(Error::Disconnect); + return Err(Error::Disconnect)?; }; match net_packet.protocol() { @@ -188,13 +201,13 @@ impl ServerPacketHandler { match protocol::service_packet::Protocol::from(net_packet.transport_protocol()) { service_packet::Protocol::PullDeviceList => { //拉取网段设备信息 - return self.poll_device_list(net_packet, addr, &context); + return self.poll_device_list(net_packet, addr, &link_context); } service_packet::Protocol::ClientStatusInfo => { //客户端上报信息 let client_status_info = message::ClientStatusInfo::parse_from_bytes(net_packet.payload())?; - self.up_client_status_info(client_status_info, &context); + self.up_client_status_info(client_status_info, &link_context); return Ok(None); } _ => {} @@ -205,17 +218,22 @@ impl ServerPacketHandler { if let control_packet::Protocol::Ping = protocol::control_packet::Protocol::from(net_packet.transport_protocol()) { - return self.control_ping(net_packet, &context); + return self.control_ping(net_packet, &link_context); } } Protocol::IpTurn => { match protocol::ip_turn_packet::Protocol::from(net_packet.transport_protocol()) { + protocol::ip_turn_packet::Protocol::WGIpv4 => { + //wg数据转发 + self.wg_ipv4(&link_context, net_packet).await?; + return Ok(None); + } protocol::ip_turn_packet::Protocol::Ipv4Broadcast => { //处理选择性广播,进过网关还原成原始广播 let broadcast_packet = BroadcastPacket::new(net_packet.payload())?; let exclude = broadcast_packet.addresses(); let broadcast_net_packet = NetPacket::new(broadcast_packet.data()?)?; - self.broadcast(&context, broadcast_net_packet, &exclude)?; + self.broadcast(&link_context, broadcast_net_packet, &exclude)?; return Ok(None); } protocol::ip_turn_packet::Protocol::Ipv4 => { @@ -258,6 +276,7 @@ impl ServerPacketHandler { impl ServerPacketHandler { async fn not_context>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, @@ -269,7 +288,7 @@ impl ServerPacketHandler { { //注册 return Ok(self - .register(net_packet, addr, tcp_sender, server_secret) + .register(context, net_packet, addr, tcp_sender, server_secret) .await); } } else if net_packet.protocol() == Protocol::Control { @@ -287,7 +306,7 @@ impl ServerPacketHandler { fn control_ping>( &self, net_packet: NetPacket, - context: &Context, + context: &LinkVntContext, ) -> Result>>> { let vec = vec![0u8; 12 + 4 + ENCRYPTION_RESERVED]; let mut packet = NetPacket::new_encrypt(vec)?; @@ -324,13 +343,13 @@ impl ServerPacketHandler { impl ServerPacketHandler { async fn register>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, tcp_sender: &Option>>, server_secret: bool, ) -> Result>>> { let config = &self.config; - let cache = &self.cache; let request = RegistrationRequest::parse_from_bytes(net_packet.payload())?; check_reg(&request)?; log::info!( @@ -346,6 +365,8 @@ impl ServerPacketHandler { tcp_sender.is_some() ); let group_id = request.token.clone(); + let gateway = config.gateway; + let netmask = config.netmask; if let Some(white_token) = &config.white_token { if !white_token.contains(&group_id) { log::info!( @@ -353,7 +374,7 @@ impl ServerPacketHandler { white_token, group_id ); - return Err(Error::TokenError); + Err(Error::TokenError)? } } let mut response = RegistrationResponse::new(); @@ -371,119 +392,46 @@ impl ServerPacketHandler { } } } - //固定网段 - let gateway: u32 = config.gateway.into(); - let netmask: u32 = config.netmask.into(); - let network: u32 = gateway & netmask; - - response.virtual_netmask = netmask; - response.virtual_gateway = gateway; - - let v = cache - .virtual_network - .optionally_get_with(group_id.clone(), || { - ( - Duration::from_secs(7 * 24 * 3600), - Arc::new(parking_lot::const_rwlock(NetworkInfo::new( - network, netmask, gateway, - ))), - ) - }) - .await; - let mut virtual_ip = request.virtual_ip; - // 可分配的ip段 - let ip_range = network + 1..gateway | (!netmask); - let timestamp = Local::now().timestamp(); - { - let mut lock = v.write(); - let mut insert = true; - if virtual_ip != 0 { - if u32::from(config.gateway) == virtual_ip - || u32::from(config.broadcast) == virtual_ip - || !ip_range.contains(&virtual_ip) - { - log::warn!("手动指定的ip无效: {:?}", request); - return Err(Error::InvalidIp); - } - //指定了ip - if let Some(info) = lock.clients.get_mut(&request.virtual_ip) { - if info.device_id != request.device_id { - //ip被占用了,并且不能更改ip - if !request.allow_ip_change { - log::warn!("手动指定的ip已经存在:{:?}", request); - return Err(Error::IpAlreadyExists); - } - // 重新挑选ip - virtual_ip = 0; - } else { - insert = false; - } - } - } - let mut old_ip = 0; - if insert { - // 找到上一次用的ip - for (ip, x) in &lock.clients { - if x.device_id == request.device_id { - if virtual_ip == 0 { - virtual_ip = *ip; - } else { - old_ip = *ip; - } - break; - } - } - } + let register_client_request = RegisterClientRequest { + group_id: group_id.clone(), + virtual_ip: request.virtual_ip.into(), + gateway, + netmask, + allow_ip_change: request.allow_ip_change, + device_id: request.device_id, + version: request.version, + name: request.name, + client_secret: request.client_secret, + client_secret_hash: request.client_secret_hash, + server_secret, + address: addr, + tcp_sender: tcp_sender.clone(), + online: true, + wireguard: None, + }; + let register_response = generate_ip(&self.cache, register_client_request).await?; + let virtual_ip = register_response.virtual_ip.into(); + response.virtual_gateway = gateway.into(); + response.virtual_netmask = netmask.into(); + response.virtual_ip = virtual_ip; + response.epoch = register_response.epoch as u32; + response.device_info_list = register_response + .client_list + .into_iter() + .map(|v| v.into()) + .collect(); + context.link_context.replace(LinkVntContext { + network_info: self + .cache + .virtual_network + .get(&group_id) + .context("virtual_network is none")?, + group: group_id.clone(), + virtual_ip, + broadcast: config.broadcast, + timestamp: register_response.timestamp, + }); - if virtual_ip == 0 { - // 从小到大找一个未使用的ip - for ip in ip_range { - if ip == lock.gateway_ip { - continue; - } - if !lock.clients.contains_key(&ip) { - virtual_ip = ip; - break; - } - } - } - if virtual_ip == 0 { - log::error!("地址使用完:{:?}", request); - return Err(Error::AddressExhausted); - } - let info = if old_ip == 0 { - lock.clients - .entry(virtual_ip) - .or_insert_with(ClientInfo::default) - } else { - let client_info = lock.clients.remove(&old_ip).unwrap(); - lock.clients - .entry(virtual_ip) - .or_insert_with(|| client_info) - }; - info.name = request.name; - info.device_id = request.device_id; - info.version = request.version; - info.client_secret = request.client_secret; - info.server_secret = server_secret; - info.address = addr; - info.online = true; - info.virtual_ip = virtual_ip; - info.tcp_sender = tcp_sender.clone(); - info.last_join_time = Local::now(); - info.timestamp = timestamp; - lock.epoch += 1; - response.virtual_ip = virtual_ip; - response.epoch = lock.epoch as u32; - response.device_info_list = Self::clients_info(&lock.clients, virtual_ip); - drop(lock); - } - cache - .insert_ip_session((group_id.clone(), virtual_ip), addr) - .await; - cache - .insert_addr_session(addr, (group_id, virtual_ip, timestamp)) - .await; let bytes = response.write_to_bytes()?; let rs = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED]; let mut packet = NetPacket::new_encrypt(rs)?; @@ -496,13 +444,16 @@ impl ServerPacketHandler { fn check_reg(request: &RegistrationRequest) -> Result<()> { if request.token.is_empty() || request.token.len() > 128 { - return Err(Error::Other("group length error".into())); + Err(anyhow!("group length error"))? } if request.device_id.is_empty() || request.device_id.len() > 128 { - return Err(Error::Other("device_id length error".into())); + Err(anyhow!("device_id length error"))? } if request.name.is_empty() || request.name.len() > 128 { - return Err(Error::Other("name length error".into())); + Err(anyhow!("name length error"))? + } + if request.client_secret_hash.len() > 128 { + Err(anyhow!("client_secret_hash length error"))? } Ok(()) } @@ -535,6 +486,7 @@ impl ServerPacketHandler { } async fn secret_handshake>( &self, + context: &mut VntContext, net_packet: NetPacket, addr: SocketAddr, ) -> Result>> { @@ -545,10 +497,7 @@ impl ServerPacketHandler { let sync_secret = message::SecretHandshakeRequest::parse_from_bytes(rsa_secret_body.data())?; let c = Aes256GcmCipher::new( - sync_secret - .key - .try_into() - .map_err(|_| Error::Other("key err".into()))?, + sync_secret.key.try_into().map_err(|_| anyhow!("key err"))?, Finger::new(&sync_secret.token), ); let rs = vec![0u8; 12 + ENCRYPTION_RESERVED]; @@ -557,10 +506,11 @@ impl ServerPacketHandler { packet.set_transport_protocol(service_packet::Protocol::SecretHandshakeResponse.into()); self.common_param(&mut packet, source); c.encrypt_ipv4(&mut packet)?; + context.server_cipher.replace(c.clone()); self.cache.insert_cipher_session(addr, c).await; return Ok(packet); } - Err(Error::Other("no encryption".into())) + Err(anyhow!("no encryption")) } } @@ -569,15 +519,15 @@ impl ServerPacketHandler { &self, _net_packet: NetPacket, _addr: SocketAddr, - context: &Context, + context: &LinkVntContext, ) -> Result>>> { let guard = context.network_info.read(); - let ips = Self::clients_info(&guard.clients, context.virtual_ip); + let ips = clients_info(&guard.clients, context.virtual_ip); let epoch = guard.epoch; drop(guard); let mut device_list = DeviceList::new(); device_list.epoch = epoch as u32; - device_list.device_info_list = ips; + device_list.device_info_list = ips.into_iter().map(|v| v.into()).collect(); let bytes = device_list.write_to_bytes()?; let vec = vec![0u8; 12 + bytes.len() + ENCRYPTION_RESERVED]; let mut device_list_packet = NetPacket::new_encrypt(vec)?; @@ -589,7 +539,7 @@ impl ServerPacketHandler { fn up_client_status_info( &self, client_status_info: message::ClientStatusInfo, - context: &Context, + context: &LinkVntContext, ) { let mut status_info = ClientStatusInfo::default(); let iplist = &mut status_info.p2p_list; @@ -612,34 +562,55 @@ impl ServerPacketHandler { v.client_status = Some(status_info); } } - fn clients_info( - clients: &HashMap, - current_ip: u32, - ) -> Vec { - clients - .iter() - .filter(|&(_, dev)| dev.virtual_ip != current_ip) - .map(|(_, device_info)| { - let mut dev = message::DeviceInfo::new(); - dev.virtual_ip = device_info.virtual_ip; - dev.name = device_info.name.clone(); - dev.device_status = if device_info.online { 0 } else { 1 }; - dev.client_secret = device_info.client_secret; - dev - }) - .collect() + async fn wg_ipv4>( + &self, + context: &LinkVntContext, + net_packet: NetPacket, + ) -> anyhow::Result<()> { + let source = net_packet.source(); + let dest = net_packet.destination(); + let destination = u32::from(dest); + if destination == context.virtual_ip { + return Ok(()); + } + if dest.is_broadcast() || dest == context.broadcast { + // 广播 + for peer in context.network_info.read().clients.values() { + if !peer.online || destination == peer.virtual_ip { + continue; + } + if let Some(sender) = &peer.wg_sender { + if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) { + log::info!("广播到对端wg失败 {}->{},{}", source, dest, e); + } + } + } + } else if let Some(peer) = context.network_info.read().clients.get(&destination) { + // 点对点 + if peer.online { + if let Some(sender) = &peer.wg_sender { + if let Err(e) = sender.try_send((net_packet.payload().to_vec(), source)) { + log::info!("发送到对端wg失败 {}->{},{}", source, dest, e); + } + } + } + } + Ok(()) } fn broadcast>( &self, - context: &Context, + context: &LinkVntContext, net_packet: NetPacket, exclude: &[Ipv4Addr], ) -> io::Result<()> { let client_secret = net_packet.is_encrypt(); + let destination = u32::from(net_packet.destination()); for (ip, client_info) in &context.network_info.read().clients { if client_info.online - && !exclude.contains(&(*ip).into()) + && destination != *ip && client_info.client_secret == client_secret + && client_info.wireguard.is_none() + && !exclude.contains(&(*ip).into()) { if let Some(sender) = &client_info.tcp_sender { let _ = sender.try_send(net_packet.buffer().to_vec()); @@ -653,3 +624,171 @@ impl ServerPacketHandler { Ok(()) } } + +pub struct RegisterClientRequest { + pub group_id: String, + // ip 0表示自动分配 + pub virtual_ip: Ipv4Addr, + pub gateway: Ipv4Addr, + pub netmask: Ipv4Addr, + + // 允许分配不一样的ip + pub allow_ip_change: bool, + // 设备ID + pub device_id: String, + // 版本 + pub version: String, + // 名称 + pub name: String, + // 客户端间是否加密 + pub client_secret: bool, + // 加密hash + pub client_secret_hash: Vec, + // 和服务端是否加密 + pub server_secret: bool, + // 链接服务器的来源地址 + pub address: SocketAddr, + pub tcp_sender: Option>>, + // 是否在线 + pub online: bool, + // wireguard客户端公钥 + pub wireguard: Option<[u8; 32]>, +} + +pub struct RegisterClientResponse { + timestamp: i64, + pub virtual_ip: Ipv4Addr, + // 纪元号 + pub epoch: u64, + pub client_list: Vec, +} + +pub async fn generate_ip( + cache: &AppCache, + register_request: RegisterClientRequest, +) -> anyhow::Result { + let gateway: u32 = register_request.gateway.into(); + let netmask: u32 = register_request.netmask.into(); + let network: u32 = gateway & netmask; + let mut virtual_ip: u32 = register_request.virtual_ip.into(); + let device_id = register_request.device_id; + let allow_ip_change = register_request.allow_ip_change; + let group_id = register_request.group_id; + let v = cache + .virtual_network + .optionally_get_with(group_id, || { + ( + Duration::from_secs(7 * 24 * 3600), + Arc::new(parking_lot::const_rwlock(NetworkInfo::new( + network, netmask, gateway, + ))), + ) + }) + .await; + // 可分配的ip段 + let ip_range = network + 1..gateway | (!netmask); + let timestamp = Local::now().timestamp(); + let mut lock = v.write(); + let mut insert = true; + if virtual_ip != 0 { + if gateway == virtual_ip || !ip_range.contains(&virtual_ip) { + Err(Error::InvalidIp)? + } + //指定了ip + if let Some(info) = lock.clients.get_mut(&virtual_ip) { + if info.device_id != device_id { + //ip被占用了,并且不能更改ip + if !allow_ip_change { + Err(Error::IpAlreadyExists)? + } + // 重新挑选ip + virtual_ip = 0; + } else { + insert = false; + } + } + } + let mut old_ip = 0; + if insert { + // 找到上一次用的ip + for (ip, x) in &lock.clients { + if x.device_id == device_id { + if virtual_ip == 0 { + virtual_ip = *ip; + } else { + old_ip = *ip; + } + break; + } + } + } + + if virtual_ip == 0 { + // 从小到大找一个未使用的ip + for ip in ip_range { + if ip == lock.gateway_ip { + continue; + } + if !lock.clients.contains_key(&ip) { + virtual_ip = ip; + break; + } + } + } + if virtual_ip == 0 { + log::error!("地址使用完:{:?}", lock); + Err(Error::AddressExhausted)? + } + let info = if old_ip == 0 { + lock.clients + .entry(virtual_ip) + .or_insert_with(ClientInfo::default) + } else { + let client_info = lock.clients.remove(&old_ip).unwrap(); + lock.clients + .entry(virtual_ip) + .or_insert_with(|| client_info) + }; + info.name = register_request.name; + info.device_id = device_id; + info.version = register_request.version; + info.client_secret = register_request.client_secret; + info.client_secret_hash = register_request.client_secret_hash; + info.server_secret = register_request.server_secret; + info.address = register_request.address; + info.online = register_request.online; + info.wireguard = register_request.wireguard; + info.virtual_ip = virtual_ip; + info.tcp_sender = register_request.tcp_sender; + info.last_join_time = Local::now(); + info.timestamp = timestamp; + lock.epoch += 1; + let response = RegisterClientResponse { + timestamp, + virtual_ip: virtual_ip.into(), + epoch: lock.epoch, + client_list: clients_info(&lock.clients, virtual_ip), + }; + Ok(response) +} +fn clients_info(clients: &HashMap, current_ip: u32) -> Vec { + clients + .iter() + .filter(|&(_, dev)| dev.virtual_ip != current_ip) + .map(|(_, device_info)| device_info.into()) + .collect() +} +impl From for message::DeviceInfo { + fn from(value: SimpleClientInfo) -> Self { + let mut dev = message::DeviceInfo::new(); + dev.virtual_ip = value.virtual_ip; + dev.name = value.name; + dev.device_status = if value.online { 0 } else { 1 }; + dev.client_secret = value.client_secret; + if value.online { + dev.client_secret_hash = value.client_secret_hash; + } + dev.wireguard = value.wireguard; + dev + } +} diff --git a/src/core/store/cache.rs b/src/core/store/cache.rs index 6068a29..8f50210 100644 --- a/src/core/store/cache.rs +++ b/src/core/store/cache.rs @@ -1,130 +1,121 @@ +use dashmap::DashMap; +use parking_lot::RwLock; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use parking_lot::RwLock; - use crate::cipher::Aes256GcmCipher; -use crate::core::entity::NetworkInfo; +use crate::core::entity::{NetworkInfo, WireGuardConfig}; use crate::core::store::expire_map::ExpireMap; #[derive(Clone)] pub struct AppCache { // group -> NetworkInfo pub virtual_network: ExpireMap>>, - // (group,ip) -> addr + // (group,ip) -> addr 用于客户端过期,只有客户端离线才设置 pub ip_session: ExpireMap<(String, u32), SocketAddr>, - // addr -> (group,ip) - pub addr_session: ExpireMap, - pub cipher_session: ExpireMap>, + // 加密密钥 + pub cipher_session: Arc>>, + // web登录状态 pub auth_map: ExpireMap, + // wg公钥 -> wg配置 + pub wg_group_map: Arc>, } -pub struct Context { +pub struct VntContext { + pub link_context: Option, + pub server_cipher: Option, + pub link_address: SocketAddr, +} +pub struct LinkVntContext { pub network_info: Arc>, pub group: String, pub virtual_ip: u32, + pub broadcast: Ipv4Addr, + pub timestamp: i64, +} +impl VntContext { + pub async fn leave(self, cache: &AppCache) { + if self.server_cipher.is_some() { + cache.cipher_session.remove(&self.link_address); + } + if let Some(context) = self.link_context { + if let Some(network_info) = cache.virtual_network.get(&context.group) { + { + let mut guard = network_info.write(); + if let Some(client_info) = guard.clients.get_mut(&context.virtual_ip) { + if client_info.address != self.link_address + && client_info.timestamp != context.timestamp + { + return; + } + client_info.online = false; + client_info.tcp_sender = None; + guard.epoch += 1; + } + drop(guard); + } + cache + .insert_ip_session((context.group, context.virtual_ip), self.link_address) + .await; + } + } + } } impl AppCache { pub fn new() -> Self { + let wg_group_map: Arc> = Default::default(); // 网段7天未使用则回收 let virtual_network: ExpireMap>> = - ExpireMap::new(|_k, _v| {}); - let virtual_network_ = virtual_network.clone(); - // ip一天未使用则回收 - let ip_session: ExpireMap<(String, u32), SocketAddr> = - ExpireMap::new(move |(group_id, ip), addr: SocketAddr| { - log::info!( - "ip_session eviction group_id={},ip={},addr={}", - group_id, - Ipv4Addr::from(ip), - addr - ); - if let Some(v) = virtual_network_.get(&group_id) { - let mut lock = v.write(); - if let Some(dev) = lock.clients.get(&ip) { - if dev.address == addr { - lock.clients.remove(&ip); - lock.epoch += 1; - } - } + ExpireMap::new(|_k, v: &Arc>| { + let lock = v.read(); + if !lock.clients.is_empty() { + // 存在客户端的不过期 + return Some(Duration::from_secs(7 * 24 * 3600)); } + None }); let virtual_network_ = virtual_network.clone(); - // 20秒钟没有收到消息则判定为掉线 - let addr_session = ExpireMap::new( - move |addr: SocketAddr, (group, virtual_ip, timestamp)| { - log::info!( - "addr_session eviction group={},virtual_ip={},addr={},timestamp={}", - group, - Ipv4Addr::from(virtual_ip), - addr, - timestamp - ); - - if let Some(v) = virtual_network_.get(&group) { - let mut lock = v.write(); - if let Some(item) = lock.clients.get_mut(&virtual_ip) { - if item.address != addr || item.timestamp != timestamp { - log::info!( - "无效信息 addr_session eviction group={},virtual_ip={},addr={},timestamp={}", - group, - Ipv4Addr::from(virtual_ip), - addr, - timestamp - ); - return; - } - item.online = false; + // ip一天未使用则回收 + let ip_session: ExpireMap<(String, u32), SocketAddr> = ExpireMap::new(move |key, addr| { + let (group_id, ip) = &key; + log::info!( + "ip_session eviction group_id={},ip={},addr={}", + group_id, + Ipv4Addr::from(*ip), + addr + ); + if let Some(v) = virtual_network_.get(group_id) { + let mut lock = v.write(); + if let Some(dev) = lock.clients.get(ip) { + if !dev.online && &dev.address == addr { + lock.clients.remove(ip); lock.epoch += 1; } } - }, - ); - let cipher_session = ExpireMap::new(|_k, _v| {}); - let auth_map = ExpireMap::new(|_k, _v| {}); + } + None + }); + + let auth_map = ExpireMap::new(|_k, _v| None); Self { virtual_network, ip_session, - addr_session, - cipher_session, + cipher_session: Default::default(), auth_map, + wg_group_map, } } } impl AppCache { - pub fn get_context(&self, addr: &SocketAddr) -> Option { - if let Some((group, virtual_ip, _)) = self.addr_session.get(addr) { - let k = (group, virtual_ip); - self.ip_session.get(&k)?; - let (group, virtual_ip) = k; - return self - .virtual_network - .get(&group) - .map(|network_info| Context { - network_info, - group, - virtual_ip, - }); - } - None - } - pub async fn insert_cipher_session(&self, key: SocketAddr, value: Aes256GcmCipher) { - self.cipher_session - .insert(key, Arc::new(value), Duration::from_secs(120)) - .await + self.cipher_session.insert(key, Arc::new(value)); } pub async fn insert_ip_session(&self, key: (String, u32), value: SocketAddr) { self.ip_session .insert(key, value, Duration::from_secs(24 * 3600)) .await } - pub async fn insert_addr_session(&self, key: SocketAddr, value: (String, u32, i64)) { - self.addr_session - .insert(key, value, Duration::from_secs(20)) - .await - } } diff --git a/src/core/store/expire_map.rs b/src/core/store/expire_map.rs index d120390..406cd95 100644 --- a/src/core/store/expire_map.rs +++ b/src/core/store/expire_map.rs @@ -26,7 +26,7 @@ struct Value { impl ExpireMap { pub fn new(call: F) -> ExpireMap where - F: Fn(K, V) + Send + 'static, + F: Fn(&K, &V) -> Option + Send + 'static, K: Eq + Hash + Clone + Sync + Send + 'static, V: Clone + Sync + Send + 'static, { @@ -66,6 +66,14 @@ where .await .unwrap(); } + /// remove出去的不会执行过期回调 + pub fn remove(&self, k: &K) -> Option { + if let Some(v) = self.base.write().remove(k) { + Some(v.val) + } else { + None + } + } pub fn get(&self, k: &K) -> Option { if let Some(v) = self.base.read().get(k) { // 延长过期时间 @@ -78,7 +86,10 @@ where pub fn get_val(&self, k: &K) -> Option { self.base.read().get(k).map(|v| v.val.clone()) } - fn expire_call(&self, k: &K) -> Op { + fn expire_call(&self, k: &K, f: &F) -> Op + where + F: Fn(&K, &V) -> Option, + { let mut write_guard = self.base.write(); if let Some(v) = write_guard.get(k) { let now = Instant::now(); @@ -87,10 +98,14 @@ where // 过期时间更新了 return Op::Reset(instant); } else { - //删除key - if let Some((k, v)) = write_guard.remove_entry(k) { - //执行回调 - return Op::Remove(k, v.val); + //执行过期回调 + if let Some(v) = f(k, &v.val) { + return Op::Reset(now.add(v)); + } else { + //删除key + if let Some((k, v)) = write_guard.remove_entry(k) { + return Op::Remove(k, v.val); + } } } } @@ -142,7 +157,7 @@ async fn expire_task(mut receiver: Receiver>, map: Expir where K: Eq + Hash + Clone, V: Clone, - F: Fn(K, V), + F: Fn(&K, &V) -> Option, { let mut binary_heap = BinaryHeap::>::with_capacity(32); loop { @@ -165,17 +180,13 @@ where } } else if let Some(mut task) = binary_heap.pop() { //执行过期逻辑 - match map.expire_call(&task.k) { + match map.expire_call(&task.k, &f) { Op::Reset(time) => { //没有过期,重新加入监听 - task.time = time; binary_heap.push(task); } - Op::Remove(k, v) => { - //执行回调 - f(k, v) - } + Op::Remove(_, _) => {} Op::None => {} } } diff --git a/src/error/mod.rs b/src/error/mod.rs index 61b77f8..e88300a 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -1,18 +1,9 @@ #![allow(dead_code, clippy::enum_variant_names)] -use std::io; - -use crossbeam::channel::RecvError; use thiserror::Error; #[derive(Error, Debug)] pub enum Error { - #[error("Io error")] - Io(#[from] io::Error), - #[error("Channel error")] - Channel(#[from] RecvError), - #[error("Protobuf error")] - Protobuf(#[from] protobuf::Error), #[error("Disconnect")] Disconnect, #[error("No Key")] @@ -29,4 +20,4 @@ pub enum Error { Other(String), } -pub type Result = std::result::Result; +pub type Result = anyhow::Result; diff --git a/src/main.rs b/src/main.rs index e3fa18f..dd93d2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,16 @@ +use aes_gcm::aead::rand_core::RngCore; +use anyhow::{anyhow, Context}; +use base64::engine::general_purpose; +use base64::Engine; +use boringtun::x25519::{PublicKey, StaticSecret}; +use clap::Parser; use std::collections::HashSet; -use std::fmt::Display; +use std::fmt::{Debug, Display, Formatter}; use std::io; use std::io::Write; use std::net::Ipv4Addr; use std::path::PathBuf; -use clap::Parser; - use crate::cipher::RsaCipher; mod cipher; @@ -15,6 +19,7 @@ mod error; mod generated_serial_number; mod proto; mod protocol; + pub const VNT_VERSION: &str = env!("CARGO_PKG_VERSION"); /// 默认网关信息 @@ -56,9 +61,12 @@ pub struct StartArgs { /// web后台用户密码,默认为admin #[arg(short = 'W', long)] password: Option, + /// wg私钥,使用base64编码 + #[arg(long = "wg")] + wg_secret_key: Option, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ConfigInfo { pub port: u16, pub white_token: Option>, @@ -70,6 +78,28 @@ pub struct ConfigInfo { pub username: String, #[cfg(feature = "web")] pub password: String, + pub wg_secret_key: StaticSecret, + pub wg_public_key: PublicKey, +} +impl Debug for ConfigInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConfigInfo") + .field("port", &self.port) + .field("white_token", &self.white_token) + .field("gateway", &self.gateway) + .field("broadcast", &self.broadcast) + .field("netmask", &self.netmask) + .field("check_finger", &self.check_finger) + .field( + "wg_secret_key", + &general_purpose::STANDARD.encode(&self.wg_secret_key), + ) + .field( + "wg_public_key", + &general_purpose::STANDARD.encode(&self.wg_public_key), + ) + .finish() + } } fn log_init(root_path: PathBuf, log_path: Option) { @@ -231,6 +261,22 @@ async fn main() { if check_finger { println!("转发校验数据指纹,客户端必须增加--finger参数"); } + let wg_secret_key: [u8; 32] = if let Some(wg_secret_key) = args.wg_secret_key { + let wg_secret_key = general_purpose::STANDARD + .decode(wg_secret_key) + .context("wg私钥错误") + .unwrap(); + wg_secret_key + .try_into() + .map_err(|_| anyhow!("wg私钥错误")) + .unwrap() + } else { + let mut wg_secret_key = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut wg_secret_key); + wg_secret_key + }; + let wg_secret_key = boringtun::x25519::StaticSecret::from(wg_secret_key); + let wg_public_key = boringtun::x25519::PublicKey::from(&wg_secret_key); let config = ConfigInfo { port, white_token, @@ -242,6 +288,8 @@ async fn main() { username: args.username.unwrap_or_else(|| "admin".into()), #[cfg(feature = "web")] password: args.password.unwrap_or_else(|| "admin".into()), + wg_secret_key, + wg_public_key, }; let rsa = match RsaCipher::new(root_path) { Ok(rsa) => { @@ -258,8 +306,8 @@ async fn main() { log::info!("监听udp端口: {:?}", port); println!("监听udp端口: {:?}", port); let tcp = create_tcp(port).unwrap(); - log::info!("监听tcp端口: {:?}", port); - println!("监听tcp端口: {:?}", port); + log::info!("监听tcp/ws端口: {:?}", port); + println!("监听tcp/ws端口: {:?}", port); #[cfg(feature = "web")] let http = if web_port != 0 { let http = create_tcp(web_port).unwrap(); diff --git a/src/protocol/body.rs b/src/protocol/body.rs index 7395be5..07d39c0 100644 --- a/src/protocol/body.rs +++ b/src/protocol/body.rs @@ -1,10 +1,264 @@ -#![allow(dead_code)] use std::{fmt, io}; pub const ENCRYPTION_RESERVED: usize = 16 + 32 + 12; pub const AES_GCM_ENCRYPTION_RESERVED: usize = 32; pub const RSA_ENCRYPTION_RESERVED: usize = 32; +pub const RANDOM_RESERVED: usize = 4; +pub const FINGER_RESERVED: usize = 12; +pub const TAG_RESERVED: usize = 16; + +/* ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| random(32) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| finger(32) | +| finger(32) | +| finger(32) | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +pub trait SecretTail { + fn buffer(&self) -> &[u8]; + fn exist_finger(&self) -> bool; + fn random_buf(&self) -> &[u8] { + let buf = self.buffer(); + let mut end = buf.len(); + if self.exist_finger() { + end -= FINGER_RESERVED; + } + &buf[end - RANDOM_RESERVED..end] + } + fn finger(&self) -> &[u8] { + if self.exist_finger() { + let buf = self.buffer(); + let end = buf.len(); + &buf[end - FINGER_RESERVED..end] + } else { + &[] + } + } +} + +pub trait SecretTailMut: SecretTail { + fn buffer_mut(&mut self) -> &mut [u8]; + fn set_random(&mut self, random: &[u8]) { + let f = self.exist_finger(); + let buf = self.buffer_mut(); + let mut end = buf.len(); + if f { + end -= FINGER_RESERVED; + } + buf[end - RANDOM_RESERVED..end].copy_from_slice(random); + } + fn set_finger(&mut self, finger: &[u8]) -> io::Result<()> { + if self.exist_finger() { + if finger.len() != FINGER_RESERVED { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "finger.len != 12", + )); + } + let buf = self.buffer_mut(); + let end = buf.len(); + buf[end - FINGER_RESERVED..end].copy_from_slice(finger); + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "not exist finger", + )) + } + } +} + +/* aead加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | tag(32) | + | tag(32) | + | tag(32) | + | tag(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + 注:finger用于快速校验数据是否被修改,上层可使用token、协议头参与计算finger, + 确保服务端和客户端都能感知修改(服务端不能解密也能校验指纹) +*/ +pub struct AEADSecretBody { + buffer: B, + exist_finger: bool, +} + +impl> AEADSecretBody { + pub fn new(buffer: B, exist_finger: bool) -> io::Result> { + let len = buffer.as_ref().len(); + let min_len = if exist_finger { + TAG_RESERVED + RANDOM_RESERVED + FINGER_RESERVED + } else { + TAG_RESERVED + RANDOM_RESERVED + }; + // 不能大于udp最大载荷长度 + if len < min_len || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("AEADSecretBody length overflow {}", len), + )); + } + Ok(AEADSecretBody { + buffer, + exist_finger, + }) + } + pub fn data(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - TAG_RESERVED - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[..end] + } + pub fn tag(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[end - TAG_RESERVED..end] + } +} + +impl> SecretTail for AEADSecretBody { + #[inline] + fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } + #[inline] + fn exist_finger(&self) -> bool { + self.exist_finger + } +} + +impl + AsMut<[u8]>> SecretTailMut for AEADSecretBody { + #[inline] + fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + +impl + AsMut<[u8]>> AEADSecretBody { + /// 数据部分 + pub fn data_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED - TAG_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } + /// 数据和tag部分 + pub fn data_tag_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } + pub fn set_tag(&mut self, tag: &[u8]) -> io::Result<()> { + if tag.len() != 16 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "tag.len != 16")); + } + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + self.buffer.as_mut()[end - TAG_RESERVED..end].copy_from_slice(tag); + Ok(()) + } +} + +/* 带随机数的加密数据体 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 数据体 | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | random(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | finger(32) | + | finger(32) | + | finger(32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + 注:finger用于快速校验数据是否被修改,上层可使用token、协议头参与计算finger, + 确保服务端和客户端都能感知修改(服务端不能解密也能校验指纹) +*/ +pub struct IVSecretBody { + buffer: B, + exist_finger: bool, +} + +impl> IVSecretBody { + pub fn new(buffer: B, exist_finger: bool) -> io::Result> { + let len = buffer.as_ref().len(); + let min_len = if exist_finger { + FINGER_RESERVED + RANDOM_RESERVED + } else { + RANDOM_RESERVED + }; + // 不能大于udp最大载荷长度 + if len < min_len || len > 65535 - 20 - 8 - 12 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("IVSecretBody length overflow {}", len), + )); + } + Ok(IVSecretBody { + buffer, + exist_finger, + }) + } + pub fn data(&self) -> &[u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &self.buffer.as_ref()[..end] + } +} + +impl + AsMut<[u8]>> IVSecretBody { + pub fn data_mut(&mut self) -> &mut [u8] { + let mut end = self.buffer.as_ref().len() - RANDOM_RESERVED; + if self.exist_finger { + end -= FINGER_RESERVED; + } + &mut self.buffer.as_mut()[..end] + } +} + +impl> SecretTail for IVSecretBody { + #[inline] + fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } + #[inline] + fn exist_finger(&self) -> bool { + self.exist_finger + } +} + +impl + AsMut<[u8]>> SecretTailMut for IVSecretBody { + #[inline] + fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + /* aes_gcm加密数据体 0 15 31 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -280,7 +534,7 @@ impl> RsaSecretBody { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); // 不能大于udp最大载荷长度 - if !(32..=65535 - 20 - 8 - 12).contains(&len) { + if len < 32 || len > 65535 - 20 - 8 - 12 { return Err(io::Error::new( io::ErrorKind::InvalidData, "length overflow", @@ -305,7 +559,7 @@ impl> RsaSecretBody { &self.buffer.as_ref()[end..] } pub fn buffer(&self) -> &[u8] { - self.buffer.as_ref() + &self.buffer.as_ref() } } diff --git a/src/protocol/control_packet.rs b/src/protocol/control_packet.rs index 3ae37b5..dbd1bd6 100644 --- a/src/protocol/control_packet.rs +++ b/src/protocol/control_packet.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] use std::net::Ipv4Addr; use std::{fmt, io}; @@ -6,14 +5,20 @@ use std::{fmt, io}; pub enum Protocol { /// ping请求 /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | time | echo | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | time | echo | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ Ping, - /// 维持连接,内容同ping + /* + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | time | echo | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ Pong, /// 打洞请求 PunchRequest, @@ -39,9 +44,9 @@ impl From for Protocol { } } -impl From for u8 { - fn from(val: Protocol) -> Self { - match val { +impl Into for Protocol { + fn into(self) -> u8 { + match self { Protocol::Ping => 1, Protocol::Pong => 2, Protocol::PunchRequest => 3, @@ -86,8 +91,8 @@ pub type PongPacket = PingPacket; impl> PingPacket { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); - if len != 4 { - return Err(io::Error::new(io::ErrorKind::InvalidData, "len != 4")); + if len < 4 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "len < 4")); } Ok(PingPacket { buffer }) } @@ -127,8 +132,8 @@ pub struct AddrPacket { impl> AddrPacket { pub fn new(buffer: B) -> io::Result> { let len = buffer.as_ref().len(); - if len != 6 { - return Err(io::Error::new(io::ErrorKind::InvalidData, "len != 6")); + if len < 6 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "len < 6")); } Ok(AddrPacket { buffer }) } diff --git a/src/protocol/error_packet.rs b/src/protocol/error_packet.rs index 018ca14..8014f44 100644 --- a/src/protocol/error_packet.rs +++ b/src/protocol/error_packet.rs @@ -1,6 +1,4 @@ -#![allow(dead_code)] - -use tokio::io; +use std::io; #[derive(Eq, PartialEq, Copy, Clone, Debug)] pub enum Protocol { diff --git a/src/protocol/extension.rs b/src/protocol/extension.rs new file mode 100644 index 0000000..1e43e06 --- /dev/null +++ b/src/protocol/extension.rs @@ -0,0 +1,141 @@ +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ + +use anyhow::anyhow; +use std::io; + +use crate::protocol::NetPacket; + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum ExtensionTailType { + Compression, + Unknown(u8), +} + +impl From for ExtensionTailType { + fn from(value: u8) -> Self { + if value == 0 { + ExtensionTailType::Compression + } else { + ExtensionTailType::Unknown(value) + } + } +} + +pub enum ExtensionTailPacket { + Compression(CompressionExtensionTail), + Unknown, +} + +impl + AsMut<[u8]>> NetPacket { + /// 分离尾部数据 + pub fn split_tail_packet(&mut self) -> anyhow::Result> { + if self.is_extension() { + let payload = self.payload(); + if let Some(v) = payload.last() { + return match ExtensionTailType::from(*v) { + ExtensionTailType::Compression => { + let data_len = self.data_len - 4; + self.set_data_len(data_len)?; + self.set_extension_flag(false); + Ok(ExtensionTailPacket::Compression( + CompressionExtensionTail::new( + &self.raw_buffer()[data_len..data_len + 4], + ), + )) + } + ExtensionTailType::Unknown(e) => Err(anyhow!("unknown extension {}", e)), + }; + } + } + Err(anyhow!("not extension")) + } + /// 追加压缩扩展 + pub fn append_compression_extension_tail( + &mut self, + ) -> io::Result> { + let len = self.data_len; + //增加数据长度 + self.set_data_len(self.data_len + 4)?; + self.set_extension_flag(true); + let mut tail = CompressionExtensionTail::new(&mut self.buffer_mut()[len..]); + tail.init(); + return Ok(tail); + } +} + +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | algorithm(8) | | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ +/// 压缩扩展 +pub struct CompressionExtensionTail { + buffer: B, +} + +impl> CompressionExtensionTail { + pub fn new(buffer: B) -> CompressionExtensionTail { + assert_eq!(buffer.as_ref().len(), 4); + CompressionExtensionTail { buffer } + } +} + +impl> CompressionExtensionTail { + pub fn algorithm(&self) -> CompressionAlgorithm { + self.buffer.as_ref()[0].into() + } +} + +impl + AsMut<[u8]>> CompressionExtensionTail { + pub fn init(&mut self) { + self.buffer.as_mut().fill(0); + } + pub fn set_algorithm(&mut self, algorithm: CompressionAlgorithm) { + self.buffer.as_mut()[0] = algorithm.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum CompressionAlgorithm { + #[cfg(feature = "lz4_compress")] + Lz4, + #[cfg(feature = "zstd_compress")] + Zstd, + Unknown(u8), +} + +impl From for CompressionAlgorithm { + fn from(value: u8) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + 1 => CompressionAlgorithm::Lz4, + #[cfg(feature = "zstd_compress")] + 2 => CompressionAlgorithm::Zstd, + v => CompressionAlgorithm::Unknown(v), + } + } +} + +impl From for u8 { + fn from(value: CompressionAlgorithm) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + CompressionAlgorithm::Lz4 => 1, + #[cfg(feature = "zstd_compress")] + CompressionAlgorithm::Zstd => 2, + CompressionAlgorithm::Unknown(val) => val, + } + } +} diff --git a/src/protocol/ip_turn_packet.rs b/src/protocol/ip_turn_packet.rs index 03ef556..edccdd5 100644 --- a/src/protocol/ip_turn_packet.rs +++ b/src/protocol/ip_turn_packet.rs @@ -6,6 +6,7 @@ use std::net::Ipv4Addr; #[derive(Copy, Clone, Eq, PartialEq, Debug)] pub enum Protocol { Ipv4, + WGIpv4, Ipv4Broadcast, Unknown(u8), } @@ -14,6 +15,7 @@ impl From for Protocol { fn from(value: u8) -> Self { match value { 4 => Protocol::Ipv4, + 5 => Protocol::WGIpv4, 201 => Protocol::Ipv4Broadcast, val => Protocol::Unknown(val), } @@ -24,6 +26,7 @@ impl From for u8 { fn from(val: Protocol) -> Self { match val { Protocol::Ipv4 => 4, + Protocol::WGIpv4 => 5, Protocol::Ipv4Broadcast => 201, Protocol::Unknown(val) => val, } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 4edc6d3..4fbc95a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] use crate::protocol::body::ENCRYPTION_RESERVED; -use std::fmt::Formatter; use std::net::Ipv4Addr; use std::{fmt, io}; @@ -9,7 +8,7 @@ use std::{fmt, io}; 0 15 31 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |e |s |u |u| 版本(4) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | + |e |s |x |u| 版本(4) | 协议(8) | 上层协议(8) | 初始ttl(4) | 生存时间(4) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 源ip地址(32) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -17,13 +16,14 @@ use std::{fmt, io}; +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | 数据体 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - 注:e为是否加密标志,s为服务端通信包标志,u未使用 + 注:e为是否加密标志,s为服务端通信包标志,x扩展标志,u未使用 */ pub const HEAD_LEN: usize = 12; pub mod body; pub mod control_packet; pub mod error_packet; +pub mod extension; pub mod ip_turn_packet; pub mod other_turn_packet; pub mod service_packet; @@ -44,9 +44,9 @@ impl From for Version { } } -impl From for u8 { - fn from(val: Version) -> Self { - match val { +impl Into for Version { + fn into(self) -> u8 { + match self { Version::V2 => 2, Version::Unknown(val) => val, } @@ -81,9 +81,9 @@ impl From for Protocol { } } -impl From for u8 { - fn from(val: Protocol) -> Self { - match val { +impl Into for Protocol { + fn into(self) -> u8 { + match self { Protocol::Service => 1, Protocol::Error => 2, Protocol::Control => 3, @@ -104,6 +104,10 @@ pub struct NetPacket { } impl> NetPacket { + pub fn unchecked(buffer: B) -> Self { + let data_len = buffer.as_ref().len(); + Self { data_len, buffer } + } pub fn new(buffer: B) -> io::Result> { let data_len = buffer.as_ref().len(); Self::new0(data_len, buffer) @@ -126,14 +130,15 @@ impl> NetPacket { "length overflow", )); } - if HEAD_LEN > data_len { + if data_len < 12 { return Err(io::Error::new( io::ErrorKind::InvalidData, - "length overflow", + "data_len too short", )); } Ok(NetPacket { data_len, buffer }) } + #[inline] pub fn buffer(&self) -> &[u8] { &self.buffer.as_ref()[..self.data_len] } @@ -160,6 +165,10 @@ impl> NetPacket { pub fn is_gateway(&self) -> bool { self.buffer.as_ref()[0] & 0x40 == 0x40 } + /// 扩展协议 + pub fn is_extension(&self) -> bool { + self.buffer.as_ref()[0] & 0x20 == 0x20 + } pub fn version(&self) -> Version { Version::from(self.buffer.as_ref()[0] & 0x0F) } @@ -192,6 +201,9 @@ impl> NetPacket { } impl + AsMut<[u8]>> NetPacket { + pub fn head_mut(&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[..12] + } pub fn buffer_mut(&mut self) -> &mut [u8] { &mut self.buffer.as_mut()[..self.data_len] } @@ -204,12 +216,18 @@ impl + AsMut<[u8]>> NetPacket { } pub fn set_gateway_flag(&mut self, is_gateway: bool) { if is_gateway { - // 后面的版本再改为0x40,改了之后不兼容1.2.5之前的版本 - self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x50 + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x40 } else { self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xBF }; } + pub fn set_extension_flag(&mut self, is_extension: bool) { + if is_extension { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x20 + } else { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xDF + }; + } pub fn set_default_version(&mut self) { let v: u8 = Version::V2.into(); self.buffer.as_mut()[0] = (self.buffer.as_ref()[0] & 0xF0) | (0x0F & v); @@ -266,6 +284,10 @@ impl + AsMut<[u8]>> NetPacket { self.data_len = data_len; Ok(()) } + pub fn set_payload_len(&mut self, payload_len: usize) -> io::Result<()> { + let data_len = HEAD_LEN + payload_len; + self.set_data_len(data_len) + } pub fn set_data_len_max(&mut self) { self.data_len = self.buffer.as_ref().len(); } @@ -287,19 +309,3 @@ impl> fmt::Debug for NetPacket { .finish() } } - -impl> fmt::Display for NetPacket { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("NetPacket") - .field("version", &self.version()) - .field("gateway", &self.is_gateway()) - .field("encrypt", &self.is_encrypt()) - .field("protocol", &self.protocol()) - .field("transport_protocol", &self.transport_protocol()) - .field("ttl", &self.ttl()) - .field("source_ttl", &self.source_ttl()) - .field("source", &self.source()) - .field("destination", &self.destination()) - .finish() - } -} diff --git a/static/css/index.css b/static/css/index.css index e69de29..0be0101 100644 --- a/static/css/index.css +++ b/static/css/index.css @@ -0,0 +1,174 @@ +.option-cell { + display: flex; + gap: 10px; /* 间隔 */ +} + +.option-cell button { + padding: 5px 10px; + font-size: 14px; + border: none; + border-radius: 4px; + cursor: pointer; + transition: background-color 0.3s ease; +} + +.option-cell button.delete-button { + background-color: #f44336; /* 红色 */ + color: white; +} + +.option-cell button.delete-button:hover { + background-color: #d32f2f; /* 深红色 */ +} + +.option-cell button.view-button { + background-color: #4CAF50; /* 绿色 */ + color: white; +} + +.option-cell button.view-button:hover { + background-color: #388E3C; /* 深绿色 */ +} + +/* wg弹窗样式 */ +.modal { + display: none; /* 默认隐藏 */ + position: fixed; + z-index: 1; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + background-color: rgb(0,0,0); + background-color: rgba(0,0,0,0.4); + padding-top: 60px; +} + +.modal-content { + position: relative; + background-color: #fefefe; + width: 380px; + height: 380px; + margin: 5% auto; + padding: 50px; + border: 1px solid #888; + text-align: center; + box-sizing: border-box; + border-radius: 5px; +} +.add-modal-content{ + position: relative; + background-color: #fefefe; + margin: 5% auto; + padding: 20px; + border: 1px solid #888; + width: 80%; + max-width: 600px; + box-sizing: border-box; + border-radius: 5px; +} +.form-group { + display: flex; + align-items: center; + margin: 10px 0; +} + +.form-group label { + flex: 1; + margin-right: 10px; +} + +.form-group input { + flex: 2; + padding: 10px; + box-sizing: border-box; +} + +.modal button { + padding: 10px 20px; + margin: 10px 5px; +} + +.button-container { + text-align: center; +} +.button-container button { + padding: 10px 20px; + margin: 10px 5px; + border: none; + border-radius: 5px; + cursor: pointer; + font-size: 16px; + transition: background-color 0.3s, box-shadow 0.3s; +} + +.error { + color: red; + font-size: 14px; +} + +#confirmButton { + background-color: #4CAF50; /* 绿色背景 */ + color: white; +} + +#confirmButton:hover { + background-color: #45a049; /* 鼠标悬停时变暗 */ + box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */ +} + +#cancelButton { + background-color: #f44336; /* 红色背景 */ + color: white; +} + +#cancelButton:hover { + background-color: #e53935; /* 鼠标悬停时变暗 */ + box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */ +} +.modal-content .title{ + position: absolute; + left: 20px; + top: 10px; +} + +.close { + position: absolute; + right: 20px; + top: 10px; + color: #aaa; + font-size: 28px; + font-weight: bold; +} + +.close:hover, +.close:focus { + color: black; + text-decoration: none; + cursor: pointer; +} + +.hidden { + display: none; +} + +.visible { + display: block; +} +#qrcode { + margin: 0 auto; /* 居中 */ + width: 260px; /* 固定宽度 */ + height: 260px; /* 固定高度 */ +} +pre { + text-align: left; /* 左对齐 */ + white-space: pre-wrap; /* 自动换行 */ + word-break: break-all; + user-select: auto; +} +#toggleButton{ + position: absolute; + bottom: 10px; + left: 116px; +} \ No newline at end of file diff --git a/static/css/select.css b/static/css/select.css index 892ede3..f149533 100644 --- a/static/css/select.css +++ b/static/css/select.css @@ -123,7 +123,27 @@ body { padding-top: 20px; margin-left: 300px; } +button{ + padding: 10px 20px; + margin: 10px 5px; + border: none; + border-radius: 5px; + cursor: pointer; + font-size: 16px; + transition: background-color 0.3s, box-shadow 0.3s; +} +#addWireGuard { + background-color: #4CAF50; /* 绿色背景 */ + color: white; + position: absolute; + right: 50px; + top:20px; +} +#addWireGuard:hover { + background-color: #45a049; /* 鼠标悬停时变暗 */ + box-shadow: 0 0 10px rgba(0, 0, 0, 0.2); /* 阴影效果 */ +} /* 下拉菜单按钮 */ diff --git a/static/index.html b/static/index.html index 472448b..80996db 100644 --- a/static/index.html +++ b/static/index.html @@ -4,9 +4,11 @@ + + vnts-web